~christopher-hunt08/maus/maus_integrated_kalman

« back to all changes in this revision

Viewing changes to src/common_cpp/Recon/Global/GlobalTools.cc

merging in changes in merge branch

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/* This file is part of MAUS: http://micewww.pp.rl.ac.uk:8080/projects/maus
 
2
 *
 
3
 * MAUS is free software: you can redistribute it and/or modify
 
4
 * it under the terms of the GNU General Public License as published by
 
5
 * the Free Software Foundation, either version 3 of the License, or
 
6
 * (at your option) any later version.
 
7
 *
 
8
 * MAUS is distributed in the hope that it will be useful,
 
9
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
 * GNU General Public License for more details.
 
12
 *
 
13
 * You should have received a copy of the GNU General Public License
 
14
 * along with MAUS.  If not, see <http://www.gnu.org/licenses/>.
 
15
 *
 
16
 */
 
17
 
 
18
#include <algorithm>
 
19
#include <cmath>
 
20
 
 
21
#include "Geant4/G4Navigator.hh"
 
22
#include "Geant4/G4TransportationManager.hh"
 
23
#include "Geant4/G4NistManager.hh"
 
24
 
 
25
#include "gsl/gsl_odeiv.h"
 
26
#include "gsl/gsl_errno.h"
 
27
 
 
28
#include "src/common_cpp/DataStructure/ReconEvent.hh"
 
29
#include "src/common_cpp/Recon/Global/Particle.hh"
 
30
#include "src/common_cpp/Simulation/GeometryNavigator.hh"
 
31
#include "src/common_cpp/Utils/Exception.hh"
 
32
#include "src/common_cpp/Utils/Globals.hh"
 
33
 
 
34
#include "src/legacy/BeamTools/BTField.hh"
 
35
#include "src/legacy/Config/MiceModule.hh"
 
36
 
 
37
#include "src/common_cpp/Recon/Global/GlobalTools.hh"
 
38
 
 
39
namespace MAUS {
 
40
namespace GlobalTools {
 
41
 
 
42
std::map<DataStructure::Global::DetectorPoint, bool>
 
43
    GetReconDetectors(GlobalEvent* global_event) {
 
44
  std::map<DataStructure::Global::DetectorPoint, bool> recon_detectors;
 
45
  for (int i = 0; i < 27; i++) {
 
46
    recon_detectors[static_cast<DataStructure::Global::DetectorPoint>(i)] =
 
47
      false;
 
48
  }
 
49
  DataStructure::Global::TrackPArray* imported_tracks =
 
50
      global_event->get_tracks();
 
51
 
 
52
  for (auto imported_track_iter = imported_tracks->begin();
 
53
       imported_track_iter != imported_tracks->end();
 
54
       ++imported_track_iter) {
 
55
    std::vector<DataStructure::Global::DetectorPoint> track_detectors =
 
56
        (*imported_track_iter)->GetDetectorPoints();
 
57
    for (size_t i = 0; i < track_detectors.size(); i++) {
 
58
      recon_detectors[track_detectors[i]] = true;
 
59
    }
 
60
  }
 
61
 
 
62
  std::vector<DataStructure::Global::SpacePoint*>*
 
63
      imported_spacepoints = global_event->get_space_points();
 
64
  for (auto sp_iter = imported_spacepoints->begin();
 
65
       sp_iter != imported_spacepoints->end();
 
66
       ++sp_iter) {
 
67
    recon_detectors[(*sp_iter)->get_detector()] = true;
 
68
  }
 
69
  return recon_detectors;
 
70
}
 
71
 
 
72
std::vector<DataStructure::Global::Track*>* GetSpillDetectorTracks(Spill* spill,
 
73
    DataStructure::Global::DetectorPoint detector, std::string mapper_name) {
 
74
  std::vector<DataStructure::Global::Track*>* detector_tracks = new
 
75
      std::vector<DataStructure::Global::Track*>;
 
76
  ReconEventPArray* recon_events = spill->GetReconEvents();
 
77
  if (recon_events) {
 
78
    for (auto recon_event_iter = recon_events->begin();
 
79
         recon_event_iter != recon_events->end();
 
80
         ++recon_event_iter) {
 
81
      GlobalEvent* global_event = (*recon_event_iter)->GetGlobalEvent();
 
82
      if (global_event) {
 
83
        std::vector<DataStructure::Global::Track*>* global_tracks =
 
84
            global_event->get_tracks();
 
85
        for (auto track_iter = global_tracks->begin();
 
86
             track_iter != global_tracks->end();
 
87
             ++track_iter) {
 
88
          // The third condition is a bit of a dirty hack here to make sure that
 
89
          // if we select for EMR tracks, we only get primaries.
 
90
          if (((*track_iter)->HasDetector(detector)) and
 
91
              ((*track_iter)->get_mapper_name() == mapper_name) and
 
92
              ((*track_iter)->get_emr_range_secondary() < 0.001)) {
 
93
            detector_tracks->push_back((*track_iter));
 
94
          }
 
95
        }
 
96
      }
 
97
    }
 
98
  }
 
99
  return detector_tracks;
 
100
}
 
101
 
 
102
std::vector<DataStructure::Global::SpacePoint*>* GetSpillSpacePoints(
 
103
    Spill* spill, DataStructure::Global::DetectorPoint detector) {
 
104
  std::vector<DataStructure::Global::SpacePoint*>* spill_spacepoints =
 
105
      new std::vector<DataStructure::Global::SpacePoint*>;
 
106
  ReconEventPArray* recon_events = spill->GetReconEvents();
 
107
  if (recon_events) {
 
108
    for (auto recon_event_iter = recon_events->begin();
 
109
         recon_event_iter != recon_events->end();
 
110
         ++recon_event_iter) {
 
111
      GlobalEvent* global_event = (*recon_event_iter)->GetGlobalEvent();
 
112
      if (global_event) {
 
113
        std::vector<DataStructure::Global::SpacePoint*>* spacepoints =
 
114
            global_event->get_space_points();
 
115
        for (auto sp_iter = spacepoints->begin(); sp_iter != spacepoints->end();
 
116
             ++sp_iter) {
 
117
          if ((*sp_iter)->get_detector() == detector) {
 
118
            spill_spacepoints->push_back(*sp_iter);
 
119
          }
 
120
        }
 
121
      }
 
122
    }
 
123
  }
 
124
  if (spill_spacepoints->size() > 0) {
 
125
    return spill_spacepoints;
 
126
  } else {
 
127
    return 0;
 
128
  }
 
129
}
 
130
 
 
131
std::vector<DataStructure::Global::Track*>* GetTracksByMapperName(
 
132
    GlobalEvent* global_event, std::string mapper_name) {
 
133
  std::vector<DataStructure::Global::Track*>* global_tracks =
 
134
      global_event->get_tracks();
 
135
  std::vector<DataStructure::Global::Track*>* selected_tracks = new
 
136
      std::vector<DataStructure::Global::Track*>;
 
137
  for (auto global_track_iter = global_tracks->begin();
 
138
       global_track_iter != global_tracks->end();
 
139
       ++global_track_iter) {
 
140
    if ((*global_track_iter)->get_mapper_name() == mapper_name) {
 
141
      selected_tracks->push_back(*global_track_iter);
 
142
    }
 
143
  }
 
144
  return selected_tracks;
 
145
}
 
146
 
 
147
std::vector<DataStructure::Global::Track*>* GetTracksByMapperName(
 
148
    GlobalEvent* global_event, std::string mapper_name,
 
149
    DataStructure::Global::PID pid) {
 
150
  std::vector<DataStructure::Global::Track*>* global_tracks =
 
151
          global_event->get_tracks();
 
152
  std::vector<DataStructure::Global::Track*>* selected_tracks = new
 
153
      std::vector<DataStructure::Global::Track*>;
 
154
  for (auto global_track_iter = global_tracks->begin();
 
155
       global_track_iter != global_tracks->end();
 
156
       ++global_track_iter) {
 
157
    if ((*global_track_iter)->get_mapper_name() == mapper_name) {
 
158
      if ((*global_track_iter)->get_pid() == pid) {
 
159
        selected_tracks->push_back(*global_track_iter);
 
160
      }
 
161
    }
 
162
  }
 
163
  return selected_tracks;
 
164
}
 
165
 
 
166
std::vector<int> GetTrackerPlane(const DataStructure::Global::TrackPoint*
 
167
    track_point, std::vector<double> z_positions) {
 
168
  std::vector<int> tracker_plane(3, 0);
 
169
  double z = track_point->get_position().Z();
 
170
  int plane = 100;
 
171
  for (size_t i = 0; i < z_positions.size(); i++) {
 
172
    if (approx(z, z_positions[i], 0.25)) {
 
173
      plane = i;
 
174
      break;
 
175
    }
 
176
  }
 
177
  if (plane < 15) {
 
178
    tracker_plane[0] = 0;
 
179
    tracker_plane[1] = 5 - plane/3;
 
180
    tracker_plane[2] = 2 - plane%3;
 
181
  } else if (plane < 30) {
 
182
    tracker_plane[0] = 1;
 
183
    tracker_plane[1] = plane/3 - 4;
 
184
    tracker_plane[2] = plane%3;
 
185
  } else {
 
186
    // error output
 
187
    tracker_plane[0] = 99;
 
188
  }
 
189
  return tracker_plane;
 
190
}
 
191
 
 
192
std::vector<double> GetTrackerPlaneZPositions(std::string geo_filename) {
 
193
  MiceModule* geo_module = new MiceModule(geo_filename);
 
194
  std::vector<const MiceModule*> tracker_planes =
 
195
      geo_module->findModulesByPropertyString("SensitiveDetector", "SciFi");
 
196
  std::vector<double> z_positions;
 
197
  for (size_t i = 0; i < tracker_planes.size(); i++) {
 
198
    z_positions.push_back(tracker_planes.at(i)->globalPosition().getZ());
 
199
  }
 
200
  std::sort(z_positions.begin(), z_positions.end());
 
201
  return z_positions;
 
202
}
 
203
 
 
204
bool approx(double a, double b, double tolerance) {
 
205
  if (std::abs(a - b) > std::abs(tolerance)) {
 
206
    return false;
 
207
  } else {
 
208
    return true;
 
209
  }
 
210
}
 
211
 
 
212
DataStructure::Global::TrackPoint* GetNearestZTrackPoint(
 
213
    const DataStructure::Global::Track* track, double z_position) {
 
214
  std::vector<const DataStructure::Global::TrackPoint*> trackpoints =
 
215
      track->GetTrackPoints();
 
216
  size_t nearest_index = 0;
 
217
  double z_distance = 1.0e20;
 
218
  for (size_t i = 0; i < trackpoints.size(); i++) {
 
219
    double current_distance = std::abs(z_position -
 
220
                                       trackpoints.at(i)->get_position().Z());
 
221
    if (current_distance < z_distance) {
 
222
      nearest_index = i;
 
223
      z_distance = current_distance;
 
224
    }
 
225
  }
 
226
  DataStructure::Global::TrackPoint* nearest_track_point = new
 
227
      DataStructure::Global::TrackPoint(*trackpoints.at(nearest_index));
 
228
  return nearest_track_point;
 
229
}
 
230
 
 
231
double dEdx(const G4Material* material, double E, double m) {
 
232
  double constant = 2.54955123375e-23;
 
233
  double m_e = 0.510998928;
 
234
  double beta = std::sqrt(1 - (m*m)/(E*E));
 
235
  double beta2 = beta*beta;
 
236
  double gamma = 1/std::sqrt(1 - beta2);
 
237
  double bg = beta*gamma;
 
238
  double bg2 = bg*bg;
 
239
  double mRatio = m_e/m;
 
240
  double T_max = 2.0*m_e*bg2/(1.0 + 2.0*gamma*mRatio + mRatio*mRatio);
 
241
 
 
242
  double n_e = material->GetElectronDensity();
 
243
  double I = material->GetIonisation()->GetMeanExcitationEnergy();
 
244
  double x_0 = material->GetIonisation()->GetX0density();
 
245
  double x_1 = material->GetIonisation()->GetX1density();
 
246
  double C = material->GetIonisation()->GetCdensity();
 
247
  double a = material->GetIonisation()->GetAdensity();
 
248
  double k = material->GetIonisation()->GetMdensity();
 
249
 
 
250
  double logterm = std::log(2.0*m_e*bg2*T_max/(I*I));
 
251
  double x = std::log(bg)/std::log(10);
 
252
 
 
253
  // density correction
 
254
  double delta = 0.0;
 
255
  if (x > x_0) {
 
256
    delta = 2*std::log(10)*x - C;
 
257
    if (x < x_1) {
 
258
      delta += a*std::pow((x_1 - x), k);
 
259
    }
 
260
  }
 
261
  double dEdx = -constant*n_e/beta2*(logterm - 2*beta2 - delta);
 
262
  return dEdx;
 
263
}
 
264
 
 
265
// Need some global variables here
 
266
static const BTField* _field;
 
267
static int _charge;
 
268
 
 
269
void propagate(double* x, double target_z, const BTField* field,
 
270
               double step_size, DataStructure::Global::PID pid,
 
271
               bool energy_loss) {
 
272
  if (std::abs(target_z) > 100000) {
 
273
    throw(Exception(Exception::recoverable, "Extreme target z",
 
274
                    "GlobalTools::propagate"));
 
275
  }
 
276
  int prop_dir = 1;
 
277
  _field = field;
 
278
  _charge = recon::global::Particle::GetInstance().GetCharge(pid);
 
279
  double mass = recon::global::Particle::GetInstance().GetMass(pid);
 
280
  G4Navigator* g4navigator = G4TransportationManager::GetTransportationManager()
 
281
      ->GetNavigatorForTracking();
 
282
  G4NistManager* manager = G4NistManager::Instance();
 
283
  bool backwards = false;
 
284
  // If we propagate backwards, reverse momentum 4-vector
 
285
  if (target_z < x[3]) {
 
286
    backwards = true;
 
287
    _charge *= -1;
 
288
    prop_dir = -1;
 
289
    for (size_t i = 4; i < 8; i++) {
 
290
      x[i] *= -1;
 
291
    }
 
292
  }
 
293
  const gsl_odeiv_step_type * T = gsl_odeiv_step_rk4;
 
294
  double absolute_error = (*Globals::GetInstance()->GetConfigurationCards())
 
295
                           ["field_tracker_absolute_error"].asDouble();
 
296
  double relative_error = (*Globals::GetInstance()->GetConfigurationCards())
 
297
                           ["field_tracker_relative_error"].asDouble();
 
298
  gsl_odeiv_step    * step    = gsl_odeiv_step_alloc(T, 8);
 
299
  gsl_odeiv_control * control = gsl_odeiv_control_y_new(absolute_error,
 
300
                                                        relative_error);
 
301
  gsl_odeiv_evolve  * evolve  = gsl_odeiv_evolve_alloc(8);
 
302
  int (*FuncEqM)(double z, const double y[], double f[], void *params)=NULL;
 
303
  FuncEqM = z_equations_of_motion;
 
304
  gsl_odeiv_system system  = {FuncEqM, NULL, 8, NULL};
 
305
  double h = step_size*prop_dir;
 
306
  size_t max_steps = 10000000;
 
307
  size_t n_steps = 0;
 
308
  double z = x[3];
 
309
  while (fabs(z - target_z) > 1e-6) {
 
310
    n_steps++;
 
311
    h = step_size*prop_dir; // revert step size as large step size problematic for dEdx
 
312
    int status;
 
313
    if (energy_loss) {
 
314
      const CLHEP::Hep3Vector posvector(x[1], x[2], x[3]);
 
315
      double mommag = std::sqrt(x[5]*x[5] + x[6]*x[6] + x[7]*x[7]);
 
316
      const CLHEP::Hep3Vector momvector(x[5]/mommag, x[6]/mommag, x[7]/mommag);
 
317
      g4navigator->LocateGlobalPointAndSetup(posvector, &momvector);
 
318
      GeometryNavigator geometry_navigator;
 
319
      geometry_navigator.Initialise(g4navigator->GetWorldVolume());
 
320
      double safety = 10;
 
321
      double boundary_dist = g4navigator->ComputeStep(posvector, momvector, h, safety);
 
322
      if (boundary_dist > 1e6) {
 
323
        boundary_dist = safety;
 
324
      }
 
325
      double z_dist = boundary_dist*momvector.z();
 
326
      // Check if z distance to next material boundary is smaller than step size
 
327
      // if yes, we impose a tight limit on the step size to avoid issues
 
328
      // arising from the track not being straight
 
329
      if (std::abs(z_dist) < std::abs(h)) {
 
330
        if (std::abs(z_dist) > 2.0) {
 
331
          h = 2.0*prop_dir;
 
332
        } else {
 
333
          h = z_dist; // will have proper sign from momvector
 
334
        }
 
335
      }
 
336
      if (std::abs(h) < 0.1) {
 
337
        h = 0.1*prop_dir;
 
338
      }
 
339
      double x_prev[] = {x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]};
 
340
      status = gsl_odeiv_evolve_apply(evolve, control, step, &system, &z,
 
341
                                        target_z, &h, x);
 
342
      // Calculate energy loss for the step
 
343
      geometry_navigator.SetPoint(ThreeVector((x[1] + x_prev[1])/2,
 
344
                                  (x[2] + x_prev[2])/2, (x[3] + x_prev[3])/2));
 
345
      double step_distance = std::sqrt((x[1]-x_prev[1])*(x[1]-x_prev[1]) +
 
346
                                       (x[2]-x_prev[2])*(x[2]-x_prev[2]) +
 
347
                                       (x[3]-x_prev[3])*(x[3]-x_prev[3]));
 
348
      G4Material* material = manager->FindOrBuildMaterial(geometry_navigator.GetMaterialName());
 
349
      double step_energy_loss = dEdx(material, x[4], mass)*step_distance;
 
350
      changeEnergy(x, step_energy_loss, mass);
 
351
    } else {
 
352
      status = gsl_odeiv_evolve_apply(evolve, control, step, &system, &z,
 
353
                                        target_z, &h, x);
 
354
    }
 
355
    if (status != GSL_SUCCESS) {
 
356
      throw(Exception(Exception::recoverable, "Propagation failed",
 
357
                            "GlobalTools::propagate"));
 
358
    }
 
359
 
 
360
    if (n_steps > max_steps) {
 
361
      std::stringstream ios;
 
362
      ios << "Stopping at step " << n_steps << " of " << max_steps << "\n"
 
363
          << "t: " << x[0] << " pos: " << x[1] << " " << x[2] << " " << x[3] << "\n"
 
364
          << "E: " << x[4] << " mom: " << x[5] << " " << x[6] << " " << x[7] << std::endl;
 
365
      throw(Exception(Exception::recoverable, ios.str()+
 
366
            "Exceeded maximum number of steps", "GlobalTools::propagate"));
 
367
      break;
 
368
    }
 
369
 
 
370
    // Need to catch the case where the particle is stopped
 
371
    if (std::abs(x[4]) < (mass + 0.01)) {
 
372
      std::stringstream ios;
 
373
      ios << "t: " << x[0] << " pos: " << x[1] << " " << x[2] << " " << x[3] << std::endl;
 
374
      throw(Exception(Exception::recoverable, ios.str()+
 
375
            "Particle terminated with 0 momentum", "GlobalTools::propagate"));
 
376
    }
 
377
  }
 
378
  gsl_odeiv_evolve_free(evolve);
 
379
  gsl_odeiv_control_free(control);
 
380
  gsl_odeiv_step_free(step);
 
381
 
 
382
  // If we propagate backwards, reverse momentum 4-vector back to original sign
 
383
  if (backwards) {
 
384
    for (size_t i = 4; i < 8; i++) {
 
385
      x[i] *= -1;
 
386
    }
 
387
  }
 
388
}
 
389
 
 
390
int z_equations_of_motion(double z, const double x[8], double dxdz[8],
 
391
                                   void* params) {
 
392
  if (fabs(x[7]) < 1e-9) {
 
393
  // z-momentum is 0
 
394
    return GSL_ERANGE;
 
395
  }
 
396
  const double c_l = 299.792458; // mm*ns^{-1}
 
397
  double field[6] = {0.0, 0.0, 0.0, 0.0, 0.0, 0.0};
 
398
  double xfield[4] = {x[1], x[2], x[3], x[0]};
 
399
  _field->GetFieldValue(xfield, field);
 
400
  double dtdz = x[4]/x[7];
 
401
  double dir = fabs(x[7])/x[7]; // direction of motion
 
402
  dxdz[0] = dtdz/c_l; // dt/dz
 
403
  dxdz[1] = x[5]/x[7]; // dx/dz = px/pz
 
404
  dxdz[2] = x[6]/x[7]; // dy/dz = py/pz
 
405
  dxdz[3] = 1.0; // dz/dz
 
406
  // dE/dz only contains electric field as B conserves energy, not relevant at
 
407
  // least in step 4 as all fields are static.
 
408
  dxdz[4] = (dxdz[1]*_charge*field[3] + dxdz[2]*_charge*field[4] +
 
409
             _charge*field[5])*dir; // dE/dz
 
410
  // dpx/dz = q*c*(dy/dz*Bz - dz/dz*By) + q*Ex*dt/dz
 
411
  dxdz[5] = _charge*c_l*(dxdz[2]*field[2] - dxdz[3]*field[1])
 
412
            + _charge*field[3]*dtdz*dir; // dpx/dz
 
413
  dxdz[6] = _charge*c_l*(dxdz[3]*field[0] - dxdz[1]*field[2])
 
414
            + _charge*field[4]*dtdz*dir; // dpy/dz
 
415
  dxdz[7] = _charge*c_l*(dxdz[1]*field[1] - dxdz[2]*field[0])
 
416
            + _charge*field[5]*dtdz*dir; // dpz/dz
 
417
  return GSL_SUCCESS;
 
418
}
 
419
 
 
420
void changeEnergy(double* x, double deltaE, double mass) {
 
421
  double old_momentum = std::sqrt(x[5]*x[5] + x[6]*x[6] + x[7]*x[7]);
 
422
  x[4] += deltaE;
 
423
  double new_momentum, momentum_ratio;
 
424
  if (std::abs(x[4]) > mass) {
 
425
    new_momentum = std::sqrt(x[4]*x[4] - mass*mass);
 
426
  } else {
 
427
    new_momentum = 0.0;
 
428
  }
 
429
  momentum_ratio = new_momentum / old_momentum;
 
430
  x[5] *= momentum_ratio;
 
431
  x[6] *= momentum_ratio;
 
432
  x[7] *= momentum_ratio;
 
433
}
 
434
 
 
435
bool TrackPointSort(const DataStructure::Global::TrackPoint* tp1,
 
436
                    const DataStructure::Global::TrackPoint* tp2) {
 
437
  return (tp1->get_position().Z() < tp2->get_position().Z());
 
438
}
 
439
 
 
440
} // ~namespace GlobalTools
 
441
} // ~namespace MAUS