summaryrefslogtreecommitdiff
path: root/fileio/root.py
diff options
context:
space:
mode:
Diffstat (limited to 'fileio/root.py')
-rw-r--r--fileio/root.py137
1 files changed, 131 insertions, 6 deletions
diff --git a/fileio/root.py b/fileio/root.py
index 1fa2425..4c9d9bb 100644
--- a/fileio/root.py
+++ b/fileio/root.py
@@ -7,6 +7,129 @@ ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g')
import ROOT
import chroma.event as event
+def tvector3_to_ndarray(vec):
+ '''Convert a ROOT.TVector3 into a numpy np.float32 array'''
+ return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32)
+
+def make_photon_with_arrays(size):
+ '''Returns a new chroma.event.Photons object for `size` number of
+ photons with empty arrays set for all the photon attributes.'''
+ return event.Photons(positions=np.empty((size,3), dtype=np.float32),
+ directions=np.empty((size,3), dtype=np.float32),
+ polarizations=np.empty((size,3), dtype=np.float32),
+ wavelengths=np.empty(size, dtype=np.float32),
+ times=np.empty(size, dtype=np.float32),
+ histories=np.empty(size, dtype=np.uint32),
+ last_hit_triangles=np.empty(size, dtype=np.int32))
+
+
+def root_event_to_python_event(ev):
+ '''Returns a new chroma.event.Event object created from the
+ contents of the ROOT event `ev`.'''
+ pyev = event.Event(ev.event_id)
+
+ # MC
+ pyev.particle_name = str(ev.mc.particle)
+ pyev.gen_position = tvector3_to_ndarray(ev.mc.gen_position)
+ pyev.gen_direction = tvector3_to_ndarray(ev.mc.gen_direction)
+ pyev.gen_total_energy = ev.mc.gen_total_energy
+
+ pyev.nphoton = ev.mc.nphoton
+
+ for subtrack in ev.mc.subtrack:
+ pysubtrack = event.Subtrack(str(subtrack.particle),
+ tvector3_to_ndarray(subtrack.position),
+ tvector3_to_ndarray(subtrack.direction),
+ subtrack.start_time,
+ subtrack.total_energy)
+ pyev.subtracks.append(pysubtrack)
+
+ # photon start
+ if ev.mc.photon_start.size() > 0:
+ photons = make_photon_with_arrays(ev.mc.photon_start.size())
+ ROOT.get_photons(ev.mc.photon_start, photons.positions.ravel(), photons.directions.ravel(),
+ photons.polarizations.ravel(), photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
+ pyev.photon_start = photons
+
+ # photon stop
+ if ev.mc.photon_stop.size() > 0:
+ photons = make_photon_with_arrays(ev.mc.photon_stop.size())
+ ROOT.get_photons(ev.mc.photon_stop, photons.positions.ravel(), photons.directions.ravel(),
+ photons.polarizations.ravel(), photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
+ pyev.photon_stop = photons
+
+ # hits
+ max_channel_id = ev.max_channel_id
+ hit = np.empty(shape=max_channel_id+1, dtype=np.int32)
+ t = np.empty(shape=max_channel_id+1, dtype=np.float32)
+ q = np.empty(shape=max_channel_id+1, dtype=np.float32)
+ histories = np.empty(shape=max_channel_id+1, dtype=np.uint32)
+
+ ev.get_channels(max_channel_id+1, hit, t, q, histories)
+ pyev.channels = event.Channels(hit.astype(bool), t, q, histories)
+ return pyev
+
+class RootReader(object):
+ '''Reader of Chroma events from a ROOT file. This class can be used to
+ navigate up and down the file linearly or in a random access fashion.
+ All returned events are instances of the chroma.event.Event class.
+
+ It implements the iterator protocol, so you can do
+
+ for ev in RootReader('electron.root'):
+ # process event here
+ '''
+
+ def __init__(self, filename):
+ '''Open ROOT file named `filename` containing TTree `T`.'''
+ self.f = ROOT.TFile(filename)
+ self.T = self.f.T
+ self.i = -1
+
+ def __len__(self):
+ '''Returns number of events in this file.'''
+ return self.T.GetEntries()
+
+ def next(self):
+ '''Return the next event in the file. Raises StopIteration
+ when you get to the end.'''
+ if self.i + 1 >= len(self):
+ raise StopIteration
+
+ self.i += 1
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def prev(self):
+ '''Return the next event in the file. Raises StopIteration if
+ that would go past the beginning.'''
+ if self.i <= 0:
+ self.i = -1
+ raise StopIteration
+
+ self.i -= 1
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def current(self):
+ '''Return the current event in the file.'''
+ self.T.GetEntry(self.i) # just in case?
+ return root_event_to_python_event(self.T.ev)
+
+ def jump_to(self, index):
+ '''Return the event at `index`. Updates current location.'''
+ if index < 0 or index >= len(self):
+ raise IndexError
+
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def index(self):
+ '''Return the current event index'''
+ return self.i
+
class RootWriter(object):
def __init__(self, filename):
self.filename = filename
@@ -28,20 +151,22 @@ class RootWriter(object):
if pyev.photon_start is not None:
photons = pyev.photon_start
- ROOT.fill_photons(self.ev, True,
+ ROOT.fill_photons(self.ev.mc.photon_start,
len(photons.positions),
np.ravel(photons.positions),
np.ravel(photons.directions),
np.ravel(photons.polarizations),
- photons.wavelengths, photons.times)
+ photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
if pyev.photon_stop is not None:
- photons = photon_stop
- ROOT.fill_photons(self.ev, True,
+ photons = pyev.photon_stop
+ ROOT.fill_photons(self.ev.mc.photon_stop,
len(photons.positions),
np.ravel(photons.positions),
np.ravel(photons.directions),
np.ravel(photons.polarizations),
- photons.wavelengths, photons.times)
+ photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
self.ev.mc.subtrack.resize(0)
if pyev.subtracks is not None:
@@ -53,7 +178,7 @@ class RootWriter(object):
self.ev.mc.subtrack[i].start_time = subtrack.start_time
self.ev.mc.subtrack[i].total_energy = subtrack.total_energy
- ROOT.fill_hits(self.ev, len(pyev.hits.t), pyev.hits.t, pyev.hits.q, pyev.hits.histories)
+ ROOT.fill_hits(self.ev, len(pyev.channels.t), pyev.channels.t, pyev.channels.q, pyev.channels.histories)
self.T.Fill()
def close(self):