numba-cuda 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (237) hide show
  1. _numba_cuda_redirector.py +17 -13
  2. numba_cuda/VERSION +1 -1
  3. numba_cuda/_version.py +4 -1
  4. numba_cuda/numba/cuda/__init__.py +6 -2
  5. numba_cuda/numba/cuda/api.py +129 -86
  6. numba_cuda/numba/cuda/api_util.py +3 -3
  7. numba_cuda/numba/cuda/args.py +12 -16
  8. numba_cuda/numba/cuda/cg.py +6 -6
  9. numba_cuda/numba/cuda/codegen.py +74 -43
  10. numba_cuda/numba/cuda/compiler.py +246 -114
  11. numba_cuda/numba/cuda/cpp_function_wrappers.cu +1 -2
  12. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +293 -99
  14. numba_cuda/numba/cuda/cudadecl.py +93 -79
  15. numba_cuda/numba/cuda/cudadrv/__init__.py +3 -1
  16. numba_cuda/numba/cuda/cudadrv/devicearray.py +185 -135
  17. numba_cuda/numba/cuda/cudadrv/devices.py +16 -11
  18. numba_cuda/numba/cuda/cudadrv/driver.py +460 -297
  19. numba_cuda/numba/cuda/cudadrv/drvapi.py +241 -207
  20. numba_cuda/numba/cuda/cudadrv/dummyarray.py +66 -54
  21. numba_cuda/numba/cuda/cudadrv/enums.py +1 -1
  22. numba_cuda/numba/cuda/cudadrv/error.py +6 -2
  23. numba_cuda/numba/cuda/cudadrv/libs.py +67 -63
  24. numba_cuda/numba/cuda/cudadrv/linkable_code.py +27 -3
  25. numba_cuda/numba/cuda/cudadrv/mappings.py +16 -14
  26. numba_cuda/numba/cuda/cudadrv/nvrtc.py +146 -30
  27. numba_cuda/numba/cuda/cudadrv/nvvm.py +296 -161
  28. numba_cuda/numba/cuda/cudadrv/rtapi.py +1 -1
  29. numba_cuda/numba/cuda/cudadrv/runtime.py +20 -8
  30. numba_cuda/numba/cuda/cudaimpl.py +296 -275
  31. numba_cuda/numba/cuda/cudamath.py +1 -1
  32. numba_cuda/numba/cuda/debuginfo.py +99 -7
  33. numba_cuda/numba/cuda/decorators.py +87 -45
  34. numba_cuda/numba/cuda/descriptor.py +1 -1
  35. numba_cuda/numba/cuda/device_init.py +68 -18
  36. numba_cuda/numba/cuda/deviceufunc.py +143 -98
  37. numba_cuda/numba/cuda/dispatcher.py +300 -213
  38. numba_cuda/numba/cuda/errors.py +13 -10
  39. numba_cuda/numba/cuda/extending.py +55 -1
  40. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  41. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  42. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +1090 -927
  43. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +468 -319
  44. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  45. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  46. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  47. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  48. numba_cuda/numba/cuda/initialize.py +5 -3
  49. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  50. numba_cuda/numba/cuda/intrinsics.py +203 -28
  51. numba_cuda/numba/cuda/kernels/reduction.py +13 -13
  52. numba_cuda/numba/cuda/kernels/transpose.py +3 -6
  53. numba_cuda/numba/cuda/libdevice.py +317 -317
  54. numba_cuda/numba/cuda/libdeviceimpl.py +3 -2
  55. numba_cuda/numba/cuda/locks.py +16 -0
  56. numba_cuda/numba/cuda/lowering.py +43 -0
  57. numba_cuda/numba/cuda/mathimpl.py +62 -57
  58. numba_cuda/numba/cuda/models.py +1 -5
  59. numba_cuda/numba/cuda/nvvmutils.py +103 -88
  60. numba_cuda/numba/cuda/printimpl.py +9 -5
  61. numba_cuda/numba/cuda/random.py +46 -36
  62. numba_cuda/numba/cuda/reshape_funcs.cu +1 -1
  63. numba_cuda/numba/cuda/runtime/__init__.py +1 -1
  64. numba_cuda/numba/cuda/runtime/memsys.cu +1 -1
  65. numba_cuda/numba/cuda/runtime/memsys.cuh +1 -1
  66. numba_cuda/numba/cuda/runtime/nrt.cu +3 -3
  67. numba_cuda/numba/cuda/runtime/nrt.py +48 -43
  68. numba_cuda/numba/cuda/simulator/__init__.py +22 -12
  69. numba_cuda/numba/cuda/simulator/api.py +38 -22
  70. numba_cuda/numba/cuda/simulator/compiler.py +2 -2
  71. numba_cuda/numba/cuda/simulator/cudadrv/__init__.py +8 -2
  72. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +63 -55
  73. numba_cuda/numba/cuda/simulator/cudadrv/devices.py +13 -11
  74. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +5 -5
  75. numba_cuda/numba/cuda/simulator/cudadrv/drvapi.py +2 -2
  76. numba_cuda/numba/cuda/simulator/cudadrv/libs.py +1 -1
  77. numba_cuda/numba/cuda/simulator/cudadrv/nvvm.py +3 -3
  78. numba_cuda/numba/cuda/simulator/cudadrv/runtime.py +3 -3
  79. numba_cuda/numba/cuda/simulator/kernel.py +43 -34
  80. numba_cuda/numba/cuda/simulator/kernelapi.py +31 -26
  81. numba_cuda/numba/cuda/simulator/reduction.py +1 -0
  82. numba_cuda/numba/cuda/simulator/vector_types.py +13 -9
  83. numba_cuda/numba/cuda/simulator_init.py +2 -4
  84. numba_cuda/numba/cuda/stubs.py +134 -108
  85. numba_cuda/numba/cuda/target.py +92 -47
  86. numba_cuda/numba/cuda/testing.py +24 -19
  87. numba_cuda/numba/cuda/tests/__init__.py +14 -12
  88. numba_cuda/numba/cuda/tests/cudadrv/test_array_attr.py +16 -17
  89. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +7 -7
  90. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_array_slicing.py +73 -54
  91. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_auto_context.py +1 -1
  92. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +48 -50
  93. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +47 -29
  94. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +3 -3
  95. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_memory.py +19 -19
  96. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +108 -103
  97. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +20 -11
  98. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +20 -17
  99. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +8 -6
  100. numba_cuda/numba/cuda/tests/cudadrv/test_events.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_host_alloc.py +8 -7
  102. numba_cuda/numba/cuda/tests/cudadrv/test_init.py +13 -13
  103. numba_cuda/numba/cuda/tests/cudadrv/test_inline_ptx.py +12 -9
  104. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +36 -31
  105. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +8 -7
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +294 -0
  107. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +10 -7
  108. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +24 -15
  109. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +43 -41
  110. numba_cuda/numba/cuda/tests/cudadrv/test_pinned.py +4 -5
  111. numba_cuda/numba/cuda/tests/cudadrv/test_profiler.py +2 -2
  112. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +28 -17
  113. numba_cuda/numba/cuda/tests/cudadrv/test_reset_device.py +1 -2
  114. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +22 -14
  115. numba_cuda/numba/cuda/tests/cudadrv/test_select_device.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +4 -3
  117. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +10 -4
  118. numba_cuda/numba/cuda/tests/cudapy/cache_with_cpu_usecases.py +1 -0
  119. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +10 -7
  120. numba_cuda/numba/cuda/tests/cudapy/jitlink.ptx +0 -2
  121. numba_cuda/numba/cuda/tests/cudapy/recursion_usecases.py +1 -0
  122. numba_cuda/numba/cuda/tests/cudapy/test_alignment.py +6 -5
  123. numba_cuda/numba/cuda/tests/cudapy/test_array.py +52 -42
  124. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +5 -6
  125. numba_cuda/numba/cuda/tests/cudapy/test_array_methods.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +501 -304
  127. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  128. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +59 -23
  129. numba_cuda/numba/cuda/tests/cudapy/test_boolean.py +3 -3
  130. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +50 -37
  131. numba_cuda/numba/cuda/tests/cudapy/test_casting.py +29 -24
  132. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +11 -6
  133. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +84 -50
  134. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +144 -73
  135. numba_cuda/numba/cuda/tests/cudapy/test_complex_kernel.py +2 -2
  136. numba_cuda/numba/cuda/tests/cudapy/test_const_string.py +37 -27
  137. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +43 -45
  138. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +21 -14
  139. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +60 -55
  140. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +3 -2
  141. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +26 -22
  142. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +29 -27
  143. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +77 -28
  144. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +52 -45
  145. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +55 -43
  146. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +24 -7
  147. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +30 -15
  148. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +11 -12
  149. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +21 -12
  150. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +77 -66
  151. numba_cuda/numba/cuda/tests/cudapy/test_forall.py +5 -3
  152. numba_cuda/numba/cuda/tests/cudapy/test_freevar.py +5 -3
  153. numba_cuda/numba/cuda/tests/cudapy/test_frexp_ldexp.py +1 -1
  154. numba_cuda/numba/cuda/tests/cudapy/test_globals.py +3 -5
  155. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +144 -126
  156. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scalar.py +23 -18
  157. numba_cuda/numba/cuda/tests/cudapy/test_gufunc_scheduling.py +16 -22
  158. numba_cuda/numba/cuda/tests/cudapy/test_idiv.py +1 -3
  159. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
  160. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +29 -20
  161. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +147 -99
  162. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +50 -36
  163. numba_cuda/numba/cuda/tests/cudapy/test_iterators.py +1 -2
  164. numba_cuda/numba/cuda/tests/cudapy/test_lang.py +4 -4
  165. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +7 -7
  166. numba_cuda/numba/cuda/tests/cudapy/test_libdevice.py +24 -20
  167. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +36 -31
  168. numba_cuda/numba/cuda/tests/cudapy/test_localmem.py +13 -13
  169. numba_cuda/numba/cuda/tests/cudapy/test_mandel.py +13 -6
  170. numba_cuda/numba/cuda/tests/cudapy/test_math.py +83 -66
  171. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -3
  172. numba_cuda/numba/cuda/tests/cudapy/test_minmax.py +19 -58
  173. numba_cuda/numba/cuda/tests/cudapy/test_montecarlo.py +4 -4
  174. numba_cuda/numba/cuda/tests/cudapy/test_multigpu.py +9 -7
  175. numba_cuda/numba/cuda/tests/cudapy/test_multiprocessing.py +9 -8
  176. numba_cuda/numba/cuda/tests/cudapy/test_multithreads.py +12 -10
  177. numba_cuda/numba/cuda/tests/cudapy/test_nondet.py +1 -1
  178. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +180 -96
  179. numba_cuda/numba/cuda/tests/cudapy/test_optimization.py +5 -5
  180. numba_cuda/numba/cuda/tests/cudapy/test_overload.py +37 -18
  181. numba_cuda/numba/cuda/tests/cudapy/test_powi.py +7 -7
  182. numba_cuda/numba/cuda/tests/cudapy/test_print.py +9 -7
  183. numba_cuda/numba/cuda/tests/cudapy/test_py2_div_issue.py +1 -1
  184. numba_cuda/numba/cuda/tests/cudapy/test_random.py +15 -10
  185. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +88 -87
  186. numba_cuda/numba/cuda/tests/cudapy/test_recursion.py +12 -10
  187. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +26 -11
  188. numba_cuda/numba/cuda/tests/cudapy/test_retrieve_autoconverted_arrays.py +7 -10
  189. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +4 -6
  190. numba_cuda/numba/cuda/tests/cudapy/test_slicing.py +1 -1
  191. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +10 -9
  192. numba_cuda/numba/cuda/tests/cudapy/test_sm_creation.py +62 -43
  193. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +7 -3
  194. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +7 -5
  195. numba_cuda/numba/cuda/tests/cudapy/test_transpose.py +18 -11
  196. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +111 -88
  197. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +2 -3
  198. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +305 -130
  199. numba_cuda/numba/cuda/tests/cudapy/test_vectorize.py +33 -36
  200. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_complex.py +5 -5
  201. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +16 -12
  202. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_device.py +7 -7
  203. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_scalar_arg.py +6 -7
  204. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +31 -29
  205. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +81 -30
  206. numba_cuda/numba/cuda/tests/cudasim/test_cudasim_issues.py +19 -13
  207. numba_cuda/numba/cuda/tests/data/jitlink.cu +1 -1
  208. numba_cuda/numba/cuda/tests/data/jitlink.ptx +0 -2
  209. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +15 -8
  210. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +4 -7
  211. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +14 -9
  212. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +22 -18
  213. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +7 -4
  214. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +2 -0
  215. numba_cuda/numba/cuda/tests/doc_examples/test_random.py +8 -4
  216. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +2 -1
  217. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +94 -19
  218. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +2 -2
  219. numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py +91 -62
  220. numba_cuda/numba/cuda/tests/nocuda/test_function_resolution.py +14 -5
  221. numba_cuda/numba/cuda/tests/nocuda/test_import.py +25 -25
  222. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +40 -40
  223. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +12 -10
  224. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +16 -20
  225. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +12 -10
  226. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +2 -2
  227. numba_cuda/numba/cuda/types.py +5 -2
  228. numba_cuda/numba/cuda/ufuncs.py +382 -362
  229. numba_cuda/numba/cuda/utils.py +2 -2
  230. numba_cuda/numba/cuda/vector_types.py +5 -3
  231. numba_cuda/numba/cuda/vectorizers.py +38 -33
  232. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
  233. numba_cuda-0.10.0.dist-info/RECORD +263 -0
  234. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
  235. numba_cuda-0.8.1.dist-info/RECORD +0 -251
  236. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
  237. {numba_cuda-0.8.1.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -3,13 +3,47 @@ import re
3
3
  import os
4
4
  from collections import namedtuple
5
5
  import platform
6
-
6
+ import site
7
+ from pathlib import Path
7
8
  from numba.core.config import IS_WIN32
8
- from numba.misc.findlib import find_lib, find_file
9
+ from numba.misc.findlib import find_lib
9
10
  from numba import config
11
+ import ctypes
12
+
13
+ _env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
14
+
15
+ SEARCH_PRIORITY = [
16
+ "Conda environment",
17
+ "Conda environment (NVIDIA package)",
18
+ "NVIDIA NVCC Wheel",
19
+ "CUDA_HOME",
20
+ "System",
21
+ "Debian package",
22
+ ]
23
+
24
+
25
+ def _priority_index(label):
26
+ if label in SEARCH_PRIORITY:
27
+ return SEARCH_PRIORITY.index(label)
28
+ else:
29
+ raise ValueError(f"Can't determine search priority for {label}")
30
+
10
31
 
32
+ def _find_first_valid_lazy(options):
33
+ sorted_options = sorted(options, key=lambda x: _priority_index(x[0]))
34
+ for label, fn in sorted_options:
35
+ value = fn()
36
+ if value:
37
+ return label, value
38
+ return "<unknown>", None
11
39
 
12
- _env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])
40
+
41
+ def _build_options(pairs):
42
+ """Sorts and returns a list of (label, value) tuples according to SEARCH_PRIORITY."""
43
+ priority_index = {label: i for i, label in enumerate(SEARCH_PRIORITY)}
44
+ return sorted(
45
+ pairs, key=lambda pair: priority_index.get(pair[0], float("inf"))
46
+ )
13
47
 
14
48
 
15
49
  def _find_valid_path(options):
@@ -21,83 +55,212 @@ def _find_valid_path(options):
21
55
  if data is not None:
22
56
  return by, data
23
57
  else:
24
- return '<unknown>', None
58
+ return "<unknown>", None
25
59
 
26
60
 
27
61
  def _get_libdevice_path_decision():
28
- options = [
29
- ('Conda environment', get_conda_ctk()),
30
- ('Conda environment (NVIDIA package)', get_nvidia_libdevice_ctk()),
31
- ('CUDA_HOME', get_cuda_home('nvvm', 'libdevice')),
32
- ('System', get_system_ctk('nvvm', 'libdevice')),
33
- ('Debian package', get_debian_pkg_libdevice()),
34
- ]
35
- by, libdir = _find_valid_path(options)
36
- return by, libdir
62
+ options = _build_options(
63
+ [
64
+ ("Conda environment", get_conda_ctk),
65
+ ("Conda environment (NVIDIA package)", get_nvidia_libdevice_ctk),
66
+ ("CUDA_HOME", lambda: get_cuda_home("nvvm", "libdevice")),
67
+ ("NVIDIA NVCC Wheel", get_libdevice_wheel),
68
+ ("System", lambda: get_system_ctk("nvvm", "libdevice")),
69
+ ("Debian package", get_debian_pkg_libdevice),
70
+ ]
71
+ )
72
+ return _find_first_valid_lazy(options)
37
73
 
38
74
 
39
75
  def _nvvm_lib_dir():
40
76
  if IS_WIN32:
41
- return 'nvvm', 'bin'
77
+ return "nvvm", "bin"
42
78
  else:
43
- return 'nvvm', 'lib64'
79
+ return "nvvm", "lib64"
44
80
 
45
81
 
46
82
  def _get_nvvm_path_decision():
47
83
  options = [
48
- ('Conda environment', get_conda_ctk()),
49
- ('Conda environment (NVIDIA package)', get_nvidia_nvvm_ctk()),
50
- ('CUDA_HOME', get_cuda_home(*_nvvm_lib_dir())),
51
- ('System', get_system_ctk(*_nvvm_lib_dir())),
84
+ ("Conda environment", get_conda_ctk),
85
+ ("Conda environment (NVIDIA package)", get_nvidia_nvvm_ctk),
86
+ ("NVIDIA NVCC Wheel", _get_nvvm_wheel),
87
+ ("CUDA_HOME", lambda: get_cuda_home(*_nvvm_lib_dir())),
88
+ ("System", lambda: get_system_ctk(*_nvvm_lib_dir())),
52
89
  ]
53
- by, path = _find_valid_path(options)
54
- return by, path
90
+ return _find_first_valid_lazy(options)
91
+
92
+
93
+ def _get_nvrtc_system_ctk():
94
+ sys_path = get_system_ctk("bin" if IS_WIN32 else "lib64")
95
+ candidates = find_lib("nvrtc", sys_path)
96
+ if candidates:
97
+ return max(candidates)
98
+
99
+
100
+ def _get_nvrtc_path_decision():
101
+ options = _build_options(
102
+ [
103
+ ("CUDA_HOME", lambda: get_cuda_home("nvrtc")),
104
+ ("Conda environment", get_conda_ctk),
105
+ ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
106
+ ("NVIDIA NVCC Wheel", _get_nvrtc_wheel),
107
+ ("System", _get_nvrtc_system_ctk),
108
+ ]
109
+ )
110
+ return _find_first_valid_lazy(options)
111
+
112
+
113
+ def _get_nvvm_wheel():
114
+ platform_map = {
115
+ "linux": ("lib64", "libnvvm.so"),
116
+ "win32": ("bin", "nvvm64_40_0.dll"),
117
+ }
118
+
119
+ for plat, (dso_dir, dso_path) in platform_map.items():
120
+ if sys.platform.startswith(plat):
121
+ break
122
+ else:
123
+ raise NotImplementedError("Unsupported platform")
124
+
125
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
126
+
127
+ for sp in filter(None, site_paths):
128
+ nvvm_path = Path(sp, "nvidia", "cuda_nvcc", "nvvm", dso_dir, dso_path)
129
+ if nvvm_path.exists():
130
+ return str(nvvm_path.parent)
131
+
132
+ return None
133
+
134
+
135
+ def get_major_cuda_version():
136
+ # TODO: remove once cuda-python is
137
+ # a hard dependency
138
+ from numba.cuda.cudadrv.runtime import get_version
139
+
140
+ return get_version()[0]
141
+
142
+
143
+ def get_nvrtc_dso_path():
144
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
145
+ for sp in site_paths:
146
+ lib_dir = os.path.join(
147
+ sp,
148
+ "nvidia",
149
+ "cuda_nvrtc",
150
+ ("bin" if IS_WIN32 else "lib") if sp else None,
151
+ )
152
+ if lib_dir and os.path.exists(lib_dir):
153
+ try:
154
+ major = get_major_cuda_version()
155
+ if major == 11:
156
+ cu_ver = "112" if IS_WIN32 else "11.2"
157
+ elif major == 12:
158
+ cu_ver = "120" if IS_WIN32 else "12"
159
+ else:
160
+ raise NotImplementedError(f"CUDA {major} is not supported")
161
+
162
+ return os.path.join(
163
+ lib_dir,
164
+ f"nvrtc64_{cu_ver}_0.dll"
165
+ if IS_WIN32
166
+ else f"libnvrtc.so.{cu_ver}",
167
+ )
168
+ except RuntimeError:
169
+ continue
170
+
171
+
172
+ def _get_nvrtc_wheel():
173
+ dso_path = get_nvrtc_dso_path()
174
+ if dso_path:
175
+ try:
176
+ result = ctypes.CDLL(dso_path, mode=ctypes.RTLD_GLOBAL)
177
+ except OSError:
178
+ pass
179
+ else:
180
+ if IS_WIN32:
181
+ import win32api
182
+
183
+ # This absolute path will
184
+ # always be correct regardless of the package source
185
+ nvrtc_path = win32api.GetModuleFileNameW(result._handle)
186
+ dso_dir = os.path.dirname(nvrtc_path)
187
+ builtins_path = os.path.join(
188
+ dso_dir,
189
+ [
190
+ f
191
+ for f in os.listdir(dso_dir)
192
+ if re.match("^nvrtc-builtins.*.dll$", f)
193
+ ][0],
194
+ )
195
+ if not os.path.exists(builtins_path):
196
+ raise RuntimeError(
197
+ f'Path does not exist: "{builtins_path}"'
198
+ )
199
+ return Path(dso_path)
55
200
 
56
201
 
57
202
  def _get_libdevice_paths():
58
203
  by, libdir = _get_libdevice_path_decision()
59
- # Search for pattern
60
- pat = r'libdevice(\.\d+)*\.bc$'
61
- candidates = find_file(re.compile(pat), libdir)
62
- # Keep only the max (most recent version) of the bitcode files.
63
- out = max(candidates, default=None)
204
+ if not libdir:
205
+ return _env_path_tuple(by, None)
206
+ out = os.path.join(libdir, "libdevice.10.bc")
64
207
  return _env_path_tuple(by, out)
65
208
 
66
209
 
67
210
  def _cudalib_path():
68
211
  if IS_WIN32:
69
- return 'bin'
212
+ return "bin"
70
213
  else:
71
- return 'lib64'
214
+ return "lib64"
72
215
 
73
216
 
74
217
  def _cuda_home_static_cudalib_path():
75
218
  if IS_WIN32:
76
- return ('lib', 'x64')
219
+ return ("lib", "x64")
77
220
  else:
78
- return ('lib64',)
221
+ return ("lib64",)
222
+
223
+
224
+ def _get_cudalib_wheel():
225
+ """Get the cudalib path from the NVCC wheel."""
226
+ site_paths = [site.getusersitepackages()] + site.getsitepackages()
227
+ libdir = "bin" if IS_WIN32 else "lib"
228
+ for sp in filter(None, site_paths):
229
+ cudalib_path = Path(sp, "nvidia", "cuda_runtime", libdir)
230
+ if cudalib_path.exists():
231
+ return str(cudalib_path)
232
+ return None
79
233
 
80
234
 
81
235
  def _get_cudalib_dir_path_decision():
82
- options = [
83
- ('Conda environment', get_conda_ctk()),
84
- ('Conda environment (NVIDIA package)', get_nvidia_cudalib_ctk()),
85
- ('CUDA_HOME', get_cuda_home(_cudalib_path())),
86
- ('System', get_system_ctk(_cudalib_path())),
87
- ]
88
- by, libdir = _find_valid_path(options)
89
- return by, libdir
236
+ options = _build_options(
237
+ [
238
+ ("Conda environment", get_conda_ctk),
239
+ ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
240
+ ("NVIDIA NVCC Wheel", _get_cudalib_wheel),
241
+ ("CUDA_HOME", lambda: get_cuda_home(_cudalib_path())),
242
+ ("System", lambda: get_system_ctk(_cudalib_path())),
243
+ ]
244
+ )
245
+ return _find_first_valid_lazy(options)
90
246
 
91
247
 
92
248
  def _get_static_cudalib_dir_path_decision():
93
- options = [
94
- ('Conda environment', get_conda_ctk()),
95
- ('Conda environment (NVIDIA package)', get_nvidia_static_cudalib_ctk()),
96
- ('CUDA_HOME', get_cuda_home(*_cuda_home_static_cudalib_path())),
97
- ('System', get_system_ctk(_cudalib_path())),
98
- ]
99
- by, libdir = _find_valid_path(options)
100
- return by, libdir
249
+ options = _build_options(
250
+ [
251
+ ("Conda environment", get_conda_ctk),
252
+ (
253
+ "Conda environment (NVIDIA package)",
254
+ get_nvidia_static_cudalib_ctk,
255
+ ),
256
+ (
257
+ "CUDA_HOME",
258
+ lambda: get_cuda_home(*_cuda_home_static_cudalib_path()),
259
+ ),
260
+ ("System", lambda: get_system_ctk(_cudalib_path())),
261
+ ]
262
+ )
263
+ return _find_first_valid_lazy(options)
101
264
 
102
265
 
103
266
  def _get_cudalib_dir():
@@ -111,25 +274,23 @@ def _get_static_cudalib_dir():
111
274
 
112
275
 
113
276
  def get_system_ctk(*subdirs):
114
- """Return path to system-wide cudatoolkit; or, None if it doesn't exist.
115
- """
277
+ """Return path to system-wide cudatoolkit; or, None if it doesn't exist."""
116
278
  # Linux?
117
- if sys.platform.startswith('linux'):
279
+ if not IS_WIN32:
118
280
  # Is cuda alias to /usr/local/cuda?
119
281
  # We are intentionally not getting versioned cuda installation.
120
- base = '/usr/local/cuda'
121
- if os.path.exists(base):
122
- return os.path.join(base, *subdirs)
282
+ result = os.path.join("/usr/local/cuda", *subdirs)
283
+ if os.path.exists(result):
284
+ return result
123
285
 
124
286
 
125
287
  def get_conda_ctk():
126
- """Return path to directory containing the shared libraries of cudatoolkit.
127
- """
128
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
288
+ """Return path to directory containing the shared libraries of cudatoolkit."""
289
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
129
290
  if not is_conda_env:
130
291
  return
131
292
  # Assume the existence of NVVM to imply cudatoolkit installed
132
- paths = find_lib('nvvm')
293
+ paths = find_lib("nvvm")
133
294
  if not paths:
134
295
  return
135
296
  # Use the directory name of the max path
@@ -137,9 +298,8 @@ def get_conda_ctk():
137
298
 
138
299
 
139
300
  def get_nvidia_nvvm_ctk():
140
- """Return path to directory containing the NVVM shared library.
141
- """
142
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
301
+ """Return path to directory containing the NVVM shared library."""
302
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
143
303
  if not is_conda_env:
144
304
  return
145
305
 
@@ -147,16 +307,16 @@ def get_nvidia_nvvm_ctk():
147
307
  # conda package is installed.
148
308
 
149
309
  # First, try the location used on Linux and the Windows 11.x packages
150
- libdir = os.path.join(sys.prefix, 'nvvm', _cudalib_path())
310
+ libdir = os.path.join(sys.prefix, "nvvm", _cudalib_path())
151
311
  if not os.path.exists(libdir) or not os.path.isdir(libdir):
152
312
  # If that fails, try the location used for Windows 12.x packages
153
- libdir = os.path.join(sys.prefix, 'Library', 'nvvm', _cudalib_path())
313
+ libdir = os.path.join(sys.prefix, "Library", "nvvm", _cudalib_path())
154
314
  if not os.path.exists(libdir) or not os.path.isdir(libdir):
155
315
  # If that doesn't exist either, assume we don't have the NVIDIA
156
316
  # conda package
157
317
  return
158
318
 
159
- paths = find_lib('nvvm', libdir=libdir)
319
+ paths = find_lib("nvvm", libdir=libdir)
160
320
  if not paths:
161
321
  return
162
322
  # Use the directory name of the max path
@@ -164,39 +324,36 @@ def get_nvidia_nvvm_ctk():
164
324
 
165
325
 
166
326
  def get_nvidia_libdevice_ctk():
167
- """Return path to directory containing the libdevice library.
168
- """
327
+ """Return path to directory containing the libdevice library."""
169
328
  nvvm_ctk = get_nvidia_nvvm_ctk()
170
329
  if not nvvm_ctk:
171
330
  return
172
331
  nvvm_dir = os.path.dirname(nvvm_ctk)
173
- return os.path.join(nvvm_dir, 'libdevice')
332
+ return os.path.join(nvvm_dir, "libdevice")
174
333
 
175
334
 
176
335
  def get_nvidia_cudalib_ctk():
177
- """Return path to directory containing the shared libraries of cudatoolkit.
178
- """
336
+ """Return path to directory containing the shared libraries of cudatoolkit."""
179
337
  nvvm_ctk = get_nvidia_nvvm_ctk()
180
338
  if not nvvm_ctk:
181
339
  return
182
340
  env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
183
- subdir = 'bin' if IS_WIN32 else 'lib'
341
+ subdir = "bin" if IS_WIN32 else "lib"
184
342
  return os.path.join(env_dir, subdir)
185
343
 
186
344
 
187
345
  def get_nvidia_static_cudalib_ctk():
188
- """Return path to directory containing the static libraries of cudatoolkit.
189
- """
346
+ """Return path to directory containing the static libraries of cudatoolkit."""
190
347
  nvvm_ctk = get_nvidia_nvvm_ctk()
191
348
  if not nvvm_ctk:
192
349
  return
193
350
 
194
351
  if IS_WIN32 and ("Library" not in nvvm_ctk):
195
352
  # Location specific to CUDA 11.x packages on Windows
196
- dirs = ('Lib', 'x64')
353
+ dirs = ("Lib", "x64")
197
354
  else:
198
355
  # Linux, or Windows with CUDA 12.x packages
199
- dirs = ('lib',)
356
+ dirs = ("lib",)
200
357
 
201
358
  env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
202
359
  return os.path.join(env_dir, *dirs)
@@ -207,18 +364,45 @@ def get_cuda_home(*subdirs):
207
364
  If *subdirs* are the subdirectory name to be appended in the resulting
208
365
  path.
209
366
  """
210
- cuda_home = os.environ.get('CUDA_HOME')
367
+ cuda_home = os.environ.get("CUDA_HOME")
211
368
  if cuda_home is None:
212
369
  # Try Windows CUDA installation without Anaconda
213
- cuda_home = os.environ.get('CUDA_PATH')
370
+ cuda_home = os.environ.get("CUDA_PATH")
214
371
  if cuda_home is not None:
215
372
  return os.path.join(cuda_home, *subdirs)
216
373
 
217
374
 
218
375
  def _get_nvvm_path():
219
376
  by, path = _get_nvvm_path_decision()
220
- candidates = find_lib('nvvm', path)
221
- path = max(candidates) if candidates else None
377
+
378
+ if by == "NVIDIA NVCC Wheel":
379
+ platform_map = {
380
+ "linux": "libnvvm.so",
381
+ "win32": "nvvm64_40_0.dll",
382
+ }
383
+
384
+ for plat, dso_name in platform_map.items():
385
+ if sys.platform.startswith(plat):
386
+ break
387
+ else:
388
+ raise NotImplementedError("Unsupported platform")
389
+
390
+ path = os.path.join(path, dso_name)
391
+ else:
392
+ candidates = find_lib("nvvm", path)
393
+ path = max(candidates) if candidates else None
394
+ return _env_path_tuple(by, path)
395
+
396
+
397
+ def _get_nvrtc_path():
398
+ by, path = _get_nvrtc_path_decision()
399
+ if by == "NVIDIA NVCC Wheel":
400
+ path = str(path)
401
+ elif by == "System":
402
+ return _env_path_tuple(by, path)
403
+ else:
404
+ candidates = find_lib("nvrtc", path)
405
+ path = max(candidates) if candidates else None
222
406
  return _env_path_tuple(by, path)
223
407
 
224
408
 
@@ -234,16 +418,17 @@ def get_cuda_paths():
234
418
  Note: The result of the function is cached.
235
419
  """
236
420
  # Check cache
237
- if hasattr(get_cuda_paths, '_cached_result'):
421
+ if hasattr(get_cuda_paths, "_cached_result"):
238
422
  return get_cuda_paths._cached_result
239
423
  else:
240
424
  # Not in cache
241
425
  d = {
242
- 'nvvm': _get_nvvm_path(),
243
- 'libdevice': _get_libdevice_paths(),
244
- 'cudalib_dir': _get_cudalib_dir(),
245
- 'static_cudalib_dir': _get_static_cudalib_dir(),
246
- 'include_dir': _get_include_dir(),
426
+ "nvvm": _get_nvvm_path(),
427
+ "nvrtc": _get_nvrtc_path(),
428
+ "libdevice": _get_libdevice_paths(),
429
+ "cudalib_dir": _get_cudalib_dir(),
430
+ "static_cudalib_dir": _get_static_cudalib_dir(),
431
+ "include_dir": _get_include_dir(),
247
432
  }
248
433
  # Cache result
249
434
  get_cuda_paths._cached_result = d
@@ -255,12 +440,22 @@ def get_debian_pkg_libdevice():
255
440
  Return the Debian NVIDIA Maintainers-packaged libdevice location, if it
256
441
  exists.
257
442
  """
258
- pkg_libdevice_location = '/usr/lib/nvidia-cuda-toolkit/libdevice'
443
+ pkg_libdevice_location = "/usr/lib/nvidia-cuda-toolkit/libdevice"
259
444
  if not os.path.exists(pkg_libdevice_location):
260
445
  return None
261
446
  return pkg_libdevice_location
262
447
 
263
448
 
449
+ def get_libdevice_wheel():
450
+ nvvm_path = _get_nvvm_wheel()
451
+ if nvvm_path is None:
452
+ return None
453
+ nvvm_path = Path(nvvm_path)
454
+ libdevice_path = nvvm_path.parent / "libdevice"
455
+
456
+ return str(libdevice_path)
457
+
458
+
264
459
  def get_current_cuda_target_name():
265
460
  """Determine conda's CTK target folder based on system and machine arch.
266
461
 
@@ -274,13 +469,10 @@ def get_current_cuda_target_name():
274
469
  machine = platform.machine()
275
470
 
276
471
  if system == "Linux":
277
- arch_to_targets = {
278
- 'x86_64': 'x86_64-linux',
279
- 'aarch64': 'sbsa-linux'
280
- }
472
+ arch_to_targets = {"x86_64": "x86_64-linux", "aarch64": "sbsa-linux"}
281
473
  elif system == "Windows":
282
474
  arch_to_targets = {
283
- 'AMD64': 'x64',
475
+ "AMD64": "x64",
284
476
  }
285
477
  else:
286
478
  arch_to_targets = {}
@@ -293,26 +485,28 @@ def get_conda_include_dir():
293
485
  Return the include directory in the current conda environment, if one
294
486
  is active and it exists.
295
487
  """
296
- is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
488
+ is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
297
489
  if not is_conda_env:
298
490
  return
299
491
 
300
492
  if platform.system() == "Windows":
301
- include_dir = os.path.join(
302
- sys.prefix, 'Library', 'include'
303
- )
493
+ include_dir = os.path.join(sys.prefix, "Library", "include")
304
494
  elif target_name := get_current_cuda_target_name():
305
495
  include_dir = os.path.join(
306
- sys.prefix, 'targets', target_name, 'include'
496
+ sys.prefix, "targets", target_name, "include"
307
497
  )
308
498
  else:
309
499
  # A fallback when target cannot determined
310
500
  # though usually it shouldn't.
311
- include_dir = os.path.join(sys.prefix, 'include')
501
+ include_dir = os.path.join(sys.prefix, "include")
312
502
 
313
- if (os.path.exists(include_dir) and os.path.isdir(include_dir)
314
- and os.path.exists(os.path.join(include_dir,
315
- 'cuda_device_runtime_api.h'))):
503
+ if (
504
+ os.path.exists(include_dir)
505
+ and os.path.isdir(include_dir)
506
+ and os.path.exists(
507
+ os.path.join(include_dir, "cuda_device_runtime_api.h")
508
+ )
509
+ ):
316
510
  return include_dir
317
511
  return
318
512
 
@@ -320,8 +514,8 @@ def get_conda_include_dir():
320
514
  def _get_include_dir():
321
515
  """Find the root include directory."""
322
516
  options = [
323
- ('Conda environment (NVIDIA package)', get_conda_include_dir()),
324
- ('CUDA_INCLUDE_PATH Config Entry', config.CUDA_INCLUDE_PATH),
517
+ ("Conda environment (NVIDIA package)", get_conda_include_dir()),
518
+ ("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
325
519
  # TODO: add others
326
520
  ]
327
521
  by, include_dir = _find_valid_path(options)