numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.py +17 -13
- numba_cuda/VERSION +1 -1
- numba_cuda/_version.py +4 -1
- numba_cuda/numba/cuda/__init__.py +6 -2
- numba_cuda/numba/cuda/api.py +129 -86
- numba_cuda/numba/cuda/api_util.py +3 -3
- numba_cuda/numba/cuda/args.py +12 -16
- numba_cuda/numba/cuda/cg.py +6 -6
- numba_cuda/numba/cuda/codegen.py +74 -43
- numba_cuda/numba/cuda/compiler.py +246 -114
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
- numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
- numba_cuda/numba/cuda/cuda_paths.py +293 -99
- numba_cuda/numba/cuda/cudadecl.py +93 -79
- numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
- numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
- numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
- numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
- numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
- numba_cuda/numba/cuda/cudadrv/error.py +6 -2
- numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
- numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
- numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
- numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
- numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
- numba_cuda/numba/cuda/cudaimpl.py +296 -275
- numba_cuda/numba/cuda/cudamath.py +1 -1
- numba_cuda/numba/cuda/debuginfo.py +99 -7
- numba_cuda/numba/cuda/decorators.py +87 -45
- numba_cuda/numba/cuda/descriptor.py +1 -1
- numba_cuda/numba/cuda/device_init.py +68 -18
- numba_cuda/numba/cuda/deviceufunc.py +143 -98
- numba_cuda/numba/cuda/dispatcher.py +300 -213
- numba_cuda/numba/cuda/errors.py +13 -10
- numba_cuda/numba/cuda/extending.py +55 -1
- numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
- numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
- numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
- numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
- numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/initialize.py +5 -3
- numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
- numba_cuda/numba/cuda/intrinsics.py +203 -28
- numba_cuda/numba/cuda/kernels/reduction.py +13 -13
- numba_cuda/numba/cuda/kernels/transpose.py +3 -6
- numba_cuda/numba/cuda/libdevice.py +317 -317
- numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
- numba_cuda/numba/cuda/locks.py +16 -0
- numba_cuda/numba/cuda/lowering.py +43 -0
- numba_cuda/numba/cuda/mathimpl.py +62 -57
- numba_cuda/numba/cuda/models.py +1 -5
- numba_cuda/numba/cuda/nvvmutils.py +103 -88
- numba_cuda/numba/cuda/printimpl.py +9 -5
- numba_cuda/numba/cuda/random.py +46 -36
- numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
- numba_cuda/numba/cuda/runtime/__init__.py +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
- numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
- numba_cuda/numba/cuda/runtime/nrt.py +48 -43
- numba_cuda/numba/cuda/simulator/__init__.py +22 -12
- numba_cuda/numba/cuda/simulator/api.py +38 -22
- numba_cuda/numba/cuda/simulator/compiler.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
- numba_cuda/numba/cuda/simulator/kernel.py +43 -34
- numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
- numba_cuda/numba/cuda/simulator/reduction.py +1 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
- numba_cuda/numba/cuda/simulator_init.py +2 -4
- numba_cuda/numba/cuda/stubs.py +134 -108
- numba_cuda/numba/cuda/target.py +92 -47
- numba_cuda/numba/cuda/testing.py +24 -19
- numba_cuda/numba/cuda/tests/__init__.py +14 -12
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
- numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
- numba_cuda/numba/cuda/types.py +5 -2
- numba_cuda/numba/cuda/ufuncs.py +382 -362
- numba_cuda/numba/cuda/utils.py +2 -2
- numba_cuda/numba/cuda/vector_types.py +5 -3
- numba_cuda/numba/cuda/vectorizers.py +38 -33
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
- numba_cuda-0.10.0.dist-info/RECORD +263 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
- numba_cuda-0.8.1.dist-info/RECORD +0 -251
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,23 @@
|
|
1
1
|
import operator
|
2
2
|
from numba.core import types
|
3
|
-
from numba.core.typing.npydecl import (
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
from numba.core.typing.npydecl import (
|
4
|
+
parse_dtype,
|
5
|
+
parse_shape,
|
6
|
+
register_number_classes,
|
7
|
+
register_numpy_ufunc,
|
8
|
+
trigonometric_functions,
|
9
|
+
comparison_functions,
|
10
|
+
math_operations,
|
11
|
+
bit_twiddling_functions,
|
12
|
+
)
|
13
|
+
from numba.core.typing.templates import (
|
14
|
+
AttributeTemplate,
|
15
|
+
ConcreteTemplate,
|
16
|
+
AbstractTemplate,
|
17
|
+
CallableTemplate,
|
18
|
+
signature,
|
19
|
+
Registry,
|
20
|
+
)
|
13
21
|
from numba.cuda.types import dim3
|
14
22
|
from numba.core.typeconv import Conversion
|
15
23
|
from numba import cuda
|
@@ -26,15 +34,15 @@ register_number_classes(register_global)
|
|
26
34
|
class Cuda_array_decl(CallableTemplate):
|
27
35
|
def generic(self):
|
28
36
|
def typer(shape, dtype):
|
29
|
-
|
30
37
|
# Only integer literals and tuples of integer literals are valid
|
31
38
|
# shapes
|
32
39
|
if isinstance(shape, types.Integer):
|
33
40
|
if not isinstance(shape, types.IntegerLiteral):
|
34
41
|
return None
|
35
42
|
elif isinstance(shape, (types.Tuple, types.UniTuple)):
|
36
|
-
if any(
|
37
|
-
|
43
|
+
if any(
|
44
|
+
[not isinstance(s, types.IntegerLiteral) for s in shape]
|
45
|
+
):
|
38
46
|
return None
|
39
47
|
else:
|
40
48
|
return None
|
@@ -42,7 +50,7 @@ class Cuda_array_decl(CallableTemplate):
|
|
42
50
|
ndim = parse_shape(shape)
|
43
51
|
nb_dtype = parse_dtype(dtype)
|
44
52
|
if nb_dtype is not None and ndim is not None:
|
45
|
-
return types.Array(dtype=nb_dtype, ndim=ndim, layout=
|
53
|
+
return types.Array(dtype=nb_dtype, ndim=ndim, layout="C")
|
46
54
|
|
47
55
|
return typer
|
48
56
|
|
@@ -64,6 +72,7 @@ class Cuda_const_array_like(CallableTemplate):
|
|
64
72
|
def generic(self):
|
65
73
|
def typer(ndarray):
|
66
74
|
return ndarray
|
75
|
+
|
67
76
|
return typer
|
68
77
|
|
69
78
|
|
@@ -91,26 +100,14 @@ class Cuda_syncwarp(ConcreteTemplate):
|
|
91
100
|
cases = [signature(types.none), signature(types.none, types.i4)]
|
92
101
|
|
93
102
|
|
94
|
-
@register
|
95
|
-
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
|
96
|
-
key = cuda.shfl_sync_intrinsic
|
97
|
-
cases = [
|
98
|
-
signature(types.Tuple((types.i4, types.b1)),
|
99
|
-
types.i4, types.i4, types.i4, types.i4, types.i4),
|
100
|
-
signature(types.Tuple((types.i8, types.b1)),
|
101
|
-
types.i4, types.i4, types.i8, types.i4, types.i4),
|
102
|
-
signature(types.Tuple((types.f4, types.b1)),
|
103
|
-
types.i4, types.i4, types.f4, types.i4, types.i4),
|
104
|
-
signature(types.Tuple((types.f8, types.b1)),
|
105
|
-
types.i4, types.i4, types.f8, types.i4, types.i4),
|
106
|
-
]
|
107
|
-
|
108
|
-
|
109
103
|
@register
|
110
104
|
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
|
111
105
|
key = cuda.vote_sync_intrinsic
|
112
|
-
cases = [
|
113
|
-
|
106
|
+
cases = [
|
107
|
+
signature(
|
108
|
+
types.Tuple((types.i4, types.b1)), types.i4, types.i4, types.b1
|
109
|
+
)
|
110
|
+
]
|
114
111
|
|
115
112
|
|
116
113
|
@register
|
@@ -153,6 +150,7 @@ class Cuda_popc(ConcreteTemplate):
|
|
153
150
|
Supported types from `llvm.popc`
|
154
151
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
155
152
|
"""
|
153
|
+
|
156
154
|
key = cuda.popc
|
157
155
|
cases = [
|
158
156
|
signature(types.int8, types.int8),
|
@@ -172,6 +170,7 @@ class Cuda_fma(ConcreteTemplate):
|
|
172
170
|
Supported types from `llvm.fma`
|
173
171
|
[here](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#standard-c-library-intrinics)
|
174
172
|
"""
|
173
|
+
|
175
174
|
key = cuda.fma
|
176
175
|
cases = [
|
177
176
|
signature(types.float32, types.float32, types.float32, types.float32),
|
@@ -189,7 +188,6 @@ class Cuda_hfma(ConcreteTemplate):
|
|
189
188
|
|
190
189
|
@register
|
191
190
|
class Cuda_cbrt(ConcreteTemplate):
|
192
|
-
|
193
191
|
key = cuda.cbrt
|
194
192
|
cases = [
|
195
193
|
signature(types.float32, types.float32),
|
@@ -212,6 +210,7 @@ class Cuda_clz(ConcreteTemplate):
|
|
212
210
|
Supported types from `llvm.ctlz`
|
213
211
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
214
212
|
"""
|
213
|
+
|
215
214
|
key = cuda.clz
|
216
215
|
cases = [
|
217
216
|
signature(types.int8, types.int8),
|
@@ -231,6 +230,7 @@ class Cuda_ffs(ConcreteTemplate):
|
|
231
230
|
Supported types from `llvm.cttz`
|
232
231
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
233
232
|
"""
|
233
|
+
|
234
234
|
key = cuda.ffs
|
235
235
|
cases = [
|
236
236
|
signature(types.uint32, types.int8),
|
@@ -254,10 +254,16 @@ class Cuda_selp(AbstractTemplate):
|
|
254
254
|
|
255
255
|
# per docs
|
256
256
|
# http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
|
257
|
-
supported_types = (
|
258
|
-
|
259
|
-
|
260
|
-
|
257
|
+
supported_types = (
|
258
|
+
types.float64,
|
259
|
+
types.float32,
|
260
|
+
types.int16,
|
261
|
+
types.uint16,
|
262
|
+
types.int32,
|
263
|
+
types.uint32,
|
264
|
+
types.int64,
|
265
|
+
types.uint64,
|
266
|
+
)
|
261
267
|
|
262
268
|
if a != b or a not in supported_types:
|
263
269
|
return
|
@@ -298,7 +304,6 @@ def _genfp16_binary(l_key):
|
|
298
304
|
|
299
305
|
@register_global(float)
|
300
306
|
class Float(AbstractTemplate):
|
301
|
-
|
302
307
|
def generic(self, args, kws):
|
303
308
|
assert not kws
|
304
309
|
|
@@ -313,11 +318,11 @@ def _genfp16_binary_comparison(l_key):
|
|
313
318
|
class Cuda_fp16_cmp(ConcreteTemplate):
|
314
319
|
key = l_key
|
315
320
|
|
316
|
-
cases = [
|
317
|
-
|
318
|
-
]
|
321
|
+
cases = [signature(types.b1, types.float16, types.float16)]
|
322
|
+
|
319
323
|
return Cuda_fp16_cmp
|
320
324
|
|
325
|
+
|
321
326
|
# If multiple ConcreteTemplates provide typing for a single function, then
|
322
327
|
# function resolution will pick the first compatible typing it finds even if it
|
323
328
|
# involves inserting a cast that would be considered undesirable (in this
|
@@ -340,9 +345,10 @@ def _fp16_binary_operator(l_key, retty):
|
|
340
345
|
def generic(self, args, kws):
|
341
346
|
assert not kws
|
342
347
|
|
343
|
-
if len(args) == 2 and
|
344
|
-
|
345
|
-
|
348
|
+
if len(args) == 2 and (
|
349
|
+
args[0] == types.float16 or args[1] == types.float16
|
350
|
+
):
|
351
|
+
if args[0] == types.float16:
|
346
352
|
convertible = self.context.can_convert(args[1], args[0])
|
347
353
|
else:
|
348
354
|
convertible = self.context.can_convert(args[0], args[1])
|
@@ -355,9 +361,11 @@ def _fp16_binary_operator(l_key, retty):
|
|
355
361
|
# 3. fp16 to int8 (safe conversion) -
|
356
362
|
# - Conversion.safe
|
357
363
|
|
358
|
-
if (
|
359
|
-
|
360
|
-
|
364
|
+
if (
|
365
|
+
(convertible == Conversion.exact)
|
366
|
+
or (convertible == Conversion.promote)
|
367
|
+
or (convertible == Conversion.safe)
|
368
|
+
):
|
361
369
|
return signature(retty, types.float16, types.float16)
|
362
370
|
|
363
371
|
return Cuda_fp16_operator
|
@@ -404,38 +412,42 @@ _genfp16_binary_operator(operator.itruediv)
|
|
404
412
|
|
405
413
|
def _resolve_wrapped_unary(fname):
|
406
414
|
link = tuple()
|
407
|
-
decl = declare_device_function_template(
|
408
|
-
|
409
|
-
|
410
|
-
link)
|
415
|
+
decl = declare_device_function_template(
|
416
|
+
f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
|
417
|
+
)
|
411
418
|
return types.Function(decl)
|
412
419
|
|
413
420
|
|
414
421
|
def _resolve_wrapped_binary(fname):
|
415
422
|
link = tuple()
|
416
|
-
decl = declare_device_function_template(
|
417
|
-
|
418
|
-
|
419
|
-
|
423
|
+
decl = declare_device_function_template(
|
424
|
+
f"__numba_wrapper_{fname}",
|
425
|
+
types.float16,
|
426
|
+
(
|
427
|
+
types.float16,
|
428
|
+
types.float16,
|
429
|
+
),
|
430
|
+
link,
|
431
|
+
)
|
420
432
|
return types.Function(decl)
|
421
433
|
|
422
434
|
|
423
|
-
hsin_device = _resolve_wrapped_unary(
|
424
|
-
hcos_device = _resolve_wrapped_unary(
|
425
|
-
hlog_device = _resolve_wrapped_unary(
|
426
|
-
hlog10_device = _resolve_wrapped_unary(
|
427
|
-
hlog2_device = _resolve_wrapped_unary(
|
428
|
-
hexp_device = _resolve_wrapped_unary(
|
429
|
-
hexp10_device = _resolve_wrapped_unary(
|
430
|
-
hexp2_device = _resolve_wrapped_unary(
|
431
|
-
hsqrt_device = _resolve_wrapped_unary(
|
432
|
-
hrsqrt_device = _resolve_wrapped_unary(
|
433
|
-
hfloor_device = _resolve_wrapped_unary(
|
434
|
-
hceil_device = _resolve_wrapped_unary(
|
435
|
-
hrcp_device = _resolve_wrapped_unary(
|
436
|
-
hrint_device = _resolve_wrapped_unary(
|
437
|
-
htrunc_device = _resolve_wrapped_unary(
|
438
|
-
hdiv_device = _resolve_wrapped_binary(
|
435
|
+
hsin_device = _resolve_wrapped_unary("hsin")
|
436
|
+
hcos_device = _resolve_wrapped_unary("hcos")
|
437
|
+
hlog_device = _resolve_wrapped_unary("hlog")
|
438
|
+
hlog10_device = _resolve_wrapped_unary("hlog10")
|
439
|
+
hlog2_device = _resolve_wrapped_unary("hlog2")
|
440
|
+
hexp_device = _resolve_wrapped_unary("hexp")
|
441
|
+
hexp10_device = _resolve_wrapped_unary("hexp10")
|
442
|
+
hexp2_device = _resolve_wrapped_unary("hexp2")
|
443
|
+
hsqrt_device = _resolve_wrapped_unary("hsqrt")
|
444
|
+
hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
|
445
|
+
hfloor_device = _resolve_wrapped_unary("hfloor")
|
446
|
+
hceil_device = _resolve_wrapped_unary("hceil")
|
447
|
+
hrcp_device = _resolve_wrapped_unary("hrcp")
|
448
|
+
hrint_device = _resolve_wrapped_unary("hrint")
|
449
|
+
htrunc_device = _resolve_wrapped_unary("htrunc")
|
450
|
+
hdiv_device = _resolve_wrapped_binary("hdiv")
|
439
451
|
|
440
452
|
|
441
453
|
# generate atomic operations
|
@@ -455,15 +467,20 @@ def _gen(l_key, supported_types):
|
|
455
467
|
return signature(ary.dtype, ary, types.intp, ary.dtype)
|
456
468
|
elif ary.ndim > 1:
|
457
469
|
return signature(ary.dtype, ary, idx, ary.dtype)
|
470
|
+
|
458
471
|
return Cuda_atomic
|
459
472
|
|
460
473
|
|
461
|
-
all_numba_types = (
|
462
|
-
|
463
|
-
|
474
|
+
all_numba_types = (
|
475
|
+
types.float64,
|
476
|
+
types.float32,
|
477
|
+
types.int32,
|
478
|
+
types.uint32,
|
479
|
+
types.int64,
|
480
|
+
types.uint64,
|
481
|
+
)
|
464
482
|
|
465
|
-
integer_numba_types = (types.int32, types.uint32,
|
466
|
-
types.int64, types.uint64)
|
483
|
+
integer_numba_types = (types.int32, types.uint32, types.int64, types.uint64)
|
467
484
|
|
468
485
|
unsigned_int_numba_types = (types.uint32, types.uint64)
|
469
486
|
|
@@ -759,9 +776,6 @@ class CudaModuleTemplate(AttributeTemplate):
|
|
759
776
|
def resolve_syncwarp(self, mod):
|
760
777
|
return types.Function(Cuda_syncwarp)
|
761
778
|
|
762
|
-
def resolve_shfl_sync_intrinsic(self, mod):
|
763
|
-
return types.Function(Cuda_shfl_sync_intrinsic)
|
764
|
-
|
765
779
|
def resolve_vote_sync_intrinsic(self, mod):
|
766
780
|
return types.Function(Cuda_vote_sync_intrinsic)
|
767
781
|
|
@@ -811,5 +825,5 @@ for func in bit_twiddling_functions:
|
|
811
825
|
register_numpy_ufunc(func, register_global)
|
812
826
|
|
813
827
|
for func in math_operations:
|
814
|
-
if func in (
|
828
|
+
if func in ("log", "log2", "log10"):
|
815
829
|
register_numpy_ufunc(func, register_global)
|