From ba3ef49e9822373733a15e083c41024ffc0cc3cf Mon Sep 17 00:00:00 2001 From: Stan Seibert Date: Fri, 16 Sep 2011 20:58:13 -0400 Subject: Rename chroma.fileio to chroma.io --- .rootlogon.C | 2 +- chroma/__init__.py | 2 +- chroma/camera.py | 2 +- chroma/fileio/__init__.py | 0 chroma/fileio/root.py | 199 ---------------------------------------------- chroma/io/__init__.py | 0 chroma/io/root.C | 147 ++++++++++++++++++++++++++++++++++ chroma/io/root.py | 199 ++++++++++++++++++++++++++++++++++++++++++++++ src/root.C | 147 ---------------------------------- test/test_fileio.py | 74 ----------------- test/test_io.py | 74 +++++++++++++++++ 11 files changed, 423 insertions(+), 423 deletions(-) delete mode 100644 chroma/fileio/__init__.py delete mode 100644 chroma/fileio/root.py create mode 100644 chroma/io/__init__.py create mode 100644 chroma/io/root.C create mode 100644 chroma/io/root.py delete mode 100644 src/root.C delete mode 100644 test/test_fileio.py create mode 100644 test/test_io.py diff --git a/.rootlogon.C b/.rootlogon.C index 17915c5..ce0159c 100644 --- a/.rootlogon.C +++ b/.rootlogon.C @@ -1,3 +1,3 @@ { - gROOT->ProcessLine(".L fileio/root.C+g"); + //gROOT->ProcessLine(".L fileio/root.C+g"); } diff --git a/chroma/__init__.py b/chroma/__init__.py index 3359bbc..bc6d5cc 100644 --- a/chroma/__init__.py +++ b/chroma/__init__.py @@ -1,7 +1,7 @@ from camera import Camera, EventViewer, view, build import geometry import event -from fileio import root +from io import root import generator import gpu import itertoolset diff --git a/chroma/camera.py b/chroma/camera.py index 470a1b9..575d13d 100644 --- a/chroma/camera.py +++ b/chroma/camera.py @@ -18,7 +18,7 @@ from chroma.transform import rotate, make_rotation_matrix from chroma.sample import uniform_sphere from chroma.optics import vacuum from chroma.project import from_film -from chroma.fileio.root import RootReader +from chroma.io.root import RootReader from chroma import make from chroma import gpu diff --git a/chroma/fileio/__init__.py b/chroma/fileio/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chroma/fileio/root.py b/chroma/fileio/root.py deleted file mode 100644 index af86deb..0000000 --- a/chroma/fileio/root.py +++ /dev/null @@ -1,199 +0,0 @@ -import ROOT -import os.path -import numpy as np - -ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C+g')) - -import ROOT -import chroma.event as event - -def tvector3_to_ndarray(vec): - '''Convert a ROOT.TVector3 into a numpy np.float32 array''' - return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32) - -def make_photon_with_arrays(size): - '''Returns a new chroma.event.Photons object for `size` number of - photons with empty arrays set for all the photon attributes.''' - return event.Photons(pos=np.empty((size,3), dtype=np.float32), - dir=np.empty((size,3), dtype=np.float32), - pol=np.empty((size,3), dtype=np.float32), - wavelengths=np.empty(size, dtype=np.float32), - t=np.empty(size, dtype=np.float32), - flags=np.empty(size, dtype=np.uint32), - last_hit_triangles=np.empty(size, dtype=np.int32)) - -def root_vertex_to_python_vertex(vertex): - "Returns a chroma.event.Vertex object from a root Vertex object." - return event.Vertex(str(vertex.particle_name), - pos=tvector3_to_ndarray(vertex.pos), - dir=tvector3_to_ndarray(vertex.dir), - ke=vertex.ke, - t0=vertex.t0, - pol=tvector3_to_ndarray(vertex.pol)) - -def root_event_to_python_event(ev): - '''Returns a new chroma.event.Event object created from the - contents of the ROOT event `ev`.''' - pyev = event.Event(ev.id) - pyev.primary_vertex = root_vertex_to_python_vertex(ev.primary_vertex) - - for vertex in ev.vertices: - pyev.vertices.append(root_vertex_to_python_vertex(vertex)) - - # photon begin - if ev.photons_beg.size() > 0: - photons = make_photon_with_arrays(ev.photons_beg.size()) - ROOT.get_photons(ev.photons_beg, - photons.pos.ravel(), - photons.dir.ravel(), - photons.pol.ravel(), - photons.wavelengths, - photons.t, - photons.last_hit_triangles, - photons.flags) - pyev.photons_beg = photons - - # photon end - if ev.photons_end.size() > 0: - photons = make_photon_with_arrays(ev.photons_end.size()) - ROOT.get_photons(ev.photons_end, - photons.pos.ravel(), - photons.dir.ravel(), - photons.pol.ravel(), - photons.wavelengths, - photons.t, - photons.last_hit_triangles, - photons.flags) - pyev.photons_end = photons - - # channels - hit = np.empty(ev.nchannels, dtype=np.int32) - t = np.empty(ev.nchannels, dtype=np.float32) - q = np.empty(ev.nchannels, dtype=np.float32) - flags = np.empty(ev.nchannels, dtype=np.uint32) - - ROOT.get_channels(ev, hit, t, q, flags) - pyev.channels = event.Channels(hit.astype(bool), t, q, flags) - return pyev - -class RootReader(object): - '''Reader of Chroma events from a ROOT file. This class can be used to - navigate up and down the file linearly or in a random access fashion. - All returned events are instances of the chroma.event.Event class. - - It implements the iterator protocol, so you can do - - for ev in RootReader('electron.root'): - # process event here - ''' - - def __init__(self, filename): - '''Open ROOT file named `filename` containing TTree `T`.''' - self.f = ROOT.TFile(filename) - self.T = self.f.T - self.i = -1 - - def __len__(self): - '''Returns number of events in this file.''' - return self.T.GetEntries() - - def next(self): - '''Return the next event in the file. Raises StopIteration - when you get to the end.''' - if self.i + 1 >= len(self): - raise StopIteration - - self.i += 1 - self.T.GetEntry(self.i) - return root_event_to_python_event(self.T.ev) - - def prev(self): - '''Return the next event in the file. Raises StopIteration if - that would go past the beginning.''' - if self.i <= 0: - self.i = -1 - raise StopIteration - - self.i -= 1 - self.T.GetEntry(self.i) - return root_event_to_python_event(self.T.ev) - - def current(self): - '''Return the current event in the file.''' - self.T.GetEntry(self.i) # just in case? - return root_event_to_python_event(self.T.ev) - - def jump_to(self, index): - '''Return the event at `index`. Updates current location.''' - if index < 0 or index >= len(self): - raise IndexError - - self.T.GetEntry(self.i) - return root_event_to_python_event(self.T.ev) - - def index(self): - '''Return the current event index''' - return self.i - -class RootWriter(object): - def __init__(self, filename): - self.filename = filename - self.file = ROOT.TFile(filename, 'RECREATE') - - self.T = ROOT.TTree('T', 'Chroma events') - self.ev = ROOT.Event() - self.T.Branch('ev', self.ev) - - def write_event(self, pyev): - "Write an event.Event object to the ROOT tree as a ROOT.Event object." - self.ev.id = pyev.id - - if pyev.primary_vertex is not None: - self.ev.primary_vertex.particle_name = \ - pyev.primary_vertex.particle_name - self.ev.primary_vertex.pos.SetXYZ(*pyev.primary_vertex.pos) - self.ev.primary_vertex.dir.SetXYZ(*pyev.primary_vertex.dir) - if pyev.primary_vertex.pol is not None: - self.ev.primary_vertex.pol.SetXYZ(*pyev.primary_vertex.pol) - self.ev.primary_vertex.ke = pyev.primary_vertex.ke - - if pyev.photons_beg is not None: - photons = pyev.photons_beg - ROOT.fill_photons(self.ev.photons_beg, - len(photons.pos), - photons.pos.ravel(), - photons.dir.ravel(), - photons.pol.ravel(), - photons.wavelengths, photons.t, - photons.last_hit_triangles, photons.flags) - - if pyev.photons_end is not None: - photons = pyev.photons_end - ROOT.fill_photons(self.ev.photons_end, - len(photons.pos), - photons.pos.ravel(), - photons.dir.ravel(), - photons.pol.ravel(), - photons.wavelengths, photons.t, - photons.last_hit_triangles, photons.flags) - - self.ev.vertices.resize(0) - if pyev.vertices is not None: - self.ev.vertices.resize(len(pyev.vertices)) - for i, vertex in enumerate(pyev.vertices): - self.ev.vertices[i].particle_name = vertex.particle_name - self.ev.vertices[i].pos.SetXYZ(*vertex.pos) - self.ev.vertices[i].dir.SetXYZ(*vertex.dir) - if vertex.pol is not None: - self.ev.vertices[i].pol.SetXYZ(*vertex.pol) - self.ev.vertices[i].ke = vertex.ke - self.ev.vertices[i].t0 = vertex.t0 - - if pyev.channels is not None: - ROOT.fill_channels(self.ev, np.count_nonzero(pyev.channels.hit), np.arange(len(pyev.channels.t))[pyev.channels.hit].astype(np.int32), pyev.channels.t, pyev.channels.q, pyev.channels.flags, len(pyev.channels.hit)) - - self.T.Fill() - - def close(self): - self.T.Write() - self.file.Close() diff --git a/chroma/io/__init__.py b/chroma/io/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chroma/io/root.C b/chroma/io/root.C new file mode 100644 index 0000000..2b39b1b --- /dev/null +++ b/chroma/io/root.C @@ -0,0 +1,147 @@ +#include +#include +#include +#include + +struct Vertex { + std::string particle_name; + TVector3 pos; + TVector3 dir; + TVector3 pol; + double ke; + double t0; + + ClassDef(Vertex, 1); +}; + +struct Photon { + double t; + TVector3 pos; + TVector3 dir; + TVector3 pol; + double wavelength; // nm + unsigned int flag; + int last_hit_triangle; + + ClassDef(Photon, 1); +}; + +struct Channel { + Channel() : id(-1), t(-1e9), q(-1e9) { }; + int id; + double t; + double q; + unsigned int flag; + + ClassDef(Channel, 1); +}; + +struct Event { + int id; + unsigned int nhit; + unsigned int nchannels; + + Vertex primary_vertex; + + std::vector vertices; + std::vector photons_beg; + std::vector photons_end; + std::vector channels; + + ClassDef(Event, 1); +}; + +void fill_channels(Event *ev, unsigned int nhit, unsigned int *ids, float *t, + float *q, unsigned int *flags, unsigned int nchannels) +{ + ev->nhit = 0; + ev->nchannels = nchannels; + ev->channels.resize(0); + + Channel ch; + unsigned int id; + for (unsigned int i=0; i < nhit; i++) { + ev->nhit++; + id = ids[i]; + ch.id = id; + ch.t = t[id]; + ch.q = q[id]; + ch.flag = flags[id]; + ev->channels.push_back(ch); + } +} + +void get_channels(Event *ev, int *hit, float *t, float *q, unsigned int *flags) +{ + for (unsigned int i=0; i < ev->nchannels; i++) { + hit[i] = 0; + t[i] = -1e9f; + q[i] = -1e9f; + flags[i] = 0; + } + + unsigned int id; + for (unsigned int i=0; i < ev->channels.size(); i++) { + id = ev->channels[i].id; + + if (id < ev->nchannels) { + hit[id] = 1; + t[id] = ev->channels[i].t; + q[id] = ev->channels[i].q; + flags[id] = ev->channels[i].flag; + } + } +} + +void get_photons(const std::vector &photons, float *pos, float *dir, + float *pol, float *wavelengths, float *t, + int *last_hit_triangles, unsigned int *flags) +{ + for (unsigned int i=0; i < photons.size(); i++) { + const Photon &photon = photons[i]; + pos[3*i] = photon.pos.X(); + pos[3*i+1] = photon.pos.Y(); + pos[3*i+2] = photon.pos.Z(); + + dir[3*i] = photon.dir.X(); + dir[3*i+1] = photon.dir.Y(); + dir[3*i+2] = photon.dir.Z(); + + pol[3*i] = photon.pol.X(); + pol[3*i+1] = photon.pol.Y(); + pol[3*i+2] = photon.pol.Z(); + + wavelengths[i] = photon.wavelength; + t[i] = photon.t; + flags[i] = photon.flag; + last_hit_triangles[i] = photon.last_hit_triangle; + } +} + +void fill_photons(std::vector &photons, + unsigned int nphotons, float *pos, float *dir, + float *pol, float *wavelengths, float *t, + int *last_hit_triangles, unsigned int *flags) +{ + photons.resize(nphotons); + + for (unsigned int i=0; i < nphotons; i++) { + Photon &photon = photons[i]; + photon.t = t[i]; + photon.pos.SetXYZ(pos[3*i], pos[3*i + 1], pos[3*i + 2]); + photon.dir.SetXYZ(dir[3*i], dir[3*i + 1], dir[3*i + 2]); + photon.pol.SetXYZ(pol[3*i], pol[3*i + 1], pol[3*i + 2]); + photon.wavelength = wavelengths[i]; + photon.last_hit_triangle = last_hit_triangles[i]; + photon.flag = flags[i]; + + } +} + +#ifdef __MAKECINT__ +#pragma link C++ class vector; +#pragma link C++ class vector; +#pragma link C++ class vector; +#endif + + diff --git a/chroma/io/root.py b/chroma/io/root.py new file mode 100644 index 0000000..bb4b39a --- /dev/null +++ b/chroma/io/root.py @@ -0,0 +1,199 @@ +import ROOT +import os.path +import numpy as np + +ROOT.gROOT.ProcessLine('.L '+os.path.join(os.path.dirname(__file__), 'root.C')) + +import ROOT +import chroma.event as event + +def tvector3_to_ndarray(vec): + '''Convert a ROOT.TVector3 into a numpy np.float32 array''' + return np.array((vec.X(), vec.Y(), vec.Z()), dtype=np.float32) + +def make_photon_with_arrays(size): + '''Returns a new chroma.event.Photons object for `size` number of + photons with empty arrays set for all the photon attributes.''' + return event.Photons(pos=np.empty((size,3), dtype=np.float32), + dir=np.empty((size,3), dtype=np.float32), + pol=np.empty((size,3), dtype=np.float32), + wavelengths=np.empty(size, dtype=np.float32), + t=np.empty(size, dtype=np.float32), + flags=np.empty(size, dtype=np.uint32), + last_hit_triangles=np.empty(size, dtype=np.int32)) + +def root_vertex_to_python_vertex(vertex): + "Returns a chroma.event.Vertex object from a root Vertex object." + return event.Vertex(str(vertex.particle_name), + pos=tvector3_to_ndarray(vertex.pos), + dir=tvector3_to_ndarray(vertex.dir), + ke=vertex.ke, + t0=vertex.t0, + pol=tvector3_to_ndarray(vertex.pol)) + +def root_event_to_python_event(ev): + '''Returns a new chroma.event.Event object created from the + contents of the ROOT event `ev`.''' + pyev = event.Event(ev.id) + pyev.primary_vertex = root_vertex_to_python_vertex(ev.primary_vertex) + + for vertex in ev.vertices: + pyev.vertices.append(root_vertex_to_python_vertex(vertex)) + + # photon begin + if ev.photons_beg.size() > 0: + photons = make_photon_with_arrays(ev.photons_beg.size()) + ROOT.get_photons(ev.photons_beg, + photons.pos.ravel(), + photons.dir.ravel(), + photons.pol.ravel(), + photons.wavelengths, + photons.t, + photons.last_hit_triangles, + photons.flags) + pyev.photons_beg = photons + + # photon end + if ev.photons_end.size() > 0: + photons = make_photon_with_arrays(ev.photons_end.size()) + ROOT.get_photons(ev.photons_end, + photons.pos.ravel(), + photons.dir.ravel(), + photons.pol.ravel(), + photons.wavelengths, + photons.t, + photons.last_hit_triangles, + photons.flags) + pyev.photons_end = photons + + # channels + hit = np.empty(ev.nchannels, dtype=np.int32) + t = np.empty(ev.nchannels, dtype=np.float32) + q = np.empty(ev.nchannels, dtype=np.float32) + flags = np.empty(ev.nchannels, dtype=np.uint32) + + ROOT.get_channels(ev, hit, t, q, flags) + pyev.channels = event.Channels(hit.astype(bool), t, q, flags) + return pyev + +class RootReader(object): + '''Reader of Chroma events from a ROOT file. This class can be used to + navigate up and down the file linearly or in a random access fashion. + All returned events are instances of the chroma.event.Event class. + + It implements the iterator protocol, so you can do + + for ev in RootReader('electron.root'): + # process event here + ''' + + def __init__(self, filename): + '''Open ROOT file named `filename` containing TTree `T`.''' + self.f = ROOT.TFile(filename) + self.T = self.f.T + self.i = -1 + + def __len__(self): + '''Returns number of events in this file.''' + return self.T.GetEntries() + + def next(self): + '''Return the next event in the file. Raises StopIteration + when you get to the end.''' + if self.i + 1 >= len(self): + raise StopIteration + + self.i += 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def prev(self): + '''Return the next event in the file. Raises StopIteration if + that would go past the beginning.''' + if self.i <= 0: + self.i = -1 + raise StopIteration + + self.i -= 1 + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def current(self): + '''Return the current event in the file.''' + self.T.GetEntry(self.i) # just in case? + return root_event_to_python_event(self.T.ev) + + def jump_to(self, index): + '''Return the event at `index`. Updates current location.''' + if index < 0 or index >= len(self): + raise IndexError + + self.T.GetEntry(self.i) + return root_event_to_python_event(self.T.ev) + + def index(self): + '''Return the current event index''' + return self.i + +class RootWriter(object): + def __init__(self, filename): + self.filename = filename + self.file = ROOT.TFile(filename, 'RECREATE') + + self.T = ROOT.TTree('T', 'Chroma events') + self.ev = ROOT.Event() + self.T.Branch('ev', self.ev) + + def write_event(self, pyev): + "Write an event.Event object to the ROOT tree as a ROOT.Event object." + self.ev.id = pyev.id + + if pyev.primary_vertex is not None: + self.ev.primary_vertex.particle_name = \ + pyev.primary_vertex.particle_name + self.ev.primary_vertex.pos.SetXYZ(*pyev.primary_vertex.pos) + self.ev.primary_vertex.dir.SetXYZ(*pyev.primary_vertex.dir) + if pyev.primary_vertex.pol is not None: + self.ev.primary_vertex.pol.SetXYZ(*pyev.primary_vertex.pol) + self.ev.primary_vertex.ke = pyev.primary_vertex.ke + + if pyev.photons_beg is not None: + photons = pyev.photons_beg + ROOT.fill_photons(self.ev.photons_beg, + len(photons.pos), + photons.pos.ravel(), + photons.dir.ravel(), + photons.pol.ravel(), + photons.wavelengths, photons.t, + photons.last_hit_triangles, photons.flags) + + if pyev.photons_end is not None: + photons = pyev.photons_end + ROOT.fill_photons(self.ev.photons_end, + len(photons.pos), + photons.pos.ravel(), + photons.dir.ravel(), + photons.pol.ravel(), + photons.wavelengths, photons.t, + photons.last_hit_triangles, photons.flags) + + self.ev.vertices.resize(0) + if pyev.vertices is not None: + self.ev.vertices.resize(len(pyev.vertices)) + for i, vertex in enumerate(pyev.vertices): + self.ev.vertices[i].particle_name = vertex.particle_name + self.ev.vertices[i].pos.SetXYZ(*vertex.pos) + self.ev.vertices[i].dir.SetXYZ(*vertex.dir) + if vertex.pol is not None: + self.ev.vertices[i].pol.SetXYZ(*vertex.pol) + self.ev.vertices[i].ke = vertex.ke + self.ev.vertices[i].t0 = vertex.t0 + + if pyev.channels is not None: + ROOT.fill_channels(self.ev, np.count_nonzero(pyev.channels.hit), np.arange(len(pyev.channels.t))[pyev.channels.hit].astype(np.int32), pyev.channels.t, pyev.channels.q, pyev.channels.flags, len(pyev.channels.hit)) + + self.T.Fill() + + def close(self): + self.T.Write() + self.file.Close() diff --git a/src/root.C b/src/root.C deleted file mode 100644 index 2b39b1b..0000000 --- a/src/root.C +++ /dev/null @@ -1,147 +0,0 @@ -#include -#include -#include -#include - -struct Vertex { - std::string particle_name; - TVector3 pos; - TVector3 dir; - TVector3 pol; - double ke; - double t0; - - ClassDef(Vertex, 1); -}; - -struct Photon { - double t; - TVector3 pos; - TVector3 dir; - TVector3 pol; - double wavelength; // nm - unsigned int flag; - int last_hit_triangle; - - ClassDef(Photon, 1); -}; - -struct Channel { - Channel() : id(-1), t(-1e9), q(-1e9) { }; - int id; - double t; - double q; - unsigned int flag; - - ClassDef(Channel, 1); -}; - -struct Event { - int id; - unsigned int nhit; - unsigned int nchannels; - - Vertex primary_vertex; - - std::vector vertices; - std::vector photons_beg; - std::vector photons_end; - std::vector channels; - - ClassDef(Event, 1); -}; - -void fill_channels(Event *ev, unsigned int nhit, unsigned int *ids, float *t, - float *q, unsigned int *flags, unsigned int nchannels) -{ - ev->nhit = 0; - ev->nchannels = nchannels; - ev->channels.resize(0); - - Channel ch; - unsigned int id; - for (unsigned int i=0; i < nhit; i++) { - ev->nhit++; - id = ids[i]; - ch.id = id; - ch.t = t[id]; - ch.q = q[id]; - ch.flag = flags[id]; - ev->channels.push_back(ch); - } -} - -void get_channels(Event *ev, int *hit, float *t, float *q, unsigned int *flags) -{ - for (unsigned int i=0; i < ev->nchannels; i++) { - hit[i] = 0; - t[i] = -1e9f; - q[i] = -1e9f; - flags[i] = 0; - } - - unsigned int id; - for (unsigned int i=0; i < ev->channels.size(); i++) { - id = ev->channels[i].id; - - if (id < ev->nchannels) { - hit[id] = 1; - t[id] = ev->channels[i].t; - q[id] = ev->channels[i].q; - flags[id] = ev->channels[i].flag; - } - } -} - -void get_photons(const std::vector &photons, float *pos, float *dir, - float *pol, float *wavelengths, float *t, - int *last_hit_triangles, unsigned int *flags) -{ - for (unsigned int i=0; i < photons.size(); i++) { - const Photon &photon = photons[i]; - pos[3*i] = photon.pos.X(); - pos[3*i+1] = photon.pos.Y(); - pos[3*i+2] = photon.pos.Z(); - - dir[3*i] = photon.dir.X(); - dir[3*i+1] = photon.dir.Y(); - dir[3*i+2] = photon.dir.Z(); - - pol[3*i] = photon.pol.X(); - pol[3*i+1] = photon.pol.Y(); - pol[3*i+2] = photon.pol.Z(); - - wavelengths[i] = photon.wavelength; - t[i] = photon.t; - flags[i] = photon.flag; - last_hit_triangles[i] = photon.last_hit_triangle; - } -} - -void fill_photons(std::vector &photons, - unsigned int nphotons, float *pos, float *dir, - float *pol, float *wavelengths, float *t, - int *last_hit_triangles, unsigned int *flags) -{ - photons.resize(nphotons); - - for (unsigned int i=0; i < nphotons; i++) { - Photon &photon = photons[i]; - photon.t = t[i]; - photon.pos.SetXYZ(pos[3*i], pos[3*i + 1], pos[3*i + 2]); - photon.dir.SetXYZ(dir[3*i], dir[3*i + 1], dir[3*i + 2]); - photon.pol.SetXYZ(pol[3*i], pol[3*i + 1], pol[3*i + 2]); - photon.wavelength = wavelengths[i]; - photon.last_hit_triangle = last_hit_triangles[i]; - photon.flag = flags[i]; - - } -} - -#ifdef __MAKECINT__ -#pragma link C++ class vector; -#pragma link C++ class vector; -#pragma link C++ class vector; -#endif - - diff --git a/test/test_fileio.py b/test/test_fileio.py deleted file mode 100644 index 3869a9f..0000000 --- a/test/test_fileio.py +++ /dev/null @@ -1,74 +0,0 @@ -import unittest -from chroma.fileio import root -from chroma import event -import numpy as np - -class TestFileIO(unittest.TestCase): - def test_file_write_and_read(self): - ev = event.Event(1, event.Vertex('e-', pos=(0,0,1), dir=(1,0,0), - ke=15.0, pol=(0,1,0))) - - photons_beg = root.make_photon_with_arrays(1) - photons_beg.pos[0] = (1,2,3) - photons_beg.dir[0] = (4,5,6) - photons_beg.pol[0] = (7,8,9) - photons_beg.wavelengths[0] = 400.0 - photons_beg.t[0] = 100.0 - photons_beg.last_hit_triangles[0] = 5 - photons_beg.flags[0] = 20 - ev.photons_beg = photons_beg - - photons_end = root.make_photon_with_arrays(1) - photons_end.pos[0] = (1,2,3) - photons_end.dir[0] = (4,5,6) - photons_end.pol[0] = (7,8,9) - photons_end.wavelengths[0] = 400.0 - photons_end.t[0] = 100.0 - photons_end.last_hit_triangles[0] = 5 - photons_end.flags[0] = 20 - ev.photons_end = photons_end - - ev.vertices = [ev.primary_vertex] - - channels = event.Channels(hit=np.array([True, False]), - t=np.array([20.0, 1e9], dtype=np.float32), - q=np.array([2.0, 0.0], dtype=np.float32), - flags=np.array([8, 32], dtype=np.uint32)) - ev.channels = channels - - filename = '/tmp/chroma-filewritertest.root' - writer = root.RootWriter(filename) - writer.write_event(ev) - writer.close() - - # Exercise the RootReader methods - reader = root.RootReader(filename) - self.assertEquals(len(reader), 1) - - self.assertRaises(StopIteration, reader.prev) - - reader.next() - - self.assertEqual(reader.index(), 0) - self.assertRaises(StopIteration, reader.next) - - reader.jump_to(0) - - # Enough screwing around, let's get the one event in the file - newev = reader.current() - - # Now check if everything is correct in the event - for attribute in ['id']: - self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute) - - for attribute in ['pos', 'dir', 'pol', 'ke', 't0']: - self.assertTrue(np.allclose(getattr(ev.primary_vertex, attribute), getattr(newev.primary_vertex, attribute)), 'compare %s' % attribute) - - for i in range(len(ev.vertices)): - self.assertTrue(np.allclose(getattr(ev.vertices[i], attribute), getattr(newev.vertices[i], attribute)), 'compare %s' % attribute) - - for attribute in ['pos', 'dir', 'pol', 'wavelengths', 't', 'last_hit_triangles', 'flags']: - self.assertTrue(np.allclose(getattr(ev.photons_beg, attribute), - getattr(newev.photons_beg, attribute)), 'compare %s' % attribute) - self.assertTrue(np.allclose(getattr(ev.photons_end, attribute), - getattr(newev.photons_end, attribute)), 'compare %s' % attribute) diff --git a/test/test_io.py b/test/test_io.py new file mode 100644 index 0000000..3553058 --- /dev/null +++ b/test/test_io.py @@ -0,0 +1,74 @@ +import unittest +from chroma.io import root +from chroma import event +import numpy as np + +class TestRootIO(unittest.TestCase): + def test_file_write_and_read(self): + ev = event.Event(1, event.Vertex('e-', pos=(0,0,1), dir=(1,0,0), + ke=15.0, pol=(0,1,0))) + + photons_beg = root.make_photon_with_arrays(1) + photons_beg.pos[0] = (1,2,3) + photons_beg.dir[0] = (4,5,6) + photons_beg.pol[0] = (7,8,9) + photons_beg.wavelengths[0] = 400.0 + photons_beg.t[0] = 100.0 + photons_beg.last_hit_triangles[0] = 5 + photons_beg.flags[0] = 20 + ev.photons_beg = photons_beg + + photons_end = root.make_photon_with_arrays(1) + photons_end.pos[0] = (1,2,3) + photons_end.dir[0] = (4,5,6) + photons_end.pol[0] = (7,8,9) + photons_end.wavelengths[0] = 400.0 + photons_end.t[0] = 100.0 + photons_end.last_hit_triangles[0] = 5 + photons_end.flags[0] = 20 + ev.photons_end = photons_end + + ev.vertices = [ev.primary_vertex] + + channels = event.Channels(hit=np.array([True, False]), + t=np.array([20.0, 1e9], dtype=np.float32), + q=np.array([2.0, 0.0], dtype=np.float32), + flags=np.array([8, 32], dtype=np.uint32)) + ev.channels = channels + + filename = '/tmp/chroma-filewritertest.root' + writer = root.RootWriter(filename) + writer.write_event(ev) + writer.close() + + # Exercise the RootReader methods + reader = root.RootReader(filename) + self.assertEquals(len(reader), 1) + + self.assertRaises(StopIteration, reader.prev) + + reader.next() + + self.assertEqual(reader.index(), 0) + self.assertRaises(StopIteration, reader.next) + + reader.jump_to(0) + + # Enough screwing around, let's get the one event in the file + newev = reader.current() + + # Now check if everything is correct in the event + for attribute in ['id']: + self.assertEqual(getattr(ev, attribute), getattr(newev, attribute), 'compare %s' % attribute) + + for attribute in ['pos', 'dir', 'pol', 'ke', 't0']: + self.assertTrue(np.allclose(getattr(ev.primary_vertex, attribute), getattr(newev.primary_vertex, attribute)), 'compare %s' % attribute) + + for i in range(len(ev.vertices)): + self.assertTrue(np.allclose(getattr(ev.vertices[i], attribute), getattr(newev.vertices[i], attribute)), 'compare %s' % attribute) + + for attribute in ['pos', 'dir', 'pol', 'wavelengths', 't', 'last_hit_triangles', 'flags']: + self.assertTrue(np.allclose(getattr(ev.photons_beg, attribute), + getattr(newev.photons_beg, attribute)), 'compare %s' % attribute) + self.assertTrue(np.allclose(getattr(ev.photons_end, attribute), + getattr(newev.photons_end, attribute)), 'compare %s' % attribute) -- cgit