numba-cuda 0.15.2__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +35 -1
  3. numba_cuda/numba/cuda/codegen.py +11 -9
  4. numba_cuda/numba/cuda/compiler.py +3 -39
  5. numba_cuda/numba/cuda/cuda_paths.py +21 -23
  6. numba_cuda/numba/cuda/cudadrv/devices.py +4 -6
  7. numba_cuda/numba/cuda/cudadrv/driver.py +300 -335
  8. numba_cuda/numba/cuda/cudadrv/error.py +4 -0
  9. numba_cuda/numba/cuda/cudadrv/libs.py +1 -1
  10. numba_cuda/numba/cuda/cudadrv/mappings.py +8 -9
  11. numba_cuda/numba/cuda/cudadrv/nvrtc.py +153 -108
  12. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -197
  13. numba_cuda/numba/cuda/cudadrv/runtime.py +5 -136
  14. numba_cuda/numba/cuda/decorators.py +18 -0
  15. numba_cuda/numba/cuda/dispatcher.py +3 -3
  16. numba_cuda/numba/cuda/flags.py +36 -0
  17. numba_cuda/numba/cuda/memory_management/nrt.py +3 -3
  18. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +6 -2
  19. numba_cuda/numba/cuda/target.py +55 -2
  20. numba_cuda/numba/cuda/testing.py +0 -22
  21. numba_cuda/numba/cuda/tests/__init__.py +0 -2
  22. numba_cuda/numba/cuda/tests/cudadrv/__init__.py +0 -2
  23. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +31 -6
  24. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +5 -1
  25. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  26. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +17 -6
  27. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +9 -167
  28. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +27 -0
  29. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +3 -19
  30. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +1 -37
  31. numba_cuda/numba/cuda/tests/cudapy/__init__.py +0 -2
  32. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +1 -1
  33. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -5
  34. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +0 -9
  35. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +14 -0
  36. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +0 -6
  37. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  38. numba_cuda/numba/cuda/tests/cudapy/test_fastmath.py +0 -4
  39. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +18 -0
  40. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +2 -5
  41. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +0 -7
  42. numba_cuda/numba/cuda/tests/nocuda/__init__.py +0 -2
  43. numba_cuda/numba/cuda/tests/nrt/__init__.py +0 -2
  44. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +10 -1
  45. {numba_cuda-0.15.2.dist-info → numba_cuda-0.17.0.dist-info}/METADATA +5 -2
  46. {numba_cuda-0.15.2.dist-info → numba_cuda-0.17.0.dist-info}/RECORD +49 -47
  47. {numba_cuda-0.15.2.dist-info → numba_cuda-0.17.0.dist-info}/WHEEL +0 -0
  48. {numba_cuda-0.15.2.dist-info → numba_cuda-0.17.0.dist-info}/licenses/LICENSE +0 -0
  49. {numba_cuda-0.15.2.dist-info → numba_cuda-0.17.0.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ import importlib
42
42
  import numpy as np
43
43
  from collections import namedtuple, deque
44
44
 
45
+
45
46
  from numba import mviewbuf
46
47
  from numba.core import utils, serialize, config
47
48
  from .error import CudaSupportError, CudaDriverError
@@ -58,6 +59,22 @@ except ImportError:
58
59
  NvJitLinker, NvJitLinkError = None, None
59
60
 
60
61
 
62
+ USE_NV_BINDING = config.CUDA_USE_NVIDIA_BINDING
63
+
64
+ if USE_NV_BINDING:
65
+ from cuda.bindings import driver as binding
66
+ from cuda.core.experimental import (
67
+ Linker,
68
+ LinkerOptions,
69
+ ObjectCode,
70
+ )
71
+
72
+ # There is no definition of the default stream in the Nvidia bindings (nor
73
+ # is there at the C/C++ level), so we define it here so we don't need to
74
+ # use a magic number 0 in places where we want the default stream.
75
+ CU_STREAM_DEFAULT = 0
76
+
77
+
61
78
  MIN_REQUIRED_CC = (3, 5)
62
79
  SUPPORTS_IPC = sys.platform.startswith("linux")
63
80
 
@@ -108,6 +125,25 @@ def make_logger():
108
125
  return logger
109
126
 
110
127
 
128
+ @functools.cache
129
+ def _have_nvjitlink():
130
+ if not USE_NV_BINDING:
131
+ return False
132
+ try:
133
+ from cuda.bindings._internal import nvjitlink as nvjitlink_internal
134
+ from cuda.bindings._internal.utils import NotSupportedError
135
+ except ImportError:
136
+ return False
137
+ try:
138
+ return (
139
+ nvjitlink_internal._inspect_function_pointer("__nvJitLinkVersion")
140
+ != 0
141
+ )
142
+ except NotSupportedError:
143
+ # no driver
144
+ return False
145
+
146
+
111
147
  class DeadMemoryError(RuntimeError):
112
148
  pass
113
149
 
@@ -454,11 +490,11 @@ class Driver(object):
454
490
  with self.get_active_context() as ac:
455
491
  if ac.devnum is not None:
456
492
  if USE_NV_BINDING:
457
- return driver.cuCtxPopCurrent()
493
+ popped = drvapi.cu_context(int(driver.cuCtxPopCurrent()))
458
494
  else:
459
495
  popped = drvapi.cu_context()
460
496
  driver.cuCtxPopCurrent(byref(popped))
461
- return popped
497
+ return popped
462
498
 
463
499
  def get_active_context(self):
464
500
  """Returns an instance of ``_ActiveContext``."""
@@ -502,6 +538,8 @@ class _ActiveContext(object):
502
538
  hctx = driver.cuCtxGetCurrent()
503
539
  if int(hctx) == 0:
504
540
  hctx = None
541
+ else:
542
+ hctx = drvapi.cu_context(int(hctx))
505
543
  else:
506
544
  hctx = drvapi.cu_context(0)
507
545
  driver.cuCtxGetCurrent(byref(hctx))
@@ -680,6 +718,7 @@ class Device(object):
680
718
  # create primary context
681
719
  if USE_NV_BINDING:
682
720
  hctx = driver.cuDevicePrimaryCtxRetain(self.id)
721
+ hctx = drvapi.cu_context(int(hctx))
683
722
  else:
684
723
  hctx = drvapi.cu_context()
685
724
  driver.cuDevicePrimaryCtxRetain(byref(hctx), self.id)
@@ -1218,6 +1257,7 @@ class _PendingDeallocs(object):
1218
1257
  [dtor, handle, size] = self._cons.popleft()
1219
1258
  _logger.info("dealloc: %s %s bytes", dtor.__name__, size)
1220
1259
  dtor(handle)
1260
+
1221
1261
  self._size = 0
1222
1262
 
1223
1263
  @contextlib.contextmanager
@@ -1394,7 +1434,10 @@ class Context(object):
1394
1434
  """
1395
1435
  Pushes this context on the current CPU Thread.
1396
1436
  """
1397
- driver.cuCtxPushCurrent(self.handle)
1437
+ if USE_NV_BINDING:
1438
+ driver.cuCtxPushCurrent(self.handle.value)
1439
+ else:
1440
+ driver.cuCtxPushCurrent(self.handle)
1398
1441
  self.prepare_for_use()
1399
1442
 
1400
1443
  def pop(self):
@@ -1403,10 +1446,7 @@ class Context(object):
1403
1446
  must be at the top of the context stack, otherwise an error will occur.
1404
1447
  """
1405
1448
  popped = driver.pop_active_context()
1406
- if USE_NV_BINDING:
1407
- assert int(popped) == int(self.handle)
1408
- else:
1409
- assert popped.value == self.handle.value
1449
+ assert popped.value == self.handle.value
1410
1450
 
1411
1451
  def memalloc(self, bytesize):
1412
1452
  return self.memory_manager.memalloc(bytesize)
@@ -1472,7 +1512,7 @@ class Context(object):
1472
1512
  if isinstance(ptx, str):
1473
1513
  ptx = ptx.encode("utf8")
1474
1514
  if USE_NV_BINDING:
1475
- image = ptx
1515
+ image = ObjectCode.from_ptx(ptx)
1476
1516
  else:
1477
1517
  image = c_char_p(ptx)
1478
1518
  return self.create_module_image(image)
@@ -1499,21 +1539,25 @@ class Context(object):
1499
1539
 
1500
1540
  def get_default_stream(self):
1501
1541
  if USE_NV_BINDING:
1502
- handle = binding.CUstream(CU_STREAM_DEFAULT)
1542
+ handle = drvapi.cu_stream(int(binding.CUstream(CU_STREAM_DEFAULT)))
1503
1543
  else:
1504
1544
  handle = drvapi.cu_stream(drvapi.CU_STREAM_DEFAULT)
1505
1545
  return Stream(weakref.proxy(self), handle, None)
1506
1546
 
1507
1547
  def get_legacy_default_stream(self):
1508
1548
  if USE_NV_BINDING:
1509
- handle = binding.CUstream(binding.CU_STREAM_LEGACY)
1549
+ handle = drvapi.cu_stream(
1550
+ int(binding.CUstream(binding.CU_STREAM_LEGACY))
1551
+ )
1510
1552
  else:
1511
1553
  handle = drvapi.cu_stream(drvapi.CU_STREAM_LEGACY)
1512
1554
  return Stream(weakref.proxy(self), handle, None)
1513
1555
 
1514
1556
  def get_per_thread_default_stream(self):
1515
1557
  if USE_NV_BINDING:
1516
- handle = binding.CUstream(binding.CU_STREAM_PER_THREAD)
1558
+ handle = drvapi.cu_stream(
1559
+ int(binding.CUstream(binding.CU_STREAM_PER_THREAD))
1560
+ )
1517
1561
  else:
1518
1562
  handle = drvapi.cu_stream(drvapi.CU_STREAM_PER_THREAD)
1519
1563
  return Stream(weakref.proxy(self), handle, None)
@@ -1525,7 +1569,7 @@ class Context(object):
1525
1569
  # default stream, which we define also as CU_STREAM_DEFAULT when
1526
1570
  # the NV binding is in use).
1527
1571
  flags = binding.CUstream_flags.CU_STREAM_DEFAULT.value
1528
- handle = driver.cuStreamCreate(flags)
1572
+ handle = drvapi.cu_stream(int(driver.cuStreamCreate(flags)))
1529
1573
  else:
1530
1574
  handle = drvapi.cu_stream()
1531
1575
  driver.cuStreamCreate(byref(handle), 0)
@@ -1539,7 +1583,7 @@ class Context(object):
1539
1583
  if not isinstance(ptr, int):
1540
1584
  raise TypeError("ptr for external stream must be an int")
1541
1585
  if USE_NV_BINDING:
1542
- handle = binding.CUstream(ptr)
1586
+ handle = drvapi.cu_stream(int(binding.CUstream(ptr)))
1543
1587
  else:
1544
1588
  handle = drvapi.cu_stream(ptr)
1545
1589
  return Stream(weakref.proxy(self), handle, None, external=True)
@@ -1549,7 +1593,7 @@ class Context(object):
1549
1593
  if not timing:
1550
1594
  flags |= enums.CU_EVENT_DISABLE_TIMING
1551
1595
  if USE_NV_BINDING:
1552
- handle = driver.cuEventCreate(flags)
1596
+ handle = drvapi.cu_event(int(driver.cuEventCreate(flags)))
1553
1597
  else:
1554
1598
  handle = drvapi.cu_event()
1555
1599
  driver.cuEventCreate(byref(handle), flags)
@@ -1615,7 +1659,6 @@ def load_module_image_ctypes(
1615
1659
 
1616
1660
  option_keys = (drvapi.cu_jit_option * len(options))(*options.keys())
1617
1661
  option_vals = (c_void_p * len(options))(*options.values())
1618
-
1619
1662
  handle = drvapi.cu_module()
1620
1663
  try:
1621
1664
  driver.cuModuleLoadDataEx(
@@ -1662,7 +1705,7 @@ def load_module_image_cuda_python(
1662
1705
 
1663
1706
  try:
1664
1707
  handle = driver.cuModuleLoadDataEx(
1665
- image, len(options), option_keys, option_vals
1708
+ image.code, len(options), option_keys, option_vals
1666
1709
  )
1667
1710
  except CudaAPIError as e:
1668
1711
  err_string = jiterrors.decode("utf-8")
@@ -1741,14 +1784,14 @@ def _pin_finalizer(memory_manager, ptr, alloc_key, mapped):
1741
1784
 
1742
1785
  def _event_finalizer(deallocs, handle):
1743
1786
  def core():
1744
- deallocs.add_item(driver.cuEventDestroy, handle)
1787
+ deallocs.add_item(driver.cuEventDestroy, handle.value)
1745
1788
 
1746
1789
  return core
1747
1790
 
1748
1791
 
1749
1792
  def _stream_finalizer(deallocs, handle):
1750
1793
  def core():
1751
- deallocs.add_item(driver.cuStreamDestroy, handle)
1794
+ deallocs.add_item(driver.cuStreamDestroy, handle.value)
1752
1795
 
1753
1796
  return core
1754
1797
 
@@ -2019,6 +2062,9 @@ class MemoryPointer(object):
2019
2062
  __cuda_memory__ = True
2020
2063
 
2021
2064
  def __init__(self, context, pointer, size, owner=None, finalizer=None):
2065
+ if USE_NV_BINDING and isinstance(pointer, ctypes.c_void_p):
2066
+ pointer = binding.CUdeviceptr(pointer.value)
2067
+
2022
2068
  self.context = context
2023
2069
  self.device_pointer = pointer
2024
2070
  self.size = size
@@ -2051,9 +2097,11 @@ class MemoryPointer(object):
2051
2097
  def memset(self, byte, count=None, stream=0):
2052
2098
  count = self.size if count is None else count
2053
2099
  if stream:
2054
- driver.cuMemsetD8Async(
2055
- self.device_pointer, byte, count, stream.handle
2056
- )
2100
+ if USE_NV_BINDING:
2101
+ handle = stream.handle.value
2102
+ else:
2103
+ handle = stream.handle
2104
+ driver.cuMemsetD8Async(self.device_pointer, byte, count, handle)
2057
2105
  else:
2058
2106
  driver.cuMemsetD8(self.device_pointer, byte, count)
2059
2107
 
@@ -2291,27 +2339,16 @@ class Stream(object):
2291
2339
  weakref.finalize(self, finalizer)
2292
2340
 
2293
2341
  def __int__(self):
2294
- if USE_NV_BINDING:
2295
- return int(self.handle)
2296
- else:
2297
- # The default stream's handle.value is 0, which gives `None`
2298
- return self.handle.value or drvapi.CU_STREAM_DEFAULT
2342
+ # The default stream's handle.value is 0, which gives `None`
2343
+ return self.handle.value or drvapi.CU_STREAM_DEFAULT
2299
2344
 
2300
2345
  def __repr__(self):
2301
- if USE_NV_BINDING:
2302
- default_streams = {
2303
- CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2304
- binding.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2305
- binding.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2306
- }
2307
- ptr = int(self.handle) or 0
2308
- else:
2309
- default_streams = {
2310
- drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2311
- drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2312
- drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2313
- }
2314
- ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2346
+ default_streams = {
2347
+ drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2348
+ drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2349
+ drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2350
+ }
2351
+ ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2315
2352
 
2316
2353
  if ptr in default_streams:
2317
2354
  return default_streams[ptr] % self.context
@@ -2325,7 +2362,11 @@ class Stream(object):
2325
2362
  Wait for all commands in this stream to execute. This will commit any
2326
2363
  pending memory transfers.
2327
2364
  """
2328
- driver.cuStreamSynchronize(self.handle)
2365
+ if USE_NV_BINDING:
2366
+ handle = self.handle.value
2367
+ else:
2368
+ handle = self.handle
2369
+ driver.cuStreamSynchronize(handle)
2329
2370
 
2330
2371
  @contextlib.contextmanager
2331
2372
  def auto_synchronize(self):
@@ -2350,6 +2391,16 @@ class Stream(object):
2350
2391
  callback will block later work in the stream and may block other
2351
2392
  callbacks from being executed.
2352
2393
 
2394
+ .. warning::
2395
+ There is a potential for deadlock due to a lock ordering issue
2396
+ between the GIL and the CUDA driver lock when using libraries
2397
+ that call CUDA functions without releasing the GIL. This can
2398
+ occur when the callback function, which holds the CUDA driver lock,
2399
+ attempts to acquire the GIL while another thread that holds the GIL
2400
+ is waiting for the CUDA driver lock. Consider using libraries that
2401
+ properly release the GIL around CUDA operations or restructure
2402
+ your code to avoid this situation.
2403
+
2353
2404
  Note: The driver function underlying this method is marked for
2354
2405
  eventual deprecation and may be replaced in a future CUDA release.
2355
2406
 
@@ -2363,9 +2414,11 @@ class Stream(object):
2363
2414
  stream_callback = binding.CUstreamCallback(ptr)
2364
2415
  # The callback needs to receive a pointer to the data PyObject
2365
2416
  data = id(data)
2417
+ handle = self.handle.value
2366
2418
  else:
2367
2419
  stream_callback = self._stream_callback
2368
- driver.cuStreamAddCallback(self.handle, stream_callback, data, 0)
2420
+ handle = self.handle
2421
+ driver.cuStreamAddCallback(handle, stream_callback, data, 0)
2369
2422
 
2370
2423
  @staticmethod
2371
2424
  @cu_stream_callback_pyobj
@@ -2382,6 +2435,16 @@ class Stream(object):
2382
2435
  """
2383
2436
  Return an awaitable that resolves once all preceding stream operations
2384
2437
  are complete. The result of the awaitable is the current stream.
2438
+
2439
+ .. warning::
2440
+ There is a potential for deadlock due to a lock ordering issue
2441
+ between the GIL and the CUDA driver lock when using libraries
2442
+ that call CUDA functions without releasing the GIL. This can
2443
+ occur when the callback function (internally used by this method),
2444
+ which holds the CUDA driver lock, attempts to acquire the GIL
2445
+ while another thread that holds the GIL is waiting for the CUDA driver lock.
2446
+ Consider using libraries that properly release the GIL around
2447
+ CUDA operations or restructure your code to avoid this situation.
2385
2448
  """
2386
2449
  loop = asyncio.get_running_loop()
2387
2450
  future = loop.create_future()
@@ -2433,27 +2496,35 @@ class Event(object):
2433
2496
  completed.
2434
2497
  """
2435
2498
  if USE_NV_BINDING:
2436
- hstream = stream.handle if stream else binding.CUstream(0)
2499
+ hstream = stream.handle.value if stream else binding.CUstream(0)
2500
+ handle = self.handle.value
2437
2501
  else:
2438
2502
  hstream = stream.handle if stream else 0
2439
- driver.cuEventRecord(self.handle, hstream)
2503
+ handle = self.handle
2504
+ driver.cuEventRecord(handle, hstream)
2440
2505
 
2441
2506
  def synchronize(self):
2442
2507
  """
2443
2508
  Synchronize the host thread for the completion of the event.
2444
2509
  """
2445
- driver.cuEventSynchronize(self.handle)
2510
+ if USE_NV_BINDING:
2511
+ handle = self.handle.value
2512
+ else:
2513
+ handle = self.handle
2514
+ driver.cuEventSynchronize(handle)
2446
2515
 
2447
2516
  def wait(self, stream=0):
2448
2517
  """
2449
2518
  All future works submitted to stream will wait util the event completes.
2450
2519
  """
2451
2520
  if USE_NV_BINDING:
2452
- hstream = stream.handle if stream else binding.CUstream(0)
2521
+ hstream = stream.handle.value if stream else binding.CUstream(0)
2522
+ handle = self.handle.value
2453
2523
  else:
2454
2524
  hstream = stream.handle if stream else 0
2525
+ handle = self.handle
2455
2526
  flags = 0
2456
- driver.cuStreamWaitEvent(hstream, self.handle, flags)
2527
+ driver.cuStreamWaitEvent(hstream, handle, flags)
2457
2528
 
2458
2529
  def elapsed_time(self, evtend):
2459
2530
  return event_elapsed_time(self, evtend)
@@ -2464,7 +2535,9 @@ def event_elapsed_time(evtstart, evtend):
2464
2535
  Compute the elapsed time between two events in milliseconds.
2465
2536
  """
2466
2537
  if USE_NV_BINDING:
2467
- return driver.cuEventElapsedTime(evtstart.handle, evtend.handle)
2538
+ return driver.cuEventElapsedTime(
2539
+ evtstart.handle.value, evtend.handle.value
2540
+ )
2468
2541
  else:
2469
2542
  msec = c_float()
2470
2543
  driver.cuEventElapsedTime(byref(msec), evtstart.handle, evtend.handle)
@@ -2722,7 +2795,7 @@ def launch_kernel(
2722
2795
  )
2723
2796
 
2724
2797
 
2725
- class Linker(metaclass=ABCMeta):
2798
+ class _LinkerBase(metaclass=ABCMeta):
2726
2799
  """Abstract base class for linkers"""
2727
2800
 
2728
2801
  @classmethod
@@ -2735,30 +2808,27 @@ class Linker(metaclass=ABCMeta):
2735
2808
  additional_flags=None,
2736
2809
  ):
2737
2810
  driver_ver = driver.get_version()
2738
- if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and driver_ver >= (
2739
- 12,
2740
- 0,
2741
- ):
2742
- raise ValueError("Use CUDA_ENABLE_PYNVJITLINK for CUDA >= 12.0 MVC")
2743
- if config.CUDA_ENABLE_PYNVJITLINK and driver_ver < (12, 0):
2744
- raise ValueError("Enabling pynvjitlink requires CUDA 12.")
2745
- if config.CUDA_ENABLE_PYNVJITLINK:
2746
- linker = PyNvJitLinker
2747
-
2748
- elif config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
2749
- linker = MVCLinker
2811
+ if driver_ver < (12, 0):
2812
+ if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
2813
+ linker = MVCLinker
2814
+ elif USE_NV_BINDING:
2815
+ linker = _Linker
2816
+ else:
2817
+ linker = CtypesLinker
2750
2818
  else:
2751
2819
  if USE_NV_BINDING:
2752
- linker = CudaPythonLinker
2820
+ linker = _Linker
2753
2821
  else:
2754
2822
  linker = CtypesLinker
2755
2823
 
2756
- if linker is PyNvJitLinker:
2757
- return linker(max_registers, lineinfo, cc, lto, additional_flags)
2758
- elif additional_flags or lto:
2759
- raise ValueError("LTO and additional flags require PyNvJitLinker")
2824
+ params = (max_registers, lineinfo, cc)
2825
+ if linker is _Linker:
2826
+ params = (*params, lto, additional_flags)
2760
2827
  else:
2761
- return linker(max_registers, lineinfo, cc)
2828
+ if lto or additional_flags:
2829
+ raise ValueError("LTO and additional flags require nvjitlink")
2830
+
2831
+ return linker(*params)
2762
2832
 
2763
2833
  @abstractmethod
2764
2834
  def __init__(self, max_registers, lineinfo, cc):
@@ -2786,7 +2856,6 @@ class Linker(metaclass=ABCMeta):
2786
2856
  with driver.get_active_context() as ac:
2787
2857
  dev = driver.get_device(ac.devnum)
2788
2858
  cc = dev.compute_capability
2789
-
2790
2859
  ptx, log = nvrtc.compile(cu, name, cc)
2791
2860
 
2792
2861
  if config.DUMP_ASSEMBLY:
@@ -2821,7 +2890,6 @@ class Linker(metaclass=ABCMeta):
2821
2890
  LTO-ed portion of the PTX when linker is added with objects that can be
2822
2891
  both LTO-ed and not LTO-ed.
2823
2892
  """
2824
-
2825
2893
  if isinstance(path_or_code, str):
2826
2894
  ext = pathlib.Path(path_or_code).suffix
2827
2895
  if ext == "":
@@ -2901,7 +2969,148 @@ class Linker(metaclass=ABCMeta):
2901
2969
  """
2902
2970
 
2903
2971
 
2904
- class MVCLinker(Linker):
2972
+ class _Linker(_LinkerBase):
2973
+ def __init__(
2974
+ self,
2975
+ max_registers=None,
2976
+ lineinfo=False,
2977
+ cc=None,
2978
+ lto=None,
2979
+ additional_flags=None,
2980
+ ):
2981
+ arch = f"sm_{cc[0]}{cc[1]}"
2982
+ self.max_registers = max_registers if max_registers else None
2983
+ self.lineinfo = lineinfo
2984
+ self.cc = cc
2985
+ self.arch = arch
2986
+ if lto is False:
2987
+ # WAR for apparent nvjitlink issue
2988
+ lto = None
2989
+ self.lto = lto
2990
+ self.additional_flags = additional_flags
2991
+
2992
+ self.options = LinkerOptions(
2993
+ max_register_count=self.max_registers,
2994
+ lineinfo=lineinfo,
2995
+ arch=arch,
2996
+ link_time_optimization=lto,
2997
+ )
2998
+ self._complete = False
2999
+ self._object_codes = []
3000
+ self.linker = None # need at least one program
3001
+
3002
+ @property
3003
+ def info_log(self):
3004
+ if not self.linker:
3005
+ raise ValueError("Not Initialized")
3006
+ if self._complete:
3007
+ return self._info_log
3008
+ raise RuntimeError("Link not yet complete.")
3009
+
3010
+ @property
3011
+ def error_log(self):
3012
+ if not self.linker:
3013
+ raise ValueError("Not Initialized")
3014
+ if self._complete:
3015
+ return self._error_log
3016
+ raise RuntimeError("Link not yet complete.")
3017
+
3018
+ def add_ptx(self, ptx, name="<cudapy-ptx>"):
3019
+ obj = ObjectCode.from_ptx(ptx, name=name)
3020
+ self._object_codes.append(obj)
3021
+
3022
+ def add_cu(self, cu, name="<cudapy-cu>"):
3023
+ with driver.get_active_context() as ac:
3024
+ dev = driver.get_device(ac.devnum)
3025
+ cc = dev.compute_capability
3026
+ obj, log = nvrtc.compile(cu, name, cc, ltoir=self.lto)
3027
+
3028
+ if not self.lto and config.DUMP_ASSEMBLY:
3029
+ print(("ASSEMBLY %s" % name).center(80, "-"))
3030
+ print(obj.code)
3031
+
3032
+ self._object_codes.append(obj)
3033
+
3034
+ def add_cubin(self, cubin, name="<cudapy-cubin>"):
3035
+ obj = ObjectCode.from_cubin(cubin, name=name)
3036
+ self._object_codes.append(obj)
3037
+
3038
+ def add_ltoir(self, ltoir, name="<cudapy-ltoir>"):
3039
+ obj = ObjectCode.from_ltoir(ltoir, name=name)
3040
+ self._object_codes.append(obj)
3041
+
3042
+ def add_fatbin(self, fatbin, name="<cudapy-fatbin>"):
3043
+ obj = ObjectCode.from_fatbin(fatbin, name=name)
3044
+ self._object_codes.append(obj)
3045
+
3046
+ def add_object(self, obj, name="<cudapy-object>"):
3047
+ obj = ObjectCode.from_object(obj, name=name)
3048
+ self._object_codes.append(obj)
3049
+
3050
+ def add_library(self, lib, name="<cudapy-lib>"):
3051
+ obj = ObjectCode.from_library(lib, name=name)
3052
+ self._object_codes.append(obj)
3053
+
3054
+ def add_file(self, path, kind):
3055
+ try:
3056
+ data = cached_file_read(path, how="rb")
3057
+ except FileNotFoundError:
3058
+ raise LinkerError(f"{path} not found")
3059
+ name = pathlib.Path(path).name
3060
+ self.add_data(data, kind, name)
3061
+
3062
+ def add_data(self, data, kind, name):
3063
+ if kind == FILE_EXTENSION_MAP["ptx"]:
3064
+ fn = self.add_ptx
3065
+ elif kind == FILE_EXTENSION_MAP["cubin"]:
3066
+ fn = self.add_cubin
3067
+ elif kind == "cu":
3068
+ fn = self.add_cu
3069
+ elif (
3070
+ kind == FILE_EXTENSION_MAP["lib"] or kind == FILE_EXTENSION_MAP["a"]
3071
+ ):
3072
+ fn = self.add_library
3073
+ elif kind == FILE_EXTENSION_MAP["fatbin"]:
3074
+ fn = self.add_fatbin
3075
+ elif kind == FILE_EXTENSION_MAP["o"]:
3076
+ fn = self.add_object
3077
+ elif kind == FILE_EXTENSION_MAP["ltoir"]:
3078
+ fn = self.add_ltoir
3079
+ else:
3080
+ raise LinkerError(f"Don't know how to link {kind}")
3081
+
3082
+ fn(data, name)
3083
+
3084
+ def get_linked_ptx(self):
3085
+ options = LinkerOptions(
3086
+ max_register_count=self.max_registers,
3087
+ lineinfo=self.lineinfo,
3088
+ arch=self.arch,
3089
+ link_time_optimization=True,
3090
+ ptx=True,
3091
+ )
3092
+
3093
+ self.linker = Linker(*self._object_codes, options=options)
3094
+
3095
+ result = self.linker.link("ptx")
3096
+ self.close()
3097
+ self._complete = True
3098
+ return result.code
3099
+
3100
+ def close(self):
3101
+ self._info_log = self.linker.get_info_log()
3102
+ self._error_log = self.linker.get_error_log()
3103
+ self.linker.close()
3104
+
3105
+ def complete(self):
3106
+ self.linker = Linker(*self._object_codes, options=self.options)
3107
+ result = self.linker.link("cubin")
3108
+ self.close()
3109
+ self._complete = True
3110
+ return result
3111
+
3112
+
3113
+ class MVCLinker(_LinkerBase):
2905
3114
  """
2906
3115
  Linker supporting Minor Version Compatibility, backed by the cubinlinker
2907
3116
  package.
@@ -2996,7 +3205,7 @@ class MVCLinker(Linker):
2996
3205
  raise LinkerError from e
2997
3206
 
2998
3207
 
2999
- class CtypesLinker(Linker):
3208
+ class CtypesLinker(_LinkerBase):
3000
3209
  """
3001
3210
  Links for current device if no CC given
3002
3211
  """
@@ -3139,266 +3348,6 @@ class CtypesLinker(Linker):
3139
3348
  return bytes(np.ctypeslib.as_array(cubin_ptr, shape=(size,)))
3140
3349
 
3141
3350
 
3142
- class CudaPythonLinker(Linker):
3143
- """
3144
- Links for current device if no CC given
3145
- """
3146
-
3147
- def __init__(self, max_registers=0, lineinfo=False, cc=None):
3148
- super().__init__(max_registers, lineinfo, cc)
3149
-
3150
- logsz = config.CUDA_LOG_SIZE
3151
- linkerinfo = bytearray(logsz)
3152
- linkererrors = bytearray(logsz)
3153
-
3154
- jit_option = binding.CUjit_option
3155
-
3156
- options = {
3157
- jit_option.CU_JIT_INFO_LOG_BUFFER: linkerinfo,
3158
- jit_option.CU_JIT_INFO_LOG_BUFFER_SIZE_BYTES: logsz,
3159
- jit_option.CU_JIT_ERROR_LOG_BUFFER: linkererrors,
3160
- jit_option.CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES: logsz,
3161
- jit_option.CU_JIT_LOG_VERBOSE: 1,
3162
- }
3163
- if max_registers:
3164
- options[jit_option.CU_JIT_MAX_REGISTERS] = max_registers
3165
- if lineinfo:
3166
- options[jit_option.CU_JIT_GENERATE_LINE_INFO] = 1
3167
-
3168
- if cc is None:
3169
- # No option value is needed, but we need something as a placeholder
3170
- options[jit_option.CU_JIT_TARGET_FROM_CUCONTEXT] = 1
3171
- else:
3172
- cc_val = cc[0] * 10 + cc[1]
3173
- cc_enum = getattr(
3174
- binding.CUjit_target, f"CU_TARGET_COMPUTE_{cc_val}"
3175
- )
3176
- options[jit_option.CU_JIT_TARGET] = cc_enum
3177
-
3178
- raw_keys = list(options.keys())
3179
- raw_values = list(options.values())
3180
- self.handle = driver.cuLinkCreate(len(raw_keys), raw_keys, raw_values)
3181
-
3182
- weakref.finalize(self, driver.cuLinkDestroy, self.handle)
3183
-
3184
- self.linker_info_buf = linkerinfo
3185
- self.linker_errors_buf = linkererrors
3186
-
3187
- self._keep_alive = [linkerinfo, linkererrors, raw_keys, raw_values]
3188
-
3189
- @property
3190
- def info_log(self):
3191
- return self.linker_info_buf.decode("utf8")
3192
-
3193
- @property
3194
- def error_log(self):
3195
- return self.linker_errors_buf.decode("utf8")
3196
-
3197
- def add_cubin(self, cubin, name="<unnamed-cubin>"):
3198
- input_type = binding.CUjitInputType.CU_JIT_INPUT_CUBIN
3199
- return self._add_data(input_type, cubin, name)
3200
-
3201
- def add_ptx(self, ptx, name="<unnamed-ptx>"):
3202
- input_type = binding.CUjitInputType.CU_JIT_INPUT_PTX
3203
- return self._add_data(input_type, ptx, name)
3204
-
3205
- def add_object(self, object_, name="<unnamed-object>"):
3206
- input_type = binding.CUjitInputType.CU_JIT_INPUT_OBJECT
3207
- return self._add_data(input_type, object_, name)
3208
-
3209
- def add_fatbin(self, fatbin, name="<unnamed-fatbin>"):
3210
- input_type = binding.CUjitInputType.CU_JIT_INPUT_FATBINARY
3211
- return self._add_data(input_type, fatbin, name)
3212
-
3213
- def add_library(self, library, name="<unnamed-library>"):
3214
- input_type = binding.CUjitInputType.CU_JIT_INPUT_LIBRARY
3215
- return self._add_data(input_type, library, name)
3216
-
3217
- def _add_data(self, input_type, data, name):
3218
- name_buffer = name.encode("utf8")
3219
- self._keep_alive += [data, name_buffer]
3220
- try:
3221
- driver.cuLinkAddData(
3222
- self.handle, input_type, data, len(data), name_buffer, 0, [], []
3223
- )
3224
- except CudaAPIError as e:
3225
- raise LinkerError("%s\n%s" % (e, self.error_log))
3226
-
3227
- def add_data(self, data, kind, name=None):
3228
- # We pass the name as **kwargs to ensure the default name for the input
3229
- # type is used if none is supplied
3230
- kws = {}
3231
- if name is not None:
3232
- kws["name"] = name
3233
-
3234
- if kind == FILE_EXTENSION_MAP["cubin"]:
3235
- self.add_cubin(data, **kws)
3236
- elif kind == FILE_EXTENSION_MAP["fatbin"]:
3237
- self.add_fatbin(data, **kws)
3238
- elif kind == FILE_EXTENSION_MAP["a"]:
3239
- self.add_library(data, **kws)
3240
- elif kind == FILE_EXTENSION_MAP["ptx"]:
3241
- self.add_ptx(data, **kws)
3242
- elif kind == FILE_EXTENSION_MAP["o"]:
3243
- self.add_object(data, **kws)
3244
- elif kind == FILE_EXTENSION_MAP["ltoir"]:
3245
- raise LinkerError("CudaPythonLinker cannot link LTO-IR")
3246
- else:
3247
- raise LinkerError(f"Don't know how to link {kind}")
3248
-
3249
- def add_file(self, path, kind):
3250
- pathbuf = path.encode("utf8")
3251
- self._keep_alive.append(pathbuf)
3252
-
3253
- try:
3254
- driver.cuLinkAddFile(self.handle, kind, pathbuf, 0, [], [])
3255
- except CudaAPIError as e:
3256
- if e.code == binding.CUresult.CUDA_ERROR_FILE_NOT_FOUND:
3257
- msg = f"{path} not found"
3258
- else:
3259
- msg = "%s\n%s" % (e, self.error_log)
3260
- raise LinkerError(msg)
3261
-
3262
- def complete(self):
3263
- try:
3264
- cubin_buf, size = driver.cuLinkComplete(self.handle)
3265
- except CudaAPIError as e:
3266
- raise LinkerError("%s\n%s" % (e, self.error_log))
3267
-
3268
- assert size > 0, "linker returned a zero sized cubin"
3269
- del self._keep_alive[:]
3270
- # We return a copy of the cubin because it's owned by the linker
3271
- cubin_ptr = ctypes.cast(cubin_buf, ctypes.POINTER(ctypes.c_char))
3272
- return bytes(np.ctypeslib.as_array(cubin_ptr, shape=(size,)))
3273
-
3274
-
3275
- class PyNvJitLinker(Linker):
3276
- def __init__(
3277
- self,
3278
- max_registers=None,
3279
- lineinfo=False,
3280
- cc=None,
3281
- lto=False,
3282
- additional_flags=None,
3283
- ):
3284
- if NvJitLinker is None:
3285
- raise ImportError(
3286
- "Using pynvjitlink requires the pynvjitlink package to be "
3287
- "available"
3288
- )
3289
-
3290
- if config.CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY:
3291
- raise ValueError(
3292
- "Can't set CUDA_ENABLE_MINOR_VERSION_COMPATIBILITY and "
3293
- "CUDA_ENABLE_PYNVJITLINK at the same time"
3294
- )
3295
-
3296
- if cc is None:
3297
- raise RuntimeError("PyNvJitLinker requires CC to be specified")
3298
- if not any(isinstance(cc, t) for t in [list, tuple]):
3299
- raise TypeError("`cc` must be a list or tuple of length 2")
3300
-
3301
- sm_ver = f"{cc[0] * 10 + cc[1]}"
3302
- arch = f"-arch=sm_{sm_ver}"
3303
- options = [arch]
3304
- if max_registers:
3305
- options.append(f"-maxrregcount={max_registers}")
3306
- if lineinfo:
3307
- options.append("-lineinfo")
3308
- if lto:
3309
- options.append("-lto")
3310
- if additional_flags is not None:
3311
- options.extend(additional_flags)
3312
-
3313
- self._linker = NvJitLinker(*options)
3314
- self.lto = lto
3315
- self.options = options
3316
-
3317
- @property
3318
- def info_log(self):
3319
- return self._linker.info_log
3320
-
3321
- @property
3322
- def error_log(self):
3323
- return self._linker.error_log
3324
-
3325
- def add_ptx(self, ptx, name="<cudapy-ptx>"):
3326
- self._linker.add_ptx(ptx, name)
3327
-
3328
- def add_fatbin(self, fatbin, name="<external-fatbin>"):
3329
- self._linker.add_fatbin(fatbin, name)
3330
-
3331
- def add_ltoir(self, ltoir, name="<external-ltoir>"):
3332
- self._linker.add_ltoir(ltoir, name)
3333
-
3334
- def add_object(self, obj, name="<external-object>"):
3335
- self._linker.add_object(obj, name)
3336
-
3337
- def add_file(self, path, kind):
3338
- try:
3339
- data = cached_file_read(path, "rb")
3340
- except FileNotFoundError:
3341
- raise LinkerError(f"{path} not found")
3342
-
3343
- name = pathlib.Path(path).name
3344
- self.add_data(data, kind, name)
3345
-
3346
- def add_cu(self, cu, name):
3347
- """Add CUDA source in a string to the link. The name of the source
3348
- file should be specified in `name`."""
3349
- with driver.get_active_context() as ac:
3350
- dev = driver.get_device(ac.devnum)
3351
- cc = dev.compute_capability
3352
-
3353
- program, log = nvrtc.compile(cu, name, cc, ltoir=self.lto)
3354
-
3355
- if not self.lto and config.DUMP_ASSEMBLY:
3356
- print(("ASSEMBLY %s" % name).center(80, "-"))
3357
- print(program)
3358
- print("=" * 80)
3359
-
3360
- suffix = ".ltoir" if self.lto else ".ptx"
3361
- program_name = os.path.splitext(name)[0] + suffix
3362
- # Link the program's PTX or LTOIR using the normal linker mechanism
3363
- if self.lto:
3364
- self.add_ltoir(program, program_name)
3365
- else:
3366
- self.add_ptx(program.encode(), program_name)
3367
-
3368
- def add_data(self, data, kind, name):
3369
- if kind == FILE_EXTENSION_MAP["cubin"]:
3370
- fn = self._linker.add_cubin
3371
- elif kind == FILE_EXTENSION_MAP["fatbin"]:
3372
- fn = self._linker.add_fatbin
3373
- elif kind == FILE_EXTENSION_MAP["a"]:
3374
- fn = self._linker.add_library
3375
- elif kind == FILE_EXTENSION_MAP["ptx"]:
3376
- return self.add_ptx(data, name)
3377
- elif kind == FILE_EXTENSION_MAP["o"]:
3378
- fn = self._linker.add_object
3379
- elif kind == FILE_EXTENSION_MAP["ltoir"]:
3380
- fn = self._linker.add_ltoir
3381
- else:
3382
- raise LinkerError(f"Don't know how to link {kind}")
3383
-
3384
- try:
3385
- fn(data, name)
3386
- except NvJitLinkError as e:
3387
- raise LinkerError from e
3388
-
3389
- def get_linked_ptx(self):
3390
- try:
3391
- return self._linker.get_linked_ptx()
3392
- except NvJitLinkError as e:
3393
- raise LinkerError from e
3394
-
3395
- def complete(self):
3396
- try:
3397
- return self._linker.get_linked_cubin()
3398
- except NvJitLinkError as e:
3399
- raise LinkerError from e
3400
-
3401
-
3402
3351
  # -----------------------------------------------------------------------------
3403
3352
 
3404
3353
 
@@ -3566,7 +3515,11 @@ def host_to_device(dst, src, size, stream=0):
3566
3515
  if stream:
3567
3516
  assert isinstance(stream, Stream)
3568
3517
  fn = driver.cuMemcpyHtoDAsync
3569
- varargs.append(stream.handle)
3518
+ if USE_NV_BINDING:
3519
+ handle = stream.handle.value
3520
+ else:
3521
+ handle = stream.handle
3522
+ varargs.append(handle)
3570
3523
  else:
3571
3524
  fn = driver.cuMemcpyHtoD
3572
3525
 
@@ -3584,7 +3537,11 @@ def device_to_host(dst, src, size, stream=0):
3584
3537
  if stream:
3585
3538
  assert isinstance(stream, Stream)
3586
3539
  fn = driver.cuMemcpyDtoHAsync
3587
- varargs.append(stream.handle)
3540
+ if USE_NV_BINDING:
3541
+ handle = stream.handle.value
3542
+ else:
3543
+ handle = stream.handle
3544
+ varargs.append(handle)
3588
3545
  else:
3589
3546
  fn = driver.cuMemcpyDtoH
3590
3547
 
@@ -3602,7 +3559,11 @@ def device_to_device(dst, src, size, stream=0):
3602
3559
  if stream:
3603
3560
  assert isinstance(stream, Stream)
3604
3561
  fn = driver.cuMemcpyDtoDAsync
3605
- varargs.append(stream.handle)
3562
+ if USE_NV_BINDING:
3563
+ handle = stream.handle.value
3564
+ else:
3565
+ handle = stream.handle
3566
+ varargs.append(handle)
3606
3567
  else:
3607
3568
  fn = driver.cuMemcpyDtoD
3608
3569
 
@@ -3623,7 +3584,11 @@ def device_memset(dst, val, size, stream=0):
3623
3584
  if stream:
3624
3585
  assert isinstance(stream, Stream)
3625
3586
  fn = driver.cuMemsetD8Async
3626
- varargs.append(stream.handle)
3587
+ if USE_NV_BINDING:
3588
+ handle = stream.handle.value
3589
+ else:
3590
+ handle = stream.handle
3591
+ varargs.append(handle)
3627
3592
  else:
3628
3593
  fn = driver.cuMemsetD8
3629
3594