1
# This file is part of MAUS: http://micewww.pp.rl.ac.uk:8080/projects/maus
3
# MAUS is free software: you can redistribute it and/or modify
4
# it under the terms of the GNU General Public License as published by
5
# the Free Software Foundation, either version 3 of the License, or
6
# (at your option) any later version.
8
# MAUS is distributed in the hope that it will be useful,
9
# but WITHOUT ANY WARRANTY; without even the implied warranty of
10
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
# GNU General Public License for more details.
13
# You should have received a copy of the GNU General Public License
14
# along with MAUS. If not, see <http://www.gnu.org/licenses/>.
16
# pylint: disable=C0103
19
Test maus_cpp.optics_model
27
import maus_cpp.globals
28
import maus_cpp.mice_module
29
import maus_cpp.covariance_matrix
30
from maus_cpp.covariance_matrix import CovarianceMatrix
31
import maus_cpp.phase_space_vector
32
from maus_cpp.phase_space_vector import PhaseSpaceVector
33
import maus_cpp.optics_model
34
from maus_cpp.optics_model import OpticsModel
37
"position":{"x":0., "y":0., "z":1000.001},
38
"momentum":{"x":0., "y":0., "z":1.},
39
"particle_id":-13, "energy":226., "random_seed":0, "time":0.
42
class OpticsModelTestCase(unittest.TestCase): # pylint: disable=R0904
43
"""Test maus_cpp.optics_model"""
45
"""import datacards"""
47
self.good_mod = os.path.expandvars(\
48
"${MAUS_ROOT_DIR}/tests/py_unit/test_maus_cpp/"+\
49
"test_optics/optics_model.dat")
50
self.no_planes_mod = os.path.expandvars("Test.dat")
52
def _set_geometry(self, geom_filename):
54
if maus_cpp.globals.has_instance():
55
maus_cpp.globals.death()
56
self.test_config = Configuration.Configuration().getConfigJSON()
57
json_config = json.loads(self.test_config)
58
json_config["simulation_geometry_filename"] = geom_filename
59
json_config["simulation_reference_particle"] = REF
60
json_config["physics_processes"] = "none"
61
self.test_config = json.dumps(json_config)
62
maus_cpp.globals.birth(self.test_config)
63
maus_cpp.globals.set_monte_carlo_mice_modules(
64
maus_cpp.mice_module.MiceModule(geom_filename))
66
def test_init_no_globals(self):
67
"""Test maus_cpp.optics_model.__init__() and deallocation"""
68
if maus_cpp.globals.has_instance():
69
maus_cpp.globals.death()
72
self.assertTrue(False, msg="Should throw an exception if globals "+\
77
def test_init_all_okay(self):
78
"""Test maus_cpp.optics_model.__init__() and deallocation"""
79
self._set_geometry(self.good_mod)
80
optics = OpticsModel()
81
optics.__init__() # legal, should reinitialise
83
def test_transport_covariance_matrix_no_planes(self):
85
Test maus_cpp.optics_model.Optics().transport_covariance_matrix()
88
self._set_geometry(self.no_planes_mod)
89
optics = OpticsModel()
90
cm_in = CovarianceMatrix()
92
optics.transport_covariance_matrix(cm_in, 2000.)
93
self.assertTrue(False, "Should throw when no virtuals")
97
def test_transport_covariance_matrix_bad_type(self):
99
Test maus_cpp.optics_model.Optics().transport_covariance_matrix()
102
self._set_geometry(self.good_mod)
103
optics = OpticsModel()
105
optics.transport_covariance_matrix("should be a cm", 2000.)
106
self.assertTrue(False, "Should throw when wrong type passed")
110
def test_transport_covariance_matrix(self):
111
"""Test maus_cpp.optics_model.Optics().transport_covariance_matrix()"""
112
self._set_geometry(self.good_mod)
113
optics = OpticsModel()
114
cm_in = maus_cpp.covariance_matrix.create_from_penn_parameters(
115
mass=105.658, momentum=200., emittance_t=6., beta_t=333.,
116
emittance_l=1., beta_l=10., bz=0.)
117
# check energy, px RMS does not change (no fields)
118
cm_out_1 = optics.transport_covariance_matrix(cm_in, 2000.)
120
self.assertAlmostEqual(cm_in.get_element(i, i),
121
cm_out_1.get_element(i, i), 1,
122
msg="\nIN\n"+str(cm_in)+"\nOUT\n"+str(cm_out_1))
123
cm_out_2 = optics.transport_covariance_matrix(cm_in, 3000.)
125
self.assertGreater(abs(cm_out_1.get_element(2*i+1, 2*i+2)), 1.)
126
# check that we have some growth of correlation between e.g. x, px
127
self.assertAlmostEqual(cm_out_1.get_element(3, 4),
128
cm_out_1.get_element(5, 6), 2)
130
self.assertAlmostEqual(cm_in.get_element(i, i),
131
cm_out_2.get_element(i, i), 1,
132
msg="\nIN\n"+str(cm_in)+"\nOUT\n"+str(cm_out_2))
133
# check no coupling between phase space planes
135
self.assertAlmostEqual(
136
cm_in.get_element(i+1, i)-cm_out_1.get_element(i+1, i),
137
cm_out_1.get_element(i+1, i)-cm_out_2.get_element(i+1, i),
139
msg="\nIN\n"+str(cm_in)+"\nOUT\n"+str(cm_out_2))
142
def test_transport_phase_space_vector_no_planes(self):
144
Test maus_cpp.optics_model.Optics().transport_phase_space_vector() with
147
self._set_geometry(self.no_planes_mod)
148
optics = OpticsModel()
149
psv_in = PhaseSpaceVector()
151
optics.transport_phase_space_vector(psv_in, 2000.)
152
self.assertTrue(False, "Should throw when no virtuals")
156
def test_transport_phase_space_vector_bad_type(self):
158
Test maus_cpp.optics_model.Optics().transport_phase_space_vector() with
161
self._set_geometry(self.good_mod)
162
optics = OpticsModel()
164
optics.transport_phase_space_vector("should be a psv", 2000.)
165
self.assertTrue(False, "Should throw when wrong type passed")
169
def test_transport_phase_space_vector(self):
170
"""Test maus_cpp.optics_model.Optics().transport_phase_space_vector()"""
171
self._set_geometry(self.good_mod)
172
optics = OpticsModel()
173
psv_in = maus_cpp.phase_space_vector.create_from_coordinates \
174
(0.1, 226., 1., 2., 3., 4.)
175
psv_out = optics.transport_phase_space_vector(psv_in, 2000.)
176
pz = (226.**2-105.6583715**2.-2.**2-4.**2)**0.5
178
t_expected = psv_in.get_t()+psv_out.get_energy()/pz/300.*dz # c=300.
179
x_expected = psv_in.get_x()+psv_in.get_px()/pz*dz
180
y_expected = psv_in.get_y()+psv_in.get_py()/pz*dz
181
self.assertAlmostEqual(psv_out.get_t()/t_expected, 1, 3)
182
self.assertAlmostEqual(psv_out.get_energy()/psv_in.get_energy(), 1., 5)
183
self.assertAlmostEqual(psv_out.get_x()/x_expected, 1., 3)
184
self.assertAlmostEqual(psv_out.get_px()/psv_in.get_px(), 1., 5)
185
self.assertAlmostEqual(psv_out.get_y()/y_expected, 1., 3)
186
self.assertAlmostEqual(psv_out.get_py()/psv_in.get_py(), 1., 5)
188
if __name__ == "__main__":