numba-cuda 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +463 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.0.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.0.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  """
2
2
  This is a direct translation of nvvm.h
3
3
  """
4
+
4
5
  import logging
5
6
  import re
6
7
  import sys
7
8
  import warnings
8
- from ctypes import (c_void_p, c_int, POINTER, c_char_p, c_size_t, byref,
9
- c_char)
9
+ from ctypes import c_void_p, c_int, POINTER, c_char_p, c_size_t, byref, c_char
10
10
 
11
11
  import threading
12
12
 
@@ -31,7 +31,7 @@ nvvm_program = c_void_p
31
31
  # Result code
32
32
  nvvm_result = c_int
33
33
 
34
- RESULT_CODE_NAMES = '''
34
+ RESULT_CODE_NAMES = """
35
35
  NVVM_SUCCESS
36
36
  NVVM_ERROR_OUT_OF_MEMORY
37
37
  NVVM_ERROR_PROGRAM_CREATION_FAILURE
@@ -42,19 +42,23 @@ NVVM_ERROR_INVALID_IR
42
42
  NVVM_ERROR_INVALID_OPTION
43
43
  NVVM_ERROR_NO_MODULE_IN_PROGRAM
44
44
  NVVM_ERROR_COMPILATION
45
- '''.split()
45
+ """.split()
46
46
 
47
47
  for i, k in enumerate(RESULT_CODE_NAMES):
48
48
  setattr(sys.modules[__name__], k, i)
49
49
 
50
50
  # Data layouts. NVVM IR 1.8 (CUDA 11.6) introduced 128-bit integer support.
51
51
 
52
- _datalayout_original = ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-'
53
- 'i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-'
54
- 'v64:64:64-v128:128:128-n16:32:64')
55
- _datalayout_i128 = ('e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-'
56
- 'i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-'
57
- 'v64:64:64-v128:128:128-n16:32:64')
52
+ _datalayout_original = (
53
+ "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-"
54
+ "i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-"
55
+ "v64:64:64-v128:128:128-n16:32:64"
56
+ )
57
+ _datalayout_i128 = (
58
+ "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-"
59
+ "i128:128:128-f32:32:32-f64:64:64-v16:16:16-v32:32:32-"
60
+ "v64:64:64-v128:128:128-n16:32:64"
61
+ )
58
62
 
59
63
 
60
64
  def is_available():
@@ -73,59 +77,74 @@ _nvvm_lock = threading.Lock()
73
77
 
74
78
 
75
79
  class NVVM(object):
76
- '''Process-wide singleton.
77
- '''
78
- _PROTOTYPES = {
80
+ """Process-wide singleton."""
79
81
 
82
+ _PROTOTYPES = {
80
83
  # nvvmResult nvvmVersion(int *major, int *minor)
81
- 'nvvmVersion': (nvvm_result, POINTER(c_int), POINTER(c_int)),
82
-
84
+ "nvvmVersion": (nvvm_result, POINTER(c_int), POINTER(c_int)),
83
85
  # nvvmResult nvvmCreateProgram(nvvmProgram *cu)
84
- 'nvvmCreateProgram': (nvvm_result, POINTER(nvvm_program)),
85
-
86
+ "nvvmCreateProgram": (nvvm_result, POINTER(nvvm_program)),
86
87
  # nvvmResult nvvmDestroyProgram(nvvmProgram *cu)
87
- 'nvvmDestroyProgram': (nvvm_result, POINTER(nvvm_program)),
88
-
88
+ "nvvmDestroyProgram": (nvvm_result, POINTER(nvvm_program)),
89
89
  # nvvmResult nvvmAddModuleToProgram(nvvmProgram cu, const char *buffer,
90
90
  # size_t size, const char *name)
91
- 'nvvmAddModuleToProgram': (
92
- nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p),
93
-
91
+ "nvvmAddModuleToProgram": (
92
+ nvvm_result,
93
+ nvvm_program,
94
+ c_char_p,
95
+ c_size_t,
96
+ c_char_p,
97
+ ),
94
98
  # nvvmResult nvvmLazyAddModuleToProgram(nvvmProgram cu,
95
99
  # const char* buffer,
96
100
  # size_t size,
97
101
  # const char *name)
98
- 'nvvmLazyAddModuleToProgram': (
99
- nvvm_result, nvvm_program, c_char_p, c_size_t, c_char_p),
100
-
102
+ "nvvmLazyAddModuleToProgram": (
103
+ nvvm_result,
104
+ nvvm_program,
105
+ c_char_p,
106
+ c_size_t,
107
+ c_char_p,
108
+ ),
101
109
  # nvvmResult nvvmCompileProgram(nvvmProgram cu, int numOptions,
102
110
  # const char **options)
103
- 'nvvmCompileProgram': (
104
- nvvm_result, nvvm_program, c_int, POINTER(c_char_p)),
105
-
111
+ "nvvmCompileProgram": (
112
+ nvvm_result,
113
+ nvvm_program,
114
+ c_int,
115
+ POINTER(c_char_p),
116
+ ),
106
117
  # nvvmResult nvvmGetCompiledResultSize(nvvmProgram cu,
107
118
  # size_t *bufferSizeRet)
108
- 'nvvmGetCompiledResultSize': (
109
- nvvm_result, nvvm_program, POINTER(c_size_t)),
110
-
119
+ "nvvmGetCompiledResultSize": (
120
+ nvvm_result,
121
+ nvvm_program,
122
+ POINTER(c_size_t),
123
+ ),
111
124
  # nvvmResult nvvmGetCompiledResult(nvvmProgram cu, char *buffer)
112
- 'nvvmGetCompiledResult': (nvvm_result, nvvm_program, c_char_p),
113
-
125
+ "nvvmGetCompiledResult": (nvvm_result, nvvm_program, c_char_p),
114
126
  # nvvmResult nvvmGetProgramLogSize(nvvmProgram cu,
115
127
  # size_t *bufferSizeRet)
116
- 'nvvmGetProgramLogSize': (nvvm_result, nvvm_program, POINTER(c_size_t)),
117
-
128
+ "nvvmGetProgramLogSize": (nvvm_result, nvvm_program, POINTER(c_size_t)),
118
129
  # nvvmResult nvvmGetProgramLog(nvvmProgram cu, char *buffer)
119
- 'nvvmGetProgramLog': (nvvm_result, nvvm_program, c_char_p),
120
-
130
+ "nvvmGetProgramLog": (nvvm_result, nvvm_program, c_char_p),
121
131
  # nvvmResult nvvmIRVersion (int* majorIR, int* minorIR, int* majorDbg,
122
132
  # int* minorDbg )
123
- 'nvvmIRVersion': (nvvm_result, POINTER(c_int), POINTER(c_int),
124
- POINTER(c_int), POINTER(c_int)),
133
+ "nvvmIRVersion": (
134
+ nvvm_result,
135
+ POINTER(c_int),
136
+ POINTER(c_int),
137
+ POINTER(c_int),
138
+ POINTER(c_int),
139
+ ),
125
140
  # nvvmResult nvvmVerifyProgram (nvvmProgram prog, int numOptions,
126
141
  # const char** options)
127
- 'nvvmVerifyProgram': (nvvm_result, nvvm_program, c_int,
128
- POINTER(c_char_p))
142
+ "nvvmVerifyProgram": (
143
+ nvvm_result,
144
+ nvvm_program,
145
+ c_int,
146
+ POINTER(c_char_p),
147
+ ),
129
148
  }
130
149
 
131
150
  # Singleton reference
@@ -136,11 +155,13 @@ class NVVM(object):
136
155
  if cls.__INSTANCE is None:
137
156
  cls.__INSTANCE = inst = object.__new__(cls)
138
157
  try:
139
- inst.driver = open_cudalib('nvvm')
158
+ inst.driver = open_cudalib("nvvm")
140
159
  except OSError as e:
141
160
  cls.__INSTANCE = None
142
- errmsg = ("libNVVM cannot be found. Do `conda install "
143
- "cudatoolkit`:\n%s")
161
+ errmsg = (
162
+ "libNVVM cannot be found. Do `conda install "
163
+ "cudatoolkit`:\n%s"
164
+ )
144
165
  raise NvvmSupportError(errmsg % e)
145
166
 
146
167
  # Find & populate functions
@@ -175,7 +196,7 @@ class NVVM(object):
175
196
  major = c_int()
176
197
  minor = c_int()
177
198
  err = self.nvvmVersion(byref(major), byref(minor))
178
- self.check_error(err, 'Failed to get version.')
199
+ self.check_error(err, "Failed to get version.")
179
200
  return major.value, minor.value
180
201
 
181
202
  def get_ir_version(self):
@@ -183,9 +204,10 @@ class NVVM(object):
183
204
  minorIR = c_int()
184
205
  majorDbg = c_int()
185
206
  minorDbg = c_int()
186
- err = self.nvvmIRVersion(byref(majorIR), byref(minorIR),
187
- byref(majorDbg), byref(minorDbg))
188
- self.check_error(err, 'Failed to get IR version.')
207
+ err = self.nvvmIRVersion(
208
+ byref(majorIR), byref(minorIR), byref(majorDbg), byref(minorDbg)
209
+ )
210
+ self.check_error(err, "Failed to get IR version.")
189
211
  return majorIR.value, minorIR.value, majorDbg.value, minorDbg.value
190
212
 
191
213
  def check_error(self, error, msg, exit=False):
@@ -223,18 +245,18 @@ class CompilationUnit(object):
223
245
  self.driver = NVVM()
224
246
  self._handle = nvvm_program()
225
247
  err = self.driver.nvvmCreateProgram(byref(self._handle))
226
- self.driver.check_error(err, 'Failed to create CU')
248
+ self.driver.check_error(err, "Failed to create CU")
227
249
 
228
250
  def stringify_option(k, v):
229
- k = k.replace('_', '-')
251
+ k = k.replace("_", "-")
230
252
 
231
253
  if v is None:
232
- return f'-{k}'.encode('utf-8')
254
+ return f"-{k}".encode("utf-8")
233
255
 
234
256
  if isinstance(v, bool):
235
257
  v = int(v)
236
258
 
237
- return f'-{k}={v}'.encode('utf-8')
259
+ return f"-{k}={v}".encode("utf-8")
238
260
 
239
261
  options = [stringify_option(k, v) for k, v in options.items()]
240
262
  option_ptrs = (c_char_p * len(options))(*[c_char_p(x) for x in options])
@@ -248,17 +270,18 @@ class CompilationUnit(object):
248
270
  def __del__(self):
249
271
  driver = NVVM()
250
272
  err = driver.nvvmDestroyProgram(byref(self._handle))
251
- driver.check_error(err, 'Failed to destroy CU', exit=True)
273
+ driver.check_error(err, "Failed to destroy CU", exit=True)
252
274
 
253
275
  def add_module(self, buffer):
254
276
  """
255
- Add a module level NVVM IR to a compilation unit.
256
- - The buffer should contain an NVVM module IR either in the bitcode
257
- representation (LLVM3.0) or in the text representation.
277
+ Add a module level NVVM IR to a compilation unit.
278
+ - The buffer should contain an NVVM module IR either in the bitcode
279
+ representation (LLVM3.0) or in the text representation.
258
280
  """
259
- err = self.driver.nvvmAddModuleToProgram(self._handle, buffer,
260
- len(buffer), None)
261
- self.driver.check_error(err, 'Failed to add module')
281
+ err = self.driver.nvvmAddModuleToProgram(
282
+ self._handle, buffer, len(buffer), None
283
+ )
284
+ self.driver.check_error(err, "Failed to add module")
262
285
 
263
286
  def lazy_add_module(self, buffer):
264
287
  """
@@ -266,37 +289,41 @@ class CompilationUnit(object):
266
289
  The buffer should contain NVVM module IR either in the bitcode
267
290
  representation or in the text representation.
268
291
  """
269
- err = self.driver.nvvmLazyAddModuleToProgram(self._handle, buffer,
270
- len(buffer), None)
271
- self.driver.check_error(err, 'Failed to add module')
292
+ err = self.driver.nvvmLazyAddModuleToProgram(
293
+ self._handle, buffer, len(buffer), None
294
+ )
295
+ self.driver.check_error(err, "Failed to add module")
272
296
 
273
297
  def verify(self):
274
298
  """
275
299
  Run the NVVM verifier on all code added to the compilation unit.
276
300
  """
277
- err = self.driver.nvvmVerifyProgram(self._handle, self.n_options,
278
- self.option_ptrs)
279
- self._try_error(err, 'Failed to verify\n')
301
+ err = self.driver.nvvmVerifyProgram(
302
+ self._handle, self.n_options, self.option_ptrs
303
+ )
304
+ self._try_error(err, "Failed to verify\n")
280
305
 
281
306
  def compile(self):
282
307
  """
283
308
  Compile all modules added to the compilation unit and return the
284
309
  resulting PTX or LTO-IR (depending on the options).
285
310
  """
286
- err = self.driver.nvvmCompileProgram(self._handle, self.n_options,
287
- self.option_ptrs)
288
- self._try_error(err, 'Failed to compile\n')
311
+ err = self.driver.nvvmCompileProgram(
312
+ self._handle, self.n_options, self.option_ptrs
313
+ )
314
+ self._try_error(err, "Failed to compile\n")
289
315
 
290
316
  # Get result
291
317
  result_size = c_size_t()
292
- err = self.driver.nvvmGetCompiledResultSize(self._handle,
293
- byref(result_size))
318
+ err = self.driver.nvvmGetCompiledResultSize(
319
+ self._handle, byref(result_size)
320
+ )
294
321
 
295
- self._try_error(err, 'Failed to get size of compiled result.')
322
+ self._try_error(err, "Failed to get size of compiled result.")
296
323
 
297
324
  output_buffer = (c_char * result_size.value)()
298
325
  err = self.driver.nvvmGetCompiledResult(self._handle, output_buffer)
299
- self._try_error(err, 'Failed to get compiled result.')
326
+ self._try_error(err, "Failed to get compiled result.")
300
327
 
301
328
  # Get log
302
329
  self.log = self.get_log()
@@ -311,31 +338,44 @@ class CompilationUnit(object):
311
338
  def get_log(self):
312
339
  reslen = c_size_t()
313
340
  err = self.driver.nvvmGetProgramLogSize(self._handle, byref(reslen))
314
- self.driver.check_error(err, 'Failed to get compilation log size.')
341
+ self.driver.check_error(err, "Failed to get compilation log size.")
315
342
 
316
343
  if reslen.value > 1:
317
344
  logbuf = (c_char * reslen.value)()
318
345
  err = self.driver.nvvmGetProgramLog(self._handle, logbuf)
319
- self.driver.check_error(err, 'Failed to get compilation log.')
346
+ self.driver.check_error(err, "Failed to get compilation log.")
320
347
 
321
- return logbuf.value.decode('utf8') # populate log attribute
348
+ return logbuf.value.decode("utf8") # populate log attribute
322
349
 
323
- return ''
350
+ return ""
324
351
 
325
352
 
326
353
  COMPUTE_CAPABILITIES = (
327
- (3, 5), (3, 7),
328
- (5, 0), (5, 2), (5, 3),
329
- (6, 0), (6, 1), (6, 2),
330
- (7, 0), (7, 2), (7, 5),
331
- (8, 0), (8, 6), (8, 7), (8, 9),
354
+ (3, 5),
355
+ (3, 7),
356
+ (5, 0),
357
+ (5, 2),
358
+ (5, 3),
359
+ (6, 0),
360
+ (6, 1),
361
+ (6, 2),
362
+ (7, 0),
363
+ (7, 2),
364
+ (7, 5),
365
+ (8, 0),
366
+ (8, 6),
367
+ (8, 7),
368
+ (8, 9),
332
369
  (9, 0),
333
- (10, 0), (10, 1),
370
+ (10, 0),
371
+ (10, 1),
334
372
  (12, 0),
335
373
  )
336
374
 
375
+
337
376
  # Maps CTK version -> (min supported cc, max supported cc) inclusive
338
- CTK_SUPPORTED = {
377
+ _CUDA_CC_MIN_MAX_SUPPORT = {
378
+ (11, 1): ((3, 5), (8, 0)),
339
379
  (11, 2): ((3, 5), (8, 6)),
340
380
  (11, 3): ((3, 5), (8, 6)),
341
381
  (11, 4): ((3, 5), (8, 7)),
@@ -357,34 +397,43 @@ CTK_SUPPORTED = {
357
397
  def ccs_supported_by_ctk(ctk_version):
358
398
  try:
359
399
  # For supported versions, we look up the range of supported CCs
360
- min_cc, max_cc = CTK_SUPPORTED[ctk_version]
361
- return tuple([cc for cc in COMPUTE_CAPABILITIES
362
- if min_cc <= cc <= max_cc])
400
+ min_cc, max_cc = _CUDA_CC_MIN_MAX_SUPPORT[ctk_version]
401
+ return tuple(
402
+ [cc for cc in COMPUTE_CAPABILITIES if min_cc <= cc <= max_cc]
403
+ )
363
404
  except KeyError:
364
405
  # For unsupported CUDA toolkit versions, all we can do is assume all
365
406
  # non-deprecated versions we are aware of are supported.
366
- return tuple([cc for cc in COMPUTE_CAPABILITIES
367
- if cc >= config.CUDA_DEFAULT_PTX_CC])
407
+ return tuple(
408
+ [
409
+ cc
410
+ for cc in COMPUTE_CAPABILITIES
411
+ if cc >= config.CUDA_DEFAULT_PTX_CC
412
+ ]
413
+ )
368
414
 
369
415
 
370
416
  def get_supported_ccs():
371
417
  try:
372
418
  from numba.cuda.cudadrv.runtime import runtime
419
+
373
420
  cudart_version = runtime.get_version()
374
- except: # noqa: E722
421
+ except: # noqa: E722
375
422
  # We can't support anything if there's an error getting the runtime
376
423
  # version (e.g. if it's not present or there's another issue)
377
424
  _supported_cc = ()
378
425
  return _supported_cc
379
426
 
380
427
  # Ensure the minimum CTK version requirement is met
381
- min_cudart = min(CTK_SUPPORTED)
428
+ min_cudart = min(_CUDA_CC_MIN_MAX_SUPPORT)
382
429
  if cudart_version < min_cudart:
383
430
  _supported_cc = ()
384
431
  ctk_ver = f"{cudart_version[0]}.{cudart_version[1]}"
385
- unsupported_ver = (f"CUDA Toolkit {ctk_ver} is unsupported by Numba - "
386
- f"{min_cudart[0]}.{min_cudart[1]} is the minimum "
387
- "required version.")
432
+ unsupported_ver = (
433
+ f"CUDA Toolkit {ctk_ver} is unsupported by Numba - "
434
+ f"{min_cudart[0]}.{min_cudart[1]} is the minimum "
435
+ "required version."
436
+ )
388
437
  warnings.warn(unsupported_ver)
389
438
  return _supported_cc
390
439
 
@@ -403,8 +452,10 @@ def find_closest_arch(mycc):
403
452
  supported_ccs = NVVM().supported_ccs
404
453
 
405
454
  if not supported_ccs:
406
- msg = "No supported GPU compute capabilities found. " \
407
- "Please check your cudatoolkit version matches your CUDA version."
455
+ msg = (
456
+ "No supported GPU compute capabilities found. "
457
+ "Please check your cudatoolkit version matches your CUDA version."
458
+ )
408
459
  raise NvvmSupportError(msg)
409
460
 
410
461
  for i, cc in enumerate(supported_ccs):
@@ -415,8 +466,10 @@ def find_closest_arch(mycc):
415
466
  # Exceeded
416
467
  if i == 0:
417
468
  # CC lower than supported
418
- msg = "GPU compute capability %d.%d is not supported" \
419
- "(requires >=%d.%d)" % (mycc + cc)
469
+ msg = (
470
+ "GPU compute capability %d.%d is not supported"
471
+ "(requires >=%d.%d)" % (mycc + cc)
472
+ )
420
473
  raise NvvmSupportError(msg)
421
474
  else:
422
475
  # return the previous CC
@@ -427,16 +480,15 @@ def find_closest_arch(mycc):
427
480
 
428
481
 
429
482
  def get_arch_option(major, minor):
430
- """Matches with the closest architecture option
431
- """
483
+ """Matches with the closest architecture option"""
432
484
  if config.FORCE_CUDA_CC:
433
485
  arch = config.FORCE_CUDA_CC
434
486
  else:
435
487
  arch = find_closest_arch((major, minor))
436
- return 'compute_%d%d' % arch
488
+ return "compute_%d%d" % arch
437
489
 
438
490
 
439
- MISSING_LIBDEVICE_FILE_MSG = '''Missing libdevice file.
491
+ MISSING_LIBDEVICE_FILE_MSG = """Missing libdevice file.
440
492
  Please ensure you have a CUDA Toolkit 11.2 or higher.
441
493
  For CUDA 12, ``cuda-nvcc`` and ``cuda-nvrtc`` are required:
442
494
 
@@ -445,7 +497,7 @@ For CUDA 12, ``cuda-nvcc`` and ``cuda-nvrtc`` are required:
445
497
  For CUDA 11, ``cudatoolkit`` is required:
446
498
 
447
499
  $ conda install -c conda-forge cudatoolkit "cuda-version>=11.2,<12.0"
448
- '''
500
+ """
449
501
 
450
502
 
451
503
  class LibDevice(object):
@@ -466,7 +518,7 @@ class LibDevice(object):
466
518
  cas_nvvm = """
467
519
  %cas_success = cmpxchg volatile {Ti}* %iptr, {Ti} %old, {Ti} %new monotonic monotonic
468
520
  %cas = extractvalue {{ {Ti}, i1 }} %cas_success, 0
469
- """ # noqa: E501
521
+ """ # noqa: E501
470
522
 
471
523
 
472
524
  # Translation of code from CUDA Programming Guide v6.5, section B.12
@@ -490,7 +542,7 @@ done:
490
542
  %result = bitcast {Ti} %old to {T}
491
543
  ret {T} %result
492
544
  }}
493
- """ # noqa: E501
545
+ """ # noqa: E501
494
546
 
495
547
  ir_numba_atomic_inc_template = """
496
548
  define internal {T} @___numba_atomic_{Tu}_inc({T}* %iptr, {T} %val) alwaysinline {{
@@ -510,7 +562,7 @@ attempt:
510
562
  done:
511
563
  ret {T} %old
512
564
  }}
513
- """ # noqa: E501
565
+ """ # noqa: E501
514
566
 
515
567
  ir_numba_atomic_dec_template = """
516
568
  define internal {T} @___numba_atomic_{Tu}_dec({T}* %iptr, {T} %val) alwaysinline {{
@@ -530,7 +582,7 @@ attempt:
530
582
  done:
531
583
  ret {T} %old
532
584
  }}
533
- """ # noqa: E501
585
+ """ # noqa: E501
534
586
 
535
587
  ir_numba_atomic_minmax_template = """
536
588
  define internal {T} @___numba_atomic_{T}_{NAN}{FUNC}({T}* %ptr, {T} %val) alwaysinline {{
@@ -561,7 +613,7 @@ attempt:
561
613
  done:
562
614
  ret {T} %ptrval
563
615
  }}
564
- """ # noqa: E501
616
+ """ # noqa: E501
565
617
 
566
618
 
567
619
  def ir_cas(Ti):
@@ -574,8 +626,15 @@ def ir_numba_atomic_binary(T, Ti, OP, FUNC):
574
626
 
575
627
 
576
628
  def ir_numba_atomic_minmax(T, Ti, NAN, OP, PTR_OR_VAL, FUNC):
577
- params = dict(T=T, Ti=Ti, NAN=NAN, OP=OP, PTR_OR_VAL=PTR_OR_VAL,
578
- FUNC=FUNC, CAS=ir_cas(Ti))
629
+ params = dict(
630
+ T=T,
631
+ Ti=Ti,
632
+ NAN=NAN,
633
+ OP=OP,
634
+ PTR_OR_VAL=PTR_OR_VAL,
635
+ FUNC=FUNC,
636
+ CAS=ir_cas(Ti),
637
+ )
579
638
 
580
639
  return ir_numba_atomic_minmax_template.format(**params)
581
640
 
@@ -590,41 +649,115 @@ def ir_numba_atomic_dec(T, Tu):
590
649
 
591
650
  def llvm_replace(llvmir):
592
651
  replacements = [
593
- ('declare double @"___numba_atomic_double_add"(double* %".1", double %".2")', # noqa: E501
594
- ir_numba_atomic_binary(T='double', Ti='i64', OP='fadd', FUNC='add')),
595
- ('declare float @"___numba_atomic_float_sub"(float* %".1", float %".2")', # noqa: E501
596
- ir_numba_atomic_binary(T='float', Ti='i32', OP='fsub', FUNC='sub')),
597
- ('declare double @"___numba_atomic_double_sub"(double* %".1", double %".2")', # noqa: E501
598
- ir_numba_atomic_binary(T='double', Ti='i64', OP='fsub', FUNC='sub')),
599
- ('declare i64 @"___numba_atomic_u64_inc"(i64* %".1", i64 %".2")',
600
- ir_numba_atomic_inc(T='i64', Tu='u64')),
601
- ('declare i64 @"___numba_atomic_u64_dec"(i64* %".1", i64 %".2")',
602
- ir_numba_atomic_dec(T='i64', Tu='u64')),
603
- ('declare float @"___numba_atomic_float_max"(float* %".1", float %".2")', # noqa: E501
604
- ir_numba_atomic_minmax(T='float', Ti='i32', NAN='', OP='nnan olt',
605
- PTR_OR_VAL='ptr', FUNC='max')),
606
- ('declare double @"___numba_atomic_double_max"(double* %".1", double %".2")', # noqa: E501
607
- ir_numba_atomic_minmax(T='double', Ti='i64', NAN='', OP='nnan olt',
608
- PTR_OR_VAL='ptr', FUNC='max')),
609
- ('declare float @"___numba_atomic_float_min"(float* %".1", float %".2")', # noqa: E501
610
- ir_numba_atomic_minmax(T='float', Ti='i32', NAN='', OP='nnan ogt',
611
- PTR_OR_VAL='ptr', FUNC='min')),
612
- ('declare double @"___numba_atomic_double_min"(double* %".1", double %".2")', # noqa: E501
613
- ir_numba_atomic_minmax(T='double', Ti='i64', NAN='', OP='nnan ogt',
614
- PTR_OR_VAL='ptr', FUNC='min')),
615
- ('declare float @"___numba_atomic_float_nanmax"(float* %".1", float %".2")', # noqa: E501
616
- ir_numba_atomic_minmax(T='float', Ti='i32', NAN='nan', OP='ult',
617
- PTR_OR_VAL='', FUNC='max')),
618
- ('declare double @"___numba_atomic_double_nanmax"(double* %".1", double %".2")', # noqa: E501
619
- ir_numba_atomic_minmax(T='double', Ti='i64', NAN='nan', OP='ult',
620
- PTR_OR_VAL='', FUNC='max')),
621
- ('declare float @"___numba_atomic_float_nanmin"(float* %".1", float %".2")', # noqa: E501
622
- ir_numba_atomic_minmax(T='float', Ti='i32', NAN='nan', OP='ugt',
623
- PTR_OR_VAL='', FUNC='min')),
624
- ('declare double @"___numba_atomic_double_nanmin"(double* %".1", double %".2")', # noqa: E501
625
- ir_numba_atomic_minmax(T='double', Ti='i64', NAN='nan', OP='ugt',
626
- PTR_OR_VAL='', FUNC='min')),
627
- ('immarg', '')
652
+ (
653
+ 'declare double @"___numba_atomic_double_add"(double* %".1", double %".2")', # noqa: E501
654
+ ir_numba_atomic_binary(T="double", Ti="i64", OP="fadd", FUNC="add"),
655
+ ),
656
+ (
657
+ 'declare float @"___numba_atomic_float_sub"(float* %".1", float %".2")', # noqa: E501
658
+ ir_numba_atomic_binary(T="float", Ti="i32", OP="fsub", FUNC="sub"),
659
+ ),
660
+ (
661
+ 'declare double @"___numba_atomic_double_sub"(double* %".1", double %".2")', # noqa: E501
662
+ ir_numba_atomic_binary(T="double", Ti="i64", OP="fsub", FUNC="sub"),
663
+ ),
664
+ (
665
+ 'declare i64 @"___numba_atomic_u64_inc"(i64* %".1", i64 %".2")',
666
+ ir_numba_atomic_inc(T="i64", Tu="u64"),
667
+ ),
668
+ (
669
+ 'declare i64 @"___numba_atomic_u64_dec"(i64* %".1", i64 %".2")',
670
+ ir_numba_atomic_dec(T="i64", Tu="u64"),
671
+ ),
672
+ (
673
+ 'declare float @"___numba_atomic_float_max"(float* %".1", float %".2")', # noqa: E501
674
+ ir_numba_atomic_minmax(
675
+ T="float",
676
+ Ti="i32",
677
+ NAN="",
678
+ OP="nnan olt",
679
+ PTR_OR_VAL="ptr",
680
+ FUNC="max",
681
+ ),
682
+ ),
683
+ (
684
+ 'declare double @"___numba_atomic_double_max"(double* %".1", double %".2")', # noqa: E501
685
+ ir_numba_atomic_minmax(
686
+ T="double",
687
+ Ti="i64",
688
+ NAN="",
689
+ OP="nnan olt",
690
+ PTR_OR_VAL="ptr",
691
+ FUNC="max",
692
+ ),
693
+ ),
694
+ (
695
+ 'declare float @"___numba_atomic_float_min"(float* %".1", float %".2")', # noqa: E501
696
+ ir_numba_atomic_minmax(
697
+ T="float",
698
+ Ti="i32",
699
+ NAN="",
700
+ OP="nnan ogt",
701
+ PTR_OR_VAL="ptr",
702
+ FUNC="min",
703
+ ),
704
+ ),
705
+ (
706
+ 'declare double @"___numba_atomic_double_min"(double* %".1", double %".2")', # noqa: E501
707
+ ir_numba_atomic_minmax(
708
+ T="double",
709
+ Ti="i64",
710
+ NAN="",
711
+ OP="nnan ogt",
712
+ PTR_OR_VAL="ptr",
713
+ FUNC="min",
714
+ ),
715
+ ),
716
+ (
717
+ 'declare float @"___numba_atomic_float_nanmax"(float* %".1", float %".2")', # noqa: E501
718
+ ir_numba_atomic_minmax(
719
+ T="float",
720
+ Ti="i32",
721
+ NAN="nan",
722
+ OP="ult",
723
+ PTR_OR_VAL="",
724
+ FUNC="max",
725
+ ),
726
+ ),
727
+ (
728
+ 'declare double @"___numba_atomic_double_nanmax"(double* %".1", double %".2")', # noqa: E501
729
+ ir_numba_atomic_minmax(
730
+ T="double",
731
+ Ti="i64",
732
+ NAN="nan",
733
+ OP="ult",
734
+ PTR_OR_VAL="",
735
+ FUNC="max",
736
+ ),
737
+ ),
738
+ (
739
+ 'declare float @"___numba_atomic_float_nanmin"(float* %".1", float %".2")', # noqa: E501
740
+ ir_numba_atomic_minmax(
741
+ T="float",
742
+ Ti="i32",
743
+ NAN="nan",
744
+ OP="ugt",
745
+ PTR_OR_VAL="",
746
+ FUNC="min",
747
+ ),
748
+ ),
749
+ (
750
+ 'declare double @"___numba_atomic_double_nanmin"(double* %".1", double %".2")', # noqa: E501
751
+ ir_numba_atomic_minmax(
752
+ T="double",
753
+ Ti="i64",
754
+ NAN="nan",
755
+ OP="ugt",
756
+ PTR_OR_VAL="",
757
+ FUNC="min",
758
+ ),
759
+ ),
760
+ ("immarg", ""),
628
761
  ]
629
762
 
630
763
  for decl, fn in replacements:
@@ -639,19 +772,21 @@ def compile_ir(llvmir, **options):
639
772
  if isinstance(llvmir, str):
640
773
  llvmir = [llvmir]
641
774
 
642
- if options.pop('fastmath', False):
643
- options.update({
644
- 'ftz': True,
645
- 'fma': True,
646
- 'prec_div': False,
647
- 'prec_sqrt': False,
648
- })
775
+ if options.pop("fastmath", False):
776
+ options.update(
777
+ {
778
+ "ftz": True,
779
+ "fma": True,
780
+ "prec_div": False,
781
+ "prec_sqrt": False,
782
+ }
783
+ )
649
784
 
650
785
  cu = CompilationUnit(options)
651
786
 
652
787
  for mod in llvmir:
653
788
  mod = llvm_replace(mod)
654
- cu.add_module(mod.encode('utf8'))
789
+ cu.add_module(mod.encode("utf8"))
655
790
  cu.verify()
656
791
 
657
792
  # We add libdevice following verification so that it is not subject to the
@@ -671,16 +806,16 @@ def llvm150_to_70_ir(ir):
671
806
  """
672
807
  buf = []
673
808
  for line in ir.splitlines():
674
- if line.startswith('attributes #'):
809
+ if line.startswith("attributes #"):
675
810
  # Remove function attributes unsupported by LLVM 7.0
676
811
  m = re_attributes_def.match(line)
677
812
  attrs = m.group(1).split()
678
- attrs = ' '.join(a for a in attrs if a != 'willreturn')
813
+ attrs = " ".join(a for a in attrs if a != "willreturn")
679
814
  line = line.replace(m.group(1), attrs)
680
815
 
681
816
  buf.append(line)
682
817
 
683
- return '\n'.join(buf)
818
+ return "\n".join(buf)
684
819
 
685
820
 
686
821
  def set_cuda_kernel(function):
@@ -704,7 +839,7 @@ def set_cuda_kernel(function):
704
839
  mdvalue = ir.Constant(ir.IntType(32), 1)
705
840
  md = module.add_metadata((function, mdstr, mdvalue))
706
841
 
707
- nmd = cgutils.get_or_insert_named_metadata(module, 'nvvm.annotations')
842
+ nmd = cgutils.get_or_insert_named_metadata(module, "nvvm.annotations")
708
843
  nmd.add(md)
709
844
 
710
845
  # Create the used list
@@ -713,13 +848,13 @@ def set_cuda_kernel(function):
713
848
 
714
849
  fnptr = function.bitcast(ptrty)
715
850
 
716
- llvm_used = ir.GlobalVariable(module, usedty, 'llvm.used')
717
- llvm_used.linkage = 'appending'
718
- llvm_used.section = 'llvm.metadata'
851
+ llvm_used = ir.GlobalVariable(module, usedty, "llvm.used")
852
+ llvm_used.linkage = "appending"
853
+ llvm_used.section = "llvm.metadata"
719
854
  llvm_used.initializer = ir.Constant(usedty, [fnptr])
720
855
 
721
856
  # Remove 'noinline' if it is present.
722
- function.attributes.discard('noinline')
857
+ function.attributes.discard("noinline")
723
858
 
724
859
 
725
860
  def add_ir_version(mod):
@@ -728,4 +863,4 @@ def add_ir_version(mod):
728
863
  i32 = ir.IntType(32)
729
864
  ir_versions = [i32(v) for v in NVVM().get_ir_version()]
730
865
  md_ver = mod.add_metadata(ir_versions)
731
- mod.add_named_metadata('nvvmir.version', md_ver)
866
+ mod.add_named_metadata("nvvmir.version", md_ver)