numba-cuda 0.18.0__py3-none-any.whl → 0.19.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (90) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +2 -2
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +1 -1
  5. numba_cuda/numba/cuda/api.py +2 -7
  6. numba_cuda/numba/cuda/compiler.py +7 -4
  7. numba_cuda/numba/cuda/core/interpreter.py +3592 -0
  8. numba_cuda/numba/cuda/core/ir_utils.py +2645 -0
  9. numba_cuda/numba/cuda/core/sigutils.py +55 -0
  10. numba_cuda/numba/cuda/cuda_paths.py +9 -17
  11. numba_cuda/numba/cuda/cudadecl.py +1 -1
  12. numba_cuda/numba/cuda/cudadrv/driver.py +4 -19
  13. numba_cuda/numba/cuda/cudadrv/libs.py +1 -2
  14. numba_cuda/numba/cuda/cudadrv/nvrtc.py +44 -44
  15. numba_cuda/numba/cuda/cudadrv/nvvm.py +3 -18
  16. numba_cuda/numba/cuda/cudadrv/runtime.py +12 -1
  17. numba_cuda/numba/cuda/cudamath.py +1 -1
  18. numba_cuda/numba/cuda/decorators.py +4 -3
  19. numba_cuda/numba/cuda/deviceufunc.py +2 -1
  20. numba_cuda/numba/cuda/dispatcher.py +5 -3
  21. numba_cuda/numba/cuda/extending.py +1 -1
  22. numba_cuda/numba/cuda/itanium_mangler.py +211 -0
  23. numba_cuda/numba/cuda/libdevicedecl.py +1 -1
  24. numba_cuda/numba/cuda/libdevicefuncs.py +1 -1
  25. numba_cuda/numba/cuda/lowering.py +1 -1
  26. numba_cuda/numba/cuda/simulator/api.py +1 -1
  27. numba_cuda/numba/cuda/simulator/cudadrv/driver.py +0 -7
  28. numba_cuda/numba/cuda/target.py +1 -2
  29. numba_cuda/numba/cuda/testing.py +4 -6
  30. numba_cuda/numba/cuda/tests/core/test_itanium_mangler.py +80 -0
  31. numba_cuda/numba/cuda/tests/cudadrv/test_context_stack.py +10 -4
  32. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_ndarray.py +1 -1
  33. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  34. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +1 -1
  35. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  36. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +1 -1
  37. numba_cuda/numba/cuda/tests/cudadrv/test_managed_alloc.py +1 -1
  38. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +1 -1
  39. numba_cuda/numba/cuda/tests/cudadrv/test_nvrtc.py +4 -6
  40. numba_cuda/numba/cuda/tests/cudadrv/test_nvvm_driver.py +0 -4
  41. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  42. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +1 -3
  43. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +1 -3
  44. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +146 -3
  45. numba_cuda/numba/cuda/tests/cudapy/test_cffi.py +1 -1
  46. numba_cuda/numba/cuda/tests/cudapy/test_compiler.py +15 -4
  47. numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py +1 -1
  48. numba_cuda/numba/cuda/tests/cudapy/test_cuda_jit_no_types.py +1 -1
  49. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  50. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +1 -284
  51. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +473 -0
  52. numba_cuda/numba/cuda/tests/cudapy/test_device_func.py +1 -1
  53. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  54. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +1 -6
  55. numba_cuda/numba/cuda/tests/cudapy/test_gufunc.py +1 -1
  56. numba_cuda/numba/cuda/tests/cudapy/test_inspect.py +14 -0
  57. numba_cuda/numba/cuda/tests/cudapy/test_ipc.py +1 -1
  58. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +295 -0
  59. numba_cuda/numba/cuda/tests/cudapy/test_lineinfo.py +1 -1
  60. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  61. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +1 -1
  62. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +5 -1
  63. numba_cuda/numba/cuda/tests/doc_examples/test_cpointer.py +1 -1
  64. numba_cuda/numba/cuda/tests/doc_examples/test_cpu_gpu_compat.py +1 -1
  65. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +1 -1
  66. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +1 -1
  67. numba_cuda/numba/cuda/tests/doc_examples/test_matmul.py +1 -1
  68. numba_cuda/numba/cuda/tests/doc_examples/test_montecarlo.py +1 -1
  69. numba_cuda/numba/cuda/tests/doc_examples/test_reduction.py +1 -1
  70. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +1 -1
  71. numba_cuda/numba/cuda/tests/doc_examples/test_ufunc.py +1 -1
  72. numba_cuda/numba/cuda/tests/doc_examples/test_vecadd.py +1 -1
  73. numba_cuda/numba/cuda/tests/nocuda/test_import.py +1 -1
  74. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -2
  75. numba_cuda/numba/cuda/tests/nrt/test_nrt_refct.py +1 -1
  76. numba_cuda/numba/cuda/tests/support.py +752 -0
  77. numba_cuda/numba/cuda/tests/test_binary_generation/Makefile +3 -3
  78. numba_cuda/numba/cuda/tests/test_binary_generation/generate_raw_ltoir.py +4 -1
  79. numba_cuda/numba/cuda/typing/__init__.py +8 -0
  80. numba_cuda/numba/cuda/typing/templates.py +1453 -0
  81. numba_cuda/numba/cuda/vector_types.py +3 -3
  82. {numba_cuda-0.18.0.dist-info → numba_cuda-0.19.0.dist-info}/METADATA +21 -28
  83. {numba_cuda-0.18.0.dist-info → numba_cuda-0.19.0.dist-info}/RECORD +86 -81
  84. numba_cuda/numba/cuda/include/11/cuda_bf16.h +0 -3749
  85. numba_cuda/numba/cuda/include/11/cuda_bf16.hpp +0 -2683
  86. numba_cuda/numba/cuda/include/11/cuda_fp16.h +0 -3794
  87. numba_cuda/numba/cuda/include/11/cuda_fp16.hpp +0 -2614
  88. {numba_cuda-0.18.0.dist-info → numba_cuda-0.19.0.dist-info}/WHEEL +0 -0
  89. {numba_cuda-0.18.0.dist-info → numba_cuda-0.19.0.dist-info}/licenses/LICENSE +0 -0
  90. {numba_cuda-0.18.0.dist-info → numba_cuda-0.19.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1453 @@
1
+ """
2
+ Define typing templates
3
+ """
4
+
5
+ from abc import ABC, abstractmethod
6
+ import functools
7
+ import sys
8
+ import inspect
9
+ import os.path
10
+ from collections import namedtuple
11
+ from collections.abc import Sequence
12
+ from types import MethodType, FunctionType, MappingProxyType
13
+
14
+ import numba
15
+ from numba.core import types, utils, targetconfig
16
+ from numba.core.errors import (
17
+ TypingError,
18
+ InternalError,
19
+ )
20
+ from numba.core.cpu_options import InlineOptions
21
+ from numba.core.typing.templates import Signature as CoreSignature
22
+ from numba.cuda.core import ir_utils
23
+
24
+ # info store for inliner callback functions e.g. cost model
25
+ _inline_info = namedtuple("inline_info", "func_ir typemap calltypes signature")
26
+
27
+
28
+ # HACK: Remove this inheritance once all references to CoreSignature are removed
29
+ class Signature(CoreSignature):
30
+ """
31
+ The signature of a function call or operation, i.e. its argument types
32
+ and return type.
33
+ """
34
+
35
+ # XXX Perhaps the signature should be a BoundArguments, instead
36
+ # of separate args and pysig...
37
+ __slots__ = "_return_type", "_args", "_recvr", "_pysig"
38
+
39
+ def __init__(self, return_type, args, recvr, pysig=None):
40
+ if isinstance(args, list):
41
+ args = tuple(args)
42
+ self._return_type = return_type
43
+ self._args = args
44
+ self._recvr = recvr
45
+ self._pysig = pysig
46
+
47
+ @property
48
+ def return_type(self):
49
+ return self._return_type
50
+
51
+ @property
52
+ def args(self):
53
+ return self._args
54
+
55
+ @property
56
+ def recvr(self):
57
+ return self._recvr
58
+
59
+ @property
60
+ def pysig(self):
61
+ return self._pysig
62
+
63
+ def replace(self, **kwargs):
64
+ """Copy and replace the given attributes provided as keyword arguments.
65
+ Returns an updated copy.
66
+ """
67
+ curstate = dict(
68
+ return_type=self.return_type,
69
+ args=self.args,
70
+ recvr=self.recvr,
71
+ pysig=self.pysig,
72
+ )
73
+ curstate.update(kwargs)
74
+ return Signature(**curstate)
75
+
76
+ def __getstate__(self):
77
+ """
78
+ Needed because of __slots__.
79
+ """
80
+ return self._return_type, self._args, self._recvr, self._pysig
81
+
82
+ def __setstate__(self, state):
83
+ """
84
+ Needed because of __slots__.
85
+ """
86
+ self._return_type, self._args, self._recvr, self._pysig = state
87
+
88
+ def __hash__(self):
89
+ return hash((self.args, self.return_type))
90
+
91
+ def __eq__(self, other):
92
+ if isinstance(other, Signature):
93
+ return (
94
+ self.args == other.args
95
+ and self.return_type == other.return_type
96
+ and self.recvr == other.recvr
97
+ and self.pysig == other.pysig
98
+ )
99
+
100
+ def __ne__(self, other):
101
+ return not (self == other)
102
+
103
+ def __repr__(self):
104
+ return "%s -> %s" % (self.args, self.return_type)
105
+
106
+ @property
107
+ def is_method(self):
108
+ """
109
+ Whether this signature represents a bound method or a regular
110
+ function.
111
+ """
112
+ return self.recvr is not None
113
+
114
+ def as_method(self):
115
+ """
116
+ Convert this signature to a bound method signature.
117
+ """
118
+ if self.recvr is not None:
119
+ return self
120
+ sig = signature(self.return_type, *self.args[1:], recvr=self.args[0])
121
+
122
+ # Adjust the python signature
123
+ params = list(self.pysig.parameters.values())[1:]
124
+ sig = sig.replace(
125
+ pysig=utils.pySignature(
126
+ parameters=params,
127
+ return_annotation=self.pysig.return_annotation,
128
+ ),
129
+ )
130
+ return sig
131
+
132
+ def as_function(self):
133
+ """
134
+ Convert this signature to a regular function signature.
135
+ """
136
+ if self.recvr is None:
137
+ return self
138
+ sig = signature(self.return_type, *((self.recvr,) + self.args))
139
+ return sig
140
+
141
+ def as_type(self):
142
+ """
143
+ Convert this signature to a first-class function type.
144
+ """
145
+ return types.FunctionType(self)
146
+
147
+ def __unliteral__(self):
148
+ return signature(
149
+ types.unliteral(self.return_type), *map(types.unliteral, self.args)
150
+ )
151
+
152
+ def dump(self, tab=""):
153
+ c = self.as_type()._code
154
+ print(f"{tab}DUMP {type(self).__name__} [type code: {c}]")
155
+ print(f"{tab} Argument types:")
156
+ for a in self.args:
157
+ a.dump(tab=tab + " | ")
158
+ print(f"{tab} Return type:")
159
+ self.return_type.dump(tab=tab + " | ")
160
+ print(f"{tab}END DUMP")
161
+
162
+ def is_precise(self):
163
+ for atype in self.args:
164
+ if not atype.is_precise():
165
+ return False
166
+ return self.return_type.is_precise()
167
+
168
+
169
+ def make_concrete_template(name, key, signatures):
170
+ baseclasses = (ConcreteTemplate,)
171
+ gvars = dict(key=key, cases=list(signatures))
172
+ return type(name, baseclasses, gvars)
173
+
174
+
175
+ def make_callable_template(key, typer, recvr=None):
176
+ """
177
+ Create a callable template with the given key and typer function.
178
+ """
179
+
180
+ def generic(self):
181
+ return typer
182
+
183
+ name = "%s_CallableTemplate" % (key,)
184
+ bases = (CallableTemplate,)
185
+ class_dict = dict(key=key, generic=generic, recvr=recvr)
186
+ return type(name, bases, class_dict)
187
+
188
+
189
+ def signature(return_type, *args, **kws):
190
+ recvr = kws.pop("recvr", None)
191
+ assert not kws
192
+ return Signature(return_type, args, recvr=recvr)
193
+
194
+
195
+ def fold_arguments(
196
+ pysig, args, kws, normal_handler, default_handler, stararg_handler
197
+ ):
198
+ """
199
+ Given the signature *pysig*, explicit *args* and *kws*, resolve
200
+ omitted arguments and keyword arguments. A tuple of positional
201
+ arguments is returned.
202
+ Various handlers allow to process arguments:
203
+ - normal_handler(index, param, value) is called for normal arguments
204
+ - default_handler(index, param, default) is called for omitted arguments
205
+ - stararg_handler(index, param, values) is called for a "*args" argument
206
+ """
207
+ if isinstance(kws, Sequence):
208
+ # Normalize dict kws
209
+ kws = dict(kws)
210
+
211
+ # deal with kwonly args
212
+ params = pysig.parameters
213
+ kwonly = []
214
+ for name, p in params.items():
215
+ if p.kind == p.KEYWORD_ONLY:
216
+ kwonly.append(name)
217
+
218
+ if kwonly:
219
+ bind_args = args[: -len(kwonly)]
220
+ else:
221
+ bind_args = args
222
+ bind_kws = kws.copy()
223
+ if kwonly:
224
+ for idx, n in enumerate(kwonly):
225
+ bind_kws[n] = args[len(kwonly) + idx]
226
+
227
+ # now bind
228
+ try:
229
+ ba = pysig.bind(*bind_args, **bind_kws)
230
+ except TypeError as e:
231
+ # The binding attempt can raise if the args don't match up, this needs
232
+ # to be converted to a TypingError so that e.g. partial type inference
233
+ # doesn't just halt.
234
+ msg = (
235
+ f"Cannot bind 'args={bind_args} kws={bind_kws}' to "
236
+ f"signature '{pysig}' due to \"{type(e).__name__}: {e}\"."
237
+ )
238
+ raise TypingError(msg)
239
+ for i, param in enumerate(pysig.parameters.values()):
240
+ name = param.name
241
+ default = param.default
242
+ if param.kind == param.VAR_POSITIONAL:
243
+ # stararg may be omitted, in which case its "default" value
244
+ # is simply the empty tuple
245
+ if name in ba.arguments:
246
+ argval = ba.arguments[name]
247
+ # NOTE: avoid wrapping the tuple type for stararg in another
248
+ # tuple.
249
+ if len(argval) == 1 and isinstance(
250
+ argval[0], (types.StarArgTuple, types.StarArgUniTuple)
251
+ ):
252
+ argval = tuple(argval[0])
253
+ else:
254
+ argval = ()
255
+ out = stararg_handler(i, param, argval)
256
+
257
+ ba.arguments[name] = out
258
+ elif name in ba.arguments:
259
+ # Non-stararg, present
260
+ ba.arguments[name] = normal_handler(i, param, ba.arguments[name])
261
+ else:
262
+ # Non-stararg, omitted
263
+ assert default is not param.empty
264
+ ba.arguments[name] = default_handler(i, param, default)
265
+ # Collect args in the right order
266
+ args = tuple(
267
+ ba.arguments[param.name] for param in pysig.parameters.values()
268
+ )
269
+ return args
270
+
271
+
272
+ class FunctionTemplate(ABC):
273
+ # Set to true to disable unsafe cast.
274
+ # subclass overide-able
275
+ unsafe_casting = True
276
+ # Set to true to require exact match without casting.
277
+ # subclass overide-able
278
+ exact_match_required = False
279
+ # Set to true to prefer literal arguments.
280
+ # Useful for definitions that specialize on literal but also support
281
+ # non-literals.
282
+ # subclass overide-able
283
+ prefer_literal = False
284
+ # metadata
285
+ metadata = {}
286
+
287
+ def __init__(self, context):
288
+ self.context = context
289
+
290
+ def _select(self, cases, args, kws):
291
+ options = {
292
+ "unsafe_casting": self.unsafe_casting,
293
+ "exact_match_required": self.exact_match_required,
294
+ }
295
+ selected = self.context.resolve_overload(
296
+ self.key, cases, args, kws, **options
297
+ )
298
+ return selected
299
+
300
+ def get_impl_key(self, sig):
301
+ """
302
+ Return the key for looking up the implementation for the given
303
+ signature on the target context.
304
+ """
305
+ # Lookup the key on the class, to avoid binding it with `self`.
306
+ key = type(self).key
307
+ # On Python 2, we must also take care about unbound methods
308
+ if isinstance(key, MethodType):
309
+ assert key.im_self is None
310
+ key = key.im_func
311
+ return key
312
+
313
+ @classmethod
314
+ def get_source_code_info(cls, impl):
315
+ """
316
+ Gets the source information about function impl.
317
+ Returns:
318
+
319
+ code - str: source code as a string
320
+ firstlineno - int: the first line number of the function impl
321
+ path - str: the path to file containing impl
322
+
323
+ if any of the above are not available something generic is returned
324
+ """
325
+ try:
326
+ code, firstlineno = inspect.getsourcelines(impl)
327
+ except OSError: # missing source, probably a string
328
+ code = "None available (built from string?)"
329
+ firstlineno = 0
330
+ path = inspect.getsourcefile(impl)
331
+ if path is None:
332
+ path = "<unknown> (built from string?)"
333
+ return code, firstlineno, path
334
+
335
+ @abstractmethod
336
+ def get_template_info(self):
337
+ """
338
+ Returns a dictionary with information specific to the template that will
339
+ govern how error messages are displayed to users. The dictionary must
340
+ be of the form:
341
+ info = {
342
+ 'kind': "unknown", # str: The kind of template, e.g. "Overload"
343
+ 'name': "unknown", # str: The name of the source function
344
+ 'sig': "unknown", # str: The signature(s) of the source function
345
+ 'filename': "unknown", # str: The filename of the source function
346
+ 'lines': ("start", "end"), # tuple(int, int): The start and
347
+ end line of the source function.
348
+ 'docstring': "unknown" # str: The docstring of the source function
349
+ }
350
+ """
351
+ pass
352
+
353
+ def __str__(self):
354
+ info = self.get_template_info()
355
+ srcinfo = f"{info['filename']}:{info['lines'][0]}"
356
+ return f"<{self.__class__.__name__} {srcinfo}>"
357
+
358
+ __repr__ = __str__
359
+
360
+
361
+ class AbstractTemplate(FunctionTemplate):
362
+ """
363
+ Defines method ``generic(self, args, kws)`` which compute a possible
364
+ signature base on input types. The signature does not have to match the
365
+ input types. It is compared against the input types afterwards.
366
+ """
367
+
368
+ def apply(self, args, kws):
369
+ generic = getattr(self, "generic")
370
+ sig = generic(args, kws)
371
+ # Enforce that *generic()* must return None or Signature
372
+ if sig is not None:
373
+ # HACK: Remove this inheritance once all references to CoreSignature are removed
374
+ if not isinstance(
375
+ sig, (Signature, numba.core.typing.templates.Signature)
376
+ ):
377
+ raise AssertionError(
378
+ "generic() must return a Signature or None. "
379
+ "{} returned {}".format(generic, type(sig)),
380
+ )
381
+
382
+ # Unpack optional type if no matching signature
383
+ if not sig and any(isinstance(x, types.Optional) for x in args):
384
+
385
+ def unpack_opt(x):
386
+ if isinstance(x, types.Optional):
387
+ return x.type
388
+ else:
389
+ return x
390
+
391
+ args = list(map(unpack_opt, args))
392
+ assert not kws # Not supported yet
393
+ sig = generic(args, kws)
394
+
395
+ return sig
396
+
397
+ def get_template_info(self):
398
+ impl = getattr(self, "generic")
399
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
400
+
401
+ code, firstlineno, path = self.get_source_code_info(impl)
402
+ sig = str(utils.pysignature(impl))
403
+ info = {
404
+ "kind": "overload",
405
+ "name": getattr(impl, "__qualname__", impl.__name__),
406
+ "sig": sig,
407
+ "filename": utils.safe_relpath(path, start=basepath),
408
+ "lines": (firstlineno, firstlineno + len(code) - 1),
409
+ "docstring": impl.__doc__,
410
+ }
411
+ return info
412
+
413
+
414
+ class CallableTemplate(FunctionTemplate):
415
+ """
416
+ Base class for a template defining a ``generic(self)`` method
417
+ returning a callable to be called with the actual ``*args`` and
418
+ ``**kwargs`` representing the call signature. The callable has
419
+ to return a return type, a full signature, or None. The signature
420
+ does not have to match the input types. It is compared against the
421
+ input types afterwards.
422
+ """
423
+
424
+ recvr = None
425
+
426
+ def apply(self, args, kws):
427
+ generic = getattr(self, "generic")
428
+ typer = generic()
429
+ match_sig = inspect.signature(typer)
430
+ try:
431
+ match_sig.bind(*args, **kws)
432
+ except TypeError as e:
433
+ # bind failed, raise, if there's a
434
+ # ValueError then there's likely unrecoverable
435
+ # problems
436
+ raise TypingError(str(e)) from e
437
+
438
+ sig = typer(*args, **kws)
439
+
440
+ # Unpack optional type if no matching signature
441
+ if sig is None:
442
+ if any(isinstance(x, types.Optional) for x in args):
443
+
444
+ def unpack_opt(x):
445
+ if isinstance(x, types.Optional):
446
+ return x.type
447
+ else:
448
+ return x
449
+
450
+ args = list(map(unpack_opt, args))
451
+ sig = typer(*args, **kws)
452
+ if sig is None:
453
+ return
454
+
455
+ # Get the pysig
456
+ try:
457
+ pysig = typer.pysig
458
+ except AttributeError:
459
+ pysig = utils.pysignature(typer)
460
+
461
+ # Fold any keyword arguments
462
+ bound = pysig.bind(*args, **kws)
463
+ if bound.kwargs:
464
+ raise TypingError("unsupported call signature")
465
+ if not isinstance(sig, Signature):
466
+ # If not a signature, `sig` is assumed to be the return type
467
+ if not isinstance(sig, types.Type):
468
+ raise TypeError(
469
+ "invalid return type for callable template: got %r" % (sig,)
470
+ )
471
+ sig = signature(sig, *bound.args)
472
+ if self.recvr is not None:
473
+ sig = sig.replace(recvr=self.recvr)
474
+ # Hack any omitted parameters out of the typer's pysig,
475
+ # as lowering expects an exact match between formal signature
476
+ # and actual args.
477
+ if len(bound.args) < len(pysig.parameters):
478
+ parameters = list(pysig.parameters.values())[: len(bound.args)]
479
+ pysig = pysig.replace(parameters=parameters)
480
+ sig = sig.replace(pysig=pysig)
481
+ cases = [sig]
482
+ return self._select(cases, bound.args, bound.kwargs)
483
+
484
+ def get_template_info(self):
485
+ impl = getattr(self, "generic")
486
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
487
+ code, firstlineno, path = self.get_source_code_info(impl)
488
+ sig = str(utils.pysignature(impl))
489
+ info = {
490
+ "kind": "overload",
491
+ "name": getattr(
492
+ self.key,
493
+ "__name__",
494
+ getattr(impl, "__qualname__", impl.__name__),
495
+ ),
496
+ "sig": sig,
497
+ "filename": utils.safe_relpath(path, start=basepath),
498
+ "lines": (firstlineno, firstlineno + len(code) - 1),
499
+ "docstring": impl.__doc__,
500
+ }
501
+ return info
502
+
503
+
504
+ class ConcreteTemplate(FunctionTemplate):
505
+ """
506
+ Defines attributes "cases" as a list of signature to match against the
507
+ given input types.
508
+ """
509
+
510
+ def apply(self, args, kws):
511
+ cases = getattr(self, "cases")
512
+ return self._select(cases, args, kws)
513
+
514
+ def get_template_info(self):
515
+ import operator
516
+
517
+ name = getattr(self.key, "__name__", "unknown")
518
+ op_func = getattr(operator, name, None)
519
+
520
+ kind = "Type restricted function"
521
+ if op_func is not None:
522
+ if self.key is op_func:
523
+ kind = "operator overload"
524
+ info = {
525
+ "kind": kind,
526
+ "name": name,
527
+ "sig": "unknown",
528
+ "filename": "unknown",
529
+ "lines": ("unknown", "unknown"),
530
+ "docstring": "unknown",
531
+ }
532
+ return info
533
+
534
+
535
+ class _EmptyImplementationEntry(InternalError):
536
+ def __init__(self, reason):
537
+ super(_EmptyImplementationEntry, self).__init__(
538
+ "_EmptyImplementationEntry({!r})".format(reason),
539
+ )
540
+
541
+
542
+ class _OverloadFunctionTemplate(AbstractTemplate):
543
+ """
544
+ A base class of templates for overload functions.
545
+ """
546
+
547
+ def _validate_sigs(self, typing_func, impl_func):
548
+ # check that the impl func and the typing func have the same signature!
549
+ typing_sig = utils.pysignature(typing_func)
550
+ impl_sig = utils.pysignature(impl_func)
551
+ # the typing signature is considered golden and must be adhered to by
552
+ # the implementation...
553
+ # Things that are valid:
554
+ # 1. args match exactly
555
+ # 2. kwargs match exactly in name and default value
556
+ # 3. Use of *args in the same location by the same name in both typing
557
+ # and implementation signature
558
+ # 4. Use of *args in the implementation signature to consume any number
559
+ # of arguments in the typing signature.
560
+ # Things that are invalid:
561
+ # 5. Use of *args in the typing signature that is not replicated
562
+ # in the implementing signature
563
+ # 6. Use of **kwargs
564
+
565
+ def get_args_kwargs(sig):
566
+ kws = []
567
+ args = []
568
+ pos_arg = None
569
+ for x in sig.parameters.values():
570
+ if x.default == utils.pyParameter.empty:
571
+ args.append(x)
572
+ if x.kind == utils.pyParameter.VAR_POSITIONAL:
573
+ pos_arg = x
574
+ elif x.kind == utils.pyParameter.VAR_KEYWORD:
575
+ msg = (
576
+ "The use of VAR_KEYWORD (e.g. **kwargs) is "
577
+ "unsupported. (offending argument name is '%s')"
578
+ )
579
+ raise InternalError(msg % x)
580
+ else:
581
+ kws.append(x)
582
+ return args, kws, pos_arg
583
+
584
+ ty_args, ty_kws, ty_pos = get_args_kwargs(typing_sig)
585
+ im_args, im_kws, im_pos = get_args_kwargs(impl_sig)
586
+
587
+ sig_fmt = "Typing signature: %s\nImplementation signature: %s"
588
+ sig_str = sig_fmt % (typing_sig, impl_sig)
589
+
590
+ err_prefix = "Typing and implementation arguments differ in "
591
+
592
+ a = ty_args
593
+ b = im_args
594
+ if ty_pos:
595
+ if not im_pos:
596
+ # case 5. described above
597
+ msg = (
598
+ "VAR_POSITIONAL (e.g. *args) argument kind (offending "
599
+ "argument name is '%s') found in the typing function "
600
+ "signature, but is not in the implementing function "
601
+ "signature.\n%s"
602
+ ) % (ty_pos, sig_str)
603
+ raise InternalError(msg)
604
+ else:
605
+ if im_pos:
606
+ # no *args in typing but there's a *args in the implementation
607
+ # this is case 4. described above
608
+ b = im_args[: im_args.index(im_pos)]
609
+ try:
610
+ a = ty_args[: ty_args.index(b[-1]) + 1]
611
+ except ValueError:
612
+ # there's no b[-1] arg name in the ty_args, something is
613
+ # very wrong, we can't work out a diff (*args consumes
614
+ # unknown quantity of args) so just report first error
615
+ specialized = "argument names.\n%s\nFirst difference: '%s'"
616
+ msg = err_prefix + specialized % (sig_str, b[-1])
617
+ raise InternalError(msg)
618
+
619
+ def gen_diff(typing, implementing):
620
+ diff = set(typing) ^ set(implementing)
621
+ return "Difference: %s" % diff
622
+
623
+ if a != b:
624
+ specialized = "argument names.\n%s\n%s" % (sig_str, gen_diff(a, b))
625
+ raise InternalError(err_prefix + specialized)
626
+
627
+ # ensure kwargs are the same
628
+ ty = [x.name for x in ty_kws]
629
+ im = [x.name for x in im_kws]
630
+ if ty != im:
631
+ specialized = "keyword argument names.\n%s\n%s"
632
+ msg = err_prefix + specialized % (sig_str, gen_diff(ty_kws, im_kws))
633
+ raise InternalError(msg)
634
+ same = [x.default for x in ty_kws] == [x.default for x in im_kws]
635
+ if not same:
636
+ specialized = "keyword argument default values.\n%s\n%s"
637
+ msg = err_prefix + specialized % (sig_str, gen_diff(ty_kws, im_kws))
638
+ raise InternalError(msg)
639
+
640
+ def generic(self, args, kws):
641
+ """
642
+ Type the overloaded function by compiling the appropriate
643
+ implementation for the given args.
644
+ """
645
+ from numba.core.typed_passes import PreLowerStripPhis
646
+
647
+ disp, new_args = self._get_impl(args, kws)
648
+ if disp is None:
649
+ return
650
+ # Compile and type it for the given types
651
+ disp_type = types.Dispatcher(disp)
652
+ # Store the compiled overload for use in the lowering phase if there's
653
+ # no inlining required (else functions are being compiled which will
654
+ # never be used as they are inlined)
655
+ if not self._inline.is_never_inline:
656
+ # need to run the compiler front end up to type inference to compute
657
+ # a signature
658
+ from numba.core import typed_passes, compiler
659
+ from numba.core.inline_closurecall import InlineWorker
660
+
661
+ fcomp = disp._compiler
662
+ flags = compiler.Flags()
663
+
664
+ # Updating these causes problems?!
665
+ # fcomp.targetdescr.options.parse_as_flags(flags,
666
+ # fcomp.targetoptions)
667
+ # flags = fcomp._customize_flags(flags)
668
+
669
+ # spoof a compiler pipline like the one that will be in use
670
+ tyctx = fcomp.targetdescr.typing_context
671
+ tgctx = fcomp.targetdescr.target_context
672
+ compiler_inst = fcomp.pipeline_class(
673
+ tyctx,
674
+ tgctx,
675
+ None,
676
+ None,
677
+ None,
678
+ flags,
679
+ None,
680
+ )
681
+ inline_worker = InlineWorker(
682
+ tyctx,
683
+ tgctx,
684
+ fcomp.locals,
685
+ compiler_inst,
686
+ flags,
687
+ None,
688
+ )
689
+
690
+ # If the inlinee contains something to trigger literal arg dispatch
691
+ # then the pipeline call will unconditionally fail due to a raised
692
+ # ForceLiteralArg exception. Therefore `resolve` is run first, as
693
+ # type resolution must occur at some point, this will hit any
694
+ # `literally` calls and because it's going via the dispatcher will
695
+ # handle them correctly i.e. ForceLiteralArg propagates. This having
696
+ # the desired effect of ensuring the pipeline call is only made in
697
+ # situations that will succeed. For context see #5887.
698
+ resolve = disp_type.dispatcher.get_call_template
699
+ template, pysig, folded_args, kws = resolve(new_args, kws)
700
+ ir = inline_worker.run_untyped_passes(
701
+ disp_type.dispatcher.py_func, enable_ssa=True
702
+ )
703
+
704
+ (typemap, return_type, calltypes, _) = (
705
+ typed_passes.type_inference_stage(
706
+ self.context, tgctx, ir, folded_args, None
707
+ )
708
+ )
709
+ ir = PreLowerStripPhis()._strip_phi_nodes(ir)
710
+ ir._definitions = ir_utils.build_definitions(ir.blocks)
711
+
712
+ sig = Signature(return_type, folded_args, None)
713
+ # this stores a load of info for the cost model function if supplied
714
+ # it by default is None
715
+ self._inline_overloads[sig.args] = {"folded_args": folded_args}
716
+ # this stores the compiled overloads, if there's no compiled
717
+ # overload available i.e. function is always inlined, the key still
718
+ # needs to exist for type resolution
719
+
720
+ # NOTE: If lowering is failing on a `_EmptyImplementationEntry`,
721
+ # the inliner has failed to inline this entry correctly.
722
+ impl_init = _EmptyImplementationEntry("always inlined")
723
+ self._compiled_overloads[sig.args] = impl_init
724
+ if not self._inline.is_always_inline:
725
+ # this branch is here because a user has supplied a function to
726
+ # determine whether to inline or not. As a result both compiled
727
+ # function and inliner info needed, delaying the computation of
728
+ # this leads to an internal state mess at present. TODO: Fix!
729
+ sig = disp_type.get_call_type(self.context, new_args, kws)
730
+ self._compiled_overloads[sig.args] = disp_type.get_overload(sig)
731
+ # store the inliner information, it's used later in the cost
732
+ # model function call
733
+ iinfo = _inline_info(ir, typemap, calltypes, sig)
734
+ self._inline_overloads[sig.args] = {
735
+ "folded_args": folded_args,
736
+ "iinfo": iinfo,
737
+ }
738
+ else:
739
+ sig = disp_type.get_call_type(self.context, new_args, kws)
740
+ if sig is None: # can't resolve for this target
741
+ return None
742
+ self._compiled_overloads[sig.args] = disp_type.get_overload(sig)
743
+ return sig
744
+
745
+ def _get_impl(self, args, kws):
746
+ """Get implementation given the argument types.
747
+
748
+ Returning a Dispatcher object. The Dispatcher object is cached
749
+ internally in `self._impl_cache`.
750
+ """
751
+ flags = targetconfig.ConfigStack.top_or_none()
752
+ cache_key = self.context, tuple(args), tuple(kws.items()), flags
753
+ try:
754
+ impl, args = self._impl_cache[cache_key]
755
+ return impl, args
756
+ except KeyError:
757
+ # pass and try outside the scope so as to not have KeyError with a
758
+ # nested addition error in the case the _build_impl fails
759
+ pass
760
+ impl, args = self._build_impl(cache_key, args, kws)
761
+ return impl, args
762
+
763
+ def _get_jit_decorator(self):
764
+ """Gets a jit decorator suitable for the current target"""
765
+
766
+ from numba.core.target_extension import (
767
+ target_registry,
768
+ get_local_target,
769
+ jit_registry,
770
+ )
771
+
772
+ jitter_str = self.metadata.get("target", "generic")
773
+ jitter = jit_registry.get(jitter_str, None)
774
+
775
+ if jitter is None:
776
+ # No JIT known for target string, see if something is
777
+ # registered for the string and report if not.
778
+ target_class = target_registry.get(jitter_str, None)
779
+ if target_class is None:
780
+ msg = ("Unknown target '{}', has it been ", "registered?")
781
+ raise ValueError(msg.format(jitter_str))
782
+
783
+ target_hw = get_local_target(self.context)
784
+
785
+ # check that the requested target is in the hierarchy for the
786
+ # current frame's target.
787
+ if not issubclass(target_hw, target_class):
788
+ msg = "No overloads exist for the requested target: {}."
789
+
790
+ jitter = jit_registry[target_hw]
791
+
792
+ if jitter is None:
793
+ raise ValueError("Cannot find a suitable jit decorator")
794
+
795
+ return jitter
796
+
797
+ def _build_impl(self, cache_key, args, kws):
798
+ """Build and cache the implementation.
799
+
800
+ Given the positional (`args`) and keyword arguments (`kws`), obtains
801
+ the `overload` implementation and wrap it in a Dispatcher object.
802
+ The expected argument types are returned for use by type-inference.
803
+ The expected argument types are only different from the given argument
804
+ types if there is an imprecise type in the given argument types.
805
+
806
+ Parameters
807
+ ----------
808
+ cache_key : hashable
809
+ The key used for caching the implementation.
810
+ args : Tuple[Type]
811
+ Types of positional argument.
812
+ kws : Dict[Type]
813
+ Types of keyword argument.
814
+
815
+ Returns
816
+ -------
817
+ disp, args :
818
+ On success, returns `(Dispatcher, Tuple[Type])`.
819
+ On failure, returns `(None, None)`.
820
+
821
+ """
822
+ jitter = self._get_jit_decorator()
823
+
824
+ # Get the overload implementation for the given types
825
+ ov_sig = inspect.signature(self._overload_func)
826
+ try:
827
+ ov_sig.bind(*args, **kws)
828
+ except TypeError as e:
829
+ # bind failed, raise, if there's a
830
+ # ValueError then there's likely unrecoverable
831
+ # problems
832
+ raise TypingError(str(e)) from e
833
+ else:
834
+ ovf_result = self._overload_func(*args, **kws)
835
+
836
+ if ovf_result is None:
837
+ # No implementation => fail typing
838
+ self._impl_cache[cache_key] = None, None
839
+ return None, None
840
+ elif isinstance(ovf_result, tuple):
841
+ # The implementation returned a signature that the type-inferencer
842
+ # should be using.
843
+ sig, pyfunc = ovf_result
844
+ args = sig.args
845
+ kws = {}
846
+ cache_key = None # don't cache
847
+ else:
848
+ # Regular case
849
+ pyfunc = ovf_result
850
+
851
+ # Check type of pyfunc
852
+ if not isinstance(pyfunc, FunctionType):
853
+ msg = (
854
+ "Implementation function returned by `@overload` "
855
+ "has an unexpected type. Got {}"
856
+ )
857
+ raise AssertionError(msg.format(pyfunc))
858
+
859
+ # check that the typing and impl sigs match up
860
+ if self._strict:
861
+ self._validate_sigs(self._overload_func, pyfunc)
862
+ # Make dispatcher
863
+ jitdecor = jitter(**self._jit_options)
864
+ disp = jitdecor(pyfunc)
865
+ # Make sure that the implementation can be fully compiled
866
+ disp_type = types.Dispatcher(disp)
867
+ disp_type.get_call_type(self.context, args, kws)
868
+ if cache_key is not None:
869
+ self._impl_cache[cache_key] = disp, args
870
+ return disp, args
871
+
872
+ def get_impl_key(self, sig):
873
+ """
874
+ Return the key for looking up the implementation for the given
875
+ signature on the target context.
876
+ """
877
+ return self._compiled_overloads[sig.args]
878
+
879
+ @classmethod
880
+ def get_source_info(cls):
881
+ """Return a dictionary with information about the source code of the
882
+ implementation.
883
+
884
+ Returns
885
+ -------
886
+ info : dict
887
+ - "kind" : str
888
+ The implementation kind.
889
+ - "name" : str
890
+ The name of the function that provided the definition.
891
+ - "sig" : str
892
+ The formatted signature of the function.
893
+ - "filename" : str
894
+ The name of the source file.
895
+ - "lines": tuple (int, int)
896
+ First and list line number.
897
+ - "docstring": str
898
+ The docstring of the definition.
899
+ """
900
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
901
+ impl = cls._overload_func
902
+ code, firstlineno, path = cls.get_source_code_info(impl)
903
+ sig = str(utils.pysignature(impl))
904
+ info = {
905
+ "kind": "overload",
906
+ "name": getattr(impl, "__qualname__", impl.__name__),
907
+ "sig": sig,
908
+ "filename": utils.safe_relpath(path, start=basepath),
909
+ "lines": (firstlineno, firstlineno + len(code) - 1),
910
+ "docstring": impl.__doc__,
911
+ }
912
+ return info
913
+
914
+ def get_template_info(self):
915
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
916
+ impl = self._overload_func
917
+ code, firstlineno, path = self.get_source_code_info(impl)
918
+ sig = str(utils.pysignature(impl))
919
+ info = {
920
+ "kind": "overload",
921
+ "name": getattr(impl, "__qualname__", impl.__name__),
922
+ "sig": sig,
923
+ "filename": utils.safe_relpath(path, start=basepath),
924
+ "lines": (firstlineno, firstlineno + len(code) - 1),
925
+ "docstring": impl.__doc__,
926
+ }
927
+ return info
928
+
929
+
930
+ def make_overload_template(
931
+ func,
932
+ overload_func,
933
+ jit_options,
934
+ strict,
935
+ inline,
936
+ prefer_literal=False,
937
+ **kwargs,
938
+ ):
939
+ """
940
+ Make a template class for function *func* overloaded by *overload_func*.
941
+ Compiler options are passed as a dictionary to *jit_options*.
942
+ """
943
+ func_name = getattr(func, "__name__", str(func))
944
+ name = "OverloadTemplate_%s" % (func_name,)
945
+ base = _OverloadFunctionTemplate
946
+ dct = dict(
947
+ key=func,
948
+ _overload_func=staticmethod(overload_func),
949
+ _impl_cache={},
950
+ _compiled_overloads={},
951
+ _jit_options=jit_options,
952
+ _strict=strict,
953
+ _inline=staticmethod(InlineOptions(inline)),
954
+ _inline_overloads={},
955
+ prefer_literal=prefer_literal,
956
+ metadata=kwargs,
957
+ )
958
+ return type(base)(name, (base,), dct)
959
+
960
+
961
+ class _TemplateTargetHelperMixin(object):
962
+ """Mixin for helper methods that assist with target/registry resolution"""
963
+
964
+ def _get_target_registry(self, reason):
965
+ """Returns the registry for the current target.
966
+
967
+ Parameters
968
+ ----------
969
+ reason: str
970
+ Reason for the resolution. Expects a noun.
971
+ Returns
972
+ -------
973
+ reg : a registry suitable for the current target.
974
+ """
975
+ from numba.core.target_extension import (
976
+ _get_local_target_checked,
977
+ dispatcher_registry,
978
+ )
979
+
980
+ hwstr = self.metadata.get("target", "generic")
981
+ target_hw = _get_local_target_checked(self.context, hwstr, reason)
982
+ # Get registry for the current hardware
983
+ disp = dispatcher_registry[target_hw]
984
+ tgtctx = disp.targetdescr.target_context
985
+
986
+ # ---------------------------------------------------------------------
987
+ # XXX: In upstream Numba, this function would prefer the builtin
988
+ # registry if it was installed in the target (as it is for the CUDA
989
+ # target). The builtin registry has been removed from this file (it was
990
+ # initialized as `builtin_registry = Registry()`) as it would duplicate
991
+ # the builtin registry in upstream Numba, which would be likely to lead
992
+ # to confusion / mixing things up between two builtin registries. The
993
+ # comment that accompanied this behaviour is left here, even though the
994
+ # code that would pick the builtin registry has been removed, for the
995
+ # benefit of future understanding.
996
+ #
997
+ # ---------------------------------------------------------------------
998
+ #
999
+ # Comment left in from upstream:
1000
+ #
1001
+ # This is all workarounds...
1002
+ # The issue is that whilst targets shouldn't care about which registry
1003
+ # in which to register lowering implementations, the CUDA target
1004
+ # "borrows" implementations from the CPU from specific registries. This
1005
+ # means that if some impl is defined via @intrinsic, e.g. numba.*unsafe
1006
+ # modules, _AND_ CUDA also makes use of the same impl, then it's
1007
+ # required that the registry in use is one that CUDA borrows from. This
1008
+ # leads to the following expression where by the CPU builtin_registry is
1009
+ # used if it is in the target context as a known registry (i.e. the
1010
+ # target installed it) and if it is not then it is assumed that the
1011
+ # registries for the target are unbound to any other target and so it's
1012
+ # fine to use any of them as a place to put lowering impls.
1013
+ #
1014
+ # NOTE: This will need subsequently fixing again when targets use solely
1015
+ # the extension APIs to describe their implementation. The issue will be
1016
+ # that the builtin_registry should contain _just_ the stack allocated
1017
+ # implementations and low level target invariant things and should not
1018
+ # be modified further. It should be acceptable to remove the `then`
1019
+ # branch and just keep the `else`.
1020
+ # =====================================================================
1021
+
1022
+ # =====================================================================
1023
+ # XXX: This ought not to be necessary in the long term, but is left in
1024
+ # for now. When there are fewer registries (or just one) for a target,
1025
+ # it may be safe to remove this. Or, it may always require a refresh in
1026
+ # case there are pending registrations - this remains to be seen
1027
+ # ---------------------------------------------------------------------
1028
+ #
1029
+ # Comment / code left in from upstream:
1030
+ #
1031
+ # In case the target has swapped, e.g. cuda borrowing cpu, refresh to
1032
+ # populate.
1033
+ tgtctx.refresh()
1034
+ # =====================================================================
1035
+
1036
+ # Pick a registry in which to install intrinsics
1037
+ registries = iter(tgtctx._registries)
1038
+ reg = next(registries)
1039
+ return reg
1040
+
1041
+
1042
+ class _IntrinsicTemplate(_TemplateTargetHelperMixin, AbstractTemplate):
1043
+ """
1044
+ A base class of templates for intrinsic definition
1045
+ """
1046
+
1047
+ def generic(self, args, kws):
1048
+ """
1049
+ Type the intrinsic by the arguments.
1050
+ """
1051
+ lower_builtin = self._get_target_registry("intrinsic").lower
1052
+ cache_key = self.context, args, tuple(kws.items())
1053
+ try:
1054
+ return self._impl_cache[cache_key]
1055
+ except KeyError:
1056
+ pass
1057
+ result = self._definition_func(self.context, *args, **kws)
1058
+ if result is None:
1059
+ return
1060
+ [sig, imp] = result
1061
+ pysig = utils.pysignature(self._definition_func)
1062
+ # omit context argument from user function
1063
+ parameters = list(pysig.parameters.values())[1:]
1064
+ sig = sig.replace(pysig=pysig.replace(parameters=parameters))
1065
+ self._impl_cache[cache_key] = sig
1066
+ self._overload_cache[sig.args] = imp
1067
+ # register the lowering
1068
+ lower_builtin(imp, *sig.args)(imp)
1069
+ return sig
1070
+
1071
+ def get_impl_key(self, sig):
1072
+ """
1073
+ Return the key for looking up the implementation for the given
1074
+ signature on the target context.
1075
+ """
1076
+ return self._overload_cache[sig.args]
1077
+
1078
+ def get_template_info(self):
1079
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
1080
+ impl = self._definition_func
1081
+ code, firstlineno, path = self.get_source_code_info(impl)
1082
+ sig = str(utils.pysignature(impl))
1083
+ info = {
1084
+ "kind": "intrinsic",
1085
+ "name": getattr(impl, "__qualname__", impl.__name__),
1086
+ "sig": sig,
1087
+ "filename": utils.safe_relpath(path, start=basepath),
1088
+ "lines": (firstlineno, firstlineno + len(code) - 1),
1089
+ "docstring": impl.__doc__,
1090
+ }
1091
+ return info
1092
+
1093
+
1094
+ def make_intrinsic_template(
1095
+ handle, defn, name, *, prefer_literal=False, kwargs=None
1096
+ ):
1097
+ """
1098
+ Make a template class for a intrinsic handle *handle* defined by the
1099
+ function *defn*. The *name* is used for naming the new template class.
1100
+ """
1101
+ kwargs = MappingProxyType({} if kwargs is None else kwargs)
1102
+ base = _IntrinsicTemplate
1103
+ name = "_IntrinsicTemplate_%s" % (name)
1104
+ dct = dict(
1105
+ key=handle,
1106
+ _definition_func=staticmethod(defn),
1107
+ _impl_cache={},
1108
+ _overload_cache={},
1109
+ prefer_literal=prefer_literal,
1110
+ metadata=kwargs,
1111
+ )
1112
+ return type(base)(name, (base,), dct)
1113
+
1114
+
1115
+ class AttributeTemplate(object):
1116
+ def __init__(self, context):
1117
+ self.context = context
1118
+
1119
+ def resolve(self, value, attr):
1120
+ return self._resolve(value, attr)
1121
+
1122
+ def _resolve(self, value, attr):
1123
+ fn = getattr(self, "resolve_%s" % attr, None)
1124
+ if fn is None:
1125
+ fn = self.generic_resolve
1126
+ if fn is NotImplemented:
1127
+ if isinstance(value, types.Module):
1128
+ return self.context.resolve_module_constants(value, attr)
1129
+ else:
1130
+ return None
1131
+ else:
1132
+ return fn(value, attr)
1133
+ else:
1134
+ return fn(value)
1135
+
1136
+ generic_resolve = NotImplemented
1137
+
1138
+
1139
+ class _OverloadAttributeTemplate(_TemplateTargetHelperMixin, AttributeTemplate):
1140
+ """
1141
+ A base class of templates for @overload_attribute functions.
1142
+ """
1143
+
1144
+ is_method = False
1145
+
1146
+ def __init__(self, context):
1147
+ super(_OverloadAttributeTemplate, self).__init__(context)
1148
+ self.context = context
1149
+ self._init_once()
1150
+
1151
+ def _init_once(self):
1152
+ cls = type(self)
1153
+ attr = cls._attr
1154
+
1155
+ lower_getattr = self._get_target_registry("attribute").lower_getattr
1156
+
1157
+ @lower_getattr(cls.key, attr)
1158
+ def getattr_impl(context, builder, typ, value):
1159
+ typingctx = context.typing_context
1160
+ fnty = cls._get_function_type(typingctx, typ)
1161
+ sig = cls._get_signature(typingctx, fnty, (typ,), {})
1162
+ call = context.get_function(fnty, sig)
1163
+ return call(builder, (value,))
1164
+
1165
+ def _resolve(self, typ, attr):
1166
+ if self._attr != attr:
1167
+ return None
1168
+ fnty = self._get_function_type(self.context, typ)
1169
+ sig = self._get_signature(self.context, fnty, (typ,), {})
1170
+ # There should only be one template
1171
+ for template in fnty.templates:
1172
+ self._inline_overloads.update(template._inline_overloads)
1173
+ return sig.return_type
1174
+
1175
+ @classmethod
1176
+ def _get_signature(cls, typingctx, fnty, args, kws):
1177
+ sig = fnty.get_call_type(typingctx, args, kws)
1178
+ sig = sig.replace(pysig=utils.pysignature(cls._overload_func))
1179
+ return sig
1180
+
1181
+ @classmethod
1182
+ def _get_function_type(cls, typingctx, typ):
1183
+ return typingctx.resolve_value_type(cls._overload_func)
1184
+
1185
+
1186
+ class _OverloadMethodTemplate(_OverloadAttributeTemplate):
1187
+ """
1188
+ A base class of templates for @overload_method functions.
1189
+ """
1190
+
1191
+ is_method = True
1192
+
1193
+ def _init_once(self):
1194
+ """
1195
+ Overriding parent definition
1196
+ """
1197
+ attr = self._attr
1198
+
1199
+ registry = self._get_target_registry("method")
1200
+
1201
+ @registry.lower((self.key, attr), self.key, types.VarArg(types.Any))
1202
+ def method_impl(context, builder, sig, args):
1203
+ typ = sig.args[0]
1204
+ typing_context = context.typing_context
1205
+ fnty = self._get_function_type(typing_context, typ)
1206
+ sig = self._get_signature(typing_context, fnty, sig.args, {})
1207
+ call = context.get_function(fnty, sig)
1208
+ # Link dependent library
1209
+ context.add_linking_libs(getattr(call, "libs", ()))
1210
+ return call(builder, args)
1211
+
1212
+ def _resolve(self, typ, attr):
1213
+ if self._attr != attr:
1214
+ return None
1215
+
1216
+ if isinstance(typ, types.TypeRef):
1217
+ assert typ == self.key
1218
+ elif isinstance(typ, types.Callable):
1219
+ assert typ == self.key
1220
+ else:
1221
+ assert isinstance(typ, self.key)
1222
+
1223
+ class MethodTemplate(AbstractTemplate):
1224
+ key = (self.key, attr)
1225
+ _inline = self._inline
1226
+ _overload_func = staticmethod(self._overload_func)
1227
+ _inline_overloads = self._inline_overloads
1228
+ prefer_literal = self.prefer_literal
1229
+
1230
+ def generic(_, args, kws):
1231
+ args = (typ,) + tuple(args)
1232
+ fnty = self._get_function_type(self.context, typ)
1233
+ sig = self._get_signature(self.context, fnty, args, kws)
1234
+ sig = sig.replace(pysig=utils.pysignature(self._overload_func))
1235
+ for template in fnty.templates:
1236
+ self._inline_overloads.update(template._inline_overloads)
1237
+ if sig is not None:
1238
+ return sig.as_method()
1239
+
1240
+ def get_template_info(self):
1241
+ basepath = os.path.dirname(os.path.dirname(numba.__file__))
1242
+ impl = self._overload_func
1243
+ code, firstlineno, path = self.get_source_code_info(impl)
1244
+ sig = str(utils.pysignature(impl))
1245
+ info = {
1246
+ "kind": "overload_method",
1247
+ "name": getattr(impl, "__qualname__", impl.__name__),
1248
+ "sig": sig,
1249
+ "filename": utils.safe_relpath(path, start=basepath),
1250
+ "lines": (firstlineno, firstlineno + len(code) - 1),
1251
+ "docstring": impl.__doc__,
1252
+ }
1253
+
1254
+ return info
1255
+
1256
+ return types.BoundFunction(MethodTemplate, typ)
1257
+
1258
+
1259
+ def make_overload_attribute_template(
1260
+ typ,
1261
+ attr,
1262
+ overload_func,
1263
+ inline="never",
1264
+ prefer_literal=False,
1265
+ base=_OverloadAttributeTemplate,
1266
+ **kwargs,
1267
+ ):
1268
+ """
1269
+ Make a template class for attribute *attr* of *typ* overloaded by
1270
+ *overload_func*.
1271
+ """
1272
+ assert isinstance(typ, types.Type) or issubclass(typ, types.Type)
1273
+ name = "OverloadAttributeTemplate_%s_%s" % (typ, attr)
1274
+ # Note the implementation cache is subclass-specific
1275
+ dct = dict(
1276
+ key=typ,
1277
+ _attr=attr,
1278
+ _impl_cache={},
1279
+ _inline=staticmethod(InlineOptions(inline)),
1280
+ _inline_overloads={},
1281
+ _overload_func=staticmethod(overload_func),
1282
+ prefer_literal=prefer_literal,
1283
+ metadata=kwargs,
1284
+ )
1285
+ obj = type(base)(name, (base,), dct)
1286
+ return obj
1287
+
1288
+
1289
+ def make_overload_method_template(
1290
+ typ, attr, overload_func, inline, prefer_literal=False, **kwargs
1291
+ ):
1292
+ """
1293
+ Make a template class for method *attr* of *typ* overloaded by
1294
+ *overload_func*.
1295
+ """
1296
+ return make_overload_attribute_template(
1297
+ typ,
1298
+ attr,
1299
+ overload_func,
1300
+ inline=inline,
1301
+ base=_OverloadMethodTemplate,
1302
+ prefer_literal=prefer_literal,
1303
+ **kwargs,
1304
+ )
1305
+
1306
+
1307
+ def bound_function(template_key):
1308
+ """
1309
+ Wrap an AttributeTemplate resolve_* method to allow it to
1310
+ resolve an instance method's signature rather than a instance attribute.
1311
+ The wrapped method must return the resolved method's signature
1312
+ according to the given self type, args, and keywords.
1313
+
1314
+ It is used thusly:
1315
+
1316
+ class ComplexAttributes(AttributeTemplate):
1317
+ @bound_function("complex.conjugate")
1318
+ def resolve_conjugate(self, ty, args, kwds):
1319
+ return ty
1320
+
1321
+ *template_key* (e.g. "complex.conjugate" above) will be used by the
1322
+ target to look up the method's implementation, as a regular function.
1323
+ """
1324
+
1325
+ def wrapper(method_resolver):
1326
+ @functools.wraps(method_resolver)
1327
+ def attribute_resolver(self, ty):
1328
+ class MethodTemplate(AbstractTemplate):
1329
+ key = template_key
1330
+
1331
+ def generic(_, args, kws):
1332
+ sig = method_resolver(self, ty, args, kws)
1333
+ if sig is not None and sig.recvr is None:
1334
+ sig = sig.replace(recvr=ty)
1335
+ return sig
1336
+
1337
+ return types.BoundFunction(MethodTemplate, ty)
1338
+
1339
+ return attribute_resolver
1340
+
1341
+ return wrapper
1342
+
1343
+
1344
+ # -----------------------------
1345
+
1346
+
1347
+ class Registry(object):
1348
+ """
1349
+ A registry of typing declarations. The registry stores such declarations
1350
+ for functions, attributes and globals.
1351
+ """
1352
+
1353
+ def __init__(self):
1354
+ self.functions = []
1355
+ self.attributes = []
1356
+ self.globals = []
1357
+
1358
+ def register(self, item):
1359
+ assert issubclass(
1360
+ item,
1361
+ (FunctionTemplate, numba.core.typing.templates.FunctionTemplate),
1362
+ )
1363
+ self.functions.append(item)
1364
+ return item
1365
+
1366
+ def register_attr(self, item):
1367
+ assert issubclass(
1368
+ item,
1369
+ (AttributeTemplate, numba.core.typing.templates.AttributeTemplate),
1370
+ )
1371
+ self.attributes.append(item)
1372
+ return item
1373
+
1374
+ def register_global(self, val=None, typ=None, **kwargs):
1375
+ """
1376
+ Register the typing of a global value.
1377
+ Functional usage with a Numba type::
1378
+ register_global(value, typ)
1379
+
1380
+ Decorator usage with a template class::
1381
+ @register_global(value, typing_key=None)
1382
+ class Template: ...
1383
+ """
1384
+ if typ is not None:
1385
+ # register_global(val, typ)
1386
+ assert val is not None
1387
+ assert not kwargs
1388
+ self.globals.append((val, typ))
1389
+ else:
1390
+
1391
+ def decorate(cls, typing_key):
1392
+ class Template(cls):
1393
+ key = typing_key
1394
+
1395
+ if callable(val):
1396
+ typ = types.Function(Template)
1397
+ else:
1398
+ raise TypeError("cannot infer type for global value %r")
1399
+ self.globals.append((val, typ))
1400
+ return cls
1401
+
1402
+ # register_global(val, typing_key=None)(<template class>)
1403
+ assert val is not None
1404
+ typing_key = kwargs.pop("typing_key", val)
1405
+ assert not kwargs
1406
+ if typing_key is val:
1407
+ # Check the value is globally reachable, as it is going
1408
+ # to be used as the key.
1409
+ mod = sys.modules[val.__module__]
1410
+ if getattr(mod, val.__name__) is not val:
1411
+ raise ValueError(
1412
+ "%r is not globally reachable as '%s.%s'"
1413
+ % (mod, val.__module__, val.__name__)
1414
+ )
1415
+
1416
+ def decorator(cls):
1417
+ return decorate(cls, typing_key)
1418
+
1419
+ return decorator
1420
+
1421
+
1422
+ class BaseRegistryLoader(object):
1423
+ """
1424
+ An incremental loader for a registry. Each new call to
1425
+ new_registrations() will iterate over the not yet seen registrations.
1426
+
1427
+ The reason for this object is multiple:
1428
+ - there can be several contexts
1429
+ - each context wants to install all registrations
1430
+ - registrations can be added after the first installation, so contexts
1431
+ must be able to get the "new" installations
1432
+
1433
+ Therefore each context maintains its own loaders for each existing
1434
+ registry, without duplicating the registries themselves.
1435
+ """
1436
+
1437
+ def __init__(self, registry):
1438
+ self._registrations = dict(
1439
+ (name, utils.stream_list(getattr(registry, name)))
1440
+ for name in self.registry_items
1441
+ )
1442
+
1443
+ def new_registrations(self, name):
1444
+ for item in next(self._registrations[name]):
1445
+ yield item
1446
+
1447
+
1448
+ class RegistryLoader(BaseRegistryLoader):
1449
+ """
1450
+ An incremental loader for a typing registry.
1451
+ """
1452
+
1453
+ registry_items = ("functions", "attributes", "globals")