diff options
Diffstat (limited to 'view.py')
-rwxr-xr-x | view.py | 244 |
1 files changed, 244 insertions, 0 deletions
@@ -0,0 +1,244 @@ +#!/usr/bin/env python +import os +import sys +import numpy as np + +import pygame +from pygame.locals import * + +import src +from camera import * +from geometry import * +from transform import * + +import time + +from pycuda import autoinit +from pycuda.compiler import SourceModule +from pycuda import gpuarray + +def view(viewable, name='', bits=8): + """ + Render `viewable` in a pygame window. + + Movement: + - zoom: scroll the mouse wheel + - rotate: click and drag the mouse + - move: shift+click and drag the mouse + """ + + if isinstance(viewable, Geometry): + geometry = viewable + geometry.build(bits) + elif isinstance(viewable, Solid): + geometry = Geometry() + geometry.add_solid(viewable) + geometry.build(bits) + elif isinstance(viewable, Mesh): + geometry = Geometry() + geometry.add_solid(Solid(viewable)) + geometry.build(bits) + else: + sys.exit("can't display %s" % args[0]) + + lower_bound = np.array([np.min(geometry.mesh[:][:,:,0]), + np.min(geometry.mesh[:][:,:,1]), + np.min(geometry.mesh[:][:,:,2])]) + + upper_bound = np.array([np.max(geometry.mesh[:][:,:,0]), + np.max(geometry.mesh[:][:,:,1]), + np.max(geometry.mesh[:][:,:,2])]) + + scale = np.linalg.norm(upper_bound-lower_bound) + + print 'device %s' % autoinit.device.name() + + module = SourceModule(src.kernel, options=['-I' + src.dir], + no_extern_c=True, cache_dir=False) + texrefs = geometry.load(module, color=True) + cuda_raytrace = module.get_function('ray_trace') + cuda_rotate = module.get_function('rotate') + cuda_translate = module.get_function('translate') + + pygame.init() + size = width, height = 800, 600 + screen = pygame.display.set_mode(size) + pygame.display.set_caption(name) + + camera = Camera(size) + + diagonal = np.linalg.norm(upper_bound-lower_bound) + + point = np.array([0, diagonal*1.75, (lower_bound[2]+upper_bound[2])/2]) + axis1 = np.array([1,0,0], dtype=np.double) + axis2 = np.array([0,0,1], dtype=np.double) + + camera.position(point) + + origins, directions = camera.get_rays() + + origins_float3 = np.empty(origins.shape[0], dtype=gpuarray.vec.float3) + origins_float3['x'] = origins[:,0] + origins_float3['y'] = origins[:,1] + origins_float3['z'] = origins[:,2] + + directions_float3 = np.empty(directions.shape[0], dtype=gpuarray.vec.float3) + directions_float3['x'] = directions[:,0] + directions_float3['y'] = directions[:,1] + directions_float3['z'] = directions[:,2] + + origins_gpu = cuda.to_device(origins_float3) + directions_gpu = cuda.to_device(directions_float3) + + pixels = np.empty(width*height, dtype=np.int32) + pixels_gpu = cuda.to_device(pixels) + + nblocks = 64 + + gpu_kwargs = {'block': (nblocks,1,1), 'grid':(pixels.size/nblocks+1,1)} + + def render(): + """Render the mesh and display to screen.""" + t0 = time.time() + cuda_raytrace(np.int32(pixels.size), origins_gpu, directions_gpu, np.int32(geometry.node_map.size-1), np.int32(geometry.first_node), pixels_gpu, texrefs=texrefs, **gpu_kwargs) + cuda.Context.synchronize() + elapsed = time.time() - t0 + + print 'elapsed %f sec' % elapsed + + cuda.memcpy_dtoh(pixels, pixels_gpu) + pygame.surfarray.blit_array(screen, pixels.reshape(size)) + pygame.display.flip() + + render() + + done = False + clicked = False + shift = False + + while not done: + for event in pygame.event.get(): + if event.type == MOUSEBUTTONDOWN: + if event.button == 4: + v = scale*np.cross(axis1,axis2)/10.0 + + cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs) + + point += v + + render() + + if event.button == 5: + v = -scale*np.cross(axis1,axis2)/10.0 + + cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs) + + point += v + + render() + + if event.button == 1: + clicked = True + mouse_position = pygame.mouse.get_rel() + + if event.type == MOUSEBUTTONUP: + if event.button == 1: + clicked = False + + if event.type == MOUSEMOTION and clicked: + movement = np.array(pygame.mouse.get_rel()) + + if (movement == 0).all(): + continue + + length = np.linalg.norm(movement) + + mouse_direction = movement[0]*axis1 + movement[1]*axis2 + mouse_direction /= np.linalg.norm(mouse_direction) + + if shift: + v = mouse_direction*scale*length/float(width) + + cuda_translate(np.int32(pixels.size), origins_gpu, gpuarray.vec.make_float3(*v), **gpu_kwargs) + + point += v + + render() + else: + phi = np.float32(2*np.pi*length/float(width)) + n = rotate(mouse_direction, np.pi/2, \ + -np.cross(axis1,axis2)) + + cuda_rotate(np.int32(pixels.size), origins_gpu, phi, gpuarray.vec.make_float3(*n), **gpu_kwargs) + + cuda_rotate(np.int32(pixels.size), directions_gpu, phi, gpuarray.vec.make_float3(*n), **gpu_kwargs) + + point = rotate(point, phi, n) + axis1 = rotate(axis1, phi, n) + axis2 = rotate(axis2, phi, n) + + render() + + if event.type == KEYDOWN: + if event.key == K_LSHIFT or event.key == K_RSHIFT: + shift = True + + if event.key == K_ESCAPE: + done = True + break + + if event.key == K_F12: + if name == '': + root, ext = 'screenshot', 'png' + else: + root, ext = name, 'png' + + filename = '.'.join([root, ext]) + + i = 1 + while os.path.exists(filename): + filename = '.'.join([root + str(i), ext]) + i += 1 + + pygame.image.save(screen, filename) + print 'image saved to %s' % filename + + if event.type == KEYUP: + if event.key == K_LSHIFT or event.key == K_RSHIFT: + shift = False + + pygame.display.quit() + +if __name__ == '__main__': + import optparse + + from stl import mesh_from_stl + + import solids + import detectors + + parser = optparse.OptionParser('%prog filename.stl') + parser.add_option('-b', '--bits', type='int', dest='bits', + help='bits for z-ordering space axes', default=8) + options, args = parser.parse_args() + + if len(args) < 1: + sys.exit(parser.format_help()) + + head, tail = os.path.split(args[0]) + root, ext = os.path.splitext(tail) + + if ext.lower() == '.stl': + geometry = Geometry() + geometry.add_solid(Solid(mesh_from_stl(args[0]))) + geometry.build(options.bits) + view(geometry, tail) + else: + import inspect + + members = dict(inspect.getmembers(detectors) + inspect.getmembers(solids)) + + if args[0] in members: + view(members[args[0]], args[0], options.bits) + else: + sys.exit("couldn't find object %s" % args[0]) |