summaryrefslogtreecommitdiff
path: root/tests/test_fileio.py
blob: d21c9ffb8bc7eee66adec17a03da051b1874f12d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import unittest
from chroma.fileio import root
from chroma import event
import numpy as np

class TestFileIO(unittest.TestCase):
    def test_file_write_and_read(self):
        ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15)

        photon_start = root.make_photon_with_arrays(1)
        photon_start.positions[0] = (1,2,3)
        photon_start.directions[0] = (4,5,6)
        photon_start.polarizations[0] = (7,8,9)
        photon_start.wavelengths[0] = 400.0
        photon_start.times[0] = 100.0
        photon_start.histories[0] = 20
        photon_start.last_hit_triangles[0] = 5
        ev.photon_start = photon_start

        photon_stop = root.make_photon_with_arrays(1)
        photon_stop.positions[0] = (1,2,3)
        photon_stop.directions[0] = (4,5,6)
        photon_stop.polarizations[0] = (7,8,9)
        photon_stop.wavelengths[0] = 400.0
        photon_stop.times[0] = 100.0
        photon_stop.histories[0] = 20
        photon_stop.last_hit_triangles[0] = 5
        ev.photon_stop = photon_stop

        ev.nphoton = 1

        ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800))

        channels = event.Channels(hit=np.array([True, False]), 
                                  t=np.array([20.0, 1e9], dtype=np.float32),
                                  q=np.array([2.0, 0.0], dtype=np.float32),
                                  histories=np.array([8, 32], dtype=np.uint32))
        ev.channels = channels

        filename = '/tmp/chroma-filewritertest.root'
        writer = root.RootWriter(filename)
        writer.write_event(ev)
        writer.close()

        # Exercise the RootReader methods
        reader = root.RootReader(filename)
        self.assertEquals(len(reader), 1)

        self.assertRaises(StopIteration, reader.prev)

        reader.next()

        self.assertEqual(reader.index(), 0)
        self.assertRaises(StopIteration, reader.next)
        
        reader.jump_to(0)

        # Enough screwing around, let's get the one event in the file
        newev = reader.current()


        # Now check if everything is correct in the event
        for attribute in ['event_id', 'particle_name','gen_total_energy']:
            self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute)
        for attribute in ['gen_position', 'gen_direction']:
            self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute)

        for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times',
                          'histories', 'last_hit_triangles']:    
            self.assertTrue(np.allclose(getattr(ev.photon_start, attribute), 
                                        getattr(newev.photon_start, attribute)), 'compare %s' % attribute)
            self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute), 
                                        getattr(newev.photon_stop, attribute)), 'compare %s' % attribute)