summaryrefslogtreecommitdiff
path: root/vector.py
blob: 7b36e07a14dda05775e05ef5dc2311f3cd1688c9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
import numpy as np
from pycuda import gpuarray

def make_vector(arr, dtype=gpuarray.vec.float3):
    if len(arr.shape) != 2 or arr.shape[-1] != 3:
        raise Exception('shape mismatch')

    x = np.empty(arr.shape[0], dtype)
    x['x'] = arr[:,0]
    x['y'] = arr[:,1]
    x['z'] = arr[:,2]

    return x