summaryrefslogtreecommitdiff
path: root/BS/baselines/gpu/cu_lib_import.py
diff options
context:
space:
mode:
Diffstat (limited to 'BS/baselines/gpu/cu_lib_import.py')
-rw-r--r--BS/baselines/gpu/cu_lib_import.py39
1 files changed, 39 insertions, 0 deletions
diff --git a/BS/baselines/gpu/cu_lib_import.py b/BS/baselines/gpu/cu_lib_import.py
new file mode 100644
index 0000000..aafbbce
--- /dev/null
+++ b/BS/baselines/gpu/cu_lib_import.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+
+__all__ = [
+ "binary_search",
+]
+
+
+from ctypes import *
+import os.path as path
+from numpy.ctypeslib import load_library, ndpointer
+import platform
+
+
+## Load the DLL
+if platform.system() == 'Linux':
+ cuda_lib = load_library("cu_binary_search.so", path.dirname(path.realpath(__file__)))
+elif platform.system() == 'Windows':
+ cuda_lib = load_library("cu_binary_search.dll", path.dirname(path.realpath(__file__)))
+
+
+
+
+## Define argtypes for all functions to import
+argtype_defs = {
+
+ "binary_search" : [ndpointer("i8"),
+ c_int,
+ ndpointer("i8"),
+ c_int],
+
+}
+
+
+
+
+## Import functions from DLL
+for func, argtypes in argtype_defs.items():
+ locals().update({func: cuda_lib[func]})
+ locals()[func].argtypes = argtypes