1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
|
import unittest
from chroma.fileio import root
from chroma import event
import numpy as np
class TestFileIO(unittest.TestCase):
def test_file_write_and_read(self):
ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15)
photon_start = root.make_photon_with_arrays(1)
photon_start.positions[0] = (1,2,3)
photon_start.directions[0] = (4,5,6)
photon_start.polarizations[0] = (7,8,9)
photon_start.wavelengths[0] = 400.0
photon_start.times[0] = 100.0
photon_start.histories[0] = 20
photon_start.last_hit_triangles[0] = 5
ev.photon_start = photon_start
photon_stop = root.make_photon_with_arrays(1)
photon_stop.positions[0] = (1,2,3)
photon_stop.directions[0] = (4,5,6)
photon_stop.polarizations[0] = (7,8,9)
photon_stop.wavelengths[0] = 400.0
photon_stop.times[0] = 100.0
photon_stop.histories[0] = 20
photon_stop.last_hit_triangles[0] = 5
ev.photon_stop = photon_stop
ev.nphoton = 1
ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800))
channels = event.Channels(hit=np.array([True, False]),
t=np.array([20.0, 1e9], dtype=np.float32),
q=np.array([2.0, 0.0], dtype=np.float32),
histories=np.array([8, 32], dtype=np.uint32))
ev.channels = channels
filename = '/tmp/chroma-filewritertest.root'
writer = root.RootWriter(filename)
writer.write_event(ev)
writer.close()
# Exercise the RootReader methods
reader = root.RootReader(filename)
self.assertEquals(len(reader), 1)
self.assertRaises(StopIteration, reader.prev)
reader.next()
self.assertEqual(reader.index(), 0)
self.assertRaises(StopIteration, reader.next)
reader.jump_to(0)
# Enough screwing around, let's get the one event in the file
newev = reader.current()
# Now check if everything is correct in the event
for attribute in ['event_id', 'particle_name','gen_total_energy']:
self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute)
for attribute in ['gen_position', 'gen_direction']:
self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute)
for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times',
'histories', 'last_hit_triangles']:
self.assertTrue(np.allclose(getattr(ev.photon_start, attribute),
getattr(newev.photon_start, attribute)), 'compare %s' % attribute)
self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute),
getattr(newev.photon_stop, attribute)), 'compare %s' % attribute)
|