summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--gpu.py27
1 files changed, 23 insertions, 4 deletions
diff --git a/gpu.py b/gpu.py
index de99e0a..671c872 100644
--- a/gpu.py
+++ b/gpu.py
@@ -2,7 +2,7 @@ import numpy as np
import numpy.ma as ma
from copy import copy
from itertools import izip
-from os.path import dirname
+import os
import sys
import pytools
@@ -34,7 +34,7 @@ def get_cu_module(name, options=None, include_source_directory=True):
else:
raise TypeError('`options` must be a tuple.')
- srcdir = dirname(chroma.src.__file__)
+ srcdir = os.path.dirname(os.path.abspath(chroma.src.__file__))
if include_source_directory:
options += ['-I' + srcdir]
@@ -46,7 +46,7 @@ def get_cu_module(name, options=None, include_source_directory=True):
no_extern_c=True)
def get_cu_source(name):
- srcdir = dirname(chroma.src.__file__)
+ srcdir = os.path.dirname(os.path.abspath(chroma.src.__file__))
with open('%s/%s' % (srcdir, name)) as f:
source = f.read()
return source
@@ -312,7 +312,7 @@ def format_array(name, array):
(name, format_size(len(array)), format_size(array.nbytes))
class GPUGeometry(object):
- def __init__(self, geometry, wavelengths=None, print_usage=True):
+ def __init__(self, geometry, wavelengths=None, print_usage=False):
if wavelengths is None:
wavelengths = standard_wavelengths
@@ -676,6 +676,25 @@ class GPUPDF(object):
return hitcount, pdf_value, pdf_value * pdf_frac_uncert
+def create_context(device_id=None):
+ """Initialize and return a GPU context on the specified device.
+ If device_id is None, the default device is used."""
+ try:
+ cuda.mem_get_info()
+ except cuda.LogicError:
+ # initialize cuda
+ cuda.init()
+
+ if device_id is None:
+ context = pycuda.tools.make_default_context()
+ else:
+ device = cuda.Device(device_id)
+ context = device.make_context()
+
+ context.set_cache_config(cuda.func_cache.PREFER_L1)
+
+ return context
+
class GPU(object):
def __init__(self, device_id=None):
"""Initialize a GPU context on the specified device.