numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +4 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,538 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
import itertools
|
|
5
|
+
|
|
6
|
+
from numba.core import errors, types, typing
|
|
7
|
+
from numba.core.typeconv import Conversion
|
|
8
|
+
|
|
9
|
+
from numba.cuda.testing import CUDATestCase, skip_on_cudasim
|
|
10
|
+
from numba.tests.test_typeconv import CompatibilityTestMixin
|
|
11
|
+
from numba.cuda.core.untyped_passes import TranslateByteCode, IRProcessing
|
|
12
|
+
from numba.cuda.core.typed_passes import PartialTypeInference
|
|
13
|
+
from numba.cuda.core.compiler_machinery import FunctionPass, register_pass
|
|
14
|
+
import unittest
|
|
15
|
+
|
|
16
|
+
from numba.cuda.flags import Flags
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
i8 = types.int8
|
|
20
|
+
i16 = types.int16
|
|
21
|
+
i32 = types.int32
|
|
22
|
+
i64 = types.int64
|
|
23
|
+
u8 = types.uint8
|
|
24
|
+
u16 = types.uint16
|
|
25
|
+
u32 = types.uint32
|
|
26
|
+
u64 = types.uint64
|
|
27
|
+
f32 = types.float32
|
|
28
|
+
f64 = types.float64
|
|
29
|
+
c64 = types.complex64
|
|
30
|
+
c128 = types.complex128
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class TestUnify(CUDATestCase):
|
|
34
|
+
"""
|
|
35
|
+
Tests for type unification with a typing context.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
int_unify = {
|
|
39
|
+
("uint8", "uint8"): "uint8",
|
|
40
|
+
("int8", "int8"): "int8",
|
|
41
|
+
("uint16", "uint16"): "uint16",
|
|
42
|
+
("int16", "int16"): "int16",
|
|
43
|
+
("uint32", "uint32"): "uint32",
|
|
44
|
+
("int32", "int32"): "int32",
|
|
45
|
+
("uint64", "uint64"): "uint64",
|
|
46
|
+
("int64", "int64"): "int64",
|
|
47
|
+
("int8", "uint8"): "int16",
|
|
48
|
+
("int8", "uint16"): "int32",
|
|
49
|
+
("int8", "uint32"): "int64",
|
|
50
|
+
("uint8", "int32"): "int32",
|
|
51
|
+
("uint8", "uint64"): "uint64",
|
|
52
|
+
("int16", "int8"): "int16",
|
|
53
|
+
("int16", "uint8"): "int16",
|
|
54
|
+
("int16", "uint16"): "int32",
|
|
55
|
+
("int16", "uint32"): "int64",
|
|
56
|
+
("int16", "int64"): "int64",
|
|
57
|
+
("int16", "uint64"): "float64",
|
|
58
|
+
("uint16", "uint8"): "uint16",
|
|
59
|
+
("uint16", "uint32"): "uint32",
|
|
60
|
+
("uint16", "int32"): "int32",
|
|
61
|
+
("uint16", "uint64"): "uint64",
|
|
62
|
+
("int32", "int8"): "int32",
|
|
63
|
+
("int32", "int16"): "int32",
|
|
64
|
+
("int32", "uint32"): "int64",
|
|
65
|
+
("int32", "int64"): "int64",
|
|
66
|
+
("uint32", "uint8"): "uint32",
|
|
67
|
+
("uint32", "int64"): "int64",
|
|
68
|
+
("uint32", "uint64"): "uint64",
|
|
69
|
+
("int64", "int8"): "int64",
|
|
70
|
+
("int64", "uint8"): "int64",
|
|
71
|
+
("int64", "uint16"): "int64",
|
|
72
|
+
("uint64", "int8"): "float64",
|
|
73
|
+
("uint64", "int32"): "float64",
|
|
74
|
+
("uint64", "int64"): "float64",
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
def assert_unify(self, aty, bty, expected):
|
|
78
|
+
ctx = typing.Context()
|
|
79
|
+
template = "{0}, {1} -> {2} != {3}"
|
|
80
|
+
for unify_func in ctx.unify_types, ctx.unify_pairs:
|
|
81
|
+
unified = unify_func(aty, bty)
|
|
82
|
+
self.assertEqual(
|
|
83
|
+
unified,
|
|
84
|
+
expected,
|
|
85
|
+
msg=template.format(aty, bty, unified, expected),
|
|
86
|
+
)
|
|
87
|
+
unified = unify_func(bty, aty)
|
|
88
|
+
self.assertEqual(
|
|
89
|
+
unified,
|
|
90
|
+
expected,
|
|
91
|
+
msg=template.format(bty, aty, unified, expected),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def assert_unify_failure(self, aty, bty):
|
|
95
|
+
self.assert_unify(aty, bty, None)
|
|
96
|
+
|
|
97
|
+
def test_integer(self):
|
|
98
|
+
ctx = typing.Context() # noqa: F841
|
|
99
|
+
for aty, bty in itertools.product(
|
|
100
|
+
types.integer_domain, types.integer_domain
|
|
101
|
+
):
|
|
102
|
+
key = (str(aty), str(bty))
|
|
103
|
+
try:
|
|
104
|
+
expected = self.int_unify[key]
|
|
105
|
+
except KeyError:
|
|
106
|
+
expected = self.int_unify[key[::-1]]
|
|
107
|
+
self.assert_unify(aty, bty, getattr(types, expected))
|
|
108
|
+
|
|
109
|
+
def test_bool(self):
|
|
110
|
+
aty = types.boolean
|
|
111
|
+
for bty in types.integer_domain:
|
|
112
|
+
self.assert_unify(aty, bty, bty)
|
|
113
|
+
# Not sure about this one, but it respects transitivity
|
|
114
|
+
for cty in types.real_domain:
|
|
115
|
+
self.assert_unify(aty, cty, cty)
|
|
116
|
+
|
|
117
|
+
def unify_number_pair_test(self, n):
|
|
118
|
+
"""
|
|
119
|
+
Test all permutations of N-combinations of numeric types and ensure
|
|
120
|
+
that the order of types in the sequence is irrelevant.
|
|
121
|
+
"""
|
|
122
|
+
ctx = typing.Context()
|
|
123
|
+
for tys in itertools.combinations(types.number_domain, n):
|
|
124
|
+
res = [
|
|
125
|
+
ctx.unify_types(*comb) for comb in itertools.permutations(tys)
|
|
126
|
+
]
|
|
127
|
+
first_result = res[0]
|
|
128
|
+
# Sanity check
|
|
129
|
+
self.assertIsInstance(first_result, types.Number)
|
|
130
|
+
# All results must be equal
|
|
131
|
+
for other in res[1:]:
|
|
132
|
+
self.assertEqual(first_result, other)
|
|
133
|
+
|
|
134
|
+
def test_unify_number_pair(self):
|
|
135
|
+
self.unify_number_pair_test(2)
|
|
136
|
+
self.unify_number_pair_test(3)
|
|
137
|
+
|
|
138
|
+
def test_none_to_optional(self):
|
|
139
|
+
"""
|
|
140
|
+
Test unification of `none` and multiple number types to optional type
|
|
141
|
+
"""
|
|
142
|
+
ctx = typing.Context()
|
|
143
|
+
for tys in itertools.combinations(types.number_domain, 2):
|
|
144
|
+
# First unify without none, to provide the control value
|
|
145
|
+
tys = list(tys)
|
|
146
|
+
expected = types.Optional(ctx.unify_types(*tys))
|
|
147
|
+
results = [
|
|
148
|
+
ctx.unify_types(*comb)
|
|
149
|
+
for comb in itertools.permutations(tys + [types.none])
|
|
150
|
+
]
|
|
151
|
+
# All results must be equal
|
|
152
|
+
for res in results:
|
|
153
|
+
self.assertEqual(res, expected)
|
|
154
|
+
|
|
155
|
+
def test_none(self):
|
|
156
|
+
aty = types.none
|
|
157
|
+
bty = types.none
|
|
158
|
+
self.assert_unify(aty, bty, types.none)
|
|
159
|
+
|
|
160
|
+
def test_optional(self):
|
|
161
|
+
aty = types.Optional(i32)
|
|
162
|
+
bty = types.none
|
|
163
|
+
self.assert_unify(aty, bty, aty)
|
|
164
|
+
aty = types.Optional(i32)
|
|
165
|
+
bty = types.Optional(i64)
|
|
166
|
+
self.assert_unify(aty, bty, bty)
|
|
167
|
+
aty = types.Optional(i32)
|
|
168
|
+
bty = i64
|
|
169
|
+
self.assert_unify(aty, bty, types.Optional(i64))
|
|
170
|
+
# Failure
|
|
171
|
+
aty = types.Optional(i32)
|
|
172
|
+
bty = types.Optional(types.slice3_type)
|
|
173
|
+
self.assert_unify_failure(aty, bty)
|
|
174
|
+
|
|
175
|
+
def test_tuple(self):
|
|
176
|
+
aty = types.UniTuple(i32, 3)
|
|
177
|
+
bty = types.UniTuple(i64, 3)
|
|
178
|
+
self.assert_unify(aty, bty, types.UniTuple(i64, 3))
|
|
179
|
+
# (Tuple, UniTuple) -> Tuple
|
|
180
|
+
aty = types.UniTuple(i32, 2)
|
|
181
|
+
bty = types.Tuple((i16, i64))
|
|
182
|
+
self.assert_unify(aty, bty, types.Tuple((i32, i64)))
|
|
183
|
+
aty = types.UniTuple(i64, 0)
|
|
184
|
+
bty = types.Tuple(())
|
|
185
|
+
self.assert_unify(aty, bty, bty)
|
|
186
|
+
# (Tuple, Tuple) -> Tuple
|
|
187
|
+
aty = types.Tuple((i8, i16, i32))
|
|
188
|
+
bty = types.Tuple((i32, i16, i8))
|
|
189
|
+
self.assert_unify(aty, bty, types.Tuple((i32, i16, i32)))
|
|
190
|
+
aty = types.Tuple((i8, i32))
|
|
191
|
+
bty = types.Tuple((i32, i8))
|
|
192
|
+
self.assert_unify(aty, bty, types.Tuple((i32, i32)))
|
|
193
|
+
aty = types.Tuple((i8, i16))
|
|
194
|
+
bty = types.Tuple((i16, i8))
|
|
195
|
+
self.assert_unify(aty, bty, types.Tuple((i16, i16)))
|
|
196
|
+
# Different number kinds
|
|
197
|
+
aty = types.UniTuple(f64, 3)
|
|
198
|
+
bty = types.UniTuple(c64, 3)
|
|
199
|
+
self.assert_unify(aty, bty, types.UniTuple(c128, 3))
|
|
200
|
+
# Tuples of tuples
|
|
201
|
+
aty = types.UniTuple(types.Tuple((u32, f32)), 2)
|
|
202
|
+
bty = types.UniTuple(types.Tuple((i16, f32)), 2)
|
|
203
|
+
self.assert_unify(aty, bty, types.UniTuple(types.Tuple((i64, f32)), 2))
|
|
204
|
+
# Failures
|
|
205
|
+
aty = types.UniTuple(i32, 1)
|
|
206
|
+
bty = types.UniTuple(types.slice3_type, 1)
|
|
207
|
+
self.assert_unify_failure(aty, bty)
|
|
208
|
+
aty = types.UniTuple(i32, 1)
|
|
209
|
+
bty = types.UniTuple(i32, 2)
|
|
210
|
+
self.assert_unify_failure(aty, bty)
|
|
211
|
+
aty = types.Tuple((i8, types.slice3_type))
|
|
212
|
+
bty = types.Tuple((i32, i8))
|
|
213
|
+
self.assert_unify_failure(aty, bty)
|
|
214
|
+
|
|
215
|
+
def test_optional_tuple(self):
|
|
216
|
+
# Unify to optional tuple
|
|
217
|
+
aty = types.none
|
|
218
|
+
bty = types.UniTuple(i32, 2)
|
|
219
|
+
self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
|
|
220
|
+
aty = types.Optional(types.UniTuple(i16, 2))
|
|
221
|
+
bty = types.UniTuple(i32, 2)
|
|
222
|
+
self.assert_unify(aty, bty, types.Optional(types.UniTuple(i32, 2)))
|
|
223
|
+
# Unify to tuple of optionals
|
|
224
|
+
aty = types.Tuple((types.none, i32))
|
|
225
|
+
bty = types.Tuple((i16, types.none))
|
|
226
|
+
self.assert_unify(
|
|
227
|
+
aty, bty, types.Tuple((types.Optional(i16), types.Optional(i32)))
|
|
228
|
+
)
|
|
229
|
+
aty = types.Tuple((types.Optional(i32), i64))
|
|
230
|
+
bty = types.Tuple((i16, types.Optional(i8)))
|
|
231
|
+
self.assert_unify(
|
|
232
|
+
aty, bty, types.Tuple((types.Optional(i32), types.Optional(i64)))
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def test_arrays(self):
|
|
236
|
+
aty = types.Array(i32, 3, "C")
|
|
237
|
+
bty = types.Array(i32, 3, "A")
|
|
238
|
+
self.assert_unify(aty, bty, bty)
|
|
239
|
+
aty = types.Array(i32, 3, "C")
|
|
240
|
+
bty = types.Array(i32, 3, "F")
|
|
241
|
+
self.assert_unify(aty, bty, types.Array(i32, 3, "A"))
|
|
242
|
+
aty = types.Array(i32, 3, "C")
|
|
243
|
+
bty = types.Array(i32, 3, "C", readonly=True)
|
|
244
|
+
self.assert_unify(aty, bty, bty)
|
|
245
|
+
aty = types.Array(i32, 3, "A")
|
|
246
|
+
bty = types.Array(i32, 3, "C", readonly=True)
|
|
247
|
+
self.assert_unify(aty, bty, types.Array(i32, 3, "A", readonly=True))
|
|
248
|
+
# Failures
|
|
249
|
+
aty = types.Array(i32, 2, "C")
|
|
250
|
+
bty = types.Array(i32, 3, "C")
|
|
251
|
+
self.assert_unify_failure(aty, bty)
|
|
252
|
+
aty = types.Array(i32, 2, "C")
|
|
253
|
+
bty = types.Array(u32, 2, "C")
|
|
254
|
+
self.assert_unify_failure(aty, bty)
|
|
255
|
+
|
|
256
|
+
def test_list(self):
|
|
257
|
+
aty = types.List(types.undefined)
|
|
258
|
+
bty = types.List(i32)
|
|
259
|
+
self.assert_unify(aty, bty, bty)
|
|
260
|
+
aty = types.List(i16)
|
|
261
|
+
bty = types.List(i32)
|
|
262
|
+
self.assert_unify(aty, bty, bty)
|
|
263
|
+
aty = types.List(types.Tuple([i32, i16]))
|
|
264
|
+
bty = types.List(types.Tuple([i16, i64]))
|
|
265
|
+
cty = types.List(types.Tuple([i32, i64]))
|
|
266
|
+
self.assert_unify(aty, bty, cty)
|
|
267
|
+
# Different reflections
|
|
268
|
+
aty = types.List(i16, reflected=True)
|
|
269
|
+
bty = types.List(i32)
|
|
270
|
+
cty = types.List(i32, reflected=True)
|
|
271
|
+
self.assert_unify(aty, bty, cty)
|
|
272
|
+
# Incompatible dtypes
|
|
273
|
+
aty = types.List(i16)
|
|
274
|
+
bty = types.List(types.Tuple([i16]))
|
|
275
|
+
self.assert_unify_failure(aty, bty)
|
|
276
|
+
|
|
277
|
+
def test_set(self):
|
|
278
|
+
# Different reflections
|
|
279
|
+
aty = types.Set(i16, reflected=True)
|
|
280
|
+
bty = types.Set(i32)
|
|
281
|
+
cty = types.Set(i32, reflected=True)
|
|
282
|
+
self.assert_unify(aty, bty, cty)
|
|
283
|
+
# Incompatible dtypes
|
|
284
|
+
aty = types.Set(i16)
|
|
285
|
+
bty = types.Set(types.Tuple([i16]))
|
|
286
|
+
self.assert_unify_failure(aty, bty)
|
|
287
|
+
|
|
288
|
+
def test_range(self):
|
|
289
|
+
aty = types.range_state32_type
|
|
290
|
+
bty = types.range_state64_type
|
|
291
|
+
self.assert_unify(aty, bty, bty)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class TestTypeConversion(CompatibilityTestMixin, unittest.TestCase):
|
|
295
|
+
"""
|
|
296
|
+
Test for conversion between types with a typing context.
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
def assert_can_convert(self, aty, bty, expected):
|
|
300
|
+
ctx = typing.Context()
|
|
301
|
+
got = ctx.can_convert(aty, bty)
|
|
302
|
+
self.assertEqual(got, expected)
|
|
303
|
+
|
|
304
|
+
def assert_cannot_convert(self, aty, bty):
|
|
305
|
+
ctx = typing.Context()
|
|
306
|
+
got = ctx.can_convert(aty, bty)
|
|
307
|
+
self.assertIsNone(got)
|
|
308
|
+
|
|
309
|
+
def test_convert_number_types(self):
|
|
310
|
+
# Check that Context.can_convert() is compatible with the default
|
|
311
|
+
# number conversion rules registered in the typeconv module
|
|
312
|
+
# (which is used internally by the C _Dispatcher object).
|
|
313
|
+
ctx = typing.Context()
|
|
314
|
+
self.check_number_compatibility(ctx.can_convert)
|
|
315
|
+
|
|
316
|
+
def test_tuple(self):
|
|
317
|
+
# UniTuple -> UniTuple
|
|
318
|
+
aty = types.UniTuple(i32, 3)
|
|
319
|
+
bty = types.UniTuple(i64, 3)
|
|
320
|
+
self.assert_can_convert(aty, aty, Conversion.exact)
|
|
321
|
+
self.assert_can_convert(aty, bty, Conversion.promote)
|
|
322
|
+
aty = types.UniTuple(i32, 3)
|
|
323
|
+
bty = types.UniTuple(f64, 3)
|
|
324
|
+
self.assert_can_convert(aty, bty, Conversion.safe)
|
|
325
|
+
# Tuple -> Tuple
|
|
326
|
+
aty = types.Tuple((i32, i32))
|
|
327
|
+
bty = types.Tuple((i32, i64))
|
|
328
|
+
self.assert_can_convert(aty, bty, Conversion.promote)
|
|
329
|
+
# UniTuple <-> Tuple
|
|
330
|
+
aty = types.UniTuple(i32, 2)
|
|
331
|
+
bty = types.Tuple((i32, i64))
|
|
332
|
+
self.assert_can_convert(aty, bty, Conversion.promote)
|
|
333
|
+
self.assert_can_convert(bty, aty, Conversion.unsafe)
|
|
334
|
+
# Empty tuples
|
|
335
|
+
aty = types.UniTuple(i64, 0)
|
|
336
|
+
bty = types.UniTuple(i32, 0)
|
|
337
|
+
cty = types.Tuple(())
|
|
338
|
+
self.assert_can_convert(aty, bty, Conversion.safe)
|
|
339
|
+
self.assert_can_convert(bty, aty, Conversion.safe)
|
|
340
|
+
self.assert_can_convert(aty, cty, Conversion.safe)
|
|
341
|
+
self.assert_can_convert(cty, aty, Conversion.safe)
|
|
342
|
+
# Failures
|
|
343
|
+
aty = types.UniTuple(i64, 3)
|
|
344
|
+
bty = types.UniTuple(types.none, 3)
|
|
345
|
+
self.assert_cannot_convert(aty, bty)
|
|
346
|
+
aty = types.UniTuple(i64, 2)
|
|
347
|
+
bty = types.UniTuple(i64, 3)
|
|
348
|
+
|
|
349
|
+
def test_arrays(self):
|
|
350
|
+
# Different layouts
|
|
351
|
+
aty = types.Array(i32, 3, "C")
|
|
352
|
+
bty = types.Array(i32, 3, "A")
|
|
353
|
+
self.assert_can_convert(aty, bty, Conversion.safe)
|
|
354
|
+
aty = types.Array(i32, 2, "C")
|
|
355
|
+
bty = types.Array(i32, 2, "F")
|
|
356
|
+
self.assert_cannot_convert(aty, bty)
|
|
357
|
+
# Different mutabilities
|
|
358
|
+
aty = types.Array(i32, 3, "C")
|
|
359
|
+
bty = types.Array(i32, 3, "C", readonly=True)
|
|
360
|
+
self.assert_can_convert(aty, aty, Conversion.exact)
|
|
361
|
+
self.assert_can_convert(bty, bty, Conversion.exact)
|
|
362
|
+
self.assert_can_convert(aty, bty, Conversion.safe)
|
|
363
|
+
self.assert_cannot_convert(bty, aty)
|
|
364
|
+
# Various failures
|
|
365
|
+
aty = types.Array(i32, 2, "C")
|
|
366
|
+
bty = types.Array(i32, 3, "C")
|
|
367
|
+
self.assert_cannot_convert(aty, bty)
|
|
368
|
+
aty = types.Array(i32, 2, "C")
|
|
369
|
+
bty = types.Array(i64, 2, "C")
|
|
370
|
+
self.assert_cannot_convert(aty, bty)
|
|
371
|
+
|
|
372
|
+
def test_optional(self):
|
|
373
|
+
aty = types.int32
|
|
374
|
+
bty = types.Optional(i32)
|
|
375
|
+
self.assert_can_convert(types.none, bty, Conversion.promote)
|
|
376
|
+
self.assert_can_convert(aty, bty, Conversion.promote)
|
|
377
|
+
self.assert_cannot_convert(bty, types.none)
|
|
378
|
+
self.assert_can_convert(bty, aty, Conversion.safe) # XXX ???
|
|
379
|
+
# Optional array
|
|
380
|
+
aty = types.Array(i32, 2, "C")
|
|
381
|
+
bty = types.Optional(aty)
|
|
382
|
+
self.assert_can_convert(types.none, bty, Conversion.promote)
|
|
383
|
+
self.assert_can_convert(aty, bty, Conversion.promote)
|
|
384
|
+
self.assert_can_convert(bty, aty, Conversion.safe)
|
|
385
|
+
aty = types.Array(i32, 2, "C")
|
|
386
|
+
bty = types.Optional(aty.copy(layout="A"))
|
|
387
|
+
self.assert_can_convert(aty, bty, Conversion.safe) # C -> A
|
|
388
|
+
self.assert_cannot_convert(bty, aty) # A -> C
|
|
389
|
+
aty = types.Array(i32, 2, "C")
|
|
390
|
+
bty = types.Optional(aty.copy(layout="F"))
|
|
391
|
+
self.assert_cannot_convert(aty, bty)
|
|
392
|
+
self.assert_cannot_convert(bty, aty)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
class TestResolveOverload(unittest.TestCase):
|
|
396
|
+
"""
|
|
397
|
+
Tests for typing.Context.resolve_overload().
|
|
398
|
+
"""
|
|
399
|
+
|
|
400
|
+
def assert_resolve_overload(self, cases, args, expected):
|
|
401
|
+
ctx = typing.Context()
|
|
402
|
+
got = ctx.resolve_overload("foo", cases, args, {})
|
|
403
|
+
self.assertEqual(got, expected)
|
|
404
|
+
|
|
405
|
+
def test_non_ambiguous_match(self):
|
|
406
|
+
def check(args, expected):
|
|
407
|
+
self.assert_resolve_overload(cases, args, expected)
|
|
408
|
+
# Order shouldn't matter here
|
|
409
|
+
self.assert_resolve_overload(cases[::-1], args, expected)
|
|
410
|
+
|
|
411
|
+
cases = [i8(i8, i8), i32(i32, i32), f64(f64, f64)]
|
|
412
|
+
# Exact match
|
|
413
|
+
check((i8, i8), cases[0])
|
|
414
|
+
check((i32, i32), cases[1])
|
|
415
|
+
check((f64, f64), cases[2])
|
|
416
|
+
# "Promote" conversion
|
|
417
|
+
check((i8, i16), cases[1])
|
|
418
|
+
check((i32, i8), cases[1])
|
|
419
|
+
check((i32, i8), cases[1])
|
|
420
|
+
check((f32, f32), cases[2])
|
|
421
|
+
# "Safe" conversion
|
|
422
|
+
check((u32, u32), cases[2])
|
|
423
|
+
# "Unsafe" conversion
|
|
424
|
+
check((i64, i64), cases[2])
|
|
425
|
+
|
|
426
|
+
def test_ambiguous_match(self):
|
|
427
|
+
# When the best match is ambiguous (there is a tie), the first
|
|
428
|
+
# best case in original sequence order should be returned.
|
|
429
|
+
def check(args, expected, expected_reverse):
|
|
430
|
+
self.assert_resolve_overload(cases, args, expected)
|
|
431
|
+
self.assert_resolve_overload(cases[::-1], args, expected_reverse)
|
|
432
|
+
|
|
433
|
+
cases = [i16(i16, i16), i32(i32, i32), f64(f64, f64)]
|
|
434
|
+
# Two "promote" conversions
|
|
435
|
+
check((i8, i8), cases[0], cases[1])
|
|
436
|
+
# Two "safe" conversions
|
|
437
|
+
check((u16, u16), cases[1], cases[2])
|
|
438
|
+
|
|
439
|
+
cases = [i32(i32, i32), f32(f32, f32)]
|
|
440
|
+
# Two "unsafe" conversions
|
|
441
|
+
check((u32, u32), cases[0], cases[1])
|
|
442
|
+
|
|
443
|
+
def test_ambiguous_error(self):
|
|
444
|
+
ctx = typing.Context()
|
|
445
|
+
cases = [i16(i16, i16), i32(i32, i32)]
|
|
446
|
+
with self.assertRaises(TypeError) as raises:
|
|
447
|
+
ctx.resolve_overload(
|
|
448
|
+
"foo", cases, (i8, i8), {}, allow_ambiguous=False
|
|
449
|
+
)
|
|
450
|
+
self.assertEqual(
|
|
451
|
+
str(raises.exception).splitlines(),
|
|
452
|
+
[
|
|
453
|
+
"Ambiguous overloading for foo (int8, int8):",
|
|
454
|
+
"(int16, int16) -> int16",
|
|
455
|
+
"(int32, int32) -> int32",
|
|
456
|
+
],
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
@register_pass(mutates_CFG=False, analysis_only=True)
|
|
461
|
+
class DummyCR(FunctionPass):
|
|
462
|
+
"""Dummy pass to add "cr" to compiler state to avoid errors in TyperCompiler since
|
|
463
|
+
it doesn't have lowering.
|
|
464
|
+
"""
|
|
465
|
+
|
|
466
|
+
_name = "dummy_cr"
|
|
467
|
+
|
|
468
|
+
def __init__(self):
|
|
469
|
+
FunctionPass.__init__(self)
|
|
470
|
+
|
|
471
|
+
def run_pass(self, state):
|
|
472
|
+
state.cr = 1 # arbitrary non-None value
|
|
473
|
+
return True
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def get_func_typing_errs(func, arg_types):
|
|
477
|
+
"""
|
|
478
|
+
Get typing errors for function 'func'. It creates a pipeline that runs untyped
|
|
479
|
+
passes as well as type inference.
|
|
480
|
+
"""
|
|
481
|
+
from numba.cuda.compiler import CompilerBase, PassManager
|
|
482
|
+
|
|
483
|
+
class TyperCompiler(CompilerBase):
|
|
484
|
+
"""A compiler pipeline that skips passes after typing (provides partial typing info
|
|
485
|
+
but not lowering).
|
|
486
|
+
"""
|
|
487
|
+
|
|
488
|
+
def define_pipelines(self):
|
|
489
|
+
pm = PassManager("custom_pipeline")
|
|
490
|
+
pm.add_pass(TranslateByteCode, "analyzing bytecode")
|
|
491
|
+
pm.add_pass(IRProcessing, "processing IR")
|
|
492
|
+
pm.add_pass(PartialTypeInference, "do partial typing")
|
|
493
|
+
pm.add_pass_after(DummyCR, PartialTypeInference)
|
|
494
|
+
pm.finalize()
|
|
495
|
+
return [pm]
|
|
496
|
+
|
|
497
|
+
from numba.cuda.descriptor import cuda_target
|
|
498
|
+
|
|
499
|
+
typingctx = cuda_target.typing_context
|
|
500
|
+
targetctx = cuda_target.target_context
|
|
501
|
+
library = None
|
|
502
|
+
return_type = None
|
|
503
|
+
_locals = {}
|
|
504
|
+
flags = Flags()
|
|
505
|
+
|
|
506
|
+
pipeline = TyperCompiler(
|
|
507
|
+
typingctx, targetctx, library, arg_types, return_type, flags, _locals
|
|
508
|
+
)
|
|
509
|
+
pipeline.compile_extra(func)
|
|
510
|
+
return pipeline.state.typing_errors
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
@skip_on_cudasim
|
|
514
|
+
class TestPartialTypingErrors(CUDATestCase):
|
|
515
|
+
"""
|
|
516
|
+
Make sure partial typing stores type errors in compiler state properly
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
def test_partial_typing_error(self):
|
|
520
|
+
# example with type unification error
|
|
521
|
+
def impl(flag):
|
|
522
|
+
if flag:
|
|
523
|
+
a = 1
|
|
524
|
+
else:
|
|
525
|
+
a = ""
|
|
526
|
+
return a
|
|
527
|
+
|
|
528
|
+
self.assertRaisesRegex(
|
|
529
|
+
errors.TypingError,
|
|
530
|
+
r"Cannot unify Literal\[int]\(1\) and Literal\[str]\(\) for 'a'",
|
|
531
|
+
get_func_typing_errs,
|
|
532
|
+
impl,
|
|
533
|
+
(types.bool_,),
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
if __name__ == "__main__":
|
|
538
|
+
unittest.main()
|