numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -29,48 +29,49 @@ def initialize_dim3(builder, prefix):
29
29
  return cgutils.pack_struct(builder, (x, y, z))
30
30
 
31
31
 
32
- @lower_attr(types.Module(cuda), 'threadIdx')
32
+ @lower_attr(types.Module(cuda), "threadIdx")
33
33
  def cuda_threadIdx(context, builder, sig, args):
34
- return initialize_dim3(builder, 'tid')
34
+ return initialize_dim3(builder, "tid")
35
35
 
36
36
 
37
- @lower_attr(types.Module(cuda), 'blockDim')
37
+ @lower_attr(types.Module(cuda), "blockDim")
38
38
  def cuda_blockDim(context, builder, sig, args):
39
- return initialize_dim3(builder, 'ntid')
39
+ return initialize_dim3(builder, "ntid")
40
40
 
41
41
 
42
- @lower_attr(types.Module(cuda), 'blockIdx')
42
+ @lower_attr(types.Module(cuda), "blockIdx")
43
43
  def cuda_blockIdx(context, builder, sig, args):
44
- return initialize_dim3(builder, 'ctaid')
44
+ return initialize_dim3(builder, "ctaid")
45
45
 
46
46
 
47
- @lower_attr(types.Module(cuda), 'gridDim')
47
+ @lower_attr(types.Module(cuda), "gridDim")
48
48
  def cuda_gridDim(context, builder, sig, args):
49
- return initialize_dim3(builder, 'nctaid')
49
+ return initialize_dim3(builder, "nctaid")
50
50
 
51
51
 
52
- @lower_attr(types.Module(cuda), 'laneid')
52
+ @lower_attr(types.Module(cuda), "laneid")
53
53
  def cuda_laneid(context, builder, sig, args):
54
- return nvvmutils.call_sreg(builder, 'laneid')
54
+ return nvvmutils.call_sreg(builder, "laneid")
55
55
 
56
56
 
57
- @lower_attr(dim3, 'x')
57
+ @lower_attr(dim3, "x")
58
58
  def dim3_x(context, builder, sig, args):
59
59
  return builder.extract_value(args, 0)
60
60
 
61
61
 
62
- @lower_attr(dim3, 'y')
62
+ @lower_attr(dim3, "y")
63
63
  def dim3_y(context, builder, sig, args):
64
64
  return builder.extract_value(args, 1)
65
65
 
66
66
 
67
- @lower_attr(dim3, 'z')
67
+ @lower_attr(dim3, "z")
68
68
  def dim3_z(context, builder, sig, args):
69
69
  return builder.extract_value(args, 2)
70
70
 
71
71
 
72
72
  # -----------------------------------------------------------------------------
73
73
 
74
+
74
75
  @lower(cuda.const.array_like, types.Array)
75
76
  def cuda_const_array_like(context, builder, sig, args):
76
77
  # This is a no-op because CUDATargetContext.make_constant_array already
@@ -95,48 +96,68 @@ def _get_unique_smem_id(name):
95
96
  def cuda_shared_array_integer(context, builder, sig, args):
96
97
  length = sig.args[0].literal_value
97
98
  dtype = parse_dtype(sig.args[1])
98
- return _generic_array(context, builder, shape=(length,), dtype=dtype,
99
- symbol_name=_get_unique_smem_id('_cudapy_smem'),
100
- addrspace=nvvm.ADDRSPACE_SHARED,
101
- can_dynsized=True)
99
+ return _generic_array(
100
+ context,
101
+ builder,
102
+ shape=(length,),
103
+ dtype=dtype,
104
+ symbol_name=_get_unique_smem_id("_cudapy_smem"),
105
+ addrspace=nvvm.ADDRSPACE_SHARED,
106
+ can_dynsized=True,
107
+ )
102
108
 
103
109
 
104
110
  @lower(cuda.shared.array, types.Tuple, types.Any)
105
111
  @lower(cuda.shared.array, types.UniTuple, types.Any)
106
112
  def cuda_shared_array_tuple(context, builder, sig, args):
107
- shape = [ s.literal_value for s in sig.args[0] ]
113
+ shape = [s.literal_value for s in sig.args[0]]
108
114
  dtype = parse_dtype(sig.args[1])
109
- return _generic_array(context, builder, shape=shape, dtype=dtype,
110
- symbol_name=_get_unique_smem_id('_cudapy_smem'),
111
- addrspace=nvvm.ADDRSPACE_SHARED,
112
- can_dynsized=True)
115
+ return _generic_array(
116
+ context,
117
+ builder,
118
+ shape=shape,
119
+ dtype=dtype,
120
+ symbol_name=_get_unique_smem_id("_cudapy_smem"),
121
+ addrspace=nvvm.ADDRSPACE_SHARED,
122
+ can_dynsized=True,
123
+ )
113
124
 
114
125
 
115
126
  @lower(cuda.local.array, types.IntegerLiteral, types.Any)
116
127
  def cuda_local_array_integer(context, builder, sig, args):
117
128
  length = sig.args[0].literal_value
118
129
  dtype = parse_dtype(sig.args[1])
119
- return _generic_array(context, builder, shape=(length,), dtype=dtype,
120
- symbol_name='_cudapy_lmem',
121
- addrspace=nvvm.ADDRSPACE_LOCAL,
122
- can_dynsized=False)
130
+ return _generic_array(
131
+ context,
132
+ builder,
133
+ shape=(length,),
134
+ dtype=dtype,
135
+ symbol_name="_cudapy_lmem",
136
+ addrspace=nvvm.ADDRSPACE_LOCAL,
137
+ can_dynsized=False,
138
+ )
123
139
 
124
140
 
125
141
  @lower(cuda.local.array, types.Tuple, types.Any)
126
142
  @lower(cuda.local.array, types.UniTuple, types.Any)
127
143
  def ptx_lmem_alloc_array(context, builder, sig, args):
128
- shape = [ s.literal_value for s in sig.args[0] ]
144
+ shape = [s.literal_value for s in sig.args[0]]
129
145
  dtype = parse_dtype(sig.args[1])
130
- return _generic_array(context, builder, shape=shape, dtype=dtype,
131
- symbol_name='_cudapy_lmem',
132
- addrspace=nvvm.ADDRSPACE_LOCAL,
133
- can_dynsized=False)
146
+ return _generic_array(
147
+ context,
148
+ builder,
149
+ shape=shape,
150
+ dtype=dtype,
151
+ symbol_name="_cudapy_lmem",
152
+ addrspace=nvvm.ADDRSPACE_LOCAL,
153
+ can_dynsized=False,
154
+ )
134
155
 
135
156
 
136
157
  @lower(stubs.threadfence_block)
137
158
  def ptx_threadfence_block(context, builder, sig, args):
138
159
  assert not args
139
- fname = 'llvm.nvvm.membar.cta'
160
+ fname = "llvm.nvvm.membar.cta"
140
161
  lmod = builder.module
141
162
  fnty = ir.FunctionType(ir.VoidType(), ())
142
163
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -147,7 +168,7 @@ def ptx_threadfence_block(context, builder, sig, args):
147
168
  @lower(stubs.threadfence_system)
148
169
  def ptx_threadfence_system(context, builder, sig, args):
149
170
  assert not args
150
- fname = 'llvm.nvvm.membar.sys'
171
+ fname = "llvm.nvvm.membar.sys"
151
172
  lmod = builder.module
152
173
  fnty = ir.FunctionType(ir.VoidType(), ())
153
174
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -158,7 +179,7 @@ def ptx_threadfence_system(context, builder, sig, args):
158
179
  @lower(stubs.threadfence)
159
180
  def ptx_threadfence_device(context, builder, sig, args):
160
181
  assert not args
161
- fname = 'llvm.nvvm.membar.gl'
182
+ fname = "llvm.nvvm.membar.gl"
162
183
  lmod = builder.module
163
184
  fnty = ir.FunctionType(ir.VoidType(), ())
164
185
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -175,7 +196,7 @@ def ptx_syncwarp(context, builder, sig, args):
175
196
 
176
197
  @lower(stubs.syncwarp, types.i4)
177
198
  def ptx_syncwarp_mask(context, builder, sig, args):
178
- fname = 'llvm.nvvm.bar.warp.sync'
199
+ fname = "llvm.nvvm.bar.warp.sync"
179
200
  lmod = builder.module
180
201
  fnty = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
181
202
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -183,68 +204,15 @@ def ptx_syncwarp_mask(context, builder, sig, args):
183
204
  return context.get_dummy_value()
184
205
 
185
206
 
186
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4,
187
- types.i4)
188
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4,
189
- types.i4)
190
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4,
191
- types.i4)
192
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4,
193
- types.i4)
194
- def ptx_shfl_sync_i32(context, builder, sig, args):
195
- """
196
- The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
197
- function supports both 32 and 64 bit ints and floats, so for feature parity,
198
- i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
199
- an int, then shuffling, then bitcasting back. And 64-bit values by packing
200
- them into 2 32bit values, shuffling thoose, and then packing back together.
201
- """
202
- mask, mode, value, index, clamp = args
203
- value_type = sig.args[2]
204
- if value_type in types.real_domain:
205
- value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
206
- fname = 'llvm.nvvm.shfl.sync.i32'
207
+ @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
208
+ def ptx_vote_sync(context, builder, sig, args):
209
+ fname = "llvm.nvvm.vote.sync"
207
210
  lmod = builder.module
208
211
  fnty = ir.FunctionType(
209
212
  ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
210
- (ir.IntType(32), ir.IntType(32), ir.IntType(32),
211
- ir.IntType(32), ir.IntType(32))
213
+ (ir.IntType(32), ir.IntType(32), ir.IntType(1)),
212
214
  )
213
215
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
214
- if value_type.bitwidth == 32:
215
- ret = builder.call(func, (mask, mode, value, index, clamp))
216
- if value_type == types.float32:
217
- rv = builder.extract_value(ret, 0)
218
- pred = builder.extract_value(ret, 1)
219
- fv = builder.bitcast(rv, ir.FloatType())
220
- ret = cgutils.make_anonymous_struct(builder, (fv, pred))
221
- else:
222
- value1 = builder.trunc(value, ir.IntType(32))
223
- value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
224
- value2 = builder.trunc(value_lshr, ir.IntType(32))
225
- ret1 = builder.call(func, (mask, mode, value1, index, clamp))
226
- ret2 = builder.call(func, (mask, mode, value2, index, clamp))
227
- rv1 = builder.extract_value(ret1, 0)
228
- rv2 = builder.extract_value(ret2, 0)
229
- pred = builder.extract_value(ret1, 1)
230
- rv1_64 = builder.zext(rv1, ir.IntType(64))
231
- rv2_64 = builder.zext(rv2, ir.IntType(64))
232
- rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
233
- rv = builder.or_(rv_shl, rv1_64)
234
- if value_type == types.float64:
235
- rv = builder.bitcast(rv, ir.DoubleType())
236
- ret = cgutils.make_anonymous_struct(builder, (rv, pred))
237
- return ret
238
-
239
-
240
- @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
241
- def ptx_vote_sync(context, builder, sig, args):
242
- fname = 'llvm.nvvm.vote.sync'
243
- lmod = builder.module
244
- fnty = ir.FunctionType(ir.LiteralStructType((ir.IntType(32),
245
- ir.IntType(1))),
246
- (ir.IntType(32), ir.IntType(32), ir.IntType(1)))
247
- func = cgutils.get_or_insert_function(lmod, fnty, fname)
248
216
  return builder.call(func, args)
249
217
 
250
218
 
@@ -257,7 +225,7 @@ def ptx_match_any_sync(context, builder, sig, args):
257
225
  width = sig.args[1].bitwidth
258
226
  if sig.args[1] in types.real_domain:
259
227
  value = builder.bitcast(value, ir.IntType(width))
260
- fname = 'llvm.nvvm.match.any.sync.i{}'.format(width)
228
+ fname = "llvm.nvvm.match.any.sync.i{}".format(width)
261
229
  lmod = builder.module
262
230
  fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32), ir.IntType(width)))
263
231
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -273,27 +241,35 @@ def ptx_match_all_sync(context, builder, sig, args):
273
241
  width = sig.args[1].bitwidth
274
242
  if sig.args[1] in types.real_domain:
275
243
  value = builder.bitcast(value, ir.IntType(width))
276
- fname = 'llvm.nvvm.match.all.sync.i{}'.format(width)
244
+ fname = "llvm.nvvm.match.all.sync.i{}".format(width)
277
245
  lmod = builder.module
278
- fnty = ir.FunctionType(ir.LiteralStructType((ir.IntType(32),
279
- ir.IntType(1))),
280
- (ir.IntType(32), ir.IntType(width)))
246
+ fnty = ir.FunctionType(
247
+ ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
248
+ (ir.IntType(32), ir.IntType(width)),
249
+ )
281
250
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
282
251
  return builder.call(func, (mask, value))
283
252
 
284
253
 
285
254
  @lower(stubs.activemask)
286
255
  def ptx_activemask(context, builder, sig, args):
287
- activemask = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
288
- "activemask.b32 $0;", '=r', side_effect=True)
256
+ activemask = ir.InlineAsm(
257
+ ir.FunctionType(ir.IntType(32), []),
258
+ "activemask.b32 $0;",
259
+ "=r",
260
+ side_effect=True,
261
+ )
289
262
  return builder.call(activemask, [])
290
263
 
291
264
 
292
265
  @lower(stubs.lanemask_lt)
293
266
  def ptx_lanemask_lt(context, builder, sig, args):
294
- activemask = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
295
- "mov.u32 $0, %lanemask_lt;", '=r',
296
- side_effect=True)
267
+ activemask = ir.InlineAsm(
268
+ ir.FunctionType(ir.IntType(32), []),
269
+ "mov.u32 $0, %lanemask_lt;",
270
+ "=r",
271
+ side_effect=True,
272
+ )
297
273
  return builder.call(activemask, [])
298
274
 
299
275
 
@@ -308,7 +284,7 @@ def ptx_fma(context, builder, sig, args):
308
284
 
309
285
 
310
286
  def float16_float_ty_constraint(bitwidth):
311
- typemap = {32: ('f32', 'f'), 64: ('f64', 'd')}
287
+ typemap = {32: ("f32", "f"), 64: ("f64", "d")}
312
288
 
313
289
  try:
314
290
  return typemap[bitwidth]
@@ -342,7 +318,7 @@ def float_to_float16_cast(context, builder, fromty, toty, val):
342
318
 
343
319
 
344
320
  def float16_int_constraint(bitwidth):
345
- typemap = { 8: 'c', 16: 'h', 32: 'r', 64: 'l' }
321
+ typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
346
322
 
347
323
  try:
348
324
  return typemap[bitwidth]
@@ -355,12 +331,12 @@ def float16_int_constraint(bitwidth):
355
331
  def float16_to_integer_cast(context, builder, fromty, toty, val):
356
332
  bitwidth = toty.bitwidth
357
333
  constraint = float16_int_constraint(bitwidth)
358
- signedness = 's' if toty.signed else 'u'
334
+ signedness = "s" if toty.signed else "u"
359
335
 
360
336
  fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
361
- asm = ir.InlineAsm(fnty,
362
- f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;",
363
- f"={constraint},h")
337
+ asm = ir.InlineAsm(
338
+ fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
339
+ )
364
340
  return builder.call(asm, [val])
365
341
 
366
342
 
@@ -369,40 +345,38 @@ def float16_to_integer_cast(context, builder, fromty, toty, val):
369
345
  def integer_to_float16_cast(context, builder, fromty, toty, val):
370
346
  bitwidth = fromty.bitwidth
371
347
  constraint = float16_int_constraint(bitwidth)
372
- signedness = 's' if fromty.signed else 'u'
348
+ signedness = "s" if fromty.signed else "u"
373
349
 
374
- fnty = ir.FunctionType(ir.IntType(16),
375
- [context.get_value_type(fromty)])
376
- asm = ir.InlineAsm(fnty,
377
- f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;",
378
- f"=h,{constraint}")
350
+ fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
351
+ asm = ir.InlineAsm(
352
+ fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
353
+ )
379
354
  return builder.call(asm, [val])
380
355
 
381
356
 
382
357
  def lower_fp16_binary(fn, op):
383
358
  @lower(fn, types.float16, types.float16)
384
359
  def ptx_fp16_binary(context, builder, sig, args):
385
- fnty = ir.FunctionType(ir.IntType(16),
386
- [ir.IntType(16), ir.IntType(16)])
387
- asm = ir.InlineAsm(fnty, f'{op}.f16 $0,$1,$2;', '=h,h,h')
360
+ fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
361
+ asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
388
362
  return builder.call(asm, args)
389
363
 
390
364
 
391
- lower_fp16_binary(stubs.fp16.hadd, 'add')
392
- lower_fp16_binary(operator.add, 'add')
393
- lower_fp16_binary(operator.iadd, 'add')
394
- lower_fp16_binary(stubs.fp16.hsub, 'sub')
395
- lower_fp16_binary(operator.sub, 'sub')
396
- lower_fp16_binary(operator.isub, 'sub')
397
- lower_fp16_binary(stubs.fp16.hmul, 'mul')
398
- lower_fp16_binary(operator.mul, 'mul')
399
- lower_fp16_binary(operator.imul, 'mul')
365
+ lower_fp16_binary(stubs.fp16.hadd, "add")
366
+ lower_fp16_binary(operator.add, "add")
367
+ lower_fp16_binary(operator.iadd, "add")
368
+ lower_fp16_binary(stubs.fp16.hsub, "sub")
369
+ lower_fp16_binary(operator.sub, "sub")
370
+ lower_fp16_binary(operator.isub, "sub")
371
+ lower_fp16_binary(stubs.fp16.hmul, "mul")
372
+ lower_fp16_binary(operator.mul, "mul")
373
+ lower_fp16_binary(operator.imul, "mul")
400
374
 
401
375
 
402
376
  @lower(stubs.fp16.hneg, types.float16)
403
377
  def ptx_fp16_hneg(context, builder, sig, args):
404
378
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
405
- asm = ir.InlineAsm(fnty, 'neg.f16 $0, $1;', '=h,h')
379
+ asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
406
380
  return builder.call(asm, args)
407
381
 
408
382
 
@@ -414,7 +388,7 @@ def operator_hneg(context, builder, sig, args):
414
388
  @lower(stubs.fp16.habs, types.float16)
415
389
  def ptx_fp16_habs(context, builder, sig, args):
416
390
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
417
- asm = ir.InlineAsm(fnty, 'abs.f16 $0, $1;', '=h,h')
391
+ asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
418
392
  return builder.call(asm, args)
419
393
 
420
394
 
@@ -450,27 +424,28 @@ _fp16_cmp = """{{
450
424
  def _gen_fp16_cmp(op):
451
425
  def ptx_fp16_comparison(context, builder, sig, args):
452
426
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
453
- asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), '=h,h,h')
427
+ asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
454
428
  result = builder.call(asm, args)
455
429
 
456
430
  zero = context.get_constant(types.int16, 0)
457
431
  int_result = builder.bitcast(result, ir.IntType(16))
458
432
  return builder.icmp_unsigned("!=", int_result, zero)
433
+
459
434
  return ptx_fp16_comparison
460
435
 
461
436
 
462
- lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp('eq'))
463
- lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp('eq'))
464
- lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp('ne'))
465
- lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp('ne'))
466
- lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp('ge'))
467
- lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp('ge'))
468
- lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp('gt'))
469
- lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp('gt'))
470
- lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp('le'))
471
- lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp('le'))
472
- lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp('lt'))
473
- lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp('lt'))
437
+ lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
438
+ lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
439
+ lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
440
+ lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
441
+ lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
442
+ lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
443
+ lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
444
+ lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
445
+ lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
446
+ lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
447
+ lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
448
+ lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
474
449
 
475
450
 
476
451
  def lower_fp16_minmax(fn, fname, op):
@@ -480,8 +455,8 @@ def lower_fp16_minmax(fn, fname, op):
480
455
  return builder.select(choice, args[0], args[1])
481
456
 
482
457
 
483
- lower_fp16_minmax(stubs.fp16.hmax, 'max', 'gt')
484
- lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
458
+ lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
459
+ lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
485
460
 
486
461
  # See:
487
462
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
@@ -489,8 +464,8 @@ lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
489
464
 
490
465
 
491
466
  cbrt_funcs = {
492
- types.float32: '__nv_cbrtf',
493
- types.float64: '__nv_cbrt',
467
+ types.float32: "__nv_cbrtf",
468
+ types.float64: "__nv_cbrt",
494
469
  }
495
470
 
496
471
 
@@ -514,7 +489,8 @@ def ptx_brev_u4(context, builder, sig, args):
514
489
  fn = cgutils.get_or_insert_function(
515
490
  builder.module,
516
491
  ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
517
- '__nv_brev')
492
+ "__nv_brev",
493
+ )
518
494
  return builder.call(fn, args)
519
495
 
520
496
 
@@ -526,15 +502,14 @@ def ptx_brev_u8(context, builder, sig, args):
526
502
  fn = cgutils.get_or_insert_function(
527
503
  builder.module,
528
504
  ir.FunctionType(ir.IntType(64), (ir.IntType(64),)),
529
- '__nv_brevll')
505
+ "__nv_brevll",
506
+ )
530
507
  return builder.call(fn, args)
531
508
 
532
509
 
533
510
  @lower(stubs.clz, types.Any)
534
511
  def ptx_clz(context, builder, sig, args):
535
- return builder.ctlz(
536
- args[0],
537
- context.get_constant(types.boolean, 0))
512
+ return builder.ctlz(args[0], context.get_constant(types.boolean, 0))
538
513
 
539
514
 
540
515
  @lower(stubs.ffs, types.i4)
@@ -543,7 +518,8 @@ def ptx_ffs_32(context, builder, sig, args):
543
518
  fn = cgutils.get_or_insert_function(
544
519
  builder.module,
545
520
  ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
546
- '__nv_ffs')
521
+ "__nv_ffs",
522
+ )
547
523
  return builder.call(fn, args)
548
524
 
549
525
 
@@ -553,7 +529,8 @@ def ptx_ffs_64(context, builder, sig, args):
553
529
  fn = cgutils.get_or_insert_function(
554
530
  builder.module,
555
531
  ir.FunctionType(ir.IntType(32), (ir.IntType(64),)),
556
- '__nv_ffsll')
532
+ "__nv_ffsll",
533
+ )
557
534
  return builder.call(fn, args)
558
535
 
559
536
 
@@ -567,10 +544,9 @@ def ptx_selp(context, builder, sig, args):
567
544
  def ptx_max_f4(context, builder, sig, args):
568
545
  fn = cgutils.get_or_insert_function(
569
546
  builder.module,
570
- ir.FunctionType(
571
- ir.FloatType(),
572
- (ir.FloatType(), ir.FloatType())),
573
- '__nv_fmaxf')
547
+ ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
548
+ "__nv_fmaxf",
549
+ )
574
550
  return builder.call(fn, args)
575
551
 
576
552
 
@@ -580,25 +556,26 @@ def ptx_max_f4(context, builder, sig, args):
580
556
  def ptx_max_f8(context, builder, sig, args):
581
557
  fn = cgutils.get_or_insert_function(
582
558
  builder.module,
583
- ir.FunctionType(
584
- ir.DoubleType(),
585
- (ir.DoubleType(), ir.DoubleType())),
586
- '__nv_fmax')
559
+ ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
560
+ "__nv_fmax",
561
+ )
587
562
 
588
- return builder.call(fn, [
589
- context.cast(builder, args[0], sig.args[0], types.double),
590
- context.cast(builder, args[1], sig.args[1], types.double),
591
- ])
563
+ return builder.call(
564
+ fn,
565
+ [
566
+ context.cast(builder, args[0], sig.args[0], types.double),
567
+ context.cast(builder, args[1], sig.args[1], types.double),
568
+ ],
569
+ )
592
570
 
593
571
 
594
572
  @lower(min, types.f4, types.f4)
595
573
  def ptx_min_f4(context, builder, sig, args):
596
574
  fn = cgutils.get_or_insert_function(
597
575
  builder.module,
598
- ir.FunctionType(
599
- ir.FloatType(),
600
- (ir.FloatType(), ir.FloatType())),
601
- '__nv_fminf')
576
+ ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
577
+ "__nv_fminf",
578
+ )
602
579
  return builder.call(fn, args)
603
580
 
604
581
 
@@ -608,15 +585,17 @@ def ptx_min_f4(context, builder, sig, args):
608
585
  def ptx_min_f8(context, builder, sig, args):
609
586
  fn = cgutils.get_or_insert_function(
610
587
  builder.module,
611
- ir.FunctionType(
612
- ir.DoubleType(),
613
- (ir.DoubleType(), ir.DoubleType())),
614
- '__nv_fmin')
588
+ ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
589
+ "__nv_fmin",
590
+ )
615
591
 
616
- return builder.call(fn, [
617
- context.cast(builder, args[0], sig.args[0], types.double),
618
- context.cast(builder, args[1], sig.args[1], types.double),
619
- ])
592
+ return builder.call(
593
+ fn,
594
+ [
595
+ context.cast(builder, args[0], sig.args[0], types.double),
596
+ context.cast(builder, args[1], sig.args[1], types.double),
597
+ ],
598
+ )
620
599
 
621
600
 
622
601
  @lower(round, types.f4)
@@ -624,19 +603,22 @@ def ptx_min_f8(context, builder, sig, args):
624
603
  def ptx_round(context, builder, sig, args):
625
604
  fn = cgutils.get_or_insert_function(
626
605
  builder.module,
627
- ir.FunctionType(
628
- ir.IntType(64),
629
- (ir.DoubleType(),)),
630
- '__nv_llrint')
631
- return builder.call(fn, [
632
- context.cast(builder, args[0], sig.args[0], types.double),
633
- ])
606
+ ir.FunctionType(ir.IntType(64), (ir.DoubleType(),)),
607
+ "__nv_llrint",
608
+ )
609
+ return builder.call(
610
+ fn,
611
+ [
612
+ context.cast(builder, args[0], sig.args[0], types.double),
613
+ ],
614
+ )
634
615
 
635
616
 
636
617
  # This rounding implementation follows the algorithm used in the "fallback
637
618
  # version" of double_round in CPython.
638
619
  # https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/floatobject.c#L964-L1007
639
620
 
621
+
640
622
  @lower(round, types.f4, types.Integer)
641
623
  @lower(round, types.f8, types.Integer)
642
624
  def round_to_impl(context, builder, sig, args):
@@ -651,7 +633,7 @@ def round_to_impl(context, builder, sig, args):
651
633
  pow1 = 10.0 ** (ndigits - 22)
652
634
  pow2 = 1e22
653
635
  else:
654
- pow1 = 10.0 ** ndigits
636
+ pow1 = 10.0**ndigits
655
637
  pow2 = 1.0
656
638
  y = (x * pow1) * pow2
657
639
  if math.isinf(y):
@@ -662,7 +644,7 @@ def round_to_impl(context, builder, sig, args):
662
644
  y = x / pow1
663
645
 
664
646
  z = round(y)
665
- if (math.fabs(y - z) == 0.5):
647
+ if math.fabs(y - z) == 0.5:
666
648
  # halfway between two integers; use round-half-even
667
649
  z = 2.0 * round(y / 2.0)
668
650
 
@@ -673,19 +655,25 @@ def round_to_impl(context, builder, sig, args):
673
655
 
674
656
  return z
675
657
 
676
- return context.compile_internal(builder, round_ndigits, sig, args, )
658
+ return context.compile_internal(
659
+ builder,
660
+ round_ndigits,
661
+ sig,
662
+ args,
663
+ )
677
664
 
678
665
 
679
666
  def gen_deg_rad(const):
680
667
  def impl(context, builder, sig, args):
681
- argty, = sig.args
668
+ (argty,) = sig.args
682
669
  factor = context.get_constant(argty, const)
683
670
  return builder.fmul(factor, args[0])
671
+
684
672
  return impl
685
673
 
686
674
 
687
- _deg2rad = math.pi / 180.
688
- _rad2deg = 180. / math.pi
675
+ _deg2rad = math.pi / 180.0
676
+ _rad2deg = 180.0 / math.pi
689
677
  lower(math.radians, types.f4)(gen_deg_rad(_deg2rad))
690
678
  lower(math.radians, types.f8)(gen_deg_rad(_deg2rad))
691
679
  lower(math.degrees, types.f4)(gen_deg_rad(_rad2deg))
@@ -701,16 +689,18 @@ def _normalize_indices(context, builder, indty, inds, aryty, valty):
701
689
  indices = [inds]
702
690
  else:
703
691
  indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
704
- indices = [context.cast(builder, i, t, types.intp)
705
- for t, i in zip(indty, indices)]
692
+ indices = [
693
+ context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
694
+ ]
706
695
 
707
696
  dtype = aryty.dtype
708
697
  if dtype != valty:
709
698
  raise TypeError("expect %s but got %s" % (dtype, valty))
710
699
 
711
700
  if aryty.ndim != len(indty):
712
- raise TypeError("indexing %d-D array with %d-D index" %
713
- (aryty.ndim, len(indty)))
701
+ raise TypeError(
702
+ "indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))
703
+ )
714
704
 
715
705
  return indty, indices
716
706
 
@@ -722,14 +712,17 @@ def _atomic_dispatcher(dispatch_fn):
722
712
  ary, inds, val = args
723
713
  dtype = aryty.dtype
724
714
 
725
- indty, indices = _normalize_indices(context, builder, indty, inds,
726
- aryty, valty)
715
+ indty, indices = _normalize_indices(
716
+ context, builder, indty, inds, aryty, valty
717
+ )
727
718
 
728
719
  lary = context.make_array(aryty)(context, builder, ary)
729
- ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
730
- wraparound=True)
720
+ ptr = cgutils.get_item_pointer(
721
+ context, builder, aryty, lary, indices, wraparound=True
722
+ )
731
723
  # dispatcher to implementation base on dtype
732
724
  return dispatch_fn(context, builder, dtype, ptr, val)
725
+
733
726
  return imp
734
727
 
735
728
 
@@ -740,14 +733,16 @@ def _atomic_dispatcher(dispatch_fn):
740
733
  def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
741
734
  if dtype == types.float32:
742
735
  lmod = builder.module
743
- return builder.call(nvvmutils.declare_atomic_add_float32(lmod),
744
- (ptr, val))
736
+ return builder.call(
737
+ nvvmutils.declare_atomic_add_float32(lmod), (ptr, val)
738
+ )
745
739
  elif dtype == types.float64:
746
740
  lmod = builder.module
747
- return builder.call(nvvmutils.declare_atomic_add_float64(lmod),
748
- (ptr, val))
741
+ return builder.call(
742
+ nvvmutils.declare_atomic_add_float64(lmod), (ptr, val)
743
+ )
749
744
  else:
750
- return builder.atomic_rmw('add', ptr, val, 'monotonic')
745
+ return builder.atomic_rmw("add", ptr, val, "monotonic")
751
746
 
752
747
 
753
748
  @lower(stubs.atomic.sub, types.Array, types.intp, types.Any)
@@ -757,14 +752,16 @@ def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
757
752
  def ptx_atomic_sub(context, builder, dtype, ptr, val):
758
753
  if dtype == types.float32:
759
754
  lmod = builder.module
760
- return builder.call(nvvmutils.declare_atomic_sub_float32(lmod),
761
- (ptr, val))
755
+ return builder.call(
756
+ nvvmutils.declare_atomic_sub_float32(lmod), (ptr, val)
757
+ )
762
758
  elif dtype == types.float64:
763
759
  lmod = builder.module
764
- return builder.call(nvvmutils.declare_atomic_sub_float64(lmod),
765
- (ptr, val))
760
+ return builder.call(
761
+ nvvmutils.declare_atomic_sub_float64(lmod), (ptr, val)
762
+ )
766
763
  else:
767
- return builder.atomic_rmw('sub', ptr, val, 'monotonic')
764
+ return builder.atomic_rmw("sub", ptr, val, "monotonic")
768
765
 
769
766
 
770
767
  @lower(stubs.atomic.inc, types.Array, types.intp, types.Any)
@@ -775,10 +772,10 @@ def ptx_atomic_inc(context, builder, dtype, ptr, val):
775
772
  if dtype in cuda.cudadecl.unsigned_int_numba_types:
776
773
  bw = dtype.bitwidth
777
774
  lmod = builder.module
778
- fn = getattr(nvvmutils, f'declare_atomic_inc_int{bw}')
775
+ fn = getattr(nvvmutils, f"declare_atomic_inc_int{bw}")
779
776
  return builder.call(fn(lmod), (ptr, val))
780
777
  else:
781
- raise TypeError(f'Unimplemented atomic inc with {dtype} array')
778
+ raise TypeError(f"Unimplemented atomic inc with {dtype} array")
782
779
 
783
780
 
784
781
  @lower(stubs.atomic.dec, types.Array, types.intp, types.Any)
@@ -789,27 +786,27 @@ def ptx_atomic_dec(context, builder, dtype, ptr, val):
789
786
  if dtype in cuda.cudadecl.unsigned_int_numba_types:
790
787
  bw = dtype.bitwidth
791
788
  lmod = builder.module
792
- fn = getattr(nvvmutils, f'declare_atomic_dec_int{bw}')
789
+ fn = getattr(nvvmutils, f"declare_atomic_dec_int{bw}")
793
790
  return builder.call(fn(lmod), (ptr, val))
794
791
  else:
795
- raise TypeError(f'Unimplemented atomic dec with {dtype} array')
792
+ raise TypeError(f"Unimplemented atomic dec with {dtype} array")
796
793
 
797
794
 
798
795
  def ptx_atomic_bitwise(stub, op):
799
796
  @_atomic_dispatcher
800
797
  def impl_ptx_atomic(context, builder, dtype, ptr, val):
801
798
  if dtype in (cuda.cudadecl.integer_numba_types):
802
- return builder.atomic_rmw(op, ptr, val, 'monotonic')
799
+ return builder.atomic_rmw(op, ptr, val, "monotonic")
803
800
  else:
804
- raise TypeError(f'Unimplemented atomic {op} with {dtype} array')
801
+ raise TypeError(f"Unimplemented atomic {op} with {dtype} array")
805
802
 
806
803
  for ty in (types.intp, types.UniTuple, types.Tuple):
807
804
  lower(stub, types.Array, ty, types.Any)(impl_ptx_atomic)
808
805
 
809
806
 
810
- ptx_atomic_bitwise(stubs.atomic.and_, 'and')
811
- ptx_atomic_bitwise(stubs.atomic.or_, 'or')
812
- ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
807
+ ptx_atomic_bitwise(stubs.atomic.and_, "and")
808
+ ptx_atomic_bitwise(stubs.atomic.or_, "or")
809
+ ptx_atomic_bitwise(stubs.atomic.xor, "xor")
813
810
 
814
811
 
815
812
  @lower(stubs.atomic.exch, types.Array, types.intp, types.Any)
@@ -818,9 +815,9 @@ ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
818
815
  @_atomic_dispatcher
819
816
  def ptx_atomic_exch(context, builder, dtype, ptr, val):
820
817
  if dtype in (cuda.cudadecl.integer_numba_types):
821
- return builder.atomic_rmw('xchg', ptr, val, 'monotonic')
818
+ return builder.atomic_rmw("xchg", ptr, val, "monotonic")
822
819
  else:
823
- raise TypeError(f'Unimplemented atomic exch with {dtype} array')
820
+ raise TypeError(f"Unimplemented atomic exch with {dtype} array")
824
821
 
825
822
 
826
823
  @lower(stubs.atomic.max, types.Array, types.intp, types.Any)
@@ -830,17 +827,19 @@ def ptx_atomic_exch(context, builder, dtype, ptr, val):
830
827
  def ptx_atomic_max(context, builder, dtype, ptr, val):
831
828
  lmod = builder.module
832
829
  if dtype == types.float64:
833
- return builder.call(nvvmutils.declare_atomic_max_float64(lmod),
834
- (ptr, val))
830
+ return builder.call(
831
+ nvvmutils.declare_atomic_max_float64(lmod), (ptr, val)
832
+ )
835
833
  elif dtype == types.float32:
836
- return builder.call(nvvmutils.declare_atomic_max_float32(lmod),
837
- (ptr, val))
834
+ return builder.call(
835
+ nvvmutils.declare_atomic_max_float32(lmod), (ptr, val)
836
+ )
838
837
  elif dtype in (types.int32, types.int64):
839
- return builder.atomic_rmw('max', ptr, val, ordering='monotonic')
838
+ return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
840
839
  elif dtype in (types.uint32, types.uint64):
841
- return builder.atomic_rmw('umax', ptr, val, ordering='monotonic')
840
+ return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
842
841
  else:
843
- raise TypeError('Unimplemented atomic max with %s array' % dtype)
842
+ raise TypeError("Unimplemented atomic max with %s array" % dtype)
844
843
 
845
844
 
846
845
  @lower(stubs.atomic.min, types.Array, types.intp, types.Any)
@@ -850,17 +849,19 @@ def ptx_atomic_max(context, builder, dtype, ptr, val):
850
849
  def ptx_atomic_min(context, builder, dtype, ptr, val):
851
850
  lmod = builder.module
852
851
  if dtype == types.float64:
853
- return builder.call(nvvmutils.declare_atomic_min_float64(lmod),
854
- (ptr, val))
852
+ return builder.call(
853
+ nvvmutils.declare_atomic_min_float64(lmod), (ptr, val)
854
+ )
855
855
  elif dtype == types.float32:
856
- return builder.call(nvvmutils.declare_atomic_min_float32(lmod),
857
- (ptr, val))
856
+ return builder.call(
857
+ nvvmutils.declare_atomic_min_float32(lmod), (ptr, val)
858
+ )
858
859
  elif dtype in (types.int32, types.int64):
859
- return builder.atomic_rmw('min', ptr, val, ordering='monotonic')
860
+ return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
860
861
  elif dtype in (types.uint32, types.uint64):
861
- return builder.atomic_rmw('umin', ptr, val, ordering='monotonic')
862
+ return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
862
863
  else:
863
- raise TypeError('Unimplemented atomic min with %s array' % dtype)
864
+ raise TypeError("Unimplemented atomic min with %s array" % dtype)
864
865
 
865
866
 
866
867
  @lower(stubs.atomic.nanmax, types.Array, types.intp, types.Any)
@@ -870,17 +871,19 @@ def ptx_atomic_min(context, builder, dtype, ptr, val):
870
871
  def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
871
872
  lmod = builder.module
872
873
  if dtype == types.float64:
873
- return builder.call(nvvmutils.declare_atomic_nanmax_float64(lmod),
874
- (ptr, val))
874
+ return builder.call(
875
+ nvvmutils.declare_atomic_nanmax_float64(lmod), (ptr, val)
876
+ )
875
877
  elif dtype == types.float32:
876
- return builder.call(nvvmutils.declare_atomic_nanmax_float32(lmod),
877
- (ptr, val))
878
+ return builder.call(
879
+ nvvmutils.declare_atomic_nanmax_float32(lmod), (ptr, val)
880
+ )
878
881
  elif dtype in (types.int32, types.int64):
879
- return builder.atomic_rmw('max', ptr, val, ordering='monotonic')
882
+ return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
880
883
  elif dtype in (types.uint32, types.uint64):
881
- return builder.atomic_rmw('umax', ptr, val, ordering='monotonic')
884
+ return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
882
885
  else:
883
- raise TypeError('Unimplemented atomic max with %s array' % dtype)
886
+ raise TypeError("Unimplemented atomic max with %s array" % dtype)
884
887
 
885
888
 
886
889
  @lower(stubs.atomic.nanmin, types.Array, types.intp, types.Any)
@@ -890,17 +893,19 @@ def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
890
893
  def ptx_atomic_nanmin(context, builder, dtype, ptr, val):
891
894
  lmod = builder.module
892
895
  if dtype == types.float64:
893
- return builder.call(nvvmutils.declare_atomic_nanmin_float64(lmod),
894
- (ptr, val))
896
+ return builder.call(
897
+ nvvmutils.declare_atomic_nanmin_float64(lmod), (ptr, val)
898
+ )
895
899
  elif dtype == types.float32:
896
- return builder.call(nvvmutils.declare_atomic_nanmin_float32(lmod),
897
- (ptr, val))
900
+ return builder.call(
901
+ nvvmutils.declare_atomic_nanmin_float32(lmod), (ptr, val)
902
+ )
898
903
  elif dtype in (types.int32, types.int64):
899
- return builder.atomic_rmw('min', ptr, val, ordering='monotonic')
904
+ return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
900
905
  elif dtype in (types.uint32, types.uint64):
901
- return builder.atomic_rmw('umin', ptr, val, ordering='monotonic')
906
+ return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
902
907
  else:
903
- raise TypeError('Unimplemented atomic min with %s array' % dtype)
908
+ raise TypeError("Unimplemented atomic min with %s array" % dtype)
904
909
 
905
910
 
906
911
  @lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
@@ -917,19 +922,21 @@ def ptx_atomic_cas(context, builder, sig, args):
917
922
  aryty, indty, oldty, valty = sig.args
918
923
  ary, inds, old, val = args
919
924
 
920
- indty, indices = _normalize_indices(context, builder, indty, inds, aryty,
921
- valty)
925
+ indty, indices = _normalize_indices(
926
+ context, builder, indty, inds, aryty, valty
927
+ )
922
928
 
923
929
  lary = context.make_array(aryty)(context, builder, ary)
924
- ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
925
- wraparound=True)
930
+ ptr = cgutils.get_item_pointer(
931
+ context, builder, aryty, lary, indices, wraparound=True
932
+ )
926
933
 
927
934
  if aryty.dtype in (cuda.cudadecl.integer_numba_types):
928
935
  lmod = builder.module
929
936
  bitwidth = aryty.dtype.bitwidth
930
937
  return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
931
938
  else:
932
- raise TypeError('Unimplemented atomic cas with %s array' % aryty.dtype)
939
+ raise TypeError("Unimplemented atomic cas with %s array" % aryty.dtype)
933
940
 
934
941
 
935
942
  # -----------------------------------------------------------------------------
@@ -937,15 +944,20 @@ def ptx_atomic_cas(context, builder, sig, args):
937
944
 
938
945
  @lower(breakpoint)
939
946
  def ptx_brkpt(context, builder, sig, args):
940
- brkpt = ir.InlineAsm(ir.FunctionType(ir.VoidType(), []),
941
- "brkpt;", '', side_effect=True)
947
+ brkpt = ir.InlineAsm(
948
+ ir.FunctionType(ir.VoidType(), []), "brkpt;", "", side_effect=True
949
+ )
942
950
  builder.call(brkpt, ())
943
951
 
944
952
 
945
953
  @lower(stubs.nanosleep, types.uint32)
946
954
  def ptx_nanosleep(context, builder, sig, args):
947
- nanosleep = ir.InlineAsm(ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
948
- "nanosleep.u32 $0;", 'r', side_effect=True)
955
+ nanosleep = ir.InlineAsm(
956
+ ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
957
+ "nanosleep.u32 $0;",
958
+ "r",
959
+ side_effect=True,
960
+ )
949
961
  ns = args[0]
950
962
  builder.call(nanosleep, [ns])
951
963
 
@@ -953,8 +965,9 @@ def ptx_nanosleep(context, builder, sig, args):
953
965
  # -----------------------------------------------------------------------------
954
966
 
955
967
 
956
- def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
957
- can_dynsized=False):
968
+ def _generic_array(
969
+ context, builder, shape, dtype, symbol_name, addrspace, can_dynsized=False
970
+ ):
958
971
  elemcount = reduce(operator.mul, shape, 1)
959
972
 
960
973
  # Check for valid shape for this type of allocation.
@@ -985,16 +998,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
985
998
  lmod = builder.module
986
999
 
987
1000
  # Create global variable in the requested address space
988
- gvmem = cgutils.add_global_variable(lmod, laryty, symbol_name,
989
- addrspace)
1001
+ gvmem = cgutils.add_global_variable(
1002
+ lmod, laryty, symbol_name, addrspace
1003
+ )
990
1004
  # Specify alignment to avoid misalignment bug
991
1005
  align = context.get_abi_sizeof(lldtype)
992
1006
  # Alignment is required to be a power of 2 for shared memory. If it is
993
1007
  # not a power of 2 (e.g. for a Record array) then round up accordingly.
994
- gvmem.align = 1 << (align - 1 ).bit_length()
1008
+ gvmem.align = 1 << (align - 1).bit_length()
995
1009
 
996
1010
  if dynamic_smem:
997
- gvmem.linkage = 'external'
1011
+ gvmem.linkage = "external"
998
1012
  else:
999
1013
  ## Comment out the following line to workaround a NVVM bug
1000
1014
  ## which generates a invalid symbol name when the linkage
@@ -1005,8 +1019,9 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1005
1019
  gvmem.initializer = ir.Constant(laryty, ir.Undefined)
1006
1020
 
1007
1021
  # Convert to generic address-space
1008
- dataptr = builder.addrspacecast(gvmem, ir.PointerType(ir.IntType(8)),
1009
- 'generic')
1022
+ dataptr = builder.addrspacecast(
1023
+ gvmem, ir.PointerType(ir.IntType(8)), "generic"
1024
+ )
1010
1025
 
1011
1026
  targetdata = ll.create_target_data(nvvm.NVVM().data_layout)
1012
1027
  lldtype = context.get_data_type(dtype)
@@ -1027,11 +1042,15 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1027
1042
  # Unfortunately NVVM does not provide an intrinsic for the
1028
1043
  # %dynamic_smem_size register, so we must read it using inline
1029
1044
  # assembly.
1030
- get_dynshared_size = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
1031
- "mov.u32 $0, %dynamic_smem_size;",
1032
- '=r', side_effect=True)
1033
- dynsmem_size = builder.zext(builder.call(get_dynshared_size, []),
1034
- ir.IntType(64))
1045
+ get_dynshared_size = ir.InlineAsm(
1046
+ ir.FunctionType(ir.IntType(32), []),
1047
+ "mov.u32 $0, %dynamic_smem_size;",
1048
+ "=r",
1049
+ side_effect=True,
1050
+ )
1051
+ dynsmem_size = builder.zext(
1052
+ builder.call(get_dynshared_size, []), ir.IntType(64)
1053
+ )
1035
1054
  # Only 1-D dynamic shared memory is supported so the following is a
1036
1055
  # sufficient construction of the shape
1037
1056
  kitemsize = context.get_constant(types.intp, itemsize)
@@ -1041,15 +1060,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1041
1060
 
1042
1061
  # Create array object
1043
1062
  ndim = len(shape)
1044
- aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
1063
+ aryty = types.Array(dtype=dtype, ndim=ndim, layout="C")
1045
1064
  ary = context.make_array(aryty)(context, builder)
1046
1065
 
1047
- context.populate_array(ary,
1048
- data=builder.bitcast(dataptr, ary.data.type),
1049
- shape=kshape,
1050
- strides=kstrides,
1051
- itemsize=context.get_constant(types.intp, itemsize),
1052
- meminfo=None)
1066
+ context.populate_array(
1067
+ ary,
1068
+ data=builder.bitcast(dataptr, ary.data.type),
1069
+ shape=kshape,
1070
+ strides=kstrides,
1071
+ itemsize=context.get_constant(types.intp, itemsize),
1072
+ meminfo=None,
1073
+ )
1053
1074
  return ary._getvalue()
1054
1075
 
1055
1076