numba-cuda 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +2 -2
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +1 -1
- numba_cuda/numba/cuda/api.py +2 -7
- numba_cuda/numba/cuda/compiler.py +7 -4
- numba_cuda/numba/cuda/core/interpreter.py +3592 -0
- numba_cuda/numba/cuda/core/ir_utils.py +2645 -0
- numba_cuda/numba/cuda/core/sigutils.py +55 -0
- numba_cuda/numba/cuda/cuda_paths.py +9 -17
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +4 -19
- numba_cuda/numba/cuda/cudadrv/libs.py +1 -2
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +44 -44
- numba_cuda/numba/cuda/cudadrv/nvvm.py +3 -18
- numba_cuda/numba/cuda/cudadrv/runtime.py +12 -1
- numba_cuda/numba/cuda/cudamath.py +1 -1
- numba_cuda/numba/cuda/decorators.py +4 -3
- numba_cuda/numba/cuda/deviceufunc.py +2 -1
- numba_cuda/numba/cuda/dispatcher.py +3 -2
- numba_cuda/numba/cuda/extending.py +1 -1
- numba_cuda/numba/cuda/itanium_mangler.py +211 -0
- numba_cuda/numba/cuda/libdevicedecl.py +1 -1
- numba_cuda/numba/cuda/libdevicefuncs.py +1 -1
- numba_cuda/numba/cuda/lowering.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/cudadrv/driver.py +0 -7
- numba_cuda/numba/cuda/target.py +1 -2
- numba_cuda/numba/cuda/testing.py +4 -6
- numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +80 -0
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +4 -6
- numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +0 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +146 -3
- numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +0 -4
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +1 -284
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +473 -0
- numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -6
- numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +295 -0
- numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +5 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +1 -1
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +1 -1
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -2
- numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +1 -1
- numba_cuda/numba/cuda/tests/support.py +752 -0
- numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -3
- numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +4 -1
- numba_cuda/numba/cuda/typing/__init__.py +8 -0
- numba_cuda/numba/cuda/typing/templates.py +1453 -0
- numba_cuda/numba/cuda/vector_types.py +3 -3
- {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/METADATA +21 -28
- {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/RECORD +84 -79
- numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
- numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
- numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
- numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
- {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,34 @@
|
|
|
1
|
+
import cmath
|
|
2
|
+
import contextlib
|
|
3
|
+
import enum
|
|
4
|
+
import gc
|
|
5
|
+
import math
|
|
6
|
+
import unittest
|
|
7
|
+
import os
|
|
8
|
+
import io
|
|
9
|
+
import subprocess
|
|
10
|
+
import sys
|
|
11
|
+
import shutil
|
|
12
|
+
import warnings
|
|
13
|
+
import tempfile
|
|
14
|
+
import time
|
|
15
|
+
import types as pytypes
|
|
16
|
+
from functools import cached_property
|
|
17
|
+
|
|
18
|
+
import numpy as np
|
|
19
|
+
|
|
20
|
+
from numba import types
|
|
21
|
+
from numba.core import errors, config
|
|
22
|
+
from numba.core.typing import cffi_utils
|
|
1
23
|
from numba.cuda.memory_management.nrt import rtsys
|
|
24
|
+
from numba.core.extending import (
|
|
25
|
+
typeof_impl,
|
|
26
|
+
register_model,
|
|
27
|
+
unbox,
|
|
28
|
+
NativeValue,
|
|
29
|
+
)
|
|
30
|
+
from numba.core.datamodel.models import OpaqueModel
|
|
31
|
+
from numba.np import numpy_support
|
|
2
32
|
|
|
3
33
|
|
|
4
34
|
class EnableNRTStatsMixin(object):
|
|
@@ -9,3 +39,725 @@ class EnableNRTStatsMixin(object):
|
|
|
9
39
|
|
|
10
40
|
def tearDown(self):
|
|
11
41
|
rtsys.memsys_disable_stats()
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
skip_unless_cffi = unittest.skipUnless(cffi_utils.SUPPORTED, "requires cffi")
|
|
45
|
+
|
|
46
|
+
_lnx_reason = "linux only test"
|
|
47
|
+
linux_only = unittest.skipIf(not sys.platform.startswith("linux"), _lnx_reason)
|
|
48
|
+
|
|
49
|
+
_win_reason = "Windows only test"
|
|
50
|
+
windows_only = unittest.skipIf(not sys.platform.startswith("win"), _win_reason)
|
|
51
|
+
|
|
52
|
+
IS_NUMPY_2 = numpy_support.numpy_version >= (2, 0)
|
|
53
|
+
skip_if_numpy_2 = unittest.skipIf(IS_NUMPY_2, "Not supported on numpy 2.0+")
|
|
54
|
+
|
|
55
|
+
_trashcan_dir = "numba-cuda-tests"
|
|
56
|
+
|
|
57
|
+
if os.name == "nt":
|
|
58
|
+
# Under Windows, gettempdir() points to the user-local temp dir
|
|
59
|
+
_trashcan_dir = os.path.join(tempfile.gettempdir(), _trashcan_dir)
|
|
60
|
+
else:
|
|
61
|
+
# Mix the UID into the directory name to allow different users to
|
|
62
|
+
# run the test suite without permission errors (issue #1586)
|
|
63
|
+
_trashcan_dir = os.path.join(
|
|
64
|
+
tempfile.gettempdir(), "%s.%s" % (_trashcan_dir, os.getuid())
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Stale temporary directories are deleted after they are older than this value.
|
|
68
|
+
# The test suite probably won't ever take longer than this...
|
|
69
|
+
_trashcan_timeout = 24 * 3600 # 1 day
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _create_trashcan_dir():
|
|
73
|
+
try:
|
|
74
|
+
os.mkdir(_trashcan_dir)
|
|
75
|
+
except FileExistsError:
|
|
76
|
+
pass
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _purge_trashcan_dir():
|
|
80
|
+
freshness_threshold = time.time() - _trashcan_timeout
|
|
81
|
+
for fn in sorted(os.listdir(_trashcan_dir)):
|
|
82
|
+
fn = os.path.join(_trashcan_dir, fn)
|
|
83
|
+
try:
|
|
84
|
+
st = os.stat(fn)
|
|
85
|
+
if st.st_mtime < freshness_threshold:
|
|
86
|
+
shutil.rmtree(fn, ignore_errors=True)
|
|
87
|
+
except OSError:
|
|
88
|
+
# In parallel testing, several processes can attempt to
|
|
89
|
+
# remove the same entry at once, ignore.
|
|
90
|
+
pass
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _create_trashcan_subdir(prefix):
|
|
94
|
+
_purge_trashcan_dir()
|
|
95
|
+
path = tempfile.mkdtemp(prefix=prefix + "-", dir=_trashcan_dir)
|
|
96
|
+
return path
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def temp_directory(prefix):
|
|
100
|
+
"""
|
|
101
|
+
Create a temporary directory with the given *prefix* that will survive
|
|
102
|
+
at least as long as this process invocation. The temporary directory
|
|
103
|
+
will be eventually deleted when it becomes stale enough.
|
|
104
|
+
|
|
105
|
+
This is necessary because a DLL file can't be deleted while in use
|
|
106
|
+
under Windows.
|
|
107
|
+
|
|
108
|
+
An interesting side-effect is to be able to inspect the test files
|
|
109
|
+
shortly after a test suite run.
|
|
110
|
+
"""
|
|
111
|
+
_create_trashcan_dir()
|
|
112
|
+
return _create_trashcan_subdir(prefix)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def import_dynamic(modname):
|
|
116
|
+
"""
|
|
117
|
+
Import and return a module of the given name. Care is taken to
|
|
118
|
+
avoid issues due to Python's internal directory caching.
|
|
119
|
+
"""
|
|
120
|
+
import importlib
|
|
121
|
+
|
|
122
|
+
importlib.invalidate_caches()
|
|
123
|
+
__import__(modname)
|
|
124
|
+
return sys.modules[modname]
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def ignore_internal_warnings():
|
|
128
|
+
"""Use in testing within a ` warnings.catch_warnings` block to filter out
|
|
129
|
+
warnings that are unrelated/internally generated by Numba.
|
|
130
|
+
"""
|
|
131
|
+
# Filter out warnings from typeguard
|
|
132
|
+
warnings.filterwarnings("ignore", module="typeguard")
|
|
133
|
+
# Filter out warnings about TBB interface mismatch
|
|
134
|
+
warnings.filterwarnings(
|
|
135
|
+
action="ignore",
|
|
136
|
+
message=r".*TBB_INTERFACE_VERSION.*",
|
|
137
|
+
category=errors.NumbaWarning,
|
|
138
|
+
module=r"numba\.np\.ufunc\.parallel.*",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@contextlib.contextmanager
|
|
143
|
+
def override_config(name, value):
|
|
144
|
+
"""
|
|
145
|
+
Return a context manager that temporarily sets Numba config variable
|
|
146
|
+
*name* to *value*. *name* must be the name of an existing variable
|
|
147
|
+
in numba.config.
|
|
148
|
+
"""
|
|
149
|
+
old_value = getattr(config, name)
|
|
150
|
+
setattr(config, name, value)
|
|
151
|
+
try:
|
|
152
|
+
yield
|
|
153
|
+
finally:
|
|
154
|
+
setattr(config, name, old_value)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def run_in_subprocess(code, flags=None, env=None, timeout=30):
|
|
158
|
+
"""Run a snippet of Python code in a subprocess with flags, if any are
|
|
159
|
+
given. 'env' is passed to subprocess.Popen(). 'timeout' is passed to
|
|
160
|
+
popen.communicate().
|
|
161
|
+
|
|
162
|
+
Returns the stdout and stderr of the subprocess after its termination.
|
|
163
|
+
"""
|
|
164
|
+
if flags is None:
|
|
165
|
+
flags = []
|
|
166
|
+
cmd = (
|
|
167
|
+
[
|
|
168
|
+
sys.executable,
|
|
169
|
+
]
|
|
170
|
+
+ flags
|
|
171
|
+
+ ["-c", code]
|
|
172
|
+
)
|
|
173
|
+
popen = subprocess.Popen(
|
|
174
|
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
|
|
175
|
+
)
|
|
176
|
+
out, err = popen.communicate(timeout=timeout)
|
|
177
|
+
if popen.returncode != 0:
|
|
178
|
+
msg = "process failed with code %s: stderr follows\n%s\n"
|
|
179
|
+
raise AssertionError(msg % (popen.returncode, err.decode()))
|
|
180
|
+
return out, err
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
@contextlib.contextmanager
|
|
184
|
+
def captured_output(stream_name):
|
|
185
|
+
"""Return a context manager used by captured_stdout/stdin/stderr
|
|
186
|
+
that temporarily replaces the sys stream *stream_name* with a StringIO."""
|
|
187
|
+
orig_stdout = getattr(sys, stream_name)
|
|
188
|
+
setattr(sys, stream_name, io.StringIO())
|
|
189
|
+
try:
|
|
190
|
+
yield getattr(sys, stream_name)
|
|
191
|
+
finally:
|
|
192
|
+
setattr(sys, stream_name, orig_stdout)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def captured_stdout():
|
|
196
|
+
"""Capture the output of sys.stdout:
|
|
197
|
+
|
|
198
|
+
with captured_stdout() as stdout:
|
|
199
|
+
print("hello")
|
|
200
|
+
self.assertEqual(stdout.getvalue(), "hello\n")
|
|
201
|
+
"""
|
|
202
|
+
return captured_output("stdout")
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def captured_stderr():
|
|
206
|
+
"""Capture the output of sys.stderr:
|
|
207
|
+
|
|
208
|
+
with captured_stderr() as stderr:
|
|
209
|
+
print("hello", file=sys.stderr)
|
|
210
|
+
self.assertEqual(stderr.getvalue(), "hello\n")
|
|
211
|
+
"""
|
|
212
|
+
return captured_output("stderr")
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class TestCase(unittest.TestCase):
|
|
216
|
+
longMessage = True
|
|
217
|
+
|
|
218
|
+
# A random state yielding the same random numbers for any test case.
|
|
219
|
+
# Use as `self.random.<method name>`
|
|
220
|
+
@cached_property
|
|
221
|
+
def random(self):
|
|
222
|
+
return np.random.RandomState(42)
|
|
223
|
+
|
|
224
|
+
def reset_module_warnings(self, module):
|
|
225
|
+
"""
|
|
226
|
+
Reset the warnings registry of a module. This can be necessary
|
|
227
|
+
as the warnings module is buggy in that regard.
|
|
228
|
+
See http://bugs.python.org/issue4180
|
|
229
|
+
"""
|
|
230
|
+
if isinstance(module, str):
|
|
231
|
+
module = sys.modules[module]
|
|
232
|
+
try:
|
|
233
|
+
del module.__warningregistry__
|
|
234
|
+
except AttributeError:
|
|
235
|
+
pass
|
|
236
|
+
|
|
237
|
+
@contextlib.contextmanager
|
|
238
|
+
def assertTypingError(self):
|
|
239
|
+
"""
|
|
240
|
+
A context manager that asserts the enclosed code block fails
|
|
241
|
+
compiling in nopython mode.
|
|
242
|
+
"""
|
|
243
|
+
_accepted_errors = (
|
|
244
|
+
errors.LoweringError,
|
|
245
|
+
errors.TypingError,
|
|
246
|
+
TypeError,
|
|
247
|
+
NotImplementedError,
|
|
248
|
+
)
|
|
249
|
+
with self.assertRaises(_accepted_errors) as cm:
|
|
250
|
+
yield cm
|
|
251
|
+
|
|
252
|
+
@contextlib.contextmanager
|
|
253
|
+
def assertRefCount(self, *objects):
|
|
254
|
+
"""
|
|
255
|
+
A context manager that asserts the given objects have the
|
|
256
|
+
same reference counts before and after executing the
|
|
257
|
+
enclosed block.
|
|
258
|
+
"""
|
|
259
|
+
old_refcounts = [sys.getrefcount(x) for x in objects]
|
|
260
|
+
yield
|
|
261
|
+
gc.collect()
|
|
262
|
+
new_refcounts = [sys.getrefcount(x) for x in objects]
|
|
263
|
+
for old, new, obj in zip(old_refcounts, new_refcounts, objects):
|
|
264
|
+
if old != new:
|
|
265
|
+
self.fail(
|
|
266
|
+
"Refcount changed from %d to %d for object: %r"
|
|
267
|
+
% (old, new, obj)
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
def assertRefCountEqual(self, *objects):
|
|
271
|
+
gc.collect()
|
|
272
|
+
rc = [sys.getrefcount(x) for x in objects]
|
|
273
|
+
rc_0 = rc[0]
|
|
274
|
+
for i in range(len(objects))[1:]:
|
|
275
|
+
rc_i = rc[i]
|
|
276
|
+
if rc_0 != rc_i:
|
|
277
|
+
self.fail(
|
|
278
|
+
f"Refcount for objects does not match. "
|
|
279
|
+
f"#0({rc_0}) != #{i}({rc_i}) does not match."
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
@contextlib.contextmanager
|
|
283
|
+
def assertNoNRTLeak(self):
|
|
284
|
+
"""
|
|
285
|
+
A context manager that asserts no NRT leak was created during
|
|
286
|
+
the execution of the enclosed block.
|
|
287
|
+
"""
|
|
288
|
+
old = rtsys.get_allocation_stats()
|
|
289
|
+
yield
|
|
290
|
+
new = rtsys.get_allocation_stats()
|
|
291
|
+
total_alloc = new.alloc - old.alloc
|
|
292
|
+
total_free = new.free - old.free
|
|
293
|
+
total_mi_alloc = new.mi_alloc - old.mi_alloc
|
|
294
|
+
total_mi_free = new.mi_free - old.mi_free
|
|
295
|
+
self.assertEqual(
|
|
296
|
+
total_alloc,
|
|
297
|
+
total_free,
|
|
298
|
+
"number of data allocs != number of data frees",
|
|
299
|
+
)
|
|
300
|
+
self.assertEqual(
|
|
301
|
+
total_mi_alloc,
|
|
302
|
+
total_mi_free,
|
|
303
|
+
"number of meminfo allocs != number of meminfo frees",
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
_bool_types = (bool, np.bool_)
|
|
307
|
+
_exact_typesets = [
|
|
308
|
+
_bool_types,
|
|
309
|
+
(int,),
|
|
310
|
+
(str,),
|
|
311
|
+
(np.integer,),
|
|
312
|
+
(bytes, np.bytes_),
|
|
313
|
+
]
|
|
314
|
+
_approx_typesets = [(float,), (complex,), (np.inexact)]
|
|
315
|
+
_sequence_typesets = [(tuple, list)]
|
|
316
|
+
_float_types = (float, np.floating)
|
|
317
|
+
_complex_types = (complex, np.complexfloating)
|
|
318
|
+
|
|
319
|
+
def _detect_family(self, numeric_object):
|
|
320
|
+
"""
|
|
321
|
+
This function returns a string description of the type family
|
|
322
|
+
that the object in question belongs to. Possible return values
|
|
323
|
+
are: "exact", "complex", "approximate", "sequence", and "unknown"
|
|
324
|
+
"""
|
|
325
|
+
if isinstance(numeric_object, np.ndarray):
|
|
326
|
+
return "ndarray"
|
|
327
|
+
|
|
328
|
+
if isinstance(numeric_object, enum.Enum):
|
|
329
|
+
return "enum"
|
|
330
|
+
|
|
331
|
+
for tp in self._sequence_typesets:
|
|
332
|
+
if isinstance(numeric_object, tp):
|
|
333
|
+
return "sequence"
|
|
334
|
+
|
|
335
|
+
for tp in self._exact_typesets:
|
|
336
|
+
if isinstance(numeric_object, tp):
|
|
337
|
+
return "exact"
|
|
338
|
+
|
|
339
|
+
for tp in self._complex_types:
|
|
340
|
+
if isinstance(numeric_object, tp):
|
|
341
|
+
return "complex"
|
|
342
|
+
|
|
343
|
+
for tp in self._approx_typesets:
|
|
344
|
+
if isinstance(numeric_object, tp):
|
|
345
|
+
return "approximate"
|
|
346
|
+
|
|
347
|
+
return "unknown"
|
|
348
|
+
|
|
349
|
+
def _fix_dtype(self, dtype):
|
|
350
|
+
"""
|
|
351
|
+
Fix the given *dtype* for comparison.
|
|
352
|
+
"""
|
|
353
|
+
# Under 64-bit Windows, Numpy may return either int32 or int64
|
|
354
|
+
# arrays depending on the function.
|
|
355
|
+
if (
|
|
356
|
+
sys.platform == "win32"
|
|
357
|
+
and sys.maxsize > 2**32
|
|
358
|
+
and dtype == np.dtype("int32")
|
|
359
|
+
):
|
|
360
|
+
return np.dtype("int64")
|
|
361
|
+
else:
|
|
362
|
+
return dtype
|
|
363
|
+
|
|
364
|
+
def _fix_strides(self, arr):
|
|
365
|
+
"""
|
|
366
|
+
Return the strides of the given array, fixed for comparison.
|
|
367
|
+
Strides for 0- or 1-sized dimensions are ignored.
|
|
368
|
+
"""
|
|
369
|
+
if arr.size == 0:
|
|
370
|
+
return [0] * arr.ndim
|
|
371
|
+
else:
|
|
372
|
+
return [
|
|
373
|
+
stride / arr.itemsize
|
|
374
|
+
for (stride, shape) in zip(arr.strides, arr.shape)
|
|
375
|
+
if shape > 1
|
|
376
|
+
]
|
|
377
|
+
|
|
378
|
+
def assertStridesEqual(self, first, second):
|
|
379
|
+
"""
|
|
380
|
+
Test that two arrays have the same shape and strides.
|
|
381
|
+
"""
|
|
382
|
+
self.assertEqual(first.shape, second.shape, "shapes differ")
|
|
383
|
+
self.assertEqual(first.itemsize, second.itemsize, "itemsizes differ")
|
|
384
|
+
self.assertEqual(
|
|
385
|
+
self._fix_strides(first),
|
|
386
|
+
self._fix_strides(second),
|
|
387
|
+
"strides differ",
|
|
388
|
+
)
|
|
389
|
+
|
|
390
|
+
def assertPreciseEqual(
|
|
391
|
+
self,
|
|
392
|
+
first,
|
|
393
|
+
second,
|
|
394
|
+
prec="exact",
|
|
395
|
+
ulps=1,
|
|
396
|
+
msg=None,
|
|
397
|
+
ignore_sign_on_zero=False,
|
|
398
|
+
abs_tol=None,
|
|
399
|
+
):
|
|
400
|
+
"""
|
|
401
|
+
Versatile equality testing function with more built-in checks than
|
|
402
|
+
standard assertEqual().
|
|
403
|
+
|
|
404
|
+
For arrays, test that layout, dtype, shape are identical, and
|
|
405
|
+
recursively call assertPreciseEqual() on the contents.
|
|
406
|
+
|
|
407
|
+
For other sequences, recursively call assertPreciseEqual() on
|
|
408
|
+
the contents.
|
|
409
|
+
|
|
410
|
+
For scalars, test that two scalars or have similar types and are
|
|
411
|
+
equal up to a computed precision.
|
|
412
|
+
If the scalars are instances of exact types or if *prec* is
|
|
413
|
+
'exact', they are compared exactly.
|
|
414
|
+
If the scalars are instances of inexact types (float, complex)
|
|
415
|
+
and *prec* is not 'exact', then the number of significant bits
|
|
416
|
+
is computed according to the value of *prec*: 53 bits if *prec*
|
|
417
|
+
is 'double', 24 bits if *prec* is single. This number of bits
|
|
418
|
+
can be lowered by raising the *ulps* value.
|
|
419
|
+
ignore_sign_on_zero can be set to True if zeros are to be considered
|
|
420
|
+
equal regardless of their sign bit.
|
|
421
|
+
abs_tol if this is set to a float value its value is used in the
|
|
422
|
+
following. If, however, this is set to the string "eps" then machine
|
|
423
|
+
precision of the type(first) is used in the following instead. This
|
|
424
|
+
kwarg is used to check if the absolute difference in value between first
|
|
425
|
+
and second is less than the value set, if so the numbers being compared
|
|
426
|
+
are considered equal. (This is to handle small numbers typically of
|
|
427
|
+
magnitude less than machine precision).
|
|
428
|
+
|
|
429
|
+
Any value of *prec* other than 'exact', 'single' or 'double'
|
|
430
|
+
will raise an error.
|
|
431
|
+
"""
|
|
432
|
+
try:
|
|
433
|
+
self._assertPreciseEqual(
|
|
434
|
+
first, second, prec, ulps, msg, ignore_sign_on_zero, abs_tol
|
|
435
|
+
)
|
|
436
|
+
except AssertionError as exc:
|
|
437
|
+
failure_msg = str(exc)
|
|
438
|
+
# Fall off of the 'except' scope to avoid Python 3 exception
|
|
439
|
+
# chaining.
|
|
440
|
+
else:
|
|
441
|
+
return
|
|
442
|
+
# Decorate the failure message with more information
|
|
443
|
+
self.fail("when comparing %s and %s: %s" % (first, second, failure_msg))
|
|
444
|
+
|
|
445
|
+
def _assertPreciseEqual(
|
|
446
|
+
self,
|
|
447
|
+
first,
|
|
448
|
+
second,
|
|
449
|
+
prec="exact",
|
|
450
|
+
ulps=1,
|
|
451
|
+
msg=None,
|
|
452
|
+
ignore_sign_on_zero=False,
|
|
453
|
+
abs_tol=None,
|
|
454
|
+
):
|
|
455
|
+
"""Recursive workhorse for assertPreciseEqual()."""
|
|
456
|
+
|
|
457
|
+
def _assertNumberEqual(first, second, delta=None):
|
|
458
|
+
if (
|
|
459
|
+
delta is None
|
|
460
|
+
or first == second == 0.0
|
|
461
|
+
or math.isinf(first)
|
|
462
|
+
or math.isinf(second)
|
|
463
|
+
):
|
|
464
|
+
self.assertEqual(first, second, msg=msg)
|
|
465
|
+
# For signed zeros
|
|
466
|
+
if not ignore_sign_on_zero:
|
|
467
|
+
try:
|
|
468
|
+
if math.copysign(1, first) != math.copysign(1, second):
|
|
469
|
+
self.fail(
|
|
470
|
+
self._formatMessage(
|
|
471
|
+
msg, "%s != %s" % (first, second)
|
|
472
|
+
)
|
|
473
|
+
)
|
|
474
|
+
except TypeError:
|
|
475
|
+
pass
|
|
476
|
+
else:
|
|
477
|
+
self.assertAlmostEqual(first, second, delta=delta, msg=msg)
|
|
478
|
+
|
|
479
|
+
first_family = self._detect_family(first)
|
|
480
|
+
second_family = self._detect_family(second)
|
|
481
|
+
|
|
482
|
+
assertion_message = "Type Family mismatch. (%s != %s)" % (
|
|
483
|
+
first_family,
|
|
484
|
+
second_family,
|
|
485
|
+
)
|
|
486
|
+
if msg:
|
|
487
|
+
assertion_message += ": %s" % (msg,)
|
|
488
|
+
self.assertEqual(first_family, second_family, msg=assertion_message)
|
|
489
|
+
|
|
490
|
+
# We now know they are in the same comparison family
|
|
491
|
+
compare_family = first_family
|
|
492
|
+
|
|
493
|
+
# For recognized sequences, recurse
|
|
494
|
+
if compare_family == "ndarray":
|
|
495
|
+
dtype = self._fix_dtype(first.dtype)
|
|
496
|
+
self.assertEqual(dtype, self._fix_dtype(second.dtype))
|
|
497
|
+
self.assertEqual(
|
|
498
|
+
first.ndim, second.ndim, "different number of dimensions"
|
|
499
|
+
)
|
|
500
|
+
self.assertEqual(first.shape, second.shape, "different shapes")
|
|
501
|
+
self.assertEqual(
|
|
502
|
+
first.flags.writeable,
|
|
503
|
+
second.flags.writeable,
|
|
504
|
+
"different mutability",
|
|
505
|
+
)
|
|
506
|
+
# itemsize is already checked by the dtype test above
|
|
507
|
+
self.assertEqual(
|
|
508
|
+
self._fix_strides(first),
|
|
509
|
+
self._fix_strides(second),
|
|
510
|
+
"different strides",
|
|
511
|
+
)
|
|
512
|
+
if first.dtype != dtype:
|
|
513
|
+
first = first.astype(dtype)
|
|
514
|
+
if second.dtype != dtype:
|
|
515
|
+
second = second.astype(dtype)
|
|
516
|
+
for a, b in zip(first.flat, second.flat):
|
|
517
|
+
self._assertPreciseEqual(
|
|
518
|
+
a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
|
|
519
|
+
)
|
|
520
|
+
return
|
|
521
|
+
|
|
522
|
+
elif compare_family == "sequence":
|
|
523
|
+
self.assertEqual(len(first), len(second), msg=msg)
|
|
524
|
+
for a, b in zip(first, second):
|
|
525
|
+
self._assertPreciseEqual(
|
|
526
|
+
a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
|
|
527
|
+
)
|
|
528
|
+
return
|
|
529
|
+
|
|
530
|
+
elif compare_family == "exact":
|
|
531
|
+
exact_comparison = True
|
|
532
|
+
|
|
533
|
+
elif compare_family in ["complex", "approximate"]:
|
|
534
|
+
exact_comparison = False
|
|
535
|
+
|
|
536
|
+
elif compare_family == "enum":
|
|
537
|
+
self.assertIs(first.__class__, second.__class__)
|
|
538
|
+
self._assertPreciseEqual(
|
|
539
|
+
first.value,
|
|
540
|
+
second.value,
|
|
541
|
+
prec,
|
|
542
|
+
ulps,
|
|
543
|
+
msg,
|
|
544
|
+
ignore_sign_on_zero,
|
|
545
|
+
abs_tol,
|
|
546
|
+
)
|
|
547
|
+
return
|
|
548
|
+
|
|
549
|
+
elif compare_family == "unknown":
|
|
550
|
+
# Assume these are non-numeric types: we will fall back
|
|
551
|
+
# on regular unittest comparison.
|
|
552
|
+
self.assertIs(first.__class__, second.__class__)
|
|
553
|
+
exact_comparison = True
|
|
554
|
+
|
|
555
|
+
else:
|
|
556
|
+
assert 0, "unexpected family"
|
|
557
|
+
|
|
558
|
+
# If a Numpy scalar, check the dtype is exactly the same too
|
|
559
|
+
# (required for datetime64 and timedelta64).
|
|
560
|
+
if hasattr(first, "dtype") and hasattr(second, "dtype"):
|
|
561
|
+
self.assertEqual(first.dtype, second.dtype)
|
|
562
|
+
|
|
563
|
+
# Mixing bools and non-bools should always fail
|
|
564
|
+
if isinstance(first, self._bool_types) != isinstance(
|
|
565
|
+
second, self._bool_types
|
|
566
|
+
):
|
|
567
|
+
assertion_message = "Mismatching return types (%s vs. %s)" % (
|
|
568
|
+
first.__class__,
|
|
569
|
+
second.__class__,
|
|
570
|
+
)
|
|
571
|
+
if msg:
|
|
572
|
+
assertion_message += ": %s" % (msg,)
|
|
573
|
+
self.fail(assertion_message)
|
|
574
|
+
|
|
575
|
+
try:
|
|
576
|
+
if cmath.isnan(first) and cmath.isnan(second):
|
|
577
|
+
# The NaNs will compare unequal, skip regular comparison
|
|
578
|
+
return
|
|
579
|
+
except TypeError:
|
|
580
|
+
# Not floats.
|
|
581
|
+
pass
|
|
582
|
+
|
|
583
|
+
# if absolute comparison is set, use it
|
|
584
|
+
if abs_tol is not None:
|
|
585
|
+
if abs_tol == "eps":
|
|
586
|
+
rtol = np.finfo(type(first)).eps
|
|
587
|
+
elif isinstance(abs_tol, float):
|
|
588
|
+
rtol = abs_tol
|
|
589
|
+
else:
|
|
590
|
+
raise ValueError(
|
|
591
|
+
'abs_tol is not "eps" or a float, found %s' % abs_tol
|
|
592
|
+
)
|
|
593
|
+
if abs(first - second) < rtol:
|
|
594
|
+
return
|
|
595
|
+
|
|
596
|
+
exact_comparison = exact_comparison or prec == "exact"
|
|
597
|
+
|
|
598
|
+
if not exact_comparison and prec != "exact":
|
|
599
|
+
if prec == "single":
|
|
600
|
+
bits = 24
|
|
601
|
+
elif prec == "double":
|
|
602
|
+
bits = 53
|
|
603
|
+
else:
|
|
604
|
+
raise ValueError("unsupported precision %r" % (prec,))
|
|
605
|
+
k = 2 ** (ulps - bits - 1)
|
|
606
|
+
delta = k * (abs(first) + abs(second))
|
|
607
|
+
else:
|
|
608
|
+
delta = None
|
|
609
|
+
if isinstance(first, self._complex_types):
|
|
610
|
+
_assertNumberEqual(first.real, second.real, delta)
|
|
611
|
+
_assertNumberEqual(first.imag, second.imag, delta)
|
|
612
|
+
elif isinstance(first, (np.timedelta64, np.datetime64)):
|
|
613
|
+
# Since Np 1.16 NaT == NaT is False, so special comparison needed
|
|
614
|
+
if np.isnat(first):
|
|
615
|
+
self.assertEqual(np.isnat(first), np.isnat(second))
|
|
616
|
+
else:
|
|
617
|
+
_assertNumberEqual(first, second, delta)
|
|
618
|
+
else:
|
|
619
|
+
_assertNumberEqual(first, second, delta)
|
|
620
|
+
|
|
621
|
+
def subprocess_test_runner(
|
|
622
|
+
self,
|
|
623
|
+
test_module,
|
|
624
|
+
test_class=None,
|
|
625
|
+
test_name=None,
|
|
626
|
+
envvars=None,
|
|
627
|
+
timeout=60,
|
|
628
|
+
flags=None,
|
|
629
|
+
_subproc_test_env="1",
|
|
630
|
+
):
|
|
631
|
+
"""
|
|
632
|
+
Runs named unit test(s) as specified in the arguments as:
|
|
633
|
+
test_module.test_class.test_name. test_module must always be supplied
|
|
634
|
+
and if no further refinement is made with test_class and test_name then
|
|
635
|
+
all tests in the module will be run. The tests will be run in a
|
|
636
|
+
subprocess with environment variables specified in `envvars`.
|
|
637
|
+
If given, envvars must be a map of form:
|
|
638
|
+
environment variable name (str) -> value (str)
|
|
639
|
+
If given, flags must be a map of form:
|
|
640
|
+
flag including the `-` (str) -> value (str)
|
|
641
|
+
It is most convenient to use this method in conjunction with
|
|
642
|
+
@needs_subprocess as the decorator will cause the decorated test to be
|
|
643
|
+
skipped unless the `SUBPROC_TEST` environment variable is set to
|
|
644
|
+
the same value of ``_subproc_test_env``
|
|
645
|
+
(this special environment variable is set by this method such that the
|
|
646
|
+
specified test(s) will not be skipped in the subprocess).
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
Following execution in the subprocess this method will check the test(s)
|
|
650
|
+
executed without error. The timeout kwarg can be used to allow more time
|
|
651
|
+
for longer running tests, it defaults to 60 seconds.
|
|
652
|
+
"""
|
|
653
|
+
parts = (test_module, test_class, test_name)
|
|
654
|
+
fully_qualified_test = ".".join(x for x in parts if x is not None)
|
|
655
|
+
flags_args = []
|
|
656
|
+
if flags is not None:
|
|
657
|
+
for flag, value in flags.items():
|
|
658
|
+
flags_args.append(f"{flag}")
|
|
659
|
+
flags_args.append(f"{value}")
|
|
660
|
+
cmd = [
|
|
661
|
+
sys.executable,
|
|
662
|
+
*flags_args,
|
|
663
|
+
"-m",
|
|
664
|
+
"numba.runtests",
|
|
665
|
+
fully_qualified_test,
|
|
666
|
+
]
|
|
667
|
+
env_copy = os.environ.copy()
|
|
668
|
+
env_copy["SUBPROC_TEST"] = _subproc_test_env
|
|
669
|
+
try:
|
|
670
|
+
env_copy["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"]
|
|
671
|
+
except KeyError:
|
|
672
|
+
pass # ignored
|
|
673
|
+
envvars = pytypes.MappingProxyType({} if envvars is None else envvars)
|
|
674
|
+
env_copy.update(envvars)
|
|
675
|
+
status = subprocess.run(
|
|
676
|
+
cmd,
|
|
677
|
+
stdout=subprocess.PIPE,
|
|
678
|
+
stderr=subprocess.PIPE,
|
|
679
|
+
timeout=timeout,
|
|
680
|
+
env=env_copy,
|
|
681
|
+
universal_newlines=True,
|
|
682
|
+
)
|
|
683
|
+
streams = (
|
|
684
|
+
f"\ncaptured stdout: {status.stdout}\n"
|
|
685
|
+
f"captured stderr: {status.stderr}"
|
|
686
|
+
)
|
|
687
|
+
self.assertEqual(status.returncode, 0, streams)
|
|
688
|
+
# Python 3.12.1 report
|
|
689
|
+
no_tests_ran = "NO TESTS RAN"
|
|
690
|
+
if no_tests_ran in status.stderr:
|
|
691
|
+
self.skipTest(no_tests_ran)
|
|
692
|
+
else:
|
|
693
|
+
self.assertIn("OK", status.stderr)
|
|
694
|
+
return status
|
|
695
|
+
|
|
696
|
+
def run_test_in_subprocess(maybefunc=None, timeout=60, envvars=None):
|
|
697
|
+
"""Runs the decorated test in a subprocess via invoking numba's test
|
|
698
|
+
runner. kwargs timeout and envvars are passed through to
|
|
699
|
+
subprocess_test_runner."""
|
|
700
|
+
|
|
701
|
+
def wrapper(func):
|
|
702
|
+
def inner(self, *args, **kwargs):
|
|
703
|
+
if os.environ.get("SUBPROC_TEST", None) != func.__name__:
|
|
704
|
+
# Not in a subprocess test env, so stage the call to run the
|
|
705
|
+
# test in a subprocess which will set the env var.
|
|
706
|
+
class_name = self.__class__.__name__
|
|
707
|
+
self.subprocess_test_runner(
|
|
708
|
+
test_module=self.__module__,
|
|
709
|
+
test_class=class_name,
|
|
710
|
+
test_name=func.__name__,
|
|
711
|
+
timeout=timeout,
|
|
712
|
+
envvars=envvars,
|
|
713
|
+
_subproc_test_env=func.__name__,
|
|
714
|
+
)
|
|
715
|
+
else:
|
|
716
|
+
# env var is set, so we're in the subprocess, run the
|
|
717
|
+
# actual test.
|
|
718
|
+
func(self)
|
|
719
|
+
|
|
720
|
+
return inner
|
|
721
|
+
|
|
722
|
+
if isinstance(maybefunc, pytypes.FunctionType):
|
|
723
|
+
return wrapper(maybefunc)
|
|
724
|
+
else:
|
|
725
|
+
return wrapper
|
|
726
|
+
|
|
727
|
+
def make_dummy_type(self):
|
|
728
|
+
"""Use to generate a dummy type unique to this test. Returns a python
|
|
729
|
+
Dummy class and a corresponding Numba type DummyType."""
|
|
730
|
+
|
|
731
|
+
# Use test_id to make sure no collision is possible.
|
|
732
|
+
test_id = self.id()
|
|
733
|
+
DummyType = type("DummyTypeFor{}".format(test_id), (types.Opaque,), {})
|
|
734
|
+
|
|
735
|
+
dummy_type = DummyType("my_dummy")
|
|
736
|
+
register_model(DummyType)(OpaqueModel)
|
|
737
|
+
|
|
738
|
+
class Dummy(object):
|
|
739
|
+
pass
|
|
740
|
+
|
|
741
|
+
@typeof_impl.register(Dummy)
|
|
742
|
+
def typeof_dummy(val, c):
|
|
743
|
+
return dummy_type
|
|
744
|
+
|
|
745
|
+
@unbox(DummyType)
|
|
746
|
+
def unbox_dummy(typ, obj, c):
|
|
747
|
+
return NativeValue(c.context.get_dummy_value())
|
|
748
|
+
|
|
749
|
+
return Dummy, DummyType
|
|
750
|
+
|
|
751
|
+
def skip_if_no_external_compiler(self):
|
|
752
|
+
"""
|
|
753
|
+
Call this to ensure the test is skipped if no suitable external compiler
|
|
754
|
+
is found. This is a method on the TestCase opposed to a stand-alone
|
|
755
|
+
decorator so as to make it "lazy" via runtime evaluation opposed to
|
|
756
|
+
running at test-discovery time.
|
|
757
|
+
"""
|
|
758
|
+
# This is a local import to avoid deprecation warnings being generated
|
|
759
|
+
# through the use of the numba.pycc module.
|
|
760
|
+
from numba.pycc.platform import external_compiler_works
|
|
761
|
+
|
|
762
|
+
if not external_compiler_works():
|
|
763
|
+
self.skipTest("No suitable external compiler was found.")
|