numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -2,29 +2,33 @@
2
2
  # SPDX-License-Identifier: BSD-2-Clause
3
3
 
4
4
  import sys
5
- import re
6
5
  import os
7
6
  from collections import namedtuple
8
7
  import platform
9
- import site
10
- from pathlib import Path
11
- from numba.core.config import IS_WIN32
12
- from numba.misc.findlib import find_lib
13
- from numba import config
14
- import ctypes
8
+ import importlib.metadata
9
+ from numba.cuda.core.config import IS_WIN32
10
+ from numba.cuda.misc.findlib import find_lib
11
+ from numba.cuda import config
15
12
 
16
13
  _env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
17
14
 
18
15
  SEARCH_PRIORITY = [
19
16
  "Conda environment",
20
- "Conda environment (NVIDIA package)",
21
17
  "NVIDIA NVCC Wheel",
22
18
  "CUDA_HOME",
23
19
  "System",
24
- "Debian package",
25
20
  ]
26
21
 
27
22
 
23
+ def _get_distribution(distribution_name):
24
+ """Get the distribution path using importlib.metadata, returning None if not found."""
25
+ try:
26
+ dist = importlib.metadata.distribution(distribution_name)
27
+ return dist
28
+ except importlib.metadata.PackageNotFoundError:
29
+ return None
30
+
31
+
28
32
  def _priority_index(label):
29
33
  if label in SEARCH_PRIORITY:
30
34
  return SEARCH_PRIORITY.index(label)
@@ -64,182 +68,183 @@ def _find_valid_path(options):
64
68
  def _get_libdevice_path_decision():
65
69
  options = _build_options(
66
70
  [
67
- ("Conda environment", get_conda_ctk),
68
- ("Conda environment (NVIDIA package)", get_nvidia_libdevice_ctk),
69
- ("CUDA_HOME", lambda: get_cuda_home("nvvm", "libdevice")),
70
- ("NVIDIA NVCC Wheel", get_libdevice_wheel),
71
- ("System", lambda: get_system_ctk("nvvm", "libdevice")),
72
- ("Debian package", get_debian_pkg_libdevice),
71
+ ("Conda environment", get_libdevice_conda_path),
72
+ ("NVIDIA NVCC Wheel", get_libdevice_wheel_path),
73
+ (
74
+ "CUDA_HOME",
75
+ lambda: get_cuda_home("nvvm", "libdevice", "libdevice.10.bc"),
76
+ ),
77
+ (
78
+ "System",
79
+ lambda: get_system_ctk("nvvm", "libdevice", "libdevice.10.bc"),
80
+ ),
73
81
  ]
74
82
  )
75
83
  return _find_first_valid_lazy(options)
76
84
 
77
85
 
78
- def _nvvm_lib_dir():
79
- if IS_WIN32:
80
- return "nvvm", "bin"
81
- else:
82
- return "nvvm", "lib64"
83
-
84
-
85
86
  def _get_nvvm_path_decision():
86
- options = [
87
- ("Conda environment", get_conda_ctk),
88
- ("Conda environment (NVIDIA package)", get_nvidia_nvvm_ctk),
89
- ("NVIDIA NVCC Wheel", _get_nvvm_wheel),
90
- ("CUDA_HOME", lambda: get_cuda_home(*_nvvm_lib_dir())),
91
- ("System", lambda: get_system_ctk(*_nvvm_lib_dir())),
92
- ]
87
+ options = _build_options(
88
+ [
89
+ ("Conda environment", _get_nvvm_conda_path),
90
+ ("NVIDIA NVCC Wheel", _get_nvvm_wheel_path),
91
+ ("CUDA_HOME", _get_nvvm_cuda_home_path),
92
+ ("System", _get_nvvm_system_path),
93
+ ]
94
+ )
93
95
  return _find_first_valid_lazy(options)
94
96
 
95
97
 
96
- def _get_nvrtc_system_ctk():
97
- sys_path = get_system_ctk("bin" if IS_WIN32 else "lib64")
98
- candidates = find_lib("nvrtc", sys_path)
99
- if candidates:
100
- return max(candidates)
101
-
102
-
103
98
  def _get_nvrtc_path_decision():
104
99
  options = _build_options(
105
100
  [
106
- ("CUDA_HOME", lambda: get_cuda_home(_cudalib_path())),
107
- ("Conda environment", get_conda_ctk),
108
- ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
109
- ("NVIDIA NVCC Wheel", _get_nvrtc_wheel),
110
- ("System", _get_nvrtc_system_ctk),
101
+ ("Conda environment", get_conda_ctk_libdir),
102
+ ("NVIDIA NVCC Wheel", _get_nvrtc_wheel_libdir),
103
+ ("CUDA_HOME", get_cuda_home_libdir),
104
+ ("System", get_system_ctk_libdir),
111
105
  ]
112
106
  )
113
107
  return _find_first_valid_lazy(options)
114
108
 
115
109
 
116
- def _get_nvvm_wheel():
117
- platform_map = {
118
- "linux": ("lib64", "libnvvm.so"),
119
- "win32": ("bin", "nvvm64_40_0.dll"),
120
- }
121
-
122
- for plat, (dso_dir, dso_path) in platform_map.items():
123
- if sys.platform.startswith(plat):
124
- break
125
- else:
126
- raise NotImplementedError("Unsupported platform")
127
-
128
- site_paths = [site.getusersitepackages()] + site.getsitepackages()
110
+ def _get_nvvm_wheel_path():
111
+ dso_path = None
112
+ # CUDA 12
113
+ nvcc_distribution = _get_distribution("nvidia-cuda-nvcc-cu12")
114
+ if nvcc_distribution is not None:
115
+ site_packages_path = nvcc_distribution.locate_file("")
116
+ nvvm_lib_dir = os.path.join(
117
+ site_packages_path,
118
+ "nvidia",
119
+ "cuda_nvcc",
120
+ "nvvm",
121
+ "bin" if IS_WIN32 else "lib64",
122
+ )
123
+ dso_path = os.path.join(
124
+ nvvm_lib_dir, "nvvm64_40_0.dll" if IS_WIN32 else "libnvvm.so"
125
+ )
129
126
 
130
- for sp in filter(None, site_paths):
131
- nvvm_path = Path(sp, "nvidia", "cuda_nvcc", "nvvm", dso_dir, dso_path)
132
- if nvvm_path.exists():
133
- return str(nvvm_path.parent)
127
+ # CUDA 13
128
+ if dso_path is None:
129
+ nvcc_distribution = _get_distribution("nvidia-nvvm")
130
+ if (
131
+ nvcc_distribution is not None
132
+ and nvcc_distribution.version.startswith("13.")
133
+ ):
134
+ site_packages_path = nvcc_distribution.locate_file("")
135
+ nvvm_lib_dir = os.path.join(
136
+ site_packages_path,
137
+ "nvidia",
138
+ "cu13",
139
+ "bin" if IS_WIN32 else "lib",
140
+ "x86_64" if IS_WIN32 else "",
141
+ )
142
+ dso_path = os.path.join(
143
+ nvvm_lib_dir, "nvvm64_40_0.dll" if IS_WIN32 else "libnvvm.so.4"
144
+ )
134
145
 
146
+ if dso_path and os.path.isfile(dso_path):
147
+ return dso_path
135
148
  return None
136
149
 
137
150
 
138
- def get_nvrtc_dso_path():
139
- site_paths = [site.getusersitepackages()] + site.getsitepackages()
140
-
141
- for sp in site_paths:
142
- lib_dir = os.path.join(
143
- sp,
151
+ def _get_nvrtc_wheel_libdir():
152
+ dso_path = None
153
+ # CUDA 12
154
+ nvrtc_distribution = _get_distribution("nvidia-cuda-nvrtc-cu12")
155
+ if nvrtc_distribution is not None:
156
+ site_packages_path = nvrtc_distribution.locate_file("")
157
+ nvrtc_lib_dir = os.path.join(
158
+ site_packages_path,
144
159
  "nvidia",
145
160
  "cuda_nvrtc",
146
- ("bin" if IS_WIN32 else "lib") if sp else None,
161
+ "bin" if IS_WIN32 else "lib",
162
+ )
163
+ dso_path = os.path.join(
164
+ nvrtc_lib_dir, "nvrtc64_120_0.dll" if IS_WIN32 else "libnvrtc.so.12"
147
165
  )
148
- if lib_dir and os.path.exists(lib_dir):
149
- chosen_path = None
150
-
151
- # Check for each version of the NVRTC DLL, preferring the most
152
- # recent.
153
- versions = (
154
- "120" if IS_WIN32 else "12",
155
- "130" if IS_WIN32 else "13",
156
- )
157
-
158
- for version in versions:
159
- dso_path = os.path.join(
160
- lib_dir,
161
- f"nvrtc64_{version}_0.dll"
162
- if IS_WIN32
163
- else f"libnvrtc.so.{version}",
164
- )
165
166
 
166
- if os.path.exists(dso_path) and os.path.isfile(dso_path):
167
- chosen_path = dso_path
167
+ # CUDA 13
168
+ if dso_path is None:
169
+ nvrtc_distribution = _get_distribution("nvidia-cuda-nvrtc")
170
+ if (
171
+ nvrtc_distribution is not None
172
+ and nvrtc_distribution.version.startswith("13.")
173
+ ):
174
+ site_packages_path = nvrtc_distribution.locate_file("")
175
+ nvrtc_lib_dir = os.path.join(
176
+ site_packages_path,
177
+ "nvidia",
178
+ "cu13",
179
+ "bin" if IS_WIN32 else "lib",
180
+ "x86_64" if IS_WIN32 else "",
181
+ )
182
+ dso_path = os.path.join(
183
+ nvrtc_lib_dir,
184
+ "nvrtc64_130_0.dll" if IS_WIN32 else "libnvrtc.so.13",
185
+ )
168
186
 
169
- return chosen_path
187
+ if dso_path and os.path.isfile(dso_path):
188
+ return os.path.dirname(dso_path)
189
+ return None
170
190
 
171
191
 
172
- def _get_nvrtc_wheel():
173
- dso_path = get_nvrtc_dso_path()
174
- if dso_path:
175
- try:
176
- result = ctypes.CDLL(dso_path, mode=ctypes.RTLD_GLOBAL)
177
- except OSError:
178
- pass
179
- else:
180
- if IS_WIN32:
181
- import win32api
182
-
183
- # This absolute path will
184
- # always be correct regardless of the package source
185
- nvrtc_path = win32api.GetModuleFileNameW(result._handle)
186
- dso_dir = os.path.dirname(nvrtc_path)
187
- builtins_path = os.path.join(
188
- dso_dir,
189
- [
190
- f
191
- for f in os.listdir(dso_dir)
192
- if re.match("^nvrtc-builtins.*.dll$", f)
193
- ][0],
194
- )
195
- if not os.path.exists(builtins_path):
196
- raise RuntimeError(
197
- f'Path does not exist: "{builtins_path}"'
198
- )
199
- return Path(dso_path)
200
-
201
-
202
- def _get_libdevice_paths():
203
- by, libdir = _get_libdevice_path_decision()
204
- if not libdir:
192
+ def _get_libdevice_path():
193
+ by, out = _get_libdevice_path_decision()
194
+ if not out:
205
195
  return _env_path_tuple(by, None)
206
- out = os.path.join(libdir, "libdevice.10.bc")
207
196
  return _env_path_tuple(by, out)
208
197
 
209
198
 
210
- def _cudalib_path():
199
+ def _cuda_static_libdir():
211
200
  if IS_WIN32:
212
- return "bin"
201
+ return ("lib", "x64")
213
202
  else:
214
- return "lib64"
203
+ return ("lib64",)
215
204
 
216
205
 
217
- def _cuda_home_static_cudalib_path():
218
- if IS_WIN32:
219
- return ("lib", "x64")
206
+ def _get_cudalib_wheel_libdir():
207
+ """Get the cudalib path from the cudart wheel."""
208
+ cuda_module_lib_dir = None
209
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime-cu12")
210
+ if cuda_runtime_distribution is not None:
211
+ site_packages_path = cuda_runtime_distribution.locate_file("")
212
+ cuda_module_lib_dir = os.path.join(
213
+ site_packages_path,
214
+ "nvidia",
215
+ "cuda_runtime",
216
+ "bin" if IS_WIN32 else "lib",
217
+ )
220
218
  else:
221
- return ("lib64",)
219
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime")
220
+ if (
221
+ cuda_runtime_distribution is not None
222
+ and cuda_runtime_distribution.version.startswith("13.")
223
+ ):
224
+ site_packages_path = cuda_runtime_distribution.locate_file("")
225
+ cuda_module_lib_dir = os.path.join(
226
+ site_packages_path,
227
+ "nvidia",
228
+ "cu13",
229
+ "bin" if IS_WIN32 else "lib",
230
+ "x86_64" if IS_WIN32 else "",
231
+ )
222
232
 
233
+ if cuda_module_lib_dir is None:
234
+ return None
223
235
 
224
- def _get_cudalib_wheel():
225
- """Get the cudalib path from the NVCC wheel."""
226
- site_paths = [site.getusersitepackages()] + site.getsitepackages()
227
- libdir = "bin" if IS_WIN32 else "lib"
228
- for sp in filter(None, site_paths):
229
- cudalib_path = Path(sp, "nvidia", "cuda_runtime", libdir)
230
- if cudalib_path.exists():
231
- return str(cudalib_path)
236
+ if cuda_module_lib_dir and os.path.isdir(cuda_module_lib_dir):
237
+ return cuda_module_lib_dir
232
238
  return None
233
239
 
234
240
 
235
241
  def _get_cudalib_dir_path_decision():
236
242
  options = _build_options(
237
243
  [
238
- ("Conda environment", get_conda_ctk),
239
- ("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk),
240
- ("NVIDIA NVCC Wheel", _get_cudalib_wheel),
241
- ("CUDA_HOME", lambda: get_cuda_home(_cudalib_path())),
242
- ("System", lambda: get_system_ctk(_cudalib_path())),
244
+ ("Conda environment", get_conda_ctk_libdir),
245
+ ("NVIDIA NVCC Wheel", _get_cudalib_wheel_libdir),
246
+ ("CUDA_HOME", get_cuda_home_libdir),
247
+ ("System", get_system_ctk_libdir),
243
248
  ]
244
249
  )
245
250
  return _find_first_valid_lazy(options)
@@ -248,16 +253,13 @@ def _get_cudalib_dir_path_decision():
248
253
  def _get_static_cudalib_dir_path_decision():
249
254
  options = _build_options(
250
255
  [
251
- ("Conda environment", get_conda_ctk),
252
- (
253
- "Conda environment (NVIDIA package)",
254
- get_nvidia_static_cudalib_ctk,
255
- ),
256
+ ("Conda environment", get_conda_ctk_libdir),
257
+ ("NVIDIA NVCC Wheel", get_wheel_static_libdir),
256
258
  (
257
259
  "CUDA_HOME",
258
- lambda: get_cuda_home(*_cuda_home_static_cudalib_path()),
260
+ lambda: get_cuda_home(*_cuda_static_libdir()),
259
261
  ),
260
- ("System", lambda: get_system_ctk(_cudalib_path())),
262
+ ("System", lambda: get_system_ctk(*_cuda_static_libdir())),
261
263
  ]
262
264
  )
263
265
  return _find_first_valid_lazy(options)
@@ -282,74 +284,196 @@ def get_system_ctk(*subdirs):
282
284
  result = os.path.join("/usr/local/cuda", *subdirs)
283
285
  if os.path.exists(result):
284
286
  return result
287
+ return None
288
+ return None
289
+
290
+
291
+ def get_system_ctk_libdir():
292
+ """Return path to directory containing the shared libraries of cudatoolkit."""
293
+ system_ctk_dir = get_system_ctk()
294
+ if system_ctk_dir is None:
295
+ return None
296
+ libdir = os.path.join(
297
+ system_ctk_dir,
298
+ "Library" if IS_WIN32 else "lib64",
299
+ "bin" if IS_WIN32 else "",
300
+ )
301
+ # Windows CUDA 13 system CTK uses "bin\x64" directory
302
+ if IS_WIN32 and os.path.isdir(os.path.join(libdir, "x64")):
303
+ libdir = os.path.join(libdir, "x64")
285
304
 
305
+ if libdir and os.path.isdir(libdir):
306
+ return os.path.normpath(libdir)
307
+ return None
308
+
309
+
310
+ def get_system_ctk_include():
311
+ system_ctk_dir = get_system_ctk()
312
+ if system_ctk_dir is None:
313
+ return None
314
+ include_dir = os.path.join(system_ctk_dir, "include")
315
+
316
+ if include_dir and os.path.isdir(include_dir):
317
+ if os.path.isfile(
318
+ os.path.join(include_dir, "cuda_device_runtime_api.h")
319
+ ):
320
+ return include_dir
321
+ return None
322
+
323
+
324
+ def _get_nvvm_system_path():
325
+ nvvm_lib_dir = get_system_ctk("nvvm")
326
+ if nvvm_lib_dir is None:
327
+ return None
328
+ nvvm_lib_dir = os.path.join(nvvm_lib_dir, "bin" if IS_WIN32 else "lib64")
329
+ if IS_WIN32 and os.path.isdir(os.path.join(nvvm_lib_dir, "x64")):
330
+ nvvm_lib_dir = os.path.join(nvvm_lib_dir, "x64")
331
+
332
+ nvvm_path = os.path.join(
333
+ nvvm_lib_dir, "nvvm64_40_0.dll" if IS_WIN32 else "libnvvm.so.4"
334
+ )
335
+ # if os.path.isfile(nvvm_path):
336
+ # return nvvm_path
337
+ return nvvm_path
286
338
 
287
- def get_conda_ctk():
339
+
340
+ def get_conda_ctk_libdir():
288
341
  """Return path to directory containing the shared libraries of cudatoolkit."""
289
- is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
342
+ is_conda_env = os.path.isdir(os.path.join(sys.prefix, "conda-meta"))
290
343
  if not is_conda_env:
291
- return
292
- # Assume the existence of NVVM to imply cudatoolkit installed
293
- paths = find_lib("nvvm")
344
+ return None
345
+ libdir = os.path.join(
346
+ sys.prefix,
347
+ "Library" if IS_WIN32 else "lib",
348
+ "bin" if IS_WIN32 else "",
349
+ )
350
+ # Windows CUDA 13.0.0 uses "bin\x64" directory but 13.0.1+ just uses "bin" directory
351
+ if IS_WIN32 and os.path.isdir(os.path.join(libdir, "x64")):
352
+ libdir = os.path.join(libdir, "x64")
353
+ # Assume the existence of nvrtc to imply needed CTK libraries are installed
354
+ paths = find_lib("nvrtc", libdir)
294
355
  if not paths:
295
- return
356
+ return None
296
357
  # Use the directory name of the max path
297
358
  return os.path.dirname(max(paths))
298
359
 
299
360
 
300
- def get_nvidia_nvvm_ctk():
301
- """Return path to directory containing the NVVM shared library."""
302
- is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
361
+ def get_libdevice_conda_path():
362
+ """Return path to directory containing the libdevice bitcode library."""
363
+ is_conda_env = os.path.isdir(os.path.join(sys.prefix, "conda-meta"))
303
364
  if not is_conda_env:
304
- return
365
+ return None
305
366
 
306
- # Assume the existence of NVVM in the conda env implies that a CUDA toolkit
307
- # conda package is installed.
308
- if IS_WIN32:
309
- # The path used on Windows
310
- libdir = os.path.join(sys.prefix, "Library", "nvvm", _cudalib_path())
311
- else:
312
- # The path used on Linux is different to that on Windows
313
- libdir = os.path.join(sys.prefix, "nvvm", _cudalib_path())
367
+ # Linux: nvvm/libdevice/libdevice.10.bc
368
+ # Windows: Library/nvvm/libdevice/libdevice.10.bc
369
+ libdevice_path = os.path.join(
370
+ sys.prefix,
371
+ "Library" if IS_WIN32 else "",
372
+ "nvvm",
373
+ "libdevice",
374
+ "libdevice.10.bc",
375
+ )
376
+ if os.path.isfile(libdevice_path):
377
+ return libdevice_path
378
+ return None
314
379
 
315
- if not os.path.exists(libdir) or not os.path.isdir(libdir):
316
- # If the path doesn't exist, we didn't find the NVIDIA conda package
317
- return
318
380
 
319
- paths = find_lib("nvvm", libdir=libdir)
320
- if not paths:
321
- return
322
- # Use the directory name of the max path
323
- return os.path.dirname(max(paths))
381
+ def _get_nvvm_conda_path():
382
+ """Return path to directory containing the nvvm library."""
383
+ is_conda_env = os.path.isdir(os.path.join(sys.prefix, "conda-meta"))
384
+ if not is_conda_env:
385
+ return None
386
+ nvvm_dir = os.path.join(
387
+ sys.prefix,
388
+ "Library" if IS_WIN32 else "",
389
+ "nvvm",
390
+ "bin" if IS_WIN32 else "lib64",
391
+ )
392
+ # Windows CUDA 13.0.0 puts in "bin\x64" directory but 13.0.1+ just uses "bin" directory
393
+ if IS_WIN32 and os.path.isdir(os.path.join(nvvm_dir, "x64")):
394
+ nvvm_dir = os.path.join(nvvm_dir, "x64")
324
395
 
396
+ nvvm_path = os.path.join(
397
+ nvvm_dir, "nvvm64_40_0.dll" if IS_WIN32 else "libnvvm.so.4"
398
+ )
399
+ if os.path.isfile(nvvm_path):
400
+ return nvvm_path
401
+ return None
325
402
 
326
- def get_nvidia_libdevice_ctk():
327
- """Return path to directory containing the libdevice library."""
328
- nvvm_ctk = get_nvidia_nvvm_ctk()
329
- if not nvvm_ctk:
330
- return
331
- nvvm_dir = os.path.dirname(nvvm_ctk)
332
- return os.path.join(nvvm_dir, "libdevice")
333
403
 
404
+ def get_wheel_static_libdir():
405
+ cuda_module_static_lib_dir = None
406
+ # CUDA 12
407
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime-cu12")
408
+ if cuda_runtime_distribution is not None:
409
+ site_packages_path = cuda_runtime_distribution.locate_file("")
410
+ cuda_module_static_lib_dir = os.path.join(
411
+ site_packages_path,
412
+ "nvidia",
413
+ "cuda_runtime",
414
+ "lib",
415
+ "x64" if IS_WIN32 else "",
416
+ )
417
+ else:
418
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime")
419
+ if (
420
+ cuda_runtime_distribution is not None
421
+ and cuda_runtime_distribution.version.startswith("13.")
422
+ ):
423
+ site_packages_path = cuda_runtime_distribution.locate_file("")
424
+ cuda_module_static_lib_dir = os.path.join(
425
+ site_packages_path,
426
+ "nvidia",
427
+ "cu13",
428
+ "lib",
429
+ "x64" if IS_WIN32 else "",
430
+ )
431
+
432
+ if cuda_module_static_lib_dir is None:
433
+ return None
334
434
 
335
- def get_nvidia_cudalib_ctk():
336
- """Return path to directory containing the shared libraries of cudatoolkit."""
337
- nvvm_ctk = get_nvidia_nvvm_ctk()
338
- if not nvvm_ctk:
339
- return
340
- env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
341
- subdir = "bin" if IS_WIN32 else "lib"
342
- return os.path.join(env_dir, subdir)
435
+ cudadevrt_path = os.path.join(
436
+ cuda_module_static_lib_dir,
437
+ "cudadevrt.lib" if IS_WIN32 else "libcudadevrt.a",
438
+ )
439
+
440
+ if cudadevrt_path and os.path.isfile(cudadevrt_path):
441
+ return os.path.dirname(cudadevrt_path)
442
+ return None
343
443
 
344
444
 
345
- def get_nvidia_static_cudalib_ctk():
346
- """Return path to directory containing the static libraries of cudatoolkit."""
347
- nvvm_ctk = get_nvidia_nvvm_ctk()
348
- if not nvvm_ctk:
349
- return
445
+ def get_wheel_include():
446
+ cuda_module_include_dir = None
447
+ # CUDA 12
448
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime-cu12")
449
+ if cuda_runtime_distribution is not None:
450
+ site_packages_path = cuda_runtime_distribution.locate_file("")
451
+ cuda_module_include_dir = os.path.join(
452
+ site_packages_path,
453
+ "nvidia",
454
+ "cuda_runtime",
455
+ "include",
456
+ )
457
+ else:
458
+ cuda_runtime_distribution = _get_distribution("nvidia-cuda-runtime")
459
+ if (
460
+ cuda_runtime_distribution is not None
461
+ and cuda_runtime_distribution.version.startswith("13.")
462
+ ):
463
+ site_packages_path = cuda_runtime_distribution.locate_file("")
464
+ cuda_module_include_dir = os.path.join(
465
+ site_packages_path,
466
+ "nvidia",
467
+ "cu13",
468
+ "include",
469
+ )
350
470
 
351
- env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
352
- return os.path.join(env_dir, "lib")
471
+ if cuda_module_include_dir and os.path.isdir(cuda_module_include_dir):
472
+ if os.path.isfile(
473
+ os.path.join(cuda_module_include_dir, "cuda_device_runtime_api.h")
474
+ ):
475
+ return cuda_module_include_dir
476
+ return None
353
477
 
354
478
 
355
479
  def get_cuda_home(*subdirs):
@@ -363,39 +487,74 @@ def get_cuda_home(*subdirs):
363
487
  cuda_home = os.environ.get("CUDA_PATH")
364
488
  if cuda_home is not None:
365
489
  return os.path.join(cuda_home, *subdirs)
490
+ return None
366
491
 
367
492
 
368
- def _get_nvvm_path():
369
- by, path = _get_nvvm_path_decision()
493
+ def get_cuda_home_libdir():
494
+ """Return path to directory containing the shared libraries of cudatoolkit."""
495
+ cuda_home_dir = get_cuda_home()
496
+ if cuda_home_dir is None:
497
+ return None
498
+ libdir = os.path.join(
499
+ cuda_home_dir,
500
+ "Library" if IS_WIN32 else "lib64",
501
+ "bin" if IS_WIN32 else "",
502
+ )
503
+ # Windows CUDA 13 system CTK uses "bin\x64" directory while conda just uses "bin" directory
504
+ if IS_WIN32 and os.path.isdir(os.path.join(libdir, "x64")):
505
+ libdir = os.path.join(libdir, "x64")
506
+ return os.path.normpath(libdir)
370
507
 
371
- if by == "NVIDIA NVCC Wheel":
372
- platform_map = {
373
- "linux": "libnvvm.so",
374
- "win32": "nvvm64_40_0.dll",
375
- }
376
508
 
377
- for plat, dso_name in platform_map.items():
378
- if sys.platform.startswith(plat):
379
- break
509
+ def get_cuda_home_include():
510
+ cuda_home_dir = get_cuda_home()
511
+ if cuda_home_dir is None:
512
+ return None
513
+ include_dir = cuda_home_dir
514
+ # For Windows, CTK puts it in $CTK/include but conda puts it in $CTK/Library/include
515
+ if IS_WIN32:
516
+ if os.path.isdir(os.path.join(include_dir, "Library")):
517
+ include_dir = os.path.join(include_dir, "Library", "include")
380
518
  else:
381
- raise NotImplementedError("Unsupported platform")
382
-
383
- path = os.path.join(path, dso_name)
519
+ include_dir = os.path.join(include_dir, "include")
384
520
  else:
385
- candidates = find_lib("nvvm", path)
386
- path = max(candidates) if candidates else None
387
- return _env_path_tuple(by, path)
521
+ include_dir = os.path.join(include_dir, "include")
522
+
523
+ if include_dir and os.path.isdir(include_dir):
524
+ if os.path.isfile(
525
+ os.path.join(include_dir, "cuda_device_runtime_api.h")
526
+ ):
527
+ return include_dir
528
+ return None
529
+
530
+
531
+ def _get_nvvm_cuda_home_path():
532
+ nvvm_lib_dir = get_cuda_home("nvvm")
533
+ if nvvm_lib_dir is None:
534
+ return
535
+ nvvm_lib_dir = os.path.join(nvvm_lib_dir, "bin" if IS_WIN32 else "lib64")
536
+ if IS_WIN32 and os.path.isdir(os.path.join(nvvm_lib_dir, "x64")):
537
+ nvvm_lib_dir = os.path.join(nvvm_lib_dir, "x64")
538
+
539
+ nvvm_path = os.path.join(
540
+ nvvm_lib_dir, "nvvm64_40_0.dll" if IS_WIN32 else "libnvvm.so.4"
541
+ )
542
+ # if os.path.isfile(nvvm_path):
543
+ # return nvvm_path
544
+ return nvvm_path
545
+
546
+
547
+ def _get_nvvm_path():
548
+ by, out = _get_nvvm_path_decision()
549
+ if not out:
550
+ return _env_path_tuple(by, None)
551
+ return _env_path_tuple(by, out)
388
552
 
389
553
 
390
554
  def _get_nvrtc_path():
391
555
  by, path = _get_nvrtc_path_decision()
392
- if by == "NVIDIA NVCC Wheel":
393
- path = str(path)
394
- elif by == "System":
395
- return _env_path_tuple(by, path)
396
- else:
397
- candidates = find_lib("nvrtc", path)
398
- path = max(candidates) if candidates else None
556
+ candidates = find_lib("nvrtc", libdir=path)
557
+ path = max(candidates) if candidates else None
399
558
  return _env_path_tuple(by, path)
400
559
 
401
560
 
@@ -405,8 +564,11 @@ def get_cuda_paths():
405
564
 
406
565
  The returned dictionary will have the following keys and infos:
407
566
  - "nvvm": file_path
408
- - "libdevice": List[Tuple[arch, file_path]]
567
+ - "nvrtc": file_path
568
+ - "libdevice": file_path
409
569
  - "cudalib_dir": directory_path
570
+ - "static_cudalib_dir": directory_path
571
+ - "include_dir": directory_path
410
572
 
411
573
  Note: The result of the function is cached.
412
574
  """
@@ -418,7 +580,7 @@ def get_cuda_paths():
418
580
  d = {
419
581
  "nvvm": _get_nvvm_path(),
420
582
  "nvrtc": _get_nvrtc_path(),
421
- "libdevice": _get_libdevice_paths(),
583
+ "libdevice": _get_libdevice_path(),
422
584
  "cudalib_dir": _get_cudalib_dir(),
423
585
  "static_cudalib_dir": _get_static_cudalib_dir(),
424
586
  "include_dir": _get_include_dir(),
@@ -428,25 +590,41 @@ def get_cuda_paths():
428
590
  return d
429
591
 
430
592
 
431
- def get_debian_pkg_libdevice():
432
- """
433
- Return the Debian NVIDIA Maintainers-packaged libdevice location, if it
434
- exists.
435
- """
436
- pkg_libdevice_location = "/usr/lib/nvidia-cuda-toolkit/libdevice"
437
- if not os.path.exists(pkg_libdevice_location):
438
- return None
439
- return pkg_libdevice_location
440
-
593
+ def get_libdevice_wheel_path():
594
+ libdevice_path = None
595
+ # CUDA 12
596
+ nvvm_distribution = _get_distribution("nvidia-cuda-nvcc-cu12")
597
+ if nvvm_distribution is not None:
598
+ site_packages_path = nvvm_distribution.locate_file("")
599
+ libdevice_path = os.path.join(
600
+ site_packages_path,
601
+ "nvidia",
602
+ "cuda_nvcc",
603
+ "nvvm",
604
+ "libdevice",
605
+ "libdevice.10.bc",
606
+ )
441
607
 
442
- def get_libdevice_wheel():
443
- nvvm_path = _get_nvvm_wheel()
444
- if nvvm_path is None:
445
- return None
446
- nvvm_path = Path(nvvm_path)
447
- libdevice_path = nvvm_path.parent / "libdevice"
608
+ # CUDA 13
609
+ if libdevice_path is None:
610
+ nvvm_distribution = _get_distribution("nvidia-nvvm")
611
+ if (
612
+ nvvm_distribution is not None
613
+ and nvvm_distribution.version.startswith("13.")
614
+ ):
615
+ site_packages_path = nvvm_distribution.locate_file("")
616
+ libdevice_path = os.path.join(
617
+ site_packages_path,
618
+ "nvidia",
619
+ "cu13",
620
+ "nvvm",
621
+ "libdevice",
622
+ "libdevice.10.bc",
623
+ )
448
624
 
449
- return str(libdevice_path)
625
+ if libdevice_path and os.path.isfile(libdevice_path):
626
+ return libdevice_path
627
+ return None
450
628
 
451
629
 
452
630
  def get_current_cuda_target_name():
@@ -478,11 +656,11 @@ def get_conda_include_dir():
478
656
  Return the include directory in the current conda environment, if one
479
657
  is active and it exists.
480
658
  """
481
- is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
659
+ is_conda_env = os.path.isdir(os.path.join(sys.prefix, "conda-meta"))
482
660
  if not is_conda_env:
483
661
  return
484
662
 
485
- if platform.system() == "Windows":
663
+ if IS_WIN32:
486
664
  include_dir = os.path.join(sys.prefix, "Library", "include")
487
665
  elif target_name := get_current_cuda_target_name():
488
666
  include_dir = os.path.join(
@@ -493,23 +671,21 @@ def get_conda_include_dir():
493
671
  # though usually it shouldn't.
494
672
  include_dir = os.path.join(sys.prefix, "include")
495
673
 
496
- if (
497
- os.path.exists(include_dir)
498
- and os.path.isdir(include_dir)
499
- and os.path.exists(
500
- os.path.join(include_dir, "cuda_device_runtime_api.h")
501
- )
674
+ if os.path.isdir(include_dir) and os.path.isfile(
675
+ os.path.join(include_dir, "cuda_device_runtime_api.h")
502
676
  ):
503
677
  return include_dir
504
- return
678
+ return None
505
679
 
506
680
 
507
681
  def _get_include_dir():
508
682
  """Find the root include directory."""
509
683
  options = [
510
684
  ("Conda environment (NVIDIA package)", get_conda_include_dir()),
685
+ ("NVIDIA NVCC Wheel", get_wheel_include()),
686
+ ("CUDA_HOME", get_cuda_home_include()),
687
+ ("System", get_system_ctk_include()),
511
688
  ("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
512
- # TODO: add others
513
689
  ]
514
690
  by, include_dir = _find_valid_path(options)
515
691
  return _env_path_tuple(by, include_dir)