summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--linalg.h10
-rw-r--r--matrix.h18
-rw-r--r--tests/linalg_test.cu16
-rw-r--r--tests/linalg_test.py22
4 files changed, 63 insertions, 3 deletions
diff --git a/linalg.h b/linalg.h
index 8362ab1..593aa9d 100644
--- a/linalg.h
+++ b/linalg.h
@@ -1,6 +1,11 @@
#ifndef __LINALG_H__
#define __LINALG_H__
+__device__ __host__ float3 operator- (const float3 &a)
+{
+ return make_float3(-a.x, -a.y, -a.z);
+}
+
__device__ __host__ float3 operator+ (const float3 &a, const float3 &b)
{
return make_float3(a.x+b.x, a.y+b.y, a.z+b.z);
@@ -103,4 +108,9 @@ __device__ __host__ float3 cross(const float3 &a, const float3 &b)
return make_float3(a.y*b.z-a.z*b.y, a.z*b.x-a.x*b.z, a.x*b.y-a.y*b.x);
}
+__device__ __host__ float norm(const float3 &a)
+{
+ return sqrtf(dot(a,a));
+}
+
#endif
diff --git a/matrix.h b/matrix.h
index b363571..14e04a4 100644
--- a/matrix.h
+++ b/matrix.h
@@ -14,6 +14,11 @@ __device__ __host__ Matrix make_matrix(float a00, float a01, float a02,
return m;
}
+__device__ __host__ Matrix make_matrix(float3 &u1, float3 &u2, float3 &u3)
+{
+ Matrix m = {u1.x, u2.x, u3.x, u1.y, u2.y, u3.y, u1.z, u2.z, u3.z};
+}
+
__device__ __host__ float3 operator* (const Matrix &m, const float3 &a)
{
return make_float3(m.a00*a.x + m.a01*a.y + m.a02*a.z,
@@ -202,6 +207,19 @@ __device__ __host__ Matrix inv(const Matrix &m)
m.a00*m.a11 - m.a01*m.a10)/det(m);
}
+__device__ __host__ Matrix inv(const Matrix&m, const float determinant)
+{
+ return make_matrix(m.a11*m.a22 - m.a12*m.a21,
+ m.a02*m.a21 - m.a01*m.a22,
+ m.a01*m.a12 - m.a02*m.a11,
+ m.a12*m.a20 - m.a10*m.a22,
+ m.a00*m.a22 - m.a02*m.a20,
+ m.a02*m.a10 - m.a00*m.a12,
+ m.a10*m.a21 - m.a11*m.a20,
+ m.a01*m.a20 - m.a00*m.a21,
+ m.a00*m.a11 - m.a01*m.a10)/determinant;
+}
+
__device__ __host__ Matrix outer(const float3 &a, const float3 &b)
{
return make_matrix(a.x*b.x, a.x*b.y, a.x*b.z,
diff --git a/tests/linalg_test.cu b/tests/linalg_test.cu
index bce5ea6..b61488f 100644
--- a/tests/linalg_test.cu
+++ b/tests/linalg_test.cu
@@ -99,16 +99,28 @@ __global__ void floatdivfloat3(float3 *a, float c, float3 *dest)
dest[idx] = c/a[idx];
}
-__global__ void dot(float3 *a, float3 *b, float* dest)
+__global__ void dot(float3 *a, float3 *b, float *dest)
{
int idx = blockIdx.x*blockDim.x + threadIdx.x;
dest[idx] = dot(a[idx],b[idx]);
}
-__global__ void cross(float3 *a, float3 *b, float3* dest)
+__global__ void cross(float3 *a, float3 *b, float3 *dest)
{
int idx = blockIdx.x*blockDim.x + threadIdx.x;
dest[idx] = cross(a[idx],b[idx]);
}
+__global__ void norm(float3 *a, float *dest)
+{
+ int idx = blockIdx.x*blockDim.x + threadIdx.x;
+ dest[idx] = norm(a[idx]);
+}
+
+__global__ void minusfloat3(float3 *a, float3 *dest)
+{
+ int idx = blockIdx.x*blockDim.x + threadIdx.x;
+ dest[idx] = -a[idx];
+}
+
} // extern "c"
diff --git a/tests/linalg_test.py b/tests/linalg_test.py
index f5e947e..44c4b52 100644
--- a/tests/linalg_test.py
+++ b/tests/linalg_test.py
@@ -30,8 +30,10 @@ float3divfloatequal = mod.get_function('float3divfloatequal')
floatdivfloat3 = mod.get_function('floatdivfloat3')
dot = mod.get_function('dot')
cross = mod.get_function('cross')
+norm = mod.get_function('norm')
+minusfloat3 = mod.get_function('minusfloat3')
-size = {'block': (100,1,1), 'grid': (1,1)}
+size = {'block': (256,1,1), 'grid': (1,1)}
a = np.empty(size['block'][0], dtype=float3)
b = np.empty(size['block'][0], dtype=float3)
@@ -187,4 +189,22 @@ def testcross():
if not np.allclose(wdest['x'], w[0]) or \
not np.allclose(wdest['y'], w[1]) or \
not np.allclose(wdest['z'], w[2]):
+ print w
+ print wdest
assert False
+
+def testnorm():
+ dest = np.empty(a.size, dtype=np.float32)
+ norm(cuda.In(a), cuda.Out(dest), **size)
+
+ for i in range(len(dest)):
+ if not np.allclose(np.linalg.norm((a['x'][i],a['y'][i],a['z'][i])), dest[i]):
+ assert False
+
+def testminusfloat3():
+ dest = np.empty(a.size, dtype=float3)
+ minusfloat3(cuda.In(a), cuda.Out(dest), **size)
+ if not np.allclose(-a['x'], dest['x']) or \
+ not np.allclose(-a['y'], dest['y']) or \
+ not np.allclose(-a['z'], dest['z']):
+ assert False