diff options
Diffstat (limited to 'src/kernel.cu')
-rw-r--r-- | src/kernel.cu | 27 |
1 files changed, 17 insertions, 10 deletions
diff --git a/src/kernel.cu b/src/kernel.cu index 584c9e4..55fde70 100644 --- a/src/kernel.cu +++ b/src/kernel.cu @@ -141,26 +141,33 @@ __device__ int intersect_mesh(const float3 &origin, const float3& direction, con return triangle_index; } -__device__ curandState rng_states[100000]; - extern "C" { - __global__ void set_pointer(uint4 *triangle_ptr, float3 *vertex_ptr) +__global__ void fill_float(int nthreads, float *a, float value) { - triangles = triangle_ptr; - vertices = vertex_ptr; + int id = blockIdx.x*blockDim.x + threadIdx.x; + + if (id >= nthreads) + return; + + a[id] = value; } -/* Initialize random number states */ -__global__ void init_rng(int nthreads, unsigned long long seed, unsigned long long offset) +__global__ void fill_float3(int nthreads, float3 *a, float3 value) { int id = blockIdx.x*blockDim.x + threadIdx.x; if (id >= nthreads) return; - curand_init(seed, id, offset, rng_states+id); + a[id] = value; +} + +__global__ void set_pointer(uint4 *triangle_ptr, float3 *vertex_ptr) +{ + triangles = triangle_ptr; + vertices = vertex_ptr; } /* Translate `points` by the vector `v` */ @@ -223,7 +230,7 @@ __global__ void ray_trace(int nthreads, float3 *positions, float3 *directions, i } // ray_trace -__global__ void propagate(int nthreads, float3 *positions, float3 *directions, float *wavelengths, float3 *polarizations, float *times, int *states, int *last_hit_triangles, int start_node, int first_node, int max_steps) +__global__ void propagate(int nthreads, curandState *rng_states, float3 *positions, float3 *directions, float *wavelengths, float3 *polarizations, float *times, int *states, int *last_hit_triangles, int start_node, int first_node, int max_steps) { int id = blockIdx.x*blockDim.x + threadIdx.x; @@ -445,8 +452,8 @@ __global__ void propagate(int nthreads, float3 *positions, float3 *directions, f } // while(nsteps < max_steps) - states[id] = state; rng_states[id] = rng; + states[id] = state; positions[id] = position; directions[id] = direction; polarizations[id] = polarization; |