diff options
Diffstat (limited to 'generator/g4gen.py')
-rw-r--r-- | generator/g4gen.py | 96 |
1 files changed, 38 insertions, 58 deletions
diff --git a/generator/g4gen.py b/generator/g4gen.py index 9650b68..8dca086 100644 --- a/generator/g4gen.py +++ b/generator/g4gen.py @@ -4,7 +4,7 @@ import g4py.NISTmaterials import g4py.ParticleGun import pyublas import numpy as np -import chroma.event as event +from chroma.event import Photons, Vertex try: import G4chroma @@ -20,16 +20,18 @@ except: class G4Generator(object): def __init__(self, material, seed=None): - '''Create generator to produce photons inside the specified material. + """Create generator to produce photons inside the specified material. material: chroma.geometry.Material object with density, composition dict and refractive_index. composition dictionary should be { element_symbol : fraction_by_weight, ... }. - seed: Random number generator seed for HepRandom. If None, - generator is not seeded. - ''' + + seed: int, *optional* + Random number generator seed for HepRandom. If None, generator + is not seeded. + """ if seed is not None: HepRandom.setTheSeed(seed) @@ -51,11 +53,8 @@ class G4Generator(object): gRunManager.SetUserAction(self.tracking_action) gRunManager.Initialize() - #preinitialize the process by running a simple event - self.generate_photons(event.Event(event_id=0, particle_name='e-', - gen_position=(0,0,0), - gen_direction=(1,0,0), - gen_total_energy=1.0)) + # preinitialize the process by running a simple event + self.generate_photons([Vertex('e-', (0,0,0), (1,0,0), 0, 1.0)]) def create_g4material(self, material): g4material = G4Material('world_material', material.density * g / cm3, @@ -96,57 +95,38 @@ class G4Generator(object): pol[:,2] = self.tracking_action.GetPolZ() wavelengths = self.tracking_action.GetWavelength().astype(np.float32) - times = self.tracking_action.GetT0().astype(np.float32) - return event.Photons(positions=pos, directions=dir, polarizations=pol, times=times, wavelengths=wavelengths) + t0 = self.tracking_action.GetT0().astype(np.float32) + + return Photons(pos, dir, pol, wavelengths, t0) - def generate_photons(self, ev): - '''Use GEANT4 to generate photons produced by the given particle. + def generate_photons(self, vertices): + """Use GEANT4 to generate photons produced by propagating `vertices`. - ev: a generator.event.Event object with the particle - properties set. If it contains subtracks, those - will be used to create the photon vertices rather - than the main particle. - - Returns an instance of event.Photons containing the - generated photon vertices for the primary particle or - all the subtracks, if present. - ''' - photons = [] - if ev.subtracks: - subtracks = ev.subtracks - else: - # Create temporary subtrack for single primary particle - subtracks = [event.Subtrack(particle_name=ev.particle_name, - position=ev.gen_position, - direction=ev.gen_direction, - start_time=0.0, - total_energy=ev.gen_total_energy)] - - for subtrack in subtracks: - self.particle_gun.SetParticleByName(subtrack.particle_name) - self.particle_gun.SetParticleEnergy(subtrack.total_energy * MeV) - self.particle_gun.SetParticlePosition(G4ThreeVector(*subtrack.position)*m) - self.particle_gun.SetParticleMomentumDirection(G4ThreeVector(*subtrack.direction).unit()) + Args: + vertices: list of event.Vertex objects + List of initial particle vertices. + + Returns: + photons: event.Photons + Photon vertices generated by the propagation of `vertices`. + """ + photons = None + + for vertex in vertices: + self.particle_gun.SetParticleByName(vertex.particle_name) + mass = G4ParticleTable.GetParticleTable().FindParticle(vertex.particle_name).GetPDGMass() + total_energy = vertex.ke*MeV + mass + self.particle_gun.SetParticleEnergy(total_energy) + self.particle_gun.SetParticlePosition(G4ThreeVector(*vertex.pos)*m) + self.particle_gun.SetParticleMomentumDirection(G4ThreeVector(*vertex.dir).unit()) self.tracking_action.Clear() gRunManager.BeamOn(1) - photons.append(self._extract_photons_from_tracking_action()) - - # Merge all photon lists into one big list - return event.concatenate_photons(photons) - -if __name__ == '__main__': - import time - import optics - gen = G4Generator(optics.water) - - start = time.time() - n = 0 - for i in xrange(100): - photons = gen.generate_photons(event.Event('mu-', (0,0,0), (1,0,0), 1.0)) - n += len(photons.times) - print photons.positions[0].min(), photons.positions[0].max() - stop = time.time() - print stop - start, 'sec' - print n / (stop-start), 'photons/sec' + + if photons is None: + photons = self._extract_photons_from_tracking_action() + else: + photons += self._extract_photons_from_tracking_action() + + return photons |