diff options
Diffstat (limited to 'BS/baselines/gpu/cu_lib_import.py')
-rw-r--r-- | BS/baselines/gpu/cu_lib_import.py | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/BS/baselines/gpu/cu_lib_import.py b/BS/baselines/gpu/cu_lib_import.py new file mode 100644 index 0000000..aafbbce --- /dev/null +++ b/BS/baselines/gpu/cu_lib_import.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*-
+
+__all__ = [
+ "binary_search",
+]
+
+
+from ctypes import *
+import os.path as path
+from numpy.ctypeslib import load_library, ndpointer
+import platform
+
+
+## Load the DLL
+if platform.system() == 'Linux':
+ cuda_lib = load_library("cu_binary_search.so", path.dirname(path.realpath(__file__)))
+elif platform.system() == 'Windows':
+ cuda_lib = load_library("cu_binary_search.dll", path.dirname(path.realpath(__file__)))
+
+
+
+
+## Define argtypes for all functions to import
+argtype_defs = {
+
+ "binary_search" : [ndpointer("i8"),
+ c_int,
+ ndpointer("i8"),
+ c_int],
+
+}
+
+
+
+
+## Import functions from DLL
+for func, argtypes in argtype_defs.items():
+ locals().update({func: cuda_lib[func]})
+ locals()[func].argtypes = argtypes
|