numba-cuda 0.18.1__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (88) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +2 -2
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +1 -1
  5. numba_cuda/numba/cuda/api.py +2 -7
  6. numba_cuda/numba/cuda/compiler.py +7 -4
  7. numba_cuda/numba/cuda/core/interpreter.py +3592 -0
  8. numba_cuda/numba/cuda/core/ir_utils.py +2645 -0
  9. numba_cuda/numba/cuda/core/sigutils.py +55 -0
  10. numba_cuda/numba/cuda/cuda_paths.py +9 -17
  11. numba_cuda/numba/cuda/cudadecl.py +1 -1
  12. numba_cuda/numba/cuda/cudadrv/driver.py +4 -19
  13. numba_cuda/numba/cuda/cudadrv/libs.py +1 -2
  14. numba_cuda/numba/cuda/cudadrv/nvrtc.py +44 -44
  15. numba_cuda/numba/cuda/cudadrv/nvvm.py +3 -18
  16. numba_cuda/numba/cuda/cudadrv/runtime.py +12 -1
  17. numba_cuda/numba/cuda/cudamath.py +1 -1
  18. numba_cuda/numba/cuda/decorators.py +4 -3
  19. numba_cuda/numba/cuda/deviceufunc.py +2 -1
  20. numba_cuda/numba/cuda/dispatcher.py +3 -2
  21. numba_cuda/numba/cuda/extending.py +1 -1
  22. numba_cuda/numba/cuda/itanium_mangler.py +211 -0
  23. numba_cuda/numba/cuda/libdevicedecl.py +1 -1
  24. numba_cuda/numba/cuda/libdevicefuncs.py +1 -1
  25. numba_cuda/numba/cuda/lowering.py +1 -1
  26. numba_cuda/numba/cuda/simulator/api.py +1 -1
  27. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +0 -7
  28. numba_cuda/numba/cuda/target.py +1 -2
  29. numba_cuda/numba/cuda/testing.py +4 -6
  30. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +80 -0
  31. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +1 -1
  32. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  33. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +1 -1
  34. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  35. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +1 -1
  36. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +1 -1
  37. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +1 -1
  38. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +4 -6
  39. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +0 -4
  40. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  41. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +1 -3
  42. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +1 -3
  43. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +146 -3
  44. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +1 -1
  45. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +0 -4
  46. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -1
  47. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +1 -284
  50. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +473 -0
  51. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +1 -1
  52. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -6
  54. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +1 -1
  55. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +1 -1
  56. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +295 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +1 -1
  58. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  59. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +1 -1
  60. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +5 -1
  61. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +1 -1
  62. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +1 -1
  63. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +1 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +1 -1
  65. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +1 -1
  66. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +1 -1
  67. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +1 -1
  68. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +1 -1
  69. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +1 -1
  70. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +1 -1
  71. numba_cuda/numba/cuda/tests/nocuda/test_import.py +1 -1
  72. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -2
  73. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +1 -1
  74. numba_cuda/numba/cuda/tests/support.py +752 -0
  75. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -3
  76. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +4 -1
  77. numba_cuda/numba/cuda/typing/__init__.py +8 -0
  78. numba_cuda/numba/cuda/typing/templates.py +1453 -0
  79. numba_cuda/numba/cuda/vector_types.py +3 -3
  80. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/METADATA +21 -28
  81. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/RECORD +84 -79
  82. numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
  83. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
  84. numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
  85. numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
  86. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/WHEEL +0 -0
  87. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/licenses/LICENSE +0 -0
  88. {numba_cuda-0.18.1.dist-info → numba_cuda-0.19.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,34 @@
1
+ import cmath
2
+ import contextlib
3
+ import enum
4
+ import gc
5
+ import math
6
+ import unittest
7
+ import os
8
+ import io
9
+ import subprocess
10
+ import sys
11
+ import shutil
12
+ import warnings
13
+ import tempfile
14
+ import time
15
+ import types as pytypes
16
+ from functools import cached_property
17
+
18
+ import numpy as np
19
+
20
+ from numba import types
21
+ from numba.core import errors, config
22
+ from numba.core.typing import cffi_utils
1
23
  from numba.cuda.memory_management.nrt import rtsys
24
+ from numba.core.extending import (
25
+ typeof_impl,
26
+ register_model,
27
+ unbox,
28
+ NativeValue,
29
+ )
30
+ from numba.core.datamodel.models import OpaqueModel
31
+ from numba.np import numpy_support
2
32
 
3
33
 
4
34
  class EnableNRTStatsMixin(object):
@@ -9,3 +39,725 @@ class EnableNRTStatsMixin(object):
9
39
 
10
40
  def tearDown(self):
11
41
  rtsys.memsys_disable_stats()
42
+
43
+
44
+ skip_unless_cffi = unittest.skipUnless(cffi_utils.SUPPORTED, "requires cffi")
45
+
46
+ _lnx_reason = "linux only test"
47
+ linux_only = unittest.skipIf(not sys.platform.startswith("linux"), _lnx_reason)
48
+
49
+ _win_reason = "Windows only test"
50
+ windows_only = unittest.skipIf(not sys.platform.startswith("win"), _win_reason)
51
+
52
+ IS_NUMPY_2 = numpy_support.numpy_version >= (2, 0)
53
+ skip_if_numpy_2 = unittest.skipIf(IS_NUMPY_2, "Not supported on numpy 2.0+")
54
+
55
+ _trashcan_dir = "numba-cuda-tests"
56
+
57
+ if os.name == "nt":
58
+ # Under Windows, gettempdir() points to the user-local temp dir
59
+ _trashcan_dir = os.path.join(tempfile.gettempdir(), _trashcan_dir)
60
+ else:
61
+ # Mix the UID into the directory name to allow different users to
62
+ # run the test suite without permission errors (issue #1586)
63
+ _trashcan_dir = os.path.join(
64
+ tempfile.gettempdir(), "%s.%s" % (_trashcan_dir, os.getuid())
65
+ )
66
+
67
+ # Stale temporary directories are deleted after they are older than this value.
68
+ # The test suite probably won't ever take longer than this...
69
+ _trashcan_timeout = 24 * 3600 # 1 day
70
+
71
+
72
+ def _create_trashcan_dir():
73
+ try:
74
+ os.mkdir(_trashcan_dir)
75
+ except FileExistsError:
76
+ pass
77
+
78
+
79
+ def _purge_trashcan_dir():
80
+ freshness_threshold = time.time() - _trashcan_timeout
81
+ for fn in sorted(os.listdir(_trashcan_dir)):
82
+ fn = os.path.join(_trashcan_dir, fn)
83
+ try:
84
+ st = os.stat(fn)
85
+ if st.st_mtime < freshness_threshold:
86
+ shutil.rmtree(fn, ignore_errors=True)
87
+ except OSError:
88
+ # In parallel testing, several processes can attempt to
89
+ # remove the same entry at once, ignore.
90
+ pass
91
+
92
+
93
+ def _create_trashcan_subdir(prefix):
94
+ _purge_trashcan_dir()
95
+ path = tempfile.mkdtemp(prefix=prefix + "-", dir=_trashcan_dir)
96
+ return path
97
+
98
+
99
+ def temp_directory(prefix):
100
+ """
101
+ Create a temporary directory with the given *prefix* that will survive
102
+ at least as long as this process invocation. The temporary directory
103
+ will be eventually deleted when it becomes stale enough.
104
+
105
+ This is necessary because a DLL file can't be deleted while in use
106
+ under Windows.
107
+
108
+ An interesting side-effect is to be able to inspect the test files
109
+ shortly after a test suite run.
110
+ """
111
+ _create_trashcan_dir()
112
+ return _create_trashcan_subdir(prefix)
113
+
114
+
115
+ def import_dynamic(modname):
116
+ """
117
+ Import and return a module of the given name. Care is taken to
118
+ avoid issues due to Python's internal directory caching.
119
+ """
120
+ import importlib
121
+
122
+ importlib.invalidate_caches()
123
+ __import__(modname)
124
+ return sys.modules[modname]
125
+
126
+
127
+ def ignore_internal_warnings():
128
+ """Use in testing within a ` warnings.catch_warnings` block to filter out
129
+ warnings that are unrelated/internally generated by Numba.
130
+ """
131
+ # Filter out warnings from typeguard
132
+ warnings.filterwarnings("ignore", module="typeguard")
133
+ # Filter out warnings about TBB interface mismatch
134
+ warnings.filterwarnings(
135
+ action="ignore",
136
+ message=r".*TBB_INTERFACE_VERSION.*",
137
+ category=errors.NumbaWarning,
138
+ module=r"numba\.np\.ufunc\.parallel.*",
139
+ )
140
+
141
+
142
+ @contextlib.contextmanager
143
+ def override_config(name, value):
144
+ """
145
+ Return a context manager that temporarily sets Numba config variable
146
+ *name* to *value*. *name* must be the name of an existing variable
147
+ in numba.config.
148
+ """
149
+ old_value = getattr(config, name)
150
+ setattr(config, name, value)
151
+ try:
152
+ yield
153
+ finally:
154
+ setattr(config, name, old_value)
155
+
156
+
157
+ def run_in_subprocess(code, flags=None, env=None, timeout=30):
158
+ """Run a snippet of Python code in a subprocess with flags, if any are
159
+ given. 'env' is passed to subprocess.Popen(). 'timeout' is passed to
160
+ popen.communicate().
161
+
162
+ Returns the stdout and stderr of the subprocess after its termination.
163
+ """
164
+ if flags is None:
165
+ flags = []
166
+ cmd = (
167
+ [
168
+ sys.executable,
169
+ ]
170
+ + flags
171
+ + ["-c", code]
172
+ )
173
+ popen = subprocess.Popen(
174
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
175
+ )
176
+ out, err = popen.communicate(timeout=timeout)
177
+ if popen.returncode != 0:
178
+ msg = "process failed with code %s: stderr follows\n%s\n"
179
+ raise AssertionError(msg % (popen.returncode, err.decode()))
180
+ return out, err
181
+
182
+
183
+ @contextlib.contextmanager
184
+ def captured_output(stream_name):
185
+ """Return a context manager used by captured_stdout/stdin/stderr
186
+ that temporarily replaces the sys stream *stream_name* with a StringIO."""
187
+ orig_stdout = getattr(sys, stream_name)
188
+ setattr(sys, stream_name, io.StringIO())
189
+ try:
190
+ yield getattr(sys, stream_name)
191
+ finally:
192
+ setattr(sys, stream_name, orig_stdout)
193
+
194
+
195
+ def captured_stdout():
196
+ """Capture the output of sys.stdout:
197
+
198
+ with captured_stdout() as stdout:
199
+ print("hello")
200
+ self.assertEqual(stdout.getvalue(), "hello\n")
201
+ """
202
+ return captured_output("stdout")
203
+
204
+
205
+ def captured_stderr():
206
+ """Capture the output of sys.stderr:
207
+
208
+ with captured_stderr() as stderr:
209
+ print("hello", file=sys.stderr)
210
+ self.assertEqual(stderr.getvalue(), "hello\n")
211
+ """
212
+ return captured_output("stderr")
213
+
214
+
215
+ class TestCase(unittest.TestCase):
216
+ longMessage = True
217
+
218
+ # A random state yielding the same random numbers for any test case.
219
+ # Use as `self.random.<method name>`
220
+ @cached_property
221
+ def random(self):
222
+ return np.random.RandomState(42)
223
+
224
+ def reset_module_warnings(self, module):
225
+ """
226
+ Reset the warnings registry of a module. This can be necessary
227
+ as the warnings module is buggy in that regard.
228
+ See http://bugs.python.org/issue4180
229
+ """
230
+ if isinstance(module, str):
231
+ module = sys.modules[module]
232
+ try:
233
+ del module.__warningregistry__
234
+ except AttributeError:
235
+ pass
236
+
237
+ @contextlib.contextmanager
238
+ def assertTypingError(self):
239
+ """
240
+ A context manager that asserts the enclosed code block fails
241
+ compiling in nopython mode.
242
+ """
243
+ _accepted_errors = (
244
+ errors.LoweringError,
245
+ errors.TypingError,
246
+ TypeError,
247
+ NotImplementedError,
248
+ )
249
+ with self.assertRaises(_accepted_errors) as cm:
250
+ yield cm
251
+
252
+ @contextlib.contextmanager
253
+ def assertRefCount(self, *objects):
254
+ """
255
+ A context manager that asserts the given objects have the
256
+ same reference counts before and after executing the
257
+ enclosed block.
258
+ """
259
+ old_refcounts = [sys.getrefcount(x) for x in objects]
260
+ yield
261
+ gc.collect()
262
+ new_refcounts = [sys.getrefcount(x) for x in objects]
263
+ for old, new, obj in zip(old_refcounts, new_refcounts, objects):
264
+ if old != new:
265
+ self.fail(
266
+ "Refcount changed from %d to %d for object: %r"
267
+ % (old, new, obj)
268
+ )
269
+
270
+ def assertRefCountEqual(self, *objects):
271
+ gc.collect()
272
+ rc = [sys.getrefcount(x) for x in objects]
273
+ rc_0 = rc[0]
274
+ for i in range(len(objects))[1:]:
275
+ rc_i = rc[i]
276
+ if rc_0 != rc_i:
277
+ self.fail(
278
+ f"Refcount for objects does not match. "
279
+ f"#0({rc_0}) != #{i}({rc_i}) does not match."
280
+ )
281
+
282
+ @contextlib.contextmanager
283
+ def assertNoNRTLeak(self):
284
+ """
285
+ A context manager that asserts no NRT leak was created during
286
+ the execution of the enclosed block.
287
+ """
288
+ old = rtsys.get_allocation_stats()
289
+ yield
290
+ new = rtsys.get_allocation_stats()
291
+ total_alloc = new.alloc - old.alloc
292
+ total_free = new.free - old.free
293
+ total_mi_alloc = new.mi_alloc - old.mi_alloc
294
+ total_mi_free = new.mi_free - old.mi_free
295
+ self.assertEqual(
296
+ total_alloc,
297
+ total_free,
298
+ "number of data allocs != number of data frees",
299
+ )
300
+ self.assertEqual(
301
+ total_mi_alloc,
302
+ total_mi_free,
303
+ "number of meminfo allocs != number of meminfo frees",
304
+ )
305
+
306
+ _bool_types = (bool, np.bool_)
307
+ _exact_typesets = [
308
+ _bool_types,
309
+ (int,),
310
+ (str,),
311
+ (np.integer,),
312
+ (bytes, np.bytes_),
313
+ ]
314
+ _approx_typesets = [(float,), (complex,), (np.inexact)]
315
+ _sequence_typesets = [(tuple, list)]
316
+ _float_types = (float, np.floating)
317
+ _complex_types = (complex, np.complexfloating)
318
+
319
+ def _detect_family(self, numeric_object):
320
+ """
321
+ This function returns a string description of the type family
322
+ that the object in question belongs to. Possible return values
323
+ are: "exact", "complex", "approximate", "sequence", and "unknown"
324
+ """
325
+ if isinstance(numeric_object, np.ndarray):
326
+ return "ndarray"
327
+
328
+ if isinstance(numeric_object, enum.Enum):
329
+ return "enum"
330
+
331
+ for tp in self._sequence_typesets:
332
+ if isinstance(numeric_object, tp):
333
+ return "sequence"
334
+
335
+ for tp in self._exact_typesets:
336
+ if isinstance(numeric_object, tp):
337
+ return "exact"
338
+
339
+ for tp in self._complex_types:
340
+ if isinstance(numeric_object, tp):
341
+ return "complex"
342
+
343
+ for tp in self._approx_typesets:
344
+ if isinstance(numeric_object, tp):
345
+ return "approximate"
346
+
347
+ return "unknown"
348
+
349
+ def _fix_dtype(self, dtype):
350
+ """
351
+ Fix the given *dtype* for comparison.
352
+ """
353
+ # Under 64-bit Windows, Numpy may return either int32 or int64
354
+ # arrays depending on the function.
355
+ if (
356
+ sys.platform == "win32"
357
+ and sys.maxsize > 2**32
358
+ and dtype == np.dtype("int32")
359
+ ):
360
+ return np.dtype("int64")
361
+ else:
362
+ return dtype
363
+
364
+ def _fix_strides(self, arr):
365
+ """
366
+ Return the strides of the given array, fixed for comparison.
367
+ Strides for 0- or 1-sized dimensions are ignored.
368
+ """
369
+ if arr.size == 0:
370
+ return [0] * arr.ndim
371
+ else:
372
+ return [
373
+ stride / arr.itemsize
374
+ for (stride, shape) in zip(arr.strides, arr.shape)
375
+ if shape > 1
376
+ ]
377
+
378
+ def assertStridesEqual(self, first, second):
379
+ """
380
+ Test that two arrays have the same shape and strides.
381
+ """
382
+ self.assertEqual(first.shape, second.shape, "shapes differ")
383
+ self.assertEqual(first.itemsize, second.itemsize, "itemsizes differ")
384
+ self.assertEqual(
385
+ self._fix_strides(first),
386
+ self._fix_strides(second),
387
+ "strides differ",
388
+ )
389
+
390
+ def assertPreciseEqual(
391
+ self,
392
+ first,
393
+ second,
394
+ prec="exact",
395
+ ulps=1,
396
+ msg=None,
397
+ ignore_sign_on_zero=False,
398
+ abs_tol=None,
399
+ ):
400
+ """
401
+ Versatile equality testing function with more built-in checks than
402
+ standard assertEqual().
403
+
404
+ For arrays, test that layout, dtype, shape are identical, and
405
+ recursively call assertPreciseEqual() on the contents.
406
+
407
+ For other sequences, recursively call assertPreciseEqual() on
408
+ the contents.
409
+
410
+ For scalars, test that two scalars or have similar types and are
411
+ equal up to a computed precision.
412
+ If the scalars are instances of exact types or if *prec* is
413
+ 'exact', they are compared exactly.
414
+ If the scalars are instances of inexact types (float, complex)
415
+ and *prec* is not 'exact', then the number of significant bits
416
+ is computed according to the value of *prec*: 53 bits if *prec*
417
+ is 'double', 24 bits if *prec* is single. This number of bits
418
+ can be lowered by raising the *ulps* value.
419
+ ignore_sign_on_zero can be set to True if zeros are to be considered
420
+ equal regardless of their sign bit.
421
+ abs_tol if this is set to a float value its value is used in the
422
+ following. If, however, this is set to the string "eps" then machine
423
+ precision of the type(first) is used in the following instead. This
424
+ kwarg is used to check if the absolute difference in value between first
425
+ and second is less than the value set, if so the numbers being compared
426
+ are considered equal. (This is to handle small numbers typically of
427
+ magnitude less than machine precision).
428
+
429
+ Any value of *prec* other than 'exact', 'single' or 'double'
430
+ will raise an error.
431
+ """
432
+ try:
433
+ self._assertPreciseEqual(
434
+ first, second, prec, ulps, msg, ignore_sign_on_zero, abs_tol
435
+ )
436
+ except AssertionError as exc:
437
+ failure_msg = str(exc)
438
+ # Fall off of the 'except' scope to avoid Python 3 exception
439
+ # chaining.
440
+ else:
441
+ return
442
+ # Decorate the failure message with more information
443
+ self.fail("when comparing %s and %s: %s" % (first, second, failure_msg))
444
+
445
+ def _assertPreciseEqual(
446
+ self,
447
+ first,
448
+ second,
449
+ prec="exact",
450
+ ulps=1,
451
+ msg=None,
452
+ ignore_sign_on_zero=False,
453
+ abs_tol=None,
454
+ ):
455
+ """Recursive workhorse for assertPreciseEqual()."""
456
+
457
+ def _assertNumberEqual(first, second, delta=None):
458
+ if (
459
+ delta is None
460
+ or first == second == 0.0
461
+ or math.isinf(first)
462
+ or math.isinf(second)
463
+ ):
464
+ self.assertEqual(first, second, msg=msg)
465
+ # For signed zeros
466
+ if not ignore_sign_on_zero:
467
+ try:
468
+ if math.copysign(1, first) != math.copysign(1, second):
469
+ self.fail(
470
+ self._formatMessage(
471
+ msg, "%s != %s" % (first, second)
472
+ )
473
+ )
474
+ except TypeError:
475
+ pass
476
+ else:
477
+ self.assertAlmostEqual(first, second, delta=delta, msg=msg)
478
+
479
+ first_family = self._detect_family(first)
480
+ second_family = self._detect_family(second)
481
+
482
+ assertion_message = "Type Family mismatch. (%s != %s)" % (
483
+ first_family,
484
+ second_family,
485
+ )
486
+ if msg:
487
+ assertion_message += ": %s" % (msg,)
488
+ self.assertEqual(first_family, second_family, msg=assertion_message)
489
+
490
+ # We now know they are in the same comparison family
491
+ compare_family = first_family
492
+
493
+ # For recognized sequences, recurse
494
+ if compare_family == "ndarray":
495
+ dtype = self._fix_dtype(first.dtype)
496
+ self.assertEqual(dtype, self._fix_dtype(second.dtype))
497
+ self.assertEqual(
498
+ first.ndim, second.ndim, "different number of dimensions"
499
+ )
500
+ self.assertEqual(first.shape, second.shape, "different shapes")
501
+ self.assertEqual(
502
+ first.flags.writeable,
503
+ second.flags.writeable,
504
+ "different mutability",
505
+ )
506
+ # itemsize is already checked by the dtype test above
507
+ self.assertEqual(
508
+ self._fix_strides(first),
509
+ self._fix_strides(second),
510
+ "different strides",
511
+ )
512
+ if first.dtype != dtype:
513
+ first = first.astype(dtype)
514
+ if second.dtype != dtype:
515
+ second = second.astype(dtype)
516
+ for a, b in zip(first.flat, second.flat):
517
+ self._assertPreciseEqual(
518
+ a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
519
+ )
520
+ return
521
+
522
+ elif compare_family == "sequence":
523
+ self.assertEqual(len(first), len(second), msg=msg)
524
+ for a, b in zip(first, second):
525
+ self._assertPreciseEqual(
526
+ a, b, prec, ulps, msg, ignore_sign_on_zero, abs_tol
527
+ )
528
+ return
529
+
530
+ elif compare_family == "exact":
531
+ exact_comparison = True
532
+
533
+ elif compare_family in ["complex", "approximate"]:
534
+ exact_comparison = False
535
+
536
+ elif compare_family == "enum":
537
+ self.assertIs(first.__class__, second.__class__)
538
+ self._assertPreciseEqual(
539
+ first.value,
540
+ second.value,
541
+ prec,
542
+ ulps,
543
+ msg,
544
+ ignore_sign_on_zero,
545
+ abs_tol,
546
+ )
547
+ return
548
+
549
+ elif compare_family == "unknown":
550
+ # Assume these are non-numeric types: we will fall back
551
+ # on regular unittest comparison.
552
+ self.assertIs(first.__class__, second.__class__)
553
+ exact_comparison = True
554
+
555
+ else:
556
+ assert 0, "unexpected family"
557
+
558
+ # If a Numpy scalar, check the dtype is exactly the same too
559
+ # (required for datetime64 and timedelta64).
560
+ if hasattr(first, "dtype") and hasattr(second, "dtype"):
561
+ self.assertEqual(first.dtype, second.dtype)
562
+
563
+ # Mixing bools and non-bools should always fail
564
+ if isinstance(first, self._bool_types) != isinstance(
565
+ second, self._bool_types
566
+ ):
567
+ assertion_message = "Mismatching return types (%s vs. %s)" % (
568
+ first.__class__,
569
+ second.__class__,
570
+ )
571
+ if msg:
572
+ assertion_message += ": %s" % (msg,)
573
+ self.fail(assertion_message)
574
+
575
+ try:
576
+ if cmath.isnan(first) and cmath.isnan(second):
577
+ # The NaNs will compare unequal, skip regular comparison
578
+ return
579
+ except TypeError:
580
+ # Not floats.
581
+ pass
582
+
583
+ # if absolute comparison is set, use it
584
+ if abs_tol is not None:
585
+ if abs_tol == "eps":
586
+ rtol = np.finfo(type(first)).eps
587
+ elif isinstance(abs_tol, float):
588
+ rtol = abs_tol
589
+ else:
590
+ raise ValueError(
591
+ 'abs_tol is not "eps" or a float, found %s' % abs_tol
592
+ )
593
+ if abs(first - second) < rtol:
594
+ return
595
+
596
+ exact_comparison = exact_comparison or prec == "exact"
597
+
598
+ if not exact_comparison and prec != "exact":
599
+ if prec == "single":
600
+ bits = 24
601
+ elif prec == "double":
602
+ bits = 53
603
+ else:
604
+ raise ValueError("unsupported precision %r" % (prec,))
605
+ k = 2 ** (ulps - bits - 1)
606
+ delta = k * (abs(first) + abs(second))
607
+ else:
608
+ delta = None
609
+ if isinstance(first, self._complex_types):
610
+ _assertNumberEqual(first.real, second.real, delta)
611
+ _assertNumberEqual(first.imag, second.imag, delta)
612
+ elif isinstance(first, (np.timedelta64, np.datetime64)):
613
+ # Since Np 1.16 NaT == NaT is False, so special comparison needed
614
+ if np.isnat(first):
615
+ self.assertEqual(np.isnat(first), np.isnat(second))
616
+ else:
617
+ _assertNumberEqual(first, second, delta)
618
+ else:
619
+ _assertNumberEqual(first, second, delta)
620
+
621
+ def subprocess_test_runner(
622
+ self,
623
+ test_module,
624
+ test_class=None,
625
+ test_name=None,
626
+ envvars=None,
627
+ timeout=60,
628
+ flags=None,
629
+ _subproc_test_env="1",
630
+ ):
631
+ """
632
+ Runs named unit test(s) as specified in the arguments as:
633
+ test_module.test_class.test_name. test_module must always be supplied
634
+ and if no further refinement is made with test_class and test_name then
635
+ all tests in the module will be run. The tests will be run in a
636
+ subprocess with environment variables specified in `envvars`.
637
+ If given, envvars must be a map of form:
638
+ environment variable name (str) -> value (str)
639
+ If given, flags must be a map of form:
640
+ flag including the `-` (str) -> value (str)
641
+ It is most convenient to use this method in conjunction with
642
+ @needs_subprocess as the decorator will cause the decorated test to be
643
+ skipped unless the `SUBPROC_TEST` environment variable is set to
644
+ the same value of ``_subproc_test_env``
645
+ (this special environment variable is set by this method such that the
646
+ specified test(s) will not be skipped in the subprocess).
647
+
648
+
649
+ Following execution in the subprocess this method will check the test(s)
650
+ executed without error. The timeout kwarg can be used to allow more time
651
+ for longer running tests, it defaults to 60 seconds.
652
+ """
653
+ parts = (test_module, test_class, test_name)
654
+ fully_qualified_test = ".".join(x for x in parts if x is not None)
655
+ flags_args = []
656
+ if flags is not None:
657
+ for flag, value in flags.items():
658
+ flags_args.append(f"{flag}")
659
+ flags_args.append(f"{value}")
660
+ cmd = [
661
+ sys.executable,
662
+ *flags_args,
663
+ "-m",
664
+ "numba.runtests",
665
+ fully_qualified_test,
666
+ ]
667
+ env_copy = os.environ.copy()
668
+ env_copy["SUBPROC_TEST"] = _subproc_test_env
669
+ try:
670
+ env_copy["COVERAGE_PROCESS_START"] = os.environ["COVERAGE_RCFILE"]
671
+ except KeyError:
672
+ pass # ignored
673
+ envvars = pytypes.MappingProxyType({} if envvars is None else envvars)
674
+ env_copy.update(envvars)
675
+ status = subprocess.run(
676
+ cmd,
677
+ stdout=subprocess.PIPE,
678
+ stderr=subprocess.PIPE,
679
+ timeout=timeout,
680
+ env=env_copy,
681
+ universal_newlines=True,
682
+ )
683
+ streams = (
684
+ f"\ncaptured stdout: {status.stdout}\n"
685
+ f"captured stderr: {status.stderr}"
686
+ )
687
+ self.assertEqual(status.returncode, 0, streams)
688
+ # Python 3.12.1 report
689
+ no_tests_ran = "NO TESTS RAN"
690
+ if no_tests_ran in status.stderr:
691
+ self.skipTest(no_tests_ran)
692
+ else:
693
+ self.assertIn("OK", status.stderr)
694
+ return status
695
+
696
+ def run_test_in_subprocess(maybefunc=None, timeout=60, envvars=None):
697
+ """Runs the decorated test in a subprocess via invoking numba's test
698
+ runner. kwargs timeout and envvars are passed through to
699
+ subprocess_test_runner."""
700
+
701
+ def wrapper(func):
702
+ def inner(self, *args, **kwargs):
703
+ if os.environ.get("SUBPROC_TEST", None) != func.__name__:
704
+ # Not in a subprocess test env, so stage the call to run the
705
+ # test in a subprocess which will set the env var.
706
+ class_name = self.__class__.__name__
707
+ self.subprocess_test_runner(
708
+ test_module=self.__module__,
709
+ test_class=class_name,
710
+ test_name=func.__name__,
711
+ timeout=timeout,
712
+ envvars=envvars,
713
+ _subproc_test_env=func.__name__,
714
+ )
715
+ else:
716
+ # env var is set, so we're in the subprocess, run the
717
+ # actual test.
718
+ func(self)
719
+
720
+ return inner
721
+
722
+ if isinstance(maybefunc, pytypes.FunctionType):
723
+ return wrapper(maybefunc)
724
+ else:
725
+ return wrapper
726
+
727
+ def make_dummy_type(self):
728
+ """Use to generate a dummy type unique to this test. Returns a python
729
+ Dummy class and a corresponding Numba type DummyType."""
730
+
731
+ # Use test_id to make sure no collision is possible.
732
+ test_id = self.id()
733
+ DummyType = type("DummyTypeFor{}".format(test_id), (types.Opaque,), {})
734
+
735
+ dummy_type = DummyType("my_dummy")
736
+ register_model(DummyType)(OpaqueModel)
737
+
738
+ class Dummy(object):
739
+ pass
740
+
741
+ @typeof_impl.register(Dummy)
742
+ def typeof_dummy(val, c):
743
+ return dummy_type
744
+
745
+ @unbox(DummyType)
746
+ def unbox_dummy(typ, obj, c):
747
+ return NativeValue(c.context.get_dummy_value())
748
+
749
+ return Dummy, DummyType
750
+
751
+ def skip_if_no_external_compiler(self):
752
+ """
753
+ Call this to ensure the test is skipped if no suitable external compiler
754
+ is found. This is a method on the TestCase opposed to a stand-alone
755
+ decorator so as to make it "lazy" via runtime evaluation opposed to
756
+ running at test-discovery time.
757
+ """
758
+ # This is a local import to avoid deprecation warnings being generated
759
+ # through the use of the numba.pycc module.
760
+ from numba.pycc.platform import external_compiler_works
761
+
762
+ if not external_compiler_works():
763
+ self.skipTest("No suitable external compiler was found.")