numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +7 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,751 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from collections.abc import Sequence
|
|
6
|
+
import types as pytypes
|
|
7
|
+
import weakref
|
|
8
|
+
import threading
|
|
9
|
+
import contextlib
|
|
10
|
+
import operator
|
|
11
|
+
|
|
12
|
+
from numba.core import types, errors
|
|
13
|
+
from numba.core.typeconv import Conversion, rules
|
|
14
|
+
from numba.core.typing.typeof import typeof, Purpose
|
|
15
|
+
from numba.core.typing import templates
|
|
16
|
+
from numba.cuda import utils
|
|
17
|
+
from numba.cuda.utils import order_by_target_specificity
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Rating(object):
|
|
21
|
+
__slots__ = "promote", "safe_convert", "unsafe_convert"
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
self.promote = 0
|
|
25
|
+
self.safe_convert = 0
|
|
26
|
+
self.unsafe_convert = 0
|
|
27
|
+
|
|
28
|
+
def astuple(self):
|
|
29
|
+
"""Returns a tuple suitable for comparing with the worse situation
|
|
30
|
+
start first.
|
|
31
|
+
"""
|
|
32
|
+
return (self.unsafe_convert, self.safe_convert, self.promote)
|
|
33
|
+
|
|
34
|
+
def __add__(self, other):
|
|
35
|
+
if type(self) is not type(other):
|
|
36
|
+
return NotImplemented
|
|
37
|
+
rsum = Rating()
|
|
38
|
+
rsum.promote = self.promote + other.promote
|
|
39
|
+
rsum.safe_convert = self.safe_convert + other.safe_convert
|
|
40
|
+
rsum.unsafe_convert = self.unsafe_convert + other.unsafe_convert
|
|
41
|
+
return rsum
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class CallStack(Sequence):
|
|
45
|
+
"""
|
|
46
|
+
A compile-time call stack
|
|
47
|
+
"""
|
|
48
|
+
|
|
49
|
+
def __init__(self):
|
|
50
|
+
self._stack = []
|
|
51
|
+
self._lock = threading.RLock()
|
|
52
|
+
|
|
53
|
+
def __getitem__(self, index):
|
|
54
|
+
"""
|
|
55
|
+
Returns item in the stack where index=0 is the top and index=1 is
|
|
56
|
+
the second item from the top.
|
|
57
|
+
"""
|
|
58
|
+
return self._stack[len(self) - index - 1]
|
|
59
|
+
|
|
60
|
+
def __len__(self):
|
|
61
|
+
return len(self._stack)
|
|
62
|
+
|
|
63
|
+
@contextlib.contextmanager
|
|
64
|
+
def register(self, target, typeinfer, func_id, args):
|
|
65
|
+
# guard compiling the same function with the same signature
|
|
66
|
+
if self.match(func_id.func, args):
|
|
67
|
+
msg = "compiler re-entrant to the same function signature"
|
|
68
|
+
raise errors.NumbaRuntimeError(msg)
|
|
69
|
+
self._lock.acquire()
|
|
70
|
+
self._stack.append(CallFrame(target, typeinfer, func_id, args))
|
|
71
|
+
try:
|
|
72
|
+
yield
|
|
73
|
+
finally:
|
|
74
|
+
self._stack.pop()
|
|
75
|
+
self._lock.release()
|
|
76
|
+
|
|
77
|
+
def finditer(self, py_func):
|
|
78
|
+
"""
|
|
79
|
+
Yields frame that matches the function object starting from the top
|
|
80
|
+
of stack.
|
|
81
|
+
"""
|
|
82
|
+
for frame in self:
|
|
83
|
+
if frame.func_id.func is py_func:
|
|
84
|
+
yield frame
|
|
85
|
+
|
|
86
|
+
def findfirst(self, py_func):
|
|
87
|
+
"""
|
|
88
|
+
Returns the first result from `.finditer(py_func)`; or None if no match.
|
|
89
|
+
"""
|
|
90
|
+
try:
|
|
91
|
+
return next(self.finditer(py_func))
|
|
92
|
+
except StopIteration:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
def match(self, py_func, args):
|
|
96
|
+
"""
|
|
97
|
+
Returns first function that matches *py_func* and the arguments types in
|
|
98
|
+
*args*; or, None if no match.
|
|
99
|
+
"""
|
|
100
|
+
for frame in self.finditer(py_func):
|
|
101
|
+
if frame.args == args:
|
|
102
|
+
return frame
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class CallFrame(object):
|
|
106
|
+
"""
|
|
107
|
+
A compile-time call frame
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
def __init__(self, target, typeinfer, func_id, args):
|
|
111
|
+
self.typeinfer = typeinfer
|
|
112
|
+
self.func_id = func_id
|
|
113
|
+
self.args = args
|
|
114
|
+
self.target = target
|
|
115
|
+
self._inferred_retty = set()
|
|
116
|
+
|
|
117
|
+
def __repr__(self):
|
|
118
|
+
return "CallFrame({}, {})".format(self.func_id, self.args)
|
|
119
|
+
|
|
120
|
+
def add_return_type(self, return_type):
|
|
121
|
+
"""Add *return_type* to the list of inferred return-types.
|
|
122
|
+
If there are too many, raise `TypingError`.
|
|
123
|
+
"""
|
|
124
|
+
# The maximum limit is picked arbitrarily.
|
|
125
|
+
# Don't think that this needs to be user configurable.
|
|
126
|
+
RETTY_LIMIT = 16
|
|
127
|
+
self._inferred_retty.add(return_type)
|
|
128
|
+
if len(self._inferred_retty) >= RETTY_LIMIT:
|
|
129
|
+
m = "Return type of recursive function does not converge"
|
|
130
|
+
raise errors.TypingError(m)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class BaseContext(object):
|
|
134
|
+
"""A typing context for storing function typing constrain template."""
|
|
135
|
+
|
|
136
|
+
def __init__(self):
|
|
137
|
+
# A list of installed registries
|
|
138
|
+
self._registries = {}
|
|
139
|
+
# Typing declarations extracted from the registries or other sources
|
|
140
|
+
self._functions = defaultdict(list)
|
|
141
|
+
self._attributes = defaultdict(list)
|
|
142
|
+
self._globals = utils.UniqueDict()
|
|
143
|
+
self.tm = rules.default_type_manager
|
|
144
|
+
self.callstack = CallStack()
|
|
145
|
+
|
|
146
|
+
# Initialize
|
|
147
|
+
self.init()
|
|
148
|
+
|
|
149
|
+
def init(self):
|
|
150
|
+
"""
|
|
151
|
+
Initialize the typing context. Can be overridden by subclasses.
|
|
152
|
+
"""
|
|
153
|
+
|
|
154
|
+
def refresh(self):
|
|
155
|
+
"""
|
|
156
|
+
Refresh context with new declarations from known registries.
|
|
157
|
+
Useful for third-party extensions.
|
|
158
|
+
"""
|
|
159
|
+
self.load_additional_registries()
|
|
160
|
+
# Some extensions may have augmented the builtin registry
|
|
161
|
+
self._load_builtins()
|
|
162
|
+
|
|
163
|
+
def explain_function_type(self, func):
|
|
164
|
+
"""
|
|
165
|
+
Returns a string description of the type of a function
|
|
166
|
+
"""
|
|
167
|
+
desc = []
|
|
168
|
+
defns = []
|
|
169
|
+
param = False
|
|
170
|
+
if isinstance(func, types.Callable):
|
|
171
|
+
sigs, param = func.get_call_signatures()
|
|
172
|
+
defns.extend(sigs)
|
|
173
|
+
|
|
174
|
+
elif func in self._functions:
|
|
175
|
+
for tpl in self._functions[func]:
|
|
176
|
+
param = param or hasattr(tpl, "generic")
|
|
177
|
+
defns.extend(getattr(tpl, "cases", []))
|
|
178
|
+
|
|
179
|
+
else:
|
|
180
|
+
msg = "No type info available for {func!r} as a callable."
|
|
181
|
+
desc.append(msg.format(func=func))
|
|
182
|
+
|
|
183
|
+
if defns:
|
|
184
|
+
desc = ["Known signatures:"]
|
|
185
|
+
for sig in defns:
|
|
186
|
+
desc.append(" * {0}".format(sig))
|
|
187
|
+
|
|
188
|
+
return "\n".join(desc)
|
|
189
|
+
|
|
190
|
+
def resolve_function_type(self, func, args, kws):
|
|
191
|
+
"""
|
|
192
|
+
Resolve function type *func* for argument types *args* and *kws*.
|
|
193
|
+
A signature is returned.
|
|
194
|
+
"""
|
|
195
|
+
# Prefer user definition first
|
|
196
|
+
try:
|
|
197
|
+
res = self._resolve_user_function_type(func, args, kws)
|
|
198
|
+
except errors.TypingError as e:
|
|
199
|
+
# Capture any typing error
|
|
200
|
+
last_exception = e
|
|
201
|
+
res = None
|
|
202
|
+
else:
|
|
203
|
+
last_exception = None
|
|
204
|
+
|
|
205
|
+
# Return early we know there's a working user function
|
|
206
|
+
if res is not None:
|
|
207
|
+
return res
|
|
208
|
+
|
|
209
|
+
# Check builtin functions
|
|
210
|
+
res = self._resolve_builtin_function_type(func, args, kws)
|
|
211
|
+
|
|
212
|
+
# Re-raise last_exception if no function type has been found
|
|
213
|
+
if res is None and last_exception is not None:
|
|
214
|
+
raise last_exception
|
|
215
|
+
|
|
216
|
+
return res
|
|
217
|
+
|
|
218
|
+
def _resolve_builtin_function_type(self, func, args, kws):
|
|
219
|
+
# NOTE: we should reduce usage of this
|
|
220
|
+
if func in self._functions:
|
|
221
|
+
# Note: Duplicating code with types.Function.get_call_type().
|
|
222
|
+
# *defns* are CallTemplates.
|
|
223
|
+
defns = self._functions[func]
|
|
224
|
+
for defn in defns:
|
|
225
|
+
for support_literals in [True, False]:
|
|
226
|
+
if support_literals:
|
|
227
|
+
res = defn.apply(args, kws)
|
|
228
|
+
else:
|
|
229
|
+
fixedargs = [types.unliteral(a) for a in args]
|
|
230
|
+
res = defn.apply(fixedargs, kws)
|
|
231
|
+
if res is not None:
|
|
232
|
+
return res
|
|
233
|
+
|
|
234
|
+
def _resolve_user_function_type(self, func, args, kws, literals=None):
|
|
235
|
+
# It's not a known function type, perhaps it's a global?
|
|
236
|
+
functy = self._lookup_global(func)
|
|
237
|
+
if functy is not None:
|
|
238
|
+
func = functy
|
|
239
|
+
|
|
240
|
+
if isinstance(func, types.Type):
|
|
241
|
+
# If it's a type, it may support a __call__ method
|
|
242
|
+
func_type = self.resolve_getattr(func, "__call__")
|
|
243
|
+
if func_type is not None:
|
|
244
|
+
# The function has a __call__ method, type its call.
|
|
245
|
+
return self.resolve_function_type(func_type, args, kws)
|
|
246
|
+
|
|
247
|
+
if isinstance(func, types.Callable):
|
|
248
|
+
# XXX fold this into the __call__ attribute logic?
|
|
249
|
+
return func.get_call_type(self, args, kws)
|
|
250
|
+
|
|
251
|
+
def _get_attribute_templates(self, typ):
|
|
252
|
+
"""
|
|
253
|
+
Get matching AttributeTemplates for the Numba type.
|
|
254
|
+
"""
|
|
255
|
+
if typ in self._attributes:
|
|
256
|
+
for attrinfo in self._attributes[typ]:
|
|
257
|
+
yield attrinfo
|
|
258
|
+
else:
|
|
259
|
+
for cls in type(typ).__mro__:
|
|
260
|
+
if cls in self._attributes:
|
|
261
|
+
for attrinfo in self._attributes[cls]:
|
|
262
|
+
yield attrinfo
|
|
263
|
+
|
|
264
|
+
def resolve_getattr(self, typ, attr):
|
|
265
|
+
"""
|
|
266
|
+
Resolve getting the attribute *attr* (a string) on the Numba type.
|
|
267
|
+
The attribute's type is returned, or None if resolution failed.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
def core(typ):
|
|
271
|
+
out = self.find_matching_getattr_template(typ, attr)
|
|
272
|
+
if out:
|
|
273
|
+
return out["return_type"]
|
|
274
|
+
|
|
275
|
+
out = core(typ)
|
|
276
|
+
if out is not None:
|
|
277
|
+
return out
|
|
278
|
+
|
|
279
|
+
# Try again without literals
|
|
280
|
+
out = core(types.unliteral(typ))
|
|
281
|
+
if out is not None:
|
|
282
|
+
return out
|
|
283
|
+
|
|
284
|
+
if isinstance(typ, types.Module):
|
|
285
|
+
attrty = self.resolve_module_constants(typ, attr)
|
|
286
|
+
if attrty is not None:
|
|
287
|
+
return attrty
|
|
288
|
+
|
|
289
|
+
def find_matching_getattr_template(self, typ, attr):
|
|
290
|
+
templates = list(self._get_attribute_templates(typ))
|
|
291
|
+
|
|
292
|
+
# get the order in which to try templates
|
|
293
|
+
from numba.core.target_extension import get_local_target # circular
|
|
294
|
+
|
|
295
|
+
target_hw = get_local_target(self)
|
|
296
|
+
order = order_by_target_specificity(target_hw, templates, fnkey=attr)
|
|
297
|
+
|
|
298
|
+
for template in order:
|
|
299
|
+
return_type = template.resolve(typ, attr)
|
|
300
|
+
if return_type is not None:
|
|
301
|
+
return {
|
|
302
|
+
"template": template,
|
|
303
|
+
"return_type": return_type,
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
def resolve_setattr(self, target, attr, value):
|
|
307
|
+
"""
|
|
308
|
+
Resolve setting the attribute *attr* (a string) on the *target* type
|
|
309
|
+
to the given *value* type.
|
|
310
|
+
A function signature is returned, or None if resolution failed.
|
|
311
|
+
"""
|
|
312
|
+
for attrinfo in self._get_attribute_templates(target):
|
|
313
|
+
expectedty = attrinfo.resolve(target, attr)
|
|
314
|
+
# NOTE: convertibility from *value* to *expectedty* is left to
|
|
315
|
+
# the caller.
|
|
316
|
+
if expectedty is not None:
|
|
317
|
+
return templates.signature(types.void, target, expectedty)
|
|
318
|
+
|
|
319
|
+
def resolve_static_getitem(self, value, index):
|
|
320
|
+
assert not isinstance(index, types.Type), index
|
|
321
|
+
args = value, index
|
|
322
|
+
kws = ()
|
|
323
|
+
return self.resolve_function_type("static_getitem", args, kws)
|
|
324
|
+
|
|
325
|
+
def resolve_static_setitem(self, target, index, value):
|
|
326
|
+
assert not isinstance(index, types.Type), index
|
|
327
|
+
args = target, index, value
|
|
328
|
+
kws = {}
|
|
329
|
+
return self.resolve_function_type("static_setitem", args, kws)
|
|
330
|
+
|
|
331
|
+
def resolve_setitem(self, target, index, value):
|
|
332
|
+
assert isinstance(index, types.Type), index
|
|
333
|
+
fnty = self.resolve_value_type(operator.setitem)
|
|
334
|
+
sig = fnty.get_call_type(self, (target, index, value), {})
|
|
335
|
+
return sig
|
|
336
|
+
|
|
337
|
+
def resolve_delitem(self, target, index):
|
|
338
|
+
args = target, index
|
|
339
|
+
kws = {}
|
|
340
|
+
fnty = self.resolve_value_type(operator.delitem)
|
|
341
|
+
sig = fnty.get_call_type(self, args, kws)
|
|
342
|
+
return sig
|
|
343
|
+
|
|
344
|
+
def resolve_module_constants(self, typ, attr):
|
|
345
|
+
"""
|
|
346
|
+
Resolve module-level global constants.
|
|
347
|
+
Return None or the attribute type
|
|
348
|
+
"""
|
|
349
|
+
assert isinstance(typ, types.Module)
|
|
350
|
+
attrval = getattr(typ.pymod, attr)
|
|
351
|
+
try:
|
|
352
|
+
return self.resolve_value_type(attrval)
|
|
353
|
+
except ValueError:
|
|
354
|
+
pass
|
|
355
|
+
|
|
356
|
+
def resolve_value_type(self, val):
|
|
357
|
+
"""
|
|
358
|
+
Return the numba type of a Python value that is being used
|
|
359
|
+
as a runtime constant.
|
|
360
|
+
ValueError is raised for unsupported types.
|
|
361
|
+
"""
|
|
362
|
+
try:
|
|
363
|
+
ty = typeof(val, Purpose.constant)
|
|
364
|
+
except ValueError as e:
|
|
365
|
+
# Make sure the exception doesn't hold a reference to the user
|
|
366
|
+
# value.
|
|
367
|
+
typeof_exc = utils.erase_traceback(e)
|
|
368
|
+
else:
|
|
369
|
+
return ty
|
|
370
|
+
|
|
371
|
+
if isinstance(val, types.ExternalFunction):
|
|
372
|
+
return val
|
|
373
|
+
|
|
374
|
+
# Try to look up target specific typing information
|
|
375
|
+
ty = self._get_global_type(val)
|
|
376
|
+
if ty is not None:
|
|
377
|
+
return ty
|
|
378
|
+
|
|
379
|
+
raise typeof_exc
|
|
380
|
+
|
|
381
|
+
def resolve_value_type_prefer_literal(self, value):
|
|
382
|
+
"""Resolve value type and prefer Literal types whenever possible."""
|
|
383
|
+
lit = types.maybe_literal(value)
|
|
384
|
+
if lit is None:
|
|
385
|
+
return self.resolve_value_type(value)
|
|
386
|
+
else:
|
|
387
|
+
return lit
|
|
388
|
+
|
|
389
|
+
def _get_global_type(self, gv):
|
|
390
|
+
ty = self._lookup_global(gv)
|
|
391
|
+
if ty is not None:
|
|
392
|
+
return ty
|
|
393
|
+
if isinstance(gv, pytypes.ModuleType):
|
|
394
|
+
return types.Module(gv)
|
|
395
|
+
|
|
396
|
+
def _load_builtins(self):
|
|
397
|
+
# Initialize declarations
|
|
398
|
+
from numba.core.typing import builtins, arraydecl, npdatetime # noqa: F401, E501
|
|
399
|
+
from numba.core.typing import ctypes_utils, bufproto # noqa: F401, E501
|
|
400
|
+
from numba.core.unsafe import eh # noqa: F401
|
|
401
|
+
|
|
402
|
+
self.install_registry(templates.builtin_registry)
|
|
403
|
+
|
|
404
|
+
def load_additional_registries(self):
|
|
405
|
+
"""
|
|
406
|
+
Load target-specific registries. Can be overridden by subclasses.
|
|
407
|
+
"""
|
|
408
|
+
|
|
409
|
+
def install_registry(self, registry):
|
|
410
|
+
"""
|
|
411
|
+
Install a *registry* (a templates.Registry instance) of function,
|
|
412
|
+
attribute and global declarations.
|
|
413
|
+
"""
|
|
414
|
+
try:
|
|
415
|
+
loader = self._registries[registry]
|
|
416
|
+
except KeyError:
|
|
417
|
+
loader = templates.RegistryLoader(registry)
|
|
418
|
+
self._registries[registry] = loader
|
|
419
|
+
|
|
420
|
+
from numba.core.target_extension import (
|
|
421
|
+
get_local_target,
|
|
422
|
+
resolve_target_str,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
current_target = get_local_target(self)
|
|
426
|
+
|
|
427
|
+
def is_for_this_target(ftcls):
|
|
428
|
+
metadata = getattr(ftcls, "metadata", None)
|
|
429
|
+
if metadata is None:
|
|
430
|
+
return True
|
|
431
|
+
|
|
432
|
+
target_str = metadata.get("target")
|
|
433
|
+
if target_str is None:
|
|
434
|
+
return True
|
|
435
|
+
|
|
436
|
+
# There may be pending registrations for nonexistent targets.
|
|
437
|
+
# Ideally it would be impossible to leave a registration pending
|
|
438
|
+
# for an invalid target, but in practice this is exceedingly
|
|
439
|
+
# difficult to guard against - many things are registered at import
|
|
440
|
+
# time, and eagerly reporting an error when registering for invalid
|
|
441
|
+
# targets would require that all target registration code is
|
|
442
|
+
# executed prior to all typing registrations during the import
|
|
443
|
+
# process; attempting to enforce this would impose constraints on
|
|
444
|
+
# execution order during import that would be very difficult to
|
|
445
|
+
# resolve and maintain in the presence of typical code maintenance.
|
|
446
|
+
# Furthermore, these constraints would be imposed not only on
|
|
447
|
+
# Numba internals, but also on its dependents.
|
|
448
|
+
#
|
|
449
|
+
# Instead of that enforcement, we simply catch any occurrences of
|
|
450
|
+
# registrations for targets that don't exist, and report that
|
|
451
|
+
# they're not for this target. They will then not be encountered
|
|
452
|
+
# again during future typing context refreshes (because the
|
|
453
|
+
# loader's new registrations are a stream_list that doesn't yield
|
|
454
|
+
# previously-yielded items).
|
|
455
|
+
try:
|
|
456
|
+
ft_target = resolve_target_str(target_str)
|
|
457
|
+
except errors.NonexistentTargetError:
|
|
458
|
+
return False
|
|
459
|
+
|
|
460
|
+
return current_target.inherits_from(ft_target)
|
|
461
|
+
|
|
462
|
+
for ftcls in loader.new_registrations("functions"):
|
|
463
|
+
if not is_for_this_target(ftcls):
|
|
464
|
+
continue
|
|
465
|
+
self.insert_function(ftcls(self))
|
|
466
|
+
for ftcls in loader.new_registrations("attributes"):
|
|
467
|
+
if not is_for_this_target(ftcls):
|
|
468
|
+
continue
|
|
469
|
+
self.insert_attributes(ftcls(self))
|
|
470
|
+
for gv, gty in loader.new_registrations("globals"):
|
|
471
|
+
existing = self._lookup_global(gv)
|
|
472
|
+
if existing is None:
|
|
473
|
+
self.insert_global(gv, gty)
|
|
474
|
+
else:
|
|
475
|
+
# A type was already inserted, see if we can add to it
|
|
476
|
+
newty = existing.augment(gty)
|
|
477
|
+
if newty is None:
|
|
478
|
+
raise TypeError(
|
|
479
|
+
"cannot augment %s with %s" % (existing, gty)
|
|
480
|
+
)
|
|
481
|
+
self._remove_global(gv)
|
|
482
|
+
self._insert_global(gv, newty)
|
|
483
|
+
|
|
484
|
+
def _lookup_global(self, gv):
|
|
485
|
+
"""
|
|
486
|
+
Look up the registered type for global value *gv*.
|
|
487
|
+
"""
|
|
488
|
+
try:
|
|
489
|
+
gv = weakref.ref(gv)
|
|
490
|
+
except TypeError:
|
|
491
|
+
pass
|
|
492
|
+
try:
|
|
493
|
+
return self._globals.get(gv, None)
|
|
494
|
+
except TypeError:
|
|
495
|
+
# Unhashable type
|
|
496
|
+
return None
|
|
497
|
+
|
|
498
|
+
def _insert_global(self, gv, gty):
|
|
499
|
+
"""
|
|
500
|
+
Register type *gty* for value *gv*. Only a weak reference
|
|
501
|
+
to *gv* is kept, if possible.
|
|
502
|
+
"""
|
|
503
|
+
|
|
504
|
+
def on_disposal(wr, pop=self._globals.pop):
|
|
505
|
+
# pop() is pre-looked up to avoid a crash late at shutdown on 3.5
|
|
506
|
+
# (https://bugs.python.org/issue25217)
|
|
507
|
+
pop(wr)
|
|
508
|
+
|
|
509
|
+
try:
|
|
510
|
+
gv = weakref.ref(gv, on_disposal)
|
|
511
|
+
except TypeError:
|
|
512
|
+
pass
|
|
513
|
+
self._globals[gv] = gty
|
|
514
|
+
|
|
515
|
+
def _remove_global(self, gv):
|
|
516
|
+
"""
|
|
517
|
+
Remove the registered type for global value *gv*.
|
|
518
|
+
"""
|
|
519
|
+
try:
|
|
520
|
+
gv = weakref.ref(gv)
|
|
521
|
+
except TypeError:
|
|
522
|
+
pass
|
|
523
|
+
del self._globals[gv]
|
|
524
|
+
|
|
525
|
+
def insert_global(self, gv, gty):
|
|
526
|
+
self._insert_global(gv, gty)
|
|
527
|
+
|
|
528
|
+
def insert_attributes(self, at):
|
|
529
|
+
key = at.key
|
|
530
|
+
self._attributes[key].append(at)
|
|
531
|
+
|
|
532
|
+
def insert_function(self, ft):
|
|
533
|
+
key = ft.key
|
|
534
|
+
self._functions[key].append(ft)
|
|
535
|
+
|
|
536
|
+
def insert_user_function(self, fn, ft):
|
|
537
|
+
"""Insert a user function.
|
|
538
|
+
|
|
539
|
+
Args
|
|
540
|
+
----
|
|
541
|
+
- fn:
|
|
542
|
+
object used as callee
|
|
543
|
+
- ft:
|
|
544
|
+
function template
|
|
545
|
+
"""
|
|
546
|
+
self._insert_global(fn, types.Function(ft))
|
|
547
|
+
|
|
548
|
+
def can_convert(self, fromty, toty):
|
|
549
|
+
"""
|
|
550
|
+
Check whether conversion is possible from *fromty* to *toty*.
|
|
551
|
+
If successful, return a numba.typeconv.Conversion instance;
|
|
552
|
+
otherwise None is returned.
|
|
553
|
+
"""
|
|
554
|
+
if fromty == toty:
|
|
555
|
+
return Conversion.exact
|
|
556
|
+
else:
|
|
557
|
+
# First check with the type manager (some rules are registered
|
|
558
|
+
# at startup there, see numba.typeconv.rules)
|
|
559
|
+
conv = self.tm.check_compatible(fromty, toty)
|
|
560
|
+
if conv is not None:
|
|
561
|
+
return conv
|
|
562
|
+
|
|
563
|
+
# Fall back on type-specific rules
|
|
564
|
+
forward = fromty.can_convert_to(self, toty)
|
|
565
|
+
backward = toty.can_convert_from(self, fromty)
|
|
566
|
+
if backward is None:
|
|
567
|
+
return forward
|
|
568
|
+
elif forward is None:
|
|
569
|
+
return backward
|
|
570
|
+
else:
|
|
571
|
+
return min(forward, backward)
|
|
572
|
+
|
|
573
|
+
def _rate_arguments(
|
|
574
|
+
self,
|
|
575
|
+
actualargs,
|
|
576
|
+
formalargs,
|
|
577
|
+
unsafe_casting=True,
|
|
578
|
+
exact_match_required=False,
|
|
579
|
+
):
|
|
580
|
+
"""
|
|
581
|
+
Rate the actual arguments for compatibility against the formal
|
|
582
|
+
arguments. A Rating instance is returned, or None if incompatible.
|
|
583
|
+
"""
|
|
584
|
+
if len(actualargs) != len(formalargs):
|
|
585
|
+
return None
|
|
586
|
+
rate = Rating()
|
|
587
|
+
for actual, formal in zip(actualargs, formalargs):
|
|
588
|
+
conv = self.can_convert(actual, formal)
|
|
589
|
+
if conv is None:
|
|
590
|
+
return None
|
|
591
|
+
elif not unsafe_casting and conv >= Conversion.unsafe:
|
|
592
|
+
return None
|
|
593
|
+
elif exact_match_required and conv != Conversion.exact:
|
|
594
|
+
return None
|
|
595
|
+
|
|
596
|
+
if conv == Conversion.promote:
|
|
597
|
+
rate.promote += 1
|
|
598
|
+
elif conv == Conversion.safe:
|
|
599
|
+
rate.safe_convert += 1
|
|
600
|
+
elif conv == Conversion.unsafe:
|
|
601
|
+
rate.unsafe_convert += 1
|
|
602
|
+
elif conv == Conversion.exact:
|
|
603
|
+
pass
|
|
604
|
+
else:
|
|
605
|
+
raise AssertionError("unreachable", conv)
|
|
606
|
+
|
|
607
|
+
return rate
|
|
608
|
+
|
|
609
|
+
def install_possible_conversions(self, actualargs, formalargs):
|
|
610
|
+
"""
|
|
611
|
+
Install possible conversions from the actual argument types to
|
|
612
|
+
the formal argument types in the C++ type manager.
|
|
613
|
+
Return True if all arguments can be converted.
|
|
614
|
+
"""
|
|
615
|
+
if len(actualargs) != len(formalargs):
|
|
616
|
+
return False
|
|
617
|
+
for actual, formal in zip(actualargs, formalargs):
|
|
618
|
+
if self.tm.check_compatible(actual, formal) is not None:
|
|
619
|
+
# This conversion is already known
|
|
620
|
+
continue
|
|
621
|
+
conv = self.can_convert(actual, formal)
|
|
622
|
+
if conv is None:
|
|
623
|
+
return False
|
|
624
|
+
assert conv is not Conversion.exact
|
|
625
|
+
self.tm.set_compatible(actual, formal, conv)
|
|
626
|
+
return True
|
|
627
|
+
|
|
628
|
+
def resolve_overload(
|
|
629
|
+
self,
|
|
630
|
+
key,
|
|
631
|
+
cases,
|
|
632
|
+
args,
|
|
633
|
+
kws,
|
|
634
|
+
allow_ambiguous=True,
|
|
635
|
+
unsafe_casting=True,
|
|
636
|
+
exact_match_required=False,
|
|
637
|
+
):
|
|
638
|
+
"""
|
|
639
|
+
Given actual *args* and *kws*, find the best matching
|
|
640
|
+
signature in *cases*, or None if none matches.
|
|
641
|
+
*key* is used for error reporting purposes.
|
|
642
|
+
If *allow_ambiguous* is False, a tie in the best matches
|
|
643
|
+
will raise an error.
|
|
644
|
+
If *unsafe_casting* is False, unsafe casting is forbidden.
|
|
645
|
+
"""
|
|
646
|
+
assert not kws, "Keyword arguments are not supported, yet"
|
|
647
|
+
options = {
|
|
648
|
+
"unsafe_casting": unsafe_casting,
|
|
649
|
+
"exact_match_required": exact_match_required,
|
|
650
|
+
}
|
|
651
|
+
# Rate each case
|
|
652
|
+
candidates = []
|
|
653
|
+
for case in cases:
|
|
654
|
+
if len(args) == len(case.args):
|
|
655
|
+
rating = self._rate_arguments(args, case.args, **options)
|
|
656
|
+
if rating is not None:
|
|
657
|
+
candidates.append((rating.astuple(), case))
|
|
658
|
+
|
|
659
|
+
# Find the best case
|
|
660
|
+
candidates.sort(key=lambda i: i[0])
|
|
661
|
+
if candidates:
|
|
662
|
+
best_rate, best = candidates[0]
|
|
663
|
+
if not allow_ambiguous:
|
|
664
|
+
# Find whether there is a tie and if so, raise an error
|
|
665
|
+
tied = []
|
|
666
|
+
for rate, case in candidates:
|
|
667
|
+
if rate != best_rate:
|
|
668
|
+
break
|
|
669
|
+
tied.append(case)
|
|
670
|
+
if len(tied) > 1:
|
|
671
|
+
args = (key, args, "\n".join(map(str, tied)))
|
|
672
|
+
msg = "Ambiguous overloading for %s %s:\n%s" % args
|
|
673
|
+
raise TypeError(msg)
|
|
674
|
+
# Simply return the best matching candidate in order.
|
|
675
|
+
# If there is a tie, since list.sort() is stable, the first case
|
|
676
|
+
# in the original order is returned.
|
|
677
|
+
# (this can happen if e.g. a function template exposes
|
|
678
|
+
# (int32, int32) -> int32 and (int64, int64) -> int64,
|
|
679
|
+
# and you call it with (int16, int16) arguments)
|
|
680
|
+
return best
|
|
681
|
+
|
|
682
|
+
def unify_types(self, *typelist):
|
|
683
|
+
# Sort the type list according to bit width before doing
|
|
684
|
+
# pairwise unification (with thanks to aterrel).
|
|
685
|
+
def keyfunc(obj):
|
|
686
|
+
"""Uses bitwidth to order numeric-types.
|
|
687
|
+
Fallback to stable, deterministic sort.
|
|
688
|
+
"""
|
|
689
|
+
return getattr(obj, "bitwidth", 0)
|
|
690
|
+
|
|
691
|
+
typelist = sorted(typelist, key=keyfunc)
|
|
692
|
+
unified = typelist[0]
|
|
693
|
+
for tp in typelist[1:]:
|
|
694
|
+
unified = self.unify_pairs(unified, tp)
|
|
695
|
+
if unified is None:
|
|
696
|
+
break
|
|
697
|
+
return unified
|
|
698
|
+
|
|
699
|
+
def unify_pairs(self, first, second):
|
|
700
|
+
"""
|
|
701
|
+
Try to unify the two given types. A third type is returned,
|
|
702
|
+
or None in case of failure.
|
|
703
|
+
"""
|
|
704
|
+
if first == second:
|
|
705
|
+
return first
|
|
706
|
+
|
|
707
|
+
if first is types.undefined:
|
|
708
|
+
return second
|
|
709
|
+
elif second is types.undefined:
|
|
710
|
+
return first
|
|
711
|
+
|
|
712
|
+
# Types with special unification rules
|
|
713
|
+
unified = first.unify(self, second)
|
|
714
|
+
if unified is not None:
|
|
715
|
+
return unified
|
|
716
|
+
|
|
717
|
+
unified = second.unify(self, first)
|
|
718
|
+
if unified is not None:
|
|
719
|
+
return unified
|
|
720
|
+
|
|
721
|
+
# Other types with simple conversion rules
|
|
722
|
+
conv = self.can_convert(fromty=first, toty=second)
|
|
723
|
+
if conv is not None and conv <= Conversion.safe:
|
|
724
|
+
# Can convert from first to second
|
|
725
|
+
return second
|
|
726
|
+
|
|
727
|
+
conv = self.can_convert(fromty=second, toty=first)
|
|
728
|
+
if conv is not None and conv <= Conversion.safe:
|
|
729
|
+
# Can convert from second to first
|
|
730
|
+
return first
|
|
731
|
+
|
|
732
|
+
if isinstance(first, types.Literal) or isinstance(
|
|
733
|
+
second, types.Literal
|
|
734
|
+
):
|
|
735
|
+
first = types.unliteral(first)
|
|
736
|
+
second = types.unliteral(second)
|
|
737
|
+
return self.unify_pairs(first, second)
|
|
738
|
+
|
|
739
|
+
# Cannot unify
|
|
740
|
+
return None
|
|
741
|
+
|
|
742
|
+
|
|
743
|
+
class Context(BaseContext):
|
|
744
|
+
# This list will be extended to include all the registries
|
|
745
|
+
# that are needed for CUDA
|
|
746
|
+
def load_additional_registries(self):
|
|
747
|
+
from . import (
|
|
748
|
+
npydecl,
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
self.install_registry(npydecl.registry)
|