numba-cuda 0.17.0__py3-none-any.whl → 0.18.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (64) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +0 -8
  3. numba_cuda/numba/cuda/_internal/cuda_fp16.py +14225 -0
  4. numba_cuda/numba/cuda/api_util.py +6 -0
  5. numba_cuda/numba/cuda/cgutils.py +1291 -0
  6. numba_cuda/numba/cuda/codegen.py +32 -14
  7. numba_cuda/numba/cuda/compiler.py +113 -10
  8. numba_cuda/numba/cuda/core/caching.py +741 -0
  9. numba_cuda/numba/cuda/core/callconv.py +338 -0
  10. numba_cuda/numba/cuda/core/codegen.py +168 -0
  11. numba_cuda/numba/cuda/core/compiler.py +205 -0
  12. numba_cuda/numba/cuda/core/typed_passes.py +139 -0
  13. numba_cuda/numba/cuda/cudadecl.py +0 -268
  14. numba_cuda/numba/cuda/cudadrv/devicearray.py +3 -0
  15. numba_cuda/numba/cuda/cudadrv/driver.py +2 -1
  16. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -1
  17. numba_cuda/numba/cuda/cudaimpl.py +4 -178
  18. numba_cuda/numba/cuda/debuginfo.py +469 -3
  19. numba_cuda/numba/cuda/device_init.py +0 -1
  20. numba_cuda/numba/cuda/dispatcher.py +310 -11
  21. numba_cuda/numba/cuda/extending.py +2 -1
  22. numba_cuda/numba/cuda/fp16.py +348 -0
  23. numba_cuda/numba/cuda/intrinsics.py +1 -1
  24. numba_cuda/numba/cuda/libdeviceimpl.py +2 -1
  25. numba_cuda/numba/cuda/lowering.py +1833 -8
  26. numba_cuda/numba/cuda/mathimpl.py +2 -90
  27. numba_cuda/numba/cuda/nvvmutils.py +2 -1
  28. numba_cuda/numba/cuda/printimpl.py +2 -1
  29. numba_cuda/numba/cuda/serialize.py +264 -0
  30. numba_cuda/numba/cuda/simulator/__init__.py +2 -0
  31. numba_cuda/numba/cuda/simulator/dispatcher.py +7 -0
  32. numba_cuda/numba/cuda/stubs.py +0 -308
  33. numba_cuda/numba/cuda/target.py +13 -5
  34. numba_cuda/numba/cuda/testing.py +156 -5
  35. numba_cuda/numba/cuda/tests/complex_usecases.py +113 -0
  36. numba_cuda/numba/cuda/tests/core/serialize_usecases.py +110 -0
  37. numba_cuda/numba/cuda/tests/core/test_serialize.py +359 -0
  38. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +10 -4
  39. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +33 -0
  40. numba_cuda/numba/cuda/tests/cudadrv/test_runtime.py +2 -2
  41. numba_cuda/numba/cuda/tests/cudadrv/test_streams.py +1 -0
  42. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  43. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +5 -10
  44. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +15 -0
  45. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  46. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +381 -0
  47. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +108 -24
  50. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +37 -23
  51. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +43 -27
  52. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +26 -9
  53. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +27 -2
  54. numba_cuda/numba/cuda/tests/enum_usecases.py +56 -0
  55. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +1 -2
  56. numba_cuda/numba/cuda/tests/nocuda/test_nvvm.py +1 -1
  57. numba_cuda/numba/cuda/utils.py +785 -0
  58. numba_cuda/numba/cuda/vector_types.py +1 -1
  59. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/METADATA +18 -4
  60. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/RECORD +63 -50
  61. numba_cuda/numba/cuda/cpp_function_wrappers.cu +0 -46
  62. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/WHEEL +0 -0
  63. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/licenses/LICENSE +0 -0
  64. {numba_cuda-0.17.0.dist-info → numba_cuda-0.18.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,110 @@
1
+ """
2
+ Separate module with function samples for serialization tests,
3
+ to avoid issues with __main__.
4
+ """
5
+
6
+ import math
7
+ from math import sqrt
8
+ import numpy as np
9
+ import numpy.random as nprand
10
+
11
+ from numba import jit
12
+ from numba.core import types
13
+
14
+
15
+ @jit((types.int32, types.int32))
16
+ def add_with_sig(a, b):
17
+ return a + b
18
+
19
+
20
+ @jit
21
+ def add_without_sig(a, b):
22
+ return a + b
23
+
24
+
25
+ @jit(nopython=True)
26
+ def add_nopython(a, b):
27
+ return a + b
28
+
29
+
30
+ @jit(nopython=True)
31
+ def add_nopython_fail(a, b):
32
+ object()
33
+ return a + b
34
+
35
+
36
+ def closure(a):
37
+ @jit(nopython=True)
38
+ def inner(b, c):
39
+ return a + b + c
40
+
41
+ return inner
42
+
43
+
44
+ K = 3.0
45
+
46
+
47
+ def closure_with_globals(x, **jit_args):
48
+ @jit(**jit_args)
49
+ def inner(y):
50
+ # Exercise a builtin function and a module-level constant
51
+ k = max(K, K + 1)
52
+ # Exercise two functions from another module, one accessed with
53
+ # dotted notation, one imported explicitly.
54
+ return math.hypot(x, y) + sqrt(k)
55
+
56
+ return inner
57
+
58
+
59
+ @jit(nopython=True)
60
+ def other_function(x, y):
61
+ return math.hypot(x, y)
62
+
63
+
64
+ @jit(forceobj=True)
65
+ def get_global_objmode(x):
66
+ return K * x
67
+
68
+
69
+ @jit(nopython=True)
70
+ def get_renamed_module(x):
71
+ nprand.seed(42)
72
+ return np.cos(x), nprand.random()
73
+
74
+
75
+ def closure_calling_other_function(x):
76
+ @jit(nopython=True)
77
+ def inner(y, z):
78
+ return other_function(x, y) + z
79
+
80
+ return inner
81
+
82
+
83
+ def closure_calling_other_closure(x):
84
+ @jit(nopython=True)
85
+ def other_inner(y):
86
+ return math.hypot(x, y)
87
+
88
+ @jit(nopython=True)
89
+ def inner(y):
90
+ return other_inner(y) + x
91
+
92
+ return inner
93
+
94
+
95
+ # A dynamic function calling a builtin function
96
+ def _get_dyn_func(**jit_args):
97
+ code = """
98
+ def dyn_func(x):
99
+ res = 0
100
+ for i in range(x):
101
+ res += x
102
+ return res
103
+ """
104
+ ns = {}
105
+ exec(code.strip(), ns)
106
+ return jit(**jit_args)(ns["dyn_func"])
107
+
108
+
109
+ dyn_func = _get_dyn_func(nopython=True)
110
+ dyn_func_objmode = _get_dyn_func(forceobj=True)
@@ -0,0 +1,359 @@
1
+ import contextlib
2
+ import gc
3
+ import pickle
4
+ import runpy
5
+ import subprocess
6
+ import sys
7
+ import unittest
8
+ from multiprocessing import get_context
9
+
10
+ import numba
11
+ from numba.core.errors import TypingError
12
+ from numba.tests.support import TestCase
13
+ from numba.core.target_extension import resolve_dispatcher_from_str
14
+ from numba.cloudpickle import dumps, loads
15
+
16
+
17
+ class TestDispatcherPickling(TestCase):
18
+ def run_with_protocols(self, meth, *args, **kwargs):
19
+ for proto in range(pickle.HIGHEST_PROTOCOL + 1):
20
+ meth(proto, *args, **kwargs)
21
+
22
+ @contextlib.contextmanager
23
+ def simulate_fresh_target(self):
24
+ hwstr = "cpu"
25
+ dispatcher_cls = resolve_dispatcher_from_str(hwstr)
26
+ old_descr = dispatcher_cls.targetdescr
27
+ # Simulate fresh targetdescr
28
+ dispatcher_cls.targetdescr = type(dispatcher_cls.targetdescr)(hwstr)
29
+ try:
30
+ yield
31
+ finally:
32
+ # Be sure to reinstantiate old descriptor, otherwise other
33
+ # objects may be out of sync.
34
+ dispatcher_cls.targetdescr = old_descr
35
+
36
+ def check_call(self, proto, func, expected_result, args):
37
+ def check_result(func):
38
+ if isinstance(expected_result, type) and issubclass(
39
+ expected_result, Exception
40
+ ):
41
+ self.assertRaises(expected_result, func, *args)
42
+ else:
43
+ self.assertPreciseEqual(func(*args), expected_result)
44
+
45
+ # Control
46
+ check_result(func)
47
+ pickled = pickle.dumps(func, proto)
48
+ with self.simulate_fresh_target():
49
+ new_func = pickle.loads(pickled)
50
+ check_result(new_func)
51
+
52
+ def test_call_with_sig(self):
53
+ from .serialize_usecases import add_with_sig
54
+
55
+ self.run_with_protocols(self.check_call, add_with_sig, 5, (1, 4))
56
+ # Compilation has been disabled => float inputs will be coerced to int
57
+ self.run_with_protocols(self.check_call, add_with_sig, 5, (1.2, 4.2))
58
+
59
+ def test_call_without_sig(self):
60
+ from .serialize_usecases import add_without_sig
61
+
62
+ self.run_with_protocols(self.check_call, add_without_sig, 5, (1, 4))
63
+ self.run_with_protocols(
64
+ self.check_call, add_without_sig, 5.5, (1.2, 4.3)
65
+ )
66
+ # Object mode is enabled
67
+ self.run_with_protocols(
68
+ self.check_call, add_without_sig, "abc", ("a", "bc")
69
+ )
70
+
71
+ def test_call_nopython(self):
72
+ from .serialize_usecases import add_nopython
73
+
74
+ self.run_with_protocols(self.check_call, add_nopython, 5.5, (1.2, 4.3))
75
+ # Object mode is disabled
76
+ self.run_with_protocols(
77
+ self.check_call, add_nopython, TypingError, (object(), object())
78
+ )
79
+
80
+ def test_call_nopython_fail(self):
81
+ from .serialize_usecases import add_nopython_fail
82
+
83
+ # Compilation fails
84
+ self.run_with_protocols(
85
+ self.check_call, add_nopython_fail, TypingError, (1, 2)
86
+ )
87
+
88
+ def test_call_objmode_with_global(self):
89
+ from .serialize_usecases import get_global_objmode
90
+
91
+ self.run_with_protocols(
92
+ self.check_call, get_global_objmode, 7.5, (2.5,)
93
+ )
94
+
95
+ def test_call_closure(self):
96
+ from .serialize_usecases import closure
97
+
98
+ inner = closure(1)
99
+ self.run_with_protocols(self.check_call, inner, 6, (2, 3))
100
+
101
+ def check_call_closure_with_globals(self, **jit_args):
102
+ from .serialize_usecases import closure_with_globals
103
+
104
+ inner = closure_with_globals(3.0, **jit_args)
105
+ self.run_with_protocols(self.check_call, inner, 7.0, (4.0,))
106
+
107
+ def test_call_closure_with_globals_nopython(self):
108
+ self.check_call_closure_with_globals(nopython=True)
109
+
110
+ def test_call_closure_with_globals_objmode(self):
111
+ self.check_call_closure_with_globals(forceobj=True)
112
+
113
+ def test_call_closure_calling_other_function(self):
114
+ from .serialize_usecases import closure_calling_other_function
115
+
116
+ inner = closure_calling_other_function(3.0)
117
+ self.run_with_protocols(self.check_call, inner, 11.0, (4.0, 6.0))
118
+
119
+ def test_call_closure_calling_other_closure(self):
120
+ from .serialize_usecases import closure_calling_other_closure
121
+
122
+ inner = closure_calling_other_closure(3.0)
123
+ self.run_with_protocols(self.check_call, inner, 8.0, (4.0,))
124
+
125
+ def test_call_dyn_func(self):
126
+ from .serialize_usecases import dyn_func
127
+
128
+ # Check serializing a dynamically-created function
129
+ self.run_with_protocols(self.check_call, dyn_func, 36, (6,))
130
+
131
+ def test_call_dyn_func_objmode(self):
132
+ from .serialize_usecases import dyn_func_objmode
133
+
134
+ # Same with an object mode function
135
+ self.run_with_protocols(self.check_call, dyn_func_objmode, 36, (6,))
136
+
137
+ def test_renamed_module(self):
138
+ from .serialize_usecases import get_renamed_module
139
+
140
+ # Issue #1559: using a renamed module (e.g. `import numpy as np`)
141
+ # should not fail serializing
142
+ expected = get_renamed_module(0.0)
143
+ self.run_with_protocols(
144
+ self.check_call, get_renamed_module, expected, (0.0,)
145
+ )
146
+
147
+ def test_other_process(self):
148
+ """
149
+ Check that reconstructing doesn't depend on resources already
150
+ instantiated in the original process.
151
+ """
152
+ from .serialize_usecases import closure_calling_other_closure
153
+
154
+ func = closure_calling_other_closure(3.0)
155
+ pickled = pickle.dumps(func)
156
+ code = """if 1:
157
+ import pickle
158
+
159
+ data = {pickled!r}
160
+ func = pickle.loads(data)
161
+ res = func(4.0)
162
+ assert res == 8.0, res
163
+ """.format(**locals())
164
+ subprocess.check_call([sys.executable, "-c", code])
165
+
166
+ def test_reuse(self):
167
+ """
168
+ Check that deserializing the same function multiple times re-uses
169
+ the same dispatcher object.
170
+
171
+ Note that "same function" is intentionally under-specified.
172
+ """
173
+ from .serialize_usecases import closure
174
+
175
+ func = closure(5)
176
+ pickled = pickle.dumps(func)
177
+ func2 = closure(6)
178
+ pickled2 = pickle.dumps(func2)
179
+
180
+ f = pickle.loads(pickled)
181
+ g = pickle.loads(pickled)
182
+ h = pickle.loads(pickled2)
183
+ self.assertIs(f, g)
184
+ self.assertEqual(f(2, 3), 10)
185
+ g.disable_compile()
186
+ self.assertEqual(g(2, 4), 11)
187
+
188
+ self.assertIsNot(f, h)
189
+ self.assertEqual(h(2, 3), 11)
190
+
191
+ # Now make sure the original object doesn't exist when deserializing
192
+ func = closure(7)
193
+ func(42, 43)
194
+ pickled = pickle.dumps(func)
195
+ del func
196
+ gc.collect()
197
+
198
+ f = pickle.loads(pickled)
199
+ g = pickle.loads(pickled)
200
+ self.assertIs(f, g)
201
+ self.assertEqual(f(2, 3), 12)
202
+ g.disable_compile()
203
+ self.assertEqual(g(2, 4), 13)
204
+
205
+ def test_imp_deprecation(self):
206
+ """
207
+ The imp module was deprecated in v3.4 in favour of importlib
208
+ """
209
+ code = """if 1:
210
+ import pickle
211
+ import warnings
212
+ with warnings.catch_warnings(record=True) as w:
213
+ warnings.simplefilter('always', DeprecationWarning)
214
+ from numba import njit
215
+ @njit
216
+ def foo(x):
217
+ return x + 1
218
+ foo(1)
219
+ serialized_foo = pickle.dumps(foo)
220
+ for x in w:
221
+ if 'serialize.py' in x.filename:
222
+ assert "the imp module is deprecated" not in x.msg
223
+ """
224
+ subprocess.check_call([sys.executable, "-c", code])
225
+
226
+
227
+ class TestSerializationMisc(TestCase):
228
+ def test_numba_unpickle(self):
229
+ # Test that _numba_unpickle is memorizing its output
230
+ from numba.core.serialize import _numba_unpickle
231
+
232
+ random_obj = object()
233
+ bytebuf = pickle.dumps(random_obj)
234
+ hashed = hash(random_obj)
235
+
236
+ got1 = _numba_unpickle(id(random_obj), bytebuf, hashed)
237
+ # not the original object
238
+ self.assertIsNot(got1, random_obj)
239
+ got2 = _numba_unpickle(id(random_obj), bytebuf, hashed)
240
+ # unpickled results are the same objects
241
+ self.assertIs(got1, got2)
242
+
243
+
244
+ class TestCloudPickleIssues(TestCase):
245
+ """This test case includes issues specific to the cloudpickle implementation."""
246
+
247
+ _numba_parallel_test_ = False
248
+
249
+ def test_dynamic_class_reset_on_unpickle(self):
250
+ # a dynamic class
251
+ class Klass:
252
+ classvar = None
253
+
254
+ def mutator():
255
+ Klass.classvar = 100
256
+
257
+ def check():
258
+ self.assertEqual(Klass.classvar, 100)
259
+
260
+ saved = dumps(Klass)
261
+ mutator()
262
+ check()
263
+ loads(saved)
264
+ # Without the patch, each `loads(saved)` will reset `Klass.classvar`
265
+ check()
266
+ loads(saved)
267
+ check()
268
+
269
+ @unittest.skipIf(
270
+ __name__ == "__main__", "Test cannot run as when module is __main__"
271
+ )
272
+ def test_main_class_reset_on_unpickle(self):
273
+ mp = get_context("spawn")
274
+ proc = mp.Process(target=check_main_class_reset_on_unpickle)
275
+ proc.start()
276
+ proc.join(timeout=60)
277
+ self.assertEqual(proc.exitcode, 0)
278
+
279
+ def test_dynamic_class_reset_on_unpickle_new_proc(self):
280
+ # a dynamic class
281
+ class Klass:
282
+ classvar = None
283
+
284
+ # serialize Klass in this process
285
+ saved = dumps(Klass)
286
+
287
+ # Check the reset problem in a new process
288
+ mp = get_context("spawn")
289
+ proc = mp.Process(
290
+ target=check_unpickle_dyn_class_new_proc, args=(saved,)
291
+ )
292
+ proc.start()
293
+ proc.join(timeout=60)
294
+ self.assertEqual(proc.exitcode, 0)
295
+
296
+ def test_dynamic_class_issue_7356(self):
297
+ cfunc = numba.njit(issue_7356)
298
+ self.assertEqual(cfunc(), (100, 100))
299
+
300
+
301
+ class DynClass(object):
302
+ # For testing issue #7356
303
+ a = None
304
+
305
+
306
+ def issue_7356():
307
+ with numba.objmode(before="intp"):
308
+ DynClass.a = 100
309
+ before = DynClass.a
310
+ with numba.objmode(after="intp"):
311
+ after = DynClass.a
312
+ return before, after
313
+
314
+
315
+ def check_main_class_reset_on_unpickle():
316
+ # Load module and get its global dictionary
317
+ glbs = runpy.run_module(
318
+ "numba.tests.cloudpickle_main_class",
319
+ run_name="__main__",
320
+ )
321
+ # Get the Klass and check it is from __main__
322
+ Klass = glbs["Klass"]
323
+ assert Klass.__module__ == "__main__"
324
+ assert Klass.classvar != 100
325
+ saved = dumps(Klass)
326
+ # mutate
327
+ Klass.classvar = 100
328
+ # check
329
+ _check_dyn_class(Klass, saved)
330
+
331
+
332
+ def check_unpickle_dyn_class_new_proc(saved):
333
+ Klass = loads(saved)
334
+ assert Klass.classvar != 100
335
+ # mutate
336
+ Klass.classvar = 100
337
+ # check
338
+ _check_dyn_class(Klass, saved)
339
+
340
+
341
+ def _check_dyn_class(Klass, saved):
342
+ def check():
343
+ if Klass.classvar != 100:
344
+ raise AssertionError("Check failed. Klass reset.")
345
+
346
+ check()
347
+ loaded = loads(saved)
348
+ if loaded is not Klass:
349
+ raise AssertionError("Expected reuse")
350
+ # Without the patch, each `loads(saved)` will reset `Klass.classvar`
351
+ check()
352
+ loaded = loads(saved)
353
+ if loaded is not Klass:
354
+ raise AssertionError("Expected reuse")
355
+ check()
356
+
357
+
358
+ if __name__ == "__main__":
359
+ unittest.main()
@@ -128,11 +128,17 @@ class Test3rdPartyContext(CUDATestCase):
128
128
  "Error getting CUDA driver version",
129
129
  )
130
130
 
131
- # CUDA 13's cuCtxCreate has an optional parameter prepended
132
- if version >= 13000:
133
- args = (None, flags, dev)
134
- else:
131
+ # CUDA 13's cuCtxCreate has an optional parameter prepended. The
132
+ # version of cuCtxCreate in use depends on the cuda.bindings major
133
+ # version rather than the installed driver version on the machine
134
+ # we're running on.
135
+ from cuda import bindings
136
+
137
+ bindings_version = int(bindings.__version__.split(".")[0])
138
+ if bindings_version in (11, 12):
135
139
  args = (flags, dev)
140
+ else:
141
+ args = (None, flags, dev)
136
142
 
137
143
  hctx = the_driver.cuCtxCreate(*args)
138
144
  else:
@@ -45,6 +45,39 @@ class TestCudaNDArray(CUDATestCase):
45
45
  self.assertEqual(ary.shape, dary.shape)
46
46
  self.assertEqual(ary.shape[1:], dary.shape[1:])
47
47
 
48
+ def test_device_array_float(self):
49
+ # Ensure that a float shape raises an TypeError
50
+ with self.assertRaises(TypeError):
51
+ cuda.device_array(shape=1.23)
52
+ with self.assertRaises(TypeError):
53
+ cuda.device_array(shape=np.float64(1.23))
54
+ with self.assertRaises(TypeError):
55
+ cuda.device_array(shape=np.array(1.23))
56
+
57
+ def test_device_array_float_vectors(self):
58
+ # Ensure that np.array, list or tuple inputs with
59
+ # non-ints raise an TypeError
60
+ with self.assertRaises(TypeError):
61
+ cuda.device_array(shape=np.array([1.1]))
62
+ with self.assertRaises(TypeError):
63
+ cuda.device_array(shape=[1.1])
64
+ with self.assertRaises(TypeError):
65
+ cuda.device_array(shape=(1.1,))
66
+ with self.assertRaises(TypeError):
67
+ cuda.device_array(shape=np.array([1.1, 2.2]))
68
+ with self.assertRaises(TypeError):
69
+ cuda.device_array(shape=[1.1, 2.2])
70
+ with self.assertRaises(TypeError):
71
+ cuda.device_array(shape=(1.1, 2.2))
72
+
73
+ def test_device_array_vectors(self):
74
+ # Ensure that np.array or list of inputs with
75
+ # ints still work
76
+ dary = cuda.device_array(shape=np.array([10, 10]), dtype=np.bool)
77
+ self.assertEqual(dary.shape, (10, 10))
78
+ dary = cuda.device_array(shape=[10, 10], dtype=np.bool)
79
+ self.assertEqual(dary.shape, (10, 10))
80
+
48
81
  def test_devicearray(self):
49
82
  array = np.arange(100, dtype=np.int32)
50
83
  original = array.copy()
@@ -1,6 +1,6 @@
1
1
  import multiprocessing
2
2
  import os
3
- from numba.cuda.testing import unittest, SerialMixin
3
+ from numba.cuda.testing import unittest
4
4
 
5
5
 
6
6
  def set_visible_devices_and_check(q):
@@ -15,7 +15,7 @@ def set_visible_devices_and_check(q):
15
15
  q.put(-1)
16
16
 
17
17
 
18
- class TestVisibleDevices(unittest.TestCase, SerialMixin):
18
+ class TestVisibleDevices(unittest.TestCase):
19
19
  def test_visible_devices_set_after_import(self):
20
20
  # See Issue #6149. This test checks that we can set
21
21
  # CUDA_VISIBLE_DEVICES after importing Numba and have the value
@@ -19,6 +19,7 @@ def with_asyncio_loop(f):
19
19
  return runner
20
20
 
21
21
 
22
+ @unittest.skip("Disabled temporarily due to Issue #317")
22
23
  @skip_on_cudasim("CUDA Driver API unsupported in the simulator")
23
24
  class TestCudaStream(CUDATestCase):
24
25
  def test_add_callback(self):
@@ -28,7 +28,7 @@ if not config.ENABLE_CUDASIM:
28
28
  make_attribute_wrapper,
29
29
  )
30
30
  from numba.cuda.cudaimpl import lower
31
- from numba.core import cgutils
31
+ from numba.cuda import cgutils
32
32
 
33
33
  @typeof_impl.register(TestStruct)
34
34
  def typeof_teststruct(val, c):
@@ -14,7 +14,6 @@ from numba.cuda.testing import (
14
14
  skip_if_mvc_enabled,
15
15
  test_data_dir,
16
16
  )
17
- from numba.tests.support import SerialMixin
18
17
  from numba.tests.test_caching import (
19
18
  DispatcherCacheUsecasesTest,
20
19
  skip_bad_access,
@@ -22,7 +21,7 @@ from numba.tests.test_caching import (
22
21
 
23
22
 
24
23
  @skip_on_cudasim("Simulator does not implement caching")
25
- class CUDACachingTest(SerialMixin, DispatcherCacheUsecasesTest):
24
+ class CUDACachingTest(DispatcherCacheUsecasesTest):
26
25
  here = os.path.dirname(__file__)
27
26
  usecases_file = os.path.join(here, "cache_usecases.py")
28
27
  modname = "cuda_caching_test_fodder"
@@ -225,7 +224,7 @@ class CUDACachingTest(SerialMixin, DispatcherCacheUsecasesTest):
225
224
 
226
225
 
227
226
  @skip_on_cudasim("Simulator does not implement caching")
228
- class CUDACooperativeGroupTest(SerialMixin, DispatcherCacheUsecasesTest):
227
+ class CUDACooperativeGroupTest(DispatcherCacheUsecasesTest):
229
228
  # See Issue #9432: https://github.com/numba/numba/issues/9432
230
229
  # If a cached function using CG sync was the first thing to compile,
231
230
  # the compile would fail.
@@ -259,7 +258,7 @@ class CUDACooperativeGroupTest(SerialMixin, DispatcherCacheUsecasesTest):
259
258
 
260
259
 
261
260
  @skip_on_cudasim("Simulator does not implement caching")
262
- class CUDAAndCPUCachingTest(SerialMixin, DispatcherCacheUsecasesTest):
261
+ class CUDAAndCPUCachingTest(DispatcherCacheUsecasesTest):
263
262
  here = os.path.dirname(__file__)
264
263
  usecases_file = os.path.join(here, "cache_with_cpu_usecases.py")
265
264
  modname = "cuda_and_cpu_caching_test_fodder"
@@ -350,7 +349,7 @@ def get_different_cc_gpus():
350
349
 
351
350
 
352
351
  @skip_on_cudasim("Simulator does not implement caching")
353
- class TestMultiCCCaching(SerialMixin, DispatcherCacheUsecasesTest):
352
+ class TestMultiCCCaching(DispatcherCacheUsecasesTest):
354
353
  here = os.path.dirname(__file__)
355
354
  usecases_file = os.path.join(here, "cache_usecases.py")
356
355
  modname = "cuda_multi_cc_caching_test_fodder"
@@ -484,11 +483,7 @@ def child_initializer():
484
483
 
485
484
 
486
485
  @skip_on_cudasim("Simulator does not implement caching")
487
- class TestMultiprocessCache(SerialMixin, DispatcherCacheUsecasesTest):
488
- # Nested multiprocessing.Pool raises AssertionError:
489
- # "daemonic processes are not allowed to have children"
490
- _numba_parallel_test_ = False
491
-
486
+ class TestMultiprocessCache(DispatcherCacheUsecasesTest):
492
487
  here = os.path.dirname(__file__)
493
488
  usecases_file = os.path.join(here, "cache_usecases.py")
494
489
  modname = "cuda_mp_caching_test_fodder"
@@ -252,6 +252,21 @@ class TestCompile(unittest.TestCase):
252
252
  output=illegal_output,
253
253
  )
254
254
 
255
+ def test_functioncompiler_locals(self):
256
+ # Tests against regression fixed in:
257
+ # https://github.com/NVIDIA/numba-cuda/pull/381
258
+ #
259
+ # "AttributeError: '_FunctionCompiler' object has no attribute
260
+ # 'locals'"
261
+ cond = None
262
+
263
+ @cuda.jit("void(float32[::1])")
264
+ def f(b_arg):
265
+ b_smem = cuda.shared.array(shape=(1,), dtype=float32)
266
+
267
+ if cond:
268
+ b_smem[0] = b_arg[0]
269
+
255
270
 
256
271
  @skip_on_cudasim("Compilation unsupported in the simulator")
257
272
  class TestCompileForCurrentDevice(CUDATestCase):
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from numba.cuda.testing import unittest, CUDATestCase
7
7
  from numba.core import types
8
8
  from numba import cuda
9
- from numba.tests.complex_usecases import (
9
+ from numba.cuda.tests.complex_usecases import (
10
10
  real_usecase,
11
11
  imag_usecase,
12
12
  conjugate_usecase,