numba-cuda 0.10.1__py3-none-any.whl → 0.12.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/{cuda_bf16.py → _internal/cuda_bf16.py} +1 -1
  3. numba_cuda/numba/cuda/api.py +13 -0
  4. numba_cuda/numba/cuda/bf16.py +112 -0
  5. numba_cuda/numba/cuda/cg.py +2 -0
  6. numba_cuda/numba/cuda/codegen.py +77 -2
  7. numba_cuda/numba/cuda/compiler.py +22 -16
  8. numba_cuda/numba/cuda/cudadecl.py +21 -6
  9. numba_cuda/numba/cuda/cudadrv/driver.py +107 -20
  10. numba_cuda/numba/cuda/cudadrv/linkable_code.py +10 -2
  11. numba_cuda/numba/cuda/cudadrv/nvrtc.py +23 -1
  12. numba_cuda/numba/cuda/cudaimpl.py +103 -11
  13. numba_cuda/numba/cuda/debuginfo.py +27 -0
  14. numba_cuda/numba/cuda/decorators.py +7 -2
  15. numba_cuda/numba/cuda/dispatcher.py +25 -65
  16. numba_cuda/numba/cuda/runtime/nrt.cu +2 -17
  17. numba_cuda/numba/cuda/runtime/nrt.cuh +41 -0
  18. numba_cuda/numba/cuda/runtime/nrt.py +13 -1
  19. numba_cuda/numba/cuda/stubs.py +23 -11
  20. numba_cuda/numba/cuda/target.py +10 -1
  21. numba_cuda/numba/cuda/tests/cudapy/cache_usecases.py +0 -12
  22. numba_cuda/numba/cuda/tests/cudapy/cg_cache_usecases.py +33 -0
  23. numba_cuda/numba/cuda/tests/cudapy/test_array_alignment.py +236 -0
  24. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +55 -0
  25. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +49 -23
  26. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +34 -51
  27. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +34 -0
  28. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +17 -0
  29. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +140 -0
  30. numba_cuda/numba/cuda/tests/data/cta_barrier.cu +23 -0
  31. numba_cuda/numba/cuda/tests/data/include/add.cuh +3 -0
  32. numba_cuda/numba/cuda/tests/doc_examples/ffi/include/mul.cuh +3 -0
  33. numba_cuda/numba/cuda/tests/doc_examples/ffi/saxpy.cu +9 -0
  34. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +48 -1
  35. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +122 -3
  36. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +11 -0
  37. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +5 -2
  38. numba_cuda/numba/cuda/tests/test_binary_generation/nrt_extern.cu +7 -0
  39. numba_cuda/numba/cuda/tests/test_binary_generation/test_device_functions.cu +4 -0
  40. numba_cuda/numba/cuda/utils.py +7 -0
  41. {numba_cuda-0.10.1.dist-info → numba_cuda-0.12.1.dist-info}/METADATA +1 -1
  42. {numba_cuda-0.10.1.dist-info → numba_cuda-0.12.1.dist-info}/RECORD +45 -35
  43. {numba_cuda-0.10.1.dist-info → numba_cuda-0.12.1.dist-info}/WHEEL +1 -1
  44. {numba_cuda-0.10.1.dist-info → numba_cuda-0.12.1.dist-info}/licenses/LICENSE +0 -0
  45. {numba_cuda-0.10.1.dist-info → numba_cuda-0.12.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,33 @@
1
+ from numba import cuda
2
+ from numba.cuda.testing import CUDATestCase
3
+ import sys
4
+
5
+ from numba.cuda.tests.cudapy.cache_usecases import CUDAUseCase
6
+
7
+
8
+ # Usecase with cooperative groups
9
+
10
+
11
+ @cuda.jit(cache=True)
12
+ def cg_usecase_kernel(r, x):
13
+ grid = cuda.cg.this_grid()
14
+ grid.sync()
15
+
16
+
17
+ cg_usecase = CUDAUseCase(cg_usecase_kernel)
18
+
19
+
20
+ class _TestModule(CUDATestCase):
21
+ """
22
+ Tests for functionality of this module's functions.
23
+ Note this does not define any "test_*" method, instead check_module()
24
+ should be called by hand.
25
+ """
26
+
27
+ def check_module(self, mod):
28
+ mod.cg_usecase(0)
29
+
30
+
31
+ def self_test():
32
+ mod = sys.modules[__name__]
33
+ _TestModule().check_module(mod)
@@ -0,0 +1,236 @@
1
+ import re
2
+ import itertools
3
+ import numpy as np
4
+ from numba import cuda
5
+ from numba.core.errors import TypingError
6
+ from numba.cuda.testing import CUDATestCase
7
+ import unittest
8
+
9
+
10
+ # Set to true if you want to see dots printed for each subtest.
11
+ NOISY = False
12
+
13
+
14
+ # In order to verify the alignment of the local and shared memory arrays, we
15
+ # inspect the LLVM IR of the generated kernel using the following regexes.
16
+
17
+ # Shared memory example:
18
+ # @"_cudapy_smem_38" = addrspace(3) global [1 x i8] undef, align 16
19
+ SMEM_PATTERN = re.compile(
20
+ r'^@"_cudapy_smem_\d+".*?align (\d+)',
21
+ re.MULTILINE,
22
+ )
23
+
24
+ # Local memory example:
25
+ # %"_cudapy_lmem" = alloca [1 x i8], align 64
26
+ LMEM_PATTERN = re.compile(
27
+ r'^\s*%"_cudapy_lmem".*?align (\d+)',
28
+ re.MULTILINE,
29
+ )
30
+
31
+
32
+ DTYPES = [np.uint8, np.uint32, np.uint64]
33
+
34
+ # Add in some record dtypes with and without alignment.
35
+ for align in (True, False):
36
+ DTYPES += [
37
+ np.dtype(
38
+ [
39
+ ("a", np.uint8),
40
+ ("b", np.int32),
41
+ ("c", np.float64),
42
+ ],
43
+ align=align,
44
+ ),
45
+ np.dtype(
46
+ [
47
+ ("a", np.uint32),
48
+ ("b", np.uint8),
49
+ ],
50
+ align=align,
51
+ ),
52
+ np.dtype(
53
+ [
54
+ ("a", np.uint8),
55
+ ("b", np.int32),
56
+ ("c", np.float64),
57
+ ("d", np.complex64),
58
+ ("e", (np.uint8, 5)),
59
+ ],
60
+ align=align,
61
+ ),
62
+ ]
63
+
64
+ # N.B. We name the test class TestArrayAddressAlignment to avoid name conflict
65
+ # with the test_alignment.TestArrayAlignment class.
66
+
67
+
68
+ class TestArrayAddressAlignment(CUDATestCase):
69
+ """
70
+ Test cuda.local.array and cuda.shared.array support for an alignment
71
+ keyword argument.
72
+ """
73
+
74
+ def test_array_alignment_1d(self):
75
+ shapes = (1, 8, 50)
76
+ alignments = (None, 16, 64, 256)
77
+ array_types = [(0, "local"), (1, "shared")]
78
+ self._do_test(array_types, shapes, DTYPES, alignments)
79
+
80
+ def test_array_alignment_2d(self):
81
+ shapes = ((2, 3),)
82
+ alignments = (None, 16, 64, 256)
83
+ array_types = [(0, "local"), (1, "shared")]
84
+ self._do_test(array_types, shapes, DTYPES, alignments)
85
+
86
+ def test_array_alignment_3d(self):
87
+ shapes = ((2, 3, 4), (1, 4, 5))
88
+ alignments = (None, 16, 64, 256)
89
+ array_types = [(0, "local"), (1, "shared")]
90
+ self._do_test(array_types, shapes, DTYPES, alignments)
91
+
92
+ def _do_test(self, array_types, shapes, dtypes, alignments):
93
+ items = itertools.product(array_types, shapes, dtypes, alignments)
94
+
95
+ for (which, array_type), shape, dtype, alignment in items:
96
+ with self.subTest(
97
+ array_type=array_type,
98
+ shape=shape,
99
+ dtype=dtype,
100
+ alignment=alignment,
101
+ ):
102
+
103
+ @cuda.jit
104
+ def f(loc, shrd, which):
105
+ i = cuda.grid(1)
106
+ if which == 0:
107
+ local_array = cuda.local.array(
108
+ shape=shape,
109
+ dtype=dtype,
110
+ alignment=alignment,
111
+ )
112
+ if i == 0:
113
+ loc[0] = local_array.ctypes.data
114
+ else:
115
+ shared_array = cuda.shared.array(
116
+ shape=shape,
117
+ dtype=dtype,
118
+ alignment=alignment,
119
+ )
120
+ if i == 0:
121
+ shrd[0] = shared_array.ctypes.data
122
+
123
+ loc = np.zeros(1, dtype=np.uint64)
124
+ shrd = np.zeros(1, dtype=np.uint64)
125
+ f[1, 1](loc, shrd, which)
126
+
127
+ kernel = f.overloads[f.signatures[0]]
128
+ llvm_ir = kernel.inspect_llvm()
129
+
130
+ if alignment is None:
131
+ if which == 0:
132
+ # Local memory shouldn't have any alignment information
133
+ # when no alignment is specified.
134
+ match = LMEM_PATTERN.findall(llvm_ir)
135
+ self.assertEqual(len(match), 0)
136
+ else:
137
+ # Shared memory should at least have a power-of-two
138
+ # alignment when no alignment is specified.
139
+ match = SMEM_PATTERN.findall(llvm_ir)
140
+ self.assertEqual(len(match), 1)
141
+
142
+ alignment = int(match[0])
143
+ # Verify alignment is a power of two.
144
+ self.assertTrue(alignment & (alignment - 1) == 0)
145
+ else:
146
+ # Verify alignment is in the LLVM IR.
147
+ if which == 0:
148
+ match = LMEM_PATTERN.findall(llvm_ir)
149
+ self.assertEqual(len(match), 1)
150
+ actual_alignment = int(match[0])
151
+ self.assertEqual(alignment, actual_alignment)
152
+ else:
153
+ match = SMEM_PATTERN.findall(llvm_ir)
154
+ self.assertEqual(len(match), 1)
155
+ actual_alignment = int(match[0])
156
+ self.assertEqual(alignment, actual_alignment)
157
+
158
+ # Also verify that the address of the array is aligned.
159
+ # If this fails, there problem is likely with NVVM.
160
+ address = loc[0] if which == 0 else shrd[0]
161
+ alignment_mod = int(address % alignment)
162
+ self.assertEqual(alignment_mod, 0)
163
+
164
+ if NOISY:
165
+ print(".", end="", flush=True)
166
+
167
+ def test_invalid_aligments(self):
168
+ shapes = (1, 50)
169
+ dtypes = (np.uint8, np.uint64)
170
+ invalid_alignment_values = (-1, 0, 3, 17, 33)
171
+ invalid_alignment_types = ("1.0", "1", "foo", 1.0, 1.5, 3.2)
172
+ alignments = invalid_alignment_values + invalid_alignment_types
173
+ array_types = [(0, "local"), (1, "shared")]
174
+
175
+ # Use regex pattern to match error message, handling potential ANSI
176
+ # color codes which appear on CI.
177
+ expected_invalid_type_error_regex = (
178
+ r"RequireLiteralValue:.*alignment must be a constant integer"
179
+ )
180
+
181
+ items = itertools.product(array_types, shapes, dtypes, alignments)
182
+
183
+ for (which, array_type), shape, dtype, alignment in items:
184
+ with self.subTest(
185
+ array_type=array_type,
186
+ shape=shape,
187
+ dtype=dtype,
188
+ alignment=alignment,
189
+ ):
190
+ if which == 0:
191
+
192
+ @cuda.jit
193
+ def f(dest_array):
194
+ i = cuda.grid(1)
195
+ local_array = cuda.local.array(
196
+ shape=shape,
197
+ dtype=dtype,
198
+ alignment=alignment,
199
+ )
200
+ if i == 0:
201
+ dest_array[0] = local_array.ctypes.data
202
+ else:
203
+
204
+ @cuda.jit
205
+ def f(dest_array):
206
+ i = cuda.grid(1)
207
+ shared_array = cuda.shared.array(
208
+ shape=shape,
209
+ dtype=dtype,
210
+ alignment=alignment,
211
+ )
212
+ if i == 0:
213
+ dest_array[0] = shared_array.ctypes.data
214
+
215
+ array = np.zeros(1, dtype=np.uint64)
216
+
217
+ # The type of error we expect differs between an invalid value
218
+ # that is still an int, and an invalid type.
219
+ if isinstance(alignment, int):
220
+ self.assertRaisesRegex(
221
+ ValueError, r"Alignment must be.*", f[1, 1], array
222
+ )
223
+ else:
224
+ self.assertRaisesRegex(
225
+ TypingError,
226
+ expected_invalid_type_error_regex,
227
+ f[1, 1],
228
+ array,
229
+ )
230
+
231
+ if NOISY:
232
+ print(".", end="", flush=True)
233
+
234
+
235
+ if __name__ == "__main__":
236
+ unittest.main()
@@ -0,0 +1,55 @@
1
+ from numba import cuda, float32
2
+ from numba.cuda.bf16 import bfloat16
3
+ from numba.cuda.testing import CUDATestCase
4
+
5
+ import math
6
+
7
+
8
+ class TestBfloat16HighLevelBindings(CUDATestCase):
9
+ def skip_unsupported(self):
10
+ if not cuda.is_bfloat16_supported():
11
+ self.skipTest(
12
+ "bfloat16 requires compute capability 8.0+ and CUDA version>= 12.0"
13
+ )
14
+
15
+ def test_use_type_in_kernel(self):
16
+ self.skip_unsupported()
17
+
18
+ @cuda.jit
19
+ def kernel():
20
+ bfloat16(3.14)
21
+
22
+ kernel[1, 1]()
23
+
24
+ def test_math_bindings(self):
25
+ self.skip_unsupported()
26
+ functions = [
27
+ math.trunc,
28
+ math.ceil,
29
+ math.floor,
30
+ math.sqrt,
31
+ math.log,
32
+ math.log10,
33
+ math.cos,
34
+ math.sin,
35
+ math.tanh,
36
+ math.exp,
37
+ math.exp2,
38
+ ]
39
+
40
+ for f in functions:
41
+ with self.subTest(func=f):
42
+
43
+ @cuda.jit
44
+ def kernel(arr):
45
+ x = bfloat16(3.14)
46
+ y = f(x)
47
+ arr[0] = float32(y)
48
+
49
+ arr = cuda.device_array((1,), dtype="float32")
50
+ kernel[1, 1](arr)
51
+
52
+ if f in (math.exp, math.exp2):
53
+ self.assertAlmostEqual(arr[0], f(3.14), delta=1e-1)
54
+ else:
55
+ self.assertAlmostEqual(arr[0], f(3.14), delta=1e-2)
@@ -5,7 +5,7 @@ import numpy as np
5
5
  from numba import int16, int32, int64, uint16, uint32, uint64, float32, float64
6
6
  from numba.types import float16
7
7
 
8
- from numba.cuda.cuda_bf16 import (
8
+ from numba.cuda._internal.cuda_bf16 import (
9
9
  nv_bfloat16,
10
10
  htrunc,
11
11
  hceil,
@@ -22,21 +22,23 @@ from numba.cuda.cuda_bf16 import (
22
22
  hexp,
23
23
  hexp2,
24
24
  hexp10,
25
+ htanh,
26
+ htanh_approx,
25
27
  )
26
28
 
27
- from numba.cuda.cudadrv.runtime import get_version
28
-
29
- cuda_version = get_version()
30
-
31
29
  dtypes = [int16, int32, int64, uint16, uint32, uint64, float32]
32
30
 
33
31
 
34
- @unittest.skipIf(
35
- (cuda.get_current_device().compute_capability < (8, 0)),
36
- "bfloat16 requires compute capability 8.0+",
37
- )
38
32
  class Bfloat16Test(CUDATestCase):
33
+ def skip_unsupported(self):
34
+ if not cuda.is_bfloat16_supported():
35
+ self.skipTest(
36
+ "bfloat16 requires compute capability 8.0+ and CUDA version>= 12.0"
37
+ )
38
+
39
39
  def test_ctor(self):
40
+ self.skip_unsupported()
41
+
40
42
  @cuda.jit
41
43
  def simple_kernel():
42
44
  a = nv_bfloat16(float64(1.0)) # noqa: F841
@@ -47,18 +49,13 @@ class Bfloat16Test(CUDATestCase):
47
49
  f = nv_bfloat16(uint16(6)) # noqa: F841
48
50
  g = nv_bfloat16(uint32(7)) # noqa: F841
49
51
  h = nv_bfloat16(uint64(8)) # noqa: F841
52
+ i = nv_bfloat16(float16(9)) # noqa: F841
50
53
 
51
54
  simple_kernel[1, 1]()
52
55
 
53
- if cuda_version >= (12, 0):
54
-
55
- @cuda.jit
56
- def simple_kernel_fp16():
57
- i = nv_bfloat16(float16(9)) # noqa: F841
58
-
59
- simple_kernel_fp16[1, 1]()
60
-
61
56
  def test_casts(self):
57
+ self.skip_unsupported()
58
+
62
59
  @cuda.jit
63
60
  def simple_kernel(b, c, d, e, f, g, h):
64
61
  a = nv_bfloat16(3.14)
@@ -90,6 +87,7 @@ class Bfloat16Test(CUDATestCase):
90
87
  assert h[0] == 3
91
88
 
92
89
  def test_ctor_cast_loop(self):
90
+ self.skip_unsupported()
93
91
  for dtype in dtypes:
94
92
  with self.subTest(dtype=dtype):
95
93
 
@@ -106,6 +104,8 @@ class Bfloat16Test(CUDATestCase):
106
104
  assert a[0] == 3
107
105
 
108
106
  def test_arithmetic(self):
107
+ self.skip_unsupported()
108
+
109
109
  @cuda.jit
110
110
  def simple_kernel(arith, logic):
111
111
  # Binary Arithmetic Operators
@@ -175,6 +175,8 @@ class Bfloat16Test(CUDATestCase):
175
175
  )
176
176
 
177
177
  def test_math_func(self):
178
+ self.skip_unsupported()
179
+
178
180
  @cuda.jit
179
181
  def simple_kernel(a):
180
182
  x = nv_bfloat16(3.14)
@@ -191,16 +193,18 @@ class Bfloat16Test(CUDATestCase):
191
193
  a[9] = float32(hlog10(x))
192
194
  a[10] = float32(hcos(x))
193
195
  a[11] = float32(hsin(x))
194
- a[12] = float32(hexp(x))
195
- a[13] = float32(hexp2(x))
196
- a[14] = float32(hexp10(x))
196
+ a[12] = float32(htanh(x))
197
+ a[13] = float32(htanh_approx(x))
198
+ a[14] = float32(hexp(x))
199
+ a[15] = float32(hexp2(x))
200
+ a[16] = float32(hexp10(x))
197
201
 
198
- a = np.zeros(15, dtype=np.float32)
202
+ a = np.zeros(17, dtype=np.float32)
199
203
  simple_kernel[1, 1](a)
200
204
 
201
205
  x = 3.14
202
206
  np.testing.assert_allclose(
203
- a[:12],
207
+ a[:14],
204
208
  [
205
209
  np.trunc(x),
206
210
  np.ceil(x),
@@ -214,15 +218,19 @@ class Bfloat16Test(CUDATestCase):
214
218
  np.log10(x),
215
219
  np.cos(x),
216
220
  np.sin(x),
221
+ np.tanh(x),
222
+ np.tanh(x),
217
223
  ],
218
224
  atol=1e-2,
219
225
  )
220
226
 
221
227
  np.testing.assert_allclose(
222
- a[12:], [np.exp(x), np.exp2(x), np.power(10, x)], atol=1e2
228
+ a[14:], [np.exp(x), np.exp2(x), np.power(10, x)], atol=1e2
223
229
  )
224
230
 
225
231
  def test_check_bfloat16_type(self):
232
+ self.skip_unsupported()
233
+
226
234
  @cuda.jit
227
235
  def kernel(arr):
228
236
  x = nv_bfloat16(3.14)
@@ -237,6 +245,8 @@ class Bfloat16Test(CUDATestCase):
237
245
  np.testing.assert_allclose(arr, [3.14], atol=1e-2)
238
246
 
239
247
  def test_use_within_device_func(self):
248
+ self.skip_unsupported()
249
+
240
250
  @cuda.jit(device=True)
241
251
  def add_bf16(a, b):
242
252
  return a + b
@@ -252,6 +262,22 @@ class Bfloat16Test(CUDATestCase):
252
262
 
253
263
  np.testing.assert_allclose(arr, [8], atol=1e-2)
254
264
 
265
+ def test_use_binding_inside_dfunc(self):
266
+ @cuda.jit(device=True)
267
+ def f(arr):
268
+ pi = nv_bfloat16(3.14)
269
+ three = htrunc(pi)
270
+ arr[0] = float32(three)
271
+
272
+ @cuda.jit
273
+ def kernel(arr):
274
+ f(arr)
275
+
276
+ arr = np.zeros(1, np.float32)
277
+ kernel[1, 1](arr)
278
+
279
+ np.testing.assert_allclose(arr, [3], atol=1e-2)
280
+
255
281
 
256
282
  if __name__ == "__main__":
257
283
  unittest.main()
@@ -1,8 +1,6 @@
1
1
  import multiprocessing
2
2
  import os
3
3
  import shutil
4
- import subprocess
5
- import sys
6
4
  import unittest
7
5
  import warnings
8
6
 
@@ -163,55 +161,6 @@ class CUDACachingTest(SerialMixin, DispatcherCacheUsecasesTest):
163
161
  f = mod.renamed_function2
164
162
  self.assertPreciseEqual(f(2), 8)
165
163
 
166
- @skip_unless_cc_60
167
- @skip_if_cudadevrt_missing
168
- @skip_if_mvc_enabled("CG not supported with MVC")
169
- def test_cache_cg(self):
170
- # Functions using cooperative groups should be cacheable. See Issue
171
- # #8888: https://github.com/numba/numba/issues/8888
172
- self.check_pycache(0)
173
- mod = self.import_module()
174
- self.check_pycache(0)
175
-
176
- mod.cg_usecase(0)
177
- self.check_pycache(2) # 1 index, 1 data
178
-
179
- # Check the code runs ok from another process
180
- self.run_in_separate_process()
181
-
182
- @skip_unless_cc_60
183
- @skip_if_cudadevrt_missing
184
- @skip_if_mvc_enabled("CG not supported with MVC")
185
- def test_cache_cg_clean_run(self):
186
- # See Issue #9432: https://github.com/numba/numba/issues/9432
187
- # If a cached function using CG sync was the first thing to compile,
188
- # the compile would fail.
189
- self.check_pycache(0)
190
-
191
- # This logic is modelled on run_in_separate_process(), but executes the
192
- # CG usecase directly in the subprocess.
193
- code = """if 1:
194
- import sys
195
-
196
- sys.path.insert(0, %(tempdir)r)
197
- mod = __import__(%(modname)r)
198
- mod.cg_usecase(0)
199
- """ % dict(tempdir=self.tempdir, modname=self.modname)
200
-
201
- popen = subprocess.Popen(
202
- [sys.executable, "-c", code],
203
- stdout=subprocess.PIPE,
204
- stderr=subprocess.PIPE,
205
- )
206
- out, err = popen.communicate(timeout=60)
207
- if popen.returncode != 0:
208
- raise AssertionError(
209
- "process failed with code %s: \n"
210
- "stdout follows\n%s\n"
211
- "stderr follows\n%s\n"
212
- % (popen.returncode, out.decode(), err.decode()),
213
- )
214
-
215
164
  def _test_pycache_fallback(self):
216
165
  """
217
166
  With a disabled __pycache__, test there is a working fallback
@@ -275,6 +224,40 @@ class CUDACachingTest(SerialMixin, DispatcherCacheUsecasesTest):
275
224
  pass
276
225
 
277
226
 
227
+ @skip_on_cudasim("Simulator does not implement caching")
228
+ class CUDACooperativeGroupTest(SerialMixin, DispatcherCacheUsecasesTest):
229
+ # See Issue #9432: https://github.com/numba/numba/issues/9432
230
+ # If a cached function using CG sync was the first thing to compile,
231
+ # the compile would fail.
232
+ here = os.path.dirname(__file__)
233
+ usecases_file = os.path.join(here, "cg_cache_usecases.py")
234
+ modname = "cuda_cooperative_caching_test_fodder"
235
+
236
+ def setUp(self):
237
+ DispatcherCacheUsecasesTest.setUp(self)
238
+ CUDATestCase.setUp(self)
239
+
240
+ def tearDown(self):
241
+ CUDATestCase.tearDown(self)
242
+ DispatcherCacheUsecasesTest.tearDown(self)
243
+
244
+ @skip_unless_cc_60
245
+ @skip_if_cudadevrt_missing
246
+ @skip_if_mvc_enabled("CG not supported with MVC")
247
+ def test_cache_cg(self):
248
+ # Functions using cooperative groups should be cacheable. See Issue
249
+ # #8888: https://github.com/numba/numba/issues/8888
250
+ self.check_pycache(0)
251
+ mod = self.import_module()
252
+ self.check_pycache(0)
253
+
254
+ mod.cg_usecase(0)
255
+ self.check_pycache(2) # 1 index, 1 data
256
+
257
+ # Check the code runs ok from another process
258
+ self.run_in_separate_process()
259
+
260
+
278
261
  @skip_on_cudasim("Simulator does not implement caching")
279
262
  class CUDAAndCPUCachingTest(SerialMixin, DispatcherCacheUsecasesTest):
280
263
  here = os.path.dirname(__file__)
@@ -1,8 +1,13 @@
1
1
  from __future__ import print_function
2
2
 
3
+ import os
4
+
5
+ import cffi
6
+
3
7
  import numpy as np
4
8
 
5
9
  from numba import config, cuda, int32
10
+ from numba.types import CPointer
6
11
  from numba.cuda.testing import (
7
12
  unittest,
8
13
  CUDATestCase,
@@ -11,6 +16,9 @@ from numba.cuda.testing import (
11
16
  skip_if_cudadevrt_missing,
12
17
  skip_if_mvc_enabled,
13
18
  )
19
+ from numba.core.typing import signature
20
+
21
+ ffi = cffi.FFI()
14
22
 
15
23
 
16
24
  @cuda.jit
@@ -149,6 +157,32 @@ class TestCudaCooperativeGroups(CUDATestCase):
149
157
  self.assertEqual(blocks1d, blocks2d)
150
158
  self.assertEqual(blocks1d, blocks3d)
151
159
 
160
+ @skip_unless_cc_60
161
+ def test_external_cooperative_func(self):
162
+ cudapy_test_path = os.path.dirname(__file__)
163
+ tests_path = os.path.dirname(cudapy_test_path)
164
+ data_path = os.path.join(tests_path, "data")
165
+ src = os.path.join(data_path, "cta_barrier.cu")
166
+
167
+ sig = signature(
168
+ CPointer(int32),
169
+ )
170
+ cta_barrier = cuda.declare_device(
171
+ "cta_barrier", sig=sig, link=[src], use_cooperative=True
172
+ )
173
+
174
+ @cuda.jit
175
+ def kernel():
176
+ cta_barrier()
177
+
178
+ block_size = 32
179
+ grid_size = 1024
180
+
181
+ kernel[grid_size, block_size]()
182
+
183
+ overload = kernel.overloads[()]
184
+ self.assertTrue(overload.cooperative)
185
+
152
186
 
153
187
  if __name__ == "__main__":
154
188
  unittest.main()
@@ -310,6 +310,23 @@ class TestCudaDebugInfo(CUDATestCase):
310
310
  with captured_stdout():
311
311
  self._test_kernel_args_types()
312
312
 
313
+ def test_kernel_args_names(self):
314
+ sig = (types.int32,)
315
+
316
+ @cuda.jit("void(int32)", debug=True, opt=False)
317
+ def f(x):
318
+ z = x # noqa: F841
319
+
320
+ llvm_ir = f.inspect_llvm(sig)
321
+
322
+ # Verify argument name is not prefixed with "arg."
323
+ pat = r"define void @.*\(i32 %\"x\"\)"
324
+ match = re.compile(pat).search(llvm_ir)
325
+ self.assertIsNotNone(match, msg=llvm_ir)
326
+ pat = r"define void @.*\(i32 %\"arg\.x\"\)"
327
+ match = re.compile(pat).search(llvm_ir)
328
+ self.assertIsNone(match, msg=llvm_ir)
329
+
313
330
  def test_llvm_dbg_value(self):
314
331
  sig = (types.int32, types.int32)
315
332