summaryrefslogtreecommitdiff
path: root/BS/baselines/gpu/cu_lib_import.py
blob: aafbbcee3d2c2f289ad7151397c2f75a9e0d341a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-

__all__ = [
    "binary_search",
]


from ctypes import *
import os.path as path
from numpy.ctypeslib import load_library, ndpointer
import platform


## Load the DLL
if platform.system() == 'Linux':
    cuda_lib = load_library("cu_binary_search.so", path.dirname(path.realpath(__file__)))
elif platform.system() == 'Windows':
    cuda_lib = load_library("cu_binary_search.dll", path.dirname(path.realpath(__file__)))




## Define argtypes for all functions to import
argtype_defs = {

    "binary_search" : [ndpointer("i8"),
                       c_int,
                       ndpointer("i8"),
                       c_int],

}




## Import functions from DLL
for func, argtypes in argtype_defs.items():
    locals().update({func: cuda_lib[func]})
    locals()[func].argtypes = argtypes