summaryrefslogtreecommitdiff
path: root/generator/g4gen.py
diff options
context:
space:
mode:
authorStan Seibert <stan@mtrr.org>2011-08-16 13:52:00 -0400
committerStan Seibert <stan@mtrr.org>2011-08-16 13:52:00 -0400
commit7d9b50e9e64c9d8d9a25942e2ffaca52142c6c2b (patch)
tree6eaf16ef125df0b02cff8198e6bece51a093c8fa /generator/g4gen.py
parent0dbfc2d7dc547452372d776ba74ec77838300a9a (diff)
downloadchroma-7d9b50e9e64c9d8d9a25942e2ffaca52142c6c2b.tar.gz
chroma-7d9b50e9e64c9d8d9a25942e2ffaca52142c6c2b.tar.bz2
chroma-7d9b50e9e64c9d8d9a25942e2ffaca52142c6c2b.zip
Epic restructuring of code to switch to a generator-based style of
event creation. Now we have vertex generators (that produce initial particles), photon generators (that create photons to propagate), and a standard data structure using Python class containers and numpy arrays to hand around the code. Also cleaned up some naming of things before they become conventions.
Diffstat (limited to 'generator/g4gen.py')
-rw-r--r--generator/g4gen.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/generator/g4gen.py b/generator/g4gen.py
new file mode 100644
index 0000000..d4e79a1
--- /dev/null
+++ b/generator/g4gen.py
@@ -0,0 +1,147 @@
+from Geant4 import *
+import g4py.ezgeom
+import g4py.NISTmaterials
+import g4py.ParticleGun
+import pyublas
+import numpy as np
+import event
+
+try:
+ import G4chroma
+except:
+ # Try building the module
+ import subprocess
+ import sys, os
+ module_dir = os.path.split(os.path.realpath(__file__))[0]
+ print >>sys.stderr, 'Compiling G4chroma.so...'
+ retcode = subprocess.call('g++ -o \'%s/G4chroma.so\' -shared \'%s/G4chroma.cc\' -fPIC `geant4-config --cflags --libs` `python-config --cflags --libs --ldflags` -lboost_python' % (module_dir, module_dir), shell=True)
+ assert retcode == 0
+ import G4chroma
+
+class G4Generator(object):
+ def __init__(self, material, seed=None):
+ '''Create generator to produce photons inside the specified material.
+
+ material: chroma.geometry.Material object with density,
+ composition dict and refractive_index.
+
+ composition dictionary should be
+ { element_symbol : fraction_by_weight, ... }.
+ seed: Random number generator seed for HepRandom. If None,
+ generator is not seeded.
+ '''
+ if seed is not None:
+ HepRandom.setTheSeed(seed)
+
+ g4py.NISTmaterials.Construct()
+ g4py.ezgeom.Construct()
+ self.physics_list = G4chroma.ChromaPhysicsList()
+ gRunManager.SetUserInitialization(self.physics_list)
+ self.particle_gun = g4py.ParticleGun.Construct()
+
+ self.world_material = self.create_g4material(material)
+ g4py.ezgeom.SetWorldMaterial(self.world_material)
+
+ self.world = g4py.ezgeom.G4EzVolume('world')
+ self.world.CreateBoxVolume(self.world_material, 100*m, 100*m, 100*m)
+ self.world.PlaceIt(G4ThreeVector(0,0,0))
+
+ self.tracking_action = G4chroma.PhotonTrackingAction()
+ gRunManager.SetUserAction(self.tracking_action)
+ gRunManager.Initialize()
+
+ def create_g4material(self, material):
+ g4material = G4Material('world_material', material.density * g / cm3,
+ len(material.composition))
+
+ # Add elements
+ for element_name, element_frac_by_weight in material.composition.items():
+ g4material.AddElement(G4Element.GetElement(element_name, True),
+ element_frac_by_weight)
+
+ # Set index of refraction
+ prop_table = G4MaterialPropertiesTable()
+ # Reverse entries so they are in ascending energy order rather
+ # than wavelength
+ energy = list((2*pi*hbarc / (material.refractive_index[::-1,0] * nanometer)).astype(float))
+ values = list(material.refractive_index[::-1, 1].astype(float))
+ prop_table.AddProperty('RINDEX', energy, values)
+
+ # Load properties
+ g4material.SetMaterialPropertiesTable(prop_table)
+ return g4material
+
+ def _extract_photons_from_tracking_action(self):
+ n = self.tracking_action.GetNumPhotons()
+ pos = np.zeros(shape=(n,3), dtype=np.float32)
+ pos[:,0] = self.tracking_action.GetX()
+ pos[:,1] = self.tracking_action.GetY()
+ pos[:,2] = self.tracking_action.GetZ()
+
+ dir = np.zeros(shape=(n,3), dtype=np.float32)
+ dir[:,0] = self.tracking_action.GetDirX()
+ dir[:,1] = self.tracking_action.GetDirY()
+ dir[:,2] = self.tracking_action.GetDirZ()
+
+ pol = np.zeros(shape=(n,3), dtype=np.float32)
+ pol[:,0] = self.tracking_action.GetPolX()
+ pol[:,1] = self.tracking_action.GetPolY()
+ pol[:,2] = self.tracking_action.GetPolZ()
+
+ wavelengths = self.tracking_action.GetWavelength().astype(np.float32)
+ times = self.tracking_action.GetT0().astype(np.float32)
+
+ return event.Photons(positions=pos, directions=dir, polarizations=pol, times=times, wavelengths=wavelengths)
+
+ def generate_photons(self, ev):
+ '''Use GEANT4 to generate photons produced by the given particle.
+
+ ev: a generator.event.Event object with the particle
+ properties set. If it contains subtracks, those
+ will be used to create the photon vertices rather
+ than the main particle.
+
+ Returns an instance of event.Photons containing the
+ generated photon vertices for the primary particle or
+ all the subtracks, if present.
+ '''
+ photons = []
+ if ev.subtracks:
+ subtracks = ev.subtracks
+ else:
+ # Create temporary subtrack for single primary particle
+ subtracks = [event.Subtrack(particle_name=ev.particle_name,
+ position=ev.gen_pos,
+ direction=ev.gen_dir,
+ start_time=0.0,
+ total_energy=ev.gen_total_energy)]
+
+ for subtrack in subtracks:
+ self.particle_gun.SetParticleByName(subtrack.particle_name)
+ self.particle_gun.SetParticleEnergy(subtrack.total_energy * MeV)
+ self.particle_gun.SetParticlePosition(G4ThreeVector(*subtrack.position)*m)
+ self.particle_gun.SetParticleMomentumDirection(G4ThreeVector(*subtrack.direction).unit())
+
+ self.tracking_action.Clear()
+ gRunManager.BeamOn(1)
+ photons.append(self._extract_photons_from_tracking_action())
+
+ # Merge all photon lists into one big list
+ return event.concatenate_photons(photons)
+
+if __name__ == '__main__':
+ import time
+ import optics
+ gen = G4Generator(optics.water)
+ # prime things
+ gen.generate_photons(event.Event('e-', (0,0,0), (1,0,0), 1.0))
+
+ start = time.time()
+ n = 0
+ for i in xrange(100):
+ photons = gen.generate_photons(event.Event('mu-', (0,0,0), (1,0,0), 1.0))
+ n += len(photons.times)
+ print photons.positions[0].min(), photons.positions[0].max()
+ stop = time.time()
+ print stop - start, 'sec'
+ print n / (stop-start), 'photons/sec'