numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (171) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +4 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  129. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  130. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  134. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  135. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  139. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  140. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  141. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  142. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  143. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  145. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  146. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  147. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  148. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  151. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  152. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  153. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  154. numba_cuda/numba/cuda/tests/support.py +55 -15
  155. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  156. numba_cuda/numba/cuda/types.py +56 -0
  157. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  158. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  159. numba_cuda/numba/cuda/typing/context.py +751 -0
  160. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  161. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  162. numba_cuda/numba/cuda/typing/templates.py +7 -6
  163. numba_cuda/numba/cuda/ufuncs.py +3 -3
  164. numba_cuda/numba/cuda/utils.py +6 -112
  165. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
  167. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  168. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
@@ -21,8 +21,9 @@ from functools import cached_property
21
21
  import numpy as np
22
22
 
23
23
  from numba import types
24
- from numba.core import errors, config
25
- from numba.core.typing import cffi_utils
24
+ from numba.core import errors
25
+ from numba.cuda.core import config
26
+ from numba.cuda.typing import cffi_utils
26
27
  from numba.cuda.memory_management.nrt import rtsys
27
28
  from numba.core.extending import (
28
29
  typeof_impl,
@@ -31,7 +32,7 @@ from numba.core.extending import (
31
32
  NativeValue,
32
33
  )
33
34
  from numba.core.datamodel.models import OpaqueModel
34
- from numba.np import numpy_support
35
+ from numba.cuda.np import numpy_support
35
36
 
36
37
 
37
38
  class EnableNRTStatsMixin(object):
@@ -751,16 +752,55 @@ class TestCase(unittest.TestCase):
751
752
 
752
753
  return Dummy, DummyType
753
754
 
754
- def skip_if_no_external_compiler(self):
755
- """
756
- Call this to ensure the test is skipped if no suitable external compiler
757
- is found. This is a method on the TestCase opposed to a stand-alone
758
- decorator so as to make it "lazy" via runtime evaluation opposed to
759
- running at test-discovery time.
760
- """
761
- # This is a local import to avoid deprecation warnings being generated
762
- # through the use of the numba.pycc module.
763
- from numba.pycc.platform import external_compiler_works
764
755
 
765
- if not external_compiler_works():
766
- self.skipTest("No suitable external compiler was found.")
756
+ class MemoryLeak(object):
757
+ __enable_leak_check = True
758
+
759
+ def memory_leak_setup(self):
760
+ # Clean up any NRT-backed objects hanging in a dead reference cycle
761
+ gc.collect()
762
+ self.__init_stats = rtsys.get_allocation_stats()
763
+
764
+ def memory_leak_teardown(self):
765
+ if self.__enable_leak_check:
766
+ self.assert_no_memory_leak()
767
+
768
+ def assert_no_memory_leak(self):
769
+ old = self.__init_stats
770
+ new = rtsys.get_allocation_stats()
771
+ total_alloc = new.alloc - old.alloc
772
+ total_free = new.free - old.free
773
+ total_mi_alloc = new.mi_alloc - old.mi_alloc
774
+ total_mi_free = new.mi_free - old.mi_free
775
+ self.assertEqual(total_alloc, total_free)
776
+ self.assertEqual(total_mi_alloc, total_mi_free)
777
+
778
+ def disable_leak_check(self):
779
+ # For per-test use when MemoryLeakMixin is injected into a TestCase
780
+ self.__enable_leak_check = False
781
+
782
+
783
+ class MemoryLeakMixin(EnableNRTStatsMixin, MemoryLeak):
784
+ def setUp(self):
785
+ super(MemoryLeakMixin, self).setUp()
786
+ self.memory_leak_setup()
787
+
788
+ def tearDown(self):
789
+ gc.collect()
790
+ self.memory_leak_teardown()
791
+ super(MemoryLeakMixin, self).tearDown()
792
+
793
+
794
+ class CheckWarningsMixin(object):
795
+ @contextlib.contextmanager
796
+ def check_warnings(self, messages, category=RuntimeWarning):
797
+ with warnings.catch_warnings(record=True) as catch:
798
+ warnings.simplefilter("always")
799
+ yield
800
+ found = 0
801
+ for w in catch:
802
+ for m in messages:
803
+ if m in str(w.message):
804
+ self.assertEqual(w.category, category)
805
+ found += 1
806
+ self.assertEqual(found, len(messages))
@@ -0,0 +1,200 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from io import StringIO
5
+ import logging
6
+
7
+ import unittest
8
+ from numba.cuda.core import tracing
9
+
10
+ logger = logging.getLogger("trace")
11
+ logger.setLevel(logging.INFO)
12
+
13
+ # Make sure tracing is enabled
14
+ orig_trace = tracing.trace
15
+ tracing.trace = tracing.dotrace
16
+
17
+
18
+ class CapturedTrace:
19
+ """Capture the trace temporarily for validation."""
20
+
21
+ def __init__(self):
22
+ self.buffer = StringIO()
23
+ self.handler = logging.StreamHandler(self.buffer)
24
+
25
+ def __enter__(self):
26
+ self._handlers = logger.handlers
27
+ self.buffer = StringIO()
28
+ logger.handlers = [logging.StreamHandler(self.buffer)]
29
+
30
+ def __exit__(self, type, value, traceback):
31
+ logger.handlers = self._handlers
32
+
33
+ def getvalue(self):
34
+ # Depending on how the tests are run, object names may be
35
+ # qualified by their containing module.
36
+ # Remove that to make the trace output independent from the testing mode.
37
+ log = self.buffer.getvalue()
38
+ log = log.replace(__name__ + ".", "")
39
+ return log
40
+
41
+
42
+ class Class(object):
43
+ @tracing.trace
44
+ @classmethod
45
+ def class_method(cls):
46
+ pass
47
+
48
+ @tracing.trace
49
+ @staticmethod
50
+ def static_method():
51
+ pass
52
+
53
+ __test = None
54
+
55
+ def _test_get(self):
56
+ return self.__test
57
+
58
+ def _test_set(self, value):
59
+ self.__test = value
60
+
61
+ test = tracing.trace(property(_test_get, _test_set))
62
+
63
+ @tracing.trace
64
+ def method(self, some, other="value", *args, **kwds):
65
+ pass
66
+
67
+ def __repr__(self):
68
+ """Generate a deterministic string for testing."""
69
+ return "<Class instance>"
70
+
71
+
72
+ class Class2(object):
73
+ @classmethod
74
+ def class_method(cls):
75
+ pass
76
+
77
+ @staticmethod
78
+ def static_method():
79
+ pass
80
+
81
+ __test = None
82
+
83
+ @property
84
+ def test(self):
85
+ return self.__test
86
+
87
+ @test.setter
88
+ def test(self, value):
89
+ self.__test = value
90
+
91
+ def method(self):
92
+ pass
93
+
94
+ def __str__(self):
95
+ return "Test(" + str(self.test) + ")"
96
+
97
+ def __repr__(self):
98
+ """Generate a deterministic string for testing."""
99
+ return "<Class2 instance>"
100
+
101
+
102
+ @tracing.trace
103
+ def test_traced_function():
104
+ # Test the tracing functionality with fixed values
105
+ x, y = 5, 5
106
+ z = True
107
+
108
+ a = x + y
109
+ b = x * y
110
+ if z:
111
+ result = a
112
+ else:
113
+ result = b
114
+
115
+ # The function should return 10 (5 + 5) when z is True
116
+ assert result == 10
117
+
118
+
119
+ class TestTracing(unittest.TestCase):
120
+ def __init__(self, *args):
121
+ super(TestTracing, self).__init__(*args)
122
+
123
+ def setUp(self):
124
+ self.capture = CapturedTrace()
125
+
126
+ def tearDown(self):
127
+ del self.capture
128
+
129
+ def test_method(self):
130
+ with self.capture:
131
+ Class().method("foo", bar="baz")
132
+ self.assertEqual(
133
+ self.capture.getvalue(),
134
+ ">> Class.method(self=<Class instance>, some='foo', other='value', bar='baz')\n"
135
+ + "<< Class.method\n",
136
+ )
137
+
138
+ def test_class_method(self):
139
+ with self.capture:
140
+ Class.class_method()
141
+ self.assertEqual(
142
+ self.capture.getvalue(),
143
+ ">> Class.class_method(cls=<class 'Class'>)\n"
144
+ + "<< Class.class_method\n",
145
+ )
146
+
147
+ def test_static_method(self):
148
+ with self.capture:
149
+ Class.static_method()
150
+ self.assertEqual(
151
+ self.capture.getvalue(),
152
+ ">> static_method()\n" + "<< static_method\n",
153
+ )
154
+
155
+ def test_property(self):
156
+ with self.capture:
157
+ test = Class()
158
+ test.test = 1
159
+ assert 1 == test.test
160
+ self.assertEqual(
161
+ self.capture.getvalue(),
162
+ ">> Class._test_set(self=<Class instance>, value=1)\n"
163
+ + "<< Class._test_set\n"
164
+ + ">> Class._test_get(self=<Class instance>)\n"
165
+ + "<< Class._test_get -> 1\n",
166
+ )
167
+
168
+ def test_function(self):
169
+ with self.capture:
170
+ test_traced_function()
171
+ # The test function should be traced when called
172
+ trace_output = self.capture.getvalue()
173
+ self.assertIn(">> test_traced_function()", trace_output)
174
+ self.assertIn("<< test_traced_function", trace_output)
175
+
176
+ @unittest.skip("recursive decoration not yet implemented")
177
+ def test_injected(self):
178
+ with self.capture:
179
+ tracing.trace(Class2, recursive=True)
180
+ Class2.class_method()
181
+ Class2.static_method()
182
+ test = Class2()
183
+ test.test = 1
184
+ assert 1 == test.test
185
+ test.method()
186
+
187
+ self.assertEqual(
188
+ self.capture.getvalue(),
189
+ ">> Class2.class_method(cls=<type 'Class2'>)\n"
190
+ + "<< Class2.class_method\n"
191
+ ">> static_method()\n"
192
+ "<< static_method\n",
193
+ )
194
+
195
+
196
+ # Reset tracing to its original value
197
+ tracing.trace = orig_trace
198
+
199
+ if __name__ == "__main__":
200
+ unittest.main()
@@ -2,6 +2,7 @@
2
2
  # SPDX-License-Identifier: BSD-2-Clause
3
3
 
4
4
  from numba.core import types
5
+ from numba.core.typeconv import Conversion
5
6
 
6
7
 
7
8
  class Dim3(types.Type):
@@ -41,3 +42,58 @@ class CUDADispatcher(types.Dispatcher):
41
42
  # is still probably a good idea to have a separate type for CUDA
42
43
  # dispatchers, and this type might get other differentiation from the CPU
43
44
  # dispatcher type in future.
45
+
46
+
47
+ class Bfloat16(types.Number):
48
+ """
49
+ A bfloat16 type. Has 8 exponent bits and 7 significand bits.
50
+
51
+ Conversion rules:
52
+ Floats:
53
+ from:
54
+ fp32, fp64: UNSAFE
55
+ fp16: UNSAFE (loses precision)
56
+ to:
57
+ fp32, fp64: PROMOTE (same exponent, more mantissa)
58
+ fp16: UNSAFE (loses range)
59
+
60
+ Integers:
61
+ from:
62
+ int8: SAFE
63
+ other int: All UNSAFE (bf16 cannot represent all integers in range)
64
+ to: UNSAFE (loses precision, round to zeros)
65
+
66
+ All other conversions are not allowed.
67
+ """
68
+
69
+ def __init__(self):
70
+ super().__init__(name="__nv_bfloat16")
71
+
72
+ self.alignof_ = 2
73
+ self.bitwidth = 16
74
+
75
+ def can_convert_from(self, typingctx, other):
76
+ if isinstance(other, types.Float):
77
+ return Conversion.unsafe
78
+
79
+ elif isinstance(other, types.Integer):
80
+ if other.bitwidth == 8:
81
+ return Conversion.safe
82
+ else:
83
+ return Conversion.unsafe
84
+
85
+ def can_convert_to(self, typingctx, other):
86
+ if isinstance(other, types.Float):
87
+ if other.bitwidth >= 32:
88
+ return Conversion.safe
89
+ else:
90
+ return Conversion.unsafe
91
+ elif isinstance(other, types.Integer):
92
+ return Conversion.unsafe
93
+
94
+ def unify(self, typingctx, other):
95
+ if isinstance(other, (types.Float, types.Integer)):
96
+ return typingctx.unify_pairs(self, other)
97
+
98
+
99
+ bfloat16 = Bfloat16()
@@ -7,5 +7,13 @@ from .templates import (
7
7
  Signature,
8
8
  fold_arguments,
9
9
  )
10
+ from .context import BaseContext, Context
10
11
 
11
- __all__ = ["signature", "make_concrete_template", "Signature", "fold_arguments"]
12
+ __all__ = [
13
+ "signature",
14
+ "make_concrete_template",
15
+ "Signature",
16
+ "fold_arguments",
17
+ "BaseContext",
18
+ "Context",
19
+ ]
@@ -0,0 +1,55 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ """
5
+ Support for CFFI. Allows checking whether objects are CFFI functions and
6
+ obtaining the pointer and numba signature.
7
+ """
8
+
9
+ from numba.core import types
10
+ from numba.core.errors import TypingError
11
+ from numba.cuda.typing import templates
12
+
13
+ try:
14
+ import cffi
15
+
16
+ ffi = cffi.FFI()
17
+ except ImportError:
18
+ ffi = None
19
+
20
+ SUPPORTED = ffi is not None
21
+ registry = templates.Registry()
22
+
23
+
24
+ @registry.register
25
+ class FFI_from_buffer(templates.AbstractTemplate):
26
+ key = "ffi.from_buffer"
27
+
28
+ def generic(self, args, kws):
29
+ if kws or len(args) != 1:
30
+ return
31
+ [ary] = args
32
+ if not isinstance(ary, types.Buffer):
33
+ raise TypingError(
34
+ "from_buffer() expected a buffer object, got %s" % (ary,)
35
+ )
36
+ if ary.layout not in ("C", "F"):
37
+ raise TypingError(
38
+ "from_buffer() unsupported on non-contiguous buffers (got %s)"
39
+ % (ary,)
40
+ )
41
+ if ary.layout != "C" and ary.ndim > 1:
42
+ raise TypingError(
43
+ "from_buffer() only supports multidimensional arrays with C layout (got %s)"
44
+ % (ary,)
45
+ )
46
+ ptr = types.CPointer(ary.dtype)
47
+ return templates.signature(ptr, ary)
48
+
49
+
50
+ @registry.register_attr
51
+ class FFIAttribute(templates.AttributeTemplate):
52
+ key = types.ffi
53
+
54
+ def resolve_from_buffer(self, ffi):
55
+ return types.BoundFunction(FFI_from_buffer, types.ffi)