numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.py +17 -13
- numba_cuda/VERSION +1 -1
- numba_cuda/_version.py +4 -1
- numba_cuda/numba/cuda/__init__.py +6 -2
- numba_cuda/numba/cuda/api.py +129 -86
- numba_cuda/numba/cuda/api_util.py +3 -3
- numba_cuda/numba/cuda/args.py +12 -16
- numba_cuda/numba/cuda/cg.py +6 -6
- numba_cuda/numba/cuda/codegen.py +74 -43
- numba_cuda/numba/cuda/compiler.py +246 -114
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
- numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
- numba_cuda/numba/cuda/cuda_paths.py +293 -99
- numba_cuda/numba/cuda/cudadecl.py +93 -79
- numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
- numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
- numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
- numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
- numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
- numba_cuda/numba/cuda/cudadrv/error.py +6 -2
- numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
- numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
- numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
- numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
- numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
- numba_cuda/numba/cuda/cudaimpl.py +296 -275
- numba_cuda/numba/cuda/cudamath.py +1 -1
- numba_cuda/numba/cuda/debuginfo.py +99 -7
- numba_cuda/numba/cuda/decorators.py +87 -45
- numba_cuda/numba/cuda/descriptor.py +1 -1
- numba_cuda/numba/cuda/device_init.py +68 -18
- numba_cuda/numba/cuda/deviceufunc.py +143 -98
- numba_cuda/numba/cuda/dispatcher.py +300 -213
- numba_cuda/numba/cuda/errors.py +13 -10
- numba_cuda/numba/cuda/extending.py +55 -1
- numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
- numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
- numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
- numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
- numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/initialize.py +5 -3
- numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
- numba_cuda/numba/cuda/intrinsics.py +203 -28
- numba_cuda/numba/cuda/kernels/reduction.py +13 -13
- numba_cuda/numba/cuda/kernels/transpose.py +3 -6
- numba_cuda/numba/cuda/libdevice.py +317 -317
- numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
- numba_cuda/numba/cuda/locks.py +16 -0
- numba_cuda/numba/cuda/lowering.py +43 -0
- numba_cuda/numba/cuda/mathimpl.py +62 -57
- numba_cuda/numba/cuda/models.py +1 -5
- numba_cuda/numba/cuda/nvvmutils.py +103 -88
- numba_cuda/numba/cuda/printimpl.py +9 -5
- numba_cuda/numba/cuda/random.py +46 -36
- numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
- numba_cuda/numba/cuda/runtime/__init__.py +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
- numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
- numba_cuda/numba/cuda/runtime/nrt.py +48 -43
- numba_cuda/numba/cuda/simulator/__init__.py +22 -12
- numba_cuda/numba/cuda/simulator/api.py +38 -22
- numba_cuda/numba/cuda/simulator/compiler.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
- numba_cuda/numba/cuda/simulator/kernel.py +43 -34
- numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
- numba_cuda/numba/cuda/simulator/reduction.py +1 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
- numba_cuda/numba/cuda/simulator_init.py +2 -4
- numba_cuda/numba/cuda/stubs.py +134 -108
- numba_cuda/numba/cuda/target.py +92 -47
- numba_cuda/numba/cuda/testing.py +24 -19
- numba_cuda/numba/cuda/tests/__init__.py +14 -12
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
- numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
- numba_cuda/numba/cuda/types.py +5 -2
- numba_cuda/numba/cuda/ufuncs.py +382 -362
- numba_cuda/numba/cuda/utils.py +2 -2
- numba_cuda/numba/cuda/vector_types.py +5 -3
- numba_cuda/numba/cuda/vectorizers.py +38 -33
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
- numba_cuda-0.10.0.dist-info/RECORD +263 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
- numba_cuda-0.8.1.dist-info/RECORD +0 -251
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -29,48 +29,49 @@ def initialize_dim3(builder, prefix):
|
|
29
29
|
return cgutils.pack_struct(builder, (x, y, z))
|
30
30
|
|
31
31
|
|
32
|
-
@lower_attr(types.Module(cuda),
|
32
|
+
@lower_attr(types.Module(cuda), "threadIdx")
|
33
33
|
def cuda_threadIdx(context, builder, sig, args):
|
34
|
-
return initialize_dim3(builder,
|
34
|
+
return initialize_dim3(builder, "tid")
|
35
35
|
|
36
36
|
|
37
|
-
@lower_attr(types.Module(cuda),
|
37
|
+
@lower_attr(types.Module(cuda), "blockDim")
|
38
38
|
def cuda_blockDim(context, builder, sig, args):
|
39
|
-
return initialize_dim3(builder,
|
39
|
+
return initialize_dim3(builder, "ntid")
|
40
40
|
|
41
41
|
|
42
|
-
@lower_attr(types.Module(cuda),
|
42
|
+
@lower_attr(types.Module(cuda), "blockIdx")
|
43
43
|
def cuda_blockIdx(context, builder, sig, args):
|
44
|
-
return initialize_dim3(builder,
|
44
|
+
return initialize_dim3(builder, "ctaid")
|
45
45
|
|
46
46
|
|
47
|
-
@lower_attr(types.Module(cuda),
|
47
|
+
@lower_attr(types.Module(cuda), "gridDim")
|
48
48
|
def cuda_gridDim(context, builder, sig, args):
|
49
|
-
return initialize_dim3(builder,
|
49
|
+
return initialize_dim3(builder, "nctaid")
|
50
50
|
|
51
51
|
|
52
|
-
@lower_attr(types.Module(cuda),
|
52
|
+
@lower_attr(types.Module(cuda), "laneid")
|
53
53
|
def cuda_laneid(context, builder, sig, args):
|
54
|
-
return nvvmutils.call_sreg(builder,
|
54
|
+
return nvvmutils.call_sreg(builder, "laneid")
|
55
55
|
|
56
56
|
|
57
|
-
@lower_attr(dim3,
|
57
|
+
@lower_attr(dim3, "x")
|
58
58
|
def dim3_x(context, builder, sig, args):
|
59
59
|
return builder.extract_value(args, 0)
|
60
60
|
|
61
61
|
|
62
|
-
@lower_attr(dim3,
|
62
|
+
@lower_attr(dim3, "y")
|
63
63
|
def dim3_y(context, builder, sig, args):
|
64
64
|
return builder.extract_value(args, 1)
|
65
65
|
|
66
66
|
|
67
|
-
@lower_attr(dim3,
|
67
|
+
@lower_attr(dim3, "z")
|
68
68
|
def dim3_z(context, builder, sig, args):
|
69
69
|
return builder.extract_value(args, 2)
|
70
70
|
|
71
71
|
|
72
72
|
# -----------------------------------------------------------------------------
|
73
73
|
|
74
|
+
|
74
75
|
@lower(cuda.const.array_like, types.Array)
|
75
76
|
def cuda_const_array_like(context, builder, sig, args):
|
76
77
|
# This is a no-op because CUDATargetContext.make_constant_array already
|
@@ -95,48 +96,68 @@ def _get_unique_smem_id(name):
|
|
95
96
|
def cuda_shared_array_integer(context, builder, sig, args):
|
96
97
|
length = sig.args[0].literal_value
|
97
98
|
dtype = parse_dtype(sig.args[1])
|
98
|
-
return _generic_array(
|
99
|
-
|
100
|
-
|
101
|
-
|
99
|
+
return _generic_array(
|
100
|
+
context,
|
101
|
+
builder,
|
102
|
+
shape=(length,),
|
103
|
+
dtype=dtype,
|
104
|
+
symbol_name=_get_unique_smem_id("_cudapy_smem"),
|
105
|
+
addrspace=nvvm.ADDRSPACE_SHARED,
|
106
|
+
can_dynsized=True,
|
107
|
+
)
|
102
108
|
|
103
109
|
|
104
110
|
@lower(cuda.shared.array, types.Tuple, types.Any)
|
105
111
|
@lower(cuda.shared.array, types.UniTuple, types.Any)
|
106
112
|
def cuda_shared_array_tuple(context, builder, sig, args):
|
107
|
-
shape = [
|
113
|
+
shape = [s.literal_value for s in sig.args[0]]
|
108
114
|
dtype = parse_dtype(sig.args[1])
|
109
|
-
return _generic_array(
|
110
|
-
|
111
|
-
|
112
|
-
|
115
|
+
return _generic_array(
|
116
|
+
context,
|
117
|
+
builder,
|
118
|
+
shape=shape,
|
119
|
+
dtype=dtype,
|
120
|
+
symbol_name=_get_unique_smem_id("_cudapy_smem"),
|
121
|
+
addrspace=nvvm.ADDRSPACE_SHARED,
|
122
|
+
can_dynsized=True,
|
123
|
+
)
|
113
124
|
|
114
125
|
|
115
126
|
@lower(cuda.local.array, types.IntegerLiteral, types.Any)
|
116
127
|
def cuda_local_array_integer(context, builder, sig, args):
|
117
128
|
length = sig.args[0].literal_value
|
118
129
|
dtype = parse_dtype(sig.args[1])
|
119
|
-
return _generic_array(
|
120
|
-
|
121
|
-
|
122
|
-
|
130
|
+
return _generic_array(
|
131
|
+
context,
|
132
|
+
builder,
|
133
|
+
shape=(length,),
|
134
|
+
dtype=dtype,
|
135
|
+
symbol_name="_cudapy_lmem",
|
136
|
+
addrspace=nvvm.ADDRSPACE_LOCAL,
|
137
|
+
can_dynsized=False,
|
138
|
+
)
|
123
139
|
|
124
140
|
|
125
141
|
@lower(cuda.local.array, types.Tuple, types.Any)
|
126
142
|
@lower(cuda.local.array, types.UniTuple, types.Any)
|
127
143
|
def ptx_lmem_alloc_array(context, builder, sig, args):
|
128
|
-
shape = [
|
144
|
+
shape = [s.literal_value for s in sig.args[0]]
|
129
145
|
dtype = parse_dtype(sig.args[1])
|
130
|
-
return _generic_array(
|
131
|
-
|
132
|
-
|
133
|
-
|
146
|
+
return _generic_array(
|
147
|
+
context,
|
148
|
+
builder,
|
149
|
+
shape=shape,
|
150
|
+
dtype=dtype,
|
151
|
+
symbol_name="_cudapy_lmem",
|
152
|
+
addrspace=nvvm.ADDRSPACE_LOCAL,
|
153
|
+
can_dynsized=False,
|
154
|
+
)
|
134
155
|
|
135
156
|
|
136
157
|
@lower(stubs.threadfence_block)
|
137
158
|
def ptx_threadfence_block(context, builder, sig, args):
|
138
159
|
assert not args
|
139
|
-
fname =
|
160
|
+
fname = "llvm.nvvm.membar.cta"
|
140
161
|
lmod = builder.module
|
141
162
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
142
163
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -147,7 +168,7 @@ def ptx_threadfence_block(context, builder, sig, args):
|
|
147
168
|
@lower(stubs.threadfence_system)
|
148
169
|
def ptx_threadfence_system(context, builder, sig, args):
|
149
170
|
assert not args
|
150
|
-
fname =
|
171
|
+
fname = "llvm.nvvm.membar.sys"
|
151
172
|
lmod = builder.module
|
152
173
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
153
174
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -158,7 +179,7 @@ def ptx_threadfence_system(context, builder, sig, args):
|
|
158
179
|
@lower(stubs.threadfence)
|
159
180
|
def ptx_threadfence_device(context, builder, sig, args):
|
160
181
|
assert not args
|
161
|
-
fname =
|
182
|
+
fname = "llvm.nvvm.membar.gl"
|
162
183
|
lmod = builder.module
|
163
184
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
164
185
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -175,7 +196,7 @@ def ptx_syncwarp(context, builder, sig, args):
|
|
175
196
|
|
176
197
|
@lower(stubs.syncwarp, types.i4)
|
177
198
|
def ptx_syncwarp_mask(context, builder, sig, args):
|
178
|
-
fname =
|
199
|
+
fname = "llvm.nvvm.bar.warp.sync"
|
179
200
|
lmod = builder.module
|
180
201
|
fnty = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
|
181
202
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -183,68 +204,15 @@ def ptx_syncwarp_mask(context, builder, sig, args):
|
|
183
204
|
return context.get_dummy_value()
|
184
205
|
|
185
206
|
|
186
|
-
@lower(stubs.
|
187
|
-
|
188
|
-
|
189
|
-
types.i4)
|
190
|
-
@lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4,
|
191
|
-
types.i4)
|
192
|
-
@lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4,
|
193
|
-
types.i4)
|
194
|
-
def ptx_shfl_sync_i32(context, builder, sig, args):
|
195
|
-
"""
|
196
|
-
The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
|
197
|
-
function supports both 32 and 64 bit ints and floats, so for feature parity,
|
198
|
-
i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
|
199
|
-
an int, then shuffling, then bitcasting back. And 64-bit values by packing
|
200
|
-
them into 2 32bit values, shuffling thoose, and then packing back together.
|
201
|
-
"""
|
202
|
-
mask, mode, value, index, clamp = args
|
203
|
-
value_type = sig.args[2]
|
204
|
-
if value_type in types.real_domain:
|
205
|
-
value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
|
206
|
-
fname = 'llvm.nvvm.shfl.sync.i32'
|
207
|
+
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
|
208
|
+
def ptx_vote_sync(context, builder, sig, args):
|
209
|
+
fname = "llvm.nvvm.vote.sync"
|
207
210
|
lmod = builder.module
|
208
211
|
fnty = ir.FunctionType(
|
209
212
|
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
210
|
-
|
211
|
-
ir.IntType(32), ir.IntType(32))
|
213
|
+
(ir.IntType(32), ir.IntType(32), ir.IntType(1)),
|
212
214
|
)
|
213
215
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
214
|
-
if value_type.bitwidth == 32:
|
215
|
-
ret = builder.call(func, (mask, mode, value, index, clamp))
|
216
|
-
if value_type == types.float32:
|
217
|
-
rv = builder.extract_value(ret, 0)
|
218
|
-
pred = builder.extract_value(ret, 1)
|
219
|
-
fv = builder.bitcast(rv, ir.FloatType())
|
220
|
-
ret = cgutils.make_anonymous_struct(builder, (fv, pred))
|
221
|
-
else:
|
222
|
-
value1 = builder.trunc(value, ir.IntType(32))
|
223
|
-
value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
|
224
|
-
value2 = builder.trunc(value_lshr, ir.IntType(32))
|
225
|
-
ret1 = builder.call(func, (mask, mode, value1, index, clamp))
|
226
|
-
ret2 = builder.call(func, (mask, mode, value2, index, clamp))
|
227
|
-
rv1 = builder.extract_value(ret1, 0)
|
228
|
-
rv2 = builder.extract_value(ret2, 0)
|
229
|
-
pred = builder.extract_value(ret1, 1)
|
230
|
-
rv1_64 = builder.zext(rv1, ir.IntType(64))
|
231
|
-
rv2_64 = builder.zext(rv2, ir.IntType(64))
|
232
|
-
rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
|
233
|
-
rv = builder.or_(rv_shl, rv1_64)
|
234
|
-
if value_type == types.float64:
|
235
|
-
rv = builder.bitcast(rv, ir.DoubleType())
|
236
|
-
ret = cgutils.make_anonymous_struct(builder, (rv, pred))
|
237
|
-
return ret
|
238
|
-
|
239
|
-
|
240
|
-
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
|
241
|
-
def ptx_vote_sync(context, builder, sig, args):
|
242
|
-
fname = 'llvm.nvvm.vote.sync'
|
243
|
-
lmod = builder.module
|
244
|
-
fnty = ir.FunctionType(ir.LiteralStructType((ir.IntType(32),
|
245
|
-
ir.IntType(1))),
|
246
|
-
(ir.IntType(32), ir.IntType(32), ir.IntType(1)))
|
247
|
-
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
248
216
|
return builder.call(func, args)
|
249
217
|
|
250
218
|
|
@@ -257,7 +225,7 @@ def ptx_match_any_sync(context, builder, sig, args):
|
|
257
225
|
width = sig.args[1].bitwidth
|
258
226
|
if sig.args[1] in types.real_domain:
|
259
227
|
value = builder.bitcast(value, ir.IntType(width))
|
260
|
-
fname =
|
228
|
+
fname = "llvm.nvvm.match.any.sync.i{}".format(width)
|
261
229
|
lmod = builder.module
|
262
230
|
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32), ir.IntType(width)))
|
263
231
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -273,27 +241,35 @@ def ptx_match_all_sync(context, builder, sig, args):
|
|
273
241
|
width = sig.args[1].bitwidth
|
274
242
|
if sig.args[1] in types.real_domain:
|
275
243
|
value = builder.bitcast(value, ir.IntType(width))
|
276
|
-
fname =
|
244
|
+
fname = "llvm.nvvm.match.all.sync.i{}".format(width)
|
277
245
|
lmod = builder.module
|
278
|
-
fnty = ir.FunctionType(
|
279
|
-
|
280
|
-
|
246
|
+
fnty = ir.FunctionType(
|
247
|
+
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
248
|
+
(ir.IntType(32), ir.IntType(width)),
|
249
|
+
)
|
281
250
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
282
251
|
return builder.call(func, (mask, value))
|
283
252
|
|
284
253
|
|
285
254
|
@lower(stubs.activemask)
|
286
255
|
def ptx_activemask(context, builder, sig, args):
|
287
|
-
activemask = ir.InlineAsm(
|
288
|
-
|
256
|
+
activemask = ir.InlineAsm(
|
257
|
+
ir.FunctionType(ir.IntType(32), []),
|
258
|
+
"activemask.b32 $0;",
|
259
|
+
"=r",
|
260
|
+
side_effect=True,
|
261
|
+
)
|
289
262
|
return builder.call(activemask, [])
|
290
263
|
|
291
264
|
|
292
265
|
@lower(stubs.lanemask_lt)
|
293
266
|
def ptx_lanemask_lt(context, builder, sig, args):
|
294
|
-
activemask = ir.InlineAsm(
|
295
|
-
|
296
|
-
|
267
|
+
activemask = ir.InlineAsm(
|
268
|
+
ir.FunctionType(ir.IntType(32), []),
|
269
|
+
"mov.u32 $0, %lanemask_lt;",
|
270
|
+
"=r",
|
271
|
+
side_effect=True,
|
272
|
+
)
|
297
273
|
return builder.call(activemask, [])
|
298
274
|
|
299
275
|
|
@@ -308,7 +284,7 @@ def ptx_fma(context, builder, sig, args):
|
|
308
284
|
|
309
285
|
|
310
286
|
def float16_float_ty_constraint(bitwidth):
|
311
|
-
typemap = {32: (
|
287
|
+
typemap = {32: ("f32", "f"), 64: ("f64", "d")}
|
312
288
|
|
313
289
|
try:
|
314
290
|
return typemap[bitwidth]
|
@@ -342,7 +318,7 @@ def float_to_float16_cast(context, builder, fromty, toty, val):
|
|
342
318
|
|
343
319
|
|
344
320
|
def float16_int_constraint(bitwidth):
|
345
|
-
typemap = {
|
321
|
+
typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
|
346
322
|
|
347
323
|
try:
|
348
324
|
return typemap[bitwidth]
|
@@ -355,12 +331,12 @@ def float16_int_constraint(bitwidth):
|
|
355
331
|
def float16_to_integer_cast(context, builder, fromty, toty, val):
|
356
332
|
bitwidth = toty.bitwidth
|
357
333
|
constraint = float16_int_constraint(bitwidth)
|
358
|
-
signedness =
|
334
|
+
signedness = "s" if toty.signed else "u"
|
359
335
|
|
360
336
|
fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
|
361
|
-
asm = ir.InlineAsm(
|
362
|
-
|
363
|
-
|
337
|
+
asm = ir.InlineAsm(
|
338
|
+
fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
|
339
|
+
)
|
364
340
|
return builder.call(asm, [val])
|
365
341
|
|
366
342
|
|
@@ -369,40 +345,38 @@ def float16_to_integer_cast(context, builder, fromty, toty, val):
|
|
369
345
|
def integer_to_float16_cast(context, builder, fromty, toty, val):
|
370
346
|
bitwidth = fromty.bitwidth
|
371
347
|
constraint = float16_int_constraint(bitwidth)
|
372
|
-
signedness =
|
348
|
+
signedness = "s" if fromty.signed else "u"
|
373
349
|
|
374
|
-
fnty = ir.FunctionType(ir.IntType(16),
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
f"=h,{constraint}")
|
350
|
+
fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
|
351
|
+
asm = ir.InlineAsm(
|
352
|
+
fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
|
353
|
+
)
|
379
354
|
return builder.call(asm, [val])
|
380
355
|
|
381
356
|
|
382
357
|
def lower_fp16_binary(fn, op):
|
383
358
|
@lower(fn, types.float16, types.float16)
|
384
359
|
def ptx_fp16_binary(context, builder, sig, args):
|
385
|
-
fnty = ir.FunctionType(ir.IntType(16),
|
386
|
-
|
387
|
-
asm = ir.InlineAsm(fnty, f'{op}.f16 $0,$1,$2;', '=h,h,h')
|
360
|
+
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
361
|
+
asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
|
388
362
|
return builder.call(asm, args)
|
389
363
|
|
390
364
|
|
391
|
-
lower_fp16_binary(stubs.fp16.hadd,
|
392
|
-
lower_fp16_binary(operator.add,
|
393
|
-
lower_fp16_binary(operator.iadd,
|
394
|
-
lower_fp16_binary(stubs.fp16.hsub,
|
395
|
-
lower_fp16_binary(operator.sub,
|
396
|
-
lower_fp16_binary(operator.isub,
|
397
|
-
lower_fp16_binary(stubs.fp16.hmul,
|
398
|
-
lower_fp16_binary(operator.mul,
|
399
|
-
lower_fp16_binary(operator.imul,
|
365
|
+
lower_fp16_binary(stubs.fp16.hadd, "add")
|
366
|
+
lower_fp16_binary(operator.add, "add")
|
367
|
+
lower_fp16_binary(operator.iadd, "add")
|
368
|
+
lower_fp16_binary(stubs.fp16.hsub, "sub")
|
369
|
+
lower_fp16_binary(operator.sub, "sub")
|
370
|
+
lower_fp16_binary(operator.isub, "sub")
|
371
|
+
lower_fp16_binary(stubs.fp16.hmul, "mul")
|
372
|
+
lower_fp16_binary(operator.mul, "mul")
|
373
|
+
lower_fp16_binary(operator.imul, "mul")
|
400
374
|
|
401
375
|
|
402
376
|
@lower(stubs.fp16.hneg, types.float16)
|
403
377
|
def ptx_fp16_hneg(context, builder, sig, args):
|
404
378
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
405
|
-
asm = ir.InlineAsm(fnty,
|
379
|
+
asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
|
406
380
|
return builder.call(asm, args)
|
407
381
|
|
408
382
|
|
@@ -414,7 +388,7 @@ def operator_hneg(context, builder, sig, args):
|
|
414
388
|
@lower(stubs.fp16.habs, types.float16)
|
415
389
|
def ptx_fp16_habs(context, builder, sig, args):
|
416
390
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
417
|
-
asm = ir.InlineAsm(fnty,
|
391
|
+
asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
|
418
392
|
return builder.call(asm, args)
|
419
393
|
|
420
394
|
|
@@ -450,27 +424,28 @@ _fp16_cmp = """{{
|
|
450
424
|
def _gen_fp16_cmp(op):
|
451
425
|
def ptx_fp16_comparison(context, builder, sig, args):
|
452
426
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
453
|
-
asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op),
|
427
|
+
asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
|
454
428
|
result = builder.call(asm, args)
|
455
429
|
|
456
430
|
zero = context.get_constant(types.int16, 0)
|
457
431
|
int_result = builder.bitcast(result, ir.IntType(16))
|
458
432
|
return builder.icmp_unsigned("!=", int_result, zero)
|
433
|
+
|
459
434
|
return ptx_fp16_comparison
|
460
435
|
|
461
436
|
|
462
|
-
lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp(
|
463
|
-
lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp(
|
464
|
-
lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp(
|
465
|
-
lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp(
|
466
|
-
lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp(
|
467
|
-
lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp(
|
468
|
-
lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp(
|
469
|
-
lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp(
|
470
|
-
lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp(
|
471
|
-
lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp(
|
472
|
-
lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp(
|
473
|
-
lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp(
|
437
|
+
lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
438
|
+
lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
439
|
+
lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
440
|
+
lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
441
|
+
lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
442
|
+
lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
443
|
+
lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
444
|
+
lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
445
|
+
lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
446
|
+
lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
447
|
+
lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
448
|
+
lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
474
449
|
|
475
450
|
|
476
451
|
def lower_fp16_minmax(fn, fname, op):
|
@@ -480,8 +455,8 @@ def lower_fp16_minmax(fn, fname, op):
|
|
480
455
|
return builder.select(choice, args[0], args[1])
|
481
456
|
|
482
457
|
|
483
|
-
lower_fp16_minmax(stubs.fp16.hmax,
|
484
|
-
lower_fp16_minmax(stubs.fp16.hmin,
|
458
|
+
lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
|
459
|
+
lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
|
485
460
|
|
486
461
|
# See:
|
487
462
|
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
|
@@ -489,8 +464,8 @@ lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
|
|
489
464
|
|
490
465
|
|
491
466
|
cbrt_funcs = {
|
492
|
-
types.float32:
|
493
|
-
types.float64:
|
467
|
+
types.float32: "__nv_cbrtf",
|
468
|
+
types.float64: "__nv_cbrt",
|
494
469
|
}
|
495
470
|
|
496
471
|
|
@@ -514,7 +489,8 @@ def ptx_brev_u4(context, builder, sig, args):
|
|
514
489
|
fn = cgutils.get_or_insert_function(
|
515
490
|
builder.module,
|
516
491
|
ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
|
517
|
-
|
492
|
+
"__nv_brev",
|
493
|
+
)
|
518
494
|
return builder.call(fn, args)
|
519
495
|
|
520
496
|
|
@@ -526,15 +502,14 @@ def ptx_brev_u8(context, builder, sig, args):
|
|
526
502
|
fn = cgutils.get_or_insert_function(
|
527
503
|
builder.module,
|
528
504
|
ir.FunctionType(ir.IntType(64), (ir.IntType(64),)),
|
529
|
-
|
505
|
+
"__nv_brevll",
|
506
|
+
)
|
530
507
|
return builder.call(fn, args)
|
531
508
|
|
532
509
|
|
533
510
|
@lower(stubs.clz, types.Any)
|
534
511
|
def ptx_clz(context, builder, sig, args):
|
535
|
-
return builder.ctlz(
|
536
|
-
args[0],
|
537
|
-
context.get_constant(types.boolean, 0))
|
512
|
+
return builder.ctlz(args[0], context.get_constant(types.boolean, 0))
|
538
513
|
|
539
514
|
|
540
515
|
@lower(stubs.ffs, types.i4)
|
@@ -543,7 +518,8 @@ def ptx_ffs_32(context, builder, sig, args):
|
|
543
518
|
fn = cgutils.get_or_insert_function(
|
544
519
|
builder.module,
|
545
520
|
ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
|
546
|
-
|
521
|
+
"__nv_ffs",
|
522
|
+
)
|
547
523
|
return builder.call(fn, args)
|
548
524
|
|
549
525
|
|
@@ -553,7 +529,8 @@ def ptx_ffs_64(context, builder, sig, args):
|
|
553
529
|
fn = cgutils.get_or_insert_function(
|
554
530
|
builder.module,
|
555
531
|
ir.FunctionType(ir.IntType(32), (ir.IntType(64),)),
|
556
|
-
|
532
|
+
"__nv_ffsll",
|
533
|
+
)
|
557
534
|
return builder.call(fn, args)
|
558
535
|
|
559
536
|
|
@@ -567,10 +544,9 @@ def ptx_selp(context, builder, sig, args):
|
|
567
544
|
def ptx_max_f4(context, builder, sig, args):
|
568
545
|
fn = cgutils.get_or_insert_function(
|
569
546
|
builder.module,
|
570
|
-
ir.FunctionType(
|
571
|
-
|
572
|
-
|
573
|
-
'__nv_fmaxf')
|
547
|
+
ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
|
548
|
+
"__nv_fmaxf",
|
549
|
+
)
|
574
550
|
return builder.call(fn, args)
|
575
551
|
|
576
552
|
|
@@ -580,25 +556,26 @@ def ptx_max_f4(context, builder, sig, args):
|
|
580
556
|
def ptx_max_f8(context, builder, sig, args):
|
581
557
|
fn = cgutils.get_or_insert_function(
|
582
558
|
builder.module,
|
583
|
-
ir.FunctionType(
|
584
|
-
|
585
|
-
|
586
|
-
'__nv_fmax')
|
559
|
+
ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
|
560
|
+
"__nv_fmax",
|
561
|
+
)
|
587
562
|
|
588
|
-
return builder.call(
|
589
|
-
|
590
|
-
|
591
|
-
|
563
|
+
return builder.call(
|
564
|
+
fn,
|
565
|
+
[
|
566
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
567
|
+
context.cast(builder, args[1], sig.args[1], types.double),
|
568
|
+
],
|
569
|
+
)
|
592
570
|
|
593
571
|
|
594
572
|
@lower(min, types.f4, types.f4)
|
595
573
|
def ptx_min_f4(context, builder, sig, args):
|
596
574
|
fn = cgutils.get_or_insert_function(
|
597
575
|
builder.module,
|
598
|
-
ir.FunctionType(
|
599
|
-
|
600
|
-
|
601
|
-
'__nv_fminf')
|
576
|
+
ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
|
577
|
+
"__nv_fminf",
|
578
|
+
)
|
602
579
|
return builder.call(fn, args)
|
603
580
|
|
604
581
|
|
@@ -608,15 +585,17 @@ def ptx_min_f4(context, builder, sig, args):
|
|
608
585
|
def ptx_min_f8(context, builder, sig, args):
|
609
586
|
fn = cgutils.get_or_insert_function(
|
610
587
|
builder.module,
|
611
|
-
ir.FunctionType(
|
612
|
-
|
613
|
-
|
614
|
-
'__nv_fmin')
|
588
|
+
ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
|
589
|
+
"__nv_fmin",
|
590
|
+
)
|
615
591
|
|
616
|
-
return builder.call(
|
617
|
-
|
618
|
-
|
619
|
-
|
592
|
+
return builder.call(
|
593
|
+
fn,
|
594
|
+
[
|
595
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
596
|
+
context.cast(builder, args[1], sig.args[1], types.double),
|
597
|
+
],
|
598
|
+
)
|
620
599
|
|
621
600
|
|
622
601
|
@lower(round, types.f4)
|
@@ -624,19 +603,22 @@ def ptx_min_f8(context, builder, sig, args):
|
|
624
603
|
def ptx_round(context, builder, sig, args):
|
625
604
|
fn = cgutils.get_or_insert_function(
|
626
605
|
builder.module,
|
627
|
-
ir.FunctionType(
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
606
|
+
ir.FunctionType(ir.IntType(64), (ir.DoubleType(),)),
|
607
|
+
"__nv_llrint",
|
608
|
+
)
|
609
|
+
return builder.call(
|
610
|
+
fn,
|
611
|
+
[
|
612
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
613
|
+
],
|
614
|
+
)
|
634
615
|
|
635
616
|
|
636
617
|
# This rounding implementation follows the algorithm used in the "fallback
|
637
618
|
# version" of double_round in CPython.
|
638
619
|
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/floatobject.c#L964-L1007
|
639
620
|
|
621
|
+
|
640
622
|
@lower(round, types.f4, types.Integer)
|
641
623
|
@lower(round, types.f8, types.Integer)
|
642
624
|
def round_to_impl(context, builder, sig, args):
|
@@ -651,7 +633,7 @@ def round_to_impl(context, builder, sig, args):
|
|
651
633
|
pow1 = 10.0 ** (ndigits - 22)
|
652
634
|
pow2 = 1e22
|
653
635
|
else:
|
654
|
-
pow1 = 10.0
|
636
|
+
pow1 = 10.0**ndigits
|
655
637
|
pow2 = 1.0
|
656
638
|
y = (x * pow1) * pow2
|
657
639
|
if math.isinf(y):
|
@@ -662,7 +644,7 @@ def round_to_impl(context, builder, sig, args):
|
|
662
644
|
y = x / pow1
|
663
645
|
|
664
646
|
z = round(y)
|
665
|
-
if
|
647
|
+
if math.fabs(y - z) == 0.5:
|
666
648
|
# halfway between two integers; use round-half-even
|
667
649
|
z = 2.0 * round(y / 2.0)
|
668
650
|
|
@@ -673,19 +655,25 @@ def round_to_impl(context, builder, sig, args):
|
|
673
655
|
|
674
656
|
return z
|
675
657
|
|
676
|
-
return context.compile_internal(
|
658
|
+
return context.compile_internal(
|
659
|
+
builder,
|
660
|
+
round_ndigits,
|
661
|
+
sig,
|
662
|
+
args,
|
663
|
+
)
|
677
664
|
|
678
665
|
|
679
666
|
def gen_deg_rad(const):
|
680
667
|
def impl(context, builder, sig, args):
|
681
|
-
argty, = sig.args
|
668
|
+
(argty,) = sig.args
|
682
669
|
factor = context.get_constant(argty, const)
|
683
670
|
return builder.fmul(factor, args[0])
|
671
|
+
|
684
672
|
return impl
|
685
673
|
|
686
674
|
|
687
|
-
_deg2rad = math.pi / 180.
|
688
|
-
_rad2deg = 180. / math.pi
|
675
|
+
_deg2rad = math.pi / 180.0
|
676
|
+
_rad2deg = 180.0 / math.pi
|
689
677
|
lower(math.radians, types.f4)(gen_deg_rad(_deg2rad))
|
690
678
|
lower(math.radians, types.f8)(gen_deg_rad(_deg2rad))
|
691
679
|
lower(math.degrees, types.f4)(gen_deg_rad(_rad2deg))
|
@@ -701,16 +689,18 @@ def _normalize_indices(context, builder, indty, inds, aryty, valty):
|
|
701
689
|
indices = [inds]
|
702
690
|
else:
|
703
691
|
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
|
704
|
-
indices = [
|
705
|
-
|
692
|
+
indices = [
|
693
|
+
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
|
694
|
+
]
|
706
695
|
|
707
696
|
dtype = aryty.dtype
|
708
697
|
if dtype != valty:
|
709
698
|
raise TypeError("expect %s but got %s" % (dtype, valty))
|
710
699
|
|
711
700
|
if aryty.ndim != len(indty):
|
712
|
-
raise TypeError(
|
713
|
-
|
701
|
+
raise TypeError(
|
702
|
+
"indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))
|
703
|
+
)
|
714
704
|
|
715
705
|
return indty, indices
|
716
706
|
|
@@ -722,14 +712,17 @@ def _atomic_dispatcher(dispatch_fn):
|
|
722
712
|
ary, inds, val = args
|
723
713
|
dtype = aryty.dtype
|
724
714
|
|
725
|
-
indty, indices = _normalize_indices(
|
726
|
-
|
715
|
+
indty, indices = _normalize_indices(
|
716
|
+
context, builder, indty, inds, aryty, valty
|
717
|
+
)
|
727
718
|
|
728
719
|
lary = context.make_array(aryty)(context, builder, ary)
|
729
|
-
ptr = cgutils.get_item_pointer(
|
730
|
-
|
720
|
+
ptr = cgutils.get_item_pointer(
|
721
|
+
context, builder, aryty, lary, indices, wraparound=True
|
722
|
+
)
|
731
723
|
# dispatcher to implementation base on dtype
|
732
724
|
return dispatch_fn(context, builder, dtype, ptr, val)
|
725
|
+
|
733
726
|
return imp
|
734
727
|
|
735
728
|
|
@@ -740,14 +733,16 @@ def _atomic_dispatcher(dispatch_fn):
|
|
740
733
|
def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
|
741
734
|
if dtype == types.float32:
|
742
735
|
lmod = builder.module
|
743
|
-
return builder.call(
|
744
|
-
|
736
|
+
return builder.call(
|
737
|
+
nvvmutils.declare_atomic_add_float32(lmod), (ptr, val)
|
738
|
+
)
|
745
739
|
elif dtype == types.float64:
|
746
740
|
lmod = builder.module
|
747
|
-
return builder.call(
|
748
|
-
|
741
|
+
return builder.call(
|
742
|
+
nvvmutils.declare_atomic_add_float64(lmod), (ptr, val)
|
743
|
+
)
|
749
744
|
else:
|
750
|
-
return builder.atomic_rmw(
|
745
|
+
return builder.atomic_rmw("add", ptr, val, "monotonic")
|
751
746
|
|
752
747
|
|
753
748
|
@lower(stubs.atomic.sub, types.Array, types.intp, types.Any)
|
@@ -757,14 +752,16 @@ def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
|
|
757
752
|
def ptx_atomic_sub(context, builder, dtype, ptr, val):
|
758
753
|
if dtype == types.float32:
|
759
754
|
lmod = builder.module
|
760
|
-
return builder.call(
|
761
|
-
|
755
|
+
return builder.call(
|
756
|
+
nvvmutils.declare_atomic_sub_float32(lmod), (ptr, val)
|
757
|
+
)
|
762
758
|
elif dtype == types.float64:
|
763
759
|
lmod = builder.module
|
764
|
-
return builder.call(
|
765
|
-
|
760
|
+
return builder.call(
|
761
|
+
nvvmutils.declare_atomic_sub_float64(lmod), (ptr, val)
|
762
|
+
)
|
766
763
|
else:
|
767
|
-
return builder.atomic_rmw(
|
764
|
+
return builder.atomic_rmw("sub", ptr, val, "monotonic")
|
768
765
|
|
769
766
|
|
770
767
|
@lower(stubs.atomic.inc, types.Array, types.intp, types.Any)
|
@@ -775,10 +772,10 @@ def ptx_atomic_inc(context, builder, dtype, ptr, val):
|
|
775
772
|
if dtype in cuda.cudadecl.unsigned_int_numba_types:
|
776
773
|
bw = dtype.bitwidth
|
777
774
|
lmod = builder.module
|
778
|
-
fn = getattr(nvvmutils, f
|
775
|
+
fn = getattr(nvvmutils, f"declare_atomic_inc_int{bw}")
|
779
776
|
return builder.call(fn(lmod), (ptr, val))
|
780
777
|
else:
|
781
|
-
raise TypeError(f
|
778
|
+
raise TypeError(f"Unimplemented atomic inc with {dtype} array")
|
782
779
|
|
783
780
|
|
784
781
|
@lower(stubs.atomic.dec, types.Array, types.intp, types.Any)
|
@@ -789,27 +786,27 @@ def ptx_atomic_dec(context, builder, dtype, ptr, val):
|
|
789
786
|
if dtype in cuda.cudadecl.unsigned_int_numba_types:
|
790
787
|
bw = dtype.bitwidth
|
791
788
|
lmod = builder.module
|
792
|
-
fn = getattr(nvvmutils, f
|
789
|
+
fn = getattr(nvvmutils, f"declare_atomic_dec_int{bw}")
|
793
790
|
return builder.call(fn(lmod), (ptr, val))
|
794
791
|
else:
|
795
|
-
raise TypeError(f
|
792
|
+
raise TypeError(f"Unimplemented atomic dec with {dtype} array")
|
796
793
|
|
797
794
|
|
798
795
|
def ptx_atomic_bitwise(stub, op):
|
799
796
|
@_atomic_dispatcher
|
800
797
|
def impl_ptx_atomic(context, builder, dtype, ptr, val):
|
801
798
|
if dtype in (cuda.cudadecl.integer_numba_types):
|
802
|
-
return builder.atomic_rmw(op, ptr, val,
|
799
|
+
return builder.atomic_rmw(op, ptr, val, "monotonic")
|
803
800
|
else:
|
804
|
-
raise TypeError(f
|
801
|
+
raise TypeError(f"Unimplemented atomic {op} with {dtype} array")
|
805
802
|
|
806
803
|
for ty in (types.intp, types.UniTuple, types.Tuple):
|
807
804
|
lower(stub, types.Array, ty, types.Any)(impl_ptx_atomic)
|
808
805
|
|
809
806
|
|
810
|
-
ptx_atomic_bitwise(stubs.atomic.and_,
|
811
|
-
ptx_atomic_bitwise(stubs.atomic.or_,
|
812
|
-
ptx_atomic_bitwise(stubs.atomic.xor,
|
807
|
+
ptx_atomic_bitwise(stubs.atomic.and_, "and")
|
808
|
+
ptx_atomic_bitwise(stubs.atomic.or_, "or")
|
809
|
+
ptx_atomic_bitwise(stubs.atomic.xor, "xor")
|
813
810
|
|
814
811
|
|
815
812
|
@lower(stubs.atomic.exch, types.Array, types.intp, types.Any)
|
@@ -818,9 +815,9 @@ ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
|
|
818
815
|
@_atomic_dispatcher
|
819
816
|
def ptx_atomic_exch(context, builder, dtype, ptr, val):
|
820
817
|
if dtype in (cuda.cudadecl.integer_numba_types):
|
821
|
-
return builder.atomic_rmw(
|
818
|
+
return builder.atomic_rmw("xchg", ptr, val, "monotonic")
|
822
819
|
else:
|
823
|
-
raise TypeError(f
|
820
|
+
raise TypeError(f"Unimplemented atomic exch with {dtype} array")
|
824
821
|
|
825
822
|
|
826
823
|
@lower(stubs.atomic.max, types.Array, types.intp, types.Any)
|
@@ -830,17 +827,19 @@ def ptx_atomic_exch(context, builder, dtype, ptr, val):
|
|
830
827
|
def ptx_atomic_max(context, builder, dtype, ptr, val):
|
831
828
|
lmod = builder.module
|
832
829
|
if dtype == types.float64:
|
833
|
-
return builder.call(
|
834
|
-
|
830
|
+
return builder.call(
|
831
|
+
nvvmutils.declare_atomic_max_float64(lmod), (ptr, val)
|
832
|
+
)
|
835
833
|
elif dtype == types.float32:
|
836
|
-
return builder.call(
|
837
|
-
|
834
|
+
return builder.call(
|
835
|
+
nvvmutils.declare_atomic_max_float32(lmod), (ptr, val)
|
836
|
+
)
|
838
837
|
elif dtype in (types.int32, types.int64):
|
839
|
-
return builder.atomic_rmw(
|
838
|
+
return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
|
840
839
|
elif dtype in (types.uint32, types.uint64):
|
841
|
-
return builder.atomic_rmw(
|
840
|
+
return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
|
842
841
|
else:
|
843
|
-
raise TypeError(
|
842
|
+
raise TypeError("Unimplemented atomic max with %s array" % dtype)
|
844
843
|
|
845
844
|
|
846
845
|
@lower(stubs.atomic.min, types.Array, types.intp, types.Any)
|
@@ -850,17 +849,19 @@ def ptx_atomic_max(context, builder, dtype, ptr, val):
|
|
850
849
|
def ptx_atomic_min(context, builder, dtype, ptr, val):
|
851
850
|
lmod = builder.module
|
852
851
|
if dtype == types.float64:
|
853
|
-
return builder.call(
|
854
|
-
|
852
|
+
return builder.call(
|
853
|
+
nvvmutils.declare_atomic_min_float64(lmod), (ptr, val)
|
854
|
+
)
|
855
855
|
elif dtype == types.float32:
|
856
|
-
return builder.call(
|
857
|
-
|
856
|
+
return builder.call(
|
857
|
+
nvvmutils.declare_atomic_min_float32(lmod), (ptr, val)
|
858
|
+
)
|
858
859
|
elif dtype in (types.int32, types.int64):
|
859
|
-
return builder.atomic_rmw(
|
860
|
+
return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
|
860
861
|
elif dtype in (types.uint32, types.uint64):
|
861
|
-
return builder.atomic_rmw(
|
862
|
+
return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
|
862
863
|
else:
|
863
|
-
raise TypeError(
|
864
|
+
raise TypeError("Unimplemented atomic min with %s array" % dtype)
|
864
865
|
|
865
866
|
|
866
867
|
@lower(stubs.atomic.nanmax, types.Array, types.intp, types.Any)
|
@@ -870,17 +871,19 @@ def ptx_atomic_min(context, builder, dtype, ptr, val):
|
|
870
871
|
def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
|
871
872
|
lmod = builder.module
|
872
873
|
if dtype == types.float64:
|
873
|
-
return builder.call(
|
874
|
-
|
874
|
+
return builder.call(
|
875
|
+
nvvmutils.declare_atomic_nanmax_float64(lmod), (ptr, val)
|
876
|
+
)
|
875
877
|
elif dtype == types.float32:
|
876
|
-
return builder.call(
|
877
|
-
|
878
|
+
return builder.call(
|
879
|
+
nvvmutils.declare_atomic_nanmax_float32(lmod), (ptr, val)
|
880
|
+
)
|
878
881
|
elif dtype in (types.int32, types.int64):
|
879
|
-
return builder.atomic_rmw(
|
882
|
+
return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
|
880
883
|
elif dtype in (types.uint32, types.uint64):
|
881
|
-
return builder.atomic_rmw(
|
884
|
+
return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
|
882
885
|
else:
|
883
|
-
raise TypeError(
|
886
|
+
raise TypeError("Unimplemented atomic max with %s array" % dtype)
|
884
887
|
|
885
888
|
|
886
889
|
@lower(stubs.atomic.nanmin, types.Array, types.intp, types.Any)
|
@@ -890,17 +893,19 @@ def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
|
|
890
893
|
def ptx_atomic_nanmin(context, builder, dtype, ptr, val):
|
891
894
|
lmod = builder.module
|
892
895
|
if dtype == types.float64:
|
893
|
-
return builder.call(
|
894
|
-
|
896
|
+
return builder.call(
|
897
|
+
nvvmutils.declare_atomic_nanmin_float64(lmod), (ptr, val)
|
898
|
+
)
|
895
899
|
elif dtype == types.float32:
|
896
|
-
return builder.call(
|
897
|
-
|
900
|
+
return builder.call(
|
901
|
+
nvvmutils.declare_atomic_nanmin_float32(lmod), (ptr, val)
|
902
|
+
)
|
898
903
|
elif dtype in (types.int32, types.int64):
|
899
|
-
return builder.atomic_rmw(
|
904
|
+
return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
|
900
905
|
elif dtype in (types.uint32, types.uint64):
|
901
|
-
return builder.atomic_rmw(
|
906
|
+
return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
|
902
907
|
else:
|
903
|
-
raise TypeError(
|
908
|
+
raise TypeError("Unimplemented atomic min with %s array" % dtype)
|
904
909
|
|
905
910
|
|
906
911
|
@lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
|
@@ -917,19 +922,21 @@ def ptx_atomic_cas(context, builder, sig, args):
|
|
917
922
|
aryty, indty, oldty, valty = sig.args
|
918
923
|
ary, inds, old, val = args
|
919
924
|
|
920
|
-
indty, indices = _normalize_indices(
|
921
|
-
|
925
|
+
indty, indices = _normalize_indices(
|
926
|
+
context, builder, indty, inds, aryty, valty
|
927
|
+
)
|
922
928
|
|
923
929
|
lary = context.make_array(aryty)(context, builder, ary)
|
924
|
-
ptr = cgutils.get_item_pointer(
|
925
|
-
|
930
|
+
ptr = cgutils.get_item_pointer(
|
931
|
+
context, builder, aryty, lary, indices, wraparound=True
|
932
|
+
)
|
926
933
|
|
927
934
|
if aryty.dtype in (cuda.cudadecl.integer_numba_types):
|
928
935
|
lmod = builder.module
|
929
936
|
bitwidth = aryty.dtype.bitwidth
|
930
937
|
return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
|
931
938
|
else:
|
932
|
-
raise TypeError(
|
939
|
+
raise TypeError("Unimplemented atomic cas with %s array" % aryty.dtype)
|
933
940
|
|
934
941
|
|
935
942
|
# -----------------------------------------------------------------------------
|
@@ -937,15 +944,20 @@ def ptx_atomic_cas(context, builder, sig, args):
|
|
937
944
|
|
938
945
|
@lower(breakpoint)
|
939
946
|
def ptx_brkpt(context, builder, sig, args):
|
940
|
-
brkpt = ir.InlineAsm(
|
941
|
-
|
947
|
+
brkpt = ir.InlineAsm(
|
948
|
+
ir.FunctionType(ir.VoidType(), []), "brkpt;", "", side_effect=True
|
949
|
+
)
|
942
950
|
builder.call(brkpt, ())
|
943
951
|
|
944
952
|
|
945
953
|
@lower(stubs.nanosleep, types.uint32)
|
946
954
|
def ptx_nanosleep(context, builder, sig, args):
|
947
|
-
nanosleep = ir.InlineAsm(
|
948
|
-
|
955
|
+
nanosleep = ir.InlineAsm(
|
956
|
+
ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
|
957
|
+
"nanosleep.u32 $0;",
|
958
|
+
"r",
|
959
|
+
side_effect=True,
|
960
|
+
)
|
949
961
|
ns = args[0]
|
950
962
|
builder.call(nanosleep, [ns])
|
951
963
|
|
@@ -953,8 +965,9 @@ def ptx_nanosleep(context, builder, sig, args):
|
|
953
965
|
# -----------------------------------------------------------------------------
|
954
966
|
|
955
967
|
|
956
|
-
def _generic_array(
|
957
|
-
|
968
|
+
def _generic_array(
|
969
|
+
context, builder, shape, dtype, symbol_name, addrspace, can_dynsized=False
|
970
|
+
):
|
958
971
|
elemcount = reduce(operator.mul, shape, 1)
|
959
972
|
|
960
973
|
# Check for valid shape for this type of allocation.
|
@@ -985,16 +998,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
985
998
|
lmod = builder.module
|
986
999
|
|
987
1000
|
# Create global variable in the requested address space
|
988
|
-
gvmem = cgutils.add_global_variable(
|
989
|
-
|
1001
|
+
gvmem = cgutils.add_global_variable(
|
1002
|
+
lmod, laryty, symbol_name, addrspace
|
1003
|
+
)
|
990
1004
|
# Specify alignment to avoid misalignment bug
|
991
1005
|
align = context.get_abi_sizeof(lldtype)
|
992
1006
|
# Alignment is required to be a power of 2 for shared memory. If it is
|
993
1007
|
# not a power of 2 (e.g. for a Record array) then round up accordingly.
|
994
|
-
gvmem.align = 1 << (align - 1
|
1008
|
+
gvmem.align = 1 << (align - 1).bit_length()
|
995
1009
|
|
996
1010
|
if dynamic_smem:
|
997
|
-
gvmem.linkage =
|
1011
|
+
gvmem.linkage = "external"
|
998
1012
|
else:
|
999
1013
|
## Comment out the following line to workaround a NVVM bug
|
1000
1014
|
## which generates a invalid symbol name when the linkage
|
@@ -1005,8 +1019,9 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1005
1019
|
gvmem.initializer = ir.Constant(laryty, ir.Undefined)
|
1006
1020
|
|
1007
1021
|
# Convert to generic address-space
|
1008
|
-
dataptr = builder.addrspacecast(
|
1009
|
-
|
1022
|
+
dataptr = builder.addrspacecast(
|
1023
|
+
gvmem, ir.PointerType(ir.IntType(8)), "generic"
|
1024
|
+
)
|
1010
1025
|
|
1011
1026
|
targetdata = ll.create_target_data(nvvm.NVVM().data_layout)
|
1012
1027
|
lldtype = context.get_data_type(dtype)
|
@@ -1027,11 +1042,15 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1027
1042
|
# Unfortunately NVVM does not provide an intrinsic for the
|
1028
1043
|
# %dynamic_smem_size register, so we must read it using inline
|
1029
1044
|
# assembly.
|
1030
|
-
get_dynshared_size = ir.InlineAsm(
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1045
|
+
get_dynshared_size = ir.InlineAsm(
|
1046
|
+
ir.FunctionType(ir.IntType(32), []),
|
1047
|
+
"mov.u32 $0, %dynamic_smem_size;",
|
1048
|
+
"=r",
|
1049
|
+
side_effect=True,
|
1050
|
+
)
|
1051
|
+
dynsmem_size = builder.zext(
|
1052
|
+
builder.call(get_dynshared_size, []), ir.IntType(64)
|
1053
|
+
)
|
1035
1054
|
# Only 1-D dynamic shared memory is supported so the following is a
|
1036
1055
|
# sufficient construction of the shape
|
1037
1056
|
kitemsize = context.get_constant(types.intp, itemsize)
|
@@ -1041,15 +1060,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1041
1060
|
|
1042
1061
|
# Create array object
|
1043
1062
|
ndim = len(shape)
|
1044
|
-
aryty = types.Array(dtype=dtype, ndim=ndim, layout=
|
1063
|
+
aryty = types.Array(dtype=dtype, ndim=ndim, layout="C")
|
1045
1064
|
ary = context.make_array(aryty)(context, builder)
|
1046
1065
|
|
1047
|
-
context.populate_array(
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1066
|
+
context.populate_array(
|
1067
|
+
ary,
|
1068
|
+
data=builder.bitcast(dataptr, ary.data.type),
|
1069
|
+
shape=kshape,
|
1070
|
+
strides=kstrides,
|
1071
|
+
itemsize=context.get_constant(types.intp, itemsize),
|
1072
|
+
meminfo=None,
|
1073
|
+
)
|
1053
1074
|
return ary._getvalue()
|
1054
1075
|
|
1055
1076
|
|