summaryrefslogtreecommitdiff
path: root/tests/test_fileio.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_fileio.py')
-rw-r--r--tests/test_fileio.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tests/test_fileio.py b/tests/test_fileio.py
new file mode 100644
index 0000000..d21c9ff
--- /dev/null
+++ b/tests/test_fileio.py
@@ -0,0 +1,73 @@
+import unittest
+from chroma.fileio import root
+from chroma import event
+import numpy as np
+
+class TestFileIO(unittest.TestCase):
+ def test_file_write_and_read(self):
+ ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15)
+
+ photon_start = root.make_photon_with_arrays(1)
+ photon_start.positions[0] = (1,2,3)
+ photon_start.directions[0] = (4,5,6)
+ photon_start.polarizations[0] = (7,8,9)
+ photon_start.wavelengths[0] = 400.0
+ photon_start.times[0] = 100.0
+ photon_start.histories[0] = 20
+ photon_start.last_hit_triangles[0] = 5
+ ev.photon_start = photon_start
+
+ photon_stop = root.make_photon_with_arrays(1)
+ photon_stop.positions[0] = (1,2,3)
+ photon_stop.directions[0] = (4,5,6)
+ photon_stop.polarizations[0] = (7,8,9)
+ photon_stop.wavelengths[0] = 400.0
+ photon_stop.times[0] = 100.0
+ photon_stop.histories[0] = 20
+ photon_stop.last_hit_triangles[0] = 5
+ ev.photon_stop = photon_stop
+
+ ev.nphoton = 1
+
+ ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800))
+
+ channels = event.Channels(hit=np.array([True, False]),
+ t=np.array([20.0, 1e9], dtype=np.float32),
+ q=np.array([2.0, 0.0], dtype=np.float32),
+ histories=np.array([8, 32], dtype=np.uint32))
+ ev.channels = channels
+
+ filename = '/tmp/chroma-filewritertest.root'
+ writer = root.RootWriter(filename)
+ writer.write_event(ev)
+ writer.close()
+
+ # Exercise the RootReader methods
+ reader = root.RootReader(filename)
+ self.assertEquals(len(reader), 1)
+
+ self.assertRaises(StopIteration, reader.prev)
+
+ reader.next()
+
+ self.assertEqual(reader.index(), 0)
+ self.assertRaises(StopIteration, reader.next)
+
+ reader.jump_to(0)
+
+ # Enough screwing around, let's get the one event in the file
+ newev = reader.current()
+
+
+ # Now check if everything is correct in the event
+ for attribute in ['event_id', 'particle_name','gen_total_energy']:
+ self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute)
+ for attribute in ['gen_position', 'gen_direction']:
+ self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute)
+
+ for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times',
+ 'histories', 'last_hit_triangles']:
+ self.assertTrue(np.allclose(getattr(ev.photon_start, attribute),
+ getattr(newev.photon_start, attribute)), 'compare %s' % attribute)
+ self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute),
+ getattr(newev.photon_stop, attribute)), 'compare %s' % attribute)