numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -1,25 +1,46 @@
1
1
  from llvmlite import ir
2
2
  from numba.core.typing.templates import ConcreteTemplate
3
3
  from numba.core import ir as numba_ir
4
- from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
5
- sigutils, utils)
6
- from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
7
- DefaultPassBuilder, Flags, Option,
8
- CompileResult)
4
+ from numba.core import (
5
+ cgutils,
6
+ types,
7
+ typing,
8
+ funcdesc,
9
+ config,
10
+ compiler,
11
+ sigutils,
12
+ utils,
13
+ )
14
+ from numba.core.compiler import (
15
+ sanitize_compile_result_entries,
16
+ CompilerBase,
17
+ DefaultPassBuilder,
18
+ Flags,
19
+ Option,
20
+ CompileResult,
21
+ )
9
22
  from numba.core.compiler_lock import global_compiler_lock
10
- from numba.core.compiler_machinery import (FunctionPass, LoweringPass,
11
- PassManager, register_pass)
23
+ from numba.core.compiler_machinery import (
24
+ FunctionPass,
25
+ LoweringPass,
26
+ PassManager,
27
+ register_pass,
28
+ )
12
29
  from numba.core.interpreter import Interpreter
13
30
  from numba.core.errors import NumbaInvalidConfigWarning
14
31
  from numba.core.untyped_passes import TranslateByteCode
15
- from numba.core.typed_passes import (IRLegalization, NativeLowering,
16
- AnnotateTypes)
32
+ from numba.core.typed_passes import (
33
+ IRLegalization,
34
+ NativeLowering,
35
+ AnnotateTypes,
36
+ )
17
37
  from warnings import warn
18
38
  from numba.cuda import nvvmutils
19
39
  from numba.cuda.api import get_current_device
20
40
  from numba.cuda.cudadrv import nvvm
21
41
  from numba.cuda.descriptor import cuda_target
22
42
  from numba.cuda.target import CUDACABICallConv
43
+ from numba.cuda import lowering
23
44
 
24
45
 
25
46
  def _nvvm_options_type(x):
@@ -52,15 +73,9 @@ class CUDAFlags(Flags):
52
73
  doc="Compute Capability",
53
74
  )
54
75
  max_registers = Option(
55
- type=_optional_int_type,
56
- default=None,
57
- doc="Max registers"
58
- )
59
- lto = Option(
60
- type=bool,
61
- default=False,
62
- doc="Enable Link-time Optimization"
76
+ type=_optional_int_type, default=None, doc="Max registers"
63
77
  )
78
+ lto = Option(type=bool, default=False, doc="Enable Link-time Optimization")
64
79
 
65
80
 
66
81
  # The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -79,6 +94,7 @@ class CUDAFlags(Flags):
79
94
  # point will no longer need to be a synthetic value, but will instead be a
80
95
  # pointer to the compiled function as in the CPU target.
81
96
 
97
+
82
98
  class CUDACompileResult(CompileResult):
83
99
  @property
84
100
  def entry_point(self):
@@ -92,7 +108,6 @@ def cuda_compile_result(**entries):
92
108
 
93
109
  @register_pass(mutates_CFG=True, analysis_only=False)
94
110
  class CUDABackend(LoweringPass):
95
-
96
111
  _name = "cuda_backend"
97
112
 
98
113
  def __init__(self):
@@ -102,7 +117,7 @@ class CUDABackend(LoweringPass):
102
117
  """
103
118
  Back-end: Packages lowering output in a compile result
104
119
  """
105
- lowered = state['cr']
120
+ lowered = state["cr"]
106
121
  signature = typing.signature(state.return_type, *state.args)
107
122
 
108
123
  state.cr = cuda_compile_result(
@@ -137,15 +152,30 @@ class CreateLibrary(LoweringPass):
137
152
  nvvm_options = state.flags.nvvm_options
138
153
  max_registers = state.flags.max_registers
139
154
  lto = state.flags.lto
140
- state.library = codegen.create_library(name, nvvm_options=nvvm_options,
141
- max_registers=max_registers,
142
- lto=lto)
155
+ state.library = codegen.create_library(
156
+ name,
157
+ nvvm_options=nvvm_options,
158
+ max_registers=max_registers,
159
+ lto=lto,
160
+ )
143
161
  # Enable object caching upfront so that the library can be serialized.
144
162
  state.library.enable_object_caching()
145
163
 
146
164
  return True
147
165
 
148
166
 
167
+ @register_pass(mutates_CFG=True, analysis_only=False)
168
+ class CUDANativeLowering(NativeLowering):
169
+ """Lowering pass for a CUDA native function IR described solely in terms of
170
+ Numba's standard `numba.core.ir` nodes."""
171
+
172
+ _name = "cuda_native_lowering"
173
+
174
+ @property
175
+ def lowering_class(self):
176
+ return lowering.CUDALower
177
+
178
+
149
179
  class CUDABytecodeInterpreter(Interpreter):
150
180
  # Based on the superclass implementation, but names the resulting variable
151
181
  # "$bool<N>" instead of "bool<N>" - see Numba PR #9888:
@@ -165,13 +195,15 @@ class CUDABytecodeInterpreter(Interpreter):
165
195
  gv_fn = numba_ir.Global("bool", bool, loc=self.loc)
166
196
  self.store(value=gv_fn, name=name)
167
197
 
168
- callres = numba_ir.Expr.call(self.get(name), (self.get(pred),), (),
169
- loc=self.loc)
198
+ callres = numba_ir.Expr.call(
199
+ self.get(name), (self.get(pred),), (), loc=self.loc
200
+ )
170
201
 
171
202
  pname = "$%spred" % (inst.offset)
172
203
  predicate = self.store(value=callres, name=pname)
173
- bra = numba_ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr,
174
- loc=self.loc)
204
+ bra = numba_ir.Branch(
205
+ cond=predicate, truebr=truebr, falsebr=falsebr, loc=self.loc
206
+ )
175
207
  self.current_block.append(bra)
176
208
 
177
209
 
@@ -183,18 +215,18 @@ class CUDATranslateBytecode(FunctionPass):
183
215
  FunctionPass.__init__(self)
184
216
 
185
217
  def run_pass(self, state):
186
- func_id = state['func_id']
187
- bc = state['bc']
218
+ func_id = state["func_id"]
219
+ bc = state["bc"]
188
220
  interp = CUDABytecodeInterpreter(func_id)
189
221
  func_ir = interp.interpret(bc)
190
- state['func_ir'] = func_ir
222
+ state["func_ir"] = func_ir
191
223
  return True
192
224
 
193
225
 
194
226
  class CUDACompiler(CompilerBase):
195
227
  def define_pipelines(self):
196
228
  dpb = DefaultPassBuilder
197
- pm = PassManager('cuda')
229
+ pm = PassManager("cuda")
198
230
 
199
231
  untyped_passes = dpb.define_untyped_pipeline(self.state)
200
232
 
@@ -225,15 +257,14 @@ class CUDACompiler(CompilerBase):
225
257
  return [pm]
226
258
 
227
259
  def define_cuda_lowering_pipeline(self, state):
228
- pm = PassManager('cuda_lowering')
260
+ pm = PassManager("cuda_lowering")
229
261
  # legalise
230
- pm.add_pass(IRLegalization,
231
- "ensure IR is legal prior to lowering")
262
+ pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
232
263
  pm.add_pass(AnnotateTypes, "annotate types")
233
264
 
234
265
  # lower
235
266
  pm.add_pass(CreateLibrary, "create library")
236
- pm.add_pass(NativeLowering, "native lowering")
267
+ pm.add_pass(CUDANativeLowering, "cuda native lowering")
237
268
  pm.add_pass(CUDABackend, "cuda backend")
238
269
 
239
270
  pm.finalize()
@@ -241,13 +272,24 @@ class CUDACompiler(CompilerBase):
241
272
 
242
273
 
243
274
  @global_compiler_lock
244
- def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
245
- inline=False, fastmath=False, nvvm_options=None,
246
- cc=None, max_registers=None, lto=False):
275
+ def compile_cuda(
276
+ pyfunc,
277
+ return_type,
278
+ args,
279
+ debug=False,
280
+ lineinfo=False,
281
+ inline=False,
282
+ fastmath=False,
283
+ nvvm_options=None,
284
+ cc=None,
285
+ max_registers=None,
286
+ lto=False,
287
+ ):
247
288
  if cc is None:
248
- raise ValueError('Compute Capability must be supplied')
289
+ raise ValueError("Compute Capability must be supplied")
249
290
 
250
291
  from .descriptor import cuda_target
292
+
251
293
  typingctx = cuda_target.typing_context
252
294
  targetctx = cuda_target.target_context
253
295
 
@@ -269,10 +311,10 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
269
311
  flags.dbg_directives_only = True
270
312
 
271
313
  if debug:
272
- flags.error_model = 'python'
314
+ flags.error_model = "python"
273
315
  flags.dbg_extend_lifetimes = True
274
316
  else:
275
- flags.error_model = 'numpy'
317
+ flags.error_model = "numpy"
276
318
 
277
319
  if inline:
278
320
  flags.forceinline = True
@@ -286,15 +328,18 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
286
328
 
287
329
  # Run compilation pipeline
288
330
  from numba.core.target_extension import target_override
289
- with target_override('cuda'):
290
- cres = compiler.compile_extra(typingctx=typingctx,
291
- targetctx=targetctx,
292
- func=pyfunc,
293
- args=args,
294
- return_type=return_type,
295
- flags=flags,
296
- locals={},
297
- pipeline_class=CUDACompiler)
331
+
332
+ with target_override("cuda"):
333
+ cres = compiler.compile_extra(
334
+ typingctx=typingctx,
335
+ targetctx=targetctx,
336
+ func=pyfunc,
337
+ args=args,
338
+ return_type=return_type,
339
+ flags=flags,
340
+ locals={},
341
+ pipeline_class=CUDACompiler,
342
+ )
298
343
 
299
344
  library = cres.library
300
345
  library.finalize()
@@ -302,8 +347,9 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
302
347
  return cres
303
348
 
304
349
 
305
- def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
306
- nvvm_options):
350
+ def cabi_wrap_function(
351
+ context, lib, fndesc, wrapper_function_name, nvvm_options
352
+ ):
307
353
  """
308
354
  Wrap a Numba ABI function in a C ABI wrapper at the NVVM IR level.
309
355
 
@@ -311,9 +357,11 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
311
357
  """
312
358
  # The wrapper will be contained in a new library that links to the wrapped
313
359
  # function's library
314
- library = lib.codegen.create_library(f'{lib.name}_function_',
315
- entry_name=wrapper_function_name,
316
- nvvm_options=nvvm_options)
360
+ library = lib.codegen.create_library(
361
+ f"{lib.name}_function_",
362
+ entry_name=wrapper_function_name,
363
+ nvvm_options=nvvm_options,
364
+ )
317
365
  library.add_linking_library(lib)
318
366
 
319
367
  # Determine the caller (C ABI) and wrapper (Numba ABI) function types
@@ -331,14 +379,15 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
331
379
  # its return value
332
380
 
333
381
  wrapfn = ir.Function(wrapper_module, wrapfnty, wrapper_function_name)
334
- builder = ir.IRBuilder(wrapfn.append_basic_block(''))
382
+ builder = ir.IRBuilder(wrapfn.append_basic_block(""))
335
383
 
336
384
  arginfo = context.get_arg_packer(argtypes)
337
385
  callargs = arginfo.from_arguments(builder, wrapfn.args)
338
386
  # We get (status, return_value), but we ignore the status since we
339
387
  # can't propagate it through the C ABI anyway
340
388
  _, return_value = context.call_conv.call_function(
341
- builder, func, restype, argtypes, callargs)
389
+ builder, func, restype, argtypes, callargs
390
+ )
342
391
  builder.ret(return_value)
343
392
 
344
393
  if config.DUMP_LLVM:
@@ -395,8 +444,10 @@ def kernel_fixup(kernel, debug):
395
444
 
396
445
  # Find all stores first
397
446
  for inst in block.instructions:
398
- if (isinstance(inst, ir.StoreInstr)
399
- and inst.operands[1] == return_value):
447
+ if (
448
+ isinstance(inst, ir.StoreInstr)
449
+ and inst.operands[1] == return_value
450
+ ):
400
451
  remove_list.append(inst)
401
452
 
402
453
  # Remove all stores
@@ -407,8 +458,9 @@ def kernel_fixup(kernel, debug):
407
458
  # value
408
459
 
409
460
  if isinstance(kernel.type, ir.PointerType):
410
- new_type = ir.PointerType(ir.FunctionType(ir.VoidType(),
411
- kernel.type.pointee.args[1:]))
461
+ new_type = ir.PointerType(
462
+ ir.FunctionType(ir.VoidType(), kernel.type.pointee.args[1:])
463
+ )
412
464
  else:
413
465
  new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])
414
466
 
@@ -418,13 +470,13 @@ def kernel_fixup(kernel, debug):
418
470
 
419
471
  # If debug metadata is present, remove the return value from it
420
472
 
421
- if kernel_metadata := getattr(kernel, 'metadata', None):
422
- if dbg_metadata := kernel_metadata.get('dbg', None):
473
+ if kernel_metadata := getattr(kernel, "metadata", None):
474
+ if dbg_metadata := kernel_metadata.get("dbg", None):
423
475
  for name, value in dbg_metadata.operands:
424
476
  if name == "type":
425
477
  type_metadata = value
426
478
  for tm_name, tm_value in type_metadata.operands:
427
- if tm_name == 'types':
479
+ if tm_name == "types":
428
480
  types = tm_value
429
481
  types.operands = types.operands[1:]
430
482
  if config.DUMP_LLVM:
@@ -435,26 +487,24 @@ def kernel_fixup(kernel, debug):
435
487
  nvvm.set_cuda_kernel(kernel)
436
488
 
437
489
  if config.DUMP_LLVM:
438
- print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-'))
490
+ print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, "-"))
439
491
  print(kernel.module)
440
- print('=' * 80)
492
+ print("=" * 80)
441
493
 
442
494
 
443
495
  def add_exception_store_helper(kernel):
444
-
445
496
  # Create global variables for exception state
446
497
 
447
498
  def define_error_gv(postfix):
448
499
  name = kernel.name + postfix
449
- gv = cgutils.add_global_variable(kernel.module, ir.IntType(32),
450
- name)
500
+ gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), name)
451
501
  gv.initializer = ir.Constant(gv.type.pointee, None)
452
502
  return gv
453
503
 
454
504
  gv_exc = define_error_gv("__errcode__")
455
505
  gv_tid = []
456
506
  gv_ctaid = []
457
- for i in 'xyz':
507
+ for i in "xyz":
458
508
  gv_tid.append(define_error_gv("__tid%s__" % i))
459
509
  gv_ctaid.append(define_error_gv("__ctaid%s__" % i))
460
510
 
@@ -484,18 +534,25 @@ def add_exception_store_helper(kernel):
484
534
  # Use atomic cmpxchg to prevent rewriting the error status
485
535
  # Only the first error is recorded
486
536
 
487
- xchg = builder.cmpxchg(gv_exc, old, status.code,
488
- 'monotonic', 'monotonic')
537
+ xchg = builder.cmpxchg(
538
+ gv_exc, old, status.code, "monotonic", "monotonic"
539
+ )
489
540
  changed = builder.extract_value(xchg, 1)
490
541
 
491
542
  # If the xchange is successful, save the thread ID.
492
543
  sreg = nvvmutils.SRegBuilder(builder)
493
544
  with builder.if_then(changed):
494
- for dim, ptr, in zip("xyz", gv_tid):
545
+ for (
546
+ dim,
547
+ ptr,
548
+ ) in zip("xyz", gv_tid):
495
549
  val = sreg.tid(dim)
496
550
  builder.store(val, ptr)
497
551
 
498
- for dim, ptr, in zip("xyz", gv_ctaid):
552
+ for (
553
+ dim,
554
+ ptr,
555
+ ) in zip("xyz", gv_ctaid):
499
556
  val = sreg.ctaid(dim)
500
557
  builder.store(val, ptr)
501
558
 
@@ -505,9 +562,19 @@ def add_exception_store_helper(kernel):
505
562
 
506
563
 
507
564
  @global_compiler_lock
508
- def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
509
- fastmath=False, cc=None, opt=None, abi="c", abi_info=None,
510
- output='ptx'):
565
+ def compile(
566
+ pyfunc,
567
+ sig,
568
+ debug=None,
569
+ lineinfo=False,
570
+ device=True,
571
+ fastmath=False,
572
+ cc=None,
573
+ opt=None,
574
+ abi="c",
575
+ abi_info=None,
576
+ output="ptx",
577
+ ):
511
578
  """Compile a Python function to PTX or LTO-IR for a given set of argument
512
579
  types.
513
580
 
@@ -551,43 +618,49 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
551
618
  :rtype: tuple
552
619
  """
553
620
  if abi not in ("numba", "c"):
554
- raise NotImplementedError(f'Unsupported ABI: {abi}')
621
+ raise NotImplementedError(f"Unsupported ABI: {abi}")
555
622
 
556
- if abi == 'c' and not device:
557
- raise NotImplementedError('The C ABI is not supported for kernels')
623
+ if abi == "c" and not device:
624
+ raise NotImplementedError("The C ABI is not supported for kernels")
558
625
 
559
626
  if output not in ("ptx", "ltoir"):
560
- raise NotImplementedError(f'Unsupported output type: {output}')
627
+ raise NotImplementedError(f"Unsupported output type: {output}")
561
628
 
562
629
  debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
563
630
  opt = (config.OPT != 0) if opt is None else opt
564
631
 
565
632
  if debug and opt:
566
- msg = ("debug=True with opt=True "
567
- "is not supported by CUDA. This may result in a crash"
568
- " - set debug=False or opt=False.")
633
+ msg = (
634
+ "debug=True with opt=True "
635
+ "is not supported by CUDA. This may result in a crash"
636
+ " - set debug=False or opt=False."
637
+ )
569
638
  warn(NumbaInvalidConfigWarning(msg))
570
639
 
571
- lto = (output == 'ltoir')
640
+ lto = output == "ltoir"
572
641
  abi_info = abi_info or dict()
573
642
 
574
- nvvm_options = {
575
- 'fastmath': fastmath,
576
- 'opt': 3 if opt else 0
577
- }
643
+ nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
578
644
 
579
645
  if debug:
580
- nvvm_options['g'] = None
646
+ nvvm_options["g"] = None
581
647
 
582
648
  if lto:
583
- nvvm_options['gen-lto'] = None
649
+ nvvm_options["gen-lto"] = None
584
650
 
585
651
  args, return_type = sigutils.normalize_signature(sig)
586
652
 
587
653
  cc = cc or config.CUDA_DEFAULT_PTX_CC
588
- cres = compile_cuda(pyfunc, return_type, args, debug=debug,
589
- lineinfo=lineinfo, fastmath=fastmath,
590
- nvvm_options=nvvm_options, cc=cc)
654
+ cres = compile_cuda(
655
+ pyfunc,
656
+ return_type,
657
+ args,
658
+ debug=debug,
659
+ lineinfo=lineinfo,
660
+ fastmath=fastmath,
661
+ nvvm_options=nvvm_options,
662
+ cc=cc,
663
+ )
591
664
  resty = cres.signature.return_type
592
665
 
593
666
  if resty and not device and resty != types.void:
@@ -598,9 +671,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
598
671
  if device:
599
672
  lib = cres.library
600
673
  if abi == "c":
601
- wrapper_name = abi_info.get('abi_name', pyfunc.__name__)
602
- lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name,
603
- nvvm_options)
674
+ wrapper_name = abi_info.get("abi_name", pyfunc.__name__)
675
+ lib = cabi_wrap_function(
676
+ tgt, lib, cres.fndesc, wrapper_name, nvvm_options
677
+ )
604
678
  else:
605
679
  lib = cres.library
606
680
  kernel = lib.get_function(cres.fndesc.llvm_func_name)
@@ -614,38 +688,94 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
614
688
  return code, resty
615
689
 
616
690
 
617
- def compile_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
618
- device=True, fastmath=False, opt=None,
619
- abi="c", abi_info=None, output='ptx'):
691
+ def compile_for_current_device(
692
+ pyfunc,
693
+ sig,
694
+ debug=None,
695
+ lineinfo=False,
696
+ device=True,
697
+ fastmath=False,
698
+ opt=None,
699
+ abi="c",
700
+ abi_info=None,
701
+ output="ptx",
702
+ ):
620
703
  """Compile a Python function to PTX or LTO-IR for a given signature for the
621
704
  current device's compute capabilility. This calls :func:`compile` with an
622
705
  appropriate ``cc`` value for the current device."""
623
706
  cc = get_current_device().compute_capability
624
- return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device,
625
- fastmath=fastmath, cc=cc, opt=opt, abi=abi,
626
- abi_info=abi_info, output=output)
707
+ return compile(
708
+ pyfunc,
709
+ sig,
710
+ debug=debug,
711
+ lineinfo=lineinfo,
712
+ device=device,
713
+ fastmath=fastmath,
714
+ cc=cc,
715
+ opt=opt,
716
+ abi=abi,
717
+ abi_info=abi_info,
718
+ output=output,
719
+ )
627
720
 
628
721
 
629
- def compile_ptx(pyfunc, sig, debug=None, lineinfo=False, device=False,
630
- fastmath=False, cc=None, opt=None, abi="numba", abi_info=None):
722
+ def compile_ptx(
723
+ pyfunc,
724
+ sig,
725
+ debug=None,
726
+ lineinfo=False,
727
+ device=False,
728
+ fastmath=False,
729
+ cc=None,
730
+ opt=None,
731
+ abi="numba",
732
+ abi_info=None,
733
+ ):
631
734
  """Compile a Python function to PTX for a given signature. See
632
735
  :func:`compile`. The defaults for this function are to compile a kernel
633
736
  with the Numba ABI, rather than :func:`compile`'s default of compiling a
634
737
  device function with the C ABI."""
635
- return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device,
636
- fastmath=fastmath, cc=cc, opt=opt, abi=abi,
637
- abi_info=abi_info, output='ptx')
738
+ return compile(
739
+ pyfunc,
740
+ sig,
741
+ debug=debug,
742
+ lineinfo=lineinfo,
743
+ device=device,
744
+ fastmath=fastmath,
745
+ cc=cc,
746
+ opt=opt,
747
+ abi=abi,
748
+ abi_info=abi_info,
749
+ output="ptx",
750
+ )
638
751
 
639
752
 
640
- def compile_ptx_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
641
- device=False, fastmath=False, opt=None,
642
- abi="numba", abi_info=None):
753
+ def compile_ptx_for_current_device(
754
+ pyfunc,
755
+ sig,
756
+ debug=None,
757
+ lineinfo=False,
758
+ device=False,
759
+ fastmath=False,
760
+ opt=None,
761
+ abi="numba",
762
+ abi_info=None,
763
+ ):
643
764
  """Compile a Python function to PTX for a given signature for the current
644
765
  device's compute capabilility. See :func:`compile_ptx`."""
645
766
  cc = get_current_device().compute_capability
646
- return compile_ptx(pyfunc, sig, debug=debug, lineinfo=lineinfo,
647
- device=device, fastmath=fastmath, cc=cc, opt=opt,
648
- abi=abi, abi_info=abi_info)
767
+ return compile_ptx(
768
+ pyfunc,
769
+ sig,
770
+ debug=debug,
771
+ lineinfo=lineinfo,
772
+ device=device,
773
+ fastmath=fastmath,
774
+ cc=cc,
775
+ opt=opt,
776
+ abi=abi,
777
+ abi_info=abi_info,
778
+ )
649
779
 
650
780
 
651
781
  def declare_device_function(name, restype, argtypes, link):
@@ -654,6 +784,7 @@ def declare_device_function(name, restype, argtypes, link):
654
784
 
655
785
  def declare_device_function_template(name, restype, argtypes, link):
656
786
  from .descriptor import cuda_target
787
+
657
788
  typingctx = cuda_target.typing_context
658
789
  targetctx = cuda_target.target_context
659
790
  sig = typing.signature(restype, *argtypes)
@@ -664,7 +795,8 @@ def declare_device_function_template(name, restype, argtypes, link):
664
795
  cases = [sig]
665
796
 
666
797
  fndesc = funcdesc.ExternalFunctionDescriptor(
667
- name=name, restype=restype, argtypes=argtypes)
798
+ name=name, restype=restype, argtypes=argtypes
799
+ )
668
800
  typingctx.insert_user_function(extfn, device_function_template)
669
801
  targetctx.insert_user_function(extfn, fndesc)
670
802
 
@@ -23,7 +23,7 @@ FNDEF(hdiv)(
23
23
  )
24
24
  {
25
25
  __half retval = __hdiv(__short_as_half (x), __short_as_half (y));
26
-
26
+
27
27
  *return_value = __half_as_short (retval);
28
28
  // Signal that no Python exception occurred
29
29
  return 0;
@@ -44,4 +44,3 @@ UNARY_FUNCTION(hceil)
44
44
  UNARY_FUNCTION(hrcp)
45
45
  UNARY_FUNCTION(hrint)
46
46
  UNARY_FUNCTION(htrunc)
47
-