numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,39 @@
1
1
  from llvmlite import ir
2
2
  from numba.core.typing.templates import ConcreteTemplate
3
3
  from numba.core import ir as numba_ir
4
- from numba.core import (cgutils, types, typing, funcdesc, config, compiler,
5
- sigutils, utils)
6
- from numba.core.compiler import (sanitize_compile_result_entries, CompilerBase,
7
- DefaultPassBuilder, Flags, Option,
8
- CompileResult)
4
+ from numba.core import (
5
+ cgutils,
6
+ types,
7
+ typing,
8
+ funcdesc,
9
+ config,
10
+ compiler,
11
+ sigutils,
12
+ utils,
13
+ )
14
+ from numba.core.compiler import (
15
+ sanitize_compile_result_entries,
16
+ CompilerBase,
17
+ DefaultPassBuilder,
18
+ Flags,
19
+ Option,
20
+ CompileResult,
21
+ )
9
22
  from numba.core.compiler_lock import global_compiler_lock
10
- from numba.core.compiler_machinery import (FunctionPass, LoweringPass,
11
- PassManager, register_pass)
23
+ from numba.core.compiler_machinery import (
24
+ FunctionPass,
25
+ LoweringPass,
26
+ PassManager,
27
+ register_pass,
28
+ )
12
29
  from numba.core.interpreter import Interpreter
13
30
  from numba.core.errors import NumbaInvalidConfigWarning
14
31
  from numba.core.untyped_passes import TranslateByteCode
15
- from numba.core.typed_passes import (IRLegalization, NativeLowering,
16
- AnnotateTypes)
32
+ from numba.core.typed_passes import (
33
+ IRLegalization,
34
+ NativeLowering,
35
+ AnnotateTypes,
36
+ )
17
37
  from warnings import warn
18
38
  from numba.cuda import nvvmutils
19
39
  from numba.cuda.api import get_current_device
@@ -52,15 +72,9 @@ class CUDAFlags(Flags):
52
72
  doc="Compute Capability",
53
73
  )
54
74
  max_registers = Option(
55
- type=_optional_int_type,
56
- default=None,
57
- doc="Max registers"
58
- )
59
- lto = Option(
60
- type=bool,
61
- default=False,
62
- doc="Enable Link-time Optimization"
75
+ type=_optional_int_type, default=None, doc="Max registers"
63
76
  )
77
+ lto = Option(type=bool, default=False, doc="Enable Link-time Optimization")
64
78
 
65
79
 
66
80
  # The CUDACompileResult (CCR) has a specially-defined entry point equal to its
@@ -79,6 +93,7 @@ class CUDAFlags(Flags):
79
93
  # point will no longer need to be a synthetic value, but will instead be a
80
94
  # pointer to the compiled function as in the CPU target.
81
95
 
96
+
82
97
  class CUDACompileResult(CompileResult):
83
98
  @property
84
99
  def entry_point(self):
@@ -92,7 +107,6 @@ def cuda_compile_result(**entries):
92
107
 
93
108
  @register_pass(mutates_CFG=True, analysis_only=False)
94
109
  class CUDABackend(LoweringPass):
95
-
96
110
  _name = "cuda_backend"
97
111
 
98
112
  def __init__(self):
@@ -102,7 +116,7 @@ class CUDABackend(LoweringPass):
102
116
  """
103
117
  Back-end: Packages lowering output in a compile result
104
118
  """
105
- lowered = state['cr']
119
+ lowered = state["cr"]
106
120
  signature = typing.signature(state.return_type, *state.args)
107
121
 
108
122
  state.cr = cuda_compile_result(
@@ -137,9 +151,12 @@ class CreateLibrary(LoweringPass):
137
151
  nvvm_options = state.flags.nvvm_options
138
152
  max_registers = state.flags.max_registers
139
153
  lto = state.flags.lto
140
- state.library = codegen.create_library(name, nvvm_options=nvvm_options,
141
- max_registers=max_registers,
142
- lto=lto)
154
+ state.library = codegen.create_library(
155
+ name,
156
+ nvvm_options=nvvm_options,
157
+ max_registers=max_registers,
158
+ lto=lto,
159
+ )
143
160
  # Enable object caching upfront so that the library can be serialized.
144
161
  state.library.enable_object_caching()
145
162
 
@@ -165,13 +182,15 @@ class CUDABytecodeInterpreter(Interpreter):
165
182
  gv_fn = numba_ir.Global("bool", bool, loc=self.loc)
166
183
  self.store(value=gv_fn, name=name)
167
184
 
168
- callres = numba_ir.Expr.call(self.get(name), (self.get(pred),), (),
169
- loc=self.loc)
185
+ callres = numba_ir.Expr.call(
186
+ self.get(name), (self.get(pred),), (), loc=self.loc
187
+ )
170
188
 
171
189
  pname = "$%spred" % (inst.offset)
172
190
  predicate = self.store(value=callres, name=pname)
173
- bra = numba_ir.Branch(cond=predicate, truebr=truebr, falsebr=falsebr,
174
- loc=self.loc)
191
+ bra = numba_ir.Branch(
192
+ cond=predicate, truebr=truebr, falsebr=falsebr, loc=self.loc
193
+ )
175
194
  self.current_block.append(bra)
176
195
 
177
196
 
@@ -183,18 +202,18 @@ class CUDATranslateBytecode(FunctionPass):
183
202
  FunctionPass.__init__(self)
184
203
 
185
204
  def run_pass(self, state):
186
- func_id = state['func_id']
187
- bc = state['bc']
205
+ func_id = state["func_id"]
206
+ bc = state["bc"]
188
207
  interp = CUDABytecodeInterpreter(func_id)
189
208
  func_ir = interp.interpret(bc)
190
- state['func_ir'] = func_ir
209
+ state["func_ir"] = func_ir
191
210
  return True
192
211
 
193
212
 
194
213
  class CUDACompiler(CompilerBase):
195
214
  def define_pipelines(self):
196
215
  dpb = DefaultPassBuilder
197
- pm = PassManager('cuda')
216
+ pm = PassManager("cuda")
198
217
 
199
218
  untyped_passes = dpb.define_untyped_pipeline(self.state)
200
219
 
@@ -225,10 +244,9 @@ class CUDACompiler(CompilerBase):
225
244
  return [pm]
226
245
 
227
246
  def define_cuda_lowering_pipeline(self, state):
228
- pm = PassManager('cuda_lowering')
247
+ pm = PassManager("cuda_lowering")
229
248
  # legalise
230
- pm.add_pass(IRLegalization,
231
- "ensure IR is legal prior to lowering")
249
+ pm.add_pass(IRLegalization, "ensure IR is legal prior to lowering")
232
250
  pm.add_pass(AnnotateTypes, "annotate types")
233
251
 
234
252
  # lower
@@ -241,13 +259,24 @@ class CUDACompiler(CompilerBase):
241
259
 
242
260
 
243
261
  @global_compiler_lock
244
- def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
245
- inline=False, fastmath=False, nvvm_options=None,
246
- cc=None, max_registers=None, lto=False):
262
+ def compile_cuda(
263
+ pyfunc,
264
+ return_type,
265
+ args,
266
+ debug=False,
267
+ lineinfo=False,
268
+ inline=False,
269
+ fastmath=False,
270
+ nvvm_options=None,
271
+ cc=None,
272
+ max_registers=None,
273
+ lto=False,
274
+ ):
247
275
  if cc is None:
248
- raise ValueError('Compute Capability must be supplied')
276
+ raise ValueError("Compute Capability must be supplied")
249
277
 
250
278
  from .descriptor import cuda_target
279
+
251
280
  typingctx = cuda_target.typing_context
252
281
  targetctx = cuda_target.target_context
253
282
 
@@ -269,10 +298,10 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
269
298
  flags.dbg_directives_only = True
270
299
 
271
300
  if debug:
272
- flags.error_model = 'python'
301
+ flags.error_model = "python"
273
302
  flags.dbg_extend_lifetimes = True
274
303
  else:
275
- flags.error_model = 'numpy'
304
+ flags.error_model = "numpy"
276
305
 
277
306
  if inline:
278
307
  flags.forceinline = True
@@ -286,15 +315,18 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
286
315
 
287
316
  # Run compilation pipeline
288
317
  from numba.core.target_extension import target_override
289
- with target_override('cuda'):
290
- cres = compiler.compile_extra(typingctx=typingctx,
291
- targetctx=targetctx,
292
- func=pyfunc,
293
- args=args,
294
- return_type=return_type,
295
- flags=flags,
296
- locals={},
297
- pipeline_class=CUDACompiler)
318
+
319
+ with target_override("cuda"):
320
+ cres = compiler.compile_extra(
321
+ typingctx=typingctx,
322
+ targetctx=targetctx,
323
+ func=pyfunc,
324
+ args=args,
325
+ return_type=return_type,
326
+ flags=flags,
327
+ locals={},
328
+ pipeline_class=CUDACompiler,
329
+ )
298
330
 
299
331
  library = cres.library
300
332
  library.finalize()
@@ -302,8 +334,9 @@ def compile_cuda(pyfunc, return_type, args, debug=False, lineinfo=False,
302
334
  return cres
303
335
 
304
336
 
305
- def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
306
- nvvm_options):
337
+ def cabi_wrap_function(
338
+ context, lib, fndesc, wrapper_function_name, nvvm_options
339
+ ):
307
340
  """
308
341
  Wrap a Numba ABI function in a C ABI wrapper at the NVVM IR level.
309
342
 
@@ -311,9 +344,11 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
311
344
  """
312
345
  # The wrapper will be contained in a new library that links to the wrapped
313
346
  # function's library
314
- library = lib.codegen.create_library(f'{lib.name}_function_',
315
- entry_name=wrapper_function_name,
316
- nvvm_options=nvvm_options)
347
+ library = lib.codegen.create_library(
348
+ f"{lib.name}_function_",
349
+ entry_name=wrapper_function_name,
350
+ nvvm_options=nvvm_options,
351
+ )
317
352
  library.add_linking_library(lib)
318
353
 
319
354
  # Determine the caller (C ABI) and wrapper (Numba ABI) function types
@@ -331,14 +366,15 @@ def cabi_wrap_function(context, lib, fndesc, wrapper_function_name,
331
366
  # its return value
332
367
 
333
368
  wrapfn = ir.Function(wrapper_module, wrapfnty, wrapper_function_name)
334
- builder = ir.IRBuilder(wrapfn.append_basic_block(''))
369
+ builder = ir.IRBuilder(wrapfn.append_basic_block(""))
335
370
 
336
371
  arginfo = context.get_arg_packer(argtypes)
337
372
  callargs = arginfo.from_arguments(builder, wrapfn.args)
338
373
  # We get (status, return_value), but we ignore the status since we
339
374
  # can't propagate it through the C ABI anyway
340
375
  _, return_value = context.call_conv.call_function(
341
- builder, func, restype, argtypes, callargs)
376
+ builder, func, restype, argtypes, callargs
377
+ )
342
378
  builder.ret(return_value)
343
379
 
344
380
  if config.DUMP_LLVM:
@@ -395,8 +431,10 @@ def kernel_fixup(kernel, debug):
395
431
 
396
432
  # Find all stores first
397
433
  for inst in block.instructions:
398
- if (isinstance(inst, ir.StoreInstr)
399
- and inst.operands[1] == return_value):
434
+ if (
435
+ isinstance(inst, ir.StoreInstr)
436
+ and inst.operands[1] == return_value
437
+ ):
400
438
  remove_list.append(inst)
401
439
 
402
440
  # Remove all stores
@@ -407,8 +445,9 @@ def kernel_fixup(kernel, debug):
407
445
  # value
408
446
 
409
447
  if isinstance(kernel.type, ir.PointerType):
410
- new_type = ir.PointerType(ir.FunctionType(ir.VoidType(),
411
- kernel.type.pointee.args[1:]))
448
+ new_type = ir.PointerType(
449
+ ir.FunctionType(ir.VoidType(), kernel.type.pointee.args[1:])
450
+ )
412
451
  else:
413
452
  new_type = ir.FunctionType(ir.VoidType(), kernel.type.args[1:])
414
453
 
@@ -418,13 +457,13 @@ def kernel_fixup(kernel, debug):
418
457
 
419
458
  # If debug metadata is present, remove the return value from it
420
459
 
421
- if kernel_metadata := getattr(kernel, 'metadata', None):
422
- if dbg_metadata := kernel_metadata.get('dbg', None):
460
+ if kernel_metadata := getattr(kernel, "metadata", None):
461
+ if dbg_metadata := kernel_metadata.get("dbg", None):
423
462
  for name, value in dbg_metadata.operands:
424
463
  if name == "type":
425
464
  type_metadata = value
426
465
  for tm_name, tm_value in type_metadata.operands:
427
- if tm_name == 'types':
466
+ if tm_name == "types":
428
467
  types = tm_value
429
468
  types.operands = types.operands[1:]
430
469
  if config.DUMP_LLVM:
@@ -435,26 +474,24 @@ def kernel_fixup(kernel, debug):
435
474
  nvvm.set_cuda_kernel(kernel)
436
475
 
437
476
  if config.DUMP_LLVM:
438
- print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, '-'))
477
+ print(f"LLVM DUMP: Post kernel fixup {kernel.name}".center(80, "-"))
439
478
  print(kernel.module)
440
- print('=' * 80)
479
+ print("=" * 80)
441
480
 
442
481
 
443
482
  def add_exception_store_helper(kernel):
444
-
445
483
  # Create global variables for exception state
446
484
 
447
485
  def define_error_gv(postfix):
448
486
  name = kernel.name + postfix
449
- gv = cgutils.add_global_variable(kernel.module, ir.IntType(32),
450
- name)
487
+ gv = cgutils.add_global_variable(kernel.module, ir.IntType(32), name)
451
488
  gv.initializer = ir.Constant(gv.type.pointee, None)
452
489
  return gv
453
490
 
454
491
  gv_exc = define_error_gv("__errcode__")
455
492
  gv_tid = []
456
493
  gv_ctaid = []
457
- for i in 'xyz':
494
+ for i in "xyz":
458
495
  gv_tid.append(define_error_gv("__tid%s__" % i))
459
496
  gv_ctaid.append(define_error_gv("__ctaid%s__" % i))
460
497
 
@@ -484,18 +521,25 @@ def add_exception_store_helper(kernel):
484
521
  # Use atomic cmpxchg to prevent rewriting the error status
485
522
  # Only the first error is recorded
486
523
 
487
- xchg = builder.cmpxchg(gv_exc, old, status.code,
488
- 'monotonic', 'monotonic')
524
+ xchg = builder.cmpxchg(
525
+ gv_exc, old, status.code, "monotonic", "monotonic"
526
+ )
489
527
  changed = builder.extract_value(xchg, 1)
490
528
 
491
529
  # If the xchange is successful, save the thread ID.
492
530
  sreg = nvvmutils.SRegBuilder(builder)
493
531
  with builder.if_then(changed):
494
- for dim, ptr, in zip("xyz", gv_tid):
532
+ for (
533
+ dim,
534
+ ptr,
535
+ ) in zip("xyz", gv_tid):
495
536
  val = sreg.tid(dim)
496
537
  builder.store(val, ptr)
497
538
 
498
- for dim, ptr, in zip("xyz", gv_ctaid):
539
+ for (
540
+ dim,
541
+ ptr,
542
+ ) in zip("xyz", gv_ctaid):
499
543
  val = sreg.ctaid(dim)
500
544
  builder.store(val, ptr)
501
545
 
@@ -505,9 +549,19 @@ def add_exception_store_helper(kernel):
505
549
 
506
550
 
507
551
  @global_compiler_lock
508
- def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
509
- fastmath=False, cc=None, opt=None, abi="c", abi_info=None,
510
- output='ptx'):
552
+ def compile(
553
+ pyfunc,
554
+ sig,
555
+ debug=None,
556
+ lineinfo=False,
557
+ device=True,
558
+ fastmath=False,
559
+ cc=None,
560
+ opt=None,
561
+ abi="c",
562
+ abi_info=None,
563
+ output="ptx",
564
+ ):
511
565
  """Compile a Python function to PTX or LTO-IR for a given set of argument
512
566
  types.
513
567
 
@@ -551,43 +605,49 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
551
605
  :rtype: tuple
552
606
  """
553
607
  if abi not in ("numba", "c"):
554
- raise NotImplementedError(f'Unsupported ABI: {abi}')
608
+ raise NotImplementedError(f"Unsupported ABI: {abi}")
555
609
 
556
- if abi == 'c' and not device:
557
- raise NotImplementedError('The C ABI is not supported for kernels')
610
+ if abi == "c" and not device:
611
+ raise NotImplementedError("The C ABI is not supported for kernels")
558
612
 
559
613
  if output not in ("ptx", "ltoir"):
560
- raise NotImplementedError(f'Unsupported output type: {output}')
614
+ raise NotImplementedError(f"Unsupported output type: {output}")
561
615
 
562
616
  debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
563
617
  opt = (config.OPT != 0) if opt is None else opt
564
618
 
565
619
  if debug and opt:
566
- msg = ("debug=True with opt=True "
567
- "is not supported by CUDA. This may result in a crash"
568
- " - set debug=False or opt=False.")
620
+ msg = (
621
+ "debug=True with opt=True "
622
+ "is not supported by CUDA. This may result in a crash"
623
+ " - set debug=False or opt=False."
624
+ )
569
625
  warn(NumbaInvalidConfigWarning(msg))
570
626
 
571
- lto = (output == 'ltoir')
627
+ lto = output == "ltoir"
572
628
  abi_info = abi_info or dict()
573
629
 
574
- nvvm_options = {
575
- 'fastmath': fastmath,
576
- 'opt': 3 if opt else 0
577
- }
630
+ nvvm_options = {"fastmath": fastmath, "opt": 3 if opt else 0}
578
631
 
579
632
  if debug:
580
- nvvm_options['g'] = None
633
+ nvvm_options["g"] = None
581
634
 
582
635
  if lto:
583
- nvvm_options['gen-lto'] = None
636
+ nvvm_options["gen-lto"] = None
584
637
 
585
638
  args, return_type = sigutils.normalize_signature(sig)
586
639
 
587
640
  cc = cc or config.CUDA_DEFAULT_PTX_CC
588
- cres = compile_cuda(pyfunc, return_type, args, debug=debug,
589
- lineinfo=lineinfo, fastmath=fastmath,
590
- nvvm_options=nvvm_options, cc=cc)
641
+ cres = compile_cuda(
642
+ pyfunc,
643
+ return_type,
644
+ args,
645
+ debug=debug,
646
+ lineinfo=lineinfo,
647
+ fastmath=fastmath,
648
+ nvvm_options=nvvm_options,
649
+ cc=cc,
650
+ )
591
651
  resty = cres.signature.return_type
592
652
 
593
653
  if resty and not device and resty != types.void:
@@ -598,9 +658,10 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
598
658
  if device:
599
659
  lib = cres.library
600
660
  if abi == "c":
601
- wrapper_name = abi_info.get('abi_name', pyfunc.__name__)
602
- lib = cabi_wrap_function(tgt, lib, cres.fndesc, wrapper_name,
603
- nvvm_options)
661
+ wrapper_name = abi_info.get("abi_name", pyfunc.__name__)
662
+ lib = cabi_wrap_function(
663
+ tgt, lib, cres.fndesc, wrapper_name, nvvm_options
664
+ )
604
665
  else:
605
666
  lib = cres.library
606
667
  kernel = lib.get_function(cres.fndesc.llvm_func_name)
@@ -614,38 +675,94 @@ def compile(pyfunc, sig, debug=None, lineinfo=False, device=True,
614
675
  return code, resty
615
676
 
616
677
 
617
- def compile_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
618
- device=True, fastmath=False, opt=None,
619
- abi="c", abi_info=None, output='ptx'):
678
+ def compile_for_current_device(
679
+ pyfunc,
680
+ sig,
681
+ debug=None,
682
+ lineinfo=False,
683
+ device=True,
684
+ fastmath=False,
685
+ opt=None,
686
+ abi="c",
687
+ abi_info=None,
688
+ output="ptx",
689
+ ):
620
690
  """Compile a Python function to PTX or LTO-IR for a given signature for the
621
691
  current device's compute capabilility. This calls :func:`compile` with an
622
692
  appropriate ``cc`` value for the current device."""
623
693
  cc = get_current_device().compute_capability
624
- return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device,
625
- fastmath=fastmath, cc=cc, opt=opt, abi=abi,
626
- abi_info=abi_info, output=output)
694
+ return compile(
695
+ pyfunc,
696
+ sig,
697
+ debug=debug,
698
+ lineinfo=lineinfo,
699
+ device=device,
700
+ fastmath=fastmath,
701
+ cc=cc,
702
+ opt=opt,
703
+ abi=abi,
704
+ abi_info=abi_info,
705
+ output=output,
706
+ )
627
707
 
628
708
 
629
- def compile_ptx(pyfunc, sig, debug=None, lineinfo=False, device=False,
630
- fastmath=False, cc=None, opt=None, abi="numba", abi_info=None):
709
+ def compile_ptx(
710
+ pyfunc,
711
+ sig,
712
+ debug=None,
713
+ lineinfo=False,
714
+ device=False,
715
+ fastmath=False,
716
+ cc=None,
717
+ opt=None,
718
+ abi="numba",
719
+ abi_info=None,
720
+ ):
631
721
  """Compile a Python function to PTX for a given signature. See
632
722
  :func:`compile`. The defaults for this function are to compile a kernel
633
723
  with the Numba ABI, rather than :func:`compile`'s default of compiling a
634
724
  device function with the C ABI."""
635
- return compile(pyfunc, sig, debug=debug, lineinfo=lineinfo, device=device,
636
- fastmath=fastmath, cc=cc, opt=opt, abi=abi,
637
- abi_info=abi_info, output='ptx')
725
+ return compile(
726
+ pyfunc,
727
+ sig,
728
+ debug=debug,
729
+ lineinfo=lineinfo,
730
+ device=device,
731
+ fastmath=fastmath,
732
+ cc=cc,
733
+ opt=opt,
734
+ abi=abi,
735
+ abi_info=abi_info,
736
+ output="ptx",
737
+ )
638
738
 
639
739
 
640
- def compile_ptx_for_current_device(pyfunc, sig, debug=None, lineinfo=False,
641
- device=False, fastmath=False, opt=None,
642
- abi="numba", abi_info=None):
740
+ def compile_ptx_for_current_device(
741
+ pyfunc,
742
+ sig,
743
+ debug=None,
744
+ lineinfo=False,
745
+ device=False,
746
+ fastmath=False,
747
+ opt=None,
748
+ abi="numba",
749
+ abi_info=None,
750
+ ):
643
751
  """Compile a Python function to PTX for a given signature for the current
644
752
  device's compute capabilility. See :func:`compile_ptx`."""
645
753
  cc = get_current_device().compute_capability
646
- return compile_ptx(pyfunc, sig, debug=debug, lineinfo=lineinfo,
647
- device=device, fastmath=fastmath, cc=cc, opt=opt,
648
- abi=abi, abi_info=abi_info)
754
+ return compile_ptx(
755
+ pyfunc,
756
+ sig,
757
+ debug=debug,
758
+ lineinfo=lineinfo,
759
+ device=device,
760
+ fastmath=fastmath,
761
+ cc=cc,
762
+ opt=opt,
763
+ abi=abi,
764
+ abi_info=abi_info,
765
+ )
649
766
 
650
767
 
651
768
  def declare_device_function(name, restype, argtypes, link):
@@ -654,6 +771,7 @@ def declare_device_function(name, restype, argtypes, link):
654
771
 
655
772
  def declare_device_function_template(name, restype, argtypes, link):
656
773
  from .descriptor import cuda_target
774
+
657
775
  typingctx = cuda_target.typing_context
658
776
  targetctx = cuda_target.target_context
659
777
  sig = typing.signature(restype, *argtypes)
@@ -664,7 +782,8 @@ def declare_device_function_template(name, restype, argtypes, link):
664
782
  cases = [sig]
665
783
 
666
784
  fndesc = funcdesc.ExternalFunctionDescriptor(
667
- name=name, restype=restype, argtypes=argtypes)
785
+ name=name, restype=restype, argtypes=argtypes
786
+ )
668
787
  typingctx.insert_user_function(extfn, device_function_template)
669
788
  targetctx.insert_user_function(extfn, fndesc)
670
789
 
@@ -23,7 +23,7 @@ FNDEF(hdiv)(
23
23
  )
24
24
  {
25
25
  __half retval = __hdiv(__short_as_half (x), __short_as_half (y));
26
-
26
+
27
27
  *return_value = __half_as_short (retval);
28
28
  // Signal that no Python exception occurred
29
29
  return 0;
@@ -44,4 +44,3 @@ UNARY_FUNCTION(hceil)
44
44
  UNARY_FUNCTION(hrcp)
45
45
  UNARY_FUNCTION(hrint)
46
46
  UNARY_FUNCTION(htrunc)
47
-