numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -2,11 +2,59 @@
2
2
  # SPDX-License-Identifier: BSD-2-Clause
3
3
 
4
4
  import abc
5
- import warnings
6
5
  from contextlib import contextmanager
7
- from numba.core import errors, types, funcdesc
8
- from numba.core.compiler_machinery import LoweringPass
9
- from llvmlite import binding as llvm
6
+ from collections import defaultdict, namedtuple
7
+ from copy import copy
8
+ import warnings
9
+
10
+ from numba.cuda.core import typeinfer
11
+ from numba.core import (
12
+ errors,
13
+ types,
14
+ typing,
15
+ ir,
16
+ lowering,
17
+ )
18
+
19
+ from numba.cuda.core.compiler_machinery import (
20
+ FunctionPass,
21
+ LoweringPass,
22
+ AnalysisPass,
23
+ register_pass,
24
+ )
25
+ from numba.cuda.core.annotations import type_annotations
26
+ from numba.cuda.core.ir_utils import (
27
+ raise_on_unsupported_feature,
28
+ warn_deprecated,
29
+ check_and_legalize_ir,
30
+ guard,
31
+ dead_code_elimination,
32
+ simplify_CFG,
33
+ get_definition,
34
+ compute_cfg_from_blocks,
35
+ is_operator_or_getitem,
36
+ )
37
+
38
+ from numba.cuda.core import postproc, rewrites, funcdesc, config
39
+
40
+
41
+ try:
42
+ # llvmlite < 0.45
43
+ from llvmlite.binding import passmanagers
44
+ except ImportError:
45
+ # llvmlite >= 0.45
46
+ from llvmlite.binding import newpassmanagers as passmanagers
47
+
48
+ # Outputs of type inference pass
49
+ _TypingResults = namedtuple(
50
+ "_TypingResults",
51
+ [
52
+ "typemap",
53
+ "return_type",
54
+ "calltypes",
55
+ "typing_errors",
56
+ ],
57
+ )
10
58
 
11
59
 
12
60
  @contextmanager
@@ -38,6 +86,216 @@ def fallback_context(state, msg):
38
86
  raise
39
87
 
40
88
 
89
+ def type_inference_stage(
90
+ typingctx,
91
+ targetctx,
92
+ interp,
93
+ args,
94
+ return_type,
95
+ locals=None,
96
+ raise_errors=True,
97
+ ):
98
+ if locals is None:
99
+ locals = {}
100
+ if len(args) != interp.arg_count:
101
+ raise TypeError("Mismatch number of argument types")
102
+ warnings = errors.WarningsFixer(errors.NumbaWarning)
103
+
104
+ infer = typeinfer.TypeInferer(typingctx, interp, warnings)
105
+ callstack_ctx = typingctx.callstack.register(
106
+ targetctx.target, infer, interp.func_id, args
107
+ )
108
+ # Setup two contexts: 1) callstack setup/teardown 2) flush warnings
109
+ with callstack_ctx, warnings:
110
+ # Seed argument types
111
+ for index, (name, ty) in enumerate(zip(interp.arg_names, args)):
112
+ infer.seed_argument(name, index, ty)
113
+
114
+ # Seed return type
115
+ if return_type is not None:
116
+ infer.seed_return(return_type)
117
+
118
+ # Seed local types
119
+ for k, v in locals.items():
120
+ infer.seed_type(k, v)
121
+
122
+ infer.build_constraint()
123
+ # return errors in case of partial typing
124
+ errs = infer.propagate(raise_errors=raise_errors)
125
+ typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
126
+
127
+ return _TypingResults(typemap, restype, calltypes, errs)
128
+
129
+
130
+ class BaseTypeInference(FunctionPass):
131
+ _raise_errors = True
132
+
133
+ def __init__(self):
134
+ FunctionPass.__init__(self)
135
+
136
+ def run_pass(self, state):
137
+ """
138
+ Type inference and legalization
139
+ """
140
+ with fallback_context(
141
+ state,
142
+ 'Function "%s" failed type inference' % (state.func_id.func_name,),
143
+ ):
144
+ # Type inference
145
+ typemap, return_type, calltypes, errs = type_inference_stage(
146
+ state.typingctx,
147
+ state.targetctx,
148
+ state.func_ir,
149
+ state.args,
150
+ state.return_type,
151
+ state.locals,
152
+ raise_errors=self._raise_errors,
153
+ )
154
+ state.typemap = typemap
155
+ # save errors in case of partial typing
156
+ state.typing_errors = errs
157
+ if self._raise_errors:
158
+ state.return_type = return_type
159
+ state.calltypes = calltypes
160
+
161
+ def legalize_return_type(return_type, interp, targetctx):
162
+ """
163
+ Only accept array return type iff it is passed into the function.
164
+ Reject function object return types if in nopython mode.
165
+ """
166
+ if not targetctx.enable_nrt and isinstance(
167
+ return_type, types.Array
168
+ ):
169
+ # Walk IR to discover all arguments and all return statements
170
+ retstmts = []
171
+ caststmts = {}
172
+ argvars = set()
173
+ for bid, blk in interp.blocks.items():
174
+ for inst in blk.body:
175
+ if isinstance(inst, ir.Return):
176
+ retstmts.append(inst.value.name)
177
+ elif isinstance(inst, ir.Assign):
178
+ if (
179
+ isinstance(inst.value, ir.Expr)
180
+ and inst.value.op == "cast"
181
+ ):
182
+ caststmts[inst.target.name] = inst.value
183
+ elif isinstance(inst.value, ir.Arg):
184
+ argvars.add(inst.target.name)
185
+
186
+ assert retstmts, "No return statements?"
187
+
188
+ for var in retstmts:
189
+ cast = caststmts.get(var)
190
+ if cast is None or cast.value.name not in argvars:
191
+ if self._raise_errors:
192
+ msg = (
193
+ "Only accept returning of array passed into "
194
+ "the function as argument"
195
+ )
196
+ raise errors.NumbaTypeError(msg)
197
+
198
+ elif isinstance(return_type, types.Function) or isinstance(
199
+ return_type, types.Phantom
200
+ ):
201
+ if self._raise_errors:
202
+ msg = "Can't return function object ({}) in nopython mode"
203
+ raise errors.NumbaTypeError(msg.format(return_type))
204
+
205
+ with fallback_context(
206
+ state,
207
+ 'Function "%s" has invalid return type'
208
+ % (state.func_id.func_name,),
209
+ ):
210
+ legalize_return_type(
211
+ state.return_type, state.func_ir, state.targetctx
212
+ )
213
+ return True
214
+
215
+
216
+ @register_pass(mutates_CFG=True, analysis_only=False)
217
+ class NopythonTypeInference(BaseTypeInference):
218
+ _name = "nopython_type_inference"
219
+
220
+
221
+ @register_pass(mutates_CFG=True, analysis_only=False)
222
+ class PartialTypeInference(BaseTypeInference):
223
+ _name = "partial_type_inference"
224
+ _raise_errors = False
225
+
226
+
227
+ @register_pass(mutates_CFG=False, analysis_only=False)
228
+ class AnnotateTypes(AnalysisPass):
229
+ _name = "annotate_types"
230
+
231
+ def __init__(self):
232
+ AnalysisPass.__init__(self)
233
+
234
+ def get_analysis_usage(self, AU):
235
+ AU.add_required(IRLegalization)
236
+
237
+ def run_pass(self, state):
238
+ """
239
+ Create type annotation after type inference
240
+ """
241
+ func_ir = state.func_ir.copy()
242
+ state.type_annotation = type_annotations.TypeAnnotation(
243
+ func_ir=func_ir,
244
+ typemap=state.typemap,
245
+ calltypes=state.calltypes,
246
+ lifted=state.lifted,
247
+ lifted_from=state.lifted_from,
248
+ args=state.args,
249
+ return_type=state.return_type,
250
+ html_output=config.HTML,
251
+ )
252
+
253
+ if config.ANNOTATE:
254
+ print("ANNOTATION".center(80, "-"))
255
+ print(state.type_annotation)
256
+ print("=" * 80)
257
+ if config.HTML:
258
+ with open(config.HTML, "w") as fout:
259
+ state.type_annotation.html_annotate(fout)
260
+
261
+ return False
262
+
263
+
264
+ @register_pass(mutates_CFG=True, analysis_only=False)
265
+ class NopythonRewrites(FunctionPass):
266
+ _name = "nopython_rewrites"
267
+
268
+ def __init__(self):
269
+ FunctionPass.__init__(self)
270
+
271
+ def run_pass(self, state):
272
+ """
273
+ Perform any intermediate representation rewrites after type
274
+ inference.
275
+ """
276
+ # a bunch of these passes are either making assumptions or rely on some
277
+ # very picky and slightly bizarre state particularly in relation to
278
+ # ir.Del presence. To accommodate, ir.Dels are added ahead of running
279
+ # this pass and stripped at the end.
280
+
281
+ # Ensure we have an IR and type information.
282
+ assert state.func_ir
283
+ assert isinstance(getattr(state, "typemap", None), dict)
284
+ assert isinstance(getattr(state, "calltypes", None), dict)
285
+ msg = (
286
+ "Internal error in post-inference rewriting "
287
+ "pass encountered during compilation of "
288
+ 'function "%s"' % (state.func_id.func_name,)
289
+ )
290
+
291
+ pp = postproc.PostProcessor(state.func_ir)
292
+ pp.run(True)
293
+ with fallback_context(state, msg):
294
+ rewrites.rewrite_registry.apply("after-inference", state)
295
+ pp.remove_dels()
296
+ return True
297
+
298
+
41
299
  class BaseNativeLowering(abc.ABC, LoweringPass):
42
300
  """The base class for a lowering pass. The lowering functionality must be
43
301
  specified in inheriting classes by providing an appropriate lowering class
@@ -71,7 +329,7 @@ class BaseNativeLowering(abc.ABC, LoweringPass):
71
329
  calltypes = state.calltypes
72
330
  flags = state.flags
73
331
  metadata = state.metadata
74
- pre_stats = llvm.passmanagers.dump_refprune_stats()
332
+ pre_stats = passmanagers.dump_refprune_stats()
75
333
 
76
334
  msg = "Function %s failed at nopython mode lowering" % (
77
335
  state.func_id.func_name,
@@ -117,7 +375,7 @@ class BaseNativeLowering(abc.ABC, LoweringPass):
117
375
  call_helper = lower.call_helper
118
376
  del lower
119
377
 
120
- from numba.core.compiler import _LowerResult # TODO: move this
378
+ from numba.cuda.compiler import _LowerResult # TODO: move this
121
379
 
122
380
  if flags.no_compile:
123
381
  state["cr"] = _LowerResult(
@@ -134,9 +392,482 @@ class BaseNativeLowering(abc.ABC, LoweringPass):
134
392
  )
135
393
 
136
394
  # capture pruning stats
137
- post_stats = llvm.passmanagers.dump_refprune_stats()
395
+ post_stats = passmanagers.dump_refprune_stats()
138
396
  metadata["prune_stats"] = post_stats - pre_stats
139
397
 
140
398
  # Save the LLVM pass timings
141
399
  metadata["llvm_pass_timings"] = library.recorded_timings
142
400
  return True
401
+
402
+
403
+ @register_pass(mutates_CFG=True, analysis_only=False)
404
+ class NativeLowering(BaseNativeLowering):
405
+ """Lowering pass for a native function IR described solely in terms of
406
+ Numba's standard `numba.core.ir` nodes."""
407
+
408
+ _name = "native_lowering"
409
+
410
+ @property
411
+ def lowering_class(self):
412
+ return lowering.Lower
413
+
414
+
415
+ @register_pass(mutates_CFG=False, analysis_only=True)
416
+ class NoPythonSupportedFeatureValidation(AnalysisPass):
417
+ """NoPython Mode check: Validates the IR to ensure that features in use are
418
+ in a form that is supported"""
419
+
420
+ _name = "nopython_supported_feature_validation"
421
+
422
+ def __init__(self):
423
+ AnalysisPass.__init__(self)
424
+
425
+ def run_pass(self, state):
426
+ raise_on_unsupported_feature(state.func_ir, state.typemap)
427
+ warn_deprecated(state.func_ir, state.typemap)
428
+ return False
429
+
430
+
431
+ @register_pass(mutates_CFG=False, analysis_only=True)
432
+ class IRLegalization(AnalysisPass):
433
+ _name = "ir_legalization"
434
+
435
+ def __init__(self):
436
+ AnalysisPass.__init__(self)
437
+
438
+ def run_pass(self, state):
439
+ # NOTE: this function call must go last, it checks and fixes invalid IR!
440
+ check_and_legalize_ir(state.func_ir, flags=state.flags)
441
+ return True
442
+
443
+
444
+ @register_pass(mutates_CFG=True, analysis_only=False)
445
+ class NoPythonBackend(LoweringPass):
446
+ _name = "nopython_backend"
447
+
448
+ def __init__(self):
449
+ LoweringPass.__init__(self)
450
+
451
+ def run_pass(self, state):
452
+ """
453
+ Back-end: Generate LLVM IR from Numba IR, compile to machine code
454
+ """
455
+ lowered = state["cr"]
456
+ signature = typing.signature(state.return_type, *state.args)
457
+
458
+ from numba.cuda.compiler import cuda_compile_result
459
+
460
+ state.cr = cuda_compile_result(
461
+ typing_context=state.typingctx,
462
+ target_context=state.targetctx,
463
+ entry_point=lowered.cfunc,
464
+ typing_error=state.status.fail_reason,
465
+ type_annotation=state.type_annotation,
466
+ library=state.library,
467
+ call_helper=lowered.call_helper,
468
+ signature=signature,
469
+ objectmode=False,
470
+ lifted=state.lifted,
471
+ fndesc=lowered.fndesc,
472
+ environment=lowered.env,
473
+ metadata=state.metadata,
474
+ reload_init=state.reload_init,
475
+ )
476
+ return True
477
+
478
+
479
+ @register_pass(mutates_CFG=True, analysis_only=False)
480
+ class InlineOverloads(FunctionPass):
481
+ """
482
+ This pass will inline a function wrapped by the numba.extending.overload
483
+ decorator directly into the site of its call depending on the value set in
484
+ the 'inline' kwarg to the decorator.
485
+
486
+ This is a typed pass. CFG simplification and DCE are performed on
487
+ completion.
488
+ """
489
+
490
+ _name = "inline_overloads"
491
+
492
+ def __init__(self):
493
+ FunctionPass.__init__(self)
494
+
495
+ _DEBUG = False
496
+
497
+ def run_pass(self, state):
498
+ """Run inlining of overloads"""
499
+ if self._DEBUG:
500
+ print("before overload inline".center(80, "-"))
501
+ print(state.func_id.unique_name)
502
+ print(state.func_ir.dump())
503
+ print("".center(80, "-"))
504
+ from numba.cuda.core.inline_closurecall import (
505
+ InlineWorker,
506
+ callee_ir_validator,
507
+ )
508
+
509
+ inline_worker = InlineWorker(
510
+ state.typingctx,
511
+ state.targetctx,
512
+ state.locals,
513
+ state.pipeline,
514
+ state.flags,
515
+ callee_ir_validator,
516
+ state.typemap,
517
+ state.calltypes,
518
+ )
519
+ modified = False
520
+ work_list = list(state.func_ir.blocks.items())
521
+ # use a work list, look for call sites via `ir.Expr.op == call` and
522
+ # then pass these to `self._do_work` to make decisions about inlining.
523
+ while work_list:
524
+ label, block = work_list.pop()
525
+ for i, instr in enumerate(block.body):
526
+ # TO-DO: other statements (setitem)
527
+ if isinstance(instr, ir.Assign):
528
+ expr = instr.value
529
+ if isinstance(expr, ir.Expr):
530
+ workfn = self._do_work_expr
531
+
532
+ if guard(
533
+ workfn,
534
+ state,
535
+ work_list,
536
+ block,
537
+ i,
538
+ expr,
539
+ inline_worker,
540
+ ):
541
+ modified = True
542
+ break # because block structure changed
543
+
544
+ if self._DEBUG:
545
+ print("after overload inline".center(80, "-"))
546
+ print(state.func_id.unique_name)
547
+ print(state.func_ir.dump())
548
+ print("".center(80, "-"))
549
+
550
+ if modified:
551
+ # Remove dead blocks, this is safe as it relies on the CFG only.
552
+ cfg = compute_cfg_from_blocks(state.func_ir.blocks)
553
+ for dead in cfg.dead_nodes():
554
+ del state.func_ir.blocks[dead]
555
+ # clean up blocks
556
+ dead_code_elimination(state.func_ir, typemap=state.typemap)
557
+ # clean up unconditional branches that appear due to inlined
558
+ # functions introducing blocks
559
+ state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)
560
+
561
+ if self._DEBUG:
562
+ print("after overload inline DCE".center(80, "-"))
563
+ print(state.func_id.unique_name)
564
+ print(state.func_ir.dump())
565
+ print("".center(80, "-"))
566
+ return True
567
+
568
+ def _get_attr_info(self, state, expr):
569
+ recv_type = state.typemap[expr.value.name]
570
+ recv_type = types.unliteral(recv_type)
571
+ matched = state.typingctx.find_matching_getattr_template(
572
+ recv_type,
573
+ expr.attr,
574
+ )
575
+ if not matched:
576
+ return None
577
+
578
+ template = matched["template"]
579
+ if getattr(template, "is_method", False):
580
+ # The attribute template is representing a method.
581
+ # Don't inline the getattr.
582
+ return None
583
+
584
+ templates = [template]
585
+ sig = typing.signature(matched["return_type"], recv_type)
586
+ arg_typs = sig.args
587
+ is_method = False
588
+
589
+ return templates, sig, arg_typs, is_method
590
+
591
+ def _get_callable_info(self, state, expr):
592
+ def get_func_type(state, expr):
593
+ func_ty = None
594
+ if expr.op == "call":
595
+ # check this is a known and typed function
596
+ try:
597
+ func_ty = state.typemap[expr.func.name]
598
+ except KeyError:
599
+ # e.g. Calls to CUDA Intrinsic have no mapped type
600
+ # so KeyError
601
+ return None
602
+ if not hasattr(func_ty, "get_call_type"):
603
+ return None
604
+
605
+ elif is_operator_or_getitem(expr):
606
+ func_ty = state.typingctx.resolve_value_type(expr.fn)
607
+ else:
608
+ return None
609
+
610
+ return func_ty
611
+
612
+ if expr.op == "call":
613
+ # try and get a definition for the call, this isn't always
614
+ # possible as it might be a eval(str)/part generated
615
+ # awaiting update etc. (parfors)
616
+ to_inline = None
617
+ try:
618
+ to_inline = state.func_ir.get_definition(expr.func)
619
+ except Exception:
620
+ return None
621
+
622
+ # do not handle closure inlining here, another pass deals with that
623
+ if getattr(to_inline, "op", False) == "make_function":
624
+ return None
625
+
626
+ func_ty = get_func_type(state, expr)
627
+ if func_ty is None:
628
+ return None
629
+
630
+ sig = state.calltypes[expr]
631
+ if not sig:
632
+ return None
633
+
634
+ templates, arg_typs, is_method = None, None, False
635
+ if getattr(func_ty, "template", None) is not None:
636
+ # @overload_method
637
+ is_method = True
638
+ templates = [func_ty.template]
639
+ arg_typs = (func_ty.template.this,) + sig.args
640
+ else:
641
+ # @overload case
642
+ templates = getattr(func_ty, "templates", None)
643
+ arg_typs = sig.args
644
+
645
+ return templates, sig, arg_typs, is_method
646
+
647
+ def _do_work_expr(self, state, work_list, block, i, expr, inline_worker):
648
+ def select_template(templates, args):
649
+ if templates is None:
650
+ return None
651
+
652
+ impl = None
653
+ for template in templates:
654
+ inline_type = getattr(template, "_inline", None)
655
+ if inline_type is None:
656
+ # inline not defined
657
+ continue
658
+ if args not in template._inline_overloads:
659
+ # skip overloads not matching signature
660
+ continue
661
+ if not inline_type.is_never_inline:
662
+ try:
663
+ impl = template._overload_func(*args)
664
+ if impl is None:
665
+ raise Exception # abort for this template
666
+ break
667
+ except Exception:
668
+ continue
669
+ else:
670
+ return None
671
+
672
+ return template, inline_type, impl
673
+
674
+ inlinee_info = None
675
+ if expr.op == "getattr":
676
+ inlinee_info = self._get_attr_info(state, expr)
677
+ else:
678
+ inlinee_info = self._get_callable_info(state, expr)
679
+
680
+ if not inlinee_info:
681
+ return False
682
+
683
+ templates, sig, arg_typs, is_method = inlinee_info
684
+ inlinee = select_template(templates, arg_typs)
685
+ if inlinee is None:
686
+ return False
687
+ template, inlinee_type, impl = inlinee
688
+
689
+ return self._run_inliner(
690
+ state,
691
+ inlinee_type,
692
+ sig,
693
+ template,
694
+ arg_typs,
695
+ expr,
696
+ i,
697
+ impl,
698
+ block,
699
+ work_list,
700
+ is_method,
701
+ inline_worker,
702
+ )
703
+
704
+ def _run_inliner(
705
+ self,
706
+ state,
707
+ inline_type,
708
+ sig,
709
+ template,
710
+ arg_typs,
711
+ expr,
712
+ i,
713
+ impl,
714
+ block,
715
+ work_list,
716
+ is_method,
717
+ inline_worker,
718
+ ):
719
+ do_inline = True
720
+ if not inline_type.is_always_inline:
721
+ from numba.core.typing.templates import _inline_info
722
+
723
+ caller_inline_info = _inline_info(
724
+ state.func_ir, state.typemap, state.calltypes, sig
725
+ )
726
+
727
+ # must be a cost-model function, run the function
728
+ iinfo = template._inline_overloads[arg_typs]["iinfo"]
729
+ if inline_type.has_cost_model:
730
+ do_inline = inline_type.value(expr, caller_inline_info, iinfo)
731
+ else:
732
+ assert "unreachable"
733
+
734
+ if do_inline:
735
+ if is_method:
736
+ if not self._add_method_self_arg(state, expr):
737
+ return False
738
+ arg_typs = template._inline_overloads[arg_typs]["folded_args"]
739
+ iinfo = template._inline_overloads[arg_typs]["iinfo"]
740
+ freevars = iinfo.func_ir.func_id.func.__code__.co_freevars
741
+ _, _, _, new_blocks = inline_worker.inline_ir(
742
+ state.func_ir,
743
+ block,
744
+ i,
745
+ iinfo.func_ir,
746
+ freevars,
747
+ arg_typs=arg_typs,
748
+ )
749
+ if work_list is not None:
750
+ for blk in new_blocks:
751
+ work_list.append(blk)
752
+ return True
753
+ else:
754
+ return False
755
+
756
+ def _add_method_self_arg(self, state, expr):
757
+ func_def = guard(get_definition, state.func_ir, expr.func)
758
+ if func_def is None:
759
+ return False
760
+ expr.args.insert(0, func_def.value)
761
+ return True
762
+
763
+
764
+ @register_pass(mutates_CFG=False, analysis_only=False)
765
+ class DeadCodeElimination(FunctionPass):
766
+ """
767
+ Does dead code elimination
768
+ """
769
+
770
+ _name = "dead_code_elimination"
771
+
772
+ def __init__(self):
773
+ FunctionPass.__init__(self)
774
+
775
+ def run_pass(self, state):
776
+ dead_code_elimination(state.func_ir, state.typemap)
777
+ return True
778
+
779
+
780
+ @register_pass(mutates_CFG=False, analysis_only=False)
781
+ class PreLowerStripPhis(FunctionPass):
782
+ """Remove phi nodes (ir.Expr.phi) introduced by SSA.
783
+
784
+ This is needed before Lowering because the phi nodes in Numba IR do not
785
+ match the semantics of phi nodes in LLVM IR. In Numba IR, phi nodes may
786
+ expand into multiple LLVM instructions.
787
+ """
788
+
789
+ _name = "strip_phis"
790
+
791
+ def __init__(self):
792
+ FunctionPass.__init__(self)
793
+
794
+ def run_pass(self, state):
795
+ state.func_ir = self._strip_phi_nodes(state.func_ir)
796
+
797
+ # Rerun postprocessor to update metadata
798
+ post_proc = postproc.PostProcessor(state.func_ir)
799
+ post_proc.run(emit_dels=False)
800
+
801
+ # Ensure we are not in objectmode generator
802
+ if (
803
+ state.func_ir.generator_info is not None
804
+ and state.typemap is not None
805
+ ):
806
+ # Rebuild generator type
807
+ # TODO: move this into PostProcessor
808
+ gentype = state.return_type
809
+ state_vars = state.func_ir.generator_info.state_vars
810
+ state_types = [state.typemap[k] for k in state_vars]
811
+ state.return_type = types.Generator(
812
+ gen_func=gentype.gen_func,
813
+ yield_type=gentype.yield_type,
814
+ arg_types=gentype.arg_types,
815
+ state_types=state_types,
816
+ has_finalizer=gentype.has_finalizer,
817
+ )
818
+ return True
819
+
820
+ def _strip_phi_nodes(self, func_ir):
821
+ """Strip Phi nodes from ``func_ir``
822
+
823
+ For each phi node, put incoming value to their respective incoming
824
+ basic-block at possibly the latest position (i.e. after the latest
825
+ assignment to the corresponding variable).
826
+ """
827
+ exporters = defaultdict(list)
828
+ phis = set()
829
+ # Find all variables that needs to be exported
830
+ for label, block in func_ir.blocks.items():
831
+ for assign in block.find_insts(ir.Assign):
832
+ if isinstance(assign.value, ir.Expr):
833
+ if assign.value.op == "phi":
834
+ phis.add(assign)
835
+ phi = assign.value
836
+ for ib, iv in zip(
837
+ phi.incoming_blocks, phi.incoming_values
838
+ ):
839
+ exporters[ib].append((assign.target, iv))
840
+
841
+ # Rewrite the blocks with the new exporting assignments
842
+ newblocks = {}
843
+ for label, block in func_ir.blocks.items():
844
+ newblk = copy(block)
845
+ newblocks[label] = newblk
846
+
847
+ # strip phis
848
+ newblk.body = [stmt for stmt in block.body if stmt not in phis]
849
+
850
+ # insert exporters
851
+ for target, rhs in exporters[label]:
852
+ # If RHS is undefined
853
+ if rhs is ir.UNDEFINED:
854
+ # Put in a NULL initializer, set the location to be in what
855
+ # will eventually materialize as the prologue.
856
+ rhs = ir.Expr.null(loc=func_ir.loc)
857
+
858
+ assign = ir.Assign(target=target, value=rhs, loc=rhs.loc)
859
+ # Insert at the earliest possible location; i.e. after the
860
+ # last assignment to rhs
861
+ assignments = [
862
+ stmt
863
+ for stmt in newblk.find_insts(ir.Assign)
864
+ if stmt.target == rhs
865
+ ]
866
+ if assignments:
867
+ last_assignment = assignments[-1]
868
+ newblk.insert_after(assign, last_assignment)
869
+ else:
870
+ newblk.prepend(assign)
871
+
872
+ func_ir.blocks = newblocks
873
+ return func_ir