numba-cuda 0.10.1__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,27 +1,25 @@
1
1
  import numpy as np
2
2
  import os
3
- import re
4
3
  import sys
5
4
  import ctypes
6
5
  import functools
7
- from collections import defaultdict
8
6
 
9
- from numba.core import config, ir, serialize, sigutils, types, typing, utils
7
+ from numba.core import config, serialize, sigutils, types, typing, utils
10
8
  from numba.core.caching import Cache, CacheImpl
11
9
  from numba.core.compiler_lock import global_compiler_lock
12
10
  from numba.core.dispatcher import Dispatcher
13
11
  from numba.core.errors import NumbaPerformanceWarning
14
12
  from numba.core.typing.typeof import Purpose, typeof
15
- from numba.core.types.functions import Function
16
13
  from numba.cuda.api import get_current_device
17
14
  from numba.cuda.args import wrap_arg
18
15
  from numba.cuda.compiler import (
19
16
  compile_cuda,
20
17
  CUDACompiler,
21
18
  kernel_fixup,
22
- ExternFunction,
23
19
  )
20
+ import re
24
21
  from numba.cuda.cudadrv import driver
22
+ from numba.cuda.cudadrv.linkable_code import LinkableCode
25
23
  from numba.cuda.cudadrv.devices import get_context
26
24
  from numba.cuda.descriptor import cuda_target
27
25
  from numba.cuda.errors import (
@@ -29,7 +27,7 @@ from numba.cuda.errors import (
29
27
  normalize_kernel_dimensions,
30
28
  )
31
29
  from numba.cuda import types as cuda_types
32
- from numba.cuda.runtime.nrt import rtsys
30
+ from numba.cuda.runtime.nrt import rtsys, NRT_LIBRARY
33
31
  from numba.cuda.locks import module_init_lock
34
32
 
35
33
  from numba import cuda
@@ -59,54 +57,6 @@ cuda_fp16_math_funcs = [
59
57
  reshape_funcs = ["nocopy_empty_reshape", "numba_attempt_nocopy_reshape"]
60
58
 
61
59
 
62
- def get_cres_link_objects(cres):
63
- """Given a compile result, return a set of all linkable code objects that
64
- are required for it to be fully linked."""
65
-
66
- link_objects = set()
67
-
68
- # List of calls into declared device functions
69
- device_func_calls = [
70
- (name, v)
71
- for name, v in cres.fndesc.typemap.items()
72
- if (isinstance(v, cuda_types.CUDADispatcher))
73
- ]
74
-
75
- # List of tuples with SSA name of calls and corresponding signature
76
- call_signatures = [
77
- (call.func.name, sig)
78
- for call, sig in cres.fndesc.calltypes.items()
79
- if (isinstance(call, ir.Expr) and call.op == "call")
80
- ]
81
-
82
- # Map SSA names to all invoked signatures
83
- call_signature_d = defaultdict(list)
84
- for name, sig in call_signatures:
85
- call_signature_d[name].append(sig)
86
-
87
- # Add the link objects from the current function's callees
88
- for name, v in device_func_calls:
89
- for sig in call_signature_d.get(name, []):
90
- called_cres = v.dispatcher.overloads[sig.args]
91
- called_link_objects = get_cres_link_objects(called_cres)
92
- link_objects.update(called_link_objects)
93
-
94
- # From this point onwards, we are only interested in ExternFunction
95
- # declarations - these are the calls made directly in this function to
96
- # them.
97
- for name, v in cres.fndesc.typemap.items():
98
- if not isinstance(v, Function):
99
- continue
100
-
101
- if not isinstance(v.typing_key, ExternFunction):
102
- continue
103
-
104
- for obj in v.typing_key.link:
105
- link_objects.add(obj)
106
-
107
- return link_objects
108
-
109
-
110
60
  class _Kernel(serialize.ReduceMixin):
111
61
  """
112
62
  CUDA Kernel specialized for a given set of argument types. When called, this
@@ -238,9 +188,6 @@ class _Kernel(serialize.ReduceMixin):
238
188
 
239
189
  self.maybe_link_nrt(link, tgt_ctx, asm)
240
190
 
241
- for obj in get_cres_link_objects(cres):
242
- lib.add_linking_file(obj)
243
-
244
191
  for filepath in link:
245
192
  lib.add_linking_file(filepath)
246
193
 
@@ -263,6 +210,13 @@ class _Kernel(serialize.ReduceMixin):
263
210
  self.reload_init = []
264
211
 
265
212
  def maybe_link_nrt(self, link, tgt_ctx, asm):
213
+ """
214
+ Add the NRT source code to the link if the neccesary conditions are met.
215
+ NRT must be enabled for the CUDATargetContext, and either NRT functions
216
+ must be detected in the kernel asm or an NRT enabled LinkableCode object
217
+ must be passed.
218
+ """
219
+
266
220
  if not tgt_ctx.enable_nrt:
267
221
  return
268
222
 
@@ -272,13 +226,19 @@ class _Kernel(serialize.ReduceMixin):
272
226
  + all_nrt
273
227
  + r")\s*\([^)]*\)\s*;"
274
228
  )
275
-
229
+ link_nrt = False
276
230
  nrt_in_asm = re.findall(pattern, asm)
277
-
278
- basedir = os.path.dirname(os.path.abspath(__file__))
279
- if nrt_in_asm:
280
- nrt_path = os.path.join(basedir, "runtime", "nrt.cu")
281
- link.append(nrt_path)
231
+ if len(nrt_in_asm) > 0:
232
+ link_nrt = True
233
+ if not link_nrt:
234
+ for file in link:
235
+ if isinstance(file, LinkableCode):
236
+ if file.nrt:
237
+ link_nrt = True
238
+ break
239
+
240
+ if link_nrt:
241
+ link.append(NRT_LIBRARY)
282
242
 
283
243
  @property
284
244
  def library(self):
@@ -4,30 +4,14 @@
4
4
  #include <cuda/atomic>
5
5
 
6
6
  #include "memsys.cuh"
7
+ #include "nrt.cuh"
7
8
 
8
- typedef void (*NRT_dtor_function)(void* ptr, size_t size, void* info);
9
- typedef void (*NRT_dealloc_func)(void* ptr, void* dealloc_info);
10
-
11
- typedef struct MemInfo NRT_MemInfo;
12
-
13
- extern "C" {
14
- struct MemInfo {
15
- cuda::atomic<size_t, cuda::thread_scope_device> refct;
16
- NRT_dtor_function dtor;
17
- void* dtor_info;
18
- void* data;
19
- size_t size;
20
- };
21
- }
22
9
 
23
10
  extern "C" __global__ void NRT_MemSys_set(NRT_MemSys *memsys_ptr)
24
11
  {
25
12
  TheMSys = memsys_ptr;
26
13
  }
27
14
 
28
- static __device__ void *nrt_allocate_meminfo_and_data_align(size_t size, unsigned align, NRT_MemInfo **mi);
29
- static __device__ void *nrt_allocate_meminfo_and_data(size_t size, NRT_MemInfo **mi_out);
30
- extern "C" __device__ void* NRT_Allocate_External(size_t size);
31
15
 
32
16
  extern "C" __device__ void* NRT_Allocate(size_t size)
33
17
  {
@@ -177,6 +161,7 @@ extern "C" __device__ void NRT_decref(NRT_MemInfo* mi)
177
161
  }
178
162
  }
179
163
 
164
+
180
165
  #endif
181
166
 
182
167
  extern "C" __device__ void NRT_incref(NRT_MemInfo* mi)
@@ -0,0 +1,41 @@
1
+ #include <cuda/atomic>
2
+
3
+ typedef void (*NRT_dtor_function)(void* ptr, size_t size, void* info);
4
+ typedef void (*NRT_dealloc_func)(void* ptr, void* dealloc_info);
5
+
6
+ extern "C"
7
+ struct MemInfo {
8
+ cuda::atomic<size_t, cuda::thread_scope_device> refct;
9
+ NRT_dtor_function dtor;
10
+ void* dtor_info;
11
+ void* data;
12
+ size_t size;
13
+ };
14
+ typedef struct MemInfo NRT_MemInfo;
15
+
16
+ extern "C" __device__ void* NRT_Allocate(size_t size);
17
+ extern "C" __device__ void NRT_MemInfo_init(NRT_MemInfo* mi,
18
+ void* data,
19
+ size_t size,
20
+ NRT_dtor_function dtor,
21
+ void* dtor_info);
22
+ static __device__ void *nrt_allocate_meminfo_and_data_align(size_t size, unsigned align, NRT_MemInfo **mi);
23
+ static __device__ void *nrt_allocate_meminfo_and_data(size_t size, NRT_MemInfo **mi_out);
24
+ extern "C" __device__ void* NRT_Allocate_External(size_t size);
25
+ extern "C" __device__ void NRT_decref(NRT_MemInfo* mi);
26
+ extern "C" __device__ void NRT_incref(NRT_MemInfo* mi);
27
+ extern "C" __device__ void* NRT_Allocate_External(size_t size);
28
+ static __device__ void *nrt_allocate_meminfo_and_data(size_t size, NRT_MemInfo **mi_out);
29
+ static __device__ void *nrt_allocate_meminfo_and_data_align(size_t size, unsigned align, NRT_MemInfo **mi);
30
+ extern "C" __device__ NRT_MemInfo *NRT_MemInfo_alloc_aligned(size_t size, unsigned align);
31
+ extern "C" __device__ void* NRT_MemInfo_data_fast(NRT_MemInfo *mi);
32
+ extern "C" __device__ void NRT_MemInfo_call_dtor(NRT_MemInfo* mi);
33
+ extern "C" __device__ void NRT_MemInfo_destroy(NRT_MemInfo* mi);
34
+ extern "C" __device__ void NRT_dealloc(NRT_MemInfo* mi);
35
+ extern "C" __device__ void NRT_Free(void* ptr);
36
+ extern "C" __device__ NRT_MemInfo* NRT_MemInfo_new(void* data, size_t size, NRT_dtor_function dtor, void* dtor_info);
37
+ extern "C" __device__ void NRT_MemInfo_init(NRT_MemInfo* mi,
38
+ void* data,
39
+ size_t size,
40
+ NRT_dtor_function dtor,
41
+ void* dtor_info);
@@ -13,7 +13,8 @@ from numba.cuda.cudadrv.driver import (
13
13
  )
14
14
  from numba.cuda.cudadrv import devices
15
15
  from numba.cuda.api import get_current_device
16
- from numba.cuda.utils import _readenv
16
+ from numba.cuda.utils import _readenv, cached_file_read
17
+ from numba.cuda.cudadrv.linkable_code import CUSource
17
18
 
18
19
 
19
20
  # Check environment variable or config for NRT statistics enablement
@@ -32,6 +33,11 @@ if not hasattr(config, "NUMBA_CUDA_ENABLE_NRT"):
32
33
  config.CUDA_ENABLE_NRT = ENABLE_NRT
33
34
 
34
35
 
36
+ def get_include():
37
+ """Return the include path for the NRT header"""
38
+ return os.path.dirname(os.path.abspath(__file__))
39
+
40
+
35
41
  # Protect method to ensure NRT memory allocation and initialization
36
42
  def _alloc_init_guard(method):
37
43
  """
@@ -340,3 +346,9 @@ class _Runtime:
340
346
 
341
347
  # Create an instance of the runtime
342
348
  rtsys = _Runtime()
349
+
350
+
351
+ basedir = os.path.dirname(os.path.abspath(__file__))
352
+ nrt_path = os.path.join(basedir, "nrt.cu")
353
+ nrt_src = cached_file_read(nrt_path)
354
+ NRT_LIBRARY = CUSource(nrt_src, name="nrt.cu", nrt=True)
@@ -129,12 +129,16 @@ class shared(Stub):
129
129
  _description_ = "<shared>"
130
130
 
131
131
  @stub_function
132
- def array(shape, dtype):
132
+ def array(shape, dtype, alignment=None):
133
133
  """
134
- Allocate a shared array of the given *shape* and *type*. *shape* is
135
- either an integer or a tuple of integers representing the array's
136
- dimensions. *type* is a :ref:`Numba type <numba-types>` of the
137
- elements needing to be stored in the array.
134
+ Allocate a shared array of the given *shape*, *type*, and, optionally,
135
+ *alignment*. *shape* is either an integer or a tuple of integers
136
+ representing the array's dimensions. *type* is a :ref:`Numba type
137
+ <numba-types>` of the elements needing to be stored in the array.
138
+ *alignment* is an optional integer specifying the byte alignment of
139
+ the array. When specified, it must be a power of two, and a multiple
140
+ of the size of a pointer (8 bytes). When not specified, the array is
141
+ allocated with an alignment appropriate for the supplied *dtype*.
138
142
 
139
143
  The returned array-like object can be read and written to like any
140
144
  normal device array (e.g. through indexing).
@@ -149,12 +153,20 @@ class local(Stub):
149
153
  _description_ = "<local>"
150
154
 
151
155
  @stub_function
152
- def array(shape, dtype):
153
- """
154
- Allocate a local array of the given *shape* and *type*. The array is
155
- private to the current thread, and resides in global memory. An
156
- array-like object is returned which can be read and written to like any
157
- standard array (e.g. through indexing).
156
+ def array(shape, dtype, alignment=None):
157
+ """
158
+ Allocate a local array of the given *shape*, *type*, and, optionally,
159
+ *alignment*. *shape* is either an integer or a tuple of integers
160
+ representing the array's dimensions. *type* is a :ref:`Numba type
161
+ <numba-types>` of the elements needing to be stored in the array.
162
+ *alignment* is an optional integer specifying the byte alignment of
163
+ the array. When specified, it must be a power of two, and a multiple
164
+ of the size of a pointer (8 bytes). When not specified, the array is
165
+ allocated with an alignment appropriate for the supplied *dtype*.
166
+
167
+ The array is private to the current thread, and resides in global
168
+ memory. An array-like object is returned which can be read and
169
+ written to like any standard array (e.g. through indexing).
158
170
  """
159
171
 
160
172
 
@@ -0,0 +1,236 @@
1
+ import re
2
+ import itertools
3
+ import numpy as np
4
+ from numba import cuda
5
+ from numba.core.errors import TypingError
6
+ from numba.cuda.testing import CUDATestCase
7
+ import unittest
8
+
9
+
10
+ # Set to true if you want to see dots printed for each subtest.
11
+ NOISY = False
12
+
13
+
14
+ # In order to verify the alignment of the local and shared memory arrays, we
15
+ # inspect the LLVM IR of the generated kernel using the following regexes.
16
+
17
+ # Shared memory example:
18
+ # @"_cudapy_smem_38" = addrspace(3) global [1 x i8] undef, align 16
19
+ SMEM_PATTERN = re.compile(
20
+ r'^@"_cudapy_smem_\d+".*?align (\d+)',
21
+ re.MULTILINE,
22
+ )
23
+
24
+ # Local memory example:
25
+ # %"_cudapy_lmem" = alloca [1 x i8], align 64
26
+ LMEM_PATTERN = re.compile(
27
+ r'^\s*%"_cudapy_lmem".*?align (\d+)',
28
+ re.MULTILINE,
29
+ )
30
+
31
+
32
+ DTYPES = [np.uint8, np.uint32, np.uint64]
33
+
34
+ # Add in some record dtypes with and without alignment.
35
+ for align in (True, False):
36
+ DTYPES += [
37
+ np.dtype(
38
+ [
39
+ ("a", np.uint8),
40
+ ("b", np.int32),
41
+ ("c", np.float64),
42
+ ],
43
+ align=align,
44
+ ),
45
+ np.dtype(
46
+ [
47
+ ("a", np.uint32),
48
+ ("b", np.uint8),
49
+ ],
50
+ align=align,
51
+ ),
52
+ np.dtype(
53
+ [
54
+ ("a", np.uint8),
55
+ ("b", np.int32),
56
+ ("c", np.float64),
57
+ ("d", np.complex64),
58
+ ("e", (np.uint8, 5)),
59
+ ],
60
+ align=align,
61
+ ),
62
+ ]
63
+
64
+ # N.B. We name the test class TestArrayAddressAlignment to avoid name conflict
65
+ # with the test_alignment.TestArrayAlignment class.
66
+
67
+
68
+ class TestArrayAddressAlignment(CUDATestCase):
69
+ """
70
+ Test cuda.local.array and cuda.shared.array support for an alignment
71
+ keyword argument.
72
+ """
73
+
74
+ def test_array_alignment_1d(self):
75
+ shapes = (1, 8, 50)
76
+ alignments = (None, 16, 64, 256)
77
+ array_types = [(0, "local"), (1, "shared")]
78
+ self._do_test(array_types, shapes, DTYPES, alignments)
79
+
80
+ def test_array_alignment_2d(self):
81
+ shapes = ((2, 3),)
82
+ alignments = (None, 16, 64, 256)
83
+ array_types = [(0, "local"), (1, "shared")]
84
+ self._do_test(array_types, shapes, DTYPES, alignments)
85
+
86
+ def test_array_alignment_3d(self):
87
+ shapes = ((2, 3, 4), (1, 4, 5))
88
+ alignments = (None, 16, 64, 256)
89
+ array_types = [(0, "local"), (1, "shared")]
90
+ self._do_test(array_types, shapes, DTYPES, alignments)
91
+
92
+ def _do_test(self, array_types, shapes, dtypes, alignments):
93
+ items = itertools.product(array_types, shapes, dtypes, alignments)
94
+
95
+ for (which, array_type), shape, dtype, alignment in items:
96
+ with self.subTest(
97
+ array_type=array_type,
98
+ shape=shape,
99
+ dtype=dtype,
100
+ alignment=alignment,
101
+ ):
102
+
103
+ @cuda.jit
104
+ def f(loc, shrd, which):
105
+ i = cuda.grid(1)
106
+ if which == 0:
107
+ local_array = cuda.local.array(
108
+ shape=shape,
109
+ dtype=dtype,
110
+ alignment=alignment,
111
+ )
112
+ if i == 0:
113
+ loc[0] = local_array.ctypes.data
114
+ else:
115
+ shared_array = cuda.shared.array(
116
+ shape=shape,
117
+ dtype=dtype,
118
+ alignment=alignment,
119
+ )
120
+ if i == 0:
121
+ shrd[0] = shared_array.ctypes.data
122
+
123
+ loc = np.zeros(1, dtype=np.uint64)
124
+ shrd = np.zeros(1, dtype=np.uint64)
125
+ f[1, 1](loc, shrd, which)
126
+
127
+ kernel = f.overloads[f.signatures[0]]
128
+ llvm_ir = kernel.inspect_llvm()
129
+
130
+ if alignment is None:
131
+ if which == 0:
132
+ # Local memory shouldn't have any alignment information
133
+ # when no alignment is specified.
134
+ match = LMEM_PATTERN.findall(llvm_ir)
135
+ self.assertEqual(len(match), 0)
136
+ else:
137
+ # Shared memory should at least have a power-of-two
138
+ # alignment when no alignment is specified.
139
+ match = SMEM_PATTERN.findall(llvm_ir)
140
+ self.assertEqual(len(match), 1)
141
+
142
+ alignment = int(match[0])
143
+ # Verify alignment is a power of two.
144
+ self.assertTrue(alignment & (alignment - 1) == 0)
145
+ else:
146
+ # Verify alignment is in the LLVM IR.
147
+ if which == 0:
148
+ match = LMEM_PATTERN.findall(llvm_ir)
149
+ self.assertEqual(len(match), 1)
150
+ actual_alignment = int(match[0])
151
+ self.assertEqual(alignment, actual_alignment)
152
+ else:
153
+ match = SMEM_PATTERN.findall(llvm_ir)
154
+ self.assertEqual(len(match), 1)
155
+ actual_alignment = int(match[0])
156
+ self.assertEqual(alignment, actual_alignment)
157
+
158
+ # Also verify that the address of the array is aligned.
159
+ # If this fails, there problem is likely with NVVM.
160
+ address = loc[0] if which == 0 else shrd[0]
161
+ alignment_mod = int(address % alignment)
162
+ self.assertEqual(alignment_mod, 0)
163
+
164
+ if NOISY:
165
+ print(".", end="", flush=True)
166
+
167
+ def test_invalid_aligments(self):
168
+ shapes = (1, 50)
169
+ dtypes = (np.uint8, np.uint64)
170
+ invalid_alignment_values = (-1, 0, 3, 17, 33)
171
+ invalid_alignment_types = ("1.0", "1", "foo", 1.0, 1.5, 3.2)
172
+ alignments = invalid_alignment_values + invalid_alignment_types
173
+ array_types = [(0, "local"), (1, "shared")]
174
+
175
+ # Use regex pattern to match error message, handling potential ANSI
176
+ # color codes which appear on CI.
177
+ expected_invalid_type_error_regex = (
178
+ r"RequireLiteralValue:.*alignment must be a constant integer"
179
+ )
180
+
181
+ items = itertools.product(array_types, shapes, dtypes, alignments)
182
+
183
+ for (which, array_type), shape, dtype, alignment in items:
184
+ with self.subTest(
185
+ array_type=array_type,
186
+ shape=shape,
187
+ dtype=dtype,
188
+ alignment=alignment,
189
+ ):
190
+ if which == 0:
191
+
192
+ @cuda.jit
193
+ def f(dest_array):
194
+ i = cuda.grid(1)
195
+ local_array = cuda.local.array(
196
+ shape=shape,
197
+ dtype=dtype,
198
+ alignment=alignment,
199
+ )
200
+ if i == 0:
201
+ dest_array[0] = local_array.ctypes.data
202
+ else:
203
+
204
+ @cuda.jit
205
+ def f(dest_array):
206
+ i = cuda.grid(1)
207
+ shared_array = cuda.shared.array(
208
+ shape=shape,
209
+ dtype=dtype,
210
+ alignment=alignment,
211
+ )
212
+ if i == 0:
213
+ dest_array[0] = shared_array.ctypes.data
214
+
215
+ array = np.zeros(1, dtype=np.uint64)
216
+
217
+ # The type of error we expect differs between an invalid value
218
+ # that is still an int, and an invalid type.
219
+ if isinstance(alignment, int):
220
+ self.assertRaisesRegex(
221
+ ValueError, r"Alignment must be.*", f[1, 1], array
222
+ )
223
+ else:
224
+ self.assertRaisesRegex(
225
+ TypingError,
226
+ expected_invalid_type_error_regex,
227
+ f[1, 1],
228
+ array,
229
+ )
230
+
231
+ if NOISY:
232
+ print(".", end="", flush=True)
233
+
234
+
235
+ if __name__ == "__main__":
236
+ unittest.main()
@@ -1,7 +1,10 @@
1
1
  from numba.cuda.testing import skip_on_cudasim, unittest, CUDATestCase
2
+ from llvmlite import ir
2
3
 
3
4
  import numpy as np
5
+ import os
4
6
  from numba import config, cuda, njit, types
7
+ from numba.extending import overload
5
8
 
6
9
 
7
10
  class Interval:
@@ -160,5 +163,142 @@ class TestExtending(CUDATestCase):
160
163
  np.testing.assert_allclose(r, expected)
161
164
 
162
165
 
166
+ TEST_BIN_DIR = os.getenv("NUMBA_CUDA_TEST_BIN_DIR")
167
+ if TEST_BIN_DIR:
168
+ test_device_functions_a = os.path.join(
169
+ TEST_BIN_DIR, "test_device_functions.a"
170
+ )
171
+ test_device_functions_cubin = os.path.join(
172
+ TEST_BIN_DIR, "test_device_functions.cubin"
173
+ )
174
+ test_device_functions_cu = os.path.join(
175
+ TEST_BIN_DIR, "test_device_functions.cu"
176
+ )
177
+ test_device_functions_fatbin = os.path.join(
178
+ TEST_BIN_DIR, "test_device_functions.fatbin"
179
+ )
180
+ test_device_functions_fatbin_multi = os.path.join(
181
+ TEST_BIN_DIR, "test_device_functions_multi.fatbin"
182
+ )
183
+ test_device_functions_o = os.path.join(
184
+ TEST_BIN_DIR, "test_device_functions.o"
185
+ )
186
+ test_device_functions_ptx = os.path.join(
187
+ TEST_BIN_DIR, "test_device_functions.ptx"
188
+ )
189
+ test_device_functions_ltoir = os.path.join(
190
+ TEST_BIN_DIR, "test_device_functions.ltoir"
191
+ )
192
+
193
+
194
+ class TestExtendingLinkage(CUDATestCase):
195
+ def test_extension_adds_linkable_code(self):
196
+ cuda_major_version = cuda.runtime.get_version()[0]
197
+
198
+ if cuda_major_version < 12:
199
+ self.skipTest("CUDA 12 required for linking in-memory data")
200
+
201
+ files = (
202
+ (test_device_functions_a, cuda.Archive),
203
+ (test_device_functions_cubin, cuda.Cubin),
204
+ (test_device_functions_cu, cuda.CUSource),
205
+ (test_device_functions_fatbin, cuda.Fatbin),
206
+ (test_device_functions_o, cuda.Object),
207
+ (test_device_functions_ptx, cuda.PTXSource),
208
+ (test_device_functions_ltoir, cuda.LTOIR),
209
+ )
210
+
211
+ lto = config.CUDA_ENABLE_PYNVJITLINK
212
+
213
+ for path, ctor in files:
214
+ if ctor == cuda.LTOIR and not lto:
215
+ # Don't try to test with LTOIR if LTO is not enabled
216
+ continue
217
+
218
+ with open(path, "rb") as f:
219
+ code_object = ctor(f.read())
220
+
221
+ def external_add(x, y):
222
+ return x + y
223
+
224
+ @type_callable(external_add)
225
+ def type_external_add(context):
226
+ def typer(x, y):
227
+ if x == types.uint32 and y == types.uint32:
228
+ return types.uint32
229
+
230
+ return typer
231
+
232
+ @lower_builtin(external_add, types.uint32, types.uint32)
233
+ def lower_external_add(context, builder, sig, args):
234
+ context.active_code_library.add_linking_file(code_object)
235
+ i32 = ir.IntType(32)
236
+ fnty = ir.FunctionType(i32, [i32, i32])
237
+ fn = cgutils.get_or_insert_function(
238
+ builder.module, fnty, "add_cabi"
239
+ )
240
+ return builder.call(fn, args)
241
+
242
+ @cuda.jit(lto=lto)
243
+ def use_external_add(r, x, y):
244
+ r[0] = external_add(x[0], y[0])
245
+
246
+ r = np.zeros(1, dtype=np.uint32)
247
+ x = np.ones(1, dtype=np.uint32)
248
+ y = np.ones(1, dtype=np.uint32) * 2
249
+
250
+ use_external_add[1, 1](r, x, y)
251
+
252
+ np.testing.assert_equal(r[0], 3)
253
+
254
+ @cuda.jit(lto=lto)
255
+ def use_external_add_device(x, y):
256
+ return external_add(x, y)
257
+
258
+ @cuda.jit(lto=lto)
259
+ def use_external_add_kernel(r, x, y):
260
+ r[0] = use_external_add_device(x[0], y[0])
261
+
262
+ r = np.zeros(1, dtype=np.uint32)
263
+ x = np.ones(1, dtype=np.uint32)
264
+ y = np.ones(1, dtype=np.uint32) * 2
265
+
266
+ use_external_add_kernel[1, 1](r, x, y)
267
+
268
+ np.testing.assert_equal(r[0], 3)
269
+
270
+ def test_linked_called_through_overload(self):
271
+ cu_code = cuda.CUSource("""
272
+ extern "C" __device__
273
+ int bar(int *out, int a)
274
+ {
275
+ *out = a * 2;
276
+ return 0;
277
+ }
278
+ """)
279
+
280
+ bar = cuda.declare_device("bar", "int32(int32)", link=cu_code)
281
+
282
+ def bar_call(val):
283
+ pass
284
+
285
+ @overload(bar_call, target="cuda")
286
+ def ol_bar_call(a):
287
+ return lambda a: bar(a)
288
+
289
+ @cuda.jit("void(int32[::1], int32[::1])")
290
+ def foo(r, x):
291
+ i = cuda.grid(1)
292
+ if i < len(r):
293
+ r[i] = bar_call(x[i])
294
+
295
+ x = np.arange(10, dtype=np.int32)
296
+ r = np.empty_like(x)
297
+
298
+ foo[1, 32](r, x)
299
+
300
+ np.testing.assert_equal(r, x * 2)
301
+
302
+
163
303
  if __name__ == "__main__":
164
304
  unittest.main()