numba-cuda 0.15.1__py3-none-any.whl → 0.16.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +51 -16
  3. numba_cuda/numba/cuda/codegen.py +11 -9
  4. numba_cuda/numba/cuda/compiler.py +3 -39
  5. numba_cuda/numba/cuda/cuda_paths.py +20 -22
  6. numba_cuda/numba/cuda/cudadrv/driver.py +197 -286
  7. numba_cuda/numba/cuda/cudadrv/error.py +4 -0
  8. numba_cuda/numba/cuda/cudadrv/libs.py +1 -1
  9. numba_cuda/numba/cuda/cudadrv/mappings.py +8 -9
  10. numba_cuda/numba/cuda/cudadrv/nvrtc.py +153 -108
  11. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -197
  12. numba_cuda/numba/cuda/cudadrv/runtime.py +5 -136
  13. numba_cuda/numba/cuda/decorators.py +18 -0
  14. numba_cuda/numba/cuda/dispatcher.py +1 -0
  15. numba_cuda/numba/cuda/flags.py +36 -0
  16. numba_cuda/numba/cuda/memory_management/nrt.py +2 -2
  17. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +6 -2
  18. numba_cuda/numba/cuda/target.py +55 -2
  19. numba_cuda/numba/cuda/testing.py +0 -22
  20. numba_cuda/numba/cuda/tests/__init__.py +0 -2
  21. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +0 -2
  22. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +15 -1
  23. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +17 -6
  24. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +9 -167
  25. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +27 -0
  26. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +3 -19
  27. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +1 -37
  28. numba_cuda/numba/cuda/tests/cudapy/__init__.py +0 -2
  29. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +1 -1
  30. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +0 -9
  31. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +14 -0
  32. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +0 -6
  33. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  34. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +0 -4
  35. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +18 -0
  36. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +0 -7
  37. numba_cuda/numba/cuda/tests/nocuda/__init__.py +0 -2
  38. numba_cuda/numba/cuda/tests/nrt/__init__.py +0 -2
  39. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +10 -1
  40. {numba_cuda-0.15.1.dist-info → numba_cuda-0.16.0.dist-info}/METADATA +8 -10
  41. {numba_cuda-0.15.1.dist-info → numba_cuda-0.16.0.dist-info}/RECORD +44 -42
  42. {numba_cuda-0.15.1.dist-info → numba_cuda-0.16.0.dist-info}/WHEEL +0 -0
  43. {numba_cuda-0.15.1.dist-info → numba_cuda-0.16.0.dist-info}/licenses/LICENSE +0 -0
  44. {numba_cuda-0.15.1.dist-info → numba_cuda-0.16.0.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  from ctypes import byref, c_char, c_char_p, c_int, c_size_t, c_void_p, POINTER
2
2
  from enum import IntEnum
3
3
  from numba.cuda.cudadrv.error import (
4
+ CCSupportError,
4
5
  NvrtcError,
5
6
  NvrtcBuiltinOperationFailure,
6
7
  NvrtcCompilationError,
@@ -27,6 +28,9 @@ nvrtc_program = c_void_p
27
28
  # Result code
28
29
  nvrtc_result = c_int
29
30
 
31
+ if config.CUDA_USE_NVIDIA_BINDING:
32
+ from cuda.core.experimental import Program, ProgramOptions
33
+
30
34
 
31
35
  class NvrtcResult(IntEnum):
32
36
  NVRTC_SUCCESS = 0
@@ -76,20 +80,6 @@ class NVRTC:
76
80
  (for Numba) open_cudalib function to load the NVRTC library.
77
81
  """
78
82
 
79
- _CU11_2ONLY_PROTOTYPES = {
80
- # nvrtcResult nvrtcGetNumSupportedArchs(int *numArchs);
81
- "nvrtcGetNumSupportedArchs": (nvrtc_result, POINTER(c_int)),
82
- # nvrtcResult nvrtcGetSupportedArchs(int *supportedArchs);
83
- "nvrtcGetSupportedArchs": (nvrtc_result, POINTER(c_int)),
84
- }
85
-
86
- _CU12ONLY_PROTOTYPES = {
87
- # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
88
- "nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
89
- # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
90
- "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
91
- }
92
-
93
83
  _PROTOTYPES = {
94
84
  # nvrtcResult nvrtcVersion(int *major, int *minor)
95
85
  "nvrtcVersion": (nvrtc_result, POINTER(c_int), POINTER(c_int)),
@@ -137,6 +127,14 @@ class NVRTC:
137
127
  ),
138
128
  # nvrtcResult nvrtcGetProgramLog(nvrtcProgram prog, char *log);
139
129
  "nvrtcGetProgramLog": (nvrtc_result, nvrtc_program, c_char_p),
130
+ # nvrtcResult nvrtcGetNumSupportedArchs(int *numArchs);
131
+ "nvrtcGetNumSupportedArchs": (nvrtc_result, POINTER(c_int)),
132
+ # nvrtcResult nvrtcGetSupportedArchs(int *supportedArchs);
133
+ "nvrtcGetSupportedArchs": (nvrtc_result, POINTER(c_int)),
134
+ # nvrtcResult nvrtcGetLTOIRSize(nvrtcProgram prog, size_t *ltoSizeRet);
135
+ "nvrtcGetLTOIRSize": (nvrtc_result, nvrtc_program, POINTER(c_size_t)),
136
+ # nvrtcResult nvrtcGetLTOIR(nvrtcProgram prog, char *lto);
137
+ "nvrtcGetLTOIR": (nvrtc_result, nvrtc_program, c_char_p),
140
138
  }
141
139
 
142
140
  # Singleton reference
@@ -154,18 +152,18 @@ class NVRTC:
154
152
  cls.__INSTANCE = None
155
153
  raise NvrtcSupportError("NVRTC cannot be loaded") from e
156
154
 
157
- from numba.cuda.cudadrv.runtime import get_version
158
-
159
- if get_version() >= (11, 2):
160
- inst._PROTOTYPES |= inst._CU11_2ONLY_PROTOTYPES
161
- if get_version() >= (12, 0):
162
- inst._PROTOTYPES |= inst._CU12ONLY_PROTOTYPES
163
-
164
155
  # Find & populate functions
165
156
  for name, proto in inst._PROTOTYPES.items():
166
- func = getattr(lib, name)
167
- func.restype = proto[0]
168
- func.argtypes = proto[1:]
157
+ try:
158
+ func = getattr(lib, name)
159
+ func.restype = proto[0]
160
+ func.argtypes = proto[1:]
161
+ except AttributeError:
162
+ if "LTOIR" in name:
163
+ # CUDA 11 does not have LTOIR functions; ignore
164
+ continue
165
+ else:
166
+ raise
169
167
 
170
168
  @functools.wraps(func)
171
169
  def checked_call(*args, func=func, name=name):
@@ -192,52 +190,16 @@ class NVRTC:
192
190
 
193
191
  return cls.__INSTANCE
194
192
 
193
+ @functools.cache
195
194
  def get_supported_archs(self):
196
195
  """
197
196
  Get Supported Architectures by NVRTC as list of arch tuples.
198
197
  """
199
- ver = self.get_version()
200
- if ver < (11, 0):
201
- raise RuntimeError(
202
- "Unsupported CUDA version. CUDA 11.0 or higher is required."
203
- )
204
- elif ver == (11, 0):
205
- return [
206
- (3, 0),
207
- (3, 2),
208
- (3, 5),
209
- (3, 7),
210
- (5, 0),
211
- (5, 2),
212
- (5, 3),
213
- (6, 0),
214
- (6, 1),
215
- (6, 2),
216
- (7, 0),
217
- (7, 2),
218
- (7, 5),
219
- ]
220
- elif ver == (11, 1):
221
- return [
222
- (3, 5),
223
- (3, 7),
224
- (5, 0),
225
- (5, 2),
226
- (5, 3),
227
- (6, 0),
228
- (6, 1),
229
- (6, 2),
230
- (7, 0),
231
- (7, 2),
232
- (7, 5),
233
- (8, 0),
234
- ]
235
- else:
236
- num = c_int()
237
- self.nvrtcGetNumSupportedArchs(byref(num))
238
- archs = (c_int * num.value)()
239
- self.nvrtcGetSupportedArchs(archs)
240
- return [(archs[i] // 10, archs[i] % 10) for i in range(num.value)]
198
+ num = c_int()
199
+ self.nvrtcGetNumSupportedArchs(byref(num))
200
+ archs = (c_int * num.value)()
201
+ self.nvrtcGetSupportedArchs(archs)
202
+ return [(archs[i] // 10, archs[i] % 10) for i in range(num.value)]
241
203
 
242
204
  def get_version(self):
243
205
  """
@@ -346,9 +308,9 @@ def compile(src, name, cc, ltoir=False):
346
308
 
347
309
  version = nvrtc.get_version()
348
310
  ver_str = lambda v: ".".join(v)
349
- if version < (11, 0):
311
+ if version < (11, 2):
350
312
  raise RuntimeError(
351
- "Unsupported CUDA version. CUDA 11.0 or higher is required."
313
+ "Unsupported CUDA version. CUDA 11.2 or higher is required."
352
314
  )
353
315
  else:
354
316
  supported_arch = nvrtc.get_supported_archs()
@@ -374,10 +336,16 @@ def compile(src, name, cc, ltoir=False):
374
336
  # - Relocatable Device Code (rdc) is needed to prevent device functions
375
337
  # being optimized away.
376
338
  major, minor = found
377
- arch = f"--gpu-architecture=compute_{major}{minor}"
378
339
 
379
- cuda_include = [
380
- f"-I{get_cuda_paths()['include_dir'].info}",
340
+ if config.CUDA_USE_NVIDIA_BINDING:
341
+ arch = f"sm_{major}{minor}"
342
+ else:
343
+ arch = f"--gpu-architecture=compute_{major}{minor}"
344
+
345
+ cuda_include_dir = get_cuda_paths()["include_dir"].info
346
+ cuda_includes = [
347
+ f"{cuda_include_dir}",
348
+ f"{os.path.join(cuda_include_dir, 'cccl')}",
381
349
  ]
382
350
 
383
351
  nvrtc_version = nvrtc.get_version()
@@ -387,54 +355,131 @@ def compile(src, name, cc, ltoir=False):
387
355
  numba_cuda_path = os.path.dirname(cudadrv_path)
388
356
 
389
357
  if nvrtc_ver_major == 11:
390
- numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '11')}"
358
+ numba_include = f"{os.path.join(numba_cuda_path, 'include', '11')}"
391
359
  else:
392
- numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '12')}"
360
+ numba_include = f"{os.path.join(numba_cuda_path, 'include', '12')}"
393
361
 
394
362
  if config.CUDA_NVRTC_EXTRA_SEARCH_PATHS:
395
- extra_search_paths = config.CUDA_NVRTC_EXTRA_SEARCH_PATHS.split(":")
396
- extra_includes = [f"-I{p}" for p in extra_search_paths]
363
+ extra_includes = config.CUDA_NVRTC_EXTRA_SEARCH_PATHS.split(":")
397
364
  else:
398
365
  extra_includes = []
399
366
 
400
- nrt_path = os.path.join(numba_cuda_path, "memory_management")
401
- nrt_include = f"-I{nrt_path}"
402
-
403
- options = [
404
- arch,
405
- numba_include,
406
- *cuda_include,
407
- nrt_include,
408
- *extra_includes,
409
- "-rdc",
410
- "true",
411
- ]
367
+ nrt_include = os.path.join(numba_cuda_path, "memory_management")
368
+
369
+ includes = [numba_include, *cuda_includes, nrt_include, *extra_includes]
370
+
371
+ if config.CUDA_USE_NVIDIA_BINDING:
372
+ options = ProgramOptions(
373
+ arch=arch,
374
+ include_path=includes,
375
+ relocatable_device_code=True,
376
+ std="c++17" if nvrtc_version < (12, 0) else None,
377
+ link_time_optimization=ltoir,
378
+ name=name,
379
+ )
412
380
 
413
- if ltoir:
414
- options.append("-dlto")
381
+ class Logger:
382
+ def __init__(self):
383
+ self.log = []
415
384
 
416
- if nvrtc_version < (12, 0):
417
- options += ["-std=c++17"]
385
+ def write(self, msg):
386
+ self.log.append(msg)
418
387
 
419
- # Compile the program
420
- compile_error = nvrtc.compile_program(program, options)
388
+ logger = Logger()
389
+ if isinstance(src, bytes):
390
+ src = src.decode("utf8")
421
391
 
422
- # Get log from compilation
423
- log = nvrtc.get_compile_log(program)
392
+ prog = Program(src, "c++", options=options)
393
+ result = prog.compile("ltoir" if ltoir else "ptx", logs=logger)
394
+ log = ""
395
+ if logger.log:
396
+ log = logger.log
397
+ joined_logs = "\n".join(log)
398
+ warnings.warn(f"NVRTC log messages: {joined_logs}")
399
+ return result, log
424
400
 
425
- # If the compile failed, provide the log in an exception
426
- if compile_error:
427
- msg = f"NVRTC Compilation failure whilst compiling {name}:\n\n{log}"
428
- raise NvrtcError(msg)
401
+ else:
402
+ includes = [f"-I{path}" for path in includes]
403
+ options = [
404
+ arch,
405
+ *includes,
406
+ "-rdc",
407
+ "true",
408
+ ]
409
+
410
+ if ltoir:
411
+ options.append("-dlto")
412
+
413
+ if nvrtc_version < (12, 0):
414
+ options.append("-std=c++17")
415
+
416
+ # Compile the program
417
+ compile_error = nvrtc.compile_program(program, options)
418
+
419
+ # Get log from compilation
420
+ log = nvrtc.get_compile_log(program)
421
+
422
+ # If the compile failed, provide the log in an exception
423
+ if compile_error:
424
+ msg = f"NVRTC Compilation failure whilst compiling {name}:\n\n{log}"
425
+ raise NvrtcError(msg)
426
+
427
+ # Otherwise, if there's any content in the log, present it as a warning
428
+ if log:
429
+ msg = f"NVRTC log messages whilst compiling {name}:\n\n{log}"
430
+ warnings.warn(msg)
431
+
432
+ if ltoir:
433
+ ltoir = nvrtc.get_lto(program)
434
+ return ltoir, log
435
+ else:
436
+ ptx = nvrtc.get_ptx(program)
437
+ return ptx, log
429
438
 
430
- # Otherwise, if there's any content in the log, present it as a warning
431
- if log:
432
- msg = f"NVRTC log messages whilst compiling {name}:\n\n{log}"
433
- warnings.warn(msg)
434
439
 
435
- if ltoir:
436
- ltoir = nvrtc.get_lto(program)
437
- return ltoir, log
440
+ def find_closest_arch(mycc):
441
+ """
442
+ Given a compute capability, return the closest compute capability supported
443
+ by the CUDA toolkit.
444
+
445
+ :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)``
446
+ :return: Closest supported CC as a tuple ``(MAJOR, MINOR)``
447
+ """
448
+ supported_ccs = get_supported_ccs()
449
+
450
+ for i, cc in enumerate(supported_ccs):
451
+ if cc == mycc:
452
+ # Matches
453
+ return cc
454
+ elif cc > mycc:
455
+ # Exceeded
456
+ if i == 0:
457
+ # CC lower than supported
458
+ msg = (
459
+ "GPU compute capability %d.%d is not supported"
460
+ "(requires >=%d.%d)" % (mycc + cc)
461
+ )
462
+ raise CCSupportError(msg)
463
+ else:
464
+ # return the previous CC
465
+ return supported_ccs[i - 1]
466
+
467
+ # CC higher than supported
468
+ return supported_ccs[-1] # Choose the highest
469
+
470
+
471
+ def get_arch_option(major, minor):
472
+ """Matches with the closest architecture option"""
473
+ if config.FORCE_CUDA_CC:
474
+ arch = config.FORCE_CUDA_CC
438
475
  else:
439
- ptx = nvrtc.get_ptx(program)
440
- return ptx, log
476
+ arch = find_closest_arch((major, minor))
477
+ return "compute_%d%d" % arch
478
+
479
+
480
+ def get_lowest_supported_cc():
481
+ return min(get_supported_ccs())
482
+
483
+
484
+ def get_supported_ccs():
485
+ return NVRTC().get_supported_archs()
@@ -14,7 +14,7 @@ from llvmlite import ir
14
14
 
15
15
  from .error import NvvmError, NvvmSupportError, NvvmWarning
16
16
  from .libs import get_libdevice, open_libdevice, open_cudalib
17
- from numba.core import cgutils, config
17
+ from numba.core import cgutils
18
18
 
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -179,7 +179,6 @@ class NVVM(object):
179
179
  self._minorIR = ir_versions[1]
180
180
  self._majorDbg = ir_versions[2]
181
181
  self._minorDbg = ir_versions[3]
182
- self._supported_ccs = get_supported_ccs()
183
182
 
184
183
  @property
185
184
  def data_layout(self):
@@ -188,10 +187,6 @@ class NVVM(object):
188
187
  else:
189
188
  return _datalayout_i128
190
189
 
191
- @property
192
- def supported_ccs(self):
193
- return self._supported_ccs
194
-
195
190
  def get_version(self):
196
191
  major = c_int()
197
192
  minor = c_int()
@@ -350,197 +345,6 @@ class CompilationUnit(object):
350
345
  return ""
351
346
 
352
347
 
353
- COMPUTE_CAPABILITIES = (
354
- (3, 5),
355
- (3, 7),
356
- (5, 0),
357
- (5, 2),
358
- (5, 3),
359
- (6, 0),
360
- (6, 1),
361
- (6, 2),
362
- (7, 0),
363
- (7, 2),
364
- (7, 5),
365
- (8, 0),
366
- (8, 6),
367
- (8, 7),
368
- (8, 9),
369
- (9, 0),
370
- (10, 0),
371
- (10, 1),
372
- (10, 3),
373
- (12, 0),
374
- (12, 1),
375
- )
376
-
377
-
378
- # Maps CTK version -> (min supported cc, max supported cc) ranges, bounds inclusive
379
- _CUDA_CC_MIN_MAX_SUPPORT = {
380
- (11, 2): [
381
- ((3, 5), (8, 6)),
382
- ],
383
- (11, 3): [
384
- ((3, 5), (8, 6)),
385
- ],
386
- (11, 4): [
387
- ((3, 5), (8, 7)),
388
- ],
389
- (11, 5): [
390
- ((3, 5), (8, 7)),
391
- ],
392
- (11, 6): [
393
- ((3, 5), (8, 7)),
394
- ],
395
- (11, 7): [
396
- ((3, 5), (8, 7)),
397
- ],
398
- (11, 8): [
399
- ((3, 5), (9, 0)),
400
- ],
401
- (12, 0): [
402
- ((5, 0), (9, 0)),
403
- ],
404
- (12, 1): [
405
- ((5, 0), (9, 0)),
406
- ],
407
- (12, 2): [
408
- ((5, 0), (9, 0)),
409
- ],
410
- (12, 3): [
411
- ((5, 0), (9, 0)),
412
- ],
413
- (12, 4): [
414
- ((5, 0), (9, 0)),
415
- ],
416
- (12, 5): [
417
- ((5, 0), (9, 0)),
418
- ],
419
- (12, 6): [
420
- ((5, 0), (9, 0)),
421
- ],
422
- (12, 8): [
423
- ((5, 0), (10, 1)),
424
- ((12, 0), (12, 0)),
425
- ],
426
- (12, 9): [
427
- ((5, 0), (12, 1)),
428
- ],
429
- }
430
-
431
- # From CUDA 12.9 Release notes, Section 1.5.4, "Deprecated Architectures"
432
- # https://docs.nvidia.com/cuda/archive/12.9.0/cuda-toolkit-release-notes/index.html#deprecated-architectures
433
- #
434
- # "Maxwell, Pascal, and Volta architectures are now feature-complete with no
435
- # further enhancements planned. While CUDA Toolkit 12.x series will continue
436
- # to support building applications for these architectures, offline
437
- # compilation and library support will be removed in the next major CUDA
438
- # Toolkit version release. Users should plan migration to newer
439
- # architectures, as future toolkits will be unable to target Maxwell, Pascal,
440
- # and Volta GPUs."
441
- #
442
- # In order to maintain compatibility with future toolkits, we use Turing (7.5)
443
- # as the default CC if it is not otherwise specified.
444
- LOWEST_CURRENT_CC = (7, 5)
445
-
446
-
447
- def ccs_supported_by_ctk(ctk_version):
448
- try:
449
- # For supported versions, we look up the range of supported CCs
450
- return tuple(
451
- [
452
- cc
453
- for min_cc, max_cc in _CUDA_CC_MIN_MAX_SUPPORT[ctk_version]
454
- for cc in COMPUTE_CAPABILITIES
455
- if min_cc <= cc <= max_cc
456
- ]
457
- )
458
- except KeyError:
459
- # For unsupported CUDA toolkit versions, all we can do is assume all
460
- # non-deprecated versions we are aware of are supported.
461
- #
462
- # If the user has specified a non-default CC that is greater than the
463
- # lowest non-deprecated one, then we should assume that instead.
464
- MIN_CC = max(config.CUDA_DEFAULT_PTX_CC, LOWEST_CURRENT_CC)
465
-
466
- return tuple([cc for cc in COMPUTE_CAPABILITIES if cc >= MIN_CC])
467
-
468
-
469
- def get_supported_ccs():
470
- try:
471
- from numba.cuda.cudadrv.runtime import runtime
472
-
473
- cudart_version = runtime.get_version()
474
- except: # noqa: E722
475
- # We can't support anything if there's an error getting the runtime
476
- # version (e.g. if it's not present or there's another issue)
477
- _supported_cc = ()
478
- return _supported_cc
479
-
480
- # Ensure the minimum CTK version requirement is met
481
- min_cudart = min(_CUDA_CC_MIN_MAX_SUPPORT)
482
- if cudart_version < min_cudart:
483
- _supported_cc = ()
484
- ctk_ver = f"{cudart_version[0]}.{cudart_version[1]}"
485
- unsupported_ver = (
486
- f"CUDA Toolkit {ctk_ver} is unsupported by Numba - "
487
- f"{min_cudart[0]}.{min_cudart[1]} is the minimum "
488
- "required version."
489
- )
490
- warnings.warn(unsupported_ver)
491
- return _supported_cc
492
-
493
- _supported_cc = ccs_supported_by_ctk(cudart_version)
494
- return _supported_cc
495
-
496
-
497
- def find_closest_arch(mycc):
498
- """
499
- Given a compute capability, return the closest compute capability supported
500
- by the CUDA toolkit.
501
-
502
- :param mycc: Compute capability as a tuple ``(MAJOR, MINOR)``
503
- :return: Closest supported CC as a tuple ``(MAJOR, MINOR)``
504
- """
505
- supported_ccs = NVVM().supported_ccs
506
-
507
- if not supported_ccs:
508
- msg = (
509
- "No supported GPU compute capabilities found. "
510
- "Please check your cudatoolkit version matches your CUDA version."
511
- )
512
- raise NvvmSupportError(msg)
513
-
514
- for i, cc in enumerate(supported_ccs):
515
- if cc == mycc:
516
- # Matches
517
- return cc
518
- elif cc > mycc:
519
- # Exceeded
520
- if i == 0:
521
- # CC lower than supported
522
- msg = (
523
- "GPU compute capability %d.%d is not supported"
524
- "(requires >=%d.%d)" % (mycc + cc)
525
- )
526
- raise NvvmSupportError(msg)
527
- else:
528
- # return the previous CC
529
- return supported_ccs[i - 1]
530
-
531
- # CC higher than supported
532
- return supported_ccs[-1] # Choose the highest
533
-
534
-
535
- def get_arch_option(major, minor):
536
- """Matches with the closest architecture option"""
537
- if config.FORCE_CUDA_CC:
538
- arch = config.FORCE_CUDA_CC
539
- else:
540
- arch = find_closest_arch((major, minor))
541
- return "compute_%d%d" % arch
542
-
543
-
544
348
  MISSING_LIBDEVICE_FILE_MSG = """Missing libdevice file.
545
349
  Please ensure you have a CUDA Toolkit 11.2 or higher.
546
350
  For CUDA 12, ``cuda-nvcc`` and ``cuda-nvrtc`` are required: