summaryrefslogtreecommitdiff
path: root/tests/test_fileio.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_fileio.py')
-rw-r--r--tests/test_fileio.py66
1 files changed, 33 insertions, 33 deletions
diff --git a/tests/test_fileio.py b/tests/test_fileio.py
index d21c9ff..2911976 100644
--- a/tests/test_fileio.py
+++ b/tests/test_fileio.py
@@ -5,36 +5,34 @@ import numpy as np
class TestFileIO(unittest.TestCase):
def test_file_write_and_read(self):
- ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15)
+ ev = event.Event(1, event.Vertex('e-', (0,0,1), (1,0,0), (0,1,0), 15))
- photon_start = root.make_photon_with_arrays(1)
- photon_start.positions[0] = (1,2,3)
- photon_start.directions[0] = (4,5,6)
- photon_start.polarizations[0] = (7,8,9)
- photon_start.wavelengths[0] = 400.0
- photon_start.times[0] = 100.0
- photon_start.histories[0] = 20
- photon_start.last_hit_triangles[0] = 5
- ev.photon_start = photon_start
+ photons_beg = root.make_photon_with_arrays(1)
+ photons_beg.pos[0] = (1,2,3)
+ photons_beg.dir[0] = (4,5,6)
+ photons_beg.pol[0] = (7,8,9)
+ photons_beg.wavelengths[0] = 400.0
+ photons_beg.t[0] = 100.0
+ photons_beg.last_hit_triangles[0] = 5
+ photons_beg.flags[0] = 20
+ ev.photons_beg = photons_beg
- photon_stop = root.make_photon_with_arrays(1)
- photon_stop.positions[0] = (1,2,3)
- photon_stop.directions[0] = (4,5,6)
- photon_stop.polarizations[0] = (7,8,9)
- photon_stop.wavelengths[0] = 400.0
- photon_stop.times[0] = 100.0
- photon_stop.histories[0] = 20
- photon_stop.last_hit_triangles[0] = 5
- ev.photon_stop = photon_stop
+ photons_end = root.make_photon_with_arrays(1)
+ photons_end.pos[0] = (1,2,3)
+ photons_end.dir[0] = (4,5,6)
+ photons_end.pol[0] = (7,8,9)
+ photons_end.wavelengths[0] = 400.0
+ photons_end.t[0] = 100.0
+ photons_end.last_hit_triangles[0] = 5
+ photons_end.flags[0] = 20
+ ev.photons_end = photons_end
- ev.nphoton = 1
-
- ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800))
+ ev.vertices = [ev.primary_vertex]
channels = event.Channels(hit=np.array([True, False]),
t=np.array([20.0, 1e9], dtype=np.float32),
q=np.array([2.0, 0.0], dtype=np.float32),
- histories=np.array([8, 32], dtype=np.uint32))
+ flags=np.array([8, 32], dtype=np.uint32))
ev.channels = channels
filename = '/tmp/chroma-filewritertest.root'
@@ -58,16 +56,18 @@ class TestFileIO(unittest.TestCase):
# Enough screwing around, let's get the one event in the file
newev = reader.current()
-
# Now check if everything is correct in the event
- for attribute in ['event_id', 'particle_name','gen_total_energy']:
+ for attribute in ['id']:
self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute)
- for attribute in ['gen_position', 'gen_direction']:
- self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute)
- for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times',
- 'histories', 'last_hit_triangles']:
- self.assertTrue(np.allclose(getattr(ev.photon_start, attribute),
- getattr(newev.photon_start, attribute)), 'compare %s' % attribute)
- self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute),
- getattr(newev.photon_stop, attribute)), 'compare %s' % attribute)
+ for attribute in ['pos', 'dir', 'pol', 'ke', 't0']:
+ self.assertTrue(np.allclose(getattr(ev.primary_vertex, attribute), getattr(newev.primary_vertex, attribute)), 'compare %s' % attribute)
+
+ for i in range(len(ev.vertices)):
+ self.assertTrue(np.allclose(getattr(ev.vertices[i], attribute), getattr(newev.vertices[i], attribute)), 'compare %s' % attribute)
+
+ for attribute in ['pos', 'dir', 'pol', 'wavelengths', 't', 'last_hit_triangles', 'flags']:
+ self.assertTrue(np.allclose(getattr(ev.photons_beg, attribute),
+ getattr(newev.photons_beg, attribute)), 'compare %s' % attribute)
+ self.assertTrue(np.allclose(getattr(ev.photons_end, attribute),
+ getattr(newev.photons_end, attribute)), 'compare %s' % attribute)