numba-cuda 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (227) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +232 -113
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_fp16.h +661 -661
  13. numba_cuda/numba/cuda/cuda_fp16.hpp +3 -3
  14. numba_cuda/numba/cuda/cuda_paths.py +291 -99
  15. numba_cuda/numba/cuda/cudadecl.py +125 -69
  16. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  17. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  18. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  19. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  20. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  21. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  22. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  23. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  24. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  25. numba_cuda/numba/cuda/cudadrv/linkable_code.py +16 -1
  26. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  27. numba_cuda/numba/cuda/cudadrv/nvrtc.py +138 -29
  28. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  29. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  30. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  31. numba_cuda/numba/cuda/cudaimpl.py +317 -233
  32. numba_cuda/numba/cuda/cudamath.py +1 -1
  33. numba_cuda/numba/cuda/debuginfo.py +8 -6
  34. numba_cuda/numba/cuda/decorators.py +75 -45
  35. numba_cuda/numba/cuda/descriptor.py +1 -1
  36. numba_cuda/numba/cuda/device_init.py +69 -18
  37. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  38. numba_cuda/numba/cuda/dispatcher.py +300 -213
  39. numba_cuda/numba/cuda/errors.py +13 -10
  40. numba_cuda/numba/cuda/extending.py +1 -1
  41. numba_cuda/numba/cuda/initialize.py +5 -3
  42. numba_cuda/numba/cuda/intrinsic_wrapper.py +3 -3
  43. numba_cuda/numba/cuda/intrinsics.py +31 -27
  44. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  45. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  46. numba_cuda/numba/cuda/libdevice.py +317 -317
  47. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  48. numba_cuda/numba/cuda/locks.py +16 -0
  49. numba_cuda/numba/cuda/mathimpl.py +62 -57
  50. numba_cuda/numba/cuda/models.py +1 -5
  51. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  52. numba_cuda/numba/cuda/printimpl.py +9 -5
  53. numba_cuda/numba/cuda/random.py +46 -36
  54. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  55. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  56. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  57. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  58. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  59. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  60. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  61. numba_cuda/numba/cuda/simulator/api.py +38 -22
  62. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  63. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  64. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  65. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  66. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  67. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  68. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  69. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  70. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  71. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  72. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  73. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  74. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  75. numba_cuda/numba/cuda/simulator_init.py +2 -4
  76. numba_cuda/numba/cuda/stubs.py +139 -102
  77. numba_cuda/numba/cuda/target.py +64 -47
  78. numba_cuda/numba/cuda/testing.py +24 -19
  79. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  80. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  81. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  82. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  83. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  84. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  85. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  86. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  87. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  88. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  89. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  90. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  91. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  92. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  93. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  94. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  95. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  96. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  97. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  98. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  99. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  100. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  101. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  102. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  103. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  104. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  105. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  107. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  109. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  110. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  111. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +7 -6
  112. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  113. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  114. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  115. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  116. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  117. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  118. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  119. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +57 -21
  120. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  121. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  122. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  123. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  124. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  125. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  126. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  127. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  128. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  129. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  130. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  131. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  132. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  133. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  134. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +31 -28
  135. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  136. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  137. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +6 -7
  138. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  139. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  140. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +19 -12
  141. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  142. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  143. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  144. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  146. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  147. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  148. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  149. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  150. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  151. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  152. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  153. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  154. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  155. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +6 -6
  156. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  157. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  158. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  159. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  160. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  161. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  162. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  163. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  164. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  165. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  166. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  167. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  168. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  169. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  170. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  171. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  172. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  173. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  174. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  175. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  176. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  178. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  179. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  180. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  181. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  182. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  183. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  184. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  185. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  186. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  187. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  188. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  189. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  190. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  191. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  192. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  193. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  194. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  195. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +31 -25
  196. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  197. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  198. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  199. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  200. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  201. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  202. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  203. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  204. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  205. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  206. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  207. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  208. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  209. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  210. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  211. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  212. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  213. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  214. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  215. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  216. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  217. numba_cuda/numba/cuda/types.py +5 -2
  218. numba_cuda/numba/cuda/ufuncs.py +382 -362
  219. numba_cuda/numba/cuda/utils.py +2 -2
  220. numba_cuda/numba/cuda/vector_types.py +2 -2
  221. numba_cuda/numba/cuda/vectorizers.py +37 -32
  222. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/METADATA +1 -1
  223. numba_cuda-0.9.0.dist-info/RECORD +253 -0
  224. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/WHEEL +1 -1
  225. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  226. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/licenses/LICENSE +0 -0
  227. {numba_cuda-0.8.1.dist-info → numba_cuda-0.9.0.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,7 @@
60
60
  # define __CPP_VERSION_AT_LEAST_11_FP16
61
61
  #endif
62
62
 
63
- /* C++11 header for std::move.
63
+ /* C++11 header for std::move.
64
64
  * In RTC mode, std::move is provided implicitly; don't include the header
65
65
  */
66
66
  #if defined(__CPP_VERSION_AT_LEAST_11_FP16) && !defined(__CUDACC_RTC__)
@@ -145,7 +145,7 @@
145
145
  * Types which allow static initialization of "half" and "half2" until
146
146
  * these become an actual builtin. Note this initialization is as a
147
147
  * bitfield representation of "half", and not a conversion from short->half.
148
- * Such a representation will be deprecated in a future version of CUDA.
148
+ * Such a representation will be deprecated in a future version of CUDA.
149
149
  * (Note these are visible to non-nvcc compilers, including C-only compilation)
150
150
  */
151
151
  typedef struct __CUDA_ALIGN__(2) {
@@ -2443,7 +2443,7 @@ __CUDA_FP16_DECL__ __half atomicAdd(__half *const address, const __half val) {
2443
2443
 
2444
2444
  #undef __CUDA_HOSTDEVICE_FP16_DECL__
2445
2445
  #undef __CUDA_FP16_DECL__
2446
-
2446
+
2447
2447
  /* Define first-class types "half" and "half2", unless user specifies otherwise via "#define CUDA_NO_HALF" */
2448
2448
  /* C cannot ever have these types defined here, because __half and __half2 are C++ classes */
2449
2449
  #if defined(__cplusplus) && !defined(CUDA_NO_HALF)
@@ -3,13 +3,47 @@ import re
3
3
  import os
4
4
  from collections import namedtuple
5
5
  import platform
6
-
6
+ import site
7
+ from pathlib import Path
7
8
  from numba.core.config import IS_WIN32
8
- from numba.misc.findlib import find_lib, find_file
9
+ from numba.misc.findlib import find_lib
9
10
  from numba import config
11
+ import ctypes
12
+
13
+ _env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
14
+
15
+ SEARCH_PRIORITY = [
16
+ "Conda environment",
17
+ "Conda environment (NVIDIA package)",
18
+ "NVIDIA NVCC Wheel",
19
+ "CUDA_HOME",
20
+ "System",
21
+ "Debian package",
22
+ ]
23
+
24
+
25
+ def _priority_index(label):
26
+ if label in SEARCH_PRIORITY:
27
+ return SEARCH_PRIORITY.index(label)
28
+ else:
29
+ raise ValueError(f"Can't determine search priority for {label}")
30
+
10
31
 
32
+ def _find_first_valid_lazy(options):
33
+ sorted_options = sorted(options, key=lambda x: _priority_index(x[0]))
34
+ for label, fn in sorted_options:
35
+ value = fn()
36
+ if value:
37
+ return label, value
38
+ return "<unknown>", None
11
39
 
12
- _env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])
40
+
41
+ def _build_options(pairs):
42
+ """Sorts and returns a list of (label, value) tuples according to SEARCH_PRIORITY."""
43
+ priority_index = {label: i for i, label in enumerate(SEARCH_PRIORITY)}
44
+ return sorted(
45
+ pairs, key=lambda pair: priority_index.get(pair[0], float("inf"))
46
+ )
13
47
 
14
48
 
15
49
  def _find_valid_path(options):
@@ -21,83 +55,210 @@ def _find_valid_path(options):
21
55
  if data is not None:
22
56
  return by, data
23
57
  else:
24
- return '<unknown>', None
58
+ return "<unknown>", None
25
59
 
26
60
 
27
61
  def _get_libdevice_path_decision():
28
- options = [
29
- ('Conda environment', get_conda_ctk()),
30
- ('Conda environment (NVIDIA package)', get_nvidia_libdevice_ctk()),
31
- ('CUDA_HOME', get_cuda_home('nvvm', 'libdevice')),
32
- ('System', get_system_ctk('nvvm', 'libdevice')),
33
- ('Debian package', get_debian_pkg_libdevice()),
34
- ]
35
- by, libdir = _find_valid_path(options)
36
- return by, libdir
62
+ options = _build_options(
63
+ [
64
+ ("Conda environment", get_conda_ctk),
65
+ ("Conda environment (NVIDIA package)", get_nvidia_libdevice_ctk),
66
+ ("CUDA_HOME", lambda: get_cuda_home("nvvm", "libdevice")),
67
+ ("NVIDIA NVCC Wheel", get_libdevice_wheel),
68
+ ("System", lambda: get_system_ctk("nvvm", "libdevice")),
69
+ ("Debian package", get_debian_pkg_libdevice),
70
+ ]
71
+ )
72
+ return _find_first_valid_lazy(options)
37
73
 
38
74
 
39
75
  def _nvvm_lib_dir():
40
76
  if IS_WIN32:
41
- return 'nvvm', 'bin'
77
+ return "nvvm", "bin"
42
78
  else:
43
- return 'nvvm', 'lib64'
79
+ return "nvvm", "lib64"
44
80
 
45
81
 
46
82
  def _get_nvvm_path_decision():
47
83
  options = [
48
- ('Conda environment', get_conda_ctk()),
49
- ('Conda environment (NVIDIA package)', get_nvidia_nvvm_ctk()),
50
- ('CUDA_HOME', get_cuda_home(*_nvvm_lib_dir())),
51
- ('System', get_system_ctk(*_nvvm_lib_dir())),
84
+ ("Conda environment", get_conda_ctk),
85
+ ("Conda environment (NVIDIA package)", get_nvidia_nvvm_ctk),
86
+ ("NVIDIA NVCC Wheel", _get_nvvm_wheel),
87
+ ("CUDA_HOME", lambda: get_cuda_home(*_nvvm_lib_dir())),
88
+ ("System", lambda: get_system_ctk(*_nvvm_lib_dir())),
52
89
  ]
53
- by, path = _find_valid_path(options)
54
- return by, path
90
+ return _find_first_valid_lazy(options)
91
+
92
+
93
+ def _get_nvrtc_system_ctk():
94
+ sys_path = get_system_ctk("bin" if IS_WIN32 else "lib64")
95
+ candidates = find_lib("nvrtc", sys_path)
96
+ if candidates:
97
+ return max(candidates)
98
+
99
+
100
+ def _get_nvrtc_path_decision():
101
+ options = _build_options(
102
+ [
103
+ ("CUDA_HOME", lambda: get_cuda_home("nvrtc")),
104
+ ("Conda environment", get_conda_ctk),
105
+ ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
106
+ ("NVIDIA NVCC Wheel", _get_nvrtc_wheel),
107
+ ("System", _get_nvrtc_system_ctk),
108
+ ]
109
+ )
110
+ return _find_first_valid_lazy(options)
111
+
112
+
113
+ def _get_nvvm_wheel():
114
+ platform_map = {
115
+ "linux": ("lib64", "libnvvm.so"),
116
+ "win32": ("bin", "nvvm64_40_0.dll"),
117
+ }
118
+
119
+ for plat, (dso_dir, dso_path) in platform_map.items():
120
+ if sys.platform.startswith(plat):
121
+ break
122
+ else:
123
+ raise NotImplementedError("Unsupported platform")
124
+
125
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
126
+
127
+ for sp in filter(None, site_paths):
128
+ nvvm_path = Path(sp, "nvidia", "cuda_nvcc", "nvvm", dso_dir, dso_path)
129
+ if nvvm_path.exists():
130
+ return str(nvvm_path.parent)
131
+
132
+ return None
133
+
134
+
135
+ def get_major_cuda_version():
136
+ # TODO: remove once cuda-python is
137
+ # a hard dependency
138
+ from numba.cuda.cudadrv.runtime import get_version
139
+
140
+ return get_version()[0]
141
+
142
+
143
+ def get_nvrtc_dso_path():
144
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
145
+ for sp in site_paths:
146
+ lib_dir = os.path.join(
147
+ sp,
148
+ "nvidia",
149
+ "cuda_nvrtc",
150
+ ("bin" if IS_WIN32 else "lib") if sp else None,
151
+ )
152
+ if lib_dir and os.path.exists(lib_dir):
153
+ try:
154
+ major = get_major_cuda_version()
155
+ if major == 11:
156
+ cu_ver = "112" if IS_WIN32 else "11.2"
157
+ elif major == 12:
158
+ cu_ver = "120" if IS_WIN32 else "12"
159
+ else:
160
+ raise NotImplementedError(f"CUDA {major} is not supported")
161
+
162
+ return os.path.join(
163
+ lib_dir,
164
+ f"nvrtc64_{cu_ver}_0.dll"
165
+ if IS_WIN32
166
+ else f"libnvrtc.so.{cu_ver}",
167
+ )
168
+ except RuntimeError:
169
+ continue
170
+
171
+
172
+ def _get_nvrtc_wheel():
173
+ dso_path = get_nvrtc_dso_path()
174
+ if dso_path:
175
+ try:
176
+ result = ctypes.CDLL(dso_path, mode=ctypes.RTLD_GLOBAL)
177
+ except OSError:
178
+ pass
179
+ else:
180
+ if IS_WIN32:
181
+ import win32api
182
+
183
+ # This absolute path will
184
+ # always be correct regardless of the package source
185
+ nvrtc_path = win32api.GetModuleFileNameW(result._handle)
186
+ dso_dir = os.path.dirname(nvrtc_path)
187
+ builtins_path = os.path.join(
188
+ dso_dir,
189
+ [
190
+ f
191
+ for f in os.listdir(dso_dir)
192
+ if re.match("^nvrtc-builtins.*.dll$", f)
193
+ ][0],
194
+ )
195
+ if not os.path.exists(builtins_path):
196
+ raise RuntimeError(
197
+ f'Path does not exist: "{builtins_path}"'
198
+ )
199
+ return Path(dso_path)
55
200
 
56
201
 
57
202
  def _get_libdevice_paths():
58
203
  by, libdir = _get_libdevice_path_decision()
59
- # Search for pattern
60
- pat = r'libdevice(\.\d+)*\.bc$'
61
- candidates = find_file(re.compile(pat), libdir)
62
- # Keep only the max (most recent version) of the bitcode files.
63
- out = max(candidates, default=None)
204
+ out = os.path.join(libdir, "libdevice.10.bc")
64
205
  return _env_path_tuple(by, out)
65
206
 
66
207
 
67
208
  def _cudalib_path():
68
209
  if IS_WIN32:
69
- return 'bin'
210
+ return "bin"
70
211
  else:
71
- return 'lib64'
212
+ return "lib64"
72
213
 
73
214
 
74
215
  def _cuda_home_static_cudalib_path():
75
216
  if IS_WIN32:
76
- return ('lib', 'x64')
217
+ return ("lib", "x64")
77
218
  else:
78
- return ('lib64',)
219
+ return ("lib64",)
220
+
221
+
222
+ def _get_cudalib_wheel():
223
+ """Get the cudalib path from the NVCC wheel."""
224
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
225
+ libdir = "bin" if IS_WIN32 else "lib"
226
+ for sp in filter(None, site_paths):
227
+ cudalib_path = Path(sp, "nvidia", "cuda_runtime", libdir)
228
+ if cudalib_path.exists():
229
+ return str(cudalib_path)
230
+ return None
79
231
 
80
232
 
81
233
  def _get_cudalib_dir_path_decision():
82
- options = [
83
- ('Conda environment', get_conda_ctk()),
84
- ('Conda environment (NVIDIA package)', get_nvidia_cudalib_ctk()),
85
- ('CUDA_HOME', get_cuda_home(_cudalib_path())),
86
- ('System', get_system_ctk(_cudalib_path())),
87
- ]
88
- by, libdir = _find_valid_path(options)
89
- return by, libdir
234
+ options = _build_options(
235
+ [
236
+ ("Conda environment", get_conda_ctk),
237
+ ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
238
+ ("NVIDIA NVCC Wheel", _get_cudalib_wheel),
239
+ ("CUDA_HOME", lambda: get_cuda_home(_cudalib_path())),
240
+ ("System", lambda: get_system_ctk(_cudalib_path())),
241
+ ]
242
+ )
243
+ return _find_first_valid_lazy(options)
90
244
 
91
245
 
92
246
  def _get_static_cudalib_dir_path_decision():
93
- options = [
94
- ('Conda environment', get_conda_ctk()),
95
- ('Conda environment (NVIDIA package)', get_nvidia_static_cudalib_ctk()),
96
- ('CUDA_HOME', get_cuda_home(*_cuda_home_static_cudalib_path())),
97
- ('System', get_system_ctk(_cudalib_path())),
98
- ]
99
- by, libdir = _find_valid_path(options)
100
- return by, libdir
247
+ options = _build_options(
248
+ [
249
+ ("Conda environment", get_conda_ctk),
250
+ (
251
+ "Conda environment (NVIDIA package)",
252
+ get_nvidia_static_cudalib_ctk,
253
+ ),
254
+ (
255
+ "CUDA_HOME",
256
+ lambda: get_cuda_home(*_cuda_home_static_cudalib_path()),
257
+ ),
258
+ ("System", lambda: get_system_ctk(_cudalib_path())),
259
+ ]
260
+ )
261
+ return _find_first_valid_lazy(options)
101
262
 
102
263
 
103
264
  def _get_cudalib_dir():
@@ -111,25 +272,23 @@ def _get_static_cudalib_dir():
111
272
 
112
273
 
113
274
  def get_system_ctk(*subdirs):
114
- """Return path to system-wide cudatoolkit; or, None if it doesn't exist.
115
- """
275
+ """Return path to system-wide cudatoolkit; or, None if it doesn't exist."""
116
276
  # Linux?
117
- if sys.platform.startswith('linux'):
277
+ if not IS_WIN32:
118
278
  # Is cuda alias to /usr/local/cuda?
119
279
  # We are intentionally not getting versioned cuda installation.
120
- base = '/usr/local/cuda'
121
- if os.path.exists(base):
122
- return os.path.join(base, *subdirs)
280
+ result = os.path.join("/usr/local/cuda", *subdirs)
281
+ if os.path.exists(result):
282
+ return result
123
283
 
124
284
 
125
285
  def get_conda_ctk():
126
- """Return path to directory containing the shared libraries of cudatoolkit.
127
- """
128
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
286
+ """Return path to directory containing the shared libraries of cudatoolkit."""
287
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
129
288
  if not is_conda_env:
130
289
  return
131
290
  # Assume the existence of NVVM to imply cudatoolkit installed
132
- paths = find_lib('nvvm')
291
+ paths = find_lib("nvvm")
133
292
  if not paths:
134
293
  return
135
294
  # Use the directory name of the max path
@@ -137,9 +296,8 @@ def get_conda_ctk():
137
296
 
138
297
 
139
298
  def get_nvidia_nvvm_ctk():
140
- """Return path to directory containing the NVVM shared library.
141
- """
142
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
299
+ """Return path to directory containing the NVVM shared library."""
300
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
143
301
  if not is_conda_env:
144
302
  return
145
303
 
@@ -147,16 +305,16 @@ def get_nvidia_nvvm_ctk():
147
305
  # conda package is installed.
148
306
 
149
307
  # First, try the location used on Linux and the Windows 11.x packages
150
- libdir = os.path.join(sys.prefix, 'nvvm', _cudalib_path())
308
+ libdir = os.path.join(sys.prefix, "nvvm", _cudalib_path())
151
309
  if not os.path.exists(libdir) or not os.path.isdir(libdir):
152
310
  # If that fails, try the location used for Windows 12.x packages
153
- libdir = os.path.join(sys.prefix, 'Library', 'nvvm', _cudalib_path())
311
+ libdir = os.path.join(sys.prefix, "Library", "nvvm", _cudalib_path())
154
312
  if not os.path.exists(libdir) or not os.path.isdir(libdir):
155
313
  # If that doesn't exist either, assume we don't have the NVIDIA
156
314
  # conda package
157
315
  return
158
316
 
159
- paths = find_lib('nvvm', libdir=libdir)
317
+ paths = find_lib("nvvm", libdir=libdir)
160
318
  if not paths:
161
319
  return
162
320
  # Use the directory name of the max path
@@ -164,39 +322,36 @@ def get_nvidia_nvvm_ctk():
164
322
 
165
323
 
166
324
  def get_nvidia_libdevice_ctk():
167
- """Return path to directory containing the libdevice library.
168
- """
325
+ """Return path to directory containing the libdevice library."""
169
326
  nvvm_ctk = get_nvidia_nvvm_ctk()
170
327
  if not nvvm_ctk:
171
328
  return
172
329
  nvvm_dir = os.path.dirname(nvvm_ctk)
173
- return os.path.join(nvvm_dir, 'libdevice')
330
+ return os.path.join(nvvm_dir, "libdevice")
174
331
 
175
332
 
176
333
  def get_nvidia_cudalib_ctk():
177
- """Return path to directory containing the shared libraries of cudatoolkit.
178
- """
334
+ """Return path to directory containing the shared libraries of cudatoolkit."""
179
335
  nvvm_ctk = get_nvidia_nvvm_ctk()
180
336
  if not nvvm_ctk:
181
337
  return
182
338
  env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
183
- subdir = 'bin' if IS_WIN32 else 'lib'
339
+ subdir = "bin" if IS_WIN32 else "lib"
184
340
  return os.path.join(env_dir, subdir)
185
341
 
186
342
 
187
343
  def get_nvidia_static_cudalib_ctk():
188
- """Return path to directory containing the static libraries of cudatoolkit.
189
- """
344
+ """Return path to directory containing the static libraries of cudatoolkit."""
190
345
  nvvm_ctk = get_nvidia_nvvm_ctk()
191
346
  if not nvvm_ctk:
192
347
  return
193
348
 
194
349
  if IS_WIN32 and ("Library" not in nvvm_ctk):
195
350
  # Location specific to CUDA 11.x packages on Windows
196
- dirs = ('Lib', 'x64')
351
+ dirs = ("Lib", "x64")
197
352
  else:
198
353
  # Linux, or Windows with CUDA 12.x packages
199
- dirs = ('lib',)
354
+ dirs = ("lib",)
200
355
 
201
356
  env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
202
357
  return os.path.join(env_dir, *dirs)
@@ -207,18 +362,45 @@ def get_cuda_home(*subdirs):
207
362
  If *subdirs* are the subdirectory name to be appended in the resulting
208
363
  path.
209
364
  """
210
- cuda_home = os.environ.get('CUDA_HOME')
365
+ cuda_home = os.environ.get("CUDA_HOME")
211
366
  if cuda_home is None:
212
367
  # Try Windows CUDA installation without Anaconda
213
- cuda_home = os.environ.get('CUDA_PATH')
368
+ cuda_home = os.environ.get("CUDA_PATH")
214
369
  if cuda_home is not None:
215
370
  return os.path.join(cuda_home, *subdirs)
216
371
 
217
372
 
218
373
  def _get_nvvm_path():
219
374
  by, path = _get_nvvm_path_decision()
220
- candidates = find_lib('nvvm', path)
221
- path = max(candidates) if candidates else None
375
+
376
+ if by == "NVIDIA NVCC Wheel":
377
+ platform_map = {
378
+ "linux": "libnvvm.so",
379
+ "win32": "nvvm64_40_0.dll",
380
+ }
381
+
382
+ for plat, dso_name in platform_map.items():
383
+ if sys.platform.startswith(plat):
384
+ break
385
+ else:
386
+ raise NotImplementedError("Unsupported platform")
387
+
388
+ path = os.path.join(path, dso_name)
389
+ else:
390
+ candidates = find_lib("nvvm", path)
391
+ path = max(candidates) if candidates else None
392
+ return _env_path_tuple(by, path)
393
+
394
+
395
+ def _get_nvrtc_path():
396
+ by, path = _get_nvrtc_path_decision()
397
+ if by == "NVIDIA NVCC Wheel":
398
+ path = str(path)
399
+ elif by == "System":
400
+ return _env_path_tuple(by, path)
401
+ else:
402
+ candidates = find_lib("nvrtc", path)
403
+ path = max(candidates) if candidates else None
222
404
  return _env_path_tuple(by, path)
223
405
 
224
406
 
@@ -234,16 +416,17 @@ def get_cuda_paths():
234
416
  Note: The result of the function is cached.
235
417
  """
236
418
  # Check cache
237
- if hasattr(get_cuda_paths, '_cached_result'):
419
+ if hasattr(get_cuda_paths, "_cached_result"):
238
420
  return get_cuda_paths._cached_result
239
421
  else:
240
422
  # Not in cache
241
423
  d = {
242
- 'nvvm': _get_nvvm_path(),
243
- 'libdevice': _get_libdevice_paths(),
244
- 'cudalib_dir': _get_cudalib_dir(),
245
- 'static_cudalib_dir': _get_static_cudalib_dir(),
246
- 'include_dir': _get_include_dir(),
424
+ "nvvm": _get_nvvm_path(),
425
+ "nvrtc": _get_nvrtc_path(),
426
+ "libdevice": _get_libdevice_paths(),
427
+ "cudalib_dir": _get_cudalib_dir(),
428
+ "static_cudalib_dir": _get_static_cudalib_dir(),
429
+ "include_dir": _get_include_dir(),
247
430
  }
248
431
  # Cache result
249
432
  get_cuda_paths._cached_result = d
@@ -255,12 +438,22 @@ def get_debian_pkg_libdevice():
255
438
  Return the Debian NVIDIA Maintainers-packaged libdevice location, if it
256
439
  exists.
257
440
  """
258
- pkg_libdevice_location = '/usr/lib/nvidia-cuda-toolkit/libdevice'
441
+ pkg_libdevice_location = "/usr/lib/nvidia-cuda-toolkit/libdevice"
259
442
  if not os.path.exists(pkg_libdevice_location):
260
443
  return None
261
444
  return pkg_libdevice_location
262
445
 
263
446
 
447
+ def get_libdevice_wheel():
448
+ nvvm_path = _get_nvvm_wheel()
449
+ if nvvm_path is None:
450
+ return None
451
+ nvvm_path = Path(nvvm_path)
452
+ libdevice_path = nvvm_path.parent / "libdevice"
453
+
454
+ return str(libdevice_path)
455
+
456
+
264
457
  def get_current_cuda_target_name():
265
458
  """Determine conda's CTK target folder based on system and machine arch.
266
459
 
@@ -274,13 +467,10 @@ def get_current_cuda_target_name():
274
467
  machine = platform.machine()
275
468
 
276
469
  if system == "Linux":
277
- arch_to_targets = {
278
- 'x86_64': 'x86_64-linux',
279
- 'aarch64': 'sbsa-linux'
280
- }
470
+ arch_to_targets = {"x86_64": "x86_64-linux", "aarch64": "sbsa-linux"}
281
471
  elif system == "Windows":
282
472
  arch_to_targets = {
283
- 'AMD64': 'x64',
473
+ "AMD64": "x64",
284
474
  }
285
475
  else:
286
476
  arch_to_targets = {}
@@ -293,26 +483,28 @@ def get_conda_include_dir():
293
483
  Return the include directory in the current conda environment, if one
294
484
  is active and it exists.
295
485
  """
296
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
486
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
297
487
  if not is_conda_env:
298
488
  return
299
489
 
300
490
  if platform.system() == "Windows":
301
- include_dir = os.path.join(
302
- sys.prefix, 'Library', 'include'
303
- )
491
+ include_dir = os.path.join(sys.prefix, "Library", "include")
304
492
  elif target_name := get_current_cuda_target_name():
305
493
  include_dir = os.path.join(
306
- sys.prefix, 'targets', target_name, 'include'
494
+ sys.prefix, "targets", target_name, "include"
307
495
  )
308
496
  else:
309
497
  # A fallback when target cannot determined
310
498
  # though usually it shouldn't.
311
- include_dir = os.path.join(sys.prefix, 'include')
499
+ include_dir = os.path.join(sys.prefix, "include")
312
500
 
313
- if (os.path.exists(include_dir) and os.path.isdir(include_dir)
314
- and os.path.exists(os.path.join(include_dir,
315
- 'cuda_device_runtime_api.h'))):
501
+ if (
502
+ os.path.exists(include_dir)
503
+ and os.path.isdir(include_dir)
504
+ and os.path.exists(
505
+ os.path.join(include_dir, "cuda_device_runtime_api.h")
506
+ )
507
+ ):
316
508
  return include_dir
317
509
  return
318
510
 
@@ -320,8 +512,8 @@ def get_conda_include_dir():
320
512
  def _get_include_dir():
321
513
  """Find the root include directory."""
322
514
  options = [
323
- ('Conda environment (NVIDIA package)', get_conda_include_dir()),
324
- ('CUDA_INCLUDE_PATH Config Entry', config.CUDA_INCLUDE_PATH),
515
+ ("Conda environment (NVIDIA package)", get_conda_include_dir()),
516
+ ("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
325
517
  # TODO: add others
326
518
  ]
327
519
  by, include_dir = _find_valid_path(options)