diff options
author | Anthony LaTorre <tlatorre9@gmail.com> | 2011-08-25 21:42:42 -0400 |
---|---|---|
committer | Anthony LaTorre <tlatorre9@gmail.com> | 2011-08-25 21:42:42 -0400 |
commit | e75eda8a637c01c34c71063b91a86845cc1c5beb (patch) | |
tree | d525cdaeda966c878e5387bf327f7e9aee643afc /fileio/root.py | |
parent | fa8d1082f9d989f2a3819540a9bf30dc67618709 (diff) | |
parent | b8e7b443242c716c12006442c2738e09ed77c0c9 (diff) | |
download | chroma-e75eda8a637c01c34c71063b91a86845cc1c5beb.tar.gz chroma-e75eda8a637c01c34c71063b91a86845cc1c5beb.tar.bz2 chroma-e75eda8a637c01c34c71063b91a86845cc1c5beb.zip |
merge
Diffstat (limited to 'fileio/root.py')
-rw-r--r-- | fileio/root.py | 137 |
1 files changed, 131 insertions, 6 deletions
diff --git a/fileio/root.py b/fileio/root.py index 1fa2425..4c9d9bb 100644 --- a/fileio/root.py +++ b/fileio/root.py @@ -7,6 +7,129 @@ ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g') import ROOT import chroma.event as event +def tvector3_to_ndarray(vec): + '''Convert a ROOT.TVector3 into a numpy np.float32 array''' + return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32) + +def make_photon_with_arrays(size): + '''Returns a new chroma.event.Photons object for `size` number of + photons with empty arrays set for all the photon attributes.''' + return event.Photons(positions=np.empty((size,3), dtype=np.float32), + directions=np.empty((size,3), dtype=np.float32), + polarizations=np.empty((size,3), dtype=np.float32), + wavelengths=np.empty(size, dtype=np.float32), + times=np.empty(size, dtype=np.float32), + histories=np.empty(size, dtype=np.uint32), + last_hit_triangles=np.empty(size, dtype=np.int32)) + + +def root_event_to_python_event(ev): + '''Returns a new chroma.event.Event object created from the + contents of the ROOT event `ev`.''' + pyev = event.Event(ev.event_id) + + # MC + pyev.particle_name = str(ev.mc.particle) + pyev.gen_position = tvector3_to_ndarray(ev.mc.gen_position) + pyev.gen_direction = tvector3_to_ndarray(ev.mc.gen_direction) + pyev.gen_total_energy = ev.mc.gen_total_energy + + pyev.nphoton = ev.mc.nphoton + + for subtrack in ev.mc.subtrack: + pysubtrack = event.Subtrack(str(subtrack.particle), + tvector3_to_ndarray(subtrack.position), + tvector3_to_ndarray(subtrack.direction), + subtrack.start_time, + subtrack.total_energy) + pyev.subtracks.append(pysubtrack) + + # photon start + if ev.mc.photon_start.size() > 0: + photons = make_photon_with_arrays(ev.mc.photon_start.size()) + ROOT.get_photons(ev.mc.photon_start, photons.positions.ravel(), photons.directions.ravel(), + photons.polarizations.ravel(), photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) + pyev.photon_start = photons + + # photon stop + if ev.mc.photon_stop.size() > 0: + photons = make_photon_with_arrays(ev.mc.photon_stop.size()) + ROOT.get_photons(ev.mc.photon_stop, photons.positions.ravel(), photons.directions.ravel(), + photons.polarizations.ravel(), photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) + pyev.photon_stop = photons + + # hits + max_channel_id = ev.max_channel_id + hit = np.empty(shape=max_channel_id+1, dtype=np.int32) + t = np.empty(shape=max_channel_id+1, dtype=np.float32) + q = np.empty(shape=max_channel_id+1, dtype=np.float32) + histories = np.empty(shape=max_channel_id+1, dtype=np.uint32) + + ev.get_channels(max_channel_id+1, hit, t, q, histories) + pyev.channels = event.Channels(hit.astype(bool), t, q, histories) + return pyev + +class RootReader(object): + '''Reader of Chroma events from a ROOT file. This class can be used to + navigate up and down the file linearly or in a random access fashion. + All returned events are instances of the chroma.event.Event class. + + It implements the iterator protocol, so you can do + + for ev in RootReader('electron.root'): + # process event here + ''' + + def __init__(self, filename): + '''Open ROOT file named `filename` containing TTree `T`.''' + self.f = ROOT.TFile(filename) + self.T = self.f.T + self.i = -1 + + def __len__(self): + '''Returns number of events in this file.''' + return self.T.GetEntries() + + def next(self): + '''Return the next event in the file. Raises StopIteration + when you get to the end.''' + if self.i + 1 >= len(self): + raise StopIteration + + self.i += 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def prev(self): + '''Return the next event in the file. Raises StopIteration if + that would go past the beginning.''' + if self.i <= 0: + self.i = -1 + raise StopIteration + + self.i -= 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def current(self): + '''Return the current event in the file.''' + self.T.GetEntry(self.i) # just in case? + return root_event_to_python_event(self.T.ev) + + def jump_to(self, index): + '''Return the event at `index`. Updates current location.''' + if index < 0 or index >= len(self): + raise IndexError + + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def index(self): + '''Return the current event index''' + return self.i + class RootWriter(object): def __init__(self, filename): self.filename = filename @@ -28,20 +151,22 @@ class RootWriter(object): if pyev.photon_start is not None: photons = pyev.photon_start - ROOT.fill_photons(self.ev, True, + ROOT.fill_photons(self.ev.mc.photon_start, len(photons.positions), np.ravel(photons.positions), np.ravel(photons.directions), np.ravel(photons.polarizations), - photons.wavelengths, photons.times) + photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) if pyev.photon_stop is not None: - photons = photon_stop - ROOT.fill_photons(self.ev, True, + photons = pyev.photon_stop + ROOT.fill_photons(self.ev.mc.photon_stop, len(photons.positions), np.ravel(photons.positions), np.ravel(photons.directions), np.ravel(photons.polarizations), - photons.wavelengths, photons.times) + photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) self.ev.mc.subtrack.resize(0) if pyev.subtracks is not None: @@ -53,7 +178,7 @@ class RootWriter(object): self.ev.mc.subtrack[i].start_time = subtrack.start_time self.ev.mc.subtrack[i].total_energy = subtrack.total_energy - ROOT.fill_hits(self.ev, len(pyev.hits.t), pyev.hits.t, pyev.hits.q, pyev.hits.histories) + ROOT.fill_hits(self.ev, len(pyev.channels.t), pyev.channels.t, pyev.channels.q, pyev.channels.histories) self.T.Fill() def close(self): |