numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.py +17 -13
- numba_cuda/VERSION +1 -1
- numba_cuda/_version.py +4 -1
- numba_cuda/numba/cuda/__init__.py +6 -2
- numba_cuda/numba/cuda/api.py +129 -86
- numba_cuda/numba/cuda/api_util.py +3 -3
- numba_cuda/numba/cuda/args.py +12 -16
- numba_cuda/numba/cuda/cg.py +6 -6
- numba_cuda/numba/cuda/codegen.py +74 -43
- numba_cuda/numba/cuda/compiler.py +232 -113
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
- numba_cuda/numba/cuda/cuda_fp16.h +661 -661
- numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
- numba_cuda/numba/cuda/cuda_paths.py +291 -99
- numba_cuda/numba/cuda/cudadecl.py +125 -69
- numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
- numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
- numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
- numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
- numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
- numba_cuda/numba/cuda/cudadrv/error.py +6 -2
- numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
- numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
- numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
- numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
- numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
- numba_cuda/numba/cuda/cudaimpl.py +317 -233
- numba_cuda/numba/cuda/cudamath.py +1 -1
- numba_cuda/numba/cuda/debuginfo.py +8 -6
- numba_cuda/numba/cuda/decorators.py +75 -45
- numba_cuda/numba/cuda/descriptor.py +1 -1
- numba_cuda/numba/cuda/device_init.py +69 -18
- numba_cuda/numba/cuda/deviceufunc.py +143 -98
- numba_cuda/numba/cuda/dispatcher.py +300 -213
- numba_cuda/numba/cuda/errors.py +13 -10
- numba_cuda/numba/cuda/extending.py +1 -1
- numba_cuda/numba/cuda/initialize.py +5 -3
- numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
- numba_cuda/numba/cuda/intrinsics.py +31 -27
- numba_cuda/numba/cuda/kernels/reduction.py +13 -13
- numba_cuda/numba/cuda/kernels/transpose.py +3 -6
- numba_cuda/numba/cuda/libdevice.py +317 -317
- numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
- numba_cuda/numba/cuda/locks.py +16 -0
- numba_cuda/numba/cuda/mathimpl.py +62 -57
- numba_cuda/numba/cuda/models.py +1 -5
- numba_cuda/numba/cuda/nvvmutils.py +103 -88
- numba_cuda/numba/cuda/printimpl.py +9 -5
- numba_cuda/numba/cuda/random.py +46 -36
- numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
- numba_cuda/numba/cuda/runtime/__init__.py +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
- numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
- numba_cuda/numba/cuda/runtime/nrt.py +48 -43
- numba_cuda/numba/cuda/simulator/__init__.py +22 -12
- numba_cuda/numba/cuda/simulator/api.py +38 -22
- numba_cuda/numba/cuda/simulator/compiler.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
- numba_cuda/numba/cuda/simulator/kernel.py +43 -34
- numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
- numba_cuda/numba/cuda/simulator/reduction.py +1 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
- numba_cuda/numba/cuda/simulator_init.py +2 -4
- numba_cuda/numba/cuda/stubs.py +139 -102
- numba_cuda/numba/cuda/target.py +64 -47
- numba_cuda/numba/cuda/testing.py +24 -19
- numba_cuda/numba/cuda/tests/__init__.py +14 -12
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
- numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
- numba_cuda/numba/cuda/types.py +5 -2
- numba_cuda/numba/cuda/ufuncs.py +382 -362
- numba_cuda/numba/cuda/utils.py +2 -2
- numba_cuda/numba/cuda/vector_types.py +2 -2
- numba_cuda/numba/cuda/vectorizers.py +37 -32
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
- numba_cuda-0.9.0.dist-info/RECORD +253 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
- numba_cuda-0.8.1.dist-info/RECORD +0 -251
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,15 +1,23 @@
|
|
1
1
|
import operator
|
2
2
|
from numba.core import types
|
3
|
-
from numba.core.typing.npydecl import (
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
3
|
+
from numba.core.typing.npydecl import (
|
4
|
+
parse_dtype,
|
5
|
+
parse_shape,
|
6
|
+
register_number_classes,
|
7
|
+
register_numpy_ufunc,
|
8
|
+
trigonometric_functions,
|
9
|
+
comparison_functions,
|
10
|
+
math_operations,
|
11
|
+
bit_twiddling_functions,
|
12
|
+
)
|
13
|
+
from numba.core.typing.templates import (
|
14
|
+
AttributeTemplate,
|
15
|
+
ConcreteTemplate,
|
16
|
+
AbstractTemplate,
|
17
|
+
CallableTemplate,
|
18
|
+
signature,
|
19
|
+
Registry,
|
20
|
+
)
|
13
21
|
from numba.cuda.types import dim3
|
14
22
|
from numba.core.typeconv import Conversion
|
15
23
|
from numba import cuda
|
@@ -26,15 +34,15 @@ register_number_classes(register_global)
|
|
26
34
|
class Cuda_array_decl(CallableTemplate):
|
27
35
|
def generic(self):
|
28
36
|
def typer(shape, dtype):
|
29
|
-
|
30
37
|
# Only integer literals and tuples of integer literals are valid
|
31
38
|
# shapes
|
32
39
|
if isinstance(shape, types.Integer):
|
33
40
|
if not isinstance(shape, types.IntegerLiteral):
|
34
41
|
return None
|
35
42
|
elif isinstance(shape, (types.Tuple, types.UniTuple)):
|
36
|
-
if any(
|
37
|
-
|
43
|
+
if any(
|
44
|
+
[not isinstance(s, types.IntegerLiteral) for s in shape]
|
45
|
+
):
|
38
46
|
return None
|
39
47
|
else:
|
40
48
|
return None
|
@@ -42,7 +50,7 @@ class Cuda_array_decl(CallableTemplate):
|
|
42
50
|
ndim = parse_shape(shape)
|
43
51
|
nb_dtype = parse_dtype(dtype)
|
44
52
|
if nb_dtype is not None and ndim is not None:
|
45
|
-
return types.Array(dtype=nb_dtype, ndim=ndim, layout=
|
53
|
+
return types.Array(dtype=nb_dtype, ndim=ndim, layout="C")
|
46
54
|
|
47
55
|
return typer
|
48
56
|
|
@@ -64,6 +72,7 @@ class Cuda_const_array_like(CallableTemplate):
|
|
64
72
|
def generic(self):
|
65
73
|
def typer(ndarray):
|
66
74
|
return ndarray
|
75
|
+
|
67
76
|
return typer
|
68
77
|
|
69
78
|
|
@@ -95,22 +104,49 @@ class Cuda_syncwarp(ConcreteTemplate):
|
|
95
104
|
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
|
96
105
|
key = cuda.shfl_sync_intrinsic
|
97
106
|
cases = [
|
98
|
-
signature(
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
107
|
+
signature(
|
108
|
+
types.Tuple((types.i4, types.b1)),
|
109
|
+
types.i4,
|
110
|
+
types.i4,
|
111
|
+
types.i4,
|
112
|
+
types.i4,
|
113
|
+
types.i4,
|
114
|
+
),
|
115
|
+
signature(
|
116
|
+
types.Tuple((types.i8, types.b1)),
|
117
|
+
types.i4,
|
118
|
+
types.i4,
|
119
|
+
types.i8,
|
120
|
+
types.i4,
|
121
|
+
types.i4,
|
122
|
+
),
|
123
|
+
signature(
|
124
|
+
types.Tuple((types.f4, types.b1)),
|
125
|
+
types.i4,
|
126
|
+
types.i4,
|
127
|
+
types.f4,
|
128
|
+
types.i4,
|
129
|
+
types.i4,
|
130
|
+
),
|
131
|
+
signature(
|
132
|
+
types.Tuple((types.f8, types.b1)),
|
133
|
+
types.i4,
|
134
|
+
types.i4,
|
135
|
+
types.f8,
|
136
|
+
types.i4,
|
137
|
+
types.i4,
|
138
|
+
),
|
106
139
|
]
|
107
140
|
|
108
141
|
|
109
142
|
@register
|
110
143
|
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
|
111
144
|
key = cuda.vote_sync_intrinsic
|
112
|
-
cases = [
|
113
|
-
|
145
|
+
cases = [
|
146
|
+
signature(
|
147
|
+
types.Tuple((types.i4, types.b1)), types.i4, types.i4, types.b1
|
148
|
+
)
|
149
|
+
]
|
114
150
|
|
115
151
|
|
116
152
|
@register
|
@@ -153,6 +189,7 @@ class Cuda_popc(ConcreteTemplate):
|
|
153
189
|
Supported types from `llvm.popc`
|
154
190
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
155
191
|
"""
|
192
|
+
|
156
193
|
key = cuda.popc
|
157
194
|
cases = [
|
158
195
|
signature(types.int8, types.int8),
|
@@ -172,6 +209,7 @@ class Cuda_fma(ConcreteTemplate):
|
|
172
209
|
Supported types from `llvm.fma`
|
173
210
|
[here](https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#standard-c-library-intrinics)
|
174
211
|
"""
|
212
|
+
|
175
213
|
key = cuda.fma
|
176
214
|
cases = [
|
177
215
|
signature(types.float32, types.float32, types.float32, types.float32),
|
@@ -189,7 +227,6 @@ class Cuda_hfma(ConcreteTemplate):
|
|
189
227
|
|
190
228
|
@register
|
191
229
|
class Cuda_cbrt(ConcreteTemplate):
|
192
|
-
|
193
230
|
key = cuda.cbrt
|
194
231
|
cases = [
|
195
232
|
signature(types.float32, types.float32),
|
@@ -212,6 +249,7 @@ class Cuda_clz(ConcreteTemplate):
|
|
212
249
|
Supported types from `llvm.ctlz`
|
213
250
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
214
251
|
"""
|
252
|
+
|
215
253
|
key = cuda.clz
|
216
254
|
cases = [
|
217
255
|
signature(types.int8, types.int8),
|
@@ -231,6 +269,7 @@ class Cuda_ffs(ConcreteTemplate):
|
|
231
269
|
Supported types from `llvm.cttz`
|
232
270
|
[here](http://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#bit-manipulations-intrinics)
|
233
271
|
"""
|
272
|
+
|
234
273
|
key = cuda.ffs
|
235
274
|
cases = [
|
236
275
|
signature(types.uint32, types.int8),
|
@@ -254,10 +293,16 @@ class Cuda_selp(AbstractTemplate):
|
|
254
293
|
|
255
294
|
# per docs
|
256
295
|
# http://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
|
257
|
-
supported_types = (
|
258
|
-
|
259
|
-
|
260
|
-
|
296
|
+
supported_types = (
|
297
|
+
types.float64,
|
298
|
+
types.float32,
|
299
|
+
types.int16,
|
300
|
+
types.uint16,
|
301
|
+
types.int32,
|
302
|
+
types.uint32,
|
303
|
+
types.int64,
|
304
|
+
types.uint64,
|
305
|
+
)
|
261
306
|
|
262
307
|
if a != b or a not in supported_types:
|
263
308
|
return
|
@@ -298,7 +343,6 @@ def _genfp16_binary(l_key):
|
|
298
343
|
|
299
344
|
@register_global(float)
|
300
345
|
class Float(AbstractTemplate):
|
301
|
-
|
302
346
|
def generic(self, args, kws):
|
303
347
|
assert not kws
|
304
348
|
|
@@ -313,11 +357,11 @@ def _genfp16_binary_comparison(l_key):
|
|
313
357
|
class Cuda_fp16_cmp(ConcreteTemplate):
|
314
358
|
key = l_key
|
315
359
|
|
316
|
-
cases = [
|
317
|
-
|
318
|
-
]
|
360
|
+
cases = [signature(types.b1, types.float16, types.float16)]
|
361
|
+
|
319
362
|
return Cuda_fp16_cmp
|
320
363
|
|
364
|
+
|
321
365
|
# If multiple ConcreteTemplates provide typing for a single function, then
|
322
366
|
# function resolution will pick the first compatible typing it finds even if it
|
323
367
|
# involves inserting a cast that would be considered undesirable (in this
|
@@ -340,9 +384,10 @@ def _fp16_binary_operator(l_key, retty):
|
|
340
384
|
def generic(self, args, kws):
|
341
385
|
assert not kws
|
342
386
|
|
343
|
-
if len(args) == 2 and
|
344
|
-
|
345
|
-
|
387
|
+
if len(args) == 2 and (
|
388
|
+
args[0] == types.float16 or args[1] == types.float16
|
389
|
+
):
|
390
|
+
if args[0] == types.float16:
|
346
391
|
convertible = self.context.can_convert(args[1], args[0])
|
347
392
|
else:
|
348
393
|
convertible = self.context.can_convert(args[0], args[1])
|
@@ -355,9 +400,11 @@ def _fp16_binary_operator(l_key, retty):
|
|
355
400
|
# 3. fp16 to int8 (safe conversion) -
|
356
401
|
# - Conversion.safe
|
357
402
|
|
358
|
-
if (
|
359
|
-
|
360
|
-
|
403
|
+
if (
|
404
|
+
(convertible == Conversion.exact)
|
405
|
+
or (convertible == Conversion.promote)
|
406
|
+
or (convertible == Conversion.safe)
|
407
|
+
):
|
361
408
|
return signature(retty, types.float16, types.float16)
|
362
409
|
|
363
410
|
return Cuda_fp16_operator
|
@@ -404,38 +451,42 @@ _genfp16_binary_operator(operator.itruediv)
|
|
404
451
|
|
405
452
|
def _resolve_wrapped_unary(fname):
|
406
453
|
link = tuple()
|
407
|
-
decl = declare_device_function_template(
|
408
|
-
|
409
|
-
|
410
|
-
link)
|
454
|
+
decl = declare_device_function_template(
|
455
|
+
f"__numba_wrapper_{fname}", types.float16, (types.float16,), link
|
456
|
+
)
|
411
457
|
return types.Function(decl)
|
412
458
|
|
413
459
|
|
414
460
|
def _resolve_wrapped_binary(fname):
|
415
461
|
link = tuple()
|
416
|
-
decl = declare_device_function_template(
|
417
|
-
|
418
|
-
|
419
|
-
|
462
|
+
decl = declare_device_function_template(
|
463
|
+
f"__numba_wrapper_{fname}",
|
464
|
+
types.float16,
|
465
|
+
(
|
466
|
+
types.float16,
|
467
|
+
types.float16,
|
468
|
+
),
|
469
|
+
link,
|
470
|
+
)
|
420
471
|
return types.Function(decl)
|
421
472
|
|
422
473
|
|
423
|
-
hsin_device = _resolve_wrapped_unary(
|
424
|
-
hcos_device = _resolve_wrapped_unary(
|
425
|
-
hlog_device = _resolve_wrapped_unary(
|
426
|
-
hlog10_device = _resolve_wrapped_unary(
|
427
|
-
hlog2_device = _resolve_wrapped_unary(
|
428
|
-
hexp_device = _resolve_wrapped_unary(
|
429
|
-
hexp10_device = _resolve_wrapped_unary(
|
430
|
-
hexp2_device = _resolve_wrapped_unary(
|
431
|
-
hsqrt_device = _resolve_wrapped_unary(
|
432
|
-
hrsqrt_device = _resolve_wrapped_unary(
|
433
|
-
hfloor_device = _resolve_wrapped_unary(
|
434
|
-
hceil_device = _resolve_wrapped_unary(
|
435
|
-
hrcp_device = _resolve_wrapped_unary(
|
436
|
-
hrint_device = _resolve_wrapped_unary(
|
437
|
-
htrunc_device = _resolve_wrapped_unary(
|
438
|
-
hdiv_device = _resolve_wrapped_binary(
|
474
|
+
hsin_device = _resolve_wrapped_unary("hsin")
|
475
|
+
hcos_device = _resolve_wrapped_unary("hcos")
|
476
|
+
hlog_device = _resolve_wrapped_unary("hlog")
|
477
|
+
hlog10_device = _resolve_wrapped_unary("hlog10")
|
478
|
+
hlog2_device = _resolve_wrapped_unary("hlog2")
|
479
|
+
hexp_device = _resolve_wrapped_unary("hexp")
|
480
|
+
hexp10_device = _resolve_wrapped_unary("hexp10")
|
481
|
+
hexp2_device = _resolve_wrapped_unary("hexp2")
|
482
|
+
hsqrt_device = _resolve_wrapped_unary("hsqrt")
|
483
|
+
hrsqrt_device = _resolve_wrapped_unary("hrsqrt")
|
484
|
+
hfloor_device = _resolve_wrapped_unary("hfloor")
|
485
|
+
hceil_device = _resolve_wrapped_unary("hceil")
|
486
|
+
hrcp_device = _resolve_wrapped_unary("hrcp")
|
487
|
+
hrint_device = _resolve_wrapped_unary("hrint")
|
488
|
+
htrunc_device = _resolve_wrapped_unary("htrunc")
|
489
|
+
hdiv_device = _resolve_wrapped_binary("hdiv")
|
439
490
|
|
440
491
|
|
441
492
|
# generate atomic operations
|
@@ -455,15 +506,20 @@ def _gen(l_key, supported_types):
|
|
455
506
|
return signature(ary.dtype, ary, types.intp, ary.dtype)
|
456
507
|
elif ary.ndim > 1:
|
457
508
|
return signature(ary.dtype, ary, idx, ary.dtype)
|
509
|
+
|
458
510
|
return Cuda_atomic
|
459
511
|
|
460
512
|
|
461
|
-
all_numba_types = (
|
462
|
-
|
463
|
-
|
513
|
+
all_numba_types = (
|
514
|
+
types.float64,
|
515
|
+
types.float32,
|
516
|
+
types.int32,
|
517
|
+
types.uint32,
|
518
|
+
types.int64,
|
519
|
+
types.uint64,
|
520
|
+
)
|
464
521
|
|
465
|
-
integer_numba_types = (types.int32, types.uint32,
|
466
|
-
types.int64, types.uint64)
|
522
|
+
integer_numba_types = (types.int32, types.uint32, types.int64, types.uint64)
|
467
523
|
|
468
524
|
unsigned_int_numba_types = (types.uint32, types.uint64)
|
469
525
|
|
@@ -811,5 +867,5 @@ for func in bit_twiddling_functions:
|
|
811
867
|
register_numpy_ufunc(func, register_global)
|
812
868
|
|
813
869
|
for func in math_operations:
|
814
|
-
if func in (
|
870
|
+
if func in ("log", "log2", "log10"):
|
815
871
|
register_numpy_ufunc(func, register_global)
|