numba-cuda 0.16.0__py3-none-any.whl → 0.18.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cuda_paths.py +1 -1
  14. numba_cuda/numba/cuda/cudadecl.py +0 -268
  15. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  16. numba_cuda/numba/cuda/cudadrv/devices.py +4 -6
  17. numba_cuda/numba/cuda/cudadrv/driver.py +105 -50
  18. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  19. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  20. numba_cuda/numba/cuda/debuginfo.py +469 -3
  21. numba_cuda/numba/cuda/device_init.py +0 -1
  22. numba_cuda/numba/cuda/dispatcher.py +311 -14
  23. numba_cuda/numba/cuda/extending.py +2 -1
  24. numba_cuda/numba/cuda/fp16.py +348 -0
  25. numba_cuda/numba/cuda/intrinsics.py +1 -1
  26. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  27. numba_cuda/numba/cuda/lowering.py +1833 -8
  28. numba_cuda/numba/cuda/mathimpl.py +2 -90
  29. numba_cuda/numba/cuda/memory_management/nrt.py +1 -1
  30. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  31. numba_cuda/numba/cuda/printimpl.py +2 -1
  32. numba_cuda/numba/cuda/serialize.py +264 -0
  33. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  34. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  35. numba_cuda/numba/cuda/stubs.py +0 -308
  36. numba_cuda/numba/cuda/target.py +13 -5
  37. numba_cuda/numba/cuda/testing.py +156 -5
  38. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  39. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  40. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  41. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +16 -5
  42. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_driver.py +5 -1
  43. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  44. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  45. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  46. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  47. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  49. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  50. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -5
  51. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  52. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  54. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +94 -24
  55. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  56. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  57. numba_cuda/numba/cuda/tests/cudapy/test_stream_api.py +2 -5
  58. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  59. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  60. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  61. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  62. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  63. numba_cuda/numba/cuda/utils.py +785 -0
  64. numba_cuda/numba/cuda/vector_types.py +1 -1
  65. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/METADATA +18 -4
  66. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/RECORD +69 -56
  67. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  68. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/WHEEL +0 -0
  69. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/licenses/LICENSE +0 -0
  70. {numba_cuda-0.16.0.dist-info → numba_cuda-0.18.0.dist-info}/top_level.txt +0 -0
@@ -2,13 +2,19 @@ import numpy as np
2
2
  import os
3
3
  import sys
4
4
  import ctypes
5
+ import collections
5
6
  import functools
7
+ import types as pytypes
8
+ import weakref
9
+ import uuid
6
10
 
7
- from numba.core import config, serialize, sigutils, types, typing, utils
8
- from numba.core.caching import Cache, CacheImpl
11
+ from numba.core import compiler, sigutils, types, typing, config
12
+ from numba.cuda import serialize, utils
13
+ from numba.cuda.core.caching import Cache, CacheImpl, NullCache
9
14
  from numba.core.compiler_lock import global_compiler_lock
10
- from numba.core.dispatcher import Dispatcher
11
- from numba.core.errors import NumbaPerformanceWarning
15
+ from numba.core.dispatcher import _DispatcherBase
16
+ from numba.core.errors import NumbaPerformanceWarning, TypingError
17
+ from numba.core.typing.templates import fold_arguments
12
18
  from numba.core.typing.typeof import Purpose, typeof
13
19
  from numba.cuda.api import get_current_device
14
20
  from numba.cuda.args import wrap_arg
@@ -185,10 +191,6 @@ class _Kernel(serialize.ReduceMixin):
185
191
 
186
192
  # Link to the helper library functions if needed
187
193
  link_to_library_functions(reshape_funcs, "reshape_funcs.cu")
188
- # Link to the CUDA FP16 math library functions if needed
189
- link_to_library_functions(
190
- cuda_fp16_math_funcs, "cpp_function_wrappers.cu", "__numba_wrapper_"
191
- )
192
194
 
193
195
  self.maybe_link_nrt(link, tgt_ctx, asm)
194
196
 
@@ -384,6 +386,12 @@ class _Kernel(serialize.ReduceMixin):
384
386
  """
385
387
  return self._codelibrary.get_asm_str(cc=cc)
386
388
 
389
+ def inspect_lto_ptx(self, cc):
390
+ """
391
+ Returns the PTX code for the external functions linked to this kernel.
392
+ """
393
+ return self._codelibrary.get_lto_ptx(cc=cc)
394
+
387
395
  def inspect_sass_cfg(self):
388
396
  """
389
397
  Returns the CFG of the SASS for this kernel.
@@ -458,11 +466,10 @@ class _Kernel(serialize.ReduceMixin):
458
466
  self._prepare_args(t, v, stream, retr, kernelargs)
459
467
 
460
468
  if driver.USE_NV_BINDING:
461
- zero_stream = driver.binding.CUstream(0)
469
+ stream_handle = stream and stream.handle.value or 0
462
470
  else:
463
471
  zero_stream = None
464
-
465
- stream_handle = stream and stream.handle or zero_stream
472
+ stream_handle = stream and stream.handle or zero_stream
466
473
 
467
474
  # Invoke kernel
468
475
  driver.launch_kernel(
@@ -726,7 +733,134 @@ class CUDACache(Cache):
726
733
  return super().load_overload(sig, target_context)
727
734
 
728
735
 
729
- class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
736
+ class _MemoMixin:
737
+ __uuid = None
738
+ # A {uuid -> instance} mapping, for deserialization
739
+ _memo = weakref.WeakValueDictionary()
740
+ # hold refs to last N functions deserialized, retaining them in _memo
741
+ # regardless of whether there is another reference
742
+ _recent = collections.deque(maxlen=config.FUNCTION_CACHE_SIZE)
743
+
744
+ @property
745
+ def _uuid(self):
746
+ """
747
+ An instance-specific UUID, to avoid multiple deserializations of
748
+ a given instance.
749
+
750
+ Note: this is lazily-generated, for performance reasons.
751
+ """
752
+ u = self.__uuid
753
+ if u is None:
754
+ u = str(uuid.uuid4())
755
+ self._set_uuid(u)
756
+ return u
757
+
758
+ def _set_uuid(self, u):
759
+ assert self.__uuid is None
760
+ self.__uuid = u
761
+ self._memo[u] = self
762
+ self._recent.append(self)
763
+
764
+
765
+ _CompileStats = collections.namedtuple(
766
+ "_CompileStats", ("cache_path", "cache_hits", "cache_misses")
767
+ )
768
+
769
+
770
+ class _FunctionCompiler(object):
771
+ def __init__(self, py_func, targetdescr, targetoptions, pipeline_class):
772
+ self.py_func = py_func
773
+ self.targetdescr = targetdescr
774
+ self.targetoptions = targetoptions
775
+ self.pysig = utils.pysignature(self.py_func)
776
+ self.pipeline_class = pipeline_class
777
+ # Remember key=(args, return_type) combinations that will fail
778
+ # compilation to avoid compilation attempt on them. The values are
779
+ # the exceptions.
780
+ self._failed_cache = {}
781
+
782
+ def fold_argument_types(self, args, kws):
783
+ """
784
+ Given positional and named argument types, fold keyword arguments
785
+ and resolve defaults by inserting types.Omitted() instances.
786
+
787
+ A (pysig, argument types) tuple is returned.
788
+ """
789
+
790
+ def normal_handler(index, param, value):
791
+ return value
792
+
793
+ def default_handler(index, param, default):
794
+ return types.Omitted(default)
795
+
796
+ def stararg_handler(index, param, values):
797
+ return types.StarArgTuple(values)
798
+
799
+ # For now, we take argument values from the @jit function
800
+ args = fold_arguments(
801
+ self.pysig,
802
+ args,
803
+ kws,
804
+ normal_handler,
805
+ default_handler,
806
+ stararg_handler,
807
+ )
808
+ return self.pysig, args
809
+
810
+ def compile(self, args, return_type):
811
+ status, retval = self._compile_cached(args, return_type)
812
+ if status:
813
+ return retval
814
+ else:
815
+ raise retval
816
+
817
+ def _compile_cached(self, args, return_type):
818
+ key = tuple(args), return_type
819
+ try:
820
+ return False, self._failed_cache[key]
821
+ except KeyError:
822
+ pass
823
+
824
+ try:
825
+ retval = self._compile_core(args, return_type)
826
+ except TypingError as e:
827
+ self._failed_cache[key] = e
828
+ return False, e
829
+ else:
830
+ return True, retval
831
+
832
+ def _compile_core(self, args, return_type):
833
+ flags = compiler.Flags()
834
+ self.targetdescr.options.parse_as_flags(flags, self.targetoptions)
835
+ flags = self._customize_flags(flags)
836
+
837
+ impl = self._get_implementation(args, {})
838
+ cres = compiler.compile_extra(
839
+ self.targetdescr.typing_context,
840
+ self.targetdescr.target_context,
841
+ impl,
842
+ args=args,
843
+ return_type=return_type,
844
+ flags=flags,
845
+ locals={},
846
+ pipeline_class=self.pipeline_class,
847
+ )
848
+ # Check typing error if object mode is used
849
+ if cres.typing_error is not None and not flags.enable_pyobject:
850
+ raise cres.typing_error
851
+ return cres
852
+
853
+ def get_globals_for_reduction(self):
854
+ return serialize._get_function_globals_for_reduction(self.py_func)
855
+
856
+ def _get_implementation(self, args, kws):
857
+ return self.py_func
858
+
859
+ def _customize_flags(self, flags):
860
+ return flags
861
+
862
+
863
+ class CUDADispatcher(serialize.ReduceMixin, _MemoMixin, _DispatcherBase):
730
864
  """
731
865
  CUDA Dispatcher object. When configured and called, the dispatcher will
732
866
  specialize itself for the given arguments (if no suitable specialized
@@ -745,10 +879,42 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
745
879
  targetdescr = cuda_target
746
880
 
747
881
  def __init__(self, py_func, targetoptions, pipeline_class=CUDACompiler):
748
- super().__init__(
749
- py_func, targetoptions=targetoptions, pipeline_class=pipeline_class
882
+ """
883
+ Parameters
884
+ ----------
885
+ py_func: function object to be compiled
886
+ targetoptions: dict, optional
887
+ Target-specific config options.
888
+ pipeline_class: type numba.compiler.CompilerBase
889
+ The compiler pipeline type.
890
+ """
891
+ self.typingctx = self.targetdescr.typing_context
892
+ self.targetctx = self.targetdescr.target_context
893
+
894
+ pysig = utils.pysignature(py_func)
895
+ arg_count = len(pysig.parameters)
896
+ can_fallback = not targetoptions.get("nopython", False)
897
+
898
+ _DispatcherBase.__init__(
899
+ self,
900
+ arg_count,
901
+ py_func,
902
+ pysig,
903
+ can_fallback,
904
+ exact_match_required=False,
750
905
  )
751
906
 
907
+ functools.update_wrapper(self, py_func)
908
+
909
+ self.targetoptions = targetoptions
910
+ self._cache = NullCache()
911
+ compiler_class = _FunctionCompiler
912
+ self._compiler = compiler_class(
913
+ py_func, self.targetdescr, targetoptions, pipeline_class
914
+ )
915
+ self._cache_hits = collections.Counter()
916
+ self._cache_misses = collections.Counter()
917
+
752
918
  # The following properties are for specialization of CUDADispatchers. A
753
919
  # specialized CUDADispatcher is one that is compiled for exactly one
754
920
  # set of argument types, and bypasses some argument type checking for
@@ -761,6 +927,15 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
761
927
  # argument types
762
928
  self.specializations = {}
763
929
 
930
+ def dump(self, tab=""):
931
+ print(
932
+ f"{tab}DUMP {type(self).__name__}[{self.py_func.__name__}"
933
+ f", type code={self._type._code}]"
934
+ )
935
+ for cres in self.overloads.values():
936
+ cres.dump(tab=tab + " ")
937
+ print(f"{tab}END DUMP {type(self).__name__}[{self.py_func.__name__}]")
938
+
764
939
  @property
765
940
  def _numba_type_(self):
766
941
  return cuda_types.CUDADispatcher(self)
@@ -768,6 +943,13 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
768
943
  def enable_caching(self):
769
944
  self._cache = CUDACache(self.py_func)
770
945
 
946
+ def __get__(self, obj, objtype=None):
947
+ """Allow a JIT function to be bound as a method to an object"""
948
+ if obj is None: # Unbound method
949
+ return self
950
+ else: # Bound method
951
+ return pytypes.MethodType(self, obj)
952
+
771
953
  @functools.lru_cache(maxsize=128)
772
954
  def configure(self, griddim, blockdim, stream=0, sharedmem=0):
773
955
  griddim, blockdim = normalize_kernel_dimensions(griddim, blockdim)
@@ -1115,6 +1297,93 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1115
1297
 
1116
1298
  return kernel
1117
1299
 
1300
+ def get_compile_result(self, sig):
1301
+ """Compile (if needed) and return the compilation result with the
1302
+ given signature.
1303
+
1304
+ Returns ``CompileResult``.
1305
+ Raises ``NumbaError`` if the signature is incompatible.
1306
+ """
1307
+ atypes = tuple(sig.args)
1308
+ if atypes not in self.overloads:
1309
+ if self._can_compile:
1310
+ # Compiling may raise any NumbaError
1311
+ self.compile(atypes)
1312
+ else:
1313
+ msg = f"{sig} not available and compilation disabled"
1314
+ raise TypingError(msg)
1315
+ return self.overloads[atypes]
1316
+
1317
+ def recompile(self):
1318
+ """
1319
+ Recompile all signatures afresh.
1320
+ """
1321
+ sigs = list(self.overloads)
1322
+ old_can_compile = self._can_compile
1323
+ # Ensure the old overloads are disposed of,
1324
+ # including compiled functions.
1325
+ self._make_finalizer()()
1326
+ self._reset_overloads()
1327
+ self._cache.flush()
1328
+ self._can_compile = True
1329
+ try:
1330
+ for sig in sigs:
1331
+ self.compile(sig)
1332
+ finally:
1333
+ self._can_compile = old_can_compile
1334
+
1335
+ @property
1336
+ def stats(self):
1337
+ return _CompileStats(
1338
+ cache_path=self._cache.cache_path,
1339
+ cache_hits=self._cache_hits,
1340
+ cache_misses=self._cache_misses,
1341
+ )
1342
+
1343
+ def parallel_diagnostics(self, signature=None, level=1):
1344
+ """
1345
+ Print parallel diagnostic information for the given signature. If no
1346
+ signature is present it is printed for all known signatures. level is
1347
+ used to adjust the verbosity, level=1 (default) is minimal verbosity,
1348
+ and 2, 3, and 4 provide increasing levels of verbosity.
1349
+ """
1350
+
1351
+ def dump(sig):
1352
+ ol = self.overloads[sig]
1353
+ pfdiag = ol.metadata.get("parfor_diagnostics", None)
1354
+ if pfdiag is None:
1355
+ msg = "No parfors diagnostic available, is 'parallel=True' set?"
1356
+ raise ValueError(msg)
1357
+ pfdiag.dump(level)
1358
+
1359
+ if signature is not None:
1360
+ dump(signature)
1361
+ else:
1362
+ [dump(sig) for sig in self.signatures]
1363
+
1364
+ def get_metadata(self, signature=None):
1365
+ """
1366
+ Obtain the compilation metadata for a given signature.
1367
+ """
1368
+ if signature is not None:
1369
+ return self.overloads[signature].metadata
1370
+ else:
1371
+ return dict(
1372
+ (sig, self.overloads[sig].metadata) for sig in self.signatures
1373
+ )
1374
+
1375
+ def get_function_type(self):
1376
+ """Return unique function type of dispatcher when possible, otherwise
1377
+ return None.
1378
+
1379
+ A Dispatcher instance has unique function type when it
1380
+ contains exactly one compilation result and its compilation
1381
+ has been disabled (via its disable_compile method).
1382
+ """
1383
+ if not self._can_compile and len(self.overloads) == 1:
1384
+ cres = tuple(self.overloads.values())[0]
1385
+ return types.FunctionType(cres.signature)
1386
+
1118
1387
  def inspect_llvm(self, signature=None):
1119
1388
  """
1120
1389
  Return the LLVM IR for this kernel.
@@ -1170,6 +1439,34 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1170
1439
  for sig, overload in self.overloads.items()
1171
1440
  }
1172
1441
 
1442
+ def inspect_lto_ptx(self, signature=None):
1443
+ """
1444
+ Return link-time optimized PTX code for the given signature.
1445
+
1446
+ :param signature: A tuple of argument types.
1447
+ :return: The PTX code for the given signature, or a dict of PTX codes
1448
+ for all previously-encountered signatures.
1449
+ """
1450
+ cc = get_current_device().compute_capability
1451
+ device = self.targetoptions.get("device")
1452
+
1453
+ if signature is not None:
1454
+ if device:
1455
+ return self.overloads[signature].library.get_lto_ptx(cc)
1456
+ else:
1457
+ return self.overloads[signature].inspect_lto_ptx(cc)
1458
+ else:
1459
+ if device:
1460
+ return {
1461
+ sig: overload.library.get_lto_ptx(cc)
1462
+ for sig, overload in self.overloads.items()
1463
+ }
1464
+ else:
1465
+ return {
1466
+ sig: overload.inspect_lto_ptx(cc)
1467
+ for sig, overload in self.overloads.items()
1468
+ }
1469
+
1173
1470
  def inspect_sass_cfg(self, signature=None):
1174
1471
  """
1175
1472
  Return this kernel's CFG for the device in the current context.
@@ -23,7 +23,8 @@ def make_attribute_wrapper(typeclass, struct_attr, python_attr):
23
23
  from numba.core.datamodel import default_manager
24
24
  from numba.core.datamodel.models import StructModel
25
25
  from numba.core.imputils import impl_ret_borrowed
26
- from numba.core import cgutils, types
26
+ from numba.core import types
27
+ from numba.cuda import cgutils
27
28
 
28
29
  from numba.cuda.models import cuda_data_manager
29
30
  from numba.cuda.cudadecl import registry as cuda_registry