numba-cuda 0.0.1__py3-none-any.whl → 0.0.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.pth +1 -0
- _numba_cuda_redirector.py +74 -0
- numba_cuda/VERSION +1 -0
- numba_cuda/__init__.py +5 -0
- numba_cuda/_version.py +19 -0
- numba_cuda/numba/cuda/__init__.py +22 -0
- numba_cuda/numba/cuda/api.py +526 -0
- numba_cuda/numba/cuda/api_util.py +30 -0
- numba_cuda/numba/cuda/args.py +77 -0
- numba_cuda/numba/cuda/cg.py +62 -0
- numba_cuda/numba/cuda/codegen.py +378 -0
- numba_cuda/numba/cuda/compiler.py +422 -0
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +47 -0
- numba_cuda/numba/cuda/cuda_fp16.h +3631 -0
- numba_cuda/numba/cuda/cuda_fp16.hpp +2465 -0
- numba_cuda/numba/cuda/cuda_paths.py +258 -0
- numba_cuda/numba/cuda/cudadecl.py +806 -0
- numba_cuda/numba/cuda/cudadrv/__init__.py +9 -0
- numba_cuda/numba/cuda/cudadrv/devicearray.py +904 -0
- numba_cuda/numba/cuda/cudadrv/devices.py +248 -0
- numba_cuda/numba/cuda/cudadrv/driver.py +3201 -0
- numba_cuda/numba/cuda/cudadrv/drvapi.py +398 -0
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +452 -0
- numba_cuda/numba/cuda/cudadrv/enums.py +607 -0
- numba_cuda/numba/cuda/cudadrv/error.py +36 -0
- numba_cuda/numba/cuda/cudadrv/libs.py +176 -0
- numba_cuda/numba/cuda/cudadrv/ndarray.py +20 -0
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +260 -0
- numba_cuda/numba/cuda/cudadrv/nvvm.py +707 -0
- numba_cuda/numba/cuda/cudadrv/rtapi.py +10 -0
- numba_cuda/numba/cuda/cudadrv/runtime.py +142 -0
- numba_cuda/numba/cuda/cudaimpl.py +1055 -0
- numba_cuda/numba/cuda/cudamath.py +140 -0
- numba_cuda/numba/cuda/decorators.py +189 -0
- numba_cuda/numba/cuda/descriptor.py +33 -0
- numba_cuda/numba/cuda/device_init.py +89 -0
- numba_cuda/numba/cuda/deviceufunc.py +908 -0
- numba_cuda/numba/cuda/dispatcher.py +1057 -0
- numba_cuda/numba/cuda/errors.py +59 -0
- numba_cuda/numba/cuda/extending.py +7 -0
- numba_cuda/numba/cuda/initialize.py +13 -0
- numba_cuda/numba/cuda/intrinsic_wrapper.py +77 -0
- numba_cuda/numba/cuda/intrinsics.py +198 -0
- numba_cuda/numba/cuda/kernels/__init__.py +0 -0
- numba_cuda/numba/cuda/kernels/reduction.py +262 -0
- numba_cuda/numba/cuda/kernels/transpose.py +65 -0
- numba_cuda/numba/cuda/libdevice.py +3382 -0
- numba_cuda/numba/cuda/libdevicedecl.py +17 -0
- numba_cuda/numba/cuda/libdevicefuncs.py +1057 -0
- numba_cuda/numba/cuda/libdeviceimpl.py +83 -0
- numba_cuda/numba/cuda/mathimpl.py +448 -0
- numba_cuda/numba/cuda/models.py +48 -0
- numba_cuda/numba/cuda/nvvmutils.py +235 -0
- numba_cuda/numba/cuda/printimpl.py +86 -0
- numba_cuda/numba/cuda/random.py +292 -0
- numba_cuda/numba/cuda/simulator/__init__.py +38 -0
- numba_cuda/numba/cuda/simulator/api.py +110 -0
- numba_cuda/numba/cuda/simulator/compiler.py +9 -0
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +2 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +432 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +117 -0
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +62 -0
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/dummyarray.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/error.py +6 -0
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +2 -0
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +29 -0
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +19 -0
- numba_cuda/numba/cuda/simulator/kernel.py +308 -0
- numba_cuda/numba/cuda/simulator/kernelapi.py +495 -0
- numba_cuda/numba/cuda/simulator/reduction.py +15 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +58 -0
- numba_cuda/numba/cuda/simulator_init.py +17 -0
- numba_cuda/numba/cuda/stubs.py +902 -0
- numba_cuda/numba/cuda/target.py +440 -0
- numba_cuda/numba/cuda/testing.py +202 -0
- numba_cuda/numba/cuda/tests/__init__.py +58 -0
- numba_cuda/numba/cuda/tests/cudadrv/__init__.py +8 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +145 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +145 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +375 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +21 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +179 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +235 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +22 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +193 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +547 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +249 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +81 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +192 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +38 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +65 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +139 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +37 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_is_fp16.py +12 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +317 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +127 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +54 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +199 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +37 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +20 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +149 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +36 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +85 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +41 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +122 -0
- numba_cuda/numba/cuda/tests/cudapy/__init__.py +8 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +234 -0
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +41 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +58 -0
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +30 -0
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +100 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +42 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +260 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +201 -0
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +35 -0
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1620 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +120 -0
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +24 -0
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +545 -0
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +257 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +33 -0
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +276 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +296 -0
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +129 -0
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +176 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +147 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +435 -0
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +90 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +94 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +221 -0
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +222 -0
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +700 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +121 -0
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +79 -0
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +174 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +155 -0
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +244 -0
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +52 -0
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +29 -0
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +66 -0
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +60 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +456 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +159 -0
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +95 -0
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +165 -0
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +1106 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +318 -0
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +99 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +64 -0
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +119 -0
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +187 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +199 -0
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +164 -0
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +786 -0
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +74 -0
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +113 -0
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +22 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +140 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +46 -0
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +101 -0
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +49 -0
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +401 -0
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +86 -0
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +335 -0
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +124 -0
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +128 -0
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +33 -0
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +104 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +610 -0
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +125 -0
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +76 -0
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +83 -0
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +85 -0
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +444 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +205 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +271 -0
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +80 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +277 -0
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +47 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +307 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +283 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +69 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +36 -0
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +37 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +139 -0
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +276 -0
- numba_cuda/numba/cuda/tests/cudasim/__init__.py +6 -0
- numba_cuda/numba/cuda/tests/cudasim/support.py +6 -0
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +102 -0
- numba_cuda/numba/cuda/tests/data/__init__.py +0 -0
- numba_cuda/numba/cuda/tests/data/cuda_include.cu +5 -0
- numba_cuda/numba/cuda/tests/data/error.cu +7 -0
- numba_cuda/numba/cuda/tests/data/jitlink.cu +23 -0
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +51 -0
- numba_cuda/numba/cuda/tests/data/warn.cu +7 -0
- numba_cuda/numba/cuda/tests/doc_examples/__init__.py +6 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/__init__.py +0 -0
- numba_cuda/numba/cuda/tests/doc_examples/ffi/functions.cu +49 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +77 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +76 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +82 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +155 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +173 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +109 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +59 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +76 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +130 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +50 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +73 -0
- numba_cuda/numba/cuda/tests/nocuda/__init__.py +8 -0
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +359 -0
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +36 -0
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +49 -0
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +238 -0
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +54 -0
- numba_cuda/numba/cuda/types.py +37 -0
- numba_cuda/numba/cuda/ufuncs.py +662 -0
- numba_cuda/numba/cuda/vector_types.py +209 -0
- numba_cuda/numba/cuda/vectorizers.py +252 -0
- numba_cuda-0.0.13.dist-info/LICENSE +25 -0
- numba_cuda-0.0.13.dist-info/METADATA +69 -0
- numba_cuda-0.0.13.dist-info/RECORD +231 -0
- {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/WHEEL +1 -1
- numba_cuda-0.0.1.dist-info/METADATA +0 -10
- numba_cuda-0.0.1.dist-info/RECORD +0 -5
- {numba_cuda-0.0.1.dist-info → numba_cuda-0.0.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,83 @@
|
|
1
|
+
from llvmlite import ir
|
2
|
+
from numba.core import cgutils, types
|
3
|
+
from numba.core.imputils import Registry
|
4
|
+
from numba.cuda import libdevice, libdevicefuncs
|
5
|
+
|
6
|
+
registry = Registry()
|
7
|
+
lower = registry.lower
|
8
|
+
|
9
|
+
|
10
|
+
def libdevice_implement(func, retty, nbargs):
|
11
|
+
def core(context, builder, sig, args):
|
12
|
+
lmod = builder.module
|
13
|
+
fretty = context.get_value_type(retty)
|
14
|
+
fargtys = [context.get_value_type(arg.ty) for arg in nbargs]
|
15
|
+
fnty = ir.FunctionType(fretty, fargtys)
|
16
|
+
fn = cgutils.get_or_insert_function(lmod, fnty, func)
|
17
|
+
return builder.call(fn, args)
|
18
|
+
|
19
|
+
key = getattr(libdevice, func[5:])
|
20
|
+
|
21
|
+
argtys = [arg.ty for arg in args if not arg.is_ptr]
|
22
|
+
lower(key, *argtys)(core)
|
23
|
+
|
24
|
+
|
25
|
+
def libdevice_implement_multiple_returns(func, retty, prototype_args):
|
26
|
+
sig = libdevicefuncs.create_signature(retty, prototype_args)
|
27
|
+
nb_retty = sig.return_type
|
28
|
+
|
29
|
+
def core(context, builder, sig, args):
|
30
|
+
lmod = builder.module
|
31
|
+
|
32
|
+
fargtys = []
|
33
|
+
for arg in prototype_args:
|
34
|
+
ty = context.get_value_type(arg.ty)
|
35
|
+
if arg.is_ptr:
|
36
|
+
ty = ty.as_pointer()
|
37
|
+
fargtys.append(ty)
|
38
|
+
|
39
|
+
fretty = context.get_value_type(retty)
|
40
|
+
|
41
|
+
fnty = ir.FunctionType(fretty, fargtys)
|
42
|
+
fn = cgutils.get_or_insert_function(lmod, fnty, func)
|
43
|
+
|
44
|
+
# For returned values that are returned through a pointer, we need to
|
45
|
+
# allocate variables on the stack and pass a pointer to them.
|
46
|
+
actual_args = []
|
47
|
+
virtual_args = []
|
48
|
+
arg_idx = 0
|
49
|
+
for arg in prototype_args:
|
50
|
+
if arg.is_ptr:
|
51
|
+
# Allocate space for return value and add to args
|
52
|
+
tmp_arg = cgutils.alloca_once(builder,
|
53
|
+
context.get_value_type(arg.ty))
|
54
|
+
actual_args.append(tmp_arg)
|
55
|
+
virtual_args.append(tmp_arg)
|
56
|
+
else:
|
57
|
+
actual_args.append(args[arg_idx])
|
58
|
+
arg_idx += 1
|
59
|
+
|
60
|
+
ret = builder.call(fn, actual_args)
|
61
|
+
|
62
|
+
# Following the call, we need to assemble the returned values into a
|
63
|
+
# tuple for returning back to the caller.
|
64
|
+
tuple_args = []
|
65
|
+
if retty != types.void:
|
66
|
+
tuple_args.append(ret)
|
67
|
+
for arg in virtual_args:
|
68
|
+
tuple_args.append(builder.load(arg))
|
69
|
+
|
70
|
+
if isinstance(nb_retty, types.UniTuple):
|
71
|
+
return cgutils.pack_array(builder, tuple_args)
|
72
|
+
else:
|
73
|
+
return cgutils.pack_struct(builder, tuple_args)
|
74
|
+
|
75
|
+
key = getattr(libdevice, func[5:])
|
76
|
+
lower(key, *sig.args)(core)
|
77
|
+
|
78
|
+
|
79
|
+
for func, (retty, args) in libdevicefuncs.functions.items():
|
80
|
+
if any([arg.is_ptr for arg in args]):
|
81
|
+
libdevice_implement_multiple_returns(func, retty, args)
|
82
|
+
else:
|
83
|
+
libdevice_implement(func, retty, args)
|
@@ -0,0 +1,448 @@
|
|
1
|
+
import math
|
2
|
+
import operator
|
3
|
+
from llvmlite import ir
|
4
|
+
from numba.core import types, typing, cgutils, targetconfig
|
5
|
+
from numba.core.imputils import Registry
|
6
|
+
from numba.types import float32, float64, int64, uint64
|
7
|
+
from numba.cuda import libdevice
|
8
|
+
from numba import cuda
|
9
|
+
|
10
|
+
registry = Registry()
|
11
|
+
lower = registry.lower
|
12
|
+
|
13
|
+
|
14
|
+
booleans = []
|
15
|
+
booleans += [('isnand', 'isnanf', math.isnan)]
|
16
|
+
booleans += [('isinfd', 'isinff', math.isinf)]
|
17
|
+
booleans += [('isfinited', 'finitef', math.isfinite)]
|
18
|
+
|
19
|
+
unarys = []
|
20
|
+
unarys += [('ceil', 'ceilf', math.ceil)]
|
21
|
+
unarys += [('floor', 'floorf', math.floor)]
|
22
|
+
unarys += [('fabs', 'fabsf', math.fabs)]
|
23
|
+
unarys += [('exp', 'expf', math.exp)]
|
24
|
+
unarys += [('expm1', 'expm1f', math.expm1)]
|
25
|
+
unarys += [('erf', 'erff', math.erf)]
|
26
|
+
unarys += [('erfc', 'erfcf', math.erfc)]
|
27
|
+
unarys += [('tgamma', 'tgammaf', math.gamma)]
|
28
|
+
unarys += [('lgamma', 'lgammaf', math.lgamma)]
|
29
|
+
unarys += [('sqrt', 'sqrtf', math.sqrt)]
|
30
|
+
unarys += [('log', 'logf', math.log)]
|
31
|
+
unarys += [('log2', 'log2f', math.log2)]
|
32
|
+
unarys += [('log10', 'log10f', math.log10)]
|
33
|
+
unarys += [('log1p', 'log1pf', math.log1p)]
|
34
|
+
unarys += [('acosh', 'acoshf', math.acosh)]
|
35
|
+
unarys += [('acos', 'acosf', math.acos)]
|
36
|
+
unarys += [('cos', 'cosf', math.cos)]
|
37
|
+
unarys += [('cosh', 'coshf', math.cosh)]
|
38
|
+
unarys += [('asinh', 'asinhf', math.asinh)]
|
39
|
+
unarys += [('asin', 'asinf', math.asin)]
|
40
|
+
unarys += [('sin', 'sinf', math.sin)]
|
41
|
+
unarys += [('sinh', 'sinhf', math.sinh)]
|
42
|
+
unarys += [('atan', 'atanf', math.atan)]
|
43
|
+
unarys += [('atanh', 'atanhf', math.atanh)]
|
44
|
+
unarys += [('tan', 'tanf', math.tan)]
|
45
|
+
unarys += [('trunc', 'truncf', math.trunc)]
|
46
|
+
|
47
|
+
unarys_fastmath = {}
|
48
|
+
unarys_fastmath['cosf'] = 'fast_cosf'
|
49
|
+
unarys_fastmath['sinf'] = 'fast_sinf'
|
50
|
+
unarys_fastmath['tanf'] = 'fast_tanf'
|
51
|
+
unarys_fastmath['expf'] = 'fast_expf'
|
52
|
+
unarys_fastmath['log2f'] = 'fast_log2f'
|
53
|
+
unarys_fastmath['log10f'] = 'fast_log10f'
|
54
|
+
unarys_fastmath['logf'] = 'fast_logf'
|
55
|
+
|
56
|
+
binarys = []
|
57
|
+
binarys += [('copysign', 'copysignf', math.copysign)]
|
58
|
+
binarys += [('atan2', 'atan2f', math.atan2)]
|
59
|
+
binarys += [('pow', 'powf', math.pow)]
|
60
|
+
binarys += [('fmod', 'fmodf', math.fmod)]
|
61
|
+
binarys += [('hypot', 'hypotf', math.hypot)]
|
62
|
+
binarys += [('remainder', 'remainderf', math.remainder)]
|
63
|
+
|
64
|
+
binarys_fastmath = {}
|
65
|
+
binarys_fastmath['powf'] = 'fast_powf'
|
66
|
+
|
67
|
+
|
68
|
+
@lower(math.isinf, types.Integer)
|
69
|
+
@lower(math.isnan, types.Integer)
|
70
|
+
def math_isinf_isnan_int(context, builder, sig, args):
|
71
|
+
return context.get_constant(types.boolean, 0)
|
72
|
+
|
73
|
+
|
74
|
+
@lower(operator.truediv, types.float32, types.float32)
|
75
|
+
def maybe_fast_truediv(context, builder, sig, args):
|
76
|
+
if context.fastmath:
|
77
|
+
sig = typing.signature(float32, float32, float32)
|
78
|
+
impl = context.get_function(libdevice.fast_fdividef, sig)
|
79
|
+
return impl(builder, args)
|
80
|
+
else:
|
81
|
+
with cgutils.if_zero(builder, args[1]):
|
82
|
+
context.error_model.fp_zero_division(builder, ("division by zero",))
|
83
|
+
res = builder.fdiv(*args)
|
84
|
+
return res
|
85
|
+
|
86
|
+
|
87
|
+
@lower(math.isfinite, types.Integer)
|
88
|
+
def math_isfinite_int(context, builder, sig, args):
|
89
|
+
return context.get_constant(types.boolean, 1)
|
90
|
+
|
91
|
+
|
92
|
+
@lower(math.sin, types.float16)
|
93
|
+
def fp16_sin_impl(context, builder, sig, args):
|
94
|
+
def fp16_sin(x):
|
95
|
+
return cuda.fp16.hsin(x)
|
96
|
+
|
97
|
+
return context.compile_internal(builder, fp16_sin, sig, args)
|
98
|
+
|
99
|
+
|
100
|
+
@lower(math.cos, types.float16)
|
101
|
+
def fp16_cos_impl(context, builder, sig, args):
|
102
|
+
def fp16_cos(x):
|
103
|
+
return cuda.fp16.hcos(x)
|
104
|
+
|
105
|
+
return context.compile_internal(builder, fp16_cos, sig, args)
|
106
|
+
|
107
|
+
|
108
|
+
@lower(math.log, types.float16)
|
109
|
+
def fp16_log_impl(context, builder, sig, args):
|
110
|
+
def fp16_log(x):
|
111
|
+
return cuda.fp16.hlog(x)
|
112
|
+
|
113
|
+
return context.compile_internal(builder, fp16_log, sig, args)
|
114
|
+
|
115
|
+
|
116
|
+
@lower(math.log10, types.float16)
|
117
|
+
def fp16_log10_impl(context, builder, sig, args):
|
118
|
+
def fp16_log10(x):
|
119
|
+
return cuda.fp16.hlog10(x)
|
120
|
+
|
121
|
+
return context.compile_internal(builder, fp16_log10, sig, args)
|
122
|
+
|
123
|
+
|
124
|
+
@lower(math.log2, types.float16)
|
125
|
+
def fp16_log2_impl(context, builder, sig, args):
|
126
|
+
def fp16_log2(x):
|
127
|
+
return cuda.fp16.hlog2(x)
|
128
|
+
|
129
|
+
return context.compile_internal(builder, fp16_log2, sig, args)
|
130
|
+
|
131
|
+
|
132
|
+
@lower(math.exp, types.float16)
|
133
|
+
def fp16_exp_impl(context, builder, sig, args):
|
134
|
+
def fp16_exp(x):
|
135
|
+
return cuda.fp16.hexp(x)
|
136
|
+
|
137
|
+
return context.compile_internal(builder, fp16_exp, sig, args)
|
138
|
+
|
139
|
+
|
140
|
+
@lower(math.floor, types.float16)
|
141
|
+
def fp16_floor_impl(context, builder, sig, args):
|
142
|
+
def fp16_floor(x):
|
143
|
+
return cuda.fp16.hfloor(x)
|
144
|
+
|
145
|
+
return context.compile_internal(builder, fp16_floor, sig, args)
|
146
|
+
|
147
|
+
|
148
|
+
@lower(math.ceil, types.float16)
|
149
|
+
def fp16_ceil_impl(context, builder, sig, args):
|
150
|
+
def fp16_ceil(x):
|
151
|
+
return cuda.fp16.hceil(x)
|
152
|
+
|
153
|
+
return context.compile_internal(builder, fp16_ceil, sig, args)
|
154
|
+
|
155
|
+
|
156
|
+
@lower(math.sqrt, types.float16)
|
157
|
+
def fp16_sqrt_impl(context, builder, sig, args):
|
158
|
+
def fp16_sqrt(x):
|
159
|
+
return cuda.fp16.hsqrt(x)
|
160
|
+
|
161
|
+
return context.compile_internal(builder, fp16_sqrt, sig, args)
|
162
|
+
|
163
|
+
|
164
|
+
@lower(math.fabs, types.float16)
|
165
|
+
def fp16_fabs_impl(context, builder, sig, args):
|
166
|
+
def fp16_fabs(x):
|
167
|
+
return cuda.fp16.habs(x)
|
168
|
+
|
169
|
+
return context.compile_internal(builder, fp16_fabs, sig, args)
|
170
|
+
|
171
|
+
|
172
|
+
@lower(math.trunc, types.float16)
|
173
|
+
def fp16_trunc_impl(context, builder, sig, args):
|
174
|
+
def fp16_trunc(x):
|
175
|
+
return cuda.fp16.htrunc(x)
|
176
|
+
|
177
|
+
return context.compile_internal(builder, fp16_trunc, sig, args)
|
178
|
+
|
179
|
+
|
180
|
+
def impl_boolean(key, ty, libfunc):
|
181
|
+
def lower_boolean_impl(context, builder, sig, args):
|
182
|
+
libfunc_impl = context.get_function(libfunc,
|
183
|
+
typing.signature(types.int32, ty))
|
184
|
+
result = libfunc_impl(builder, args)
|
185
|
+
return context.cast(builder, result, types.int32, types.boolean)
|
186
|
+
|
187
|
+
lower(key, ty)(lower_boolean_impl)
|
188
|
+
|
189
|
+
|
190
|
+
def get_lower_unary_impl(key, ty, libfunc):
|
191
|
+
def lower_unary_impl(context, builder, sig, args):
|
192
|
+
actual_libfunc = libfunc
|
193
|
+
fast_replacement = None
|
194
|
+
if ty == float32 and context.fastmath:
|
195
|
+
fast_replacement = unarys_fastmath.get(libfunc.__name__)
|
196
|
+
|
197
|
+
if fast_replacement is not None:
|
198
|
+
actual_libfunc = getattr(libdevice, fast_replacement)
|
199
|
+
|
200
|
+
libfunc_impl = context.get_function(actual_libfunc,
|
201
|
+
typing.signature(ty, ty))
|
202
|
+
return libfunc_impl(builder, args)
|
203
|
+
return lower_unary_impl
|
204
|
+
|
205
|
+
|
206
|
+
def get_unary_impl_for_fn_and_ty(fn, ty):
|
207
|
+
# tanh is a special case - because it is not registered like the other
|
208
|
+
# unary implementations, it does not appear in the unarys list. However,
|
209
|
+
# its implementation can be looked up by key like the other
|
210
|
+
# implementations, so we add it to the list we search here.
|
211
|
+
tanh_impls = ('tanh', 'tanhf', math.tanh)
|
212
|
+
for fname64, fname32, key in unarys + [tanh_impls]:
|
213
|
+
if fn == key:
|
214
|
+
if ty == float32:
|
215
|
+
impl = getattr(libdevice, fname32)
|
216
|
+
elif ty == float64:
|
217
|
+
impl = getattr(libdevice, fname64)
|
218
|
+
|
219
|
+
return get_lower_unary_impl(key, ty, impl)
|
220
|
+
|
221
|
+
raise RuntimeError(f"Implementation of {fn} for {ty} not found")
|
222
|
+
|
223
|
+
|
224
|
+
def impl_unary(key, ty, libfunc):
|
225
|
+
lower_unary_impl = get_lower_unary_impl(key, ty, libfunc)
|
226
|
+
lower(key, ty)(lower_unary_impl)
|
227
|
+
|
228
|
+
|
229
|
+
def impl_unary_int(key, ty, libfunc):
|
230
|
+
def lower_unary_int_impl(context, builder, sig, args):
|
231
|
+
if sig.args[0] == int64:
|
232
|
+
convert = builder.sitofp
|
233
|
+
elif sig.args[0] == uint64:
|
234
|
+
convert = builder.uitofp
|
235
|
+
else:
|
236
|
+
m = 'Only 64-bit integers are supported for generic unary int ops'
|
237
|
+
raise TypeError(m)
|
238
|
+
|
239
|
+
arg = convert(args[0], ir.DoubleType())
|
240
|
+
sig = typing.signature(float64, float64)
|
241
|
+
libfunc_impl = context.get_function(libfunc, sig)
|
242
|
+
return libfunc_impl(builder, [arg])
|
243
|
+
|
244
|
+
lower(key, ty)(lower_unary_int_impl)
|
245
|
+
|
246
|
+
|
247
|
+
def get_lower_binary_impl(key, ty, libfunc):
|
248
|
+
def lower_binary_impl(context, builder, sig, args):
|
249
|
+
actual_libfunc = libfunc
|
250
|
+
fast_replacement = None
|
251
|
+
if ty == float32 and context.fastmath:
|
252
|
+
fast_replacement = binarys_fastmath.get(libfunc.__name__)
|
253
|
+
|
254
|
+
if fast_replacement is not None:
|
255
|
+
actual_libfunc = getattr(libdevice, fast_replacement)
|
256
|
+
|
257
|
+
libfunc_impl = context.get_function(actual_libfunc,
|
258
|
+
typing.signature(ty, ty, ty))
|
259
|
+
return libfunc_impl(builder, args)
|
260
|
+
return lower_binary_impl
|
261
|
+
|
262
|
+
|
263
|
+
def get_binary_impl_for_fn_and_ty(fn, ty):
|
264
|
+
for fname64, fname32, key in binarys:
|
265
|
+
if fn == key:
|
266
|
+
if ty == float32:
|
267
|
+
impl = getattr(libdevice, fname32)
|
268
|
+
elif ty == float64:
|
269
|
+
impl = getattr(libdevice, fname64)
|
270
|
+
|
271
|
+
return get_lower_binary_impl(key, ty, impl)
|
272
|
+
|
273
|
+
raise RuntimeError(f"Implementation of {fn} for {ty} not found")
|
274
|
+
|
275
|
+
|
276
|
+
def impl_binary(key, ty, libfunc):
|
277
|
+
lower_binary_impl = get_lower_binary_impl(key, ty, libfunc)
|
278
|
+
lower(key, ty, ty)(lower_binary_impl)
|
279
|
+
|
280
|
+
|
281
|
+
def impl_binary_int(key, ty, libfunc):
|
282
|
+
def lower_binary_int_impl(context, builder, sig, args):
|
283
|
+
if sig.args[0] == int64:
|
284
|
+
convert = builder.sitofp
|
285
|
+
elif sig.args[0] == uint64:
|
286
|
+
convert = builder.uitofp
|
287
|
+
else:
|
288
|
+
m = 'Only 64-bit integers are supported for generic binary int ops'
|
289
|
+
raise TypeError(m)
|
290
|
+
|
291
|
+
args = [convert(arg, ir.DoubleType()) for arg in args]
|
292
|
+
sig = typing.signature(float64, float64, float64)
|
293
|
+
libfunc_impl = context.get_function(libfunc, sig)
|
294
|
+
return libfunc_impl(builder, args)
|
295
|
+
|
296
|
+
lower(key, ty, ty)(lower_binary_int_impl)
|
297
|
+
|
298
|
+
|
299
|
+
for fname64, fname32, key in booleans:
|
300
|
+
impl32 = getattr(libdevice, fname32)
|
301
|
+
impl64 = getattr(libdevice, fname64)
|
302
|
+
impl_boolean(key, float32, impl32)
|
303
|
+
impl_boolean(key, float64, impl64)
|
304
|
+
|
305
|
+
|
306
|
+
for fname64, fname32, key in unarys:
|
307
|
+
impl32 = getattr(libdevice, fname32)
|
308
|
+
impl64 = getattr(libdevice, fname64)
|
309
|
+
impl_unary(key, float32, impl32)
|
310
|
+
impl_unary(key, float64, impl64)
|
311
|
+
impl_unary_int(key, int64, impl64)
|
312
|
+
impl_unary_int(key, uint64, impl64)
|
313
|
+
|
314
|
+
|
315
|
+
for fname64, fname32, key in binarys:
|
316
|
+
impl32 = getattr(libdevice, fname32)
|
317
|
+
impl64 = getattr(libdevice, fname64)
|
318
|
+
impl_binary(key, float32, impl32)
|
319
|
+
impl_binary(key, float64, impl64)
|
320
|
+
impl_binary_int(key, int64, impl64)
|
321
|
+
impl_binary_int(key, uint64, impl64)
|
322
|
+
|
323
|
+
|
324
|
+
def impl_pow_int(ty, libfunc):
|
325
|
+
def lower_pow_impl_int(context, builder, sig, args):
|
326
|
+
powi_sig = typing.signature(ty, ty, types.int32)
|
327
|
+
libfunc_impl = context.get_function(libfunc, powi_sig)
|
328
|
+
return libfunc_impl(builder, args)
|
329
|
+
|
330
|
+
lower(math.pow, ty, types.int32)(lower_pow_impl_int)
|
331
|
+
|
332
|
+
|
333
|
+
impl_pow_int(types.float32, libdevice.powif)
|
334
|
+
impl_pow_int(types.float64, libdevice.powi)
|
335
|
+
|
336
|
+
|
337
|
+
def impl_modf(ty, libfunc):
|
338
|
+
retty = types.UniTuple(ty, 2)
|
339
|
+
|
340
|
+
def lower_modf_impl(context, builder, sig, args):
|
341
|
+
modf_sig = typing.signature(retty, ty)
|
342
|
+
libfunc_impl = context.get_function(libfunc, modf_sig)
|
343
|
+
return libfunc_impl(builder, args)
|
344
|
+
|
345
|
+
lower(math.modf, ty)(lower_modf_impl)
|
346
|
+
|
347
|
+
|
348
|
+
impl_modf(types.float32, libdevice.modff)
|
349
|
+
impl_modf(types.float64, libdevice.modf)
|
350
|
+
|
351
|
+
|
352
|
+
def impl_frexp(ty, libfunc):
|
353
|
+
retty = types.Tuple((ty, types.int32))
|
354
|
+
|
355
|
+
def lower_frexp_impl(context, builder, sig, args):
|
356
|
+
frexp_sig = typing.signature(retty, ty)
|
357
|
+
libfunc_impl = context.get_function(libfunc, frexp_sig)
|
358
|
+
return libfunc_impl(builder, args)
|
359
|
+
|
360
|
+
lower(math.frexp, ty)(lower_frexp_impl)
|
361
|
+
|
362
|
+
|
363
|
+
impl_frexp(types.float32, libdevice.frexpf)
|
364
|
+
impl_frexp(types.float64, libdevice.frexp)
|
365
|
+
|
366
|
+
|
367
|
+
def impl_ldexp(ty, libfunc):
|
368
|
+
def lower_ldexp_impl(context, builder, sig, args):
|
369
|
+
ldexp_sig = typing.signature(ty, ty, types.int32)
|
370
|
+
libfunc_impl = context.get_function(libfunc, ldexp_sig)
|
371
|
+
return libfunc_impl(builder, args)
|
372
|
+
|
373
|
+
lower(math.ldexp, ty, types.int32)(lower_ldexp_impl)
|
374
|
+
|
375
|
+
|
376
|
+
impl_ldexp(types.float32, libdevice.ldexpf)
|
377
|
+
impl_ldexp(types.float64, libdevice.ldexp)
|
378
|
+
|
379
|
+
|
380
|
+
def impl_tanh(ty, libfunc):
|
381
|
+
def lower_tanh_impl(context, builder, sig, args):
|
382
|
+
def get_compute_capability():
|
383
|
+
flags = targetconfig.ConfigStack().top()
|
384
|
+
return flags.compute_capability
|
385
|
+
|
386
|
+
def tanh_impl_libdevice():
|
387
|
+
tanh_sig = typing.signature(ty, ty)
|
388
|
+
libfunc_impl = context.get_function(libfunc, tanh_sig)
|
389
|
+
return libfunc_impl(builder, args)
|
390
|
+
|
391
|
+
def tanhf_impl_fastmath():
|
392
|
+
fnty = ir.FunctionType(ir.FloatType(), [ir.FloatType()])
|
393
|
+
asm = ir.InlineAsm(fnty, 'tanh.approx.f32 $0, $1;', '=f,f')
|
394
|
+
return builder.call(asm, args)
|
395
|
+
|
396
|
+
if ty == float32 and context.fastmath:
|
397
|
+
cc = get_compute_capability()
|
398
|
+
if cc >= (7,5):
|
399
|
+
return tanhf_impl_fastmath()
|
400
|
+
|
401
|
+
return tanh_impl_libdevice()
|
402
|
+
|
403
|
+
lower(math.tanh, ty)(lower_tanh_impl)
|
404
|
+
|
405
|
+
|
406
|
+
impl_tanh(types.float32, libdevice.tanhf)
|
407
|
+
impl_tanh(types.float64, libdevice.tanh)
|
408
|
+
|
409
|
+
impl_unary_int(math.tanh, int64, libdevice.tanh)
|
410
|
+
impl_unary_int(math.tanh, uint64, libdevice.tanh)
|
411
|
+
|
412
|
+
# Complex power implementations - translations of _Py_c_pow from CPython
|
413
|
+
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/complexobject.c#L123-L151
|
414
|
+
#
|
415
|
+
# The complex64 variant casts all constants and some variables to ensure that
|
416
|
+
# as much computation is done in single precision as possible. A small number
|
417
|
+
# of operations are still done in 64-bit, but these come from libdevice code.
|
418
|
+
|
419
|
+
|
420
|
+
def cpow_implement(fty, cty):
|
421
|
+
def core(context, builder, sig, args):
|
422
|
+
def cpow_internal(a, b):
|
423
|
+
|
424
|
+
if b.real == fty(0.0) and b.imag == fty(0.0):
|
425
|
+
return cty(1.0) + cty(0.0j)
|
426
|
+
elif a.real == fty(0.0) and b.real == fty(0.0):
|
427
|
+
return cty(0.0) + cty(0.0j)
|
428
|
+
|
429
|
+
vabs = math.hypot(a.real, a.imag)
|
430
|
+
len = math.pow(vabs, b.real)
|
431
|
+
at = math.atan2(a.imag, a.real)
|
432
|
+
phase = at * b.real
|
433
|
+
if b.imag != fty(0.0):
|
434
|
+
len /= math.exp(at * b.imag)
|
435
|
+
phase += b.imag * math.log(vabs)
|
436
|
+
|
437
|
+
return len * (cty(math.cos(phase)) +
|
438
|
+
cty(math.sin(phase) * cty(1.0j)))
|
439
|
+
|
440
|
+
return context.compile_internal(builder, cpow_internal, sig, args)
|
441
|
+
|
442
|
+
lower(operator.pow, cty, cty)(core)
|
443
|
+
lower(operator.ipow, cty, cty)(core)
|
444
|
+
lower(pow, cty, cty)(core)
|
445
|
+
|
446
|
+
|
447
|
+
cpow_implement(types.float32, types.complex64)
|
448
|
+
cpow_implement(types.float64, types.complex128)
|
@@ -0,0 +1,48 @@
|
|
1
|
+
import functools
|
2
|
+
|
3
|
+
from llvmlite import ir
|
4
|
+
|
5
|
+
from numba.core.datamodel.registry import DataModelManager, register
|
6
|
+
from numba.core.extending import models
|
7
|
+
from numba.core import types
|
8
|
+
from numba.cuda.types import Dim3, GridGroup, CUDADispatcher
|
9
|
+
|
10
|
+
|
11
|
+
cuda_data_manager = DataModelManager()
|
12
|
+
|
13
|
+
register_model = functools.partial(register, cuda_data_manager)
|
14
|
+
|
15
|
+
|
16
|
+
@register_model(Dim3)
|
17
|
+
class Dim3Model(models.StructModel):
|
18
|
+
def __init__(self, dmm, fe_type):
|
19
|
+
members = [
|
20
|
+
('x', types.int32),
|
21
|
+
('y', types.int32),
|
22
|
+
('z', types.int32)
|
23
|
+
]
|
24
|
+
super().__init__(dmm, fe_type, members)
|
25
|
+
|
26
|
+
|
27
|
+
@register_model(GridGroup)
|
28
|
+
class GridGroupModel(models.PrimitiveModel):
|
29
|
+
def __init__(self, dmm, fe_type):
|
30
|
+
be_type = ir.IntType(64)
|
31
|
+
super().__init__(dmm, fe_type, be_type)
|
32
|
+
|
33
|
+
|
34
|
+
@register_model(types.Float)
|
35
|
+
class FloatModel(models.PrimitiveModel):
|
36
|
+
def __init__(self, dmm, fe_type):
|
37
|
+
if fe_type == types.float16:
|
38
|
+
be_type = ir.IntType(16)
|
39
|
+
elif fe_type == types.float32:
|
40
|
+
be_type = ir.FloatType()
|
41
|
+
elif fe_type == types.float64:
|
42
|
+
be_type = ir.DoubleType()
|
43
|
+
else:
|
44
|
+
raise NotImplementedError(fe_type)
|
45
|
+
super(FloatModel, self).__init__(dmm, fe_type, be_type)
|
46
|
+
|
47
|
+
|
48
|
+
register_model(CUDADispatcher)(models.OpaqueModel)
|