numba-cuda 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/compiler.py +14 -1
- numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
- numba_cuda/numba/cuda/cuda_paths.py +2 -0
- numba_cuda/numba/cuda/cudadecl.py +0 -42
- numba_cuda/numba/cuda/cudadrv/linkable_code.py +11 -2
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +10 -3
- numba_cuda/numba/cuda/cudaimpl.py +0 -63
- numba_cuda/numba/cuda/debuginfo.py +92 -2
- numba_cuda/numba/cuda/decorators.py +13 -1
- numba_cuda/numba/cuda/device_init.py +4 -5
- numba_cuda/numba/cuda/extending.py +54 -0
- numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
- numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
- numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +550 -387
- numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +465 -316
- numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
- numba_cuda/numba/cuda/intrinsics.py +172 -1
- numba_cuda/numba/cuda/lowering.py +43 -0
- numba_cuda/numba/cuda/stubs.py +0 -11
- numba_cuda/numba/cuda/target.py +28 -0
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +4 -2
- numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
- numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +46 -0
- numba_cuda/numba/cuda/tests/cudapy/test_enums.py +18 -0
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +4 -2
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +59 -0
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +50 -5
- numba_cuda/numba/cuda/vector_types.py +3 -1
- numba_cuda/numba/cuda/vectorizers.py +1 -1
- {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.0.dist-info}/METADATA +1 -1
- {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.0.dist-info}/RECORD +42 -32
- {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.0.dist-info}/WHEEL +1 -1
- {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.0.dist-info}/top_level.txt +0 -0
@@ -201,6 +201,8 @@ def _get_nvrtc_wheel():
|
|
201
201
|
|
202
202
|
def _get_libdevice_paths():
|
203
203
|
by, libdir = _get_libdevice_path_decision()
|
204
|
+
if not libdir:
|
205
|
+
return _env_path_tuple(by, None)
|
204
206
|
out = os.path.join(libdir, "libdevice.10.bc")
|
205
207
|
return _env_path_tuple(by, out)
|
206
208
|
|
@@ -100,45 +100,6 @@ class Cuda_syncwarp(ConcreteTemplate):
|
|
100
100
|
cases = [signature(types.none), signature(types.none, types.i4)]
|
101
101
|
|
102
102
|
|
103
|
-
@register
|
104
|
-
class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
|
105
|
-
key = cuda.shfl_sync_intrinsic
|
106
|
-
cases = [
|
107
|
-
signature(
|
108
|
-
types.Tuple((types.i4, types.b1)),
|
109
|
-
types.i4,
|
110
|
-
types.i4,
|
111
|
-
types.i4,
|
112
|
-
types.i4,
|
113
|
-
types.i4,
|
114
|
-
),
|
115
|
-
signature(
|
116
|
-
types.Tuple((types.i8, types.b1)),
|
117
|
-
types.i4,
|
118
|
-
types.i4,
|
119
|
-
types.i8,
|
120
|
-
types.i4,
|
121
|
-
types.i4,
|
122
|
-
),
|
123
|
-
signature(
|
124
|
-
types.Tuple((types.f4, types.b1)),
|
125
|
-
types.i4,
|
126
|
-
types.i4,
|
127
|
-
types.f4,
|
128
|
-
types.i4,
|
129
|
-
types.i4,
|
130
|
-
),
|
131
|
-
signature(
|
132
|
-
types.Tuple((types.f8, types.b1)),
|
133
|
-
types.i4,
|
134
|
-
types.i4,
|
135
|
-
types.f8,
|
136
|
-
types.i4,
|
137
|
-
types.i4,
|
138
|
-
),
|
139
|
-
]
|
140
|
-
|
141
|
-
|
142
103
|
@register
|
143
104
|
class Cuda_vote_sync_intrinsic(ConcreteTemplate):
|
144
105
|
key = cuda.vote_sync_intrinsic
|
@@ -815,9 +776,6 @@ class CudaModuleTemplate(AttributeTemplate):
|
|
815
776
|
def resolve_syncwarp(self, mod):
|
816
777
|
return types.Function(Cuda_syncwarp)
|
817
778
|
|
818
|
-
def resolve_shfl_sync_intrinsic(self, mod):
|
819
|
-
return types.Function(Cuda_shfl_sync_intrinsic)
|
820
|
-
|
821
779
|
def resolve_vote_sync_intrinsic(self, mod):
|
822
780
|
return types.Function(Cuda_vote_sync_intrinsic)
|
823
781
|
|
@@ -1,10 +1,13 @@
|
|
1
|
+
import io
|
1
2
|
from .mappings import FILE_EXTENSION_MAP
|
2
3
|
|
3
4
|
|
4
5
|
class LinkableCode:
|
5
6
|
"""An object that holds code to be linked from memory.
|
6
7
|
|
7
|
-
:param data: A buffer containing the data to link.
|
8
|
+
:param data: A buffer, StringIO or BytesIO containing the data to link.
|
9
|
+
If a file object is passed, the content in the object is
|
10
|
+
read when `data` property is accessed.
|
8
11
|
:param name: The name of the file to be referenced in any compilation or
|
9
12
|
linking errors that may be produced.
|
10
13
|
:param setup_callback: A function called prior to the launch of a kernel
|
@@ -23,8 +26,8 @@ class LinkableCode:
|
|
23
26
|
if teardown_callback and not callable(teardown_callback):
|
24
27
|
raise TypeError("teardown_callback must be callable")
|
25
28
|
|
26
|
-
self.data = data
|
27
29
|
self._name = name
|
30
|
+
self._data = data
|
28
31
|
self.setup_callback = setup_callback
|
29
32
|
self.teardown_callback = teardown_callback
|
30
33
|
|
@@ -32,6 +35,12 @@ class LinkableCode:
|
|
32
35
|
def name(self):
|
33
36
|
return self._name or self.default_name
|
34
37
|
|
38
|
+
@property
|
39
|
+
def data(self):
|
40
|
+
if isinstance(self._data, (io.StringIO, io.BytesIO)):
|
41
|
+
return self._data.getvalue()
|
42
|
+
return self._data
|
43
|
+
|
35
44
|
|
36
45
|
class PTXSource(LinkableCode):
|
37
46
|
"""PTX source code in memory."""
|
@@ -372,19 +372,26 @@ def compile(src, name, cc, ltoir=False):
|
|
372
372
|
f"-I{get_cuda_paths()['include_dir'].info}",
|
373
373
|
]
|
374
374
|
|
375
|
+
nvrtc_version = nvrtc.get_version()
|
376
|
+
nvrtc_ver_major = nvrtc_version[0]
|
377
|
+
|
375
378
|
cudadrv_path = os.path.dirname(os.path.abspath(__file__))
|
376
379
|
numba_cuda_path = os.path.dirname(cudadrv_path)
|
377
|
-
|
380
|
+
|
381
|
+
if nvrtc_ver_major == 11:
|
382
|
+
numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '11')}"
|
383
|
+
else:
|
384
|
+
numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '12')}"
|
378
385
|
|
379
386
|
nrt_path = os.path.join(numba_cuda_path, "runtime")
|
380
387
|
nrt_include = f"-I{nrt_path}"
|
381
388
|
|
382
|
-
options = [arch, *cuda_include,
|
389
|
+
options = [arch, numba_include, *cuda_include, nrt_include, "-rdc", "true"]
|
383
390
|
|
384
391
|
if ltoir:
|
385
392
|
options.append("-dlto")
|
386
393
|
|
387
|
-
if
|
394
|
+
if nvrtc_version < (12, 0):
|
388
395
|
options += ["-std=c++17"]
|
389
396
|
|
390
397
|
# Compile the program
|
@@ -204,69 +204,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
|
|
204
204
|
return context.get_dummy_value()
|
205
205
|
|
206
206
|
|
207
|
-
@lower(
|
208
|
-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
|
209
|
-
)
|
210
|
-
@lower(
|
211
|
-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
|
212
|
-
)
|
213
|
-
@lower(
|
214
|
-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
|
215
|
-
)
|
216
|
-
@lower(
|
217
|
-
stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
|
218
|
-
)
|
219
|
-
def ptx_shfl_sync_i32(context, builder, sig, args):
|
220
|
-
"""
|
221
|
-
The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
|
222
|
-
function supports both 32 and 64 bit ints and floats, so for feature parity,
|
223
|
-
i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
|
224
|
-
an int, then shuffling, then bitcasting back. And 64-bit values by packing
|
225
|
-
them into 2 32bit values, shuffling thoose, and then packing back together.
|
226
|
-
"""
|
227
|
-
mask, mode, value, index, clamp = args
|
228
|
-
value_type = sig.args[2]
|
229
|
-
if value_type in types.real_domain:
|
230
|
-
value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
|
231
|
-
fname = "llvm.nvvm.shfl.sync.i32"
|
232
|
-
lmod = builder.module
|
233
|
-
fnty = ir.FunctionType(
|
234
|
-
ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
|
235
|
-
(
|
236
|
-
ir.IntType(32),
|
237
|
-
ir.IntType(32),
|
238
|
-
ir.IntType(32),
|
239
|
-
ir.IntType(32),
|
240
|
-
ir.IntType(32),
|
241
|
-
),
|
242
|
-
)
|
243
|
-
func = cgutils.get_or_insert_function(lmod, fnty, fname)
|
244
|
-
if value_type.bitwidth == 32:
|
245
|
-
ret = builder.call(func, (mask, mode, value, index, clamp))
|
246
|
-
if value_type == types.float32:
|
247
|
-
rv = builder.extract_value(ret, 0)
|
248
|
-
pred = builder.extract_value(ret, 1)
|
249
|
-
fv = builder.bitcast(rv, ir.FloatType())
|
250
|
-
ret = cgutils.make_anonymous_struct(builder, (fv, pred))
|
251
|
-
else:
|
252
|
-
value1 = builder.trunc(value, ir.IntType(32))
|
253
|
-
value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
|
254
|
-
value2 = builder.trunc(value_lshr, ir.IntType(32))
|
255
|
-
ret1 = builder.call(func, (mask, mode, value1, index, clamp))
|
256
|
-
ret2 = builder.call(func, (mask, mode, value2, index, clamp))
|
257
|
-
rv1 = builder.extract_value(ret1, 0)
|
258
|
-
rv2 = builder.extract_value(ret2, 0)
|
259
|
-
pred = builder.extract_value(ret1, 1)
|
260
|
-
rv1_64 = builder.zext(rv1, ir.IntType(64))
|
261
|
-
rv2_64 = builder.zext(rv2, ir.IntType(64))
|
262
|
-
rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
|
263
|
-
rv = builder.or_(rv_shl, rv1_64)
|
264
|
-
if value_type == types.float64:
|
265
|
-
rv = builder.bitcast(rv, ir.DoubleType())
|
266
|
-
ret = cgutils.make_anonymous_struct(builder, (rv, pred))
|
267
|
-
return ret
|
268
|
-
|
269
|
-
|
270
207
|
@lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
|
271
208
|
def ptx_vote_sync(context, builder, sig, args):
|
272
209
|
fname = "llvm.nvvm.vote.sync"
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from llvmlite import ir
|
2
|
-
from numba.core import types
|
2
|
+
from numba.core import types, cgutils
|
3
3
|
from numba.core.debuginfo import DIBuilder
|
4
4
|
from numba.cuda.types import GridGroup
|
5
5
|
|
@@ -7,8 +7,14 @@ _BYTE_SIZE = 8
|
|
7
7
|
|
8
8
|
|
9
9
|
class CUDADIBuilder(DIBuilder):
|
10
|
+
def __init__(self, module, filepath, cgctx, directives_only):
|
11
|
+
super().__init__(module, filepath, cgctx, directives_only)
|
12
|
+
# Cache for local variable metadata type and line deduplication
|
13
|
+
self._vartypelinemap = {}
|
14
|
+
|
10
15
|
def _var_type(self, lltype, size, datamodel=None):
|
11
16
|
is_bool = False
|
17
|
+
is_int_literal = False
|
12
18
|
is_grid_group = False
|
13
19
|
|
14
20
|
if isinstance(lltype, ir.IntType):
|
@@ -20,15 +26,23 @@ class CUDADIBuilder(DIBuilder):
|
|
20
26
|
name = str(datamodel.fe_type)
|
21
27
|
if isinstance(datamodel.fe_type, types.Boolean):
|
22
28
|
is_bool = True
|
29
|
+
if isinstance(datamodel.fe_type, types.BooleanLiteral):
|
30
|
+
name = "bool"
|
31
|
+
elif isinstance(datamodel.fe_type, types.Integer):
|
32
|
+
if isinstance(datamodel.fe_type, types.IntegerLiteral):
|
33
|
+
name = f"int{_BYTE_SIZE * size}"
|
34
|
+
is_int_literal = True
|
23
35
|
elif isinstance(datamodel.fe_type, GridGroup):
|
24
36
|
is_grid_group = True
|
25
37
|
|
26
|
-
if is_bool or is_grid_group:
|
38
|
+
if is_bool or is_int_literal or is_grid_group:
|
27
39
|
m = self.module
|
28
40
|
bitsize = _BYTE_SIZE * size
|
29
41
|
# Boolean type workaround until upstream Numba is fixed
|
30
42
|
if is_bool:
|
31
43
|
ditok = "DW_ATE_boolean"
|
44
|
+
elif is_int_literal:
|
45
|
+
ditok = "DW_ATE_signed"
|
32
46
|
# GridGroup type should use numba.cuda implementation
|
33
47
|
elif is_grid_group:
|
34
48
|
ditok = "DW_ATE_unsigned"
|
@@ -44,3 +58,79 @@ class CUDADIBuilder(DIBuilder):
|
|
44
58
|
|
45
59
|
# For other cases, use upstream Numba implementation
|
46
60
|
return super()._var_type(lltype, size, datamodel=datamodel)
|
61
|
+
|
62
|
+
def mark_variable(
|
63
|
+
self,
|
64
|
+
builder,
|
65
|
+
allocavalue,
|
66
|
+
name,
|
67
|
+
lltype,
|
68
|
+
size,
|
69
|
+
line,
|
70
|
+
datamodel=None,
|
71
|
+
argidx=None,
|
72
|
+
):
|
73
|
+
if name.startswith("$") or "." in name:
|
74
|
+
# Do not emit llvm.dbg.declare on user variable alias
|
75
|
+
return
|
76
|
+
else:
|
77
|
+
int_type = (ir.IntType,)
|
78
|
+
real_type = ir.FloatType, ir.DoubleType
|
79
|
+
if isinstance(lltype, int_type + real_type):
|
80
|
+
# Start with scalar variable, swtiching llvm.dbg.declare
|
81
|
+
# to llvm.dbg.value
|
82
|
+
return
|
83
|
+
else:
|
84
|
+
return super().mark_variable(
|
85
|
+
builder,
|
86
|
+
allocavalue,
|
87
|
+
name,
|
88
|
+
lltype,
|
89
|
+
size,
|
90
|
+
line,
|
91
|
+
datamodel,
|
92
|
+
argidx,
|
93
|
+
)
|
94
|
+
|
95
|
+
def update_variable(
|
96
|
+
self,
|
97
|
+
builder,
|
98
|
+
value,
|
99
|
+
name,
|
100
|
+
lltype,
|
101
|
+
size,
|
102
|
+
line,
|
103
|
+
datamodel=None,
|
104
|
+
argidx=None,
|
105
|
+
):
|
106
|
+
m = self.module
|
107
|
+
fnty = ir.FunctionType(ir.VoidType(), [ir.MetaDataType()] * 3)
|
108
|
+
decl = cgutils.get_or_insert_function(m, fnty, "llvm.dbg.value")
|
109
|
+
|
110
|
+
mdtype = self._var_type(lltype, size, datamodel)
|
111
|
+
index = name.find(".")
|
112
|
+
if index >= 0:
|
113
|
+
name = name[:index]
|
114
|
+
# Merge DILocalVariable nodes with same name and type but different
|
115
|
+
# lines. Use the cached [(name, type) -> line] info to deduplicate
|
116
|
+
# metadata. Use the lltype as part of key.
|
117
|
+
key = (name, lltype)
|
118
|
+
if key in self._vartypelinemap:
|
119
|
+
line = self._vartypelinemap[key]
|
120
|
+
else:
|
121
|
+
self._vartypelinemap[key] = line
|
122
|
+
arg_index = 0 if argidx is None else argidx
|
123
|
+
mdlocalvar = m.add_debug_info(
|
124
|
+
"DILocalVariable",
|
125
|
+
{
|
126
|
+
"name": name,
|
127
|
+
"arg": arg_index,
|
128
|
+
"scope": self.subprograms[-1],
|
129
|
+
"file": self.difile,
|
130
|
+
"line": line,
|
131
|
+
"type": mdtype,
|
132
|
+
},
|
133
|
+
)
|
134
|
+
mdexpr = m.add_debug_info("DIExpression", {})
|
135
|
+
|
136
|
+
return builder.call(decl, [value, mdlocalvar, mdexpr])
|
@@ -16,7 +16,7 @@ _msg_deprecated_signature_arg = (
|
|
16
16
|
def jit(
|
17
17
|
func_or_sig=None,
|
18
18
|
device=False,
|
19
|
-
inline=
|
19
|
+
inline="never",
|
20
20
|
link=[],
|
21
21
|
debug=None,
|
22
22
|
opt=None,
|
@@ -81,6 +81,15 @@ def jit(
|
|
81
81
|
msg = _msg_deprecated_signature_arg.format("bind")
|
82
82
|
raise DeprecationError(msg)
|
83
83
|
|
84
|
+
if isinstance(inline, bool):
|
85
|
+
DeprecationWarning(
|
86
|
+
"Passing bool to inline argument is deprecated, please refer to "
|
87
|
+
"Numba's documentation on inlining: "
|
88
|
+
"https://numba.readthedocs.io/en/stable/developer/inlining.html"
|
89
|
+
)
|
90
|
+
|
91
|
+
inline = "always" if inline else "never"
|
92
|
+
|
84
93
|
debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
|
85
94
|
opt = (config.OPT != 0) if opt is None else opt
|
86
95
|
fastmath = kws.get("fastmath", False)
|
@@ -130,6 +139,7 @@ def jit(
|
|
130
139
|
targetoptions["opt"] = opt
|
131
140
|
targetoptions["fastmath"] = fastmath
|
132
141
|
targetoptions["device"] = device
|
142
|
+
targetoptions["inline"] = inline
|
133
143
|
targetoptions["extensions"] = extensions
|
134
144
|
|
135
145
|
disp = CUDADispatcher(func, targetoptions=targetoptions)
|
@@ -171,6 +181,7 @@ def jit(
|
|
171
181
|
return jit(
|
172
182
|
func,
|
173
183
|
device=device,
|
184
|
+
inline=inline,
|
174
185
|
debug=debug,
|
175
186
|
opt=opt,
|
176
187
|
lineinfo=lineinfo,
|
@@ -194,6 +205,7 @@ def jit(
|
|
194
205
|
targetoptions["link"] = link
|
195
206
|
targetoptions["fastmath"] = fastmath
|
196
207
|
targetoptions["device"] = device
|
208
|
+
targetoptions["inline"] = inline
|
197
209
|
targetoptions["extensions"] = extensions
|
198
210
|
disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)
|
199
211
|
|
@@ -13,7 +13,6 @@ from .stubs import (
|
|
13
13
|
local,
|
14
14
|
const,
|
15
15
|
atomic,
|
16
|
-
shfl_sync_intrinsic,
|
17
16
|
vote_sync_intrinsic,
|
18
17
|
match_any_sync,
|
19
18
|
match_all_sync,
|
@@ -40,6 +39,10 @@ from .intrinsics import (
|
|
40
39
|
syncthreads_and,
|
41
40
|
syncthreads_count,
|
42
41
|
syncthreads_or,
|
42
|
+
shfl_sync,
|
43
|
+
shfl_up_sync,
|
44
|
+
shfl_down_sync,
|
45
|
+
shfl_xor_sync,
|
43
46
|
)
|
44
47
|
from .cudadrv.error import CudaSupportError
|
45
48
|
from numba.cuda.cudadrv.driver import (
|
@@ -68,10 +71,6 @@ from .intrinsic_wrapper import (
|
|
68
71
|
any_sync,
|
69
72
|
eq_sync,
|
70
73
|
ballot_sync,
|
71
|
-
shfl_sync,
|
72
|
-
shfl_up_sync,
|
73
|
-
shfl_down_sync,
|
74
|
-
shfl_xor_sync,
|
75
74
|
)
|
76
75
|
|
77
76
|
from .kernels import reduction
|
@@ -3,5 +3,59 @@ Added for symmetry with the core API
|
|
3
3
|
"""
|
4
4
|
|
5
5
|
from numba.core.extending import intrinsic as _intrinsic
|
6
|
+
from numba.cuda.models import register_model # noqa: F401
|
7
|
+
from numba.cuda import models # noqa: F401
|
6
8
|
|
7
9
|
intrinsic = _intrinsic(target="cuda")
|
10
|
+
|
11
|
+
|
12
|
+
def make_attribute_wrapper(typeclass, struct_attr, python_attr):
|
13
|
+
"""
|
14
|
+
Make an automatic attribute wrapper exposing member named *struct_attr*
|
15
|
+
as a read-only attribute named *python_attr*.
|
16
|
+
The given *typeclass*'s model must be a StructModel subclass.
|
17
|
+
|
18
|
+
Vendored from numba.core.extending with a change to consider the CUDA data
|
19
|
+
model manager.
|
20
|
+
"""
|
21
|
+
from numba.core.typing.templates import AttributeTemplate
|
22
|
+
|
23
|
+
from numba.core.datamodel import default_manager
|
24
|
+
from numba.core.datamodel.models import StructModel
|
25
|
+
from numba.core.imputils import impl_ret_borrowed
|
26
|
+
from numba.core import cgutils, types
|
27
|
+
|
28
|
+
from numba.cuda.models import cuda_data_manager
|
29
|
+
from numba.cuda.cudadecl import registry as cuda_registry
|
30
|
+
from numba.cuda.cudaimpl import registry as cuda_impl_registry
|
31
|
+
|
32
|
+
data_model_manager = cuda_data_manager.chain(default_manager)
|
33
|
+
|
34
|
+
if not isinstance(typeclass, type) or not issubclass(typeclass, types.Type):
|
35
|
+
raise TypeError(f"typeclass should be a Type subclass, got {typeclass}")
|
36
|
+
|
37
|
+
def get_attr_fe_type(typ):
|
38
|
+
"""
|
39
|
+
Get the Numba type of member *struct_attr* in *typ*.
|
40
|
+
"""
|
41
|
+
model = data_model_manager.lookup(typ)
|
42
|
+
if not isinstance(model, StructModel):
|
43
|
+
raise TypeError(
|
44
|
+
f"make_attribute_wrapper() needs a type with a StructModel, but got {model}"
|
45
|
+
)
|
46
|
+
return model.get_member_fe_type(struct_attr)
|
47
|
+
|
48
|
+
@cuda_registry.register_attr
|
49
|
+
class StructAttribute(AttributeTemplate):
|
50
|
+
key = typeclass
|
51
|
+
|
52
|
+
def generic_resolve(self, typ, attr):
|
53
|
+
if attr == python_attr:
|
54
|
+
return get_attr_fe_type(typ)
|
55
|
+
|
56
|
+
@cuda_impl_registry.lower_getattr(typeclass, python_attr)
|
57
|
+
def struct_getattr_impl(context, builder, typ, val):
|
58
|
+
val = cgutils.create_struct_proxy(typ)(context, builder, value=val)
|
59
|
+
attrty = get_attr_fe_type(typ)
|
60
|
+
attrval = getattr(val, struct_attr)
|
61
|
+
return impl_ret_borrowed(context, builder, attrty, attrval)
|