numba-cuda 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +1 -1
  14. numba_cuda/numba/cuda/cudadecl.py +0 -268
  15. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  16. numba_cuda/numba/cuda/cudadrv/devices.py +4 -6
  17. numba_cuda/numba/cuda/cudadrv/driver.py +105 -50
  18. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  19. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  20. numba_cuda/numba/cuda/debuginfo.py +469 -3
  21. numba_cuda/numba/cuda/device_init.py +0 -1
  22. numba_cuda/numba/cuda/dispatcher.py +311 -14
  23. numba_cuda/numba/cuda/extending.py +2 -1
  24. numba_cuda/numba/cuda/fp16.py +348 -0
  25. numba_cuda/numba/cuda/intrinsics.py +1 -1
  26. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  27. numba_cuda/numba/cuda/lowering.py +1833 -8
  28. numba_cuda/numba/cuda/mathimpl.py +2 -90
  29. numba_cuda/numba/cuda/memory_management/nrt.py +1 -1
  30. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  31. numba_cuda/numba/cuda/printimpl.py +2 -1
  32. numba_cuda/numba/cuda/serialize.py +264 -0
  33. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  34. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  35. numba_cuda/numba/cuda/stubs.py +0 -308
  36. numba_cuda/numba/cuda/target.py +13 -5
  37. numba_cuda/numba/cuda/testing.py +156 -5
  38. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  39. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  40. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  41. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +16 -5
  42. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +5 -1
  43. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  44. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  45. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  46. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  47. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  49. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  50. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -5
  51. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  52. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  54. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
  55. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  56. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  57. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +2 -5
  58. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  59. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  60. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  61. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  62. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  63. numba_cuda/numba/cuda/utils.py +785 -0
  64. numba_cuda/numba/cuda/vector_types.py +1 -1
  65. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
  66. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +69 -56
  67. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  68. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
  69. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
  70. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
@@ -44,7 +44,8 @@ from collections import namedtuple, deque
44
44
 
45
45
 
46
46
  from numba import mviewbuf
47
- from numba.core import utils, serialize, config
47
+ from numba.core import config
48
+ from numba.cuda import utils, serialize
48
49
  from .error import CudaSupportError, CudaDriverError
49
50
  from .drvapi import API_PROTOTYPES
50
51
  from .drvapi import cu_occupancy_b2d_size, cu_stream_callback_pyobj, cu_uuid
@@ -490,11 +491,11 @@ class Driver(object):
490
491
  with self.get_active_context() as ac:
491
492
  if ac.devnum is not None:
492
493
  if USE_NV_BINDING:
493
- return driver.cuCtxPopCurrent()
494
+ popped = drvapi.cu_context(int(driver.cuCtxPopCurrent()))
494
495
  else:
495
496
  popped = drvapi.cu_context()
496
497
  driver.cuCtxPopCurrent(byref(popped))
497
- return popped
498
+ return popped
498
499
 
499
500
  def get_active_context(self):
500
501
  """Returns an instance of ``_ActiveContext``."""
@@ -538,6 +539,8 @@ class _ActiveContext(object):
538
539
  hctx = driver.cuCtxGetCurrent()
539
540
  if int(hctx) == 0:
540
541
  hctx = None
542
+ else:
543
+ hctx = drvapi.cu_context(int(hctx))
541
544
  else:
542
545
  hctx = drvapi.cu_context(0)
543
546
  driver.cuCtxGetCurrent(byref(hctx))
@@ -716,6 +719,7 @@ class Device(object):
716
719
  # create primary context
717
720
  if USE_NV_BINDING:
718
721
  hctx = driver.cuDevicePrimaryCtxRetain(self.id)
722
+ hctx = drvapi.cu_context(int(hctx))
719
723
  else:
720
724
  hctx = drvapi.cu_context()
721
725
  driver.cuDevicePrimaryCtxRetain(byref(hctx), self.id)
@@ -1254,6 +1258,7 @@ class _PendingDeallocs(object):
1254
1258
  [dtor, handle, size] = self._cons.popleft()
1255
1259
  _logger.info("dealloc: %s %s bytes", dtor.__name__, size)
1256
1260
  dtor(handle)
1261
+
1257
1262
  self._size = 0
1258
1263
 
1259
1264
  @contextlib.contextmanager
@@ -1430,7 +1435,10 @@ class Context(object):
1430
1435
  """
1431
1436
  Pushes this context on the current CPU Thread.
1432
1437
  """
1433
- driver.cuCtxPushCurrent(self.handle)
1438
+ if USE_NV_BINDING:
1439
+ driver.cuCtxPushCurrent(self.handle.value)
1440
+ else:
1441
+ driver.cuCtxPushCurrent(self.handle)
1434
1442
  self.prepare_for_use()
1435
1443
 
1436
1444
  def pop(self):
@@ -1439,10 +1447,7 @@ class Context(object):
1439
1447
  must be at the top of the context stack, otherwise an error will occur.
1440
1448
  """
1441
1449
  popped = driver.pop_active_context()
1442
- if USE_NV_BINDING:
1443
- assert int(popped) == int(self.handle)
1444
- else:
1445
- assert popped.value == self.handle.value
1450
+ assert popped.value == self.handle.value
1446
1451
 
1447
1452
  def memalloc(self, bytesize):
1448
1453
  return self.memory_manager.memalloc(bytesize)
@@ -1535,21 +1540,25 @@ class Context(object):
1535
1540
 
1536
1541
  def get_default_stream(self):
1537
1542
  if USE_NV_BINDING:
1538
- handle = binding.CUstream(CU_STREAM_DEFAULT)
1543
+ handle = drvapi.cu_stream(int(binding.CUstream(CU_STREAM_DEFAULT)))
1539
1544
  else:
1540
1545
  handle = drvapi.cu_stream(drvapi.CU_STREAM_DEFAULT)
1541
1546
  return Stream(weakref.proxy(self), handle, None)
1542
1547
 
1543
1548
  def get_legacy_default_stream(self):
1544
1549
  if USE_NV_BINDING:
1545
- handle = binding.CUstream(binding.CU_STREAM_LEGACY)
1550
+ handle = drvapi.cu_stream(
1551
+ int(binding.CUstream(binding.CU_STREAM_LEGACY))
1552
+ )
1546
1553
  else:
1547
1554
  handle = drvapi.cu_stream(drvapi.CU_STREAM_LEGACY)
1548
1555
  return Stream(weakref.proxy(self), handle, None)
1549
1556
 
1550
1557
  def get_per_thread_default_stream(self):
1551
1558
  if USE_NV_BINDING:
1552
- handle = binding.CUstream(binding.CU_STREAM_PER_THREAD)
1559
+ handle = drvapi.cu_stream(
1560
+ int(binding.CUstream(binding.CU_STREAM_PER_THREAD))
1561
+ )
1553
1562
  else:
1554
1563
  handle = drvapi.cu_stream(drvapi.CU_STREAM_PER_THREAD)
1555
1564
  return Stream(weakref.proxy(self), handle, None)
@@ -1561,7 +1570,7 @@ class Context(object):
1561
1570
  # default stream, which we define also as CU_STREAM_DEFAULT when
1562
1571
  # the NV binding is in use).
1563
1572
  flags = binding.CUstream_flags.CU_STREAM_DEFAULT.value
1564
- handle = driver.cuStreamCreate(flags)
1573
+ handle = drvapi.cu_stream(int(driver.cuStreamCreate(flags)))
1565
1574
  else:
1566
1575
  handle = drvapi.cu_stream()
1567
1576
  driver.cuStreamCreate(byref(handle), 0)
@@ -1575,7 +1584,7 @@ class Context(object):
1575
1584
  if not isinstance(ptr, int):
1576
1585
  raise TypeError("ptr for external stream must be an int")
1577
1586
  if USE_NV_BINDING:
1578
- handle = binding.CUstream(ptr)
1587
+ handle = drvapi.cu_stream(int(binding.CUstream(ptr)))
1579
1588
  else:
1580
1589
  handle = drvapi.cu_stream(ptr)
1581
1590
  return Stream(weakref.proxy(self), handle, None, external=True)
@@ -1585,7 +1594,7 @@ class Context(object):
1585
1594
  if not timing:
1586
1595
  flags |= enums.CU_EVENT_DISABLE_TIMING
1587
1596
  if USE_NV_BINDING:
1588
- handle = driver.cuEventCreate(flags)
1597
+ handle = drvapi.cu_event(int(driver.cuEventCreate(flags)))
1589
1598
  else:
1590
1599
  handle = drvapi.cu_event()
1591
1600
  driver.cuEventCreate(byref(handle), flags)
@@ -1776,14 +1785,14 @@ def _pin_finalizer(memory_manager, ptr, alloc_key, mapped):
1776
1785
 
1777
1786
  def _event_finalizer(deallocs, handle):
1778
1787
  def core():
1779
- deallocs.add_item(driver.cuEventDestroy, handle)
1788
+ deallocs.add_item(driver.cuEventDestroy, handle.value)
1780
1789
 
1781
1790
  return core
1782
1791
 
1783
1792
 
1784
1793
  def _stream_finalizer(deallocs, handle):
1785
1794
  def core():
1786
- deallocs.add_item(driver.cuStreamDestroy, handle)
1795
+ deallocs.add_item(driver.cuStreamDestroy, handle.value)
1787
1796
 
1788
1797
  return core
1789
1798
 
@@ -2054,6 +2063,9 @@ class MemoryPointer(object):
2054
2063
  __cuda_memory__ = True
2055
2064
 
2056
2065
  def __init__(self, context, pointer, size, owner=None, finalizer=None):
2066
+ if USE_NV_BINDING and isinstance(pointer, ctypes.c_void_p):
2067
+ pointer = binding.CUdeviceptr(pointer.value)
2068
+
2057
2069
  self.context = context
2058
2070
  self.device_pointer = pointer
2059
2071
  self.size = size
@@ -2086,9 +2098,11 @@ class MemoryPointer(object):
2086
2098
  def memset(self, byte, count=None, stream=0):
2087
2099
  count = self.size if count is None else count
2088
2100
  if stream:
2089
- driver.cuMemsetD8Async(
2090
- self.device_pointer, byte, count, stream.handle
2091
- )
2101
+ if USE_NV_BINDING:
2102
+ handle = stream.handle.value
2103
+ else:
2104
+ handle = stream.handle
2105
+ driver.cuMemsetD8Async(self.device_pointer, byte, count, handle)
2092
2106
  else:
2093
2107
  driver.cuMemsetD8(self.device_pointer, byte, count)
2094
2108
 
@@ -2326,27 +2340,16 @@ class Stream(object):
2326
2340
  weakref.finalize(self, finalizer)
2327
2341
 
2328
2342
  def __int__(self):
2329
- if USE_NV_BINDING:
2330
- return int(self.handle)
2331
- else:
2332
- # The default stream's handle.value is 0, which gives `None`
2333
- return self.handle.value or drvapi.CU_STREAM_DEFAULT
2343
+ # The default stream's handle.value is 0, which gives `None`
2344
+ return self.handle.value or drvapi.CU_STREAM_DEFAULT
2334
2345
 
2335
2346
  def __repr__(self):
2336
- if USE_NV_BINDING:
2337
- default_streams = {
2338
- CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2339
- binding.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2340
- binding.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2341
- }
2342
- ptr = int(self.handle) or 0
2343
- else:
2344
- default_streams = {
2345
- drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2346
- drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2347
- drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2348
- }
2349
- ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2347
+ default_streams = {
2348
+ drvapi.CU_STREAM_DEFAULT: "<Default CUDA stream on %s>",
2349
+ drvapi.CU_STREAM_LEGACY: "<Legacy default CUDA stream on %s>",
2350
+ drvapi.CU_STREAM_PER_THREAD: "<Per-thread default CUDA stream on %s>",
2351
+ }
2352
+ ptr = self.handle.value or drvapi.CU_STREAM_DEFAULT
2350
2353
 
2351
2354
  if ptr in default_streams:
2352
2355
  return default_streams[ptr] % self.context
@@ -2360,7 +2363,11 @@ class Stream(object):
2360
2363
  Wait for all commands in this stream to execute. This will commit any
2361
2364
  pending memory transfers.
2362
2365
  """
2363
- driver.cuStreamSynchronize(self.handle)
2366
+ if USE_NV_BINDING:
2367
+ handle = self.handle.value
2368
+ else:
2369
+ handle = self.handle
2370
+ driver.cuStreamSynchronize(handle)
2364
2371
 
2365
2372
  @contextlib.contextmanager
2366
2373
  def auto_synchronize(self):
@@ -2385,6 +2392,16 @@ class Stream(object):
2385
2392
  callback will block later work in the stream and may block other
2386
2393
  callbacks from being executed.
2387
2394
 
2395
+ .. warning::
2396
+ There is a potential for deadlock due to a lock ordering issue
2397
+ between the GIL and the CUDA driver lock when using libraries
2398
+ that call CUDA functions without releasing the GIL. This can
2399
+ occur when the callback function, which holds the CUDA driver lock,
2400
+ attempts to acquire the GIL while another thread that holds the GIL
2401
+ is waiting for the CUDA driver lock. Consider using libraries that
2402
+ properly release the GIL around CUDA operations or restructure
2403
+ your code to avoid this situation.
2404
+
2388
2405
  Note: The driver function underlying this method is marked for
2389
2406
  eventual deprecation and may be replaced in a future CUDA release.
2390
2407
 
@@ -2398,9 +2415,11 @@ class Stream(object):
2398
2415
  stream_callback = binding.CUstreamCallback(ptr)
2399
2416
  # The callback needs to receive a pointer to the data PyObject
2400
2417
  data = id(data)
2418
+ handle = self.handle.value
2401
2419
  else:
2402
2420
  stream_callback = self._stream_callback
2403
- driver.cuStreamAddCallback(self.handle, stream_callback, data, 0)
2421
+ handle = self.handle
2422
+ driver.cuStreamAddCallback(handle, stream_callback, data, 0)
2404
2423
 
2405
2424
  @staticmethod
2406
2425
  @cu_stream_callback_pyobj
@@ -2417,6 +2436,16 @@ class Stream(object):
2417
2436
  """
2418
2437
  Return an awaitable that resolves once all preceding stream operations
2419
2438
  are complete. The result of the awaitable is the current stream.
2439
+
2440
+ .. warning::
2441
+ There is a potential for deadlock due to a lock ordering issue
2442
+ between the GIL and the CUDA driver lock when using libraries
2443
+ that call CUDA functions without releasing the GIL. This can
2444
+ occur when the callback function (internally used by this method),
2445
+ which holds the CUDA driver lock, attempts to acquire the GIL
2446
+ while another thread that holds the GIL is waiting for the CUDA driver lock.
2447
+ Consider using libraries that properly release the GIL around
2448
+ CUDA operations or restructure your code to avoid this situation.
2420
2449
  """
2421
2450
  loop = asyncio.get_running_loop()
2422
2451
  future = loop.create_future()
@@ -2468,27 +2497,35 @@ class Event(object):
2468
2497
  completed.
2469
2498
  """
2470
2499
  if USE_NV_BINDING:
2471
- hstream = stream.handle if stream else binding.CUstream(0)
2500
+ hstream = stream.handle.value if stream else binding.CUstream(0)
2501
+ handle = self.handle.value
2472
2502
  else:
2473
2503
  hstream = stream.handle if stream else 0
2474
- driver.cuEventRecord(self.handle, hstream)
2504
+ handle = self.handle
2505
+ driver.cuEventRecord(handle, hstream)
2475
2506
 
2476
2507
  def synchronize(self):
2477
2508
  """
2478
2509
  Synchronize the host thread for the completion of the event.
2479
2510
  """
2480
- driver.cuEventSynchronize(self.handle)
2511
+ if USE_NV_BINDING:
2512
+ handle = self.handle.value
2513
+ else:
2514
+ handle = self.handle
2515
+ driver.cuEventSynchronize(handle)
2481
2516
 
2482
2517
  def wait(self, stream=0):
2483
2518
  """
2484
2519
  All future works submitted to stream will wait util the event completes.
2485
2520
  """
2486
2521
  if USE_NV_BINDING:
2487
- hstream = stream.handle if stream else binding.CUstream(0)
2522
+ hstream = stream.handle.value if stream else binding.CUstream(0)
2523
+ handle = self.handle.value
2488
2524
  else:
2489
2525
  hstream = stream.handle if stream else 0
2526
+ handle = self.handle
2490
2527
  flags = 0
2491
- driver.cuStreamWaitEvent(hstream, self.handle, flags)
2528
+ driver.cuStreamWaitEvent(hstream, handle, flags)
2492
2529
 
2493
2530
  def elapsed_time(self, evtend):
2494
2531
  return event_elapsed_time(self, evtend)
@@ -2499,7 +2536,9 @@ def event_elapsed_time(evtstart, evtend):
2499
2536
  Compute the elapsed time between two events in milliseconds.
2500
2537
  """
2501
2538
  if USE_NV_BINDING:
2502
- return driver.cuEventElapsedTime(evtstart.handle, evtend.handle)
2539
+ return driver.cuEventElapsedTime(
2540
+ evtstart.handle.value, evtend.handle.value
2541
+ )
2503
2542
  else:
2504
2543
  msec = c_float()
2505
2544
  driver.cuEventElapsedTime(byref(msec), evtstart.handle, evtend.handle)
@@ -3477,7 +3516,11 @@ def host_to_device(dst, src, size, stream=0):
3477
3516
  if stream:
3478
3517
  assert isinstance(stream, Stream)
3479
3518
  fn = driver.cuMemcpyHtoDAsync
3480
- varargs.append(stream.handle)
3519
+ if USE_NV_BINDING:
3520
+ handle = stream.handle.value
3521
+ else:
3522
+ handle = stream.handle
3523
+ varargs.append(handle)
3481
3524
  else:
3482
3525
  fn = driver.cuMemcpyHtoD
3483
3526
 
@@ -3495,7 +3538,11 @@ def device_to_host(dst, src, size, stream=0):
3495
3538
  if stream:
3496
3539
  assert isinstance(stream, Stream)
3497
3540
  fn = driver.cuMemcpyDtoHAsync
3498
- varargs.append(stream.handle)
3541
+ if USE_NV_BINDING:
3542
+ handle = stream.handle.value
3543
+ else:
3544
+ handle = stream.handle
3545
+ varargs.append(handle)
3499
3546
  else:
3500
3547
  fn = driver.cuMemcpyDtoH
3501
3548
 
@@ -3513,7 +3560,11 @@ def device_to_device(dst, src, size, stream=0):
3513
3560
  if stream:
3514
3561
  assert isinstance(stream, Stream)
3515
3562
  fn = driver.cuMemcpyDtoDAsync
3516
- varargs.append(stream.handle)
3563
+ if USE_NV_BINDING:
3564
+ handle = stream.handle.value
3565
+ else:
3566
+ handle = stream.handle
3567
+ varargs.append(handle)
3517
3568
  else:
3518
3569
  fn = driver.cuMemcpyDtoD
3519
3570
 
@@ -3534,7 +3585,11 @@ def device_memset(dst, val, size, stream=0):
3534
3585
  if stream:
3535
3586
  assert isinstance(stream, Stream)
3536
3587
  fn = driver.cuMemsetD8Async
3537
- varargs.append(stream.handle)
3588
+ if USE_NV_BINDING:
3589
+ handle = stream.handle.value
3590
+ else:
3591
+ handle = stream.handle
3592
+ varargs.append(handle)
3538
3593
  else:
3539
3594
  fn = driver.cuMemsetD8
3540
3595
 
@@ -14,7 +14,7 @@ from llvmlite import ir
14
14
 
15
15
  from .error import NvvmError, NvvmSupportError, NvvmWarning
16
16
  from .libs import get_libdevice, open_libdevice, open_cudalib
17
- from numba.core import cgutils
17
+ from numba.cuda import cgutils
18
18
 
19
19
 
20
20
  logger = logging.getLogger(__name__)
@@ -6,15 +6,16 @@ import struct
6
6
  from llvmlite import ir
7
7
  import llvmlite.binding as ll
8
8
 
9
- from numba.core.imputils import Registry, lower_cast
9
+ from numba.core.imputils import Registry
10
10
  from numba.core.typing.npydecl import parse_dtype
11
11
  from numba.core.datamodel import models
12
- from numba.core import types, cgutils
12
+ from numba.core import types
13
+ from numba.cuda import cgutils
13
14
  from numba.np import ufunc_db
14
15
  from numba.np.npyimpl import register_ufuncs
15
16
  from .cudadrv import nvvm
16
17
  from numba import cuda
17
- from numba.cuda import nvvmutils, stubs, errors
18
+ from numba.cuda import nvvmutils, stubs
18
19
  from numba.cuda.types import dim3, CUDADispatcher
19
20
 
20
21
  registry = Registry()
@@ -346,181 +347,6 @@ def ptx_fma(context, builder, sig, args):
346
347
  return builder.fma(*args)
347
348
 
348
349
 
349
- def float16_float_ty_constraint(bitwidth):
350
- typemap = {32: ("f32", "f"), 64: ("f64", "d")}
351
-
352
- try:
353
- return typemap[bitwidth]
354
- except KeyError:
355
- msg = f"Conversion between float16 and float{bitwidth} unsupported"
356
- raise errors.CudaLoweringError(msg)
357
-
358
-
359
- @lower_cast(types.float16, types.Float)
360
- def float16_to_float_cast(context, builder, fromty, toty, val):
361
- if fromty.bitwidth == toty.bitwidth:
362
- return val
363
-
364
- ty, constraint = float16_float_ty_constraint(toty.bitwidth)
365
-
366
- fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
367
- asm = ir.InlineAsm(fnty, f"cvt.{ty}.f16 $0, $1;", f"={constraint},h")
368
- return builder.call(asm, [val])
369
-
370
-
371
- @lower_cast(types.Float, types.float16)
372
- def float_to_float16_cast(context, builder, fromty, toty, val):
373
- if fromty.bitwidth == toty.bitwidth:
374
- return val
375
-
376
- ty, constraint = float16_float_ty_constraint(fromty.bitwidth)
377
-
378
- fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
379
- asm = ir.InlineAsm(fnty, f"cvt.rn.f16.{ty} $0, $1;", f"=h,{constraint}")
380
- return builder.call(asm, [val])
381
-
382
-
383
- def float16_int_constraint(bitwidth):
384
- typemap = {8: "c", 16: "h", 32: "r", 64: "l"}
385
-
386
- try:
387
- return typemap[bitwidth]
388
- except KeyError:
389
- msg = f"Conversion between float16 and int{bitwidth} unsupported"
390
- raise errors.CudaLoweringError(msg)
391
-
392
-
393
- @lower_cast(types.float16, types.Integer)
394
- def float16_to_integer_cast(context, builder, fromty, toty, val):
395
- bitwidth = toty.bitwidth
396
- constraint = float16_int_constraint(bitwidth)
397
- signedness = "s" if toty.signed else "u"
398
-
399
- fnty = ir.FunctionType(context.get_value_type(toty), [ir.IntType(16)])
400
- asm = ir.InlineAsm(
401
- fnty, f"cvt.rni.{signedness}{bitwidth}.f16 $0, $1;", f"={constraint},h"
402
- )
403
- return builder.call(asm, [val])
404
-
405
-
406
- @lower_cast(types.Integer, types.float16)
407
- @lower_cast(types.IntegerLiteral, types.float16)
408
- def integer_to_float16_cast(context, builder, fromty, toty, val):
409
- bitwidth = fromty.bitwidth
410
- constraint = float16_int_constraint(bitwidth)
411
- signedness = "s" if fromty.signed else "u"
412
-
413
- fnty = ir.FunctionType(ir.IntType(16), [context.get_value_type(fromty)])
414
- asm = ir.InlineAsm(
415
- fnty, f"cvt.rn.f16.{signedness}{bitwidth} $0, $1;", f"=h,{constraint}"
416
- )
417
- return builder.call(asm, [val])
418
-
419
-
420
- def lower_fp16_binary(fn, op):
421
- @lower(fn, types.float16, types.float16)
422
- def ptx_fp16_binary(context, builder, sig, args):
423
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
424
- asm = ir.InlineAsm(fnty, f"{op}.f16 $0,$1,$2;", "=h,h,h")
425
- return builder.call(asm, args)
426
-
427
-
428
- lower_fp16_binary(stubs.fp16.hadd, "add")
429
- lower_fp16_binary(operator.add, "add")
430
- lower_fp16_binary(operator.iadd, "add")
431
- lower_fp16_binary(stubs.fp16.hsub, "sub")
432
- lower_fp16_binary(operator.sub, "sub")
433
- lower_fp16_binary(operator.isub, "sub")
434
- lower_fp16_binary(stubs.fp16.hmul, "mul")
435
- lower_fp16_binary(operator.mul, "mul")
436
- lower_fp16_binary(operator.imul, "mul")
437
-
438
-
439
- @lower(stubs.fp16.hneg, types.float16)
440
- def ptx_fp16_hneg(context, builder, sig, args):
441
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
442
- asm = ir.InlineAsm(fnty, "neg.f16 $0, $1;", "=h,h")
443
- return builder.call(asm, args)
444
-
445
-
446
- @lower(operator.neg, types.float16)
447
- def operator_hneg(context, builder, sig, args):
448
- return ptx_fp16_hneg(context, builder, sig, args)
449
-
450
-
451
- @lower(stubs.fp16.habs, types.float16)
452
- def ptx_fp16_habs(context, builder, sig, args):
453
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16)])
454
- asm = ir.InlineAsm(fnty, "abs.f16 $0, $1;", "=h,h")
455
- return builder.call(asm, args)
456
-
457
-
458
- @lower(abs, types.float16)
459
- def operator_habs(context, builder, sig, args):
460
- return ptx_fp16_habs(context, builder, sig, args)
461
-
462
-
463
- @lower(stubs.fp16.hfma, types.float16, types.float16, types.float16)
464
- def ptx_hfma(context, builder, sig, args):
465
- argtys = [ir.IntType(16), ir.IntType(16), ir.IntType(16)]
466
- fnty = ir.FunctionType(ir.IntType(16), argtys)
467
- asm = ir.InlineAsm(fnty, "fma.rn.f16 $0,$1,$2,$3;", "=h,h,h,h")
468
- return builder.call(asm, args)
469
-
470
-
471
- @lower(operator.truediv, types.float16, types.float16)
472
- @lower(operator.itruediv, types.float16, types.float16)
473
- def fp16_div_impl(context, builder, sig, args):
474
- def fp16_div(x, y):
475
- return cuda.fp16.hdiv(x, y)
476
-
477
- return context.compile_internal(builder, fp16_div, sig, args)
478
-
479
-
480
- _fp16_cmp = """{{
481
- .reg .pred __$$f16_cmp_tmp;
482
- setp.{op}.f16 __$$f16_cmp_tmp, $1, $2;
483
- selp.u16 $0, 1, 0, __$$f16_cmp_tmp;
484
- }}"""
485
-
486
-
487
- def _gen_fp16_cmp(op):
488
- def ptx_fp16_comparison(context, builder, sig, args):
489
- fnty = ir.FunctionType(ir.IntType(16), [ir.IntType(16), ir.IntType(16)])
490
- asm = ir.InlineAsm(fnty, _fp16_cmp.format(op=op), "=h,h,h")
491
- result = builder.call(asm, args)
492
-
493
- zero = context.get_constant(types.int16, 0)
494
- int_result = builder.bitcast(result, ir.IntType(16))
495
- return builder.icmp_unsigned("!=", int_result, zero)
496
-
497
- return ptx_fp16_comparison
498
-
499
-
500
- lower(stubs.fp16.heq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
501
- lower(operator.eq, types.float16, types.float16)(_gen_fp16_cmp("eq"))
502
- lower(stubs.fp16.hne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
503
- lower(operator.ne, types.float16, types.float16)(_gen_fp16_cmp("ne"))
504
- lower(stubs.fp16.hge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
505
- lower(operator.ge, types.float16, types.float16)(_gen_fp16_cmp("ge"))
506
- lower(stubs.fp16.hgt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
507
- lower(operator.gt, types.float16, types.float16)(_gen_fp16_cmp("gt"))
508
- lower(stubs.fp16.hle, types.float16, types.float16)(_gen_fp16_cmp("le"))
509
- lower(operator.le, types.float16, types.float16)(_gen_fp16_cmp("le"))
510
- lower(stubs.fp16.hlt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
511
- lower(operator.lt, types.float16, types.float16)(_gen_fp16_cmp("lt"))
512
-
513
-
514
- def lower_fp16_minmax(fn, fname, op):
515
- @lower(fn, types.float16, types.float16)
516
- def ptx_fp16_minmax(context, builder, sig, args):
517
- choice = _gen_fp16_cmp(op)(context, builder, sig, args)
518
- return builder.select(choice, args[0], args[1])
519
-
520
-
521
- lower_fp16_minmax(stubs.fp16.hmax, "max", "gt")
522
- lower_fp16_minmax(stubs.fp16.hmin, "min", "lt")
523
-
524
350
  # See:
525
351
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrt.html#__nv_cbrt
526
352
  # https://docs.nvidia.com/cuda/libdevice-users-guide/__nv_cbrtf.html#__nv_cbrtf