diff options
Diffstat (limited to 'sim.py')
-rwxr-xr-x | sim.py | 120 |
1 files changed, 27 insertions, 93 deletions
@@ -3,15 +3,14 @@ import sys import optparse import time import os -import multiprocessing +import numpy as np import detectors import optics +import generator +from generator import constant import gpu -import g4gen from fileio import root -import numpy as np -import math import ROOT def pick_seed(): @@ -32,56 +31,6 @@ def info(type, value, tb): # ...then start the debugger in post-mortem mode. pdb.pm() -class GeneratorProcess(multiprocessing.Process): - def __init__(self, particle, energy, position, direction, nevents, material, - queue, seed=None): - multiprocessing.Process.__init__(self) - - self.particle = particle - self.energy = energy - self.position = position - self.direction = direction - self.nevents = nevents - self.material = material - self.seed = seed - self.queue = queue - self.daemon = True - - def run(self): - print >>sys.stderr, 'Starting generator thread...' - generator = g4gen.G4Generator(self.material, seed=self.seed) - - for i in xrange(self.nevents): - if self.particle == 'pi0': - photons = generator.generate_pi0(total_energy=self.energy, - position=self.position, - direction=self.direction) - else: - photons = generator.generate_photons(particle_name=self.particle, - total_energy=self.energy, - position=self.position, - direction=self.direction) - self.queue.put(photons) - - -def partition(num, partitions): - '''Generator that returns num//partitions, with the last item including the remainder. - - Useful for partitioning a number into mostly equal parts while preserving the sum. - - >>> list(partition(800, 3)) - [266, 266, 268] - >>> sum(list(partition(800, 3))) - 800 - ''' - step = num // partitions - for i in xrange(partitions): - if i < partitions - 1: - yield step - else: - yield step + (num % partitions) - - # Allow profile decorator to exist, but do nothing if not running under kernprof try: @@ -132,25 +81,23 @@ def main(): print >>sys.stderr, 'RNG seed:', options.seed print >>sys.stderr, 'Creating generator...' + if options.particle == 'pi0': + vertex_generator = generator.vertex.pi0_gun(pi0_position=constant(position), + pi0_direction=constant(direction), + pi0_total_energy=constant(options.energy)) + else: + vertex_generator = generator.vertex.particle_gun(particle_name=constant(options.particle), + position=constant(position), + direction=constant(direction), + total_energy=constant(options.energy)) detector_material = optics.water_wcsim - queue = multiprocessing.Queue() - generators = [GeneratorProcess(particle=options.particle, - energy=options.energy, - position=position, - direction=direction, - nevents=nevents, - material=detector_material, - seed=options.seed + seed_offset, - queue=queue) - for seed_offset, nevents in - enumerate(partition(options.nevents, options.ngenerators))] - + photon_generator = generator.photon.G4ParallelGenerator(options.ngenerators, detector_material, + base_seed=options.seed) print >>sys.stderr, 'WARNING: ASSUMING DETECTOR IS WCSIM WATER!!' # Do this now so we can get ahead of the photon propagation print >>sys.stderr, 'Starting GEANT4 generators...' - for generator in generators: - generator.start() + event_iterator = photon_generator.generate_events(options.nevents, vertex_generator) print >>sys.stderr, 'Creating BVH for detector "%s" with %d bits...' % (options.detector, options.nbits) detector.build(bits=options.nbits) @@ -168,41 +115,28 @@ def main(): # Create output file writer = root.RootWriter(output_filename) - # Set generator info - writer.set_generated_particle(name=options.particle, position=position, - direction=direction, total_e=options.energy) - print >>sys.stderr, 'Starting simulation...' start_sim = time.time() nphotons = 0 - for i in xrange(options.nevents): - photons = queue.get() - assert len(photons['pos']) > 0, 'GEANT4 generated event with no photons!' + for i, ev in enumerate(event_iterator): + photons = ev.photon_start + assert len(photons.positions) > 0, 'GEANT4 generated event with no photons!' - nphotons += len(photons['pos']) + nphotons += len(photons.positions) + + gpu_worker.load_photons(photons) - gpu_worker.load_photons(pos=photons['pos'], dir=photons['dir'], pol=photons['pol'], - t0=photons['t0'], wavelength=photons['wavelength']) gpu_worker.propagate() gpu_worker.run_daq() - hits = gpu_worker.get_hits() - if options.save_photon_start: - photon_start = photons - else: - photon_start = None + ev.hits = gpu_worker.get_hits() + if not options.save_photon_start: + ev.photon_start = None if options.save_photon_stop: - photon_stop = gpu_worker.get_photons() - else: - photon_stop = None - - if 'subtracks' in photons: - subtracks = photons['subtracks'] - else: - subtracks = None - writer.write_event(i, hits, photon_start=photon_start, photon_stop=photon_stop, - subtracks=subtracks) + ev.photon_stop = gpu_worker.get_photons() + writer.write_event(ev) + if i % 10 == 0: print >>sys.stderr, "\rEvent:", i, |