summaryrefslogtreecommitdiff
path: root/fileio
diff options
context:
space:
mode:
authorStan Seibert <stan@mtrr.org>2011-08-23 19:28:36 -0400
committerStan Seibert <stan@mtrr.org>2011-08-23 19:28:36 -0400
commit2f0b5cd9b42d64b50bd123b87a0c91207d674dfa (patch)
tree0bcf8a1c5a4adc3399a143cc685ebfc636f4cf64 /fileio
parent20e90ebfea255c09a7e49204e6d94fdea352340a (diff)
downloadchroma-2f0b5cd9b42d64b50bd123b87a0c91207d674dfa.tar.gz
chroma-2f0b5cd9b42d64b50bd123b87a0c91207d674dfa.tar.bz2
chroma-2f0b5cd9b42d64b50bd123b87a0c91207d674dfa.zip
Add a RootReader class that also functions as an iterator, also create
a simple unit test for event reading and writing. There were several minor I/O bugs that are now fixed! Always test your code, kids!
Diffstat (limited to 'fileio')
-rw-r--r--fileio/root.C50
-rw-r--r--fileio/root.py137
2 files changed, 176 insertions, 11 deletions
diff --git a/fileio/root.C b/fileio/root.C
index 7052a71..07e50a1 100644
--- a/fileio/root.C
+++ b/fileio/root.C
@@ -11,15 +11,18 @@ struct Photon {
double wavelength; // nm
unsigned int history;
int last_hit_triangle;
+
+ ClassDef(Photon, 1);
};
struct Track {
std::string particle;
- double time;
TVector3 position;
TVector3 direction;
double start_time;
double total_energy;
+
+ ClassDef(Track, 1);
};
struct MC {
@@ -34,6 +37,7 @@ struct MC {
std::vector<Photon> photon_start;
std::vector<Photon> photon_stop;
+ ClassDef(MC, 1);
};
struct Channel {
@@ -42,23 +46,30 @@ struct Channel {
double time;
double charge;
unsigned int mc_history;
+
+ ClassDef(Channel, 1);
};
struct Event {
int event_id;
MC mc;
int nhit;
+ int max_channel_id;
std::vector<Channel> channel;
+ ClassDef(Event, 1);
+
// Populate arrays of length nentries with hit, time, and charge
// information, indexed by channel ID
void get_channels(unsigned int nentries, int *hit, float *time,
- float *charge)
+ float *charge, unsigned int *mc_history=0)
{
for (unsigned int i=0; i < nentries; i++) {
hit[i] = 0;
time[i] = -1e9f;
charge[i] = -1e9f;
+ if (mc_history)
+ mc_history[i] = 0;
}
for (unsigned int i=0; i < channel.size(); i++) {
@@ -68,18 +79,45 @@ struct Event {
hit[channel_id] = 1;
time[channel_id] = channel[i].time;
charge[channel_id] = channel[i].charge;
+ if (mc_history)
+ mc_history[channel_id] = channel[i].mc_history;
}
}
}
};
-void fill_photons(Event *ev, bool start,
+void get_photons(const std::vector<Photon> &photons, float *positions,
+ float *directions, float *polarizations, float *wavelengths,
+ float *times, unsigned int *histories, int *last_hit_triangles)
+{
+ for (unsigned int i=0; i < photons.size(); i++) {
+ const Photon &photon = photons[i];
+ positions[3*i] = photon.position.X();
+ positions[3*i+1] = photon.position.Y();
+ positions[3*i+2] = photon.position.Z();
+
+ directions[3*i] = photon.direction.X();
+ directions[3*i+1] = photon.direction.Y();
+ directions[3*i+2] = photon.direction.Z();
+
+ polarizations[3*i] = photon.polarization.X();
+ polarizations[3*i+1] = photon.polarization.Y();
+ polarizations[3*i+2] = photon.polarization.Z();
+
+ wavelengths[i] = photon.wavelength;
+ times[i] = photon.time;
+ histories[i] = photon.history;
+ last_hit_triangles[i] = photon.last_hit_triangle;
+ }
+}
+
+
+void fill_photons(std::vector<Photon> &photons,
unsigned int nphotons, float *pos, float *dir,
float *pol, float *wavelength, float *t0,
- int *histories=0, int *last_hit_triangle=0)
+ unsigned int *histories=0, int *last_hit_triangle=0)
{
- std::vector<Photon> &photons = start ? ev->mc.photon_start : ev->mc.photon_stop;
photons.resize(nphotons);
for (unsigned int i=0; i < nphotons; i++) {
@@ -107,6 +145,8 @@ void fill_hits(Event *ev, unsigned int nchannels, float *time,
{
ev->channel.resize(0);
ev->nhit = 0;
+ ev->max_channel_id = nchannels - 1;
+
Channel ch;
for (unsigned int i=0; i < nchannels; i++) {
if (time[i] < 1e8) {
diff --git a/fileio/root.py b/fileio/root.py
index 1fa2425..4c9d9bb 100644
--- a/fileio/root.py
+++ b/fileio/root.py
@@ -7,6 +7,129 @@ ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g')
import ROOT
import chroma.event as event
+def tvector3_to_ndarray(vec):
+ '''Convert a ROOT.TVector3 into a numpy np.float32 array'''
+ return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32)
+
+def make_photon_with_arrays(size):
+ '''Returns a new chroma.event.Photons object for `size` number of
+ photons with empty arrays set for all the photon attributes.'''
+ return event.Photons(positions=np.empty((size,3), dtype=np.float32),
+ directions=np.empty((size,3), dtype=np.float32),
+ polarizations=np.empty((size,3), dtype=np.float32),
+ wavelengths=np.empty(size, dtype=np.float32),
+ times=np.empty(size, dtype=np.float32),
+ histories=np.empty(size, dtype=np.uint32),
+ last_hit_triangles=np.empty(size, dtype=np.int32))
+
+
+def root_event_to_python_event(ev):
+ '''Returns a new chroma.event.Event object created from the
+ contents of the ROOT event `ev`.'''
+ pyev = event.Event(ev.event_id)
+
+ # MC
+ pyev.particle_name = str(ev.mc.particle)
+ pyev.gen_position = tvector3_to_ndarray(ev.mc.gen_position)
+ pyev.gen_direction = tvector3_to_ndarray(ev.mc.gen_direction)
+ pyev.gen_total_energy = ev.mc.gen_total_energy
+
+ pyev.nphoton = ev.mc.nphoton
+
+ for subtrack in ev.mc.subtrack:
+ pysubtrack = event.Subtrack(str(subtrack.particle),
+ tvector3_to_ndarray(subtrack.position),
+ tvector3_to_ndarray(subtrack.direction),
+ subtrack.start_time,
+ subtrack.total_energy)
+ pyev.subtracks.append(pysubtrack)
+
+ # photon start
+ if ev.mc.photon_start.size() > 0:
+ photons = make_photon_with_arrays(ev.mc.photon_start.size())
+ ROOT.get_photons(ev.mc.photon_start, photons.positions.ravel(), photons.directions.ravel(),
+ photons.polarizations.ravel(), photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
+ pyev.photon_start = photons
+
+ # photon stop
+ if ev.mc.photon_stop.size() > 0:
+ photons = make_photon_with_arrays(ev.mc.photon_stop.size())
+ ROOT.get_photons(ev.mc.photon_stop, photons.positions.ravel(), photons.directions.ravel(),
+ photons.polarizations.ravel(), photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
+ pyev.photon_stop = photons
+
+ # hits
+ max_channel_id = ev.max_channel_id
+ hit = np.empty(shape=max_channel_id+1, dtype=np.int32)
+ t = np.empty(shape=max_channel_id+1, dtype=np.float32)
+ q = np.empty(shape=max_channel_id+1, dtype=np.float32)
+ histories = np.empty(shape=max_channel_id+1, dtype=np.uint32)
+
+ ev.get_channels(max_channel_id+1, hit, t, q, histories)
+ pyev.channels = event.Channels(hit.astype(bool), t, q, histories)
+ return pyev
+
+class RootReader(object):
+ '''Reader of Chroma events from a ROOT file. This class can be used to
+ navigate up and down the file linearly or in a random access fashion.
+ All returned events are instances of the chroma.event.Event class.
+
+ It implements the iterator protocol, so you can do
+
+ for ev in RootReader('electron.root'):
+ # process event here
+ '''
+
+ def __init__(self, filename):
+ '''Open ROOT file named `filename` containing TTree `T`.'''
+ self.f = ROOT.TFile(filename)
+ self.T = self.f.T
+ self.i = -1
+
+ def __len__(self):
+ '''Returns number of events in this file.'''
+ return self.T.GetEntries()
+
+ def next(self):
+ '''Return the next event in the file. Raises StopIteration
+ when you get to the end.'''
+ if self.i + 1 >= len(self):
+ raise StopIteration
+
+ self.i += 1
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def prev(self):
+ '''Return the next event in the file. Raises StopIteration if
+ that would go past the beginning.'''
+ if self.i <= 0:
+ self.i = -1
+ raise StopIteration
+
+ self.i -= 1
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def current(self):
+ '''Return the current event in the file.'''
+ self.T.GetEntry(self.i) # just in case?
+ return root_event_to_python_event(self.T.ev)
+
+ def jump_to(self, index):
+ '''Return the event at `index`. Updates current location.'''
+ if index < 0 or index >= len(self):
+ raise IndexError
+
+ self.T.GetEntry(self.i)
+ return root_event_to_python_event(self.T.ev)
+
+ def index(self):
+ '''Return the current event index'''
+ return self.i
+
class RootWriter(object):
def __init__(self, filename):
self.filename = filename
@@ -28,20 +151,22 @@ class RootWriter(object):
if pyev.photon_start is not None:
photons = pyev.photon_start
- ROOT.fill_photons(self.ev, True,
+ ROOT.fill_photons(self.ev.mc.photon_start,
len(photons.positions),
np.ravel(photons.positions),
np.ravel(photons.directions),
np.ravel(photons.polarizations),
- photons.wavelengths, photons.times)
+ photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
if pyev.photon_stop is not None:
- photons = photon_stop
- ROOT.fill_photons(self.ev, True,
+ photons = pyev.photon_stop
+ ROOT.fill_photons(self.ev.mc.photon_stop,
len(photons.positions),
np.ravel(photons.positions),
np.ravel(photons.directions),
np.ravel(photons.polarizations),
- photons.wavelengths, photons.times)
+ photons.wavelengths, photons.times,
+ photons.histories, photons.last_hit_triangles)
self.ev.mc.subtrack.resize(0)
if pyev.subtracks is not None:
@@ -53,7 +178,7 @@ class RootWriter(object):
self.ev.mc.subtrack[i].start_time = subtrack.start_time
self.ev.mc.subtrack[i].total_energy = subtrack.total_energy
- ROOT.fill_hits(self.ev, len(pyev.hits.t), pyev.hits.t, pyev.hits.q, pyev.hits.histories)
+ ROOT.fill_hits(self.ev, len(pyev.channels.t), pyev.channels.t, pyev.channels.q, pyev.channels.histories)
self.T.Fill()
def close(self):