numba-cuda 0.9.0__py3-none-any.whl → 0.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/compiler.py +35 -3
  3. numba_cuda/numba/cuda/cuda_bf16.py +5155 -0
  4. numba_cuda/numba/cuda/cuda_paths.py +2 -0
  5. numba_cuda/numba/cuda/cudadecl.py +0 -42
  6. numba_cuda/numba/cuda/cudadrv/linkable_code.py +11 -2
  7. numba_cuda/numba/cuda/cudadrv/nvrtc.py +10 -3
  8. numba_cuda/numba/cuda/cudaimpl.py +0 -63
  9. numba_cuda/numba/cuda/debuginfo.py +92 -2
  10. numba_cuda/numba/cuda/decorators.py +27 -1
  11. numba_cuda/numba/cuda/device_init.py +4 -5
  12. numba_cuda/numba/cuda/dispatcher.py +4 -3
  13. numba_cuda/numba/cuda/extending.py +54 -0
  14. numba_cuda/numba/cuda/include/11/cuda_bf16.h +3749 -0
  15. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +2683 -0
  16. numba_cuda/numba/cuda/{cuda_fp16.h → include/11/cuda_fp16.h} +550 -387
  17. numba_cuda/numba/cuda/{cuda_fp16.hpp → include/11/cuda_fp16.hpp} +465 -316
  18. numba_cuda/numba/cuda/include/12/cuda_bf16.h +5118 -0
  19. numba_cuda/numba/cuda/include/12/cuda_bf16.hpp +3865 -0
  20. numba_cuda/numba/cuda/include/12/cuda_fp16.h +5363 -0
  21. numba_cuda/numba/cuda/include/12/cuda_fp16.hpp +3483 -0
  22. numba_cuda/numba/cuda/intrinsic_wrapper.py +0 -39
  23. numba_cuda/numba/cuda/intrinsics.py +172 -1
  24. numba_cuda/numba/cuda/lowering.py +43 -0
  25. numba_cuda/numba/cuda/stubs.py +0 -11
  26. numba_cuda/numba/cuda/target.py +28 -0
  27. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +4 -2
  28. numba_cuda/numba/cuda/tests/cudapy/test_array_args.py +1 -1
  29. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +257 -0
  30. numba_cuda/numba/cuda/tests/cudapy/test_blackscholes.py +1 -1
  31. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +46 -0
  32. numba_cuda/numba/cuda/tests/cudapy/test_enums.py +18 -0
  33. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +4 -2
  34. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +156 -0
  35. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  36. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +50 -5
  37. numba_cuda/numba/cuda/vector_types.py +3 -1
  38. numba_cuda/numba/cuda/vectorizers.py +1 -1
  39. {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.1.dist-info}/METADATA +1 -1
  40. {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.1.dist-info}/RECORD +43 -33
  41. {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.1.dist-info}/WHEEL +1 -1
  42. {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.1.dist-info}/licenses/LICENSE +0 -0
  43. {numba_cuda-0.9.0.dist-info → numba_cuda-0.10.1.dist-info}/top_level.txt +0 -0
@@ -201,6 +201,8 @@ def _get_nvrtc_wheel():
201
201
 
202
202
  def _get_libdevice_paths():
203
203
  by, libdir = _get_libdevice_path_decision()
204
+ if not libdir:
205
+ return _env_path_tuple(by, None)
204
206
  out = os.path.join(libdir, "libdevice.10.bc")
205
207
  return _env_path_tuple(by, out)
206
208
 
@@ -100,45 +100,6 @@ class Cuda_syncwarp(ConcreteTemplate):
100
100
  cases = [signature(types.none), signature(types.none, types.i4)]
101
101
 
102
102
 
103
- @register
104
- class Cuda_shfl_sync_intrinsic(ConcreteTemplate):
105
- key = cuda.shfl_sync_intrinsic
106
- cases = [
107
- signature(
108
- types.Tuple((types.i4, types.b1)),
109
- types.i4,
110
- types.i4,
111
- types.i4,
112
- types.i4,
113
- types.i4,
114
- ),
115
- signature(
116
- types.Tuple((types.i8, types.b1)),
117
- types.i4,
118
- types.i4,
119
- types.i8,
120
- types.i4,
121
- types.i4,
122
- ),
123
- signature(
124
- types.Tuple((types.f4, types.b1)),
125
- types.i4,
126
- types.i4,
127
- types.f4,
128
- types.i4,
129
- types.i4,
130
- ),
131
- signature(
132
- types.Tuple((types.f8, types.b1)),
133
- types.i4,
134
- types.i4,
135
- types.f8,
136
- types.i4,
137
- types.i4,
138
- ),
139
- ]
140
-
141
-
142
103
  @register
143
104
  class Cuda_vote_sync_intrinsic(ConcreteTemplate):
144
105
  key = cuda.vote_sync_intrinsic
@@ -815,9 +776,6 @@ class CudaModuleTemplate(AttributeTemplate):
815
776
  def resolve_syncwarp(self, mod):
816
777
  return types.Function(Cuda_syncwarp)
817
778
 
818
- def resolve_shfl_sync_intrinsic(self, mod):
819
- return types.Function(Cuda_shfl_sync_intrinsic)
820
-
821
779
  def resolve_vote_sync_intrinsic(self, mod):
822
780
  return types.Function(Cuda_vote_sync_intrinsic)
823
781
 
@@ -1,10 +1,13 @@
1
+ import io
1
2
  from .mappings import FILE_EXTENSION_MAP
2
3
 
3
4
 
4
5
  class LinkableCode:
5
6
  """An object that holds code to be linked from memory.
6
7
 
7
- :param data: A buffer containing the data to link.
8
+ :param data: A buffer, StringIO or BytesIO containing the data to link.
9
+ If a file object is passed, the content in the object is
10
+ read when `data` property is accessed.
8
11
  :param name: The name of the file to be referenced in any compilation or
9
12
  linking errors that may be produced.
10
13
  :param setup_callback: A function called prior to the launch of a kernel
@@ -23,8 +26,8 @@ class LinkableCode:
23
26
  if teardown_callback and not callable(teardown_callback):
24
27
  raise TypeError("teardown_callback must be callable")
25
28
 
26
- self.data = data
27
29
  self._name = name
30
+ self._data = data
28
31
  self.setup_callback = setup_callback
29
32
  self.teardown_callback = teardown_callback
30
33
 
@@ -32,6 +35,12 @@ class LinkableCode:
32
35
  def name(self):
33
36
  return self._name or self.default_name
34
37
 
38
+ @property
39
+ def data(self):
40
+ if isinstance(self._data, (io.StringIO, io.BytesIO)):
41
+ return self._data.getvalue()
42
+ return self._data
43
+
35
44
 
36
45
  class PTXSource(LinkableCode):
37
46
  """PTX source code in memory."""
@@ -372,19 +372,26 @@ def compile(src, name, cc, ltoir=False):
372
372
  f"-I{get_cuda_paths()['include_dir'].info}",
373
373
  ]
374
374
 
375
+ nvrtc_version = nvrtc.get_version()
376
+ nvrtc_ver_major = nvrtc_version[0]
377
+
375
378
  cudadrv_path = os.path.dirname(os.path.abspath(__file__))
376
379
  numba_cuda_path = os.path.dirname(cudadrv_path)
377
- numba_include = f"-I{numba_cuda_path}"
380
+
381
+ if nvrtc_ver_major == 11:
382
+ numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '11')}"
383
+ else:
384
+ numba_include = f"-I{os.path.join(numba_cuda_path, 'include', '12')}"
378
385
 
379
386
  nrt_path = os.path.join(numba_cuda_path, "runtime")
380
387
  nrt_include = f"-I{nrt_path}"
381
388
 
382
- options = [arch, *cuda_include, numba_include, nrt_include, "-rdc", "true"]
389
+ options = [arch, numba_include, *cuda_include, nrt_include, "-rdc", "true"]
383
390
 
384
391
  if ltoir:
385
392
  options.append("-dlto")
386
393
 
387
- if nvrtc.get_version() < (12, 0):
394
+ if nvrtc_version < (12, 0):
388
395
  options += ["-std=c++17"]
389
396
 
390
397
  # Compile the program
@@ -204,69 +204,6 @@ def ptx_syncwarp_mask(context, builder, sig, args):
204
204
  return context.get_dummy_value()
205
205
 
206
206
 
207
- @lower(
208
- stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i4, types.i4, types.i4
209
- )
210
- @lower(
211
- stubs.shfl_sync_intrinsic, types.i4, types.i4, types.i8, types.i4, types.i4
212
- )
213
- @lower(
214
- stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f4, types.i4, types.i4
215
- )
216
- @lower(
217
- stubs.shfl_sync_intrinsic, types.i4, types.i4, types.f8, types.i4, types.i4
218
- )
219
- def ptx_shfl_sync_i32(context, builder, sig, args):
220
- """
221
- The NVVM intrinsic for shfl only supports i32, but the cuda intrinsic
222
- function supports both 32 and 64 bit ints and floats, so for feature parity,
223
- i64, f32, and f64 are implemented. Floats by way of bitcasting the float to
224
- an int, then shuffling, then bitcasting back. And 64-bit values by packing
225
- them into 2 32bit values, shuffling thoose, and then packing back together.
226
- """
227
- mask, mode, value, index, clamp = args
228
- value_type = sig.args[2]
229
- if value_type in types.real_domain:
230
- value = builder.bitcast(value, ir.IntType(value_type.bitwidth))
231
- fname = "llvm.nvvm.shfl.sync.i32"
232
- lmod = builder.module
233
- fnty = ir.FunctionType(
234
- ir.LiteralStructType((ir.IntType(32), ir.IntType(1))),
235
- (
236
- ir.IntType(32),
237
- ir.IntType(32),
238
- ir.IntType(32),
239
- ir.IntType(32),
240
- ir.IntType(32),
241
- ),
242
- )
243
- func = cgutils.get_or_insert_function(lmod, fnty, fname)
244
- if value_type.bitwidth == 32:
245
- ret = builder.call(func, (mask, mode, value, index, clamp))
246
- if value_type == types.float32:
247
- rv = builder.extract_value(ret, 0)
248
- pred = builder.extract_value(ret, 1)
249
- fv = builder.bitcast(rv, ir.FloatType())
250
- ret = cgutils.make_anonymous_struct(builder, (fv, pred))
251
- else:
252
- value1 = builder.trunc(value, ir.IntType(32))
253
- value_lshr = builder.lshr(value, context.get_constant(types.i8, 32))
254
- value2 = builder.trunc(value_lshr, ir.IntType(32))
255
- ret1 = builder.call(func, (mask, mode, value1, index, clamp))
256
- ret2 = builder.call(func, (mask, mode, value2, index, clamp))
257
- rv1 = builder.extract_value(ret1, 0)
258
- rv2 = builder.extract_value(ret2, 0)
259
- pred = builder.extract_value(ret1, 1)
260
- rv1_64 = builder.zext(rv1, ir.IntType(64))
261
- rv2_64 = builder.zext(rv2, ir.IntType(64))
262
- rv_shl = builder.shl(rv2_64, context.get_constant(types.i8, 32))
263
- rv = builder.or_(rv_shl, rv1_64)
264
- if value_type == types.float64:
265
- rv = builder.bitcast(rv, ir.DoubleType())
266
- ret = cgutils.make_anonymous_struct(builder, (rv, pred))
267
- return ret
268
-
269
-
270
207
  @lower(stubs.vote_sync_intrinsic, types.i4, types.i4, types.boolean)
271
208
  def ptx_vote_sync(context, builder, sig, args):
272
209
  fname = "llvm.nvvm.vote.sync"
@@ -1,5 +1,5 @@
1
1
  from llvmlite import ir
2
- from numba.core import types
2
+ from numba.core import types, cgutils
3
3
  from numba.core.debuginfo import DIBuilder
4
4
  from numba.cuda.types import GridGroup
5
5
 
@@ -7,8 +7,14 @@ _BYTE_SIZE = 8
7
7
 
8
8
 
9
9
  class CUDADIBuilder(DIBuilder):
10
+ def __init__(self, module, filepath, cgctx, directives_only):
11
+ super().__init__(module, filepath, cgctx, directives_only)
12
+ # Cache for local variable metadata type and line deduplication
13
+ self._vartypelinemap = {}
14
+
10
15
  def _var_type(self, lltype, size, datamodel=None):
11
16
  is_bool = False
17
+ is_int_literal = False
12
18
  is_grid_group = False
13
19
 
14
20
  if isinstance(lltype, ir.IntType):
@@ -20,15 +26,23 @@ class CUDADIBuilder(DIBuilder):
20
26
  name = str(datamodel.fe_type)
21
27
  if isinstance(datamodel.fe_type, types.Boolean):
22
28
  is_bool = True
29
+ if isinstance(datamodel.fe_type, types.BooleanLiteral):
30
+ name = "bool"
31
+ elif isinstance(datamodel.fe_type, types.Integer):
32
+ if isinstance(datamodel.fe_type, types.IntegerLiteral):
33
+ name = f"int{_BYTE_SIZE * size}"
34
+ is_int_literal = True
23
35
  elif isinstance(datamodel.fe_type, GridGroup):
24
36
  is_grid_group = True
25
37
 
26
- if is_bool or is_grid_group:
38
+ if is_bool or is_int_literal or is_grid_group:
27
39
  m = self.module
28
40
  bitsize = _BYTE_SIZE * size
29
41
  # Boolean type workaround until upstream Numba is fixed
30
42
  if is_bool:
31
43
  ditok = "DW_ATE_boolean"
44
+ elif is_int_literal:
45
+ ditok = "DW_ATE_signed"
32
46
  # GridGroup type should use numba.cuda implementation
33
47
  elif is_grid_group:
34
48
  ditok = "DW_ATE_unsigned"
@@ -44,3 +58,79 @@ class CUDADIBuilder(DIBuilder):
44
58
 
45
59
  # For other cases, use upstream Numba implementation
46
60
  return super()._var_type(lltype, size, datamodel=datamodel)
61
+
62
+ def mark_variable(
63
+ self,
64
+ builder,
65
+ allocavalue,
66
+ name,
67
+ lltype,
68
+ size,
69
+ line,
70
+ datamodel=None,
71
+ argidx=None,
72
+ ):
73
+ if name.startswith("$") or "." in name:
74
+ # Do not emit llvm.dbg.declare on user variable alias
75
+ return
76
+ else:
77
+ int_type = (ir.IntType,)
78
+ real_type = ir.FloatType, ir.DoubleType
79
+ if isinstance(lltype, int_type + real_type):
80
+ # Start with scalar variable, swtiching llvm.dbg.declare
81
+ # to llvm.dbg.value
82
+ return
83
+ else:
84
+ return super().mark_variable(
85
+ builder,
86
+ allocavalue,
87
+ name,
88
+ lltype,
89
+ size,
90
+ line,
91
+ datamodel,
92
+ argidx,
93
+ )
94
+
95
+ def update_variable(
96
+ self,
97
+ builder,
98
+ value,
99
+ name,
100
+ lltype,
101
+ size,
102
+ line,
103
+ datamodel=None,
104
+ argidx=None,
105
+ ):
106
+ m = self.module
107
+ fnty = ir.FunctionType(ir.VoidType(), [ir.MetaDataType()] * 3)
108
+ decl = cgutils.get_or_insert_function(m, fnty, "llvm.dbg.value")
109
+
110
+ mdtype = self._var_type(lltype, size, datamodel)
111
+ index = name.find(".")
112
+ if index >= 0:
113
+ name = name[:index]
114
+ # Merge DILocalVariable nodes with same name and type but different
115
+ # lines. Use the cached [(name, type) -> line] info to deduplicate
116
+ # metadata. Use the lltype as part of key.
117
+ key = (name, lltype)
118
+ if key in self._vartypelinemap:
119
+ line = self._vartypelinemap[key]
120
+ else:
121
+ self._vartypelinemap[key] = line
122
+ arg_index = 0 if argidx is None else argidx
123
+ mdlocalvar = m.add_debug_info(
124
+ "DILocalVariable",
125
+ {
126
+ "name": name,
127
+ "arg": arg_index,
128
+ "scope": self.subprograms[-1],
129
+ "file": self.difile,
130
+ "line": line,
131
+ "type": mdtype,
132
+ },
133
+ )
134
+ mdexpr = m.add_debug_info("DIExpression", {})
135
+
136
+ return builder.call(decl, [value, mdlocalvar, mdexpr])
@@ -16,7 +16,8 @@ _msg_deprecated_signature_arg = (
16
16
  def jit(
17
17
  func_or_sig=None,
18
18
  device=False,
19
- inline=False,
19
+ inline="never",
20
+ forceinline=False,
20
21
  link=[],
21
22
  debug=None,
22
23
  opt=None,
@@ -39,6 +40,14 @@ def jit(
39
40
  .. note:: A kernel cannot have any return value.
40
41
  :param device: Indicates whether this is a device function.
41
42
  :type device: bool
43
+ :param inline: Enables inlining at the Numba IR level when set to
44
+ ``"always"``. See `Notes on Inlining
45
+ <https://numba.readthedocs.io/en/stable/developer/inlining.html>`_.
46
+ :type inline: str
47
+ :param forceinline: Enables inlining at the NVVM IR level when set to
48
+ ``True``. This is accomplished by adding the ``alwaysinline`` function
49
+ attribute to the function definition.
50
+ :type forceinline: bool
42
51
  :param link: A list of files containing PTX or CUDA C/C++ source to link
43
52
  with the function
44
53
  :type link: list
@@ -81,6 +90,17 @@ def jit(
81
90
  msg = _msg_deprecated_signature_arg.format("bind")
82
91
  raise DeprecationError(msg)
83
92
 
93
+ if isinstance(inline, bool):
94
+ DeprecationWarning(
95
+ "Passing bool to inline argument is deprecated, please refer to "
96
+ "Numba's documentation on inlining: "
97
+ "https://numba.readthedocs.io/en/stable/developer/inlining.html. "
98
+ "You may have wanted the forceinline argument instead, to force "
99
+ "inlining at the NVVM IR level."
100
+ )
101
+
102
+ inline = "always" if inline else "never"
103
+
84
104
  debug = config.CUDA_DEBUGINFO_DEFAULT if debug is None else debug
85
105
  opt = (config.OPT != 0) if opt is None else opt
86
106
  fastmath = kws.get("fastmath", False)
@@ -130,6 +150,8 @@ def jit(
130
150
  targetoptions["opt"] = opt
131
151
  targetoptions["fastmath"] = fastmath
132
152
  targetoptions["device"] = device
153
+ targetoptions["inline"] = inline
154
+ targetoptions["forceinline"] = forceinline
133
155
  targetoptions["extensions"] = extensions
134
156
 
135
157
  disp = CUDADispatcher(func, targetoptions=targetoptions)
@@ -171,6 +193,8 @@ def jit(
171
193
  return jit(
172
194
  func,
173
195
  device=device,
196
+ inline=inline,
197
+ forceinline=forceinline,
174
198
  debug=debug,
175
199
  opt=opt,
176
200
  lineinfo=lineinfo,
@@ -194,6 +218,8 @@ def jit(
194
218
  targetoptions["link"] = link
195
219
  targetoptions["fastmath"] = fastmath
196
220
  targetoptions["device"] = device
221
+ targetoptions["inline"] = inline
222
+ targetoptions["forceinline"] = forceinline
197
223
  targetoptions["extensions"] = extensions
198
224
  disp = CUDADispatcher(func_or_sig, targetoptions=targetoptions)
199
225
 
@@ -13,7 +13,6 @@ from .stubs import (
13
13
  local,
14
14
  const,
15
15
  atomic,
16
- shfl_sync_intrinsic,
17
16
  vote_sync_intrinsic,
18
17
  match_any_sync,
19
18
  match_all_sync,
@@ -40,6 +39,10 @@ from .intrinsics import (
40
39
  syncthreads_and,
41
40
  syncthreads_count,
42
41
  syncthreads_or,
42
+ shfl_sync,
43
+ shfl_up_sync,
44
+ shfl_down_sync,
45
+ shfl_xor_sync,
43
46
  )
44
47
  from .cudadrv.error import CudaSupportError
45
48
  from numba.cuda.cudadrv.driver import (
@@ -68,10 +71,6 @@ from .intrinsic_wrapper import (
68
71
  any_sync,
69
72
  eq_sync,
70
73
  ballot_sync,
71
- shfl_sync,
72
- shfl_up_sync,
73
- shfl_down_sync,
74
- shfl_xor_sync,
75
74
  )
76
75
 
77
76
  from .kernels import reduction
@@ -137,6 +137,7 @@ class _Kernel(serialize.ReduceMixin):
137
137
  debug=False,
138
138
  lineinfo=False,
139
139
  inline=False,
140
+ forceinline=False,
140
141
  fastmath=False,
141
142
  extensions=None,
142
143
  max_registers=None,
@@ -182,7 +183,7 @@ class _Kernel(serialize.ReduceMixin):
182
183
  self.argtypes,
183
184
  debug=self.debug,
184
185
  lineinfo=lineinfo,
185
- inline=inline,
186
+ forceinline=forceinline,
186
187
  fastmath=fastmath,
187
188
  nvvm_options=nvvm_options,
188
189
  cc=cc,
@@ -1073,7 +1074,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1073
1074
  with self._compiling_counter:
1074
1075
  debug = self.targetoptions.get("debug")
1075
1076
  lineinfo = self.targetoptions.get("lineinfo")
1076
- inline = self.targetoptions.get("inline")
1077
+ forceinline = self.targetoptions.get("forceinline")
1077
1078
  fastmath = self.targetoptions.get("fastmath")
1078
1079
 
1079
1080
  nvvm_options = {
@@ -1091,7 +1092,7 @@ class CUDADispatcher(Dispatcher, serialize.ReduceMixin):
1091
1092
  args,
1092
1093
  debug=debug,
1093
1094
  lineinfo=lineinfo,
1094
- inline=inline,
1095
+ forceinline=forceinline,
1095
1096
  fastmath=fastmath,
1096
1097
  nvvm_options=nvvm_options,
1097
1098
  cc=cc,
@@ -3,5 +3,59 @@ Added for symmetry with the core API
3
3
  """
4
4
 
5
5
  from numba.core.extending import intrinsic as _intrinsic
6
+ from numba.cuda.models import register_model # noqa: F401
7
+ from numba.cuda import models # noqa: F401
6
8
 
7
9
  intrinsic = _intrinsic(target="cuda")
10
+
11
+
12
+ def make_attribute_wrapper(typeclass, struct_attr, python_attr):
13
+ """
14
+ Make an automatic attribute wrapper exposing member named *struct_attr*
15
+ as a read-only attribute named *python_attr*.
16
+ The given *typeclass*'s model must be a StructModel subclass.
17
+
18
+ Vendored from numba.core.extending with a change to consider the CUDA data
19
+ model manager.
20
+ """
21
+ from numba.core.typing.templates import AttributeTemplate
22
+
23
+ from numba.core.datamodel import default_manager
24
+ from numba.core.datamodel.models import StructModel
25
+ from numba.core.imputils import impl_ret_borrowed
26
+ from numba.core import cgutils, types
27
+
28
+ from numba.cuda.models import cuda_data_manager
29
+ from numba.cuda.cudadecl import registry as cuda_registry
30
+ from numba.cuda.cudaimpl import registry as cuda_impl_registry
31
+
32
+ data_model_manager = cuda_data_manager.chain(default_manager)
33
+
34
+ if not isinstance(typeclass, type) or not issubclass(typeclass, types.Type):
35
+ raise TypeError(f"typeclass should be a Type subclass, got {typeclass}")
36
+
37
+ def get_attr_fe_type(typ):
38
+ """
39
+ Get the Numba type of member *struct_attr* in *typ*.
40
+ """
41
+ model = data_model_manager.lookup(typ)
42
+ if not isinstance(model, StructModel):
43
+ raise TypeError(
44
+ f"make_attribute_wrapper() needs a type with a StructModel, but got {model}"
45
+ )
46
+ return model.get_member_fe_type(struct_attr)
47
+
48
+ @cuda_registry.register_attr
49
+ class StructAttribute(AttributeTemplate):
50
+ key = typeclass
51
+
52
+ def generic_resolve(self, typ, attr):
53
+ if attr == python_attr:
54
+ return get_attr_fe_type(typ)
55
+
56
+ @cuda_impl_registry.lower_getattr(typeclass, python_attr)
57
+ def struct_getattr_impl(context, builder, typ, val):
58
+ val = cgutils.create_struct_proxy(typ)(context, builder, value=val)
59
+ attrty = get_attr_fe_type(typ)
60
+ attrval = getattr(val, struct_attr)
61
+ return impl_ret_borrowed(context, builder, attrty, attrval)