numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -29,48 +29,49 @@ def initialize_dim3(builder, prefix):
29
29
  return cgutils.pack_struct(builder, (x, y, z))
30
30
 
31
31
 
32
- @lower_attr(types.Module(cuda), 'threadIdx')
32
+ @lower_attr(types.Module(cuda), "threadIdx")
33
33
  def cuda_threadIdx(context, builder, sig, args):
34
- return initialize_dim3(builder, 'tid')
34
+ return initialize_dim3(builder, "tid")
35
35
 
36
36
 
37
- @lower_attr(types.Module(cuda), 'blockDim')
37
+ @lower_attr(types.Module(cuda), "blockDim")
38
38
  def cuda_blockDim(context, builder, sig, args):
39
- return initialize_dim3(builder, 'ntid')
39
+ return initialize_dim3(builder, "ntid")
40
40
 
41
41
 
42
- @lower_attr(types.Module(cuda), 'blockIdx')
42
+ @lower_attr(types.Module(cuda), "blockIdx")
43
43
  def cuda_blockIdx(context, builder, sig, args):
44
- return initialize_dim3(builder, 'ctaid')
44
+ return initialize_dim3(builder, "ctaid")
45
45
 
46
46
 
47
- @lower_attr(types.Module(cuda), 'gridDim')
47
+ @lower_attr(types.Module(cuda), "gridDim")
48
48
  def cuda_gridDim(context, builder, sig, args):
49
- return initialize_dim3(builder, 'nctaid')
49
+ return initialize_dim3(builder, "nctaid")
50
50
 
51
51
 
52
- @lower_attr(types.Module(cuda), 'laneid')
52
+ @lower_attr(types.Module(cuda), "laneid")
53
53
  def cuda_laneid(context, builder, sig, args):
54
- return nvvmutils.call_sreg(builder, 'laneid')
54
+ return nvvmutils.call_sreg(builder, "laneid")
55
55
 
56
56
 
57
- @lower_attr(dim3, 'x')
57
+ @lower_attr(dim3, "x")
58
58
  def dim3_x(context, builder, sig, args):
59
59
  return builder.extract_value(args, 0)
60
60
 
61
61
 
62
- @lower_attr(dim3, 'y')
62
+ @lower_attr(dim3, "y")
63
63
  def dim3_y(context, builder, sig, args):
64
64
  return builder.extract_value(args, 1)
65
65
 
66
66
 
67
- @lower_attr(dim3, 'z')
67
+ @lower_attr(dim3, "z")
68
68
  def dim3_z(context, builder, sig, args):
69
69
  return builder.extract_value(args, 2)
70
70
 
71
71
 
72
72
  # -----------------------------------------------------------------------------
73
73
 
74
+
74
75
  @lower(cuda.const.array_like, types.Array)
75
76
  def cuda_const_array_like(context, builder, sig, args):
76
77
  # This is a no-op because CUDATargetContext.make_constant_array already
@@ -95,48 +96,68 @@ def _get_unique_smem_id(name):
95
96
  def cuda_shared_array_integer(context, builder, sig, args):
96
97
  length = sig.args[0].literal_value
97
98
  dtype = parse_dtype(sig.args[1])
98
- return _generic_array(context, builder, shape=(length,), dtype=dtype,
99
- symbol_name=_get_unique_smem_id('_cudapy_smem'),
100
- addrspace=nvvm.ADDRSPACE_SHARED,
101
- can_dynsized=True)
99
+ return _generic_array(
100
+ context,
101
+ builder,
102
+ shape=(length,),
103
+ dtype=dtype,
104
+ symbol_name=_get_unique_smem_id("_cudapy_smem"),
105
+ addrspace=nvvm.ADDRSPACE_SHARED,
106
+ can_dynsized=True,
107
+ )
102
108
 
103
109
 
104
110
  @lower(cuda.shared.array, types.Tuple, types.Any)
105
111
  @lower(cuda.shared.array, types.UniTuple, types.Any)
106
112
  def cuda_shared_array_tuple(context, builder, sig, args):
107
- shape = [ s.literal_value for s in sig.args[0] ]
113
+ shape = [s.literal_value for s in sig.args[0]]
108
114
  dtype = parse_dtype(sig.args[1])
109
- return _generic_array(context, builder, shape=shape, dtype=dtype,
110
- symbol_name=_get_unique_smem_id('_cudapy_smem'),
111
- addrspace=nvvm.ADDRSPACE_SHARED,
112
- can_dynsized=True)
115
+ return _generic_array(
116
+ context,
117
+ builder,
118
+ shape=shape,
119
+ dtype=dtype,
120
+ symbol_name=_get_unique_smem_id("_cudapy_smem"),
121
+ addrspace=nvvm.ADDRSPACE_SHARED,
122
+ can_dynsized=True,
123
+ )
113
124
 
114
125
 
115
126
  @lower(cuda.local.array, types.IntegerLiteral, types.Any)
116
127
  def cuda_local_array_integer(context, builder, sig, args):
117
128
  length = sig.args[0].literal_value
118
129
  dtype = parse_dtype(sig.args[1])
119
- return _generic_array(context, builder, shape=(length,), dtype=dtype,
120
- symbol_name='_cudapy_lmem',
121
- addrspace=nvvm.ADDRSPACE_LOCAL,
122
- can_dynsized=False)
130
+ return _generic_array(
131
+ context,
132
+ builder,
133
+ shape=(length,),
134
+ dtype=dtype,
135
+ symbol_name="_cudapy_lmem",
136
+ addrspace=nvvm.ADDRSPACE_LOCAL,
137
+ can_dynsized=False,
138
+ )
123
139
 
124
140
 
125
141
  @lower(cuda.local.array, types.Tuple, types.Any)
126
142
  @lower(cuda.local.array, types.UniTuple, types.Any)
127
143
  def ptx_lmem_alloc_array(context, builder, sig, args):
128
- shape = [ s.literal_value for s in sig.args[0] ]
144
+ shape = [s.literal_value for s in sig.args[0]]
129
145
  dtype = parse_dtype(sig.args[1])
130
- return _generic_array(context, builder, shape=shape, dtype=dtype,
131
- symbol_name='_cudapy_lmem',
132
- addrspace=nvvm.ADDRSPACE_LOCAL,
133
- can_dynsized=False)
146
+ return _generic_array(
147
+ context,
148
+ builder,
149
+ shape=shape,
150
+ dtype=dtype,
151
+ symbol_name="_cudapy_lmem",
152
+ addrspace=nvvm.ADDRSPACE_LOCAL,
153
+ can_dynsized=False,
154
+ )
134
155
 
135
156
 
136
157
  @lower(stubs.threadfence_block)
137
158
  def ptx_threadfence_block(context, builder, sig, args):
138
159
  assert not args
139
- fname = 'llvm.nvvm.membar.cta'
160
+ fname = "llvm.nvvm.membar.cta"
140
161
  lmod = builder.module
141
162
  fnty = ir.FunctionType(ir.VoidType(), ())
142
163
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -147,7 +168,7 @@ def ptx_threadfence_block(context, builder, sig, args):
147
168
  @lower(stubs.threadfence_system)
148
169
  def ptx_threadfence_system(context, builder, sig, args):
149
170
  assert not args
150
- fname = 'llvm.nvvm.membar.sys'
171
+ fname = "llvm.nvvm.membar.sys"
151
172
  lmod = builder.module
152
173
  fnty = ir.FunctionType(ir.VoidType(), ())
153
174
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -158,7 +179,7 @@ def ptx_threadfence_system(context, builder, sig, args):
158
179
  @lower(stubs.threadfence)
159
180
  def ptx_threadfence_device(context, builder, sig, args):
160
181
  assert not args
161
- fname = 'llvm.nvvm.membar.gl'
182
+ fname = "llvm.nvvm.membar.gl"
162
183
  lmod = builder.module
163
184
  fnty = ir.FunctionType(ir.VoidType(), ())
164
185
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -175,7 +196,7 @@ def ptx_syncwarp(context, builder, sig, args):
175
196
 
176
197
  @lower(stubs.syncwarp, types.i4)
177
198
  def ptx_syncwarp_mask(context, builder, sig, args):
178
- fname = 'llvm.nvvm.bar.warp.sync'
199
+ fname = "llvm.nvvm.bar.warp.sync"
179
200
  lmod = builder.module
180
201
  fnty = ir.FunctionType(ir.VoidType(), (ir.IntType(32),))
181
202
  sync = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -183,14 +204,18 @@ def ptx_syncwarp_mask(context, builder, sig, args):
183
204
  return context.get_dummy_value()
184
205
 
185
206
 
186
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4,
187
- types.i4)
188
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4,
189
- types.i4)
190
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4,
191
- types.i4)
192
- @lower(stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4,
193
- types.i4)
207
+ @lower(
208
+ stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
209
+ )
210
+ @lower(
211
+ stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
212
+ )
213
+ @lower(
214
+ stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
215
+ )
216
+ @lower(
217
+ stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
218
+ )
194
219
  def ptx_shfl_sync_i32(context, builder, sig, args):
195
220
  """
196
221
  The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
@@ -203,12 +228,17 @@ def ptx_shfl_sync_i32(context, builder, sig, args):
203
228
  value_type = sig.args[2]
204
229
  if value_type in types.real_domain:
205
230
  value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
206
- fname = 'llvm.nvvm.shfl.sync.i32'
231
+ fname = "llvm.nvvm.shfl.sync.i32"
207
232
  lmod = builder.module
208
233
  fnty = ir.FunctionType(
209
234
  ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
210
- (ir.IntType(32), ir.IntType(32), ir.IntType(32),
211
- ir.IntType(32), ir.IntType(32))
235
+ (
236
+ ir.IntType(32),
237
+ ir.IntType(32),
238
+ ir.IntType(32),
239
+ ir.IntType(32),
240
+ ir.IntType(32),
241
+ ),
212
242
  )
213
243
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
214
244
  if value_type.bitwidth == 32:
@@ -239,11 +269,12 @@ def ptx_shfl_sync_i32(context, builder, sig, args):
239
269
 
240
270
  @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
241
271
  def ptx_vote_sync(context, builder, sig, args):
242
- fname = 'llvm.nvvm.vote.sync'
272
+ fname = "llvm.nvvm.vote.sync"
243
273
  lmod = builder.module
244
- fnty = ir.FunctionType(ir.LiteralStructType((ir.IntType(32),
245
- ir.IntType(1))),
246
- (ir.IntType(32), ir.IntType(32), ir.IntType(1)))
274
+ fnty = ir.FunctionType(
275
+ ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
276
+ (ir.IntType(32), ir.IntType(32), ir.IntType(1)),
277
+ )
247
278
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
248
279
  return builder.call(func, args)
249
280
 
@@ -257,7 +288,7 @@ def ptx_match_any_sync(context, builder, sig, args):
257
288
  width = sig.args[1].bitwidth
258
289
  if sig.args[1] in types.real_domain:
259
290
  value = builder.bitcast(value, ir.IntType(width))
260
- fname = 'llvm.nvvm.match.any.sync.i{}'.format(width)
291
+ fname = "llvm.nvvm.match.any.sync.i{}".format(width)
261
292
  lmod = builder.module
262
293
  fnty = ir.FunctionType(ir.IntType(32), (ir.IntType(32), ir.IntType(width)))
263
294
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
@@ -273,27 +304,35 @@ def ptx_match_all_sync(context, builder, sig, args):
273
304
  width = sig.args[1].bitwidth
274
305
  if sig.args[1] in types.real_domain:
275
306
  value = builder.bitcast(value, ir.IntType(width))
276
- fname = 'llvm.nvvm.match.all.sync.i{}'.format(width)
307
+ fname = "llvm.nvvm.match.all.sync.i{}".format(width)
277
308
  lmod = builder.module
278
- fnty = ir.FunctionType(ir.LiteralStructType((ir.IntType(32),
279
- ir.IntType(1))),
280
- (ir.IntType(32), ir.IntType(width)))
309
+ fnty = ir.FunctionType(
310
+ ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
311
+ (ir.IntType(32), ir.IntType(width)),
312
+ )
281
313
  func = cgutils.get_or_insert_function(lmod, fnty, fname)
282
314
  return builder.call(func, (mask, value))
283
315
 
284
316
 
285
317
  @lower(stubs.activemask)
286
318
  def ptx_activemask(context, builder, sig, args):
287
- activemask = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
288
- "activemask.b32 $0;", '=r', side_effect=True)
319
+ activemask = ir.InlineAsm(
320
+ ir.FunctionType(ir.IntType(32), []),
321
+ "activemask.b32 $0;",
322
+ "=r",
323
+ side_effect=True,
324
+ )
289
325
  return builder.call(activemask, [])
290
326
 
291
327
 
292
328
  @lower(stubs.lanemask_lt)
293
329
  def ptx_lanemask_lt(context, builder, sig, args):
294
- activemask = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
295
- "mov.u32 $0, %lanemask_lt;", '=r',
296
- side_effect=True)
330
+ activemask = ir.InlineAsm(
331
+ ir.FunctionType(ir.IntType(32), []),
332
+ "mov.u32 $0, %lanemask_lt;",
333
+ "=r",
334
+ side_effect=True,
335
+ )
297
336
  return builder.call(activemask, [])
298
337
 
299
338
 
@@ -308,7 +347,7 @@ def ptx_fma(context, builder, sig, args):
308
347
 
309
348
 
310
349
  def float16_float_ty_constraint(bitwidth):
311
- typemap = {32: ('f32', 'f'), 64: ('f64', 'd')}
350
+ typemap = {32: ("f32", "f"), 64: ("f64", "d")}
312
351
 
313
352
  try:
314
353
  return typemap[bitwidth]
@@ -342,7 +381,7 @@ def float_to_float16_cast(context, builder, fromty, toty, val):
342
381
 
343
382
 
344
383
  def float16_int_constraint(bitwidth):
345
- typemap = { 8: 'c', 16: 'h', 32: 'r', 64: 'l' }
384
+ typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
346
385
 
347
386
  try:
348
387
  return typemap[bitwidth]
@@ -355,12 +394,12 @@ def float16_int_constraint(bitwidth):
355
394
  def float16_to_integer_cast(context, builder, fromty, toty, val):
356
395
  bitwidth = toty.bitwidth
357
396
  constraint = float16_int_constraint(bitwidth)
358
- signedness = 's' if toty.signed else 'u'
397
+ signedness = "s" if toty.signed else "u"
359
398
 
360
399
  fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
361
- asm = ir.InlineAsm(fnty,
362
- f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;",
363
- f"={constraint},h")
400
+ asm = ir.InlineAsm(
401
+ fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
402
+ )
364
403
  return builder.call(asm, [val])
365
404
 
366
405
 
@@ -369,40 +408,38 @@ def float16_to_integer_cast(context, builder, fromty, toty, val):
369
408
  def integer_to_float16_cast(context, builder, fromty, toty, val):
370
409
  bitwidth = fromty.bitwidth
371
410
  constraint = float16_int_constraint(bitwidth)
372
- signedness = 's' if fromty.signed else 'u'
411
+ signedness = "s" if fromty.signed else "u"
373
412
 
374
- fnty = ir.FunctionType(ir.IntType(16),
375
- [context.get_value_type(fromty)])
376
- asm = ir.InlineAsm(fnty,
377
- f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;",
378
- f"=h,{constraint}")
413
+ fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
414
+ asm = ir.InlineAsm(
415
+ fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
416
+ )
379
417
  return builder.call(asm, [val])
380
418
 
381
419
 
382
420
  def lower_fp16_binary(fn, op):
383
421
  @lower(fn, types.float16, types.float16)
384
422
  def ptx_fp16_binary(context, builder, sig, args):
385
- fnty = ir.FunctionType(ir.IntType(16),
386
- [ir.IntType(16), ir.IntType(16)])
387
- asm = ir.InlineAsm(fnty, f'{op}.f16 $0,$1,$2;', '=h,h,h')
423
+ fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
424
+ asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
388
425
  return builder.call(asm, args)
389
426
 
390
427
 
391
- lower_fp16_binary(stubs.fp16.hadd, 'add')
392
- lower_fp16_binary(operator.add, 'add')
393
- lower_fp16_binary(operator.iadd, 'add')
394
- lower_fp16_binary(stubs.fp16.hsub, 'sub')
395
- lower_fp16_binary(operator.sub, 'sub')
396
- lower_fp16_binary(operator.isub, 'sub')
397
- lower_fp16_binary(stubs.fp16.hmul, 'mul')
398
- lower_fp16_binary(operator.mul, 'mul')
399
- lower_fp16_binary(operator.imul, 'mul')
428
+ lower_fp16_binary(stubs.fp16.hadd, "add")
429
+ lower_fp16_binary(operator.add, "add")
430
+ lower_fp16_binary(operator.iadd, "add")
431
+ lower_fp16_binary(stubs.fp16.hsub, "sub")
432
+ lower_fp16_binary(operator.sub, "sub")
433
+ lower_fp16_binary(operator.isub, "sub")
434
+ lower_fp16_binary(stubs.fp16.hmul, "mul")
435
+ lower_fp16_binary(operator.mul, "mul")
436
+ lower_fp16_binary(operator.imul, "mul")
400
437
 
401
438
 
402
439
  @lower(stubs.fp16.hneg, types.float16)
403
440
  def ptx_fp16_hneg(context, builder, sig, args):
404
441
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
405
- asm = ir.InlineAsm(fnty, 'neg.f16 $0, $1;', '=h,h')
442
+ asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
406
443
  return builder.call(asm, args)
407
444
 
408
445
 
@@ -414,7 +451,7 @@ def operator_hneg(context, builder, sig, args):
414
451
  @lower(stubs.fp16.habs, types.float16)
415
452
  def ptx_fp16_habs(context, builder, sig, args):
416
453
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
417
- asm = ir.InlineAsm(fnty, 'abs.f16 $0, $1;', '=h,h')
454
+ asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
418
455
  return builder.call(asm, args)
419
456
 
420
457
 
@@ -450,27 +487,28 @@ _fp16_cmp = """{{
450
487
  def _gen_fp16_cmp(op):
451
488
  def ptx_fp16_comparison(context, builder, sig, args):
452
489
  fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
453
- asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), '=h,h,h')
490
+ asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
454
491
  result = builder.call(asm, args)
455
492
 
456
493
  zero = context.get_constant(types.int16, 0)
457
494
  int_result = builder.bitcast(result, ir.IntType(16))
458
495
  return builder.icmp_unsigned("!=", int_result, zero)
496
+
459
497
  return ptx_fp16_comparison
460
498
 
461
499
 
462
- lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp('eq'))
463
- lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp('eq'))
464
- lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp('ne'))
465
- lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp('ne'))
466
- lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp('ge'))
467
- lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp('ge'))
468
- lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp('gt'))
469
- lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp('gt'))
470
- lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp('le'))
471
- lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp('le'))
472
- lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp('lt'))
473
- lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp('lt'))
500
+ lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
501
+ lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
502
+ lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
503
+ lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
504
+ lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
505
+ lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
506
+ lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
507
+ lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
508
+ lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
509
+ lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
510
+ lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
511
+ lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
474
512
 
475
513
 
476
514
  def lower_fp16_minmax(fn, fname, op):
@@ -480,8 +518,8 @@ def lower_fp16_minmax(fn, fname, op):
480
518
  return builder.select(choice, args[0], args[1])
481
519
 
482
520
 
483
- lower_fp16_minmax(stubs.fp16.hmax, 'max', 'gt')
484
- lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
521
+ lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
522
+ lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
485
523
 
486
524
  # See:
487
525
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
@@ -489,8 +527,8 @@ lower_fp16_minmax(stubs.fp16.hmin, 'min', 'lt')
489
527
 
490
528
 
491
529
  cbrt_funcs = {
492
- types.float32: '__nv_cbrtf',
493
- types.float64: '__nv_cbrt',
530
+ types.float32: "__nv_cbrtf",
531
+ types.float64: "__nv_cbrt",
494
532
  }
495
533
 
496
534
 
@@ -514,7 +552,8 @@ def ptx_brev_u4(context, builder, sig, args):
514
552
  fn = cgutils.get_or_insert_function(
515
553
  builder.module,
516
554
  ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
517
- '__nv_brev')
555
+ "__nv_brev",
556
+ )
518
557
  return builder.call(fn, args)
519
558
 
520
559
 
@@ -526,15 +565,14 @@ def ptx_brev_u8(context, builder, sig, args):
526
565
  fn = cgutils.get_or_insert_function(
527
566
  builder.module,
528
567
  ir.FunctionType(ir.IntType(64), (ir.IntType(64),)),
529
- '__nv_brevll')
568
+ "__nv_brevll",
569
+ )
530
570
  return builder.call(fn, args)
531
571
 
532
572
 
533
573
  @lower(stubs.clz, types.Any)
534
574
  def ptx_clz(context, builder, sig, args):
535
- return builder.ctlz(
536
- args[0],
537
- context.get_constant(types.boolean, 0))
575
+ return builder.ctlz(args[0], context.get_constant(types.boolean, 0))
538
576
 
539
577
 
540
578
  @lower(stubs.ffs, types.i4)
@@ -543,7 +581,8 @@ def ptx_ffs_32(context, builder, sig, args):
543
581
  fn = cgutils.get_or_insert_function(
544
582
  builder.module,
545
583
  ir.FunctionType(ir.IntType(32), (ir.IntType(32),)),
546
- '__nv_ffs')
584
+ "__nv_ffs",
585
+ )
547
586
  return builder.call(fn, args)
548
587
 
549
588
 
@@ -553,7 +592,8 @@ def ptx_ffs_64(context, builder, sig, args):
553
592
  fn = cgutils.get_or_insert_function(
554
593
  builder.module,
555
594
  ir.FunctionType(ir.IntType(32), (ir.IntType(64),)),
556
- '__nv_ffsll')
595
+ "__nv_ffsll",
596
+ )
557
597
  return builder.call(fn, args)
558
598
 
559
599
 
@@ -567,10 +607,9 @@ def ptx_selp(context, builder, sig, args):
567
607
  def ptx_max_f4(context, builder, sig, args):
568
608
  fn = cgutils.get_or_insert_function(
569
609
  builder.module,
570
- ir.FunctionType(
571
- ir.FloatType(),
572
- (ir.FloatType(), ir.FloatType())),
573
- '__nv_fmaxf')
610
+ ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
611
+ "__nv_fmaxf",
612
+ )
574
613
  return builder.call(fn, args)
575
614
 
576
615
 
@@ -580,25 +619,26 @@ def ptx_max_f4(context, builder, sig, args):
580
619
  def ptx_max_f8(context, builder, sig, args):
581
620
  fn = cgutils.get_or_insert_function(
582
621
  builder.module,
583
- ir.FunctionType(
584
- ir.DoubleType(),
585
- (ir.DoubleType(), ir.DoubleType())),
586
- '__nv_fmax')
622
+ ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
623
+ "__nv_fmax",
624
+ )
587
625
 
588
- return builder.call(fn, [
589
- context.cast(builder, args[0], sig.args[0], types.double),
590
- context.cast(builder, args[1], sig.args[1], types.double),
591
- ])
626
+ return builder.call(
627
+ fn,
628
+ [
629
+ context.cast(builder, args[0], sig.args[0], types.double),
630
+ context.cast(builder, args[1], sig.args[1], types.double),
631
+ ],
632
+ )
592
633
 
593
634
 
594
635
  @lower(min, types.f4, types.f4)
595
636
  def ptx_min_f4(context, builder, sig, args):
596
637
  fn = cgutils.get_or_insert_function(
597
638
  builder.module,
598
- ir.FunctionType(
599
- ir.FloatType(),
600
- (ir.FloatType(), ir.FloatType())),
601
- '__nv_fminf')
639
+ ir.FunctionType(ir.FloatType(), (ir.FloatType(), ir.FloatType())),
640
+ "__nv_fminf",
641
+ )
602
642
  return builder.call(fn, args)
603
643
 
604
644
 
@@ -608,15 +648,17 @@ def ptx_min_f4(context, builder, sig, args):
608
648
  def ptx_min_f8(context, builder, sig, args):
609
649
  fn = cgutils.get_or_insert_function(
610
650
  builder.module,
611
- ir.FunctionType(
612
- ir.DoubleType(),
613
- (ir.DoubleType(), ir.DoubleType())),
614
- '__nv_fmin')
651
+ ir.FunctionType(ir.DoubleType(), (ir.DoubleType(), ir.DoubleType())),
652
+ "__nv_fmin",
653
+ )
615
654
 
616
- return builder.call(fn, [
617
- context.cast(builder, args[0], sig.args[0], types.double),
618
- context.cast(builder, args[1], sig.args[1], types.double),
619
- ])
655
+ return builder.call(
656
+ fn,
657
+ [
658
+ context.cast(builder, args[0], sig.args[0], types.double),
659
+ context.cast(builder, args[1], sig.args[1], types.double),
660
+ ],
661
+ )
620
662
 
621
663
 
622
664
  @lower(round, types.f4)
@@ -624,19 +666,22 @@ def ptx_min_f8(context, builder, sig, args):
624
666
  def ptx_round(context, builder, sig, args):
625
667
  fn = cgutils.get_or_insert_function(
626
668
  builder.module,
627
- ir.FunctionType(
628
- ir.IntType(64),
629
- (ir.DoubleType(),)),
630
- '__nv_llrint')
631
- return builder.call(fn, [
632
- context.cast(builder, args[0], sig.args[0], types.double),
633
- ])
669
+ ir.FunctionType(ir.IntType(64), (ir.DoubleType(),)),
670
+ "__nv_llrint",
671
+ )
672
+ return builder.call(
673
+ fn,
674
+ [
675
+ context.cast(builder, args[0], sig.args[0], types.double),
676
+ ],
677
+ )
634
678
 
635
679
 
636
680
  # This rounding implementation follows the algorithm used in the "fallback
637
681
  # version" of double_round in CPython.
638
682
  # https://github.com/python/cpython/blob/a755410e054e1e2390de5830befc08fe80706c66/Objects/floatobject.c#L964-L1007
639
683
 
684
+
640
685
  @lower(round, types.f4, types.Integer)
641
686
  @lower(round, types.f8, types.Integer)
642
687
  def round_to_impl(context, builder, sig, args):
@@ -651,7 +696,7 @@ def round_to_impl(context, builder, sig, args):
651
696
  pow1 = 10.0 ** (ndigits - 22)
652
697
  pow2 = 1e22
653
698
  else:
654
- pow1 = 10.0 ** ndigits
699
+ pow1 = 10.0**ndigits
655
700
  pow2 = 1.0
656
701
  y = (x * pow1) * pow2
657
702
  if math.isinf(y):
@@ -662,7 +707,7 @@ def round_to_impl(context, builder, sig, args):
662
707
  y = x / pow1
663
708
 
664
709
  z = round(y)
665
- if (math.fabs(y - z) == 0.5):
710
+ if math.fabs(y - z) == 0.5:
666
711
  # halfway between two integers; use round-half-even
667
712
  z = 2.0 * round(y / 2.0)
668
713
 
@@ -673,19 +718,25 @@ def round_to_impl(context, builder, sig, args):
673
718
 
674
719
  return z
675
720
 
676
- return context.compile_internal(builder, round_ndigits, sig, args, )
721
+ return context.compile_internal(
722
+ builder,
723
+ round_ndigits,
724
+ sig,
725
+ args,
726
+ )
677
727
 
678
728
 
679
729
  def gen_deg_rad(const):
680
730
  def impl(context, builder, sig, args):
681
- argty, = sig.args
731
+ (argty,) = sig.args
682
732
  factor = context.get_constant(argty, const)
683
733
  return builder.fmul(factor, args[0])
734
+
684
735
  return impl
685
736
 
686
737
 
687
- _deg2rad = math.pi / 180.
688
- _rad2deg = 180. / math.pi
738
+ _deg2rad = math.pi / 180.0
739
+ _rad2deg = 180.0 / math.pi
689
740
  lower(math.radians, types.f4)(gen_deg_rad(_deg2rad))
690
741
  lower(math.radians, types.f8)(gen_deg_rad(_deg2rad))
691
742
  lower(math.degrees, types.f4)(gen_deg_rad(_rad2deg))
@@ -701,16 +752,18 @@ def _normalize_indices(context, builder, indty, inds, aryty, valty):
701
752
  indices = [inds]
702
753
  else:
703
754
  indices = cgutils.unpack_tuple(builder, inds, count=len(indty))
704
- indices = [context.cast(builder, i, t, types.intp)
705
- for t, i in zip(indty, indices)]
755
+ indices = [
756
+ context.cast(builder, i, t, types.intp) for t, i in zip(indty, indices)
757
+ ]
706
758
 
707
759
  dtype = aryty.dtype
708
760
  if dtype != valty:
709
761
  raise TypeError("expect %s but got %s" % (dtype, valty))
710
762
 
711
763
  if aryty.ndim != len(indty):
712
- raise TypeError("indexing %d-D array with %d-D index" %
713
- (aryty.ndim, len(indty)))
764
+ raise TypeError(
765
+ "indexing %d-D array with %d-D index" % (aryty.ndim, len(indty))
766
+ )
714
767
 
715
768
  return indty, indices
716
769
 
@@ -722,14 +775,17 @@ def _atomic_dispatcher(dispatch_fn):
722
775
  ary, inds, val = args
723
776
  dtype = aryty.dtype
724
777
 
725
- indty, indices = _normalize_indices(context, builder, indty, inds,
726
- aryty, valty)
778
+ indty, indices = _normalize_indices(
779
+ context, builder, indty, inds, aryty, valty
780
+ )
727
781
 
728
782
  lary = context.make_array(aryty)(context, builder, ary)
729
- ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
730
- wraparound=True)
783
+ ptr = cgutils.get_item_pointer(
784
+ context, builder, aryty, lary, indices, wraparound=True
785
+ )
731
786
  # dispatcher to implementation base on dtype
732
787
  return dispatch_fn(context, builder, dtype, ptr, val)
788
+
733
789
  return imp
734
790
 
735
791
 
@@ -740,14 +796,16 @@ def _atomic_dispatcher(dispatch_fn):
740
796
  def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
741
797
  if dtype == types.float32:
742
798
  lmod = builder.module
743
- return builder.call(nvvmutils.declare_atomic_add_float32(lmod),
744
- (ptr, val))
799
+ return builder.call(
800
+ nvvmutils.declare_atomic_add_float32(lmod), (ptr, val)
801
+ )
745
802
  elif dtype == types.float64:
746
803
  lmod = builder.module
747
- return builder.call(nvvmutils.declare_atomic_add_float64(lmod),
748
- (ptr, val))
804
+ return builder.call(
805
+ nvvmutils.declare_atomic_add_float64(lmod), (ptr, val)
806
+ )
749
807
  else:
750
- return builder.atomic_rmw('add', ptr, val, 'monotonic')
808
+ return builder.atomic_rmw("add", ptr, val, "monotonic")
751
809
 
752
810
 
753
811
  @lower(stubs.atomic.sub, types.Array, types.intp, types.Any)
@@ -757,14 +815,16 @@ def ptx_atomic_add_tuple(context, builder, dtype, ptr, val):
757
815
  def ptx_atomic_sub(context, builder, dtype, ptr, val):
758
816
  if dtype == types.float32:
759
817
  lmod = builder.module
760
- return builder.call(nvvmutils.declare_atomic_sub_float32(lmod),
761
- (ptr, val))
818
+ return builder.call(
819
+ nvvmutils.declare_atomic_sub_float32(lmod), (ptr, val)
820
+ )
762
821
  elif dtype == types.float64:
763
822
  lmod = builder.module
764
- return builder.call(nvvmutils.declare_atomic_sub_float64(lmod),
765
- (ptr, val))
823
+ return builder.call(
824
+ nvvmutils.declare_atomic_sub_float64(lmod), (ptr, val)
825
+ )
766
826
  else:
767
- return builder.atomic_rmw('sub', ptr, val, 'monotonic')
827
+ return builder.atomic_rmw("sub", ptr, val, "monotonic")
768
828
 
769
829
 
770
830
  @lower(stubs.atomic.inc, types.Array, types.intp, types.Any)
@@ -775,10 +835,10 @@ def ptx_atomic_inc(context, builder, dtype, ptr, val):
775
835
  if dtype in cuda.cudadecl.unsigned_int_numba_types:
776
836
  bw = dtype.bitwidth
777
837
  lmod = builder.module
778
- fn = getattr(nvvmutils, f'declare_atomic_inc_int{bw}')
838
+ fn = getattr(nvvmutils, f"declare_atomic_inc_int{bw}")
779
839
  return builder.call(fn(lmod), (ptr, val))
780
840
  else:
781
- raise TypeError(f'Unimplemented atomic inc with {dtype} array')
841
+ raise TypeError(f"Unimplemented atomic inc with {dtype} array")
782
842
 
783
843
 
784
844
  @lower(stubs.atomic.dec, types.Array, types.intp, types.Any)
@@ -789,27 +849,27 @@ def ptx_atomic_dec(context, builder, dtype, ptr, val):
789
849
  if dtype in cuda.cudadecl.unsigned_int_numba_types:
790
850
  bw = dtype.bitwidth
791
851
  lmod = builder.module
792
- fn = getattr(nvvmutils, f'declare_atomic_dec_int{bw}')
852
+ fn = getattr(nvvmutils, f"declare_atomic_dec_int{bw}")
793
853
  return builder.call(fn(lmod), (ptr, val))
794
854
  else:
795
- raise TypeError(f'Unimplemented atomic dec with {dtype} array')
855
+ raise TypeError(f"Unimplemented atomic dec with {dtype} array")
796
856
 
797
857
 
798
858
  def ptx_atomic_bitwise(stub, op):
799
859
  @_atomic_dispatcher
800
860
  def impl_ptx_atomic(context, builder, dtype, ptr, val):
801
861
  if dtype in (cuda.cudadecl.integer_numba_types):
802
- return builder.atomic_rmw(op, ptr, val, 'monotonic')
862
+ return builder.atomic_rmw(op, ptr, val, "monotonic")
803
863
  else:
804
- raise TypeError(f'Unimplemented atomic {op} with {dtype} array')
864
+ raise TypeError(f"Unimplemented atomic {op} with {dtype} array")
805
865
 
806
866
  for ty in (types.intp, types.UniTuple, types.Tuple):
807
867
  lower(stub, types.Array, ty, types.Any)(impl_ptx_atomic)
808
868
 
809
869
 
810
- ptx_atomic_bitwise(stubs.atomic.and_, 'and')
811
- ptx_atomic_bitwise(stubs.atomic.or_, 'or')
812
- ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
870
+ ptx_atomic_bitwise(stubs.atomic.and_, "and")
871
+ ptx_atomic_bitwise(stubs.atomic.or_, "or")
872
+ ptx_atomic_bitwise(stubs.atomic.xor, "xor")
813
873
 
814
874
 
815
875
  @lower(stubs.atomic.exch, types.Array, types.intp, types.Any)
@@ -818,9 +878,9 @@ ptx_atomic_bitwise(stubs.atomic.xor, 'xor')
818
878
  @_atomic_dispatcher
819
879
  def ptx_atomic_exch(context, builder, dtype, ptr, val):
820
880
  if dtype in (cuda.cudadecl.integer_numba_types):
821
- return builder.atomic_rmw('xchg', ptr, val, 'monotonic')
881
+ return builder.atomic_rmw("xchg", ptr, val, "monotonic")
822
882
  else:
823
- raise TypeError(f'Unimplemented atomic exch with {dtype} array')
883
+ raise TypeError(f"Unimplemented atomic exch with {dtype} array")
824
884
 
825
885
 
826
886
  @lower(stubs.atomic.max, types.Array, types.intp, types.Any)
@@ -830,17 +890,19 @@ def ptx_atomic_exch(context, builder, dtype, ptr, val):
830
890
  def ptx_atomic_max(context, builder, dtype, ptr, val):
831
891
  lmod = builder.module
832
892
  if dtype == types.float64:
833
- return builder.call(nvvmutils.declare_atomic_max_float64(lmod),
834
- (ptr, val))
893
+ return builder.call(
894
+ nvvmutils.declare_atomic_max_float64(lmod), (ptr, val)
895
+ )
835
896
  elif dtype == types.float32:
836
- return builder.call(nvvmutils.declare_atomic_max_float32(lmod),
837
- (ptr, val))
897
+ return builder.call(
898
+ nvvmutils.declare_atomic_max_float32(lmod), (ptr, val)
899
+ )
838
900
  elif dtype in (types.int32, types.int64):
839
- return builder.atomic_rmw('max', ptr, val, ordering='monotonic')
901
+ return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
840
902
  elif dtype in (types.uint32, types.uint64):
841
- return builder.atomic_rmw('umax', ptr, val, ordering='monotonic')
903
+ return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
842
904
  else:
843
- raise TypeError('Unimplemented atomic max with %s array' % dtype)
905
+ raise TypeError("Unimplemented atomic max with %s array" % dtype)
844
906
 
845
907
 
846
908
  @lower(stubs.atomic.min, types.Array, types.intp, types.Any)
@@ -850,17 +912,19 @@ def ptx_atomic_max(context, builder, dtype, ptr, val):
850
912
  def ptx_atomic_min(context, builder, dtype, ptr, val):
851
913
  lmod = builder.module
852
914
  if dtype == types.float64:
853
- return builder.call(nvvmutils.declare_atomic_min_float64(lmod),
854
- (ptr, val))
915
+ return builder.call(
916
+ nvvmutils.declare_atomic_min_float64(lmod), (ptr, val)
917
+ )
855
918
  elif dtype == types.float32:
856
- return builder.call(nvvmutils.declare_atomic_min_float32(lmod),
857
- (ptr, val))
919
+ return builder.call(
920
+ nvvmutils.declare_atomic_min_float32(lmod), (ptr, val)
921
+ )
858
922
  elif dtype in (types.int32, types.int64):
859
- return builder.atomic_rmw('min', ptr, val, ordering='monotonic')
923
+ return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
860
924
  elif dtype in (types.uint32, types.uint64):
861
- return builder.atomic_rmw('umin', ptr, val, ordering='monotonic')
925
+ return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
862
926
  else:
863
- raise TypeError('Unimplemented atomic min with %s array' % dtype)
927
+ raise TypeError("Unimplemented atomic min with %s array" % dtype)
864
928
 
865
929
 
866
930
  @lower(stubs.atomic.nanmax, types.Array, types.intp, types.Any)
@@ -870,17 +934,19 @@ def ptx_atomic_min(context, builder, dtype, ptr, val):
870
934
  def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
871
935
  lmod = builder.module
872
936
  if dtype == types.float64:
873
- return builder.call(nvvmutils.declare_atomic_nanmax_float64(lmod),
874
- (ptr, val))
937
+ return builder.call(
938
+ nvvmutils.declare_atomic_nanmax_float64(lmod), (ptr, val)
939
+ )
875
940
  elif dtype == types.float32:
876
- return builder.call(nvvmutils.declare_atomic_nanmax_float32(lmod),
877
- (ptr, val))
941
+ return builder.call(
942
+ nvvmutils.declare_atomic_nanmax_float32(lmod), (ptr, val)
943
+ )
878
944
  elif dtype in (types.int32, types.int64):
879
- return builder.atomic_rmw('max', ptr, val, ordering='monotonic')
945
+ return builder.atomic_rmw("max", ptr, val, ordering="monotonic")
880
946
  elif dtype in (types.uint32, types.uint64):
881
- return builder.atomic_rmw('umax', ptr, val, ordering='monotonic')
947
+ return builder.atomic_rmw("umax", ptr, val, ordering="monotonic")
882
948
  else:
883
- raise TypeError('Unimplemented atomic max with %s array' % dtype)
949
+ raise TypeError("Unimplemented atomic max with %s array" % dtype)
884
950
 
885
951
 
886
952
  @lower(stubs.atomic.nanmin, types.Array, types.intp, types.Any)
@@ -890,17 +956,19 @@ def ptx_atomic_nanmax(context, builder, dtype, ptr, val):
890
956
  def ptx_atomic_nanmin(context, builder, dtype, ptr, val):
891
957
  lmod = builder.module
892
958
  if dtype == types.float64:
893
- return builder.call(nvvmutils.declare_atomic_nanmin_float64(lmod),
894
- (ptr, val))
959
+ return builder.call(
960
+ nvvmutils.declare_atomic_nanmin_float64(lmod), (ptr, val)
961
+ )
895
962
  elif dtype == types.float32:
896
- return builder.call(nvvmutils.declare_atomic_nanmin_float32(lmod),
897
- (ptr, val))
963
+ return builder.call(
964
+ nvvmutils.declare_atomic_nanmin_float32(lmod), (ptr, val)
965
+ )
898
966
  elif dtype in (types.int32, types.int64):
899
- return builder.atomic_rmw('min', ptr, val, ordering='monotonic')
967
+ return builder.atomic_rmw("min", ptr, val, ordering="monotonic")
900
968
  elif dtype in (types.uint32, types.uint64):
901
- return builder.atomic_rmw('umin', ptr, val, ordering='monotonic')
969
+ return builder.atomic_rmw("umin", ptr, val, ordering="monotonic")
902
970
  else:
903
- raise TypeError('Unimplemented atomic min with %s array' % dtype)
971
+ raise TypeError("Unimplemented atomic min with %s array" % dtype)
904
972
 
905
973
 
906
974
  @lower(stubs.atomic.compare_and_swap, types.Array, types.Any, types.Any)
@@ -917,19 +985,21 @@ def ptx_atomic_cas(context, builder, sig, args):
917
985
  aryty, indty, oldty, valty = sig.args
918
986
  ary, inds, old, val = args
919
987
 
920
- indty, indices = _normalize_indices(context, builder, indty, inds, aryty,
921
- valty)
988
+ indty, indices = _normalize_indices(
989
+ context, builder, indty, inds, aryty, valty
990
+ )
922
991
 
923
992
  lary = context.make_array(aryty)(context, builder, ary)
924
- ptr = cgutils.get_item_pointer(context, builder, aryty, lary, indices,
925
- wraparound=True)
993
+ ptr = cgutils.get_item_pointer(
994
+ context, builder, aryty, lary, indices, wraparound=True
995
+ )
926
996
 
927
997
  if aryty.dtype in (cuda.cudadecl.integer_numba_types):
928
998
  lmod = builder.module
929
999
  bitwidth = aryty.dtype.bitwidth
930
1000
  return nvvmutils.atomic_cmpxchg(builder, lmod, bitwidth, ptr, old, val)
931
1001
  else:
932
- raise TypeError('Unimplemented atomic cas with %s array' % aryty.dtype)
1002
+ raise TypeError("Unimplemented atomic cas with %s array" % aryty.dtype)
933
1003
 
934
1004
 
935
1005
  # -----------------------------------------------------------------------------
@@ -937,15 +1007,20 @@ def ptx_atomic_cas(context, builder, sig, args):
937
1007
 
938
1008
  @lower(breakpoint)
939
1009
  def ptx_brkpt(context, builder, sig, args):
940
- brkpt = ir.InlineAsm(ir.FunctionType(ir.VoidType(), []),
941
- "brkpt;", '', side_effect=True)
1010
+ brkpt = ir.InlineAsm(
1011
+ ir.FunctionType(ir.VoidType(), []), "brkpt;", "", side_effect=True
1012
+ )
942
1013
  builder.call(brkpt, ())
943
1014
 
944
1015
 
945
1016
  @lower(stubs.nanosleep, types.uint32)
946
1017
  def ptx_nanosleep(context, builder, sig, args):
947
- nanosleep = ir.InlineAsm(ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
948
- "nanosleep.u32 $0;", 'r', side_effect=True)
1018
+ nanosleep = ir.InlineAsm(
1019
+ ir.FunctionType(ir.VoidType(), [ir.IntType(32)]),
1020
+ "nanosleep.u32 $0;",
1021
+ "r",
1022
+ side_effect=True,
1023
+ )
949
1024
  ns = args[0]
950
1025
  builder.call(nanosleep, [ns])
951
1026
 
@@ -953,8 +1028,9 @@ def ptx_nanosleep(context, builder, sig, args):
953
1028
  # -----------------------------------------------------------------------------
954
1029
 
955
1030
 
956
- def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
957
- can_dynsized=False):
1031
+ def _generic_array(
1032
+ context, builder, shape, dtype, symbol_name, addrspace, can_dynsized=False
1033
+ ):
958
1034
  elemcount = reduce(operator.mul, shape, 1)
959
1035
 
960
1036
  # Check for valid shape for this type of allocation.
@@ -985,16 +1061,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
985
1061
  lmod = builder.module
986
1062
 
987
1063
  # Create global variable in the requested address space
988
- gvmem = cgutils.add_global_variable(lmod, laryty, symbol_name,
989
- addrspace)
1064
+ gvmem = cgutils.add_global_variable(
1065
+ lmod, laryty, symbol_name, addrspace
1066
+ )
990
1067
  # Specify alignment to avoid misalignment bug
991
1068
  align = context.get_abi_sizeof(lldtype)
992
1069
  # Alignment is required to be a power of 2 for shared memory. If it is
993
1070
  # not a power of 2 (e.g. for a Record array) then round up accordingly.
994
- gvmem.align = 1 << (align - 1 ).bit_length()
1071
+ gvmem.align = 1 << (align - 1).bit_length()
995
1072
 
996
1073
  if dynamic_smem:
997
- gvmem.linkage = 'external'
1074
+ gvmem.linkage = "external"
998
1075
  else:
999
1076
  ## Comment out the following line to workaround a NVVM bug
1000
1077
  ## which generates a invalid symbol name when the linkage
@@ -1005,8 +1082,9 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1005
1082
  gvmem.initializer = ir.Constant(laryty, ir.Undefined)
1006
1083
 
1007
1084
  # Convert to generic address-space
1008
- dataptr = builder.addrspacecast(gvmem, ir.PointerType(ir.IntType(8)),
1009
- 'generic')
1085
+ dataptr = builder.addrspacecast(
1086
+ gvmem, ir.PointerType(ir.IntType(8)), "generic"
1087
+ )
1010
1088
 
1011
1089
  targetdata = ll.create_target_data(nvvm.NVVM().data_layout)
1012
1090
  lldtype = context.get_data_type(dtype)
@@ -1027,11 +1105,15 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1027
1105
  # Unfortunately NVVM does not provide an intrinsic for the
1028
1106
  # %dynamic_smem_size register, so we must read it using inline
1029
1107
  # assembly.
1030
- get_dynshared_size = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []),
1031
- "mov.u32 $0, %dynamic_smem_size;",
1032
- '=r', side_effect=True)
1033
- dynsmem_size = builder.zext(builder.call(get_dynshared_size, []),
1034
- ir.IntType(64))
1108
+ get_dynshared_size = ir.InlineAsm(
1109
+ ir.FunctionType(ir.IntType(32), []),
1110
+ "mov.u32 $0, %dynamic_smem_size;",
1111
+ "=r",
1112
+ side_effect=True,
1113
+ )
1114
+ dynsmem_size = builder.zext(
1115
+ builder.call(get_dynshared_size, []), ir.IntType(64)
1116
+ )
1035
1117
  # Only 1-D dynamic shared memory is supported so the following is a
1036
1118
  # sufficient construction of the shape
1037
1119
  kitemsize = context.get_constant(types.intp, itemsize)
@@ -1041,15 +1123,17 @@ def _generic_array(context, builder, shape, dtype, symbol_name, addrspace,
1041
1123
 
1042
1124
  # Create array object
1043
1125
  ndim = len(shape)
1044
- aryty = types.Array(dtype=dtype, ndim=ndim, layout='C')
1126
+ aryty = types.Array(dtype=dtype, ndim=ndim, layout="C")
1045
1127
  ary = context.make_array(aryty)(context, builder)
1046
1128
 
1047
- context.populate_array(ary,
1048
- data=builder.bitcast(dataptr, ary.data.type),
1049
- shape=kshape,
1050
- strides=kstrides,
1051
- itemsize=context.get_constant(types.intp, itemsize),
1052
- meminfo=None)
1129
+ context.populate_array(
1130
+ ary,
1131
+ data=builder.bitcast(dataptr, ary.data.type),
1132
+ shape=kshape,
1133
+ strides=kstrides,
1134
+ itemsize=context.get_constant(types.intp, itemsize),
1135
+ meminfo=None,
1136
+ )
1053
1137
  return ary._getvalue()
1054
1138
 
1055
1139