numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,538 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ import itertools
5
+
6
+ from numba.core import errors, types, typing
7
+ from numba.core.typeconv import Conversion
8
+
9
+ from numba.cuda.testing import CUDATestCase, skip_on_cudasim
10
+ from numba.tests.test_typeconv import CompatibilityTestMixin
11
+ from numba.cuda.core.untyped_passes import TranslateByteCode, IRProcessing
12
+ from numba.cuda.core.typed_passes import PartialTypeInference
13
+ from numba.cuda.core.compiler_machinery import FunctionPass, register_pass
14
+ import unittest
15
+
16
+ from numba.cuda.flags import Flags
17
+
18
+
19
+ i8 = types.int8
20
+ i16 = types.int16
21
+ i32 = types.int32
22
+ i64 = types.int64
23
+ u8 = types.uint8
24
+ u16 = types.uint16
25
+ u32 = types.uint32
26
+ u64 = types.uint64
27
+ f32 = types.float32
28
+ f64 = types.float64
29
+ c64 = types.complex64
30
+ c128 = types.complex128
31
+
32
+
33
+ class TestUnify(CUDATestCase):
34
+ """
35
+ Tests for type unification with a typing context.
36
+ """
37
+
38
+ int_unify = {
39
+ ("uint8", "uint8"): "uint8",
40
+ ("int8", "int8"): "int8",
41
+ ("uint16", "uint16"): "uint16",
42
+ ("int16", "int16"): "int16",
43
+ ("uint32", "uint32"): "uint32",
44
+ ("int32", "int32"): "int32",
45
+ ("uint64", "uint64"): "uint64",
46
+ ("int64", "int64"): "int64",
47
+ ("int8", "uint8"): "int16",
48
+ ("int8", "uint16"): "int32",
49
+ ("int8", "uint32"): "int64",
50
+ ("uint8", "int32"): "int32",
51
+ ("uint8", "uint64"): "uint64",
52
+ ("int16", "int8"): "int16",
53
+ ("int16", "uint8"): "int16",
54
+ ("int16", "uint16"): "int32",
55
+ ("int16", "uint32"): "int64",
56
+ ("int16", "int64"): "int64",
57
+ ("int16", "uint64"): "float64",
58
+ ("uint16", "uint8"): "uint16",
59
+ ("uint16", "uint32"): "uint32",
60
+ ("uint16", "int32"): "int32",
61
+ ("uint16", "uint64"): "uint64",
62
+ ("int32", "int8"): "int32",
63
+ ("int32", "int16"): "int32",
64
+ ("int32", "uint32"): "int64",
65
+ ("int32", "int64"): "int64",
66
+ ("uint32", "uint8"): "uint32",
67
+ ("uint32", "int64"): "int64",
68
+ ("uint32", "uint64"): "uint64",
69
+ ("int64", "int8"): "int64",
70
+ ("int64", "uint8"): "int64",
71
+ ("int64", "uint16"): "int64",
72
+ ("uint64", "int8"): "float64",
73
+ ("uint64", "int32"): "float64",
74
+ ("uint64", "int64"): "float64",
75
+ }
76
+
77
+ def assert_unify(self, aty, bty, expected):
78
+ ctx = typing.Context()
79
+ template = "{0}, {1} -> {2} != {3}"
80
+ for unify_func in ctx.unify_types, ctx.unify_pairs:
81
+ unified = unify_func(aty, bty)
82
+ self.assertEqual(
83
+ unified,
84
+ expected,
85
+ msg=template.format(aty, bty, unified, expected),
86
+ )
87
+ unified = unify_func(bty, aty)
88
+ self.assertEqual(
89
+ unified,
90
+ expected,
91
+ msg=template.format(bty, aty, unified, expected),
92
+ )
93
+
94
+ def assert_unify_failure(self, aty, bty):
95
+ self.assert_unify(aty, bty, None)
96
+
97
+ def test_integer(self):
98
+ ctx = typing.Context() # noqa: F841
99
+ for aty, bty in itertools.product(
100
+ types.integer_domain, types.integer_domain
101
+ ):
102
+ key = (str(aty), str(bty))
103
+ try:
104
+ expected = self.int_unify[key]
105
+ except KeyError:
106
+ expected = self.int_unify[key[::-1]]
107
+ self.assert_unify(aty, bty, getattr(types, expected))
108
+
109
+ def test_bool(self):
110
+ aty = types.boolean
111
+ for bty in types.integer_domain:
112
+ self.assert_unify(aty, bty, bty)
113
+ # Not sure about this one, but it respects transitivity
114
+ for cty in types.real_domain:
115
+ self.assert_unify(aty, cty, cty)
116
+
117
+ def unify_number_pair_test(self, n):
118
+ """
119
+ Test all permutations of N-combinations of numeric types and ensure
120
+ that the order of types in the sequence is irrelevant.
121
+ """
122
+ ctx = typing.Context()
123
+ for tys in itertools.combinations(types.number_domain, n):
124
+ res = [
125
+ ctx.unify_types(*comb) for comb in itertools.permutations(tys)
126
+ ]
127
+ first_result = res[0]
128
+ # Sanity check
129
+ self.assertIsInstance(first_result, types.Number)
130
+ # All results must be equal
131
+ for other in res[1:]:
132
+ self.assertEqual(first_result, other)
133
+
134
+ def test_unify_number_pair(self):
135
+ self.unify_number_pair_test(2)
136
+ self.unify_number_pair_test(3)
137
+
138
+ def test_none_to_optional(self):
139
+ """
140
+ Test unification of `none` and multiple number types to optional type
141
+ """
142
+ ctx = typing.Context()
143
+ for tys in itertools.combinations(types.number_domain, 2):
144
+ # First unify without none, to provide the control value
145
+ tys = list(tys)
146
+ expected = types.Optional(ctx.unify_types(*tys))
147
+ results = [
148
+ ctx.unify_types(*comb)
149
+ for comb in itertools.permutations(tys + [types.none])
150
+ ]
151
+ # All results must be equal
152
+ for res in results:
153
+ self.assertEqual(res, expected)
154
+
155
+ def test_none(self):
156
+ aty = types.none
157
+ bty = types.none
158
+ self.assert_unify(aty, bty, types.none)
159
+
160
+ def test_optional(self):
161
+ aty = types.Optional(i32)
162
+ bty = types.none
163
+ self.assert_unify(aty, bty, aty)
164
+ aty = types.Optional(i32)
165
+ bty = types.Optional(i64)
166
+ self.assert_unify(aty, bty, bty)
167
+ aty = types.Optional(i32)
168
+ bty = i64
169
+ self.assert_unify(aty, bty, types.Optional(i64))
170
+ # Failure
171
+ aty = types.Optional(i32)
172
+ bty = types.Optional(types.slice3_type)
173
+ self.assert_unify_failure(aty, bty)
174
+
175
+ def test_tuple(self):
176
+ aty = types.UniTuple(i32, 3)
177
+ bty = types.UniTuple(i64, 3)
178
+ self.assert_unify(aty, bty, types.UniTuple(i64, 3))
179
+ # (Tuple, UniTuple) -> Tuple
180
+ aty = types.UniTuple(i32, 2)
181
+ bty = types.Tuple((i16, i64))
182
+ self.assert_unify(aty, bty, types.Tuple((i32, i64)))
183
+ aty = types.UniTuple(i64, 0)
184
+ bty = types.Tuple(())
185
+ self.assert_unify(aty, bty, bty)
186
+ # (Tuple, Tuple) -> Tuple
187
+ aty = types.Tuple((i8, i16, i32))
188
+ bty = types.Tuple((i32, i16, i8))
189
+ self.assert_unify(aty, bty, types.Tuple((i32, i16, i32)))
190
+ aty = types.Tuple((i8, i32))
191
+ bty = types.Tuple((i32, i8))
192
+ self.assert_unify(aty, bty, types.Tuple((i32, i32)))
193
+ aty = types.Tuple((i8, i16))
194
+ bty = types.Tuple((i16, i8))
195
+ self.assert_unify(aty, bty, types.Tuple((i16, i16)))
196
+ # Different number kinds
197
+ aty = types.UniTuple(f64, 3)
198
+ bty = types.UniTuple(c64, 3)
199
+ self.assert_unify(aty, bty, types.UniTuple(c128, 3))
200
+ # Tuples of tuples
201
+ aty = types.UniTuple(types.Tuple((u32, f32)), 2)
202
+ bty = types.UniTuple(types.Tuple((i16, f32)), 2)
203
+ self.assert_unify(aty, bty, types.UniTuple(types.Tuple((i64, f32)), 2))
204
+ # Failures
205
+ aty = types.UniTuple(i32, 1)
206
+ bty = types.UniTuple(types.slice3_type, 1)
207
+ self.assert_unify_failure(aty, bty)
208
+ aty = types.UniTuple(i32, 1)
209
+ bty = types.UniTuple(i32, 2)
210
+ self.assert_unify_failure(aty, bty)
211
+ aty = types.Tuple((i8, types.slice3_type))
212
+ bty = types.Tuple((i32, i8))
213
+ self.assert_unify_failure(aty, bty)
214
+
215
+ def test_optional_tuple(self):
216
+ # Unify to optional tuple
217
+ aty = types.none
218
+ bty = types.UniTuple(i32, 2)
219
+ self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
220
+ aty = types.Optional(types.UniTuple(i16, 2))
221
+ bty = types.UniTuple(i32, 2)
222
+ self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
223
+ # Unify to tuple of optionals
224
+ aty = types.Tuple((types.none, i32))
225
+ bty = types.Tuple((i16, types.none))
226
+ self.assert_unify(
227
+ aty, bty, types.Tuple((types.Optional(i16), types.Optional(i32)))
228
+ )
229
+ aty = types.Tuple((types.Optional(i32), i64))
230
+ bty = types.Tuple((i16, types.Optional(i8)))
231
+ self.assert_unify(
232
+ aty, bty, types.Tuple((types.Optional(i32), types.Optional(i64)))
233
+ )
234
+
235
+ def test_arrays(self):
236
+ aty = types.Array(i32, 3, "C")
237
+ bty = types.Array(i32, 3, "A")
238
+ self.assert_unify(aty, bty, bty)
239
+ aty = types.Array(i32, 3, "C")
240
+ bty = types.Array(i32, 3, "F")
241
+ self.assert_unify(aty, bty, types.Array(i32, 3, "A"))
242
+ aty = types.Array(i32, 3, "C")
243
+ bty = types.Array(i32, 3, "C", readonly=True)
244
+ self.assert_unify(aty, bty, bty)
245
+ aty = types.Array(i32, 3, "A")
246
+ bty = types.Array(i32, 3, "C", readonly=True)
247
+ self.assert_unify(aty, bty, types.Array(i32, 3, "A", readonly=True))
248
+ # Failures
249
+ aty = types.Array(i32, 2, "C")
250
+ bty = types.Array(i32, 3, "C")
251
+ self.assert_unify_failure(aty, bty)
252
+ aty = types.Array(i32, 2, "C")
253
+ bty = types.Array(u32, 2, "C")
254
+ self.assert_unify_failure(aty, bty)
255
+
256
+ def test_list(self):
257
+ aty = types.List(types.undefined)
258
+ bty = types.List(i32)
259
+ self.assert_unify(aty, bty, bty)
260
+ aty = types.List(i16)
261
+ bty = types.List(i32)
262
+ self.assert_unify(aty, bty, bty)
263
+ aty = types.List(types.Tuple([i32, i16]))
264
+ bty = types.List(types.Tuple([i16, i64]))
265
+ cty = types.List(types.Tuple([i32, i64]))
266
+ self.assert_unify(aty, bty, cty)
267
+ # Different reflections
268
+ aty = types.List(i16, reflected=True)
269
+ bty = types.List(i32)
270
+ cty = types.List(i32, reflected=True)
271
+ self.assert_unify(aty, bty, cty)
272
+ # Incompatible dtypes
273
+ aty = types.List(i16)
274
+ bty = types.List(types.Tuple([i16]))
275
+ self.assert_unify_failure(aty, bty)
276
+
277
+ def test_set(self):
278
+ # Different reflections
279
+ aty = types.Set(i16, reflected=True)
280
+ bty = types.Set(i32)
281
+ cty = types.Set(i32, reflected=True)
282
+ self.assert_unify(aty, bty, cty)
283
+ # Incompatible dtypes
284
+ aty = types.Set(i16)
285
+ bty = types.Set(types.Tuple([i16]))
286
+ self.assert_unify_failure(aty, bty)
287
+
288
+ def test_range(self):
289
+ aty = types.range_state32_type
290
+ bty = types.range_state64_type
291
+ self.assert_unify(aty, bty, bty)
292
+
293
+
294
+ class TestTypeConversion(CompatibilityTestMixin, unittest.TestCase):
295
+ """
296
+ Test for conversion between types with a typing context.
297
+ """
298
+
299
+ def assert_can_convert(self, aty, bty, expected):
300
+ ctx = typing.Context()
301
+ got = ctx.can_convert(aty, bty)
302
+ self.assertEqual(got, expected)
303
+
304
+ def assert_cannot_convert(self, aty, bty):
305
+ ctx = typing.Context()
306
+ got = ctx.can_convert(aty, bty)
307
+ self.assertIsNone(got)
308
+
309
+ def test_convert_number_types(self):
310
+ # Check that Context.can_convert() is compatible with the default
311
+ # number conversion rules registered in the typeconv module
312
+ # (which is used internally by the C _Dispatcher object).
313
+ ctx = typing.Context()
314
+ self.check_number_compatibility(ctx.can_convert)
315
+
316
+ def test_tuple(self):
317
+ # UniTuple -> UniTuple
318
+ aty = types.UniTuple(i32, 3)
319
+ bty = types.UniTuple(i64, 3)
320
+ self.assert_can_convert(aty, aty, Conversion.exact)
321
+ self.assert_can_convert(aty, bty, Conversion.promote)
322
+ aty = types.UniTuple(i32, 3)
323
+ bty = types.UniTuple(f64, 3)
324
+ self.assert_can_convert(aty, bty, Conversion.safe)
325
+ # Tuple -> Tuple
326
+ aty = types.Tuple((i32, i32))
327
+ bty = types.Tuple((i32, i64))
328
+ self.assert_can_convert(aty, bty, Conversion.promote)
329
+ # UniTuple <-> Tuple
330
+ aty = types.UniTuple(i32, 2)
331
+ bty = types.Tuple((i32, i64))
332
+ self.assert_can_convert(aty, bty, Conversion.promote)
333
+ self.assert_can_convert(bty, aty, Conversion.unsafe)
334
+ # Empty tuples
335
+ aty = types.UniTuple(i64, 0)
336
+ bty = types.UniTuple(i32, 0)
337
+ cty = types.Tuple(())
338
+ self.assert_can_convert(aty, bty, Conversion.safe)
339
+ self.assert_can_convert(bty, aty, Conversion.safe)
340
+ self.assert_can_convert(aty, cty, Conversion.safe)
341
+ self.assert_can_convert(cty, aty, Conversion.safe)
342
+ # Failures
343
+ aty = types.UniTuple(i64, 3)
344
+ bty = types.UniTuple(types.none, 3)
345
+ self.assert_cannot_convert(aty, bty)
346
+ aty = types.UniTuple(i64, 2)
347
+ bty = types.UniTuple(i64, 3)
348
+
349
+ def test_arrays(self):
350
+ # Different layouts
351
+ aty = types.Array(i32, 3, "C")
352
+ bty = types.Array(i32, 3, "A")
353
+ self.assert_can_convert(aty, bty, Conversion.safe)
354
+ aty = types.Array(i32, 2, "C")
355
+ bty = types.Array(i32, 2, "F")
356
+ self.assert_cannot_convert(aty, bty)
357
+ # Different mutabilities
358
+ aty = types.Array(i32, 3, "C")
359
+ bty = types.Array(i32, 3, "C", readonly=True)
360
+ self.assert_can_convert(aty, aty, Conversion.exact)
361
+ self.assert_can_convert(bty, bty, Conversion.exact)
362
+ self.assert_can_convert(aty, bty, Conversion.safe)
363
+ self.assert_cannot_convert(bty, aty)
364
+ # Various failures
365
+ aty = types.Array(i32, 2, "C")
366
+ bty = types.Array(i32, 3, "C")
367
+ self.assert_cannot_convert(aty, bty)
368
+ aty = types.Array(i32, 2, "C")
369
+ bty = types.Array(i64, 2, "C")
370
+ self.assert_cannot_convert(aty, bty)
371
+
372
+ def test_optional(self):
373
+ aty = types.int32
374
+ bty = types.Optional(i32)
375
+ self.assert_can_convert(types.none, bty, Conversion.promote)
376
+ self.assert_can_convert(aty, bty, Conversion.promote)
377
+ self.assert_cannot_convert(bty, types.none)
378
+ self.assert_can_convert(bty, aty, Conversion.safe) # XXX ???
379
+ # Optional array
380
+ aty = types.Array(i32, 2, "C")
381
+ bty = types.Optional(aty)
382
+ self.assert_can_convert(types.none, bty, Conversion.promote)
383
+ self.assert_can_convert(aty, bty, Conversion.promote)
384
+ self.assert_can_convert(bty, aty, Conversion.safe)
385
+ aty = types.Array(i32, 2, "C")
386
+ bty = types.Optional(aty.copy(layout="A"))
387
+ self.assert_can_convert(aty, bty, Conversion.safe) # C -> A
388
+ self.assert_cannot_convert(bty, aty) # A -> C
389
+ aty = types.Array(i32, 2, "C")
390
+ bty = types.Optional(aty.copy(layout="F"))
391
+ self.assert_cannot_convert(aty, bty)
392
+ self.assert_cannot_convert(bty, aty)
393
+
394
+
395
+ class TestResolveOverload(unittest.TestCase):
396
+ """
397
+ Tests for typing.Context.resolve_overload().
398
+ """
399
+
400
+ def assert_resolve_overload(self, cases, args, expected):
401
+ ctx = typing.Context()
402
+ got = ctx.resolve_overload("foo", cases, args, {})
403
+ self.assertEqual(got, expected)
404
+
405
+ def test_non_ambiguous_match(self):
406
+ def check(args, expected):
407
+ self.assert_resolve_overload(cases, args, expected)
408
+ # Order shouldn't matter here
409
+ self.assert_resolve_overload(cases[::-1], args, expected)
410
+
411
+ cases = [i8(i8, i8), i32(i32, i32), f64(f64, f64)]
412
+ # Exact match
413
+ check((i8, i8), cases[0])
414
+ check((i32, i32), cases[1])
415
+ check((f64, f64), cases[2])
416
+ # "Promote" conversion
417
+ check((i8, i16), cases[1])
418
+ check((i32, i8), cases[1])
419
+ check((i32, i8), cases[1])
420
+ check((f32, f32), cases[2])
421
+ # "Safe" conversion
422
+ check((u32, u32), cases[2])
423
+ # "Unsafe" conversion
424
+ check((i64, i64), cases[2])
425
+
426
+ def test_ambiguous_match(self):
427
+ # When the best match is ambiguous (there is a tie), the first
428
+ # best case in original sequence order should be returned.
429
+ def check(args, expected, expected_reverse):
430
+ self.assert_resolve_overload(cases, args, expected)
431
+ self.assert_resolve_overload(cases[::-1], args, expected_reverse)
432
+
433
+ cases = [i16(i16, i16), i32(i32, i32), f64(f64, f64)]
434
+ # Two "promote" conversions
435
+ check((i8, i8), cases[0], cases[1])
436
+ # Two "safe" conversions
437
+ check((u16, u16), cases[1], cases[2])
438
+
439
+ cases = [i32(i32, i32), f32(f32, f32)]
440
+ # Two "unsafe" conversions
441
+ check((u32, u32), cases[0], cases[1])
442
+
443
+ def test_ambiguous_error(self):
444
+ ctx = typing.Context()
445
+ cases = [i16(i16, i16), i32(i32, i32)]
446
+ with self.assertRaises(TypeError) as raises:
447
+ ctx.resolve_overload(
448
+ "foo", cases, (i8, i8), {}, allow_ambiguous=False
449
+ )
450
+ self.assertEqual(
451
+ str(raises.exception).splitlines(),
452
+ [
453
+ "Ambiguous overloading for foo (int8, int8):",
454
+ "(int16, int16) -> int16",
455
+ "(int32, int32) -> int32",
456
+ ],
457
+ )
458
+
459
+
460
+ @register_pass(mutates_CFG=False, analysis_only=True)
461
+ class DummyCR(FunctionPass):
462
+ """Dummy pass to add "cr" to compiler state to avoid errors in TyperCompiler since
463
+ it doesn't have lowering.
464
+ """
465
+
466
+ _name = "dummy_cr"
467
+
468
+ def __init__(self):
469
+ FunctionPass.__init__(self)
470
+
471
+ def run_pass(self, state):
472
+ state.cr = 1 # arbitrary non-None value
473
+ return True
474
+
475
+
476
+ def get_func_typing_errs(func, arg_types):
477
+ """
478
+ Get typing errors for function 'func'. It creates a pipeline that runs untyped
479
+ passes as well as type inference.
480
+ """
481
+ from numba.cuda.compiler import CompilerBase, PassManager
482
+
483
+ class TyperCompiler(CompilerBase):
484
+ """A compiler pipeline that skips passes after typing (provides partial typing info
485
+ but not lowering).
486
+ """
487
+
488
+ def define_pipelines(self):
489
+ pm = PassManager("custom_pipeline")
490
+ pm.add_pass(TranslateByteCode, "analyzing bytecode")
491
+ pm.add_pass(IRProcessing, "processing IR")
492
+ pm.add_pass(PartialTypeInference, "do partial typing")
493
+ pm.add_pass_after(DummyCR, PartialTypeInference)
494
+ pm.finalize()
495
+ return [pm]
496
+
497
+ from numba.cuda.descriptor import cuda_target
498
+
499
+ typingctx = cuda_target.typing_context
500
+ targetctx = cuda_target.target_context
501
+ library = None
502
+ return_type = None
503
+ _locals = {}
504
+ flags = Flags()
505
+
506
+ pipeline = TyperCompiler(
507
+ typingctx, targetctx, library, arg_types, return_type, flags, _locals
508
+ )
509
+ pipeline.compile_extra(func)
510
+ return pipeline.state.typing_errors
511
+
512
+
513
+ @skip_on_cudasim
514
+ class TestPartialTypingErrors(CUDATestCase):
515
+ """
516
+ Make sure partial typing stores type errors in compiler state properly
517
+ """
518
+
519
+ def test_partial_typing_error(self):
520
+ # example with type unification error
521
+ def impl(flag):
522
+ if flag:
523
+ a = 1
524
+ else:
525
+ a = ""
526
+ return a
527
+
528
+ self.assertRaisesRegex(
529
+ errors.TypingError,
530
+ r"Cannot unify Literal\[int]\(1\) and Literal\[str]\(\) for 'a'",
531
+ get_func_typing_errs,
532
+ impl,
533
+ (types.bool_,),
534
+ )
535
+
536
+
537
+ if __name__ == "__main__":
538
+ unittest.main()