blob: aafbbcee3d2c2f289ad7151397c2f75a9e0d341a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
|
# -*- coding: utf-8 -*-
__all__ = [
"binary_search",
]
from ctypes import *
import os.path as path
from numpy.ctypeslib import load_library, ndpointer
import platform
## Load the DLL
if platform.system() == 'Linux':
cuda_lib = load_library("cu_binary_search.so", path.dirname(path.realpath(__file__)))
elif platform.system() == 'Windows':
cuda_lib = load_library("cu_binary_search.dll", path.dirname(path.realpath(__file__)))
## Define argtypes for all functions to import
argtype_defs = {
"binary_search" : [ndpointer("i8"),
c_int,
ndpointer("i8"),
c_int],
}
## Import functions from DLL
for func, argtypes in argtype_defs.items():
locals().update({func: cuda_lib[func]})
locals()[func].argtypes = argtypes
|