summaryrefslogtreecommitdiff
path: root/view.py
diff options
context:
space:
mode:
Diffstat (limited to 'view.py')
-rwxr-xr-xview.py244
1 files changed, 244 insertions, 0 deletions
diff --git a/view.py b/view.py
new file mode 100755
index 0000000..952964f
--- /dev/null
+++ b/view.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+import os
+import sys
+import numpy as np
+
+import pygame
+from pygame.locals import *
+
+import src
+from camera import *
+from geometry import *
+from transform import *
+
+import time
+
+from pycuda import autoinit
+from pycuda.compiler import SourceModule
+from pycuda import gpuarray
+
+def view(viewable, name='', bits=8):
+ """
+ Render `viewable` in a pygame window.
+
+ Movement:
+ - zoom: scroll the mouse wheel
+ - rotate: click and drag the mouse
+ - move: shift+click and drag the mouse
+ """
+
+ if isinstance(viewable, Geometry):
+ geometry = viewable
+ geometry.build(bits)
+ elif isinstance(viewable, Solid):
+ geometry = Geometry()
+ geometry.add_solid(viewable)
+ geometry.build(bits)
+ elif isinstance(viewable, Mesh):
+ geometry = Geometry()
+ geometry.add_solid(Solid(viewable))
+ geometry.build(bits)
+ else:
+ sys.exit("can't display %s" % args[0])
+
+ lower_bound = np.array([np.min(geometry.mesh[:][:,:,0]),
+ np.min(geometry.mesh[:][:,:,1]),
+ np.min(geometry.mesh[:][:,:,2])])
+
+ upper_bound = np.array([np.max(geometry.mesh[:][:,:,0]),
+ np.max(geometry.mesh[:][:,:,1]),
+ np.max(geometry.mesh[:][:,:,2])])
+
+ scale = np.linalg.norm(upper_bound-lower_bound)
+
+ print 'device %s' % autoinit.device.name()
+
+ module = SourceModule(src.kernel, options=['-I' + src.dir],
+ no_extern_c=True, cache_dir=False)
+ texrefs = geometry.load(module, color=True)
+ cuda_raytrace = module.get_function('ray_trace')
+ cuda_rotate = module.get_function('rotate')
+ cuda_translate = module.get_function('translate')
+
+ pygame.init()
+ size = width, height = 800, 600
+ screen = pygame.display.set_mode(size)
+ pygame.display.set_caption(name)
+
+ camera = Camera(size)
+
+ diagonal = np.linalg.norm(upper_bound-lower_bound)
+
+ point = np.array([0, diagonal*1.75, (lower_bound[2]+upper_bound[2])/2])
+ axis1 = np.array([1,0,0], dtype=np.double)
+ axis2 = np.array([0,0,1], dtype=np.double)
+
+ camera.position(point)
+
+ origins, directions = camera.get_rays()
+
+ origins_float3 = np.empty(origins.shape[0], dtype=gpuarray.vec.float3)
+ origins_float3['x'] = origins[:,0]
+ origins_float3['y'] = origins[:,1]
+ origins_float3['z'] = origins[:,2]
+
+ directions_float3 = np.empty(directions.shape[0], dtype=gpuarray.vec.float3)
+ directions_float3['x'] = directions[:,0]
+ directions_float3['y'] = directions[:,1]
+ directions_float3['z'] = directions[:,2]
+
+ origins_gpu = cuda.to_device(origins_float3)
+ directions_gpu = cuda.to_device(directions_float3)
+
+ pixels = np.empty(width*height, dtype=np.int32)
+ pixels_gpu = cuda.to_device(pixels)
+
+ nblocks = 64
+
+ gpu_kwargs = {'block': (nblocks,1,1), 'grid':(pixels.size/nblocks+1,1)}
+
+ def render():
+ """Render the mesh and display to screen."""
+ t0 = time.time()
+ cuda_raytrace(np.int32(pixels.size), origins_gpu, directions_gpu, np.int32(geometry.node_map.size-1), np.int32(geometry.first_node), pixels_gpu, texrefs=texrefs, **gpu_kwargs)
+ cuda.Context.synchronize()
+ elapsed = time.time() - t0
+
+ print 'elapsed %f sec' % elapsed
+
+ cuda.memcpy_dtoh(pixels, pixels_gpu)
+ pygame.surfarray.blit_array(screen, pixels.reshape(size))
+ pygame.display.flip()
+
+ render()
+
+ done = False
+ clicked = False
+ shift = False
+
+ while not done:
+ for event in pygame.event.get():
+ if event.type == MOUSEBUTTONDOWN:
+ if event.button == 4:
+ v = scale*np.cross(axis1,axis2)/10.0
+
+ cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs)
+
+ point += v
+
+ render()
+
+ if event.button == 5:
+ v = -scale*np.cross(axis1,axis2)/10.0
+
+ cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs)
+
+ point += v
+
+ render()
+
+ if event.button == 1:
+ clicked = True
+ mouse_position = pygame.mouse.get_rel()
+
+ if event.type == MOUSEBUTTONUP:
+ if event.button == 1:
+ clicked = False
+
+ if event.type == MOUSEMOTION and clicked:
+ movement = np.array(pygame.mouse.get_rel())
+
+ if (movement == 0).all():
+ continue
+
+ length = np.linalg.norm(movement)
+
+ mouse_direction = movement[0]*axis1 + movement[1]*axis2
+ mouse_direction /= np.linalg.norm(mouse_direction)
+
+ if shift:
+ v = mouse_direction*scale*length/float(width)
+
+ cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs)
+
+ point += v
+
+ render()
+ else:
+ phi = np.float32(2*np.pi*length/float(width))
+ n = rotate(mouse_direction, np.pi/2, \
+ -np.cross(axis1,axis2))
+
+ cuda_rotate(np.int32(pixels.size), origins_gpu, phi, gpuarray.vec.make_float3(*n), **gpu_kwargs)
+
+ cuda_rotate(np.int32(pixels.size), directions_gpu, phi, gpuarray.vec.make_float3(*n), **gpu_kwargs)
+
+ point = rotate(point, phi, n)
+ axis1 = rotate(axis1, phi, n)
+ axis2 = rotate(axis2, phi, n)
+
+ render()
+
+ if event.type == KEYDOWN:
+ if event.key == K_LSHIFT or event.key == K_RSHIFT:
+ shift = True
+
+ if event.key == K_ESCAPE:
+ done = True
+ break
+
+ if event.key == K_F12:
+ if name == '':
+ root, ext = 'screenshot', 'png'
+ else:
+ root, ext = name, 'png'
+
+ filename = '.'.join([root, ext])
+
+ i = 1
+ while os.path.exists(filename):
+ filename = '.'.join([root + str(i), ext])
+ i += 1
+
+ pygame.image.save(screen, filename)
+ print 'image saved to %s' % filename
+
+ if event.type == KEYUP:
+ if event.key == K_LSHIFT or event.key == K_RSHIFT:
+ shift = False
+
+ pygame.display.quit()
+
+if __name__ == '__main__':
+ import optparse
+
+ from stl import mesh_from_stl
+
+ import solids
+ import detectors
+
+ parser = optparse.OptionParser('%prog filename.stl')
+ parser.add_option('-b', '--bits', type='int', dest='bits',
+ help='bits for z-ordering space axes', default=8)
+ options, args = parser.parse_args()
+
+ if len(args) < 1:
+ sys.exit(parser.format_help())
+
+ head, tail = os.path.split(args[0])
+ root, ext = os.path.splitext(tail)
+
+ if ext.lower() == '.stl':
+ geometry = Geometry()
+ geometry.add_solid(Solid(mesh_from_stl(args[0])))
+ geometry.build(options.bits)
+ view(geometry, tail)
+ else:
+ import inspect
+
+ members = dict(inspect.getmembers(detectors) + inspect.getmembers(solids))
+
+ if args[0] in members:
+ view(members[args[0]], args[0], options.bits)
+ else:
+ sys.exit("couldn't find object %s" % args[0])