numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- _numba_cuda_redirector.py +17 -13
- numba_cuda/VERSION +1 -1
- numba_cuda/_version.py +4 -1
- numba_cuda/numba/cuda/__init__.py +6 -2
- numba_cuda/numba/cuda/api.py +129 -86
- numba_cuda/numba/cuda/api_util.py +3 -3
- numba_cuda/numba/cuda/args.py +12 -16
- numba_cuda/numba/cuda/cg.py +6 -6
- numba_cuda/numba/cuda/codegen.py +74 -43
- numba_cuda/numba/cuda/compiler.py +232 -113
- numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
- numba_cuda/numba/cuda/cuda_fp16.h +661 -661
- numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
- numba_cuda/numba/cuda/cuda_paths.py +291 -99
- numba_cuda/numba/cuda/cudadecl.py +125 -69
- numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
- numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
- numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
- numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
- numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
- numba_cuda/numba/cuda/cudadrv/error.py +6 -2
- numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
- numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
- numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
- numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
- numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
- numba_cuda/numba/cuda/cudaimpl.py +317 -233
- numba_cuda/numba/cuda/cudamath.py +1 -1
- numba_cuda/numba/cuda/debuginfo.py +8 -6
- numba_cuda/numba/cuda/decorators.py +75 -45
- numba_cuda/numba/cuda/descriptor.py +1 -1
- numba_cuda/numba/cuda/device_init.py +69 -18
- numba_cuda/numba/cuda/deviceufunc.py +143 -98
- numba_cuda/numba/cuda/dispatcher.py +300 -213
- numba_cuda/numba/cuda/errors.py +13 -10
- numba_cuda/numba/cuda/extending.py +1 -1
- numba_cuda/numba/cuda/initialize.py +5 -3
- numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
- numba_cuda/numba/cuda/intrinsics.py +31 -27
- numba_cuda/numba/cuda/kernels/reduction.py +13 -13
- numba_cuda/numba/cuda/kernels/transpose.py +3 -6
- numba_cuda/numba/cuda/libdevice.py +317 -317
- numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
- numba_cuda/numba/cuda/locks.py +16 -0
- numba_cuda/numba/cuda/mathimpl.py +62 -57
- numba_cuda/numba/cuda/models.py +1 -5
- numba_cuda/numba/cuda/nvvmutils.py +103 -88
- numba_cuda/numba/cuda/printimpl.py +9 -5
- numba_cuda/numba/cuda/random.py +46 -36
- numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
- numba_cuda/numba/cuda/runtime/__init__.py +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
- numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
- numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
- numba_cuda/numba/cuda/runtime/nrt.py +48 -43
- numba_cuda/numba/cuda/simulator/__init__.py +22 -12
- numba_cuda/numba/cuda/simulator/api.py +38 -22
- numba_cuda/numba/cuda/simulator/compiler.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
- numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
- numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
- numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
- numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
- numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
- numba_cuda/numba/cuda/simulator/kernel.py +43 -34
- numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
- numba_cuda/numba/cuda/simulator/reduction.py +1 -0
- numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
- numba_cuda/numba/cuda/simulator_init.py +2 -4
- numba_cuda/numba/cuda/stubs.py +139 -102
- numba_cuda/numba/cuda/target.py +64 -47
- numba_cuda/numba/cuda/testing.py +24 -19
- numba_cuda/numba/cuda/tests/__init__.py +14 -12
- numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
- numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
- numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
- numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
- numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
- numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
- numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
- numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
- numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
- numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
- numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
- numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
- numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
- numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
- numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
- numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
- numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
- numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
- numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
- numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
- numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
- numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
- numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
- numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
- numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
- numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
- numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
- numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
- numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
- numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
- numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
- numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
- numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
- numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
- numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
- numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
- numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
- numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
- numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
- numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
- numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
- numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
- numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
- numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
- numba_cuda/numba/cuda/types.py +5 -2
- numba_cuda/numba/cuda/ufuncs.py +382 -362
- numba_cuda/numba/cuda/utils.py +2 -2
- numba_cuda/numba/cuda/vector_types.py +2 -2
- numba_cuda/numba/cuda/vectorizers.py +37 -32
- {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
- numba_cuda-0.9.0.dist-info/RECORD +253 -0
- {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
- numba_cuda-0.8.0.dist-info/RECORD +0 -251
- {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -29,48 +29,49 @@ def initialize_dim3(builder, prefix):
|
|
29
29
|
return cgutils.pack_struct(builder, (x, y, z))
|
30
30
|
|
31
31
|
|
32
|
-
@lower_attr(types.Module(cuda),
|
32
|
+
@lower_attr(types.Module(cuda), "threadIdx")
|
33
33
|
def cuda_threadIdx(context, builder, sig, args):
|
34
|
-
return initialize_dim3(builder,
|
34
|
+
return initialize_dim3(builder, "tid")
|
35
35
|
|
36
36
|
|
37
|
-
@lower_attr(types.Module(cuda),
|
37
|
+
@lower_attr(types.Module(cuda), "blockDim")
|
38
38
|
def cuda_blockDim(context, builder, sig, args):
|
39
|
-
return initialize_dim3(builder,
|
39
|
+
return initialize_dim3(builder, "ntid")
|
40
40
|
|
41
41
|
|
42
|
-
@lower_attr(types.Module(cuda),
|
42
|
+
@lower_attr(types.Module(cuda), "blockIdx")
|
43
43
|
def cuda_blockIdx(context, builder, sig, args):
|
44
|
-
return initialize_dim3(builder,
|
44
|
+
return initialize_dim3(builder, "ctaid")
|
45
45
|
|
46
46
|
|
47
|
-
@lower_attr(types.Module(cuda),
|
47
|
+
@lower_attr(types.Module(cuda), "gridDim")
|
48
48
|
def cuda_gridDim(context, builder, sig, args):
|
49
|
-
return initialize_dim3(builder,
|
49
|
+
return initialize_dim3(builder, "nctaid")
|
50
50
|
|
51
51
|
|
52
|
-
@lower_attr(types.Module(cuda),
|
52
|
+
@lower_attr(types.Module(cuda), "laneid")
|
53
53
|
def cuda_laneid(context, builder, sig, args):
|
54
|
-
return nvvmutils.call_sreg(builder,
|
54
|
+
return nvvmutils.call_sreg(builder, "laneid")
|
55
55
|
|
56
56
|
|
57
|
-
@lower_attr(dim3,
|
57
|
+
@lower_attr(dim3, "x")
|
58
58
|
def dim3_x(context, builder, sig, args):
|
59
59
|
return builder.extract_value(args, 0)
|
60
60
|
|
61
61
|
|
62
|
-
@lower_attr(dim3,
|
62
|
+
@lower_attr(dim3, "y")
|
63
63
|
def dim3_y(context, builder, sig, args):
|
64
64
|
return builder.extract_value(args, 1)
|
65
65
|
|
66
66
|
|
67
|
-
@lower_attr(dim3,
|
67
|
+
@lower_attr(dim3, "z")
|
68
68
|
def dim3_z(context, builder, sig, args):
|
69
69
|
return builder.extract_value(args, 2)
|
70
70
|
|
71
71
|
|
72
72
|
# -----------------------------------------------------------------------------
|
73
73
|
|
74
|
+
|
74
75
|
@lower(cuda.const.array_like, types.Array)
|
75
76
|
def cuda_const_array_like(context, builder, sig, args):
|
76
77
|
# This is a no-op because CUDATargetContext.make_constant_array already
|
@@ -95,48 +96,68 @@ def _get_unique_smem_id(name):
|
|
95
96
|
def cuda_shared_array_integer(context, builder, sig, args):
|
96
97
|
length = sig.args[0].literal_value
|
97
98
|
dtype = parse_dtype(sig.args[1])
|
98
|
-
return _generic_array(
|
99
|
-
|
100
|
-
|
101
|
-
|
99
|
+
return _generic_array(
|
100
|
+
context,
|
101
|
+
builder,
|
102
|
+
shape=(length,),
|
103
|
+
dtype=dtype,
|
104
|
+
symbol_name=_get_unique_smem_id("_cudapy_smem"),
|
105
|
+
addrspace=nvvm.ADDRSPACE_SHARED,
|
106
|
+
can_dynsized=True,
|
107
|
+
)
|
102
108
|
|
103
109
|
|
104
110
|
@lower(cuda.shared.array, types.Tuple, types.Any)
|
105
111
|
@lower(cuda.shared.array, types.UniTuple, types.Any)
|
106
112
|
def cuda_shared_array_tuple(context, builder, sig, args):
|
107
|
-
shape = [
|
113
|
+
shape = [s.literal_value for s in sig.args[0]]
|
108
114
|
dtype = parse_dtype(sig.args[1])
|
109
|
-
return _generic_array(
|
110
|
-
|
111
|
-
|
112
|
-
|
115
|
+
return _generic_array(
|
116
|
+
context,
|
117
|
+
builder,
|
118
|
+
shape=shape,
|
119
|
+
dtype=dtype,
|
120
|
+
symbol_name=_get_unique_smem_id("_cudapy_smem"),
|
121
|
+
addrspace=nvvm.ADDRSPACE_SHARED,
|
122
|
+
can_dynsized=True,
|
123
|
+
)
|
113
124
|
|
114
125
|
|
115
126
|
@lower(cuda.local.array, types.IntegerLiteral, types.Any)
|
116
127
|
def cuda_local_array_integer(context, builder, sig, args):
|
117
128
|
length = sig.args[0].literal_value
|
118
129
|
dtype = parse_dtype(sig.args[1])
|
119
|
-
return _generic_array(
|
120
|
-
|
121
|
-
|
122
|
-
|
130
|
+
return _generic_array(
|
131
|
+
context,
|
132
|
+
builder,
|
133
|
+
shape=(length,),
|
134
|
+
dtype=dtype,
|
135
|
+
symbol_name="_cudapy_lmem",
|
136
|
+
addrspace=nvvm.ADDRSPACE_LOCAL,
|
137
|
+
can_dynsized=False,
|
138
|
+
)
|
123
139
|
|
124
140
|
|
125
141
|
@lower(cuda.local.array, types.Tuple, types.Any)
|
126
142
|
@lower(cuda.local.array, types.UniTuple, types.Any)
|
127
143
|
def ptx_lmem_alloc_array(context, builder, sig, args):
|
128
|
-
shape = [
|
144
|
+
shape = [s.literal_value for s in sig.args[0]]
|
129
145
|
dtype = parse_dtype(sig.args[1])
|
130
|
-
return _generic_array(
|
131
|
-
|
132
|
-
|
133
|
-
|
146
|
+
return _generic_array(
|
147
|
+
context,
|
148
|
+
builder,
|
149
|
+
shape=shape,
|
150
|
+
dtype=dtype,
|
151
|
+
symbol_name="_cudapy_lmem",
|
152
|
+
addrspace=nvvm.ADDRSPACE_LOCAL,
|
153
|
+
can_dynsized=False,
|
154
|
+
)
|
134
155
|
|
135
156
|
|
136
157
|
@lower(stubs.threadfence_block)
|
137
158
|
def ptx_threadfence_block(context, builder, sig, args):
|
138
159
|
assert not args
|
139
|
-
fname =
|
160
|
+
fname = "llvm.nvvm.membar.cta"
|
140
161
|
lmod = builder.module
|
141
162
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
142
163
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -147,7 +168,7 @@ def ptx_threadfence_block(context, builder, sig, args):
|
|
147
168
|
@lower(stubs.threadfence_system)
|
148
169
|
def ptx_threadfence_system(context, builder, sig, args):
|
149
170
|
assert not args
|
150
|
-
fname =
|
171
|
+
fname = "llvm.nvvm.membar.sys"
|
151
172
|
lmod = builder.module
|
152
173
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
153
174
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -158,7 +179,7 @@ def ptx_threadfence_system(context, builder, sig, args):
|
|
158
179
|
@lower(stubs.threadfence)
|
159
180
|
def ptx_threadfence_device(context, builder, sig, args):
|
160
181
|
assert not args
|
161
|
-
fname =
|
182
|
+
fname = "llvm.nvvm.membar.gl"
|
162
183
|
lmod = builder.module
|
163
184
|
fnty = ir.FunctionType(ir.VoidType(), ())
|
164
185
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -175,7 +196,7 @@ def ptx_syncwarp(context, builder, sig, args):
|
|
175
196
|
|
176
197
|
@lower(stubs.syncwarp, types.i4)
|
177
198
|
def ptx_syncwarp_mask(context, builder, sig, args):
|
178
|
-
fname =
|
199
|
+
fname = "llvm.nvvm.bar.warp.sync"
|
179
200
|
lmod = builder.module
|
180
201
|
fnty = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
|
181
202
|
sync = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -183,14 +204,18 @@ def ptx_syncwarp_mask(context, builder, sig, args):
|
|
183
204
|
return context.get_dummy_value()
|
184
205
|
|
185
206
|
|
186
|
-
@lower(
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
@lower(
|
193
|
-
|
207
|
+
@lower(
|
208
|
+
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
|
209
|
+
)
|
210
|
+
@lower(
|
211
|
+
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
|
212
|
+
)
|
213
|
+
@lower(
|
214
|
+
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
|
215
|
+
)
|
216
|
+
@lower(
|
217
|
+
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
|
218
|
+
)
|
194
219
|
def ptx_shfl_sync_i32(context, builder, sig, args):
|
195
220
|
"""
|
196
221
|
The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
|
@@ -203,12 +228,17 @@ def ptx_shfl_sync_i32(context, builder, sig, args):
|
|
203
228
|
value_type = sig.args[2]
|
204
229
|
if value_type in types.real_domain:
|
205
230
|
value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
|
206
|
-
fname =
|
231
|
+
fname = "llvm.nvvm.shfl.sync.i32"
|
207
232
|
lmod = builder.module
|
208
233
|
fnty = ir.FunctionType(
|
209
234
|
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
210
|
-
|
211
|
-
|
235
|
+
(
|
236
|
+
ir.IntType(32),
|
237
|
+
ir.IntType(32),
|
238
|
+
ir.IntType(32),
|
239
|
+
ir.IntType(32),
|
240
|
+
ir.IntType(32),
|
241
|
+
),
|
212
242
|
)
|
213
243
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
214
244
|
if value_type.bitwidth == 32:
|
@@ -239,11 +269,12 @@ def ptx_shfl_sync_i32(context, builder, sig, args):
|
|
239
269
|
|
240
270
|
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
|
241
271
|
def ptx_vote_sync(context, builder, sig, args):
|
242
|
-
fname =
|
272
|
+
fname = "llvm.nvvm.vote.sync"
|
243
273
|
lmod = builder.module
|
244
|
-
fnty = ir.FunctionType(
|
245
|
-
|
246
|
-
|
274
|
+
fnty = ir.FunctionType(
|
275
|
+
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
276
|
+
(ir.IntType(32), ir.IntType(32), ir.IntType(1)),
|
277
|
+
)
|
247
278
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
248
279
|
return builder.call(func, args)
|
249
280
|
|
@@ -257,7 +288,7 @@ def ptx_match_any_sync(context, builder, sig, args):
|
|
257
288
|
width = sig.args[1].bitwidth
|
258
289
|
if sig.args[1] in types.real_domain:
|
259
290
|
value = builder.bitcast(value, ir.IntType(width))
|
260
|
-
fname =
|
291
|
+
fname = "llvm.nvvm.match.any.sync.i{}".format(width)
|
261
292
|
lmod = builder.module
|
262
293
|
fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32), ir.IntType(width)))
|
263
294
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
@@ -273,27 +304,35 @@ def ptx_match_all_sync(context, builder, sig, args):
|
|
273
304
|
width = sig.args[1].bitwidth
|
274
305
|
if sig.args[1] in types.real_domain:
|
275
306
|
value = builder.bitcast(value, ir.IntType(width))
|
276
|
-
fname =
|
307
|
+
fname = "llvm.nvvm.match.all.sync.i{}".format(width)
|
277
308
|
lmod = builder.module
|
278
|
-
fnty = ir.FunctionType(
|
279
|
-
|
280
|
-
|
309
|
+
fnty = ir.FunctionType(
|
310
|
+
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
311
|
+
(ir.IntType(32), ir.IntType(width)),
|
312
|
+
)
|
281
313
|
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
282
314
|
return builder.call(func, (mask, value))
|
283
315
|
|
284
316
|
|
285
317
|
@lower(stubs.activemask)
|
286
318
|
def ptx_activemask(context, builder, sig, args):
|
287
|
-
activemask = ir.InlineAsm(
|
288
|
-
|
319
|
+
activemask = ir.InlineAsm(
|
320
|
+
ir.FunctionType(ir.IntType(32), []),
|
321
|
+
"activemask.b32 $0;",
|
322
|
+
"=r",
|
323
|
+
side_effect=True,
|
324
|
+
)
|
289
325
|
return builder.call(activemask, [])
|
290
326
|
|
291
327
|
|
292
328
|
@lower(stubs.lanemask_lt)
|
293
329
|
def ptx_lanemask_lt(context, builder, sig, args):
|
294
|
-
activemask = ir.InlineAsm(
|
295
|
-
|
296
|
-
|
330
|
+
activemask = ir.InlineAsm(
|
331
|
+
ir.FunctionType(ir.IntType(32), []),
|
332
|
+
"mov.u32 $0, %lanemask_lt;",
|
333
|
+
"=r",
|
334
|
+
side_effect=True,
|
335
|
+
)
|
297
336
|
return builder.call(activemask, [])
|
298
337
|
|
299
338
|
|
@@ -308,7 +347,7 @@ def ptx_fma(context, builder, sig, args):
|
|
308
347
|
|
309
348
|
|
310
349
|
def float16_float_ty_constraint(bitwidth):
|
311
|
-
typemap = {32: (
|
350
|
+
typemap = {32: ("f32", "f"), 64: ("f64", "d")}
|
312
351
|
|
313
352
|
try:
|
314
353
|
return typemap[bitwidth]
|
@@ -342,7 +381,7 @@ def float_to_float16_cast(context, builder, fromty, toty, val):
|
|
342
381
|
|
343
382
|
|
344
383
|
def float16_int_constraint(bitwidth):
|
345
|
-
typemap = {
|
384
|
+
typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
|
346
385
|
|
347
386
|
try:
|
348
387
|
return typemap[bitwidth]
|
@@ -355,12 +394,12 @@ def float16_int_constraint(bitwidth):
|
|
355
394
|
def float16_to_integer_cast(context, builder, fromty, toty, val):
|
356
395
|
bitwidth = toty.bitwidth
|
357
396
|
constraint = float16_int_constraint(bitwidth)
|
358
|
-
signedness =
|
397
|
+
signedness = "s" if toty.signed else "u"
|
359
398
|
|
360
399
|
fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
|
361
|
-
asm = ir.InlineAsm(
|
362
|
-
|
363
|
-
|
400
|
+
asm = ir.InlineAsm(
|
401
|
+
fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
|
402
|
+
)
|
364
403
|
return builder.call(asm, [val])
|
365
404
|
|
366
405
|
|
@@ -369,40 +408,38 @@ def float16_to_integer_cast(context, builder, fromty, toty, val):
|
|
369
408
|
def integer_to_float16_cast(context, builder, fromty, toty, val):
|
370
409
|
bitwidth = fromty.bitwidth
|
371
410
|
constraint = float16_int_constraint(bitwidth)
|
372
|
-
signedness =
|
411
|
+
signedness = "s" if fromty.signed else "u"
|
373
412
|
|
374
|
-
fnty = ir.FunctionType(ir.IntType(16),
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
f"=h,{constraint}")
|
413
|
+
fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
|
414
|
+
asm = ir.InlineAsm(
|
415
|
+
fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
|
416
|
+
)
|
379
417
|
return builder.call(asm, [val])
|
380
418
|
|
381
419
|
|
382
420
|
def lower_fp16_binary(fn, op):
|
383
421
|
@lower(fn, types.float16, types.float16)
|
384
422
|
def ptx_fp16_binary(context, builder, sig, args):
|
385
|
-
fnty = ir.FunctionType(ir.IntType(16),
|
386
|
-
|
387
|
-
asm = ir.InlineAsm(fnty, f'{op}.f16 $0,$1,$2;', '=h,h,h')
|
423
|
+
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
424
|
+
asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
|
388
425
|
return builder.call(asm, args)
|
389
426
|
|
390
427
|
|
391
|
-
lower_fp16_binary(stubs.fp16.hadd,
|
392
|
-
lower_fp16_binary(operator.add,
|
393
|
-
lower_fp16_binary(operator.iadd,
|
394
|
-
lower_fp16_binary(stubs.fp16.hsub,
|
395
|
-
lower_fp16_binary(operator.sub,
|
396
|
-
lower_fp16_binary(operator.isub,
|
397
|
-
lower_fp16_binary(stubs.fp16.hmul,
|
398
|
-
lower_fp16_binary(operator.mul,
|
399
|
-
lower_fp16_binary(operator.imul,
|
428
|
+
lower_fp16_binary(stubs.fp16.hadd, "add")
|
429
|
+
lower_fp16_binary(operator.add, "add")
|
430
|
+
lower_fp16_binary(operator.iadd, "add")
|
431
|
+
lower_fp16_binary(stubs.fp16.hsub, "sub")
|
432
|
+
lower_fp16_binary(operator.sub, "sub")
|
433
|
+
lower_fp16_binary(operator.isub, "sub")
|
434
|
+
lower_fp16_binary(stubs.fp16.hmul, "mul")
|
435
|
+
lower_fp16_binary(operator.mul, "mul")
|
436
|
+
lower_fp16_binary(operator.imul, "mul")
|
400
437
|
|
401
438
|
|
402
439
|
@lower(stubs.fp16.hneg, types.float16)
|
403
440
|
def ptx_fp16_hneg(context, builder, sig, args):
|
404
441
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
405
|
-
asm = ir.InlineAsm(fnty,
|
442
|
+
asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
|
406
443
|
return builder.call(asm, args)
|
407
444
|
|
408
445
|
|
@@ -414,7 +451,7 @@ def operator_hneg(context, builder, sig, args):
|
|
414
451
|
@lower(stubs.fp16.habs, types.float16)
|
415
452
|
def ptx_fp16_habs(context, builder, sig, args):
|
416
453
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
|
417
|
-
asm = ir.InlineAsm(fnty,
|
454
|
+
asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
|
418
455
|
return builder.call(asm, args)
|
419
456
|
|
420
457
|
|
@@ -450,27 +487,28 @@ _fp16_cmp = """{{
|
|
450
487
|
def _gen_fp16_cmp(op):
|
451
488
|
def ptx_fp16_comparison(context, builder, sig, args):
|
452
489
|
fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
|
453
|
-
asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op),
|
490
|
+
asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
|
454
491
|
result = builder.call(asm, args)
|
455
492
|
|
456
493
|
zero = context.get_constant(types.int16, 0)
|
457
494
|
int_result = builder.bitcast(result, ir.IntType(16))
|
458
495
|
return builder.icmp_unsigned("!=", int_result, zero)
|
496
|
+
|
459
497
|
return ptx_fp16_comparison
|
460
498
|
|
461
499
|
|
462
|
-
lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp(
|
463
|
-
lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp(
|
464
|
-
lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp(
|
465
|
-
lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp(
|
466
|
-
lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp(
|
467
|
-
lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp(
|
468
|
-
lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp(
|
469
|
-
lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp(
|
470
|
-
lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp(
|
471
|
-
lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp(
|
472
|
-
lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp(
|
473
|
-
lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp(
|
500
|
+
lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
501
|
+
lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
|
502
|
+
lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
503
|
+
lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
|
504
|
+
lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
505
|
+
lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
|
506
|
+
lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
507
|
+
lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
|
508
|
+
lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
509
|
+
lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
|
510
|
+
lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
511
|
+
lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
|
474
512
|
|
475
513
|
|
476
514
|
def lower_fp16_minmax(fn, fname, op):
|
@@ -480,8 +518,8 @@ def lower_fp16_minmax(fn, fname, op):
|
|
480
518
|
return builder.select(choice, args[0], args[1])
|
481
519
|
|
482
520
|
|
483
|
-
lower_fp16_minmax(stubs.fp16.hmax,
|
484
|
-
lower_fp16_minmax(stubs.fp16.hmin,
|
521
|
+
lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
|
522
|
+
lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
|
485
523
|
|
486
524
|
# See:
|
487
525
|
# https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
|
@@ -489,8 +527,8 @@ lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
|
|
489
527
|
|
490
528
|
|
491
529
|
cbrt_funcs = {
|
492
|
-
types.float32:
|
493
|
-
types.float64:
|
530
|
+
types.float32: "__nv_cbrtf",
|
531
|
+
types.float64: "__nv_cbrt",
|
494
532
|
}
|
495
533
|
|
496
534
|
|
@@ -514,7 +552,8 @@ def ptx_brev_u4(context, builder, sig, args):
|
|
514
552
|
fn = cgutils.get_or_insert_function(
|
515
553
|
builder.module,
|
516
554
|
ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
|
517
|
-
|
555
|
+
"__nv_brev",
|
556
|
+
)
|
518
557
|
return builder.call(fn, args)
|
519
558
|
|
520
559
|
|
@@ -526,15 +565,14 @@ def ptx_brev_u8(context, builder, sig, args):
|
|
526
565
|
fn = cgutils.get_or_insert_function(
|
527
566
|
builder.module,
|
528
567
|
ir.FunctionType(ir.IntType(64), (ir.IntType(64),)),
|
529
|
-
|
568
|
+
"__nv_brevll",
|
569
|
+
)
|
530
570
|
return builder.call(fn, args)
|
531
571
|
|
532
572
|
|
533
573
|
@lower(stubs.clz, types.Any)
|
534
574
|
def ptx_clz(context, builder, sig, args):
|
535
|
-
return builder.ctlz(
|
536
|
-
args[0],
|
537
|
-
context.get_constant(types.boolean, 0))
|
575
|
+
return builder.ctlz(args[0], context.get_constant(types.boolean, 0))
|
538
576
|
|
539
577
|
|
540
578
|
@lower(stubs.ffs, types.i4)
|
@@ -543,7 +581,8 @@ def ptx_ffs_32(context, builder, sig, args):
|
|
543
581
|
fn = cgutils.get_or_insert_function(
|
544
582
|
builder.module,
|
545
583
|
ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
|
546
|
-
|
584
|
+
"__nv_ffs",
|
585
|
+
)
|
547
586
|
return builder.call(fn, args)
|
548
587
|
|
549
588
|
|
@@ -553,7 +592,8 @@ def ptx_ffs_64(context, builder, sig, args):
|
|
553
592
|
fn = cgutils.get_or_insert_function(
|
554
593
|
builder.module,
|
555
594
|
ir.FunctionType(ir.IntType(32), (ir.IntType(64),)),
|
556
|
-
|
595
|
+
"__nv_ffsll",
|
596
|
+
)
|
557
597
|
return builder.call(fn, args)
|
558
598
|
|
559
599
|
|
@@ -567,10 +607,9 @@ def ptx_selp(context, builder, sig, args):
|
|
567
607
|
def ptx_max_f4(context, builder, sig, args):
|
568
608
|
fn = cgutils.get_or_insert_function(
|
569
609
|
builder.module,
|
570
|
-
ir.FunctionType(
|
571
|
-
|
572
|
-
|
573
|
-
'__nv_fmaxf')
|
610
|
+
ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
|
611
|
+
"__nv_fmaxf",
|
612
|
+
)
|
574
613
|
return builder.call(fn, args)
|
575
614
|
|
576
615
|
|
@@ -580,25 +619,26 @@ def ptx_max_f4(context, builder, sig, args):
|
|
580
619
|
def ptx_max_f8(context, builder, sig, args):
|
581
620
|
fn = cgutils.get_or_insert_function(
|
582
621
|
builder.module,
|
583
|
-
ir.FunctionType(
|
584
|
-
|
585
|
-
|
586
|
-
'__nv_fmax')
|
622
|
+
ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
|
623
|
+
"__nv_fmax",
|
624
|
+
)
|
587
625
|
|
588
|
-
return builder.call(
|
589
|
-
|
590
|
-
|
591
|
-
|
626
|
+
return builder.call(
|
627
|
+
fn,
|
628
|
+
[
|
629
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
630
|
+
context.cast(builder, args[1], sig.args[1], types.double),
|
631
|
+
],
|
632
|
+
)
|
592
633
|
|
593
634
|
|
594
635
|
@lower(min, types.f4, types.f4)
|
595
636
|
def ptx_min_f4(context, builder, sig, args):
|
596
637
|
fn = cgutils.get_or_insert_function(
|
597
638
|
builder.module,
|
598
|
-
ir.FunctionType(
|
599
|
-
|
600
|
-
|
601
|
-
'__nv_fminf')
|
639
|
+
ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
|
640
|
+
"__nv_fminf",
|
641
|
+
)
|
602
642
|
return builder.call(fn, args)
|
603
643
|
|
604
644
|
|
@@ -608,15 +648,17 @@ def ptx_min_f4(context, builder, sig, args):
|
|
608
648
|
def ptx_min_f8(context, builder, sig, args):
|
609
649
|
fn = cgutils.get_or_insert_function(
|
610
650
|
builder.module,
|
611
|
-
ir.FunctionType(
|
612
|
-
|
613
|
-
|
614
|
-
'__nv_fmin')
|
651
|
+
ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
|
652
|
+
"__nv_fmin",
|
653
|
+
)
|
615
654
|
|
616
|
-
return builder.call(
|
617
|
-
|
618
|
-
|
619
|
-
|
655
|
+
return builder.call(
|
656
|
+
fn,
|
657
|
+
[
|
658
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
659
|
+
context.cast(builder, args[1], sig.args[1], types.double),
|
660
|
+
],
|
661
|
+
)
|
620
662
|
|
621
663
|
|
622
664
|
@lower(round, types.f4)
|
@@ -624,19 +666,22 @@ def ptx_min_f8(context, builder, sig, args):
|
|
624
666
|
def ptx_round(context, builder, sig, args):
|
625
667
|
fn = cgutils.get_or_insert_function(
|
626
668
|
builder.module,
|
627
|
-
ir.FunctionType(
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
669
|
+
ir.FunctionType(ir.IntType(64), (ir.DoubleType(),)),
|
670
|
+
"__nv_llrint",
|
671
|
+
)
|
672
|
+
return builder.call(
|
673
|
+
fn,
|
674
|
+
[
|
675
|
+
context.cast(builder, args[0], sig.args[0], types.double),
|
676
|
+
],
|
677
|
+
)
|
634
678
|
|
635
679
|
|
636
680
|
# This rounding implementation follows the algorithm used in the "fallback
|
637
681
|
# version" of double_round in CPython.
|
638
682
|
# https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/floatobject.c#L964-L1007
|
639
683
|
|
684
|
+
|
640
685
|
@lower(round, types.f4, types.Integer)
|
641
686
|
@lower(round, types.f8, types.Integer)
|
642
687
|
def round_to_impl(context, builder, sig, args):
|
@@ -651,7 +696,7 @@ def round_to_impl(context, builder, sig, args):
|
|
651
696
|
pow1 = 10.0 ** (ndigits - 22)
|
652
697
|
pow2 = 1e22
|
653
698
|
else:
|
654
|
-
pow1 = 10.0
|
699
|
+
pow1 = 10.0**ndigits
|
655
700
|
pow2 = 1.0
|
656
701
|
y = (x * pow1) * pow2
|
657
702
|
if math.isinf(y):
|
@@ -662,7 +707,7 @@ def round_to_impl(context, builder, sig, args):
|
|
662
707
|
y = x / pow1
|
663
708
|
|
664
709
|
z = round(y)
|
665
|
-
if
|
710
|
+
if math.fabs(y - z) == 0.5:
|
666
711
|
# halfway between two integers; use round-half-even
|
667
712
|
z = 2.0 * round(y / 2.0)
|
668
713
|
|
@@ -673,19 +718,25 @@ def round_to_impl(context, builder, sig, args):
|
|
673
718
|
|
674
719
|
return z
|
675
720
|
|
676
|
-
return context.compile_internal(
|
721
|
+
return context.compile_internal(
|
722
|
+
builder,
|
723
|
+
round_ndigits,
|
724
|
+
sig,
|
725
|
+
args,
|
726
|
+
)
|
677
727
|
|
678
728
|
|
679
729
|
def gen_deg_rad(const):
|
680
730
|
def impl(context, builder, sig, args):
|
681
|
-
argty, = sig.args
|
731
|
+
(argty,) = sig.args
|
682
732
|
factor = context.get_constant(argty, const)
|
683
733
|
return builder.fmul(factor, args[0])
|
734
|
+
|
684
735
|
return impl
|
685
736
|
|
686
737
|
|
687
|
-
_deg2rad = math.pi / 180.
|
688
|
-
_rad2deg = 180. / math.pi
|
738
|
+
_deg2rad = math.pi / 180.0
|
739
|
+
_rad2deg = 180.0 / math.pi
|
689
740
|
lower(math.radians, types.f4)(gen_deg_rad(_deg2rad))
|
690
741
|
lower(math.radians, types.f8)(gen_deg_rad(_deg2rad))
|
691
742
|
lower(math.degrees, types.f4)(gen_deg_rad(_rad2deg))
|
@@ -701,16 +752,18 @@ def _normalize_indices(context, builder, indty, inds, aryty, valty):
|
|
701
752
|
indices = [inds]
|
702
753
|
else:
|
703
754
|
indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
|
704
|
-
indices = [
|
705
|
-
|
755
|
+
indices = [
|
756
|
+
context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
|
757
|
+
]
|
706
758
|
|
707
759
|
dtype = aryty.dtype
|
708
760
|
if dtype != valty:
|
709
761
|
raise TypeError("expect %s but got %s" % (dtype, valty))
|
710
762
|
|
711
763
|
if aryty.ndim != len(indty):
|
712
|
-
raise TypeError(
|
713
|
-
|
764
|
+
raise TypeError(
|
765
|
+
"indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))
|
766
|
+
)
|
714
767
|
|
715
768
|
return indty, indices
|
716
769
|
|
@@ -722,14 +775,17 @@ def _atomic_dispatcher(dispatch_fn):
|
|
722
775
|
ary, inds, val = args
|
723
776
|
dtype = aryty.dtype
|
724
777
|
|
725
|
-
indty, indices = _normalize_indices(
|
726
|
-
|
778
|
+
indty, indices = _normalize_indices(
|
779
|
+
context, builder, indty, inds, aryty, valty
|
780
|
+
)
|
727
781
|
|
728
782
|
lary = context.make_array(aryty)(context, builder, ary)
|
729
|
-
ptr = cgutils.get_item_pointer(
|
730
|
-
|
783
|
+
ptr = cgutils.get_item_pointer(
|
784
|
+
context, builder, aryty, lary, indices, wraparound=True
|
785
|
+
)
|
731
786
|
# dispatcher to implementation base on dtype
|
732
787
|
return dispatch_fn(context, builder, dtype, ptr, val)
|
788
|
+
|
733
789
|
return imp
|
734
790
|
|
735
791
|
|
@@ -740,14 +796,16 @@ def _atomic_dispatcher(dispatch_fn):
|
|
740
796
|
def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
|
741
797
|
if dtype == types.float32:
|
742
798
|
lmod = builder.module
|
743
|
-
return builder.call(
|
744
|
-
|
799
|
+
return builder.call(
|
800
|
+
nvvmutils.declare_atomic_add_float32(lmod), (ptr, val)
|
801
|
+
)
|
745
802
|
elif dtype == types.float64:
|
746
803
|
lmod = builder.module
|
747
|
-
return builder.call(
|
748
|
-
|
804
|
+
return builder.call(
|
805
|
+
nvvmutils.declare_atomic_add_float64(lmod), (ptr, val)
|
806
|
+
)
|
749
807
|
else:
|
750
|
-
return builder.atomic_rmw(
|
808
|
+
return builder.atomic_rmw("add", ptr, val, "monotonic")
|
751
809
|
|
752
810
|
|
753
811
|
@lower(stubs.atomic.sub, types.Array, types.intp, types.Any)
|
@@ -757,14 +815,16 @@ def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
|
|
757
815
|
def ptx_atomic_sub(context, builder, dtype, ptr, val):
|
758
816
|
if dtype == types.float32:
|
759
817
|
lmod = builder.module
|
760
|
-
return builder.call(
|
761
|
-
|
818
|
+
return builder.call(
|
819
|
+
nvvmutils.declare_atomic_sub_float32(lmod), (ptr, val)
|
820
|
+
)
|
762
821
|
elif dtype == types.float64:
|
763
822
|
lmod = builder.module
|
764
|
-
return builder.call(
|
765
|
-
|
823
|
+
return builder.call(
|
824
|
+
nvvmutils.declare_atomic_sub_float64(lmod), (ptr, val)
|
825
|
+
)
|
766
826
|
else:
|
767
|
-
return builder.atomic_rmw(
|
827
|
+
return builder.atomic_rmw("sub", ptr, val, "monotonic")
|
768
828
|
|
769
829
|
|
770
830
|
@lower(stubs.atomic.inc, types.Array, types.intp, types.Any)
|
@@ -775,10 +835,10 @@ def ptx_atomic_inc(context, builder, dtype, ptr, val):
|
|
775
835
|
if dtype in cuda.cudadecl.unsigned_int_numba_types:
|
776
836
|
bw = dtype.bitwidth
|
777
837
|
lmod = builder.module
|
778
|
-
fn = getattr(nvvmutils, f
|
838
|
+
fn = getattr(nvvmutils, f"declare_atomic_inc_int{bw}")
|
779
839
|
return builder.call(fn(lmod), (ptr, val))
|
780
840
|
else:
|
781
|
-
raise TypeError(f
|
841
|
+
raise TypeError(f"Unimplemented atomic inc with {dtype} array")
|
782
842
|
|
783
843
|
|
784
844
|
@lower(stubs.atomic.dec, types.Array, types.intp, types.Any)
|
@@ -789,27 +849,27 @@ def ptx_atomic_dec(context, builder, dtype, ptr, val):
|
|
789
849
|
if dtype in cuda.cudadecl.unsigned_int_numba_types:
|
790
850
|
bw = dtype.bitwidth
|
791
851
|
lmod = builder.module
|
792
|
-
fn = getattr(nvvmutils, f
|
852
|
+
fn = getattr(nvvmutils, f"declare_atomic_dec_int{bw}")
|
793
853
|
return builder.call(fn(lmod), (ptr, val))
|
794
854
|
else:
|
795
|
-
raise TypeError(f
|
855
|
+
raise TypeError(f"Unimplemented atomic dec with {dtype} array")
|
796
856
|
|
797
857
|
|
798
858
|
def ptx_atomic_bitwise(stub, op):
|
799
859
|
@_atomic_dispatcher
|
800
860
|
def impl_ptx_atomic(context, builder, dtype, ptr, val):
|
801
861
|
if dtype in (cuda.cudadecl.integer_numba_types):
|
802
|
-
return builder.atomic_rmw(op, ptr, val,
|
862
|
+
return builder.atomic_rmw(op, ptr, val, "monotonic")
|
803
863
|
else:
|
804
|
-
raise TypeError(f
|
864
|
+
raise TypeError(f"Unimplemented atomic {op} with {dtype} array")
|
805
865
|
|
806
866
|
for ty in (types.intp, types.UniTuple, types.Tuple):
|
807
867
|
lower(stub, types.Array, ty, types.Any)(impl_ptx_atomic)
|
808
868
|
|
809
869
|
|
810
|
-
ptx_atomic_bitwise(stubs.atomic.and_,
|
811
|
-
ptx_atomic_bitwise(stubs.atomic.or_,
|
812
|
-
ptx_atomic_bitwise(stubs.atomic.xor,
|
870
|
+
ptx_atomic_bitwise(stubs.atomic.and_, "and")
|
871
|
+
ptx_atomic_bitwise(stubs.atomic.or_, "or")
|
872
|
+
ptx_atomic_bitwise(stubs.atomic.xor, "xor")
|
813
873
|
|
814
874
|
|
815
875
|
@lower(stubs.atomic.exch, types.Array, types.intp, types.Any)
|
@@ -818,9 +878,9 @@ ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
|
|
818
878
|
@_atomic_dispatcher
|
819
879
|
def ptx_atomic_exch(context, builder, dtype, ptr, val):
|
820
880
|
if dtype in (cuda.cudadecl.integer_numba_types):
|
821
|
-
return builder.atomic_rmw(
|
881
|
+
return builder.atomic_rmw("xchg", ptr, val, "monotonic")
|
822
882
|
else:
|
823
|
-
raise TypeError(f
|
883
|
+
raise TypeError(f"Unimplemented atomic exch with {dtype} array")
|
824
884
|
|
825
885
|
|
826
886
|
@lower(stubs.atomic.max, types.Array, types.intp, types.Any)
|
@@ -830,17 +890,19 @@ def ptx_atomic_exch(context, builder, dtype, ptr, val):
|
|
830
890
|
def ptx_atomic_max(context, builder, dtype, ptr, val):
|
831
891
|
lmod = builder.module
|
832
892
|
if dtype == types.float64:
|
833
|
-
return builder.call(
|
834
|
-
|
893
|
+
return builder.call(
|
894
|
+
nvvmutils.declare_atomic_max_float64(lmod), (ptr, val)
|
895
|
+
)
|
835
896
|
elif dtype == types.float32:
|
836
|
-
return builder.call(
|
837
|
-
|
897
|
+
return builder.call(
|
898
|
+
nvvmutils.declare_atomic_max_float32(lmod), (ptr, val)
|
899
|
+
)
|
838
900
|
elif dtype in (types.int32, types.int64):
|
839
|
-
return builder.atomic_rmw(
|
901
|
+
return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
|
840
902
|
elif dtype in (types.uint32, types.uint64):
|
841
|
-
return builder.atomic_rmw(
|
903
|
+
return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
|
842
904
|
else:
|
843
|
-
raise TypeError(
|
905
|
+
raise TypeError("Unimplemented atomic max with %s array" % dtype)
|
844
906
|
|
845
907
|
|
846
908
|
@lower(stubs.atomic.min, types.Array, types.intp, types.Any)
|
@@ -850,17 +912,19 @@ def ptx_atomic_max(context, builder, dtype, ptr, val):
|
|
850
912
|
def ptx_atomic_min(context, builder, dtype, ptr, val):
|
851
913
|
lmod = builder.module
|
852
914
|
if dtype == types.float64:
|
853
|
-
return builder.call(
|
854
|
-
|
915
|
+
return builder.call(
|
916
|
+
nvvmutils.declare_atomic_min_float64(lmod), (ptr, val)
|
917
|
+
)
|
855
918
|
elif dtype == types.float32:
|
856
|
-
return builder.call(
|
857
|
-
|
919
|
+
return builder.call(
|
920
|
+
nvvmutils.declare_atomic_min_float32(lmod), (ptr, val)
|
921
|
+
)
|
858
922
|
elif dtype in (types.int32, types.int64):
|
859
|
-
return builder.atomic_rmw(
|
923
|
+
return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
|
860
924
|
elif dtype in (types.uint32, types.uint64):
|
861
|
-
return builder.atomic_rmw(
|
925
|
+
return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
|
862
926
|
else:
|
863
|
-
raise TypeError(
|
927
|
+
raise TypeError("Unimplemented atomic min with %s array" % dtype)
|
864
928
|
|
865
929
|
|
866
930
|
@lower(stubs.atomic.nanmax, types.Array, types.intp, types.Any)
|
@@ -870,17 +934,19 @@ def ptx_atomic_min(context, builder, dtype, ptr, val):
|
|
870
934
|
def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
|
871
935
|
lmod = builder.module
|
872
936
|
if dtype == types.float64:
|
873
|
-
return builder.call(
|
874
|
-
|
937
|
+
return builder.call(
|
938
|
+
nvvmutils.declare_atomic_nanmax_float64(lmod), (ptr, val)
|
939
|
+
)
|
875
940
|
elif dtype == types.float32:
|
876
|
-
return builder.call(
|
877
|
-
|
941
|
+
return builder.call(
|
942
|
+
nvvmutils.declare_atomic_nanmax_float32(lmod), (ptr, val)
|
943
|
+
)
|
878
944
|
elif dtype in (types.int32, types.int64):
|
879
|
-
return builder.atomic_rmw(
|
945
|
+
return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
|
880
946
|
elif dtype in (types.uint32, types.uint64):
|
881
|
-
return builder.atomic_rmw(
|
947
|
+
return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
|
882
948
|
else:
|
883
|
-
raise TypeError(
|
949
|
+
raise TypeError("Unimplemented atomic max with %s array" % dtype)
|
884
950
|
|
885
951
|
|
886
952
|
@lower(stubs.atomic.nanmin, types.Array, types.intp, types.Any)
|
@@ -890,17 +956,19 @@ def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
|
|
890
956
|
def ptx_atomic_nanmin(context, builder, dtype, ptr, val):
|
891
957
|
lmod = builder.module
|
892
958
|
if dtype == types.float64:
|
893
|
-
return builder.call(
|
894
|
-
|
959
|
+
return builder.call(
|
960
|
+
nvvmutils.declare_atomic_nanmin_float64(lmod), (ptr, val)
|
961
|
+
)
|
895
962
|
elif dtype == types.float32:
|
896
|
-
return builder.call(
|
897
|
-
|
963
|
+
return builder.call(
|
964
|
+
nvvmutils.declare_atomic_nanmin_float32(lmod), (ptr, val)
|
965
|
+
)
|
898
966
|
elif dtype in (types.int32, types.int64):
|
899
|
-
return builder.atomic_rmw(
|
967
|
+
return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
|
900
968
|
elif dtype in (types.uint32, types.uint64):
|
901
|
-
return builder.atomic_rmw(
|
969
|
+
return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
|
902
970
|
else:
|
903
|
-
raise TypeError(
|
971
|
+
raise TypeError("Unimplemented atomic min with %s array" % dtype)
|
904
972
|
|
905
973
|
|
906
974
|
@lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
|
@@ -917,19 +985,21 @@ def ptx_atomic_cas(context, builder, sig, args):
|
|
917
985
|
aryty, indty, oldty, valty = sig.args
|
918
986
|
ary, inds, old, val = args
|
919
987
|
|
920
|
-
indty, indices = _normalize_indices(
|
921
|
-
|
988
|
+
indty, indices = _normalize_indices(
|
989
|
+
context, builder, indty, inds, aryty, valty
|
990
|
+
)
|
922
991
|
|
923
992
|
lary = context.make_array(aryty)(context, builder, ary)
|
924
|
-
ptr = cgutils.get_item_pointer(
|
925
|
-
|
993
|
+
ptr = cgutils.get_item_pointer(
|
994
|
+
context, builder, aryty, lary, indices, wraparound=True
|
995
|
+
)
|
926
996
|
|
927
997
|
if aryty.dtype in (cuda.cudadecl.integer_numba_types):
|
928
998
|
lmod = builder.module
|
929
999
|
bitwidth = aryty.dtype.bitwidth
|
930
1000
|
return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
|
931
1001
|
else:
|
932
|
-
raise TypeError(
|
1002
|
+
raise TypeError("Unimplemented atomic cas with %s array" % aryty.dtype)
|
933
1003
|
|
934
1004
|
|
935
1005
|
# -----------------------------------------------------------------------------
|
@@ -937,15 +1007,20 @@ def ptx_atomic_cas(context, builder, sig, args):
|
|
937
1007
|
|
938
1008
|
@lower(breakpoint)
|
939
1009
|
def ptx_brkpt(context, builder, sig, args):
|
940
|
-
brkpt = ir.InlineAsm(
|
941
|
-
|
1010
|
+
brkpt = ir.InlineAsm(
|
1011
|
+
ir.FunctionType(ir.VoidType(), []), "brkpt;", "", side_effect=True
|
1012
|
+
)
|
942
1013
|
builder.call(brkpt, ())
|
943
1014
|
|
944
1015
|
|
945
1016
|
@lower(stubs.nanosleep, types.uint32)
|
946
1017
|
def ptx_nanosleep(context, builder, sig, args):
|
947
|
-
nanosleep = ir.InlineAsm(
|
948
|
-
|
1018
|
+
nanosleep = ir.InlineAsm(
|
1019
|
+
ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
|
1020
|
+
"nanosleep.u32 $0;",
|
1021
|
+
"r",
|
1022
|
+
side_effect=True,
|
1023
|
+
)
|
949
1024
|
ns = args[0]
|
950
1025
|
builder.call(nanosleep, [ns])
|
951
1026
|
|
@@ -953,8 +1028,9 @@ def ptx_nanosleep(context, builder, sig, args):
|
|
953
1028
|
# -----------------------------------------------------------------------------
|
954
1029
|
|
955
1030
|
|
956
|
-
def _generic_array(
|
957
|
-
|
1031
|
+
def _generic_array(
|
1032
|
+
context, builder, shape, dtype, symbol_name, addrspace, can_dynsized=False
|
1033
|
+
):
|
958
1034
|
elemcount = reduce(operator.mul, shape, 1)
|
959
1035
|
|
960
1036
|
# Check for valid shape for this type of allocation.
|
@@ -985,16 +1061,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
985
1061
|
lmod = builder.module
|
986
1062
|
|
987
1063
|
# Create global variable in the requested address space
|
988
|
-
gvmem = cgutils.add_global_variable(
|
989
|
-
|
1064
|
+
gvmem = cgutils.add_global_variable(
|
1065
|
+
lmod, laryty, symbol_name, addrspace
|
1066
|
+
)
|
990
1067
|
# Specify alignment to avoid misalignment bug
|
991
1068
|
align = context.get_abi_sizeof(lldtype)
|
992
1069
|
# Alignment is required to be a power of 2 for shared memory. If it is
|
993
1070
|
# not a power of 2 (e.g. for a Record array) then round up accordingly.
|
994
|
-
gvmem.align = 1 << (align - 1
|
1071
|
+
gvmem.align = 1 << (align - 1).bit_length()
|
995
1072
|
|
996
1073
|
if dynamic_smem:
|
997
|
-
gvmem.linkage =
|
1074
|
+
gvmem.linkage = "external"
|
998
1075
|
else:
|
999
1076
|
## Comment out the following line to workaround a NVVM bug
|
1000
1077
|
## which generates a invalid symbol name when the linkage
|
@@ -1005,8 +1082,9 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1005
1082
|
gvmem.initializer = ir.Constant(laryty, ir.Undefined)
|
1006
1083
|
|
1007
1084
|
# Convert to generic address-space
|
1008
|
-
dataptr = builder.addrspacecast(
|
1009
|
-
|
1085
|
+
dataptr = builder.addrspacecast(
|
1086
|
+
gvmem, ir.PointerType(ir.IntType(8)), "generic"
|
1087
|
+
)
|
1010
1088
|
|
1011
1089
|
targetdata = ll.create_target_data(nvvm.NVVM().data_layout)
|
1012
1090
|
lldtype = context.get_data_type(dtype)
|
@@ -1027,11 +1105,15 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1027
1105
|
# Unfortunately NVVM does not provide an intrinsic for the
|
1028
1106
|
# %dynamic_smem_size register, so we must read it using inline
|
1029
1107
|
# assembly.
|
1030
|
-
get_dynshared_size = ir.InlineAsm(
|
1031
|
-
|
1032
|
-
|
1033
|
-
|
1034
|
-
|
1108
|
+
get_dynshared_size = ir.InlineAsm(
|
1109
|
+
ir.FunctionType(ir.IntType(32), []),
|
1110
|
+
"mov.u32 $0, %dynamic_smem_size;",
|
1111
|
+
"=r",
|
1112
|
+
side_effect=True,
|
1113
|
+
)
|
1114
|
+
dynsmem_size = builder.zext(
|
1115
|
+
builder.call(get_dynshared_size, []), ir.IntType(64)
|
1116
|
+
)
|
1035
1117
|
# Only 1-D dynamic shared memory is supported so the following is a
|
1036
1118
|
# sufficient construction of the shape
|
1037
1119
|
kitemsize = context.get_constant(types.intp, itemsize)
|
@@ -1041,15 +1123,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
|
|
1041
1123
|
|
1042
1124
|
# Create array object
|
1043
1125
|
ndim = len(shape)
|
1044
|
-
aryty = types.Array(dtype=dtype, ndim=ndim, layout=
|
1126
|
+
aryty = types.Array(dtype=dtype, ndim=ndim, layout="C")
|
1045
1127
|
ary = context.make_array(aryty)(context, builder)
|
1046
1128
|
|
1047
|
-
context.populate_array(
|
1048
|
-
|
1049
|
-
|
1050
|
-
|
1051
|
-
|
1052
|
-
|
1129
|
+
context.populate_array(
|
1130
|
+
ary,
|
1131
|
+
data=builder.bitcast(dataptr, ary.data.type),
|
1132
|
+
shape=kshape,
|
1133
|
+
strides=kstrides,
|
1134
|
+
itemsize=context.get_constant(types.intp, itemsize),
|
1135
|
+
meminfo=None,
|
1136
|
+
)
|
1053
1137
|
return ary._getvalue()
|
1054
1138
|
|
1055
1139
|
|