import unittest from chroma.fileio import root from chroma import event import numpy as np class TestFileIO(unittest.TestCase): def test_file_write_and_read(self): ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15) photon_start = root.make_photon_with_arrays(1) photon_start.positions[0] = (1,2,3) photon_start.directions[0] = (4,5,6) photon_start.polarizations[0] = (7,8,9) photon_start.wavelengths[0] = 400.0 photon_start.times[0] = 100.0 photon_start.histories[0] = 20 photon_start.last_hit_triangles[0] = 5 ev.photon_start = photon_start photon_stop = root.make_photon_with_arrays(1) photon_stop.positions[0] = (1,2,3) photon_stop.directions[0] = (4,5,6) photon_stop.polarizations[0] = (7,8,9) photon_stop.wavelengths[0] = 400.0 photon_stop.times[0] = 100.0 photon_stop.histories[0] = 20 photon_stop.last_hit_triangles[0] = 5 ev.photon_stop = photon_stop ev.nphoton = 1 ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800)) channels = event.Channels(hit=np.array([True, False]), t=np.array([20.0, 1e9], dtype=np.float32), q=np.array([2.0, 0.0], dtype=np.float32), histories=np.array([8, 32], dtype=np.uint32)) ev.channels = channels filename = '/tmp/chroma-filewritertest.root' writer = root.RootWriter(filename) writer.write_event(ev) writer.close() # Exercise the RootReader methods reader = root.RootReader(filename) self.assertEquals(len(reader), 1) self.assertRaises(StopIteration, reader.prev) reader.next() self.assertEqual(reader.index(), 0) self.assertRaises(StopIteration, reader.next) reader.jump_to(0) # Enough screwing around, let's get the one event in the file newev = reader.current() # Now check if everything is correct in the event for attribute in ['event_id', 'particle_name','gen_total_energy']: self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute) for attribute in ['gen_position', 'gen_direction']: self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute) for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times', 'histories', 'last_hit_triangles']: self.assertTrue(np.allclose(getattr(ev.photon_start, attribute), getattr(newev.photon_start, attribute)), 'compare %s' % attribute) self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute), getattr(newev.photon_stop, attribute)), 'compare %s' % attribute)