~christopher-hunt08/maus/maus_integrated_kalman

« back to all changes in this revision

Viewing changes to src/common_cpp/Recon/Global/MinuitTrackFitter.cc

  • Committer: Durga Rajaram
  • Date: 2014-01-14 07:07:02 UTC
  • mfrom: (659.1.80 relcand)
  • Revision ID: durga@fnal.gov-20140114070702-2l1fuj1w6rraw7xe
Tags: MAUS-v0.7.6
MAUS-v0.7.6

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
/* This file is part of MAUS: http://  micewww.pp.rl.ac.uk:8080/projects/maus
 
2
 * 
 
3
 * MAUS is free software: you can redistribute it and/or modify
 
4
 * it under the terms of the GNU General Public License as published by
 
5
 * the Free Software Foundation, either version 3 of the License, or
 
6
 * (at your option) any later version.
 
7
 * 
 
8
 * MAUS is distributed in the hope that it will be useful,
 
9
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 
10
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 
11
 * GNU General Public License for more details.
 
12
 * 
 
13
 * You should have received a copy of the GNU General Public License
 
14
 * along with MAUS.  If not, see <http://  www.gnu.org/licenses/>.
 
15
 */
 
16
 
 
17
/* Author: Peter Lane
 
18
 */
 
19
 
 
20
#include "Recon/Global/MinuitTrackFitter.hh"
 
21
 
 
22
#include <cmath>
 
23
#include <iostream>
 
24
#include <limits>
 
25
#include <sstream>
 
26
#include <string>
 
27
#include <vector>
 
28
 
 
29
#include "TMinuit.h"
 
30
#include "json/json.h"
 
31
 
 
32
#include "CLHEP/Units/PhysicalConstants.h"
 
33
#include "DataStructure/Global/Track.hh"
 
34
#include "DataStructure/Global/TrackPoint.hh"
 
35
#include "Utils/Exception.hh"
 
36
#include "src/common_cpp/Optics/CovarianceMatrix.hh"
 
37
#include "src/common_cpp/Optics/OpticsModel.hh"
 
38
#include "src/common_cpp/Optics/PhaseSpaceVector.hh"
 
39
#include "src/common_cpp/Optics/TransferMap.hh"
 
40
#include "src/common_cpp/Simulation/MAUSGeant4Manager.hh"
 
41
#include "src/common_cpp/Simulation/MAUSPrimaryGeneratorAction.hh"
 
42
#include "Recon/Global/DataStructureHelper.hh"
 
43
#include "Recon/Global/Particle.hh"
 
44
#include "Recon/Global/ParticleOpticalVector.hh"
 
45
#include "src/common_cpp/Utils/JsonWrapper.hh"
 
46
 
 
47
namespace MAUS {
 
48
namespace recon {
 
49
namespace global {
 
50
 
 
51
using MAUS::DataStructure::Global::Track;
 
52
using MAUS::DataStructure::Global::TrackPoint;
 
53
namespace GlobalDS = MAUS::DataStructure::Global;
 
54
 
 
55
void common_cpp_optics_recon_minuit_track_fitter_score_function(
 
56
    Int_t &    number_of_parameters,
 
57
    Double_t * gradiants,
 
58
    Double_t & score,
 
59
    Double_t * phase_space_coordinate_values,
 
60
    Int_t      execution_stage_flag) {
 
61
  for (size_t index = 0; index < 6; ++index) {
 
62
    std::cout << "DEBUG common_..._score_function: coordinate[" << index << "] = "
 
63
              << phase_space_coordinate_values[index] << std::endl;
 
64
  }
 
65
  // common_cpp_optics_recon_minuit_track_fitter_minuit is defined
 
66
  // globally in the header file
 
67
  TMinuit * minuit
 
68
    = common_cpp_optics_recon_minuit_track_fitter_minuit;
 
69
 
 
70
  MinuitTrackFitter * track_fitter
 
71
    = static_cast<MinuitTrackFitter *>(minuit->GetObjectFit());
 
72
 
 
73
  score = MinuitTrackFitter::ScoreTrack(
 
74
            phase_space_coordinate_values,
 
75
            *track_fitter->optics_model_,
 
76
            Particle::GetInstance().GetMass(track_fitter->particle_id_),
 
77
            track_fitter->detector_events_);
 
78
}
 
79
 
 
80
MinuitTrackFitter::MinuitTrackFitter(
 
81
    OpticsModel const * optics_model,
 
82
    const double start_plane)
 
83
    : TrackFitter(optics_model, start_plane), rounds_(0) {
 
84
  // Setup *global* scope Minuit object
 
85
  common_cpp_optics_recon_minuit_track_fitter_minuit
 
86
    = new TMinuit(kPhaseSpaceDimension);
 
87
  TMinuit * minimizer
 
88
    = common_cpp_optics_recon_minuit_track_fitter_minuit;
 
89
 
 
90
  minimizer->SetObjectFit(this);
 
91
 
 
92
  minimizer->SetFCN(
 
93
    common_cpp_optics_recon_minuit_track_fitter_score_function);
 
94
 
 
95
  ResetParameters();
 
96
}
 
97
 
 
98
MinuitTrackFitter::~MinuitTrackFitter() {
 
99
  delete common_cpp_optics_recon_minuit_track_fitter_minuit;
 
100
}
 
101
 
 
102
void MinuitTrackFitter::ResetParameters() {
 
103
  TMinuit * minimizer
 
104
    = common_cpp_optics_recon_minuit_track_fitter_minuit;
 
105
 
 
106
  // setup the index, name, init value, step size, min, and max value for each
 
107
  // phase space variable (mins and maxes calculated from 800MeV/c ISIS beam)
 
108
  int error_flag = 0;
 
109
  Json::Value const * configuration = optics_model_->configuration();
 
110
  if (configuration == NULL) {
 
111
    throw(Exception(Exception::nonRecoverable,
 
112
          "Initialized with a null configuration.",
 
113
          "MAUS::MinuitTrackFitter::ResetParameters()"));
 
114
  }
 
115
  const Json::Value parameters = JsonWrapper::GetProperty(
 
116
      *configuration, "global_recon_minuit_parameters",
 
117
      JsonWrapper::arrayValue);
 
118
  const Json::Value::UInt parameter_count = parameters.size();
 
119
  if (parameter_count != 6) {
 
120
    std::stringstream message;
 
121
    message << "Expected 6 elements in \"global_recon_minuit_parameters\""
 
122
            << " but found " << parameter_count << "." << std::endl;
 
123
    throw(Exception(Exception::nonRecoverable,
 
124
          message.str(),
 
125
          "MAUS::MinuitTrackFitter::ResetParameters()"));
 
126
  }
 
127
  for (Json::Value::UInt index = 0; index < parameter_count; ++index) {
 
128
    const Json::Value parameter = parameters[index];
 
129
    const std::string name = JsonWrapper::GetProperty(
 
130
        parameter, "name", JsonWrapper::stringValue).asString();
 
131
    const bool fixed = JsonWrapper::GetProperty(
 
132
        parameter, "fixed", JsonWrapper::booleanValue).asBool();
 
133
    const double initial_value = JsonWrapper::GetProperty(
 
134
        parameter, "initial_value", JsonWrapper::realValue).asDouble();
 
135
    const double value_step = JsonWrapper::GetProperty(
 
136
        parameter, "value_step", JsonWrapper::realValue).asDouble();
 
137
    const double min_value = JsonWrapper::GetProperty(
 
138
        parameter, "min_value", JsonWrapper::realValue).asDouble();
 
139
    const double max_value = JsonWrapper::GetProperty(
 
140
        parameter, "max_value", JsonWrapper::realValue).asDouble();
 
141
 
 
142
    minimizer->mnparm(index, name, initial_value, value_step,
 
143
                       min_value, max_value, error_flag);
 
144
    if (fixed) {
 
145
      minimizer->FixParameter(index);
 
146
    }
 
147
  }
 
148
}
 
149
 
 
150
void MinuitTrackFitter::Fit(Track const * const raw_track, Track * const track,
 
151
                            const std::string mapper_name) {
 
152
  std::cout << "CHECKPOINT Fit(): BEGIN" << std::endl;
 
153
  std::cout.flush();
 
154
  detector_events_ = raw_track->GetTrackPoints();
 
155
  std::cout << "DEBUG MinuitTrackFitter::Fit(): CHECKPOINT 0" << std::endl;
 
156
  std::cout << "DEBUG MinuitTrackFitter::Fit(): Fitting track with "
 
157
            << detector_events_.size() << " track points." << std::endl;
 
158
  std::cout << "DEBUG MinuitTrackFitter::Fit(): CHECKPOINT 0.5" << std::endl;
 
159
  particle_id_ = raw_track->get_pid();
 
160
  std::cout << "DEBUG MinuitTrackFitter::Fit(): particle ID: "
 
161
            << particle_id_ << std::endl;
 
162
 
 
163
  std::cout << "DEBUG MinuitTrackFitter::Fit(): CHECKPOINT 1" << std::endl;
 
164
  if (detector_events_.size() < 2) {
 
165
    throw(Exception(Exception::recoverable,
 
166
                 "Not enough track points to fit track (need at least two).",
 
167
                 "MAUS::MinuitTrackFitter::Fit()"));
 
168
  }
 
169
  std::cout << "DEBUG MinuitTrackFitter::Fit(): CHECKPOINT 2" << std::endl;
 
170
 
 
171
  ResetParameters();
 
172
 
 
173
  TMinuit * minimizer
 
174
    = common_cpp_optics_recon_minuit_track_fitter_minuit;
 
175
 
 
176
  // Find the start plane coordinates that minimize the score for the calculated
 
177
  // track based off of this track point (i.e. best fits the measured track
 
178
  // points from the detectors).
 
179
  Json::Value const * const configuration = optics_model_->configuration();
 
180
  if (configuration == NULL) {
 
181
    throw(Exception(Exception::nonRecoverable,
 
182
          "Initialized with a null configuration.",
 
183
          "MAUS::MinuitTrackFitter::Fit()"));
 
184
  }
 
185
  const std::string method = JsonWrapper::GetProperty(
 
186
      *configuration, "global_recon_minuit_minimizer",
 
187
      JsonWrapper::stringValue).asString();
 
188
  const double max_iterations = JsonWrapper::GetProperty(
 
189
      *configuration, "global_recon_minuit_max_iterations",
 
190
      JsonWrapper::intValue).asInt();
 
191
  const double max_EDM = JsonWrapper::GetProperty(
 
192
      *configuration, "global_recon_minuit_max_edm",
 
193
      JsonWrapper::realValue).asDouble();
 
194
 
 
195
  Int_t err = 0;
 
196
  Double_t args[2] = {max_iterations, max_EDM};
 
197
  minimizer->mnexcm(method.c_str(), args, 2, err);
 
198
 
 
199
  // Get the particle event for this track
 
200
  GlobalDS::TrackPointCPArray raw_points = raw_track->GetTrackPoints();
 
201
  size_t particle_event = raw_points[0]->get_particle_event();
 
202
 
 
203
  DataStructureHelper helper = DataStructureHelper::GetInstance();
 
204
 
 
205
  // Add a TrackPoint to the recon track for the fit primary
 
206
  PhaseSpaceVector fit_primary;
 
207
  for (size_t index = 0; index < kPhaseSpaceDimension; ++index) {
 
208
    Double_t value, error;
 
209
    minimizer->GetParameter(index, value, error);
 
210
    fit_primary[index] = value;
 
211
  }
 
212
  std::cout << "DEBUG MinuitTrackFitter::Fit: Fit Primary: " << fit_primary
 
213
            << std::endl;
 
214
  try {
 
215
    TrackPoint track_point = helper.PhaseSpaceVector2TrackPoint(
 
216
        fit_primary,
 
217
        optics_model_->primary_plane(),
 
218
        particle_id_);
 
219
    track_point.set_mapper_name(mapper_name);
 
220
    track_point.set_detector(MAUS::DataStructure::Global::kUndefined);
 
221
    track_point.set_particle_event(particle_event);
 
222
    track->AddTrackPoint(new TrackPoint(track_point));
 
223
  } catch (Exception exc) {
 
224
      std::cerr << "DEBUG MinuitTrackFitter::ScoreTrack: "
 
225
                << "something bad happened during track fitting: "
 
226
                << exc.what() << std::endl;
 
227
      // FIXME(Lane) handle better by reporting horrible score or something
 
228
  }
 
229
 
 
230
 
 
231
  // Add the fit points to the recon track by transporting the fit primary
 
232
  GlobalDS::TrackPointCPArray::const_iterator raw_point = raw_points.begin();
 
233
  for (; raw_point != raw_points.end(); ++raw_point) {
 
234
    // transport the fit primary to the desired z-position
 
235
    const double z = (*raw_point)->get_position().Z();
 
236
    const PhaseSpaceVector point = optics_model_->Transport(fit_primary, z);
 
237
    std::cout << "DEBUG MinuitTrackFitter::Fit: track point: " << point << std::endl;
 
238
 
 
239
    TrackPoint track_point;
 
240
    try {
 
241
      track_point = helper.PhaseSpaceVector2TrackPoint(point, z, particle_id_);
 
242
    } catch (Exception exc) {
 
243
        std::cerr << "DEBUG MinuitTrackFitter::ScoreTrack: "
 
244
                  << "something bad happened during track fitting: "
 
245
                  << exc.what() << std::endl;
 
246
        // FIXME(Lane) handle better by reporting horrible score or something
 
247
    }
 
248
 
 
249
    track_point.set_particle_event(particle_event);
 
250
    track_point.set_mapper_name(mapper_name);
 
251
    track_point.set_detector((*raw_point)->get_detector());
 
252
    track->AddTrackPoint(new TrackPoint(track_point));
 
253
  }
 
254
  track->set_pid(raw_track->get_pid());
 
255
}
 
256
 
 
257
Double_t MinuitTrackFitter::ScoreTrack(
 
258
    Double_t const * const start_plane_track_coordinates,
 
259
    const MAUS::OpticsModel & optics_model,
 
260
    const double mass,
 
261
    const std::vector<const GlobalDS::TrackPoint *> & detector_events) {
 
262
  DataStructureHelper helper = DataStructureHelper::GetInstance();
 
263
 
 
264
  // Setup the start plane track point based on the Minuit initial conditions
 
265
  CovarianceMatrix null_uncertainties;
 
266
  const PhaseSpaceVector guess(start_plane_track_coordinates[0],
 
267
                               start_plane_track_coordinates[1],
 
268
                               start_plane_track_coordinates[2],
 
269
                               start_plane_track_coordinates[3],
 
270
                               start_plane_track_coordinates[4],
 
271
                               start_plane_track_coordinates[5]);
 
272
  // If the guess is not physical then return a horrible score
 
273
  if (!MinuitTrackFitter::ValidVector(guess, mass)) {
 
274
    return 1.0e+15;
 
275
  }
 
276
 
 
277
  std::vector<const TrackPoint*>::const_iterator event
 
278
    = detector_events.begin();
 
279
 
 
280
  // calculate chi^2
 
281
  Double_t chi_squared = 0.0;
 
282
  size_t index = 0;
 
283
  for (std::vector<const TrackPoint*>::const_iterator event
 
284
        = detector_events.begin();
 
285
       event != detector_events.end();
 
286
       ++event) {
 
287
    std::cout << "DEBUG MinuitTrackFitter::ScoreTrack(): Guess: "
 
288
              << guess << std::endl;
 
289
    std::cout << "DEBUG MinuitTrackFitter::ScoreTrack(): Measured: "
 
290
              << *event << std::endl;
 
291
    // calculate the next guess
 
292
    const double end_plane = (*event)->get_position().Z();
 
293
    PhaseSpaceVector point =
 
294
      optics_model.Transport(guess, end_plane);
 
295
    std::cout << "DEBUG MinuitTrackFitter::ScoreTrack(): Calculated: "
 
296
              << point << std::endl;
 
297
 
 
298
    TLorentzVector position_error = (*event)->get_position_error();
 
299
    TLorentzVector momentum_error = (*event)->get_momentum_error();
 
300
    const double errors[36] = {
 
301
      position_error.T(), 0., 0., 0., 0., 0.,
 
302
      0., momentum_error.E(), 0., 0., 0., 0.,
 
303
      0., 0., position_error.X(), 0., 0., 0.,
 
304
      0., 0., 0., momentum_error.Px(), 0., 0.,
 
305
      0., 0., 0., 0., position_error.Y(), 0.,
 
306
      0., 0., 0., 0., 0., momentum_error.Py(),
 
307
    };
 
308
    const Matrix<double> error_matrix(6, 6, errors);
 
309
    const CovarianceMatrix uncertainties(error_matrix*error_matrix);
 
310
 
 
311
    const double weights[36] = {
 
312
      1., 0., 0., 0., 0., 0.,
 
313
      0., 1., 0., 0., 0., 0.,
 
314
      0., 0., 1., 0., 0., 0.,
 
315
      0., 0., 0., 1., 0., 0.,
 
316
      0., 0., 0., 0., 1., 0.,
 
317
      0., 0., 0., 0., 0., 1.,
 
318
    };
 
319
    Matrix<double> weight_matrix(6, 6, weights);
 
320
 
 
321
    // Sum the squares of the differences between the calculated phase space
 
322
    // coordinates (point) and the measured coordinates (event).
 
323
    PhaseSpaceVector event_point = helper.TrackPoint2PhaseSpaceVector(**event);
 
324
    PhaseSpaceVector residual = PhaseSpaceVector(
 
325
      weight_matrix * (event_point - point));
 
326
    const double residual_squared = (transpose(residual)
 
327
                                     * inverse(uncertainties)
 
328
                                     * residual)[0];
 
329
    std::cout << "DEBUG MinuitTrackFitter::ScoreTrack(): Residual Squared: "
 
330
          << residual_squared << std::endl;
 
331
    chi_squared += residual_squared;
 
332
    std::cerr << residual << std::endl
 
333
              << " = " << event_point << std::endl
 
334
              << " - " << point << std::endl
 
335
              << " -- chi2: " << chi_squared << std::endl;
 
336
    ++index;
 
337
  }
 
338
  std::cerr << std::endl;
 
339
 
 
340
  return chi_squared;
 
341
}
 
342
 
 
343
bool MinuitTrackFitter::ValidVector(const PhaseSpaceVector & guess,
 
344
                                    const double mass) {
 
345
  const double E = guess.E();
 
346
  const double px = guess.Px();
 
347
  const double py = guess.Py();
 
348
 
 
349
  bool valid = true;
 
350
 
 
351
  if (guess != guess) {
 
352
    // No NaN guesses
 
353
    valid = false;
 
354
  } else if (::sqrt(px*px + py*py + mass*mass) > E) {
 
355
    // Energy cannot be greater than the sum of the squares of the transverse
 
356
    // momenta and mass
 
357
    valid = false;
 
358
  }
 
359
 
 
360
  return valid;
 
361
}
 
362
 
 
363
const size_t MinuitTrackFitter::kPhaseSpaceDimension = 6;
 
364
}  // namespace global
 
365
}  // namespace recon
 
366
}  // namespace MAUS