summaryrefslogtreecommitdiff
path: root/fileio/root.py
blob: 4c9d9bbbc729da426d72134d69f936dc20d07f46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import ROOT
import os.path
import numpy as np

ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g'))

import ROOT
import chroma.event as event

def tvector3_to_ndarray(vec):
    '''Convert a ROOT.TVector3 into a numpy np.float32 array'''
    return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32)

def make_photon_with_arrays(size):
    '''Returns a new chroma.event.Photons object for `size` number of
    photons with empty arrays set for all the photon attributes.'''
    return event.Photons(positions=np.empty((size,3), dtype=np.float32),
                         directions=np.empty((size,3), dtype=np.float32),
                         polarizations=np.empty((size,3), dtype=np.float32),
                         wavelengths=np.empty(size, dtype=np.float32),
                         times=np.empty(size, dtype=np.float32),
                         histories=np.empty(size, dtype=np.uint32),
                         last_hit_triangles=np.empty(size, dtype=np.int32))


def root_event_to_python_event(ev):
    '''Returns a new chroma.event.Event object created from the
    contents of the ROOT event `ev`.'''
    pyev = event.Event(ev.event_id)

    # MC
    pyev.particle_name = str(ev.mc.particle)
    pyev.gen_position = tvector3_to_ndarray(ev.mc.gen_position)
    pyev.gen_direction = tvector3_to_ndarray(ev.mc.gen_direction)
    pyev.gen_total_energy = ev.mc.gen_total_energy

    pyev.nphoton = ev.mc.nphoton

    for subtrack in ev.mc.subtrack:
        pysubtrack = event.Subtrack(str(subtrack.particle),
                                    tvector3_to_ndarray(subtrack.position),
                                    tvector3_to_ndarray(subtrack.direction),
                                    subtrack.start_time,
                                    subtrack.total_energy)
        pyev.subtracks.append(pysubtrack)

    # photon start
    if ev.mc.photon_start.size() > 0:
        photons = make_photon_with_arrays(ev.mc.photon_start.size())
        ROOT.get_photons(ev.mc.photon_start, photons.positions.ravel(), photons.directions.ravel(),
                         photons.polarizations.ravel(), photons.wavelengths, photons.times,
                         photons.histories, photons.last_hit_triangles)
        pyev.photon_start = photons

    # photon stop
    if ev.mc.photon_stop.size() > 0:
        photons = make_photon_with_arrays(ev.mc.photon_stop.size())
        ROOT.get_photons(ev.mc.photon_stop, photons.positions.ravel(), photons.directions.ravel(),
                         photons.polarizations.ravel(), photons.wavelengths, photons.times,
                         photons.histories, photons.last_hit_triangles)
        pyev.photon_stop = photons

    # hits
    max_channel_id = ev.max_channel_id
    hit = np.empty(shape=max_channel_id+1, dtype=np.int32)
    t = np.empty(shape=max_channel_id+1, dtype=np.float32)
    q = np.empty(shape=max_channel_id+1, dtype=np.float32)
    histories = np.empty(shape=max_channel_id+1, dtype=np.uint32)

    ev.get_channels(max_channel_id+1, hit, t, q, histories)
    pyev.channels = event.Channels(hit.astype(bool), t, q, histories)
    return pyev
    
class RootReader(object):
    '''Reader of Chroma events from a ROOT file.  This class can be used to 
    navigate up and down the file linearly or in a random access fashion.
    All returned events are instances of the chroma.event.Event class.

    It implements the iterator protocol, so you can do

       for ev in RootReader('electron.root'):
           # process event here
    '''

    def __init__(self, filename):
        '''Open ROOT file named `filename` containing TTree `T`.'''
        self.f = ROOT.TFile(filename)
        self.T = self.f.T
        self.i = -1
        
    def __len__(self):
        '''Returns number of events in this file.'''
        return self.T.GetEntries()

    def next(self):
        '''Return the next event in the file. Raises StopIteration
        when you get to the end.'''
        if self.i + 1 >= len(self):
            raise StopIteration

        self.i += 1
        self.T.GetEntry(self.i)
        return root_event_to_python_event(self.T.ev)

    def prev(self):
        '''Return the next event in the file. Raises StopIteration if
        that would go past the beginning.'''
        if self.i <= 0:
            self.i = -1
            raise StopIteration

        self.i -= 1
        self.T.GetEntry(self.i)
        return root_event_to_python_event(self.T.ev)

    def current(self):
        '''Return the current event in the file.'''
        self.T.GetEntry(self.i) # just in case?
        return root_event_to_python_event(self.T.ev)

    def jump_to(self, index):
        '''Return the event at `index`.  Updates current location.'''
        if index < 0 or index >= len(self):
            raise IndexError
        
        self.T.GetEntry(self.i)
        return root_event_to_python_event(self.T.ev)

    def index(self):
        '''Return the current event index'''
        return self.i

class RootWriter(object):
    def __init__(self, filename):
        self.filename = filename
        self.file = ROOT.TFile(filename, 'RECREATE')

        self.T = ROOT.TTree('T', 'Chroma events')
        self.ev = ROOT.Event()
        self.T.Branch('ev', self.ev)

    def write_event(self, pyev):
        '''Write an event.Event object to the ROOT tree as a ROOT.Event object.'''
        self.ev.event_id = pyev.event_id

        self.ev.mc.particle = pyev.particle_name
        self.ev.mc.gen_position.SetXYZ(*pyev.gen_position)
        self.ev.mc.gen_direction.SetXYZ(*pyev.gen_direction)
        self.ev.mc.gen_total_energy = pyev.gen_total_energy
        self.ev.mc.nphoton = pyev.nphoton

        if pyev.photon_start is not None:
            photons = pyev.photon_start
            ROOT.fill_photons(self.ev.mc.photon_start,
                              len(photons.positions), 
                              np.ravel(photons.positions),
                              np.ravel(photons.directions),
                              np.ravel(photons.polarizations),
                              photons.wavelengths, photons.times,
                              photons.histories, photons.last_hit_triangles)
        if pyev.photon_stop is not None:
            photons = pyev.photon_stop
            ROOT.fill_photons(self.ev.mc.photon_stop,
                              len(photons.positions), 
                              np.ravel(photons.positions),
                              np.ravel(photons.directions),
                              np.ravel(photons.polarizations),
                              photons.wavelengths, photons.times,
                              photons.histories, photons.last_hit_triangles)

        self.ev.mc.subtrack.resize(0)
        if pyev.subtracks is not None:
            self.ev.mc.subtrack.resize(len(pyev.subtracks))
            for i, subtrack in enumerate(pyev.subtracks):
                self.ev.mc.subtrack[i].name = subtrack.particle_name
                self.ev.mc.subtrack[i].position.SetXYZ(*subtrack.position)
                self.ev.mc.subtrack[i].direction.SetXYZ(*subtrack.direction)
                self.ev.mc.subtrack[i].start_time = subtrack.start_time
                self.ev.mc.subtrack[i].total_energy = subtrack.total_energy

        ROOT.fill_hits(self.ev, len(pyev.channels.t), pyev.channels.t, pyev.channels.q, pyev.channels.histories)
        self.T.Fill()

    def close(self):
        self.T.Write()
        self.file.Close()