diff options
-rw-r--r-- | fileio/root.C | 50 | ||||
-rw-r--r-- | fileio/root.py | 137 | ||||
-rwxr-xr-x | sim.py | 3 | ||||
-rw-r--r-- | tests/test_fileio.py | 73 |
4 files changed, 250 insertions, 13 deletions
diff --git a/fileio/root.C b/fileio/root.C index 7052a71..07e50a1 100644 --- a/fileio/root.C +++ b/fileio/root.C @@ -11,15 +11,18 @@ struct Photon { double wavelength; // nm unsigned int history; int last_hit_triangle; + + ClassDef(Photon, 1); }; struct Track { std::string particle; - double time; TVector3 position; TVector3 direction; double start_time; double total_energy; + + ClassDef(Track, 1); }; struct MC { @@ -34,6 +37,7 @@ struct MC { std::vector<Photon> photon_start; std::vector<Photon> photon_stop; + ClassDef(MC, 1); }; struct Channel { @@ -42,23 +46,30 @@ struct Channel { double time; double charge; unsigned int mc_history; + + ClassDef(Channel, 1); }; struct Event { int event_id; MC mc; int nhit; + int max_channel_id; std::vector<Channel> channel; + ClassDef(Event, 1); + // Populate arrays of length nentries with hit, time, and charge // information, indexed by channel ID void get_channels(unsigned int nentries, int *hit, float *time, - float *charge) + float *charge, unsigned int *mc_history=0) { for (unsigned int i=0; i < nentries; i++) { hit[i] = 0; time[i] = -1e9f; charge[i] = -1e9f; + if (mc_history) + mc_history[i] = 0; } for (unsigned int i=0; i < channel.size(); i++) { @@ -68,18 +79,45 @@ struct Event { hit[channel_id] = 1; time[channel_id] = channel[i].time; charge[channel_id] = channel[i].charge; + if (mc_history) + mc_history[channel_id] = channel[i].mc_history; } } } }; -void fill_photons(Event *ev, bool start, +void get_photons(const std::vector<Photon> &photons, float *positions, + float *directions, float *polarizations, float *wavelengths, + float *times, unsigned int *histories, int *last_hit_triangles) +{ + for (unsigned int i=0; i < photons.size(); i++) { + const Photon &photon = photons[i]; + positions[3*i] = photon.position.X(); + positions[3*i+1] = photon.position.Y(); + positions[3*i+2] = photon.position.Z(); + + directions[3*i] = photon.direction.X(); + directions[3*i+1] = photon.direction.Y(); + directions[3*i+2] = photon.direction.Z(); + + polarizations[3*i] = photon.polarization.X(); + polarizations[3*i+1] = photon.polarization.Y(); + polarizations[3*i+2] = photon.polarization.Z(); + + wavelengths[i] = photon.wavelength; + times[i] = photon.time; + histories[i] = photon.history; + last_hit_triangles[i] = photon.last_hit_triangle; + } +} + + +void fill_photons(std::vector<Photon> &photons, unsigned int nphotons, float *pos, float *dir, float *pol, float *wavelength, float *t0, - int *histories=0, int *last_hit_triangle=0) + unsigned int *histories=0, int *last_hit_triangle=0) { - std::vector<Photon> &photons = start ? ev->mc.photon_start : ev->mc.photon_stop; photons.resize(nphotons); for (unsigned int i=0; i < nphotons; i++) { @@ -107,6 +145,8 @@ void fill_hits(Event *ev, unsigned int nchannels, float *time, { ev->channel.resize(0); ev->nhit = 0; + ev->max_channel_id = nchannels - 1; + Channel ch; for (unsigned int i=0; i < nchannels; i++) { if (time[i] < 1e8) { diff --git a/fileio/root.py b/fileio/root.py index 1fa2425..4c9d9bb 100644 --- a/fileio/root.py +++ b/fileio/root.py @@ -7,6 +7,129 @@ ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g') import ROOT import chroma.event as event +def tvector3_to_ndarray(vec): + '''Convert a ROOT.TVector3 into a numpy np.float32 array''' + return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32) + +def make_photon_with_arrays(size): + '''Returns a new chroma.event.Photons object for `size` number of + photons with empty arrays set for all the photon attributes.''' + return event.Photons(positions=np.empty((size,3), dtype=np.float32), + directions=np.empty((size,3), dtype=np.float32), + polarizations=np.empty((size,3), dtype=np.float32), + wavelengths=np.empty(size, dtype=np.float32), + times=np.empty(size, dtype=np.float32), + histories=np.empty(size, dtype=np.uint32), + last_hit_triangles=np.empty(size, dtype=np.int32)) + + +def root_event_to_python_event(ev): + '''Returns a new chroma.event.Event object created from the + contents of the ROOT event `ev`.''' + pyev = event.Event(ev.event_id) + + # MC + pyev.particle_name = str(ev.mc.particle) + pyev.gen_position = tvector3_to_ndarray(ev.mc.gen_position) + pyev.gen_direction = tvector3_to_ndarray(ev.mc.gen_direction) + pyev.gen_total_energy = ev.mc.gen_total_energy + + pyev.nphoton = ev.mc.nphoton + + for subtrack in ev.mc.subtrack: + pysubtrack = event.Subtrack(str(subtrack.particle), + tvector3_to_ndarray(subtrack.position), + tvector3_to_ndarray(subtrack.direction), + subtrack.start_time, + subtrack.total_energy) + pyev.subtracks.append(pysubtrack) + + # photon start + if ev.mc.photon_start.size() > 0: + photons = make_photon_with_arrays(ev.mc.photon_start.size()) + ROOT.get_photons(ev.mc.photon_start, photons.positions.ravel(), photons.directions.ravel(), + photons.polarizations.ravel(), photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) + pyev.photon_start = photons + + # photon stop + if ev.mc.photon_stop.size() > 0: + photons = make_photon_with_arrays(ev.mc.photon_stop.size()) + ROOT.get_photons(ev.mc.photon_stop, photons.positions.ravel(), photons.directions.ravel(), + photons.polarizations.ravel(), photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) + pyev.photon_stop = photons + + # hits + max_channel_id = ev.max_channel_id + hit = np.empty(shape=max_channel_id+1, dtype=np.int32) + t = np.empty(shape=max_channel_id+1, dtype=np.float32) + q = np.empty(shape=max_channel_id+1, dtype=np.float32) + histories = np.empty(shape=max_channel_id+1, dtype=np.uint32) + + ev.get_channels(max_channel_id+1, hit, t, q, histories) + pyev.channels = event.Channels(hit.astype(bool), t, q, histories) + return pyev + +class RootReader(object): + '''Reader of Chroma events from a ROOT file. This class can be used to + navigate up and down the file linearly or in a random access fashion. + All returned events are instances of the chroma.event.Event class. + + It implements the iterator protocol, so you can do + + for ev in RootReader('electron.root'): + # process event here + ''' + + def __init__(self, filename): + '''Open ROOT file named `filename` containing TTree `T`.''' + self.f = ROOT.TFile(filename) + self.T = self.f.T + self.i = -1 + + def __len__(self): + '''Returns number of events in this file.''' + return self.T.GetEntries() + + def next(self): + '''Return the next event in the file. Raises StopIteration + when you get to the end.''' + if self.i + 1 >= len(self): + raise StopIteration + + self.i += 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def prev(self): + '''Return the next event in the file. Raises StopIteration if + that would go past the beginning.''' + if self.i <= 0: + self.i = -1 + raise StopIteration + + self.i -= 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def current(self): + '''Return the current event in the file.''' + self.T.GetEntry(self.i) # just in case? + return root_event_to_python_event(self.T.ev) + + def jump_to(self, index): + '''Return the event at `index`. Updates current location.''' + if index < 0 or index >= len(self): + raise IndexError + + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def index(self): + '''Return the current event index''' + return self.i + class RootWriter(object): def __init__(self, filename): self.filename = filename @@ -28,20 +151,22 @@ class RootWriter(object): if pyev.photon_start is not None: photons = pyev.photon_start - ROOT.fill_photons(self.ev, True, + ROOT.fill_photons(self.ev.mc.photon_start, len(photons.positions), np.ravel(photons.positions), np.ravel(photons.directions), np.ravel(photons.polarizations), - photons.wavelengths, photons.times) + photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) if pyev.photon_stop is not None: - photons = photon_stop - ROOT.fill_photons(self.ev, True, + photons = pyev.photon_stop + ROOT.fill_photons(self.ev.mc.photon_stop, len(photons.positions), np.ravel(photons.positions), np.ravel(photons.directions), np.ravel(photons.polarizations), - photons.wavelengths, photons.times) + photons.wavelengths, photons.times, + photons.histories, photons.last_hit_triangles) self.ev.mc.subtrack.resize(0) if pyev.subtracks is not None: @@ -53,7 +178,7 @@ class RootWriter(object): self.ev.mc.subtrack[i].start_time = subtrack.start_time self.ev.mc.subtrack[i].total_energy = subtrack.total_energy - ROOT.fill_hits(self.ev, len(pyev.hits.t), pyev.hits.t, pyev.hits.q, pyev.hits.histories) + ROOT.fill_hits(self.ev, len(pyev.channels.t), pyev.channels.t, pyev.channels.q, pyev.channels.histories) self.T.Fill() def close(self): @@ -82,8 +82,7 @@ class Simulation(object): ev.photon_stop = self.gpu_worker.get_photons() if run_daq: - ev.hits = self.gpu_worker.get_hits() - ev.channels = ev.hits + ev.channels = self.gpu_worker.get_hits() yield ev diff --git a/tests/test_fileio.py b/tests/test_fileio.py new file mode 100644 index 0000000..d21c9ff --- /dev/null +++ b/tests/test_fileio.py @@ -0,0 +1,73 @@ +import unittest +from chroma.fileio import root +from chroma import event +import numpy as np + +class TestFileIO(unittest.TestCase): + def test_file_write_and_read(self): + ev = event.Event(1, 'e-', (0,0,1), (1,0,0), 15) + + photon_start = root.make_photon_with_arrays(1) + photon_start.positions[0] = (1,2,3) + photon_start.directions[0] = (4,5,6) + photon_start.polarizations[0] = (7,8,9) + photon_start.wavelengths[0] = 400.0 + photon_start.times[0] = 100.0 + photon_start.histories[0] = 20 + photon_start.last_hit_triangles[0] = 5 + ev.photon_start = photon_start + + photon_stop = root.make_photon_with_arrays(1) + photon_stop.positions[0] = (1,2,3) + photon_stop.directions[0] = (4,5,6) + photon_stop.polarizations[0] = (7,8,9) + photon_stop.wavelengths[0] = 400.0 + photon_stop.times[0] = 100.0 + photon_stop.histories[0] = 20 + photon_stop.last_hit_triangles[0] = 5 + ev.photon_stop = photon_stop + + ev.nphoton = 1 + + ev.subtracks.append(event.Subtrack('e-', (40,30,20), (-1, -2, -3), 400, 800)) + + channels = event.Channels(hit=np.array([True, False]), + t=np.array([20.0, 1e9], dtype=np.float32), + q=np.array([2.0, 0.0], dtype=np.float32), + histories=np.array([8, 32], dtype=np.uint32)) + ev.channels = channels + + filename = '/tmp/chroma-filewritertest.root' + writer = root.RootWriter(filename) + writer.write_event(ev) + writer.close() + + # Exercise the RootReader methods + reader = root.RootReader(filename) + self.assertEquals(len(reader), 1) + + self.assertRaises(StopIteration, reader.prev) + + reader.next() + + self.assertEqual(reader.index(), 0) + self.assertRaises(StopIteration, reader.next) + + reader.jump_to(0) + + # Enough screwing around, let's get the one event in the file + newev = reader.current() + + + # Now check if everything is correct in the event + for attribute in ['event_id', 'particle_name','gen_total_energy']: + self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute) + for attribute in ['gen_position', 'gen_direction']: + self.assertTrue(np.allclose(getattr(ev, attribute), getattr(newev, attribute)), 'compare %s' % attribute) + + for attribute in ['positions', 'directions', 'wavelengths', 'polarizations', 'times', + 'histories', 'last_hit_triangles']: + self.assertTrue(np.allclose(getattr(ev.photon_start, attribute), + getattr(newev.photon_start, attribute)), 'compare %s' % attribute) + self.assertTrue(np.allclose(getattr(ev.photon_stop, attribute), + getattr(newev.photon_stop, attribute)), 'compare %s' % attribute) |