numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,13 @@
1
1
  from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
2
2
  from enum import IntEnum
3
- from numba.cuda.cudadrv.error import (NvrtcError, NvrtcCompilationError,
4
- NvrtcSupportError)
3
+ from numba.cuda.cudadrv.error import (
4
+ NvrtcError,
5
+ NvrtcBuiltinOperationFailure,
6
+ NvrtcCompilationError,
7
+ NvrtcSupportError,
8
+ )
5
9
  from numba.cuda.cuda_paths import get_cuda_paths
10
+
6
11
  import functools
7
12
  import os
8
13
  import threading
@@ -39,6 +44,7 @@ class NvrtcProgram:
39
44
  the class own an nvrtcProgram; when an instance is deleted, the underlying
40
45
  nvrtcProgram is destroyed using the appropriate NVRTC API.
41
46
  """
47
+
42
48
  def __init__(self, nvrtc, handle):
43
49
  self._nvrtc = nvrtc
44
50
  self._handle = handle
@@ -62,46 +68,67 @@ class NVRTC:
62
68
  (for Numba) open_cudalib function to load the NVRTC library.
63
69
  """
64
70
 
71
+ _CU11_2ONLY_PROTOTYPES = {
72
+ # nvrtcResult nvrtcGetNumSupportedArchs(int *numArchs);
73
+ "nvrtcGetNumSupportedArchs": (nvrtc_result, POINTER(c_int)),
74
+ # nvrtcResult nvrtcGetSupportedArchs(int *supportedArchs);
75
+ "nvrtcGetSupportedArchs": (nvrtc_result, POINTER(c_int)),
76
+ }
77
+
65
78
  _CU12ONLY_PROTOTYPES = {
66
79
  # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
67
80
  "nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
68
81
  # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
69
- "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p)
82
+ "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
70
83
  }
71
84
 
72
85
  _PROTOTYPES = {
73
86
  # nvrtcResult nvrtcVersion(int *major, int *minor)
74
- 'nvrtcVersion': (nvrtc_result, POINTER(c_int), POINTER(c_int)),
87
+ "nvrtcVersion": (nvrtc_result, POINTER(c_int), POINTER(c_int)),
75
88
  # nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
76
89
  # const char *src,
77
90
  # const char *name,
78
91
  # int numHeaders,
79
92
  # const char * const *headers,
80
93
  # const char * const *includeNames)
81
- 'nvrtcCreateProgram': (nvrtc_result, nvrtc_program, c_char_p, c_char_p,
82
- c_int, POINTER(c_char_p), POINTER(c_char_p)),
94
+ "nvrtcCreateProgram": (
95
+ nvrtc_result,
96
+ nvrtc_program,
97
+ c_char_p,
98
+ c_char_p,
99
+ c_int,
100
+ POINTER(c_char_p),
101
+ POINTER(c_char_p),
102
+ ),
83
103
  # nvrtcResult nvrtcDestroyProgram(nvrtcProgram *prog);
84
- 'nvrtcDestroyProgram': (nvrtc_result, POINTER(nvrtc_program)),
104
+ "nvrtcDestroyProgram": (nvrtc_result, POINTER(nvrtc_program)),
85
105
  # nvrtcResult nvrtcCompileProgram(nvrtcProgram prog,
86
106
  # int numOptions,
87
107
  # const char * const *options)
88
- 'nvrtcCompileProgram': (nvrtc_result, nvrtc_program, c_int,
89
- POINTER(c_char_p)),
108
+ "nvrtcCompileProgram": (
109
+ nvrtc_result,
110
+ nvrtc_program,
111
+ c_int,
112
+ POINTER(c_char_p),
113
+ ),
90
114
  # nvrtcResult nvrtcGetPTXSize(nvrtcProgram prog, size_t *ptxSizeRet);
91
- 'nvrtcGetPTXSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
115
+ "nvrtcGetPTXSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
92
116
  # nvrtcResult nvrtcGetPTX(nvrtcProgram prog, char *ptx);
93
- 'nvrtcGetPTX': (nvrtc_result, nvrtc_program, c_char_p),
117
+ "nvrtcGetPTX": (nvrtc_result, nvrtc_program, c_char_p),
94
118
  # nvrtcResult nvrtcGetCUBINSize(nvrtcProgram prog,
95
119
  # size_t *cubinSizeRet);
96
- 'nvrtcGetCUBINSize': (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
120
+ "nvrtcGetCUBINSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
97
121
  # nvrtcResult nvrtcGetCUBIN(nvrtcProgram prog, char *cubin);
98
- 'nvrtcGetCUBIN': (nvrtc_result, nvrtc_program, c_char_p),
122
+ "nvrtcGetCUBIN": (nvrtc_result, nvrtc_program, c_char_p),
99
123
  # nvrtcResult nvrtcGetProgramLogSize(nvrtcProgram prog,
100
124
  # size_t *logSizeRet);
101
- 'nvrtcGetProgramLogSize': (nvrtc_result, nvrtc_program,
102
- POINTER(c_size_t)),
125
+ "nvrtcGetProgramLogSize": (
126
+ nvrtc_result,
127
+ nvrtc_program,
128
+ POINTER(c_size_t),
129
+ ),
103
130
  # nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
104
- 'nvrtcGetProgramLog': (nvrtc_result, nvrtc_program, c_char_p),
131
+ "nvrtcGetProgramLog": (nvrtc_result, nvrtc_program, c_char_p),
105
132
  }
106
133
 
107
134
  # Singleton reference
@@ -111,14 +138,18 @@ class NVRTC:
111
138
  with _nvrtc_lock:
112
139
  if cls.__INSTANCE is None:
113
140
  from numba.cuda.cudadrv.libs import open_cudalib
141
+
114
142
  cls.__INSTANCE = inst = object.__new__(cls)
115
143
  try:
116
- lib = open_cudalib('nvrtc')
144
+ lib = open_cudalib("nvrtc")
117
145
  except OSError as e:
118
146
  cls.__INSTANCE = None
119
147
  raise NvrtcSupportError("NVRTC cannot be loaded") from e
120
148
 
121
149
  from numba.cuda.cudadrv.runtime import get_version
150
+
151
+ if get_version() >= (11, 2):
152
+ inst._PROTOTYPES |= inst._CU11_2ONLY_PROTOTYPES
122
153
  if get_version() >= (12, 0):
123
154
  inst._PROTOTYPES |= inst._CU12ONLY_PROTOTYPES
124
155
 
@@ -133,19 +164,73 @@ class NVRTC:
133
164
  error = func(*args)
134
165
  if error == NvrtcResult.NVRTC_ERROR_COMPILATION:
135
166
  raise NvrtcCompilationError()
167
+ elif (
168
+ error
169
+ == NvrtcResult.NVRTC_ERROR_BUILTIN_OPERATION_FAILURE
170
+ ):
171
+ raise NvrtcBuiltinOperationFailure()
136
172
  elif error != NvrtcResult.NVRTC_SUCCESS:
137
173
  try:
138
174
  error_name = NvrtcResult(error).name
139
175
  except ValueError:
140
- error_name = ('Unknown nvrtc_result '
141
- f'(error code: {error})')
142
- msg = f'Failed to call {name}: {error_name}'
176
+ error_name = (
177
+ "Unknown nvrtc_result "
178
+ f"(error code: {error})"
179
+ )
180
+ msg = f"Failed to call {name}: {error_name}"
143
181
  raise NvrtcError(msg)
144
182
 
145
183
  setattr(inst, name, checked_call)
146
184
 
147
185
  return cls.__INSTANCE
148
186
 
187
+ def get_supported_archs(self):
188
+ """
189
+ Get Supported Architectures by NVRTC as list of arch tuples.
190
+ """
191
+ ver = self.get_version()
192
+ if ver < (11, 0):
193
+ raise RuntimeError(
194
+ "Unsupported CUDA version. CUDA 11.0 or higher is required."
195
+ )
196
+ elif ver == (11, 0):
197
+ return [
198
+ (3, 0),
199
+ (3, 2),
200
+ (3, 5),
201
+ (3, 7),
202
+ (5, 0),
203
+ (5, 2),
204
+ (5, 3),
205
+ (6, 0),
206
+ (6, 1),
207
+ (6, 2),
208
+ (7, 0),
209
+ (7, 2),
210
+ (7, 5),
211
+ ]
212
+ elif ver == (11, 1):
213
+ return [
214
+ (3, 5),
215
+ (3, 7),
216
+ (5, 0),
217
+ (5, 2),
218
+ (5, 3),
219
+ (6, 0),
220
+ (6, 1),
221
+ (6, 2),
222
+ (7, 0),
223
+ (7, 2),
224
+ (7, 5),
225
+ (8, 0),
226
+ ]
227
+ else:
228
+ num = c_int()
229
+ self.nvrtcGetNumSupportedArchs(byref(num))
230
+ archs = (c_int * num.value)()
231
+ self.nvrtcGetSupportedArchs(archs)
232
+ return [(archs[i] // 10, archs[i] % 10) for i in range(num.value)]
233
+
149
234
  def get_version(self):
150
235
  """
151
236
  Get the NVRTC version as a tuple (major, minor).
@@ -182,12 +267,12 @@ class NVRTC:
182
267
  # prior to the call to nvrtcCompileProgram
183
268
  encoded_options = [opt.encode() for opt in options]
184
269
  option_pointers = [c_char_p(opt) for opt in encoded_options]
185
- c_options_type = (c_char_p * len(options))
270
+ c_options_type = c_char_p * len(options)
186
271
  c_options = c_options_type(*option_pointers)
187
272
  try:
188
273
  self.nvrtcCompileProgram(program.handle, len(options), c_options)
189
274
  return False
190
- except NvrtcCompilationError:
275
+ except (NvrtcCompilationError, NvrtcBuiltinOperationFailure):
191
276
  return True
192
277
 
193
278
  def destroy_program(self, program):
@@ -251,13 +336,37 @@ def compile(src, name, cc, ltoir=False):
251
336
  nvrtc = NVRTC()
252
337
  program = nvrtc.create_program(src, name)
253
338
 
339
+ version = nvrtc.get_version()
340
+ ver_str = lambda v: ".".join(v)
341
+ if version < (11, 0):
342
+ raise RuntimeError(
343
+ "Unsupported CUDA version. CUDA 11.0 or higher is required."
344
+ )
345
+ else:
346
+ supported_arch = nvrtc.get_supported_archs()
347
+ try:
348
+ found = max(filter(lambda v: v <= cc, [v for v in supported_arch]))
349
+ except ValueError:
350
+ raise RuntimeError(
351
+ f"Device compute capability {ver_str(cc)} is less than the "
352
+ f"minimum supported by NVRTC {ver_str(version)}. Supported "
353
+ "compute capabilities are "
354
+ f"{', '.join([ver_str(v) for v in supported_arch])}."
355
+ )
356
+
357
+ if found != cc:
358
+ warnings.warn(
359
+ f"Device compute capability {ver_str(cc)} is not supported by "
360
+ f"NVRTC {ver_str(version)}. Using {ver_str(found)} instead."
361
+ )
362
+
254
363
  # Compilation options:
255
364
  # - Compile for the current device's compute capability.
256
365
  # - The CUDA include path is added.
257
366
  # - Relocatable Device Code (rdc) is needed to prevent device functions
258
367
  # being optimized away.
259
- major, minor = cc
260
- arch = f'--gpu-architecture=compute_{major}{minor}'
368
+ major, minor = found
369
+ arch = f"--gpu-architecture=compute_{major}{minor}"
261
370
 
262
371
  cuda_include = [
263
372
  f"-I{get_cuda_paths()['include_dir'].info}",
@@ -265,12 +374,12 @@ def compile(src, name, cc, ltoir=False):
265
374
 
266
375
  cudadrv_path = os.path.dirname(os.path.abspath(__file__))
267
376
  numba_cuda_path = os.path.dirname(cudadrv_path)
268
- numba_include = f'-I{numba_cuda_path}'
377
+ numba_include = f"-I{numba_cuda_path}"
269
378
 
270
379
  nrt_path = os.path.join(numba_cuda_path, "runtime")
271
- nrt_include = f'-I{nrt_path}'
380
+ nrt_include = f"-I{nrt_path}"
272
381
 
273
- options = [arch, *cuda_include, numba_include, nrt_include, '-rdc', 'true']
382
+ options = [arch, *cuda_include, numba_include, nrt_include, "-rdc", "true"]
274
383
 
275
384
  if ltoir:
276
385
  options.append("-dlto")
@@ -286,12 +395,12 @@ def compile(src, name, cc, ltoir=False):
286
395
 
287
396
  # If the compile failed, provide the log in an exception
288
397
  if compile_error:
289
- msg = (f'NVRTC Compilation failure whilst compiling {name}:\n\n{log}')
398
+ msg = f"NVRTC Compilation failure whilst compiling {name}:\n\n{log}"
290
399
  raise NvrtcError(msg)
291
400
 
292
401
  # Otherwise, if there's any content in the log, present it as a warning
293
402
  if log:
294
- msg = (f"NVRTC log messages whilst compiling {name}:\n\n{log}")
403
+ msg = f"NVRTC log messages whilst compiling {name}:\n\n{log}"
295
404
  warnings.warn(msg)
296
405
 
297
406
  if ltoir: