numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +7 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
|
@@ -21,8 +21,9 @@ from functools import cached_property
|
|
|
21
21
|
import numpy as np
|
|
22
22
|
|
|
23
23
|
from numba import types
|
|
24
|
-
from numba.core import errors
|
|
25
|
-
from numba.core
|
|
24
|
+
from numba.core import errors
|
|
25
|
+
from numba.cuda.core import config
|
|
26
|
+
from numba.cuda.typing import cffi_utils
|
|
26
27
|
from numba.cuda.memory_management.nrt import rtsys
|
|
27
28
|
from numba.core.extending import (
|
|
28
29
|
typeof_impl,
|
|
@@ -31,7 +32,7 @@ from numba.core.extending import (
|
|
|
31
32
|
NativeValue,
|
|
32
33
|
)
|
|
33
34
|
from numba.core.datamodel.models import OpaqueModel
|
|
34
|
-
from numba.np import numpy_support
|
|
35
|
+
from numba.cuda.np import numpy_support
|
|
35
36
|
|
|
36
37
|
|
|
37
38
|
class EnableNRTStatsMixin(object):
|
|
@@ -751,16 +752,55 @@ class TestCase(unittest.TestCase):
|
|
|
751
752
|
|
|
752
753
|
return Dummy, DummyType
|
|
753
754
|
|
|
754
|
-
def skip_if_no_external_compiler(self):
|
|
755
|
-
"""
|
|
756
|
-
Call this to ensure the test is skipped if no suitable external compiler
|
|
757
|
-
is found. This is a method on the TestCase opposed to a stand-alone
|
|
758
|
-
decorator so as to make it "lazy" via runtime evaluation opposed to
|
|
759
|
-
running at test-discovery time.
|
|
760
|
-
"""
|
|
761
|
-
# This is a local import to avoid deprecation warnings being generated
|
|
762
|
-
# through the use of the numba.pycc module.
|
|
763
|
-
from numba.pycc.platform import external_compiler_works
|
|
764
755
|
|
|
765
|
-
|
|
766
|
-
|
|
756
|
+
class MemoryLeak(object):
|
|
757
|
+
__enable_leak_check = True
|
|
758
|
+
|
|
759
|
+
def memory_leak_setup(self):
|
|
760
|
+
# Clean up any NRT-backed objects hanging in a dead reference cycle
|
|
761
|
+
gc.collect()
|
|
762
|
+
self.__init_stats = rtsys.get_allocation_stats()
|
|
763
|
+
|
|
764
|
+
def memory_leak_teardown(self):
|
|
765
|
+
if self.__enable_leak_check:
|
|
766
|
+
self.assert_no_memory_leak()
|
|
767
|
+
|
|
768
|
+
def assert_no_memory_leak(self):
|
|
769
|
+
old = self.__init_stats
|
|
770
|
+
new = rtsys.get_allocation_stats()
|
|
771
|
+
total_alloc = new.alloc - old.alloc
|
|
772
|
+
total_free = new.free - old.free
|
|
773
|
+
total_mi_alloc = new.mi_alloc - old.mi_alloc
|
|
774
|
+
total_mi_free = new.mi_free - old.mi_free
|
|
775
|
+
self.assertEqual(total_alloc, total_free)
|
|
776
|
+
self.assertEqual(total_mi_alloc, total_mi_free)
|
|
777
|
+
|
|
778
|
+
def disable_leak_check(self):
|
|
779
|
+
# For per-test use when MemoryLeakMixin is injected into a TestCase
|
|
780
|
+
self.__enable_leak_check = False
|
|
781
|
+
|
|
782
|
+
|
|
783
|
+
class MemoryLeakMixin(EnableNRTStatsMixin, MemoryLeak):
|
|
784
|
+
def setUp(self):
|
|
785
|
+
super(MemoryLeakMixin, self).setUp()
|
|
786
|
+
self.memory_leak_setup()
|
|
787
|
+
|
|
788
|
+
def tearDown(self):
|
|
789
|
+
gc.collect()
|
|
790
|
+
self.memory_leak_teardown()
|
|
791
|
+
super(MemoryLeakMixin, self).tearDown()
|
|
792
|
+
|
|
793
|
+
|
|
794
|
+
class CheckWarningsMixin(object):
|
|
795
|
+
@contextlib.contextmanager
|
|
796
|
+
def check_warnings(self, messages, category=RuntimeWarning):
|
|
797
|
+
with warnings.catch_warnings(record=True) as catch:
|
|
798
|
+
warnings.simplefilter("always")
|
|
799
|
+
yield
|
|
800
|
+
found = 0
|
|
801
|
+
for w in catch:
|
|
802
|
+
for m in messages:
|
|
803
|
+
if m in str(w.message):
|
|
804
|
+
self.assertEqual(w.category, category)
|
|
805
|
+
found += 1
|
|
806
|
+
self.assertEqual(found, len(messages))
|
|
@@ -0,0 +1,200 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from io import StringIO
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import unittest
|
|
8
|
+
from numba.cuda.core import tracing
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger("trace")
|
|
11
|
+
logger.setLevel(logging.INFO)
|
|
12
|
+
|
|
13
|
+
# Make sure tracing is enabled
|
|
14
|
+
orig_trace = tracing.trace
|
|
15
|
+
tracing.trace = tracing.dotrace
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CapturedTrace:
|
|
19
|
+
"""Capture the trace temporarily for validation."""
|
|
20
|
+
|
|
21
|
+
def __init__(self):
|
|
22
|
+
self.buffer = StringIO()
|
|
23
|
+
self.handler = logging.StreamHandler(self.buffer)
|
|
24
|
+
|
|
25
|
+
def __enter__(self):
|
|
26
|
+
self._handlers = logger.handlers
|
|
27
|
+
self.buffer = StringIO()
|
|
28
|
+
logger.handlers = [logging.StreamHandler(self.buffer)]
|
|
29
|
+
|
|
30
|
+
def __exit__(self, type, value, traceback):
|
|
31
|
+
logger.handlers = self._handlers
|
|
32
|
+
|
|
33
|
+
def getvalue(self):
|
|
34
|
+
# Depending on how the tests are run, object names may be
|
|
35
|
+
# qualified by their containing module.
|
|
36
|
+
# Remove that to make the trace output independent from the testing mode.
|
|
37
|
+
log = self.buffer.getvalue()
|
|
38
|
+
log = log.replace(__name__ + ".", "")
|
|
39
|
+
return log
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class Class(object):
|
|
43
|
+
@tracing.trace
|
|
44
|
+
@classmethod
|
|
45
|
+
def class_method(cls):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
@tracing.trace
|
|
49
|
+
@staticmethod
|
|
50
|
+
def static_method():
|
|
51
|
+
pass
|
|
52
|
+
|
|
53
|
+
__test = None
|
|
54
|
+
|
|
55
|
+
def _test_get(self):
|
|
56
|
+
return self.__test
|
|
57
|
+
|
|
58
|
+
def _test_set(self, value):
|
|
59
|
+
self.__test = value
|
|
60
|
+
|
|
61
|
+
test = tracing.trace(property(_test_get, _test_set))
|
|
62
|
+
|
|
63
|
+
@tracing.trace
|
|
64
|
+
def method(self, some, other="value", *args, **kwds):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
def __repr__(self):
|
|
68
|
+
"""Generate a deterministic string for testing."""
|
|
69
|
+
return "<Class instance>"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class Class2(object):
|
|
73
|
+
@classmethod
|
|
74
|
+
def class_method(cls):
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def static_method():
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
__test = None
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def test(self):
|
|
85
|
+
return self.__test
|
|
86
|
+
|
|
87
|
+
@test.setter
|
|
88
|
+
def test(self, value):
|
|
89
|
+
self.__test = value
|
|
90
|
+
|
|
91
|
+
def method(self):
|
|
92
|
+
pass
|
|
93
|
+
|
|
94
|
+
def __str__(self):
|
|
95
|
+
return "Test(" + str(self.test) + ")"
|
|
96
|
+
|
|
97
|
+
def __repr__(self):
|
|
98
|
+
"""Generate a deterministic string for testing."""
|
|
99
|
+
return "<Class2 instance>"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@tracing.trace
|
|
103
|
+
def test_traced_function():
|
|
104
|
+
# Test the tracing functionality with fixed values
|
|
105
|
+
x, y = 5, 5
|
|
106
|
+
z = True
|
|
107
|
+
|
|
108
|
+
a = x + y
|
|
109
|
+
b = x * y
|
|
110
|
+
if z:
|
|
111
|
+
result = a
|
|
112
|
+
else:
|
|
113
|
+
result = b
|
|
114
|
+
|
|
115
|
+
# The function should return 10 (5 + 5) when z is True
|
|
116
|
+
assert result == 10
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class TestTracing(unittest.TestCase):
|
|
120
|
+
def __init__(self, *args):
|
|
121
|
+
super(TestTracing, self).__init__(*args)
|
|
122
|
+
|
|
123
|
+
def setUp(self):
|
|
124
|
+
self.capture = CapturedTrace()
|
|
125
|
+
|
|
126
|
+
def tearDown(self):
|
|
127
|
+
del self.capture
|
|
128
|
+
|
|
129
|
+
def test_method(self):
|
|
130
|
+
with self.capture:
|
|
131
|
+
Class().method("foo", bar="baz")
|
|
132
|
+
self.assertEqual(
|
|
133
|
+
self.capture.getvalue(),
|
|
134
|
+
">> Class.method(self=<Class instance>, some='foo', other='value', bar='baz')\n"
|
|
135
|
+
+ "<< Class.method\n",
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def test_class_method(self):
|
|
139
|
+
with self.capture:
|
|
140
|
+
Class.class_method()
|
|
141
|
+
self.assertEqual(
|
|
142
|
+
self.capture.getvalue(),
|
|
143
|
+
">> Class.class_method(cls=<class 'Class'>)\n"
|
|
144
|
+
+ "<< Class.class_method\n",
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def test_static_method(self):
|
|
148
|
+
with self.capture:
|
|
149
|
+
Class.static_method()
|
|
150
|
+
self.assertEqual(
|
|
151
|
+
self.capture.getvalue(),
|
|
152
|
+
">> static_method()\n" + "<< static_method\n",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
def test_property(self):
|
|
156
|
+
with self.capture:
|
|
157
|
+
test = Class()
|
|
158
|
+
test.test = 1
|
|
159
|
+
assert 1 == test.test
|
|
160
|
+
self.assertEqual(
|
|
161
|
+
self.capture.getvalue(),
|
|
162
|
+
">> Class._test_set(self=<Class instance>, value=1)\n"
|
|
163
|
+
+ "<< Class._test_set\n"
|
|
164
|
+
+ ">> Class._test_get(self=<Class instance>)\n"
|
|
165
|
+
+ "<< Class._test_get -> 1\n",
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def test_function(self):
|
|
169
|
+
with self.capture:
|
|
170
|
+
test_traced_function()
|
|
171
|
+
# The test function should be traced when called
|
|
172
|
+
trace_output = self.capture.getvalue()
|
|
173
|
+
self.assertIn(">> test_traced_function()", trace_output)
|
|
174
|
+
self.assertIn("<< test_traced_function", trace_output)
|
|
175
|
+
|
|
176
|
+
@unittest.skip("recursive decoration not yet implemented")
|
|
177
|
+
def test_injected(self):
|
|
178
|
+
with self.capture:
|
|
179
|
+
tracing.trace(Class2, recursive=True)
|
|
180
|
+
Class2.class_method()
|
|
181
|
+
Class2.static_method()
|
|
182
|
+
test = Class2()
|
|
183
|
+
test.test = 1
|
|
184
|
+
assert 1 == test.test
|
|
185
|
+
test.method()
|
|
186
|
+
|
|
187
|
+
self.assertEqual(
|
|
188
|
+
self.capture.getvalue(),
|
|
189
|
+
">> Class2.class_method(cls=<type 'Class2'>)\n"
|
|
190
|
+
+ "<< Class2.class_method\n"
|
|
191
|
+
">> static_method()\n"
|
|
192
|
+
"<< static_method\n",
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
# Reset tracing to its original value
|
|
197
|
+
tracing.trace = orig_trace
|
|
198
|
+
|
|
199
|
+
if __name__ == "__main__":
|
|
200
|
+
unittest.main()
|
numba_cuda/numba/cuda/types.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
3
|
|
|
4
4
|
from numba.core import types
|
|
5
|
+
from numba.core.typeconv import Conversion
|
|
5
6
|
|
|
6
7
|
|
|
7
8
|
class Dim3(types.Type):
|
|
@@ -41,3 +42,58 @@ class CUDADispatcher(types.Dispatcher):
|
|
|
41
42
|
# is still probably a good idea to have a separate type for CUDA
|
|
42
43
|
# dispatchers, and this type might get other differentiation from the CPU
|
|
43
44
|
# dispatcher type in future.
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Bfloat16(types.Number):
|
|
48
|
+
"""
|
|
49
|
+
A bfloat16 type. Has 8 exponent bits and 7 significand bits.
|
|
50
|
+
|
|
51
|
+
Conversion rules:
|
|
52
|
+
Floats:
|
|
53
|
+
from:
|
|
54
|
+
fp32, fp64: UNSAFE
|
|
55
|
+
fp16: UNSAFE (loses precision)
|
|
56
|
+
to:
|
|
57
|
+
fp32, fp64: PROMOTE (same exponent, more mantissa)
|
|
58
|
+
fp16: UNSAFE (loses range)
|
|
59
|
+
|
|
60
|
+
Integers:
|
|
61
|
+
from:
|
|
62
|
+
int8: SAFE
|
|
63
|
+
other int: All UNSAFE (bf16 cannot represent all integers in range)
|
|
64
|
+
to: UNSAFE (loses precision, round to zeros)
|
|
65
|
+
|
|
66
|
+
All other conversions are not allowed.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(self):
|
|
70
|
+
super().__init__(name="__nv_bfloat16")
|
|
71
|
+
|
|
72
|
+
self.alignof_ = 2
|
|
73
|
+
self.bitwidth = 16
|
|
74
|
+
|
|
75
|
+
def can_convert_from(self, typingctx, other):
|
|
76
|
+
if isinstance(other, types.Float):
|
|
77
|
+
return Conversion.unsafe
|
|
78
|
+
|
|
79
|
+
elif isinstance(other, types.Integer):
|
|
80
|
+
if other.bitwidth == 8:
|
|
81
|
+
return Conversion.safe
|
|
82
|
+
else:
|
|
83
|
+
return Conversion.unsafe
|
|
84
|
+
|
|
85
|
+
def can_convert_to(self, typingctx, other):
|
|
86
|
+
if isinstance(other, types.Float):
|
|
87
|
+
if other.bitwidth >= 32:
|
|
88
|
+
return Conversion.safe
|
|
89
|
+
else:
|
|
90
|
+
return Conversion.unsafe
|
|
91
|
+
elif isinstance(other, types.Integer):
|
|
92
|
+
return Conversion.unsafe
|
|
93
|
+
|
|
94
|
+
def unify(self, typingctx, other):
|
|
95
|
+
if isinstance(other, (types.Float, types.Integer)):
|
|
96
|
+
return typingctx.unify_pairs(self, other)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
bfloat16 = Bfloat16()
|
|
@@ -7,5 +7,13 @@ from .templates import (
|
|
|
7
7
|
Signature,
|
|
8
8
|
fold_arguments,
|
|
9
9
|
)
|
|
10
|
+
from .context import BaseContext, Context
|
|
10
11
|
|
|
11
|
-
__all__ = [
|
|
12
|
+
__all__ = [
|
|
13
|
+
"signature",
|
|
14
|
+
"make_concrete_template",
|
|
15
|
+
"Signature",
|
|
16
|
+
"fold_arguments",
|
|
17
|
+
"BaseContext",
|
|
18
|
+
"Context",
|
|
19
|
+
]
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Support for CFFI. Allows checking whether objects are CFFI functions and
|
|
6
|
+
obtaining the pointer and numba signature.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from numba.core import types
|
|
10
|
+
from numba.core.errors import TypingError
|
|
11
|
+
from numba.cuda.typing import templates
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
import cffi
|
|
15
|
+
|
|
16
|
+
ffi = cffi.FFI()
|
|
17
|
+
except ImportError:
|
|
18
|
+
ffi = None
|
|
19
|
+
|
|
20
|
+
SUPPORTED = ffi is not None
|
|
21
|
+
registry = templates.Registry()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@registry.register
|
|
25
|
+
class FFI_from_buffer(templates.AbstractTemplate):
|
|
26
|
+
key = "ffi.from_buffer"
|
|
27
|
+
|
|
28
|
+
def generic(self, args, kws):
|
|
29
|
+
if kws or len(args) != 1:
|
|
30
|
+
return
|
|
31
|
+
[ary] = args
|
|
32
|
+
if not isinstance(ary, types.Buffer):
|
|
33
|
+
raise TypingError(
|
|
34
|
+
"from_buffer() expected a buffer object, got %s" % (ary,)
|
|
35
|
+
)
|
|
36
|
+
if ary.layout not in ("C", "F"):
|
|
37
|
+
raise TypingError(
|
|
38
|
+
"from_buffer() unsupported on non-contiguous buffers (got %s)"
|
|
39
|
+
% (ary,)
|
|
40
|
+
)
|
|
41
|
+
if ary.layout != "C" and ary.ndim > 1:
|
|
42
|
+
raise TypingError(
|
|
43
|
+
"from_buffer() only supports multidimensional arrays with C layout (got %s)"
|
|
44
|
+
% (ary,)
|
|
45
|
+
)
|
|
46
|
+
ptr = types.CPointer(ary.dtype)
|
|
47
|
+
return templates.signature(ptr, ary)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@registry.register_attr
|
|
51
|
+
class FFIAttribute(templates.AttributeTemplate):
|
|
52
|
+
key = types.ffi
|
|
53
|
+
|
|
54
|
+
def resolve_from_buffer(self, ffi):
|
|
55
|
+
return types.BoundFunction(FFI_from_buffer, types.ffi)
|