numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of numba-cuda might be problematic. Click here for more details.

Files changed (172) hide show
  1. numba_cuda/VERSION +1 -1
  2. numba_cuda/numba/cuda/__init__.py +1 -1
  3. numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
  4. numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
  5. numba_cuda/numba/cuda/api.py +6 -1
  6. numba_cuda/numba/cuda/bf16.py +285 -2
  7. numba_cuda/numba/cuda/cgutils.py +2 -2
  8. numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
  9. numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
  10. numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
  11. numba_cuda/numba/cuda/codegen.py +1 -1
  12. numba_cuda/numba/cuda/compiler.py +373 -30
  13. numba_cuda/numba/cuda/core/analysis.py +319 -0
  14. numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
  15. numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
  16. numba_cuda/numba/cuda/core/base.py +1289 -0
  17. numba_cuda/numba/cuda/core/bytecode.py +727 -0
  18. numba_cuda/numba/cuda/core/caching.py +2 -2
  19. numba_cuda/numba/cuda/core/compiler.py +6 -14
  20. numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
  21. numba_cuda/numba/cuda/core/config.py +747 -0
  22. numba_cuda/numba/cuda/core/consts.py +124 -0
  23. numba_cuda/numba/cuda/core/cpu.py +370 -0
  24. numba_cuda/numba/cuda/core/environment.py +68 -0
  25. numba_cuda/numba/cuda/core/event.py +511 -0
  26. numba_cuda/numba/cuda/core/funcdesc.py +330 -0
  27. numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
  28. numba_cuda/numba/cuda/core/interpreter.py +48 -26
  29. numba_cuda/numba/cuda/core/ir_utils.py +15 -26
  30. numba_cuda/numba/cuda/core/options.py +262 -0
  31. numba_cuda/numba/cuda/core/postproc.py +249 -0
  32. numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
  33. numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
  34. numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
  35. numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
  36. numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
  37. numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
  38. numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
  39. numba_cuda/numba/cuda/core/ssa.py +496 -0
  40. numba_cuda/numba/cuda/core/targetconfig.py +329 -0
  41. numba_cuda/numba/cuda/core/tracing.py +231 -0
  42. numba_cuda/numba/cuda/core/transforms.py +952 -0
  43. numba_cuda/numba/cuda/core/typed_passes.py +738 -7
  44. numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
  45. numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
  46. numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
  47. numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
  48. numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
  49. numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
  50. numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
  51. numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
  52. numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
  53. numba_cuda/numba/cuda/cuda_paths.py +422 -246
  54. numba_cuda/numba/cuda/cudadecl.py +1 -1
  55. numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
  56. numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
  57. numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
  58. numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
  59. numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
  60. numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
  61. numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
  62. numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
  63. numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
  64. numba_cuda/numba/cuda/cudaimpl.py +5 -1
  65. numba_cuda/numba/cuda/debuginfo.py +85 -2
  66. numba_cuda/numba/cuda/decorators.py +3 -3
  67. numba_cuda/numba/cuda/descriptor.py +3 -4
  68. numba_cuda/numba/cuda/deviceufunc.py +66 -2
  69. numba_cuda/numba/cuda/dispatcher.py +18 -39
  70. numba_cuda/numba/cuda/flags.py +141 -1
  71. numba_cuda/numba/cuda/fp16.py +0 -2
  72. numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
  73. numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
  74. numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
  75. numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
  76. numba_cuda/numba/cuda/lowering.py +7 -144
  77. numba_cuda/numba/cuda/mathimpl.py +2 -1
  78. numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
  79. numba_cuda/numba/cuda/misc/findlib.py +75 -0
  80. numba_cuda/numba/cuda/models.py +9 -1
  81. numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
  82. numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
  83. numba_cuda/numba/cuda/np/numpy_support.py +553 -0
  84. numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
  85. numba_cuda/numba/cuda/nvvmutils.py +1 -1
  86. numba_cuda/numba/cuda/printimpl.py +12 -1
  87. numba_cuda/numba/cuda/random.py +1 -1
  88. numba_cuda/numba/cuda/serialize.py +1 -1
  89. numba_cuda/numba/cuda/simulator/__init__.py +1 -1
  90. numba_cuda/numba/cuda/simulator/api.py +1 -1
  91. numba_cuda/numba/cuda/simulator/compiler.py +4 -0
  92. numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
  93. numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
  94. numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
  95. numba_cuda/numba/cuda/target.py +35 -17
  96. numba_cuda/numba/cuda/testing.py +7 -19
  97. numba_cuda/numba/cuda/tests/__init__.py +1 -1
  98. numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
  99. numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
  100. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
  101. numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
  102. numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
  103. numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
  104. numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
  105. numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
  106. numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
  107. numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
  108. numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
  109. numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
  110. numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
  111. numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
  112. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
  113. numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
  114. numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
  115. numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
  116. numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
  117. numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
  118. numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
  119. numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
  120. numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
  121. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
  122. numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
  123. numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
  124. numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
  125. numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
  126. numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
  127. numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
  128. numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
  129. numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
  130. numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
  131. numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
  132. numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
  133. numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
  134. numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
  135. numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
  136. numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
  137. numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
  138. numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
  139. numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
  140. numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
  141. numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
  142. numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
  143. numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
  144. numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
  145. numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
  146. numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
  147. numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
  148. numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
  149. numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
  150. numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
  151. numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
  152. numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
  153. numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
  154. numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
  155. numba_cuda/numba/cuda/tests/support.py +55 -15
  156. numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
  157. numba_cuda/numba/cuda/types.py +56 -0
  158. numba_cuda/numba/cuda/typing/__init__.py +9 -1
  159. numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
  160. numba_cuda/numba/cuda/typing/context.py +751 -0
  161. numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
  162. numba_cuda/numba/cuda/typing/npydecl.py +658 -0
  163. numba_cuda/numba/cuda/typing/templates.py +7 -6
  164. numba_cuda/numba/cuda/ufuncs.py +3 -3
  165. numba_cuda/numba/cuda/utils.py +6 -112
  166. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
  167. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
  168. numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
  169. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
  170. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
  171. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
  172. {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1983 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: BSD-2-Clause
3
+
4
+ from collections import defaultdict, namedtuple
5
+ from contextlib import contextmanager
6
+ from copy import deepcopy, copy
7
+ import warnings
8
+
9
+ from numba.cuda.core.compiler_machinery import (
10
+ FunctionPass,
11
+ AnalysisPass,
12
+ SSACompliantMixin,
13
+ register_pass,
14
+ )
15
+ from numba.cuda.core import postproc, bytecode, transforms, inline_closurecall
16
+ from numba.core import (
17
+ errors,
18
+ types,
19
+ ir,
20
+ )
21
+ from numba.cuda.core import consts, rewrites, config
22
+ from numba.cuda.core.interpreter import Interpreter
23
+
24
+
25
+ from numba.misc.special import literal_unroll
26
+ from numba.cuda.core.analysis import dead_branch_prune
27
+ from numba.core.analysis import (
28
+ rewrite_semantic_constants,
29
+ find_literally_calls,
30
+ compute_cfg_from_blocks,
31
+ compute_use_defs,
32
+ )
33
+ from numba.cuda.core.ir_utils import (
34
+ guard,
35
+ resolve_func_from_module,
36
+ simplify_CFG,
37
+ GuardException,
38
+ convert_code_obj_to_function,
39
+ build_definitions,
40
+ replace_var_names,
41
+ get_name_var_table,
42
+ compile_to_numba_ir,
43
+ get_definition,
44
+ find_max_label,
45
+ rename_labels,
46
+ transfer_scope,
47
+ fixup_var_define_in_scope,
48
+ )
49
+ from numba.cuda.core.ssa import reconstruct_ssa
50
+
51
+
52
+ @contextmanager
53
+ def fallback_context(state, msg):
54
+ """
55
+ Wraps code that would signal a fallback to object mode
56
+ """
57
+ try:
58
+ yield
59
+ except Exception as e:
60
+ if not state.status.can_fallback:
61
+ raise
62
+ else:
63
+ # Clear all references attached to the traceback
64
+ e = e.with_traceback(None)
65
+ # this emits a warning containing the error message body in the
66
+ # case of fallback from npm to objmode
67
+ loop_lift = "" if state.flags.enable_looplift else "OUT"
68
+ msg_rewrite = (
69
+ "\nCompilation is falling back to object mode "
70
+ "WITH%s looplifting enabled because %s" % (loop_lift, msg)
71
+ )
72
+ warnings.warn_explicit(
73
+ "%s due to: %s" % (msg_rewrite, e),
74
+ errors.NumbaWarning,
75
+ state.func_id.filename,
76
+ state.func_id.firstlineno,
77
+ )
78
+ raise
79
+
80
+
81
+ @register_pass(mutates_CFG=True, analysis_only=False)
82
+ class ExtractByteCode(FunctionPass):
83
+ _name = "extract_bytecode"
84
+
85
+ def __init__(self):
86
+ FunctionPass.__init__(self)
87
+
88
+ def run_pass(self, state):
89
+ """
90
+ Extract bytecode from function
91
+ """
92
+ func_id = state["func_id"]
93
+ bc = bytecode.ByteCode(func_id)
94
+ if config.DUMP_BYTECODE:
95
+ print(bc.dump())
96
+
97
+ state["bc"] = bc
98
+ return True
99
+
100
+
101
+ @register_pass(mutates_CFG=True, analysis_only=False)
102
+ class TranslateByteCode(FunctionPass):
103
+ _name = "translate_bytecode"
104
+
105
+ def __init__(self):
106
+ FunctionPass.__init__(self)
107
+
108
+ def run_pass(self, state):
109
+ """
110
+ Analyze bytecode and translating to Numba IR
111
+ """
112
+ func_id = state["func_id"]
113
+ bc = state["bc"]
114
+ interp = Interpreter(func_id)
115
+ func_ir = interp.interpret(bc)
116
+ state["func_ir"] = func_ir
117
+ return True
118
+
119
+
120
+ @register_pass(mutates_CFG=True, analysis_only=False)
121
+ class FixupArgs(FunctionPass):
122
+ _name = "fixup_args"
123
+
124
+ def __init__(self):
125
+ FunctionPass.__init__(self)
126
+
127
+ def run_pass(self, state):
128
+ state["nargs"] = state["func_ir"].arg_count
129
+ if not state["args"] and state["flags"].force_pyobject:
130
+ # Allow an empty argument types specification when object mode
131
+ # is explicitly requested.
132
+ state["args"] = (types.pyobject,) * state["nargs"]
133
+ elif len(state["args"]) != state["nargs"]:
134
+ raise TypeError(
135
+ "Signature mismatch: %d argument types given, "
136
+ "but function takes %d arguments"
137
+ % (len(state["args"]), state["nargs"])
138
+ )
139
+ return True
140
+
141
+
142
+ @register_pass(mutates_CFG=True, analysis_only=False)
143
+ class IRProcessing(FunctionPass):
144
+ _name = "ir_processing"
145
+
146
+ def __init__(self):
147
+ FunctionPass.__init__(self)
148
+
149
+ def run_pass(self, state):
150
+ func_ir = state["func_ir"]
151
+ post_proc = postproc.PostProcessor(func_ir)
152
+ post_proc.run()
153
+
154
+ if config.DEBUG or config.DUMP_IR:
155
+ name = func_ir.func_id.func_qualname
156
+ print(("IR DUMP: %s" % name).center(80, "-"))
157
+ func_ir.dump()
158
+ if func_ir.is_generator:
159
+ print(("GENERATOR INFO: %s" % name).center(80, "-"))
160
+ func_ir.dump_generator_info()
161
+ return True
162
+
163
+
164
+ @register_pass(mutates_CFG=True, analysis_only=False)
165
+ class RewriteSemanticConstants(FunctionPass):
166
+ _name = "rewrite_semantic_constants"
167
+
168
+ def __init__(self):
169
+ FunctionPass.__init__(self)
170
+
171
+ def run_pass(self, state):
172
+ """
173
+ This prunes dead branches, a dead branch is one which is derivable as
174
+ not taken at compile time purely based on const/literal evaluation.
175
+ """
176
+ assert state.func_ir
177
+ msg = (
178
+ "Internal error in pre-inference dead branch pruning "
179
+ "pass encountered during compilation of "
180
+ 'function "%s"' % (state.func_id.func_name,)
181
+ )
182
+ with fallback_context(state, msg):
183
+ rewrite_semantic_constants(state.func_ir, state.args)
184
+
185
+ return True
186
+
187
+
188
+ @register_pass(mutates_CFG=True, analysis_only=False)
189
+ class DeadBranchPrune(SSACompliantMixin, FunctionPass):
190
+ _name = "dead_branch_prune"
191
+
192
+ def __init__(self):
193
+ FunctionPass.__init__(self)
194
+
195
+ def run_pass(self, state):
196
+ """
197
+ This prunes dead branches, a dead branch is one which is derivable as
198
+ not taken at compile time purely based on const/literal evaluation.
199
+ """
200
+
201
+ # purely for demonstration purposes, obtain the analysis from a pass
202
+ # declare as a required dependent
203
+ semantic_const_analysis = self.get_analysis(type(self)) # noqa
204
+
205
+ assert state.func_ir
206
+ msg = (
207
+ "Internal error in pre-inference dead branch pruning "
208
+ "pass encountered during compilation of "
209
+ 'function "%s"' % (state.func_id.func_name,)
210
+ )
211
+ with fallback_context(state, msg):
212
+ dead_branch_prune(state.func_ir, state.args)
213
+
214
+ return True
215
+
216
+ def get_analysis_usage(self, AU):
217
+ AU.add_required(RewriteSemanticConstants)
218
+
219
+
220
+ @register_pass(mutates_CFG=True, analysis_only=False)
221
+ class InlineClosureLikes(FunctionPass):
222
+ _name = "inline_closure_likes"
223
+
224
+ def __init__(self):
225
+ FunctionPass.__init__(self)
226
+
227
+ def run_pass(self, state):
228
+ # Ensure we have an IR and type information.
229
+ assert state.func_ir
230
+
231
+ # if the return type is a pyobject, there's no type info available and
232
+ # no ability to resolve certain typed function calls in the array
233
+ # inlining code, use this variable to indicate
234
+ typed_pass = not isinstance(state.return_type, types.misc.PyObject)
235
+
236
+ inline_pass = inline_closurecall.InlineClosureCallPass(
237
+ state.func_ir,
238
+ state.flags.auto_parallel,
239
+ None,
240
+ typed_pass,
241
+ )
242
+ inline_pass.run()
243
+
244
+ # Remove all Dels, and re-run postproc
245
+ post_proc = postproc.PostProcessor(state.func_ir)
246
+ post_proc.run()
247
+
248
+ fixup_var_define_in_scope(state.func_ir.blocks)
249
+
250
+ return True
251
+
252
+
253
+ @register_pass(mutates_CFG=True, analysis_only=False)
254
+ class GenericRewrites(FunctionPass):
255
+ _name = "generic_rewrites"
256
+
257
+ def __init__(self):
258
+ FunctionPass.__init__(self)
259
+
260
+ def run_pass(self, state):
261
+ """
262
+ Perform any intermediate representation rewrites before type
263
+ inference.
264
+ """
265
+ assert state.func_ir
266
+ msg = (
267
+ "Internal error in pre-inference rewriting "
268
+ "pass encountered during compilation of "
269
+ 'function "%s"' % (state.func_id.func_name,)
270
+ )
271
+ with fallback_context(state, msg):
272
+ rewrites.rewrite_registry.apply("before-inference", state)
273
+ return True
274
+
275
+
276
+ @register_pass(mutates_CFG=True, analysis_only=False)
277
+ class WithLifting(FunctionPass):
278
+ _name = "with_lifting"
279
+
280
+ def __init__(self):
281
+ FunctionPass.__init__(self)
282
+
283
+ def run_pass(self, state):
284
+ """
285
+ Extract with-contexts
286
+ """
287
+ main, withs = transforms.with_lifting(
288
+ func_ir=state.func_ir,
289
+ typingctx=state.typingctx,
290
+ targetctx=state.targetctx,
291
+ flags=state.flags,
292
+ locals=state.locals,
293
+ )
294
+ if withs:
295
+ from numba.cuda.compiler import compile_ir
296
+ from numba.cuda.core.compiler import _EarlyPipelineCompletion
297
+
298
+ cres = compile_ir(
299
+ state.typingctx,
300
+ state.targetctx,
301
+ main,
302
+ state.args,
303
+ state.return_type,
304
+ state.flags,
305
+ state.locals,
306
+ lifted=tuple(withs),
307
+ lifted_from=None,
308
+ pipeline_class=type(state.pipeline),
309
+ )
310
+ raise _EarlyPipelineCompletion(cres)
311
+ return True
312
+
313
+
314
+ @register_pass(mutates_CFG=True, analysis_only=False)
315
+ class InlineInlinables(FunctionPass):
316
+ """
317
+ This pass will inline a function wrapped by the numba.jit decorator directly
318
+ into the site of its call depending on the value set in the 'inline' kwarg
319
+ to the decorator.
320
+
321
+ This is an untyped pass. CFG simplification is performed at the end of the
322
+ pass but no block level clean up is performed on the mutated IR (typing
323
+ information is not available to do so).
324
+ """
325
+
326
+ _name = "inline_inlinables"
327
+ _DEBUG = False
328
+
329
+ def __init__(self):
330
+ FunctionPass.__init__(self)
331
+
332
+ def run_pass(self, state):
333
+ """Run inlining of inlinables"""
334
+ if self._DEBUG:
335
+ print("before inline".center(80, "-"))
336
+ print(state.func_ir.dump())
337
+ print("".center(80, "-"))
338
+
339
+ inline_worker = inline_closurecall.InlineWorker(
340
+ state.typingctx,
341
+ state.targetctx,
342
+ state.locals,
343
+ state.pipeline,
344
+ state.flags,
345
+ validator=inline_closurecall.callee_ir_validator,
346
+ )
347
+
348
+ modified = False
349
+ # use a work list, look for call sites via `ir.Expr.op == call` and
350
+ # then pass these to `self._do_work` to make decisions about inlining.
351
+ work_list = list(state.func_ir.blocks.items())
352
+ while work_list:
353
+ label, block = work_list.pop()
354
+ for i, instr in enumerate(block.body):
355
+ if isinstance(instr, ir.Assign):
356
+ expr = instr.value
357
+ if isinstance(expr, ir.Expr) and expr.op == "call":
358
+ if guard(
359
+ self._do_work,
360
+ state,
361
+ work_list,
362
+ block,
363
+ i,
364
+ expr,
365
+ inline_worker,
366
+ ):
367
+ modified = True
368
+ break # because block structure changed
369
+
370
+ if modified:
371
+ # clean up unconditional branches that appear due to inlined
372
+ # functions introducing blocks
373
+ cfg = compute_cfg_from_blocks(state.func_ir.blocks)
374
+ for dead in cfg.dead_nodes():
375
+ del state.func_ir.blocks[dead]
376
+ post_proc = postproc.PostProcessor(state.func_ir)
377
+ post_proc.run()
378
+ state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)
379
+
380
+ if self._DEBUG:
381
+ print("after inline".center(80, "-"))
382
+ print(state.func_ir.dump())
383
+ print("".center(80, "-"))
384
+ return True
385
+
386
+ def _do_work(self, state, work_list, block, i, expr, inline_worker):
387
+ from numba.cuda.compiler import run_frontend
388
+ from numba.cuda.core.options import InlineOptions
389
+
390
+ # try and get a definition for the call, this isn't always possible as
391
+ # it might be a eval(str)/part generated awaiting update etc. (parfors)
392
+ to_inline = None
393
+ try:
394
+ to_inline = state.func_ir.get_definition(expr.func)
395
+ except Exception:
396
+ if self._DEBUG:
397
+ print("Cannot find definition for %s" % expr.func)
398
+ return False
399
+ # do not handle closure inlining here, another pass deals with that.
400
+ if getattr(to_inline, "op", False) == "make_function":
401
+ return False
402
+
403
+ # see if the definition is a "getattr", in which case walk the IR to
404
+ # try and find the python function via the module from which it's
405
+ # imported, this should all be encoded in the IR.
406
+ if getattr(to_inline, "op", False) == "getattr":
407
+ val = resolve_func_from_module(state.func_ir, to_inline)
408
+ else:
409
+ # This is likely a freevar or global
410
+ #
411
+ # NOTE: getattr 'value' on a call may fail if it's an ir.Expr as
412
+ # getattr is overloaded to look in _kws.
413
+ try:
414
+ val = getattr(to_inline, "value", False)
415
+ except Exception:
416
+ raise GuardException
417
+
418
+ # if something was found...
419
+ if val:
420
+ # check it's dispatcher-like, the targetoptions attr holds the
421
+ # kwargs supplied in the jit decorator and is where 'inline' will
422
+ # be if it is present.
423
+ topt = getattr(val, "targetoptions", False)
424
+ if topt:
425
+ inline_type = topt.get("inline", None)
426
+ # has 'inline' been specified?
427
+ if inline_type is not None:
428
+ inline_opt = InlineOptions(inline_type)
429
+ # Could this be inlinable?
430
+ if not inline_opt.is_never_inline:
431
+ # yes, it could be inlinable
432
+ do_inline = True
433
+ pyfunc = val.py_func
434
+ # Has it got an associated cost model?
435
+ if inline_opt.has_cost_model:
436
+ # yes, it has a cost model, use it to determine
437
+ # whether to do the inline
438
+ py_func_ir = run_frontend(pyfunc)
439
+ do_inline = inline_type(
440
+ expr, state.func_ir, py_func_ir
441
+ )
442
+ # if do_inline is True then inline!
443
+ if do_inline:
444
+ _, _, _, new_blocks = inline_worker.inline_function(
445
+ state.func_ir,
446
+ block,
447
+ i,
448
+ pyfunc,
449
+ )
450
+ if work_list is not None:
451
+ for blk in new_blocks:
452
+ work_list.append(blk)
453
+ return True
454
+ return False
455
+
456
+
457
+ @register_pass(mutates_CFG=False, analysis_only=False)
458
+ class PreserveIR(AnalysisPass):
459
+ """
460
+ Preserves the IR in the metadata
461
+ """
462
+
463
+ _name = "preserve_ir"
464
+
465
+ def __init__(self):
466
+ AnalysisPass.__init__(self)
467
+
468
+ def run_pass(self, state):
469
+ state.metadata["preserved_ir"] = state.func_ir.copy()
470
+ return False
471
+
472
+
473
+ @register_pass(mutates_CFG=False, analysis_only=True)
474
+ class FindLiterallyCalls(FunctionPass):
475
+ """Find calls to `numba.literally()` and signal if its requirement is not
476
+ satisfied.
477
+ """
478
+
479
+ _name = "find_literally"
480
+
481
+ def __init__(self):
482
+ FunctionPass.__init__(self)
483
+
484
+ def run_pass(self, state):
485
+ find_literally_calls(state.func_ir, state.args)
486
+ return False
487
+
488
+
489
+ @register_pass(mutates_CFG=True, analysis_only=False)
490
+ class CanonicalizeLoopExit(FunctionPass):
491
+ """A pass to canonicalize loop exit by splitting it from function exit."""
492
+
493
+ _name = "canonicalize_loop_exit"
494
+
495
+ def __init__(self):
496
+ FunctionPass.__init__(self)
497
+
498
+ def run_pass(self, state):
499
+ fir = state.func_ir
500
+ cfg = compute_cfg_from_blocks(fir.blocks)
501
+ status = False
502
+ for loop in cfg.loops().values():
503
+ for exit_label in loop.exits:
504
+ if exit_label in cfg.exit_points():
505
+ self._split_exit_block(fir, cfg, exit_label)
506
+ status = True
507
+
508
+ fir._reset_analysis_variables()
509
+
510
+ vlt = postproc.VariableLifetime(fir.blocks)
511
+ fir.variable_lifetime = vlt
512
+ return status
513
+
514
+ def _split_exit_block(self, fir, cfg, exit_label):
515
+ curblock = fir.blocks[exit_label]
516
+ newlabel = exit_label + 1
517
+ newlabel = find_max_label(fir.blocks) + 1
518
+ fir.blocks[newlabel] = curblock
519
+ newblock = ir.Block(scope=curblock.scope, loc=curblock.loc)
520
+ newblock.append(ir.Jump(newlabel, loc=curblock.loc))
521
+ fir.blocks[exit_label] = newblock
522
+ # Rename all labels
523
+ fir.blocks = rename_labels(fir.blocks)
524
+
525
+
526
+ @register_pass(mutates_CFG=True, analysis_only=False)
527
+ class CanonicalizeLoopEntry(FunctionPass):
528
+ """A pass to canonicalize loop header by splitting it from function entry.
529
+
530
+ This is needed for loop-lifting; esp in py3.8
531
+ """
532
+
533
+ _name = "canonicalize_loop_entry"
534
+ _supported_globals = {range, enumerate, zip}
535
+
536
+ def __init__(self):
537
+ FunctionPass.__init__(self)
538
+
539
+ def run_pass(self, state):
540
+ fir = state.func_ir
541
+ cfg = compute_cfg_from_blocks(fir.blocks)
542
+ status = False
543
+ for loop in cfg.loops().values():
544
+ if len(loop.entries) == 1:
545
+ [entry_label] = loop.entries
546
+ if entry_label == cfg.entry_point():
547
+ self._split_entry_block(fir, cfg, loop, entry_label)
548
+ status = True
549
+ fir._reset_analysis_variables()
550
+
551
+ vlt = postproc.VariableLifetime(fir.blocks)
552
+ fir.variable_lifetime = vlt
553
+ return status
554
+
555
+ def _split_entry_block(self, fir, cfg, loop, entry_label):
556
+ # Find iterator inputs into the for-loop header
557
+ header_block = fir.blocks[loop.header]
558
+ deps = set()
559
+ for expr in header_block.find_exprs(op="iternext"):
560
+ deps.add(expr.value)
561
+ # Find the getiter for each iterator
562
+ entry_block = fir.blocks[entry_label]
563
+
564
+ # Find the start of loop entry statement that needs to be included.
565
+ startpt = None
566
+ list_of_insts = list(entry_block.find_insts(ir.Assign))
567
+ for assign in reversed(list_of_insts):
568
+ if assign.target in deps:
569
+ rhs = assign.value
570
+ if isinstance(rhs, ir.Var):
571
+ if rhs.is_temp:
572
+ deps.add(rhs)
573
+ elif isinstance(rhs, ir.Expr):
574
+ expr = rhs
575
+ if expr.op == "getiter":
576
+ startpt = assign
577
+ if expr.value.is_temp:
578
+ deps.add(expr.value)
579
+ elif expr.op == "call":
580
+ defn = guard(get_definition, fir, expr.func)
581
+ if isinstance(defn, ir.Global):
582
+ if expr.func.is_temp:
583
+ deps.add(expr.func)
584
+ elif (
585
+ isinstance(rhs, ir.Global)
586
+ and rhs.value in self._supported_globals
587
+ ):
588
+ startpt = assign
589
+
590
+ if startpt is None:
591
+ return
592
+
593
+ splitpt = entry_block.body.index(startpt)
594
+ new_block = entry_block.copy()
595
+ new_block.body = new_block.body[splitpt:]
596
+ new_block.loc = new_block.body[0].loc
597
+ new_label = find_max_label(fir.blocks) + 1
598
+ entry_block.body = entry_block.body[:splitpt]
599
+ entry_block.append(ir.Jump(new_label, loc=new_block.loc))
600
+
601
+ fir.blocks[new_label] = new_block
602
+ # Rename all labels
603
+ fir.blocks = rename_labels(fir.blocks)
604
+
605
+
606
+ @register_pass(mutates_CFG=False, analysis_only=True)
607
+ class PrintIRCFG(FunctionPass):
608
+ _name = "print_ir_cfg"
609
+
610
+ def __init__(self):
611
+ FunctionPass.__init__(self)
612
+ self._ver = 0
613
+
614
+ def run_pass(self, state):
615
+ fir = state.func_ir
616
+ self._ver += 1
617
+ fir.render_dot(filename_prefix="v{}".format(self._ver)).render()
618
+ return False
619
+
620
+
621
+ @register_pass(mutates_CFG=True, analysis_only=False)
622
+ class MakeFunctionToJitFunction(FunctionPass):
623
+ """
624
+ This swaps an ir.Expr.op == "make_function" i.e. a closure, for a compiled
625
+ function containing the closure body and puts it in ir.Global. It's a 1:1
626
+ statement value swap. `make_function` is already untyped
627
+ """
628
+
629
+ _name = "make_function_op_code_to_jit_function"
630
+
631
+ def __init__(self):
632
+ FunctionPass.__init__(self)
633
+
634
+ def run_pass(self, state):
635
+ from numba import njit
636
+
637
+ func_ir = state.func_ir
638
+ mutated = False
639
+ for idx, blk in func_ir.blocks.items():
640
+ for stmt in blk.body:
641
+ if isinstance(stmt, ir.Assign):
642
+ if isinstance(stmt.value, ir.Expr):
643
+ if stmt.value.op == "make_function":
644
+ node = stmt.value
645
+ getdef = func_ir.get_definition
646
+ kw_default = getdef(node.defaults)
647
+ ok = False
648
+ if kw_default is None or isinstance(
649
+ kw_default, ir.Const
650
+ ):
651
+ ok = True
652
+ elif isinstance(kw_default, tuple):
653
+ ok = all(
654
+ [
655
+ isinstance(getdef(x), ir.Const)
656
+ for x in kw_default
657
+ ]
658
+ )
659
+ elif isinstance(kw_default, ir.Expr):
660
+ if kw_default.op != "build_tuple":
661
+ continue
662
+ ok = all(
663
+ [
664
+ isinstance(getdef(x), ir.Const)
665
+ for x in kw_default.items
666
+ ]
667
+ )
668
+ if not ok:
669
+ continue
670
+
671
+ pyfunc = convert_code_obj_to_function(node, func_ir)
672
+ func = njit()(pyfunc)
673
+ new_node = ir.Global(
674
+ node.code.co_name, func, stmt.loc
675
+ )
676
+ stmt.value = new_node
677
+ mutated |= True
678
+
679
+ # if a change was made the del ordering is probably wrong, patch up
680
+ if mutated:
681
+ post_proc = postproc.PostProcessor(func_ir)
682
+ post_proc.run()
683
+
684
+ return mutated
685
+
686
+
687
+ @register_pass(mutates_CFG=True, analysis_only=False)
688
+ class TransformLiteralUnrollConstListToTuple(FunctionPass):
689
+ """This pass spots a `literal_unroll([<constant values>])` and rewrites it
690
+ as a `literal_unroll(tuple(<constant values>))`.
691
+ """
692
+
693
+ _name = "transform_literal_unroll_const_list_to_tuple"
694
+
695
+ _accepted_types = (types.BaseTuple, types.LiteralList)
696
+
697
+ def __init__(self):
698
+ FunctionPass.__init__(self)
699
+
700
+ def run_pass(self, state):
701
+ mutated = False
702
+ func_ir = state.func_ir
703
+ for label, blk in func_ir.blocks.items():
704
+ calls = [_ for _ in blk.find_exprs("call")]
705
+ for call in calls:
706
+ glbl = guard(get_definition, func_ir, call.func)
707
+ if glbl and isinstance(glbl, (ir.Global, ir.FreeVar)):
708
+ # find a literal_unroll
709
+ if glbl.value is literal_unroll:
710
+ if len(call.args) > 1:
711
+ msg = "literal_unroll takes one argument, found %s"
712
+ raise errors.UnsupportedError(
713
+ msg % len(call.args), call.loc
714
+ )
715
+ # get the arg, make sure its a build_list
716
+ unroll_var = call.args[0]
717
+ to_unroll = guard(get_definition, func_ir, unroll_var)
718
+ if (
719
+ isinstance(to_unroll, ir.Expr)
720
+ and to_unroll.op == "build_list"
721
+ ):
722
+ # make sure they are all const items in the list
723
+ for i, item in enumerate(to_unroll.items):
724
+ val = guard(get_definition, func_ir, item)
725
+ if not val:
726
+ msg = (
727
+ "multiple definitions for variable "
728
+ "%s, cannot resolve constant"
729
+ )
730
+ raise errors.UnsupportedError(
731
+ msg % item, to_unroll.loc
732
+ )
733
+ if not isinstance(val, ir.Const):
734
+ msg = (
735
+ "Found non-constant value at "
736
+ "position %s in a list argument to "
737
+ "literal_unroll" % i
738
+ )
739
+ raise errors.UnsupportedError(
740
+ msg, to_unroll.loc
741
+ )
742
+ # The above appears ok, now swap the build_list for
743
+ # a built tuple.
744
+
745
+ # find the assignment for the unroll target
746
+ to_unroll_lhs = guard(
747
+ get_definition,
748
+ func_ir,
749
+ unroll_var,
750
+ lhs_only=True,
751
+ )
752
+
753
+ if to_unroll_lhs is None:
754
+ msg = (
755
+ "multiple definitions for variable "
756
+ "%s, cannot resolve constant"
757
+ )
758
+ raise errors.UnsupportedError(
759
+ msg % unroll_var, to_unroll.loc
760
+ )
761
+ # scan all blocks looking for the LHS
762
+ for b in func_ir.blocks.values():
763
+ asgn = b.find_variable_assignment(
764
+ to_unroll_lhs.name
765
+ )
766
+ if asgn is not None:
767
+ break
768
+ else:
769
+ msg = (
770
+ "Cannot find assignment for known "
771
+ "variable %s"
772
+ ) % to_unroll_lhs.name
773
+ raise errors.CompilerError(msg, to_unroll.loc)
774
+
775
+ # Create a tuple with the list items as contents
776
+ tup = ir.Expr.build_tuple(
777
+ to_unroll.items, to_unroll.loc
778
+ )
779
+
780
+ # swap the list for the tuple
781
+ asgn.value = tup
782
+ mutated = True
783
+ elif (
784
+ isinstance(to_unroll, ir.Expr)
785
+ and to_unroll.op == "build_tuple"
786
+ ):
787
+ # this is fine, do nothing
788
+ pass
789
+ elif isinstance(
790
+ to_unroll, (ir.Global, ir.FreeVar)
791
+ ) and isinstance(to_unroll.value, tuple):
792
+ # this is fine, do nothing
793
+ pass
794
+ elif isinstance(to_unroll, ir.Arg):
795
+ # this is only fine if the arg is a tuple
796
+ ty = state.typemap[to_unroll.name]
797
+ if not isinstance(ty, self._accepted_types):
798
+ msg = (
799
+ "Invalid use of literal_unroll with a "
800
+ "function argument, only tuples are "
801
+ "supported as function arguments, found "
802
+ "%s"
803
+ ) % ty
804
+ raise errors.UnsupportedError(
805
+ msg, to_unroll.loc
806
+ )
807
+ else:
808
+ extra = None
809
+ if isinstance(to_unroll, ir.Expr):
810
+ # probably a slice
811
+ if to_unroll.op == "getitem":
812
+ ty = state.typemap[to_unroll.value.name]
813
+ # check if this is a tuple slice
814
+ if not isinstance(ty, self._accepted_types):
815
+ extra = "operation %s" % to_unroll.op
816
+ loc = to_unroll.loc
817
+ elif isinstance(to_unroll, ir.Arg):
818
+ extra = "non-const argument %s" % to_unroll.name
819
+ loc = to_unroll.loc
820
+ else:
821
+ if to_unroll is None:
822
+ extra = (
823
+ "multiple definitions of "
824
+ 'variable "%s".' % unroll_var.name
825
+ )
826
+ loc = unroll_var.loc
827
+ else:
828
+ loc = to_unroll.loc
829
+ extra = "unknown problem"
830
+
831
+ if extra:
832
+ msg = (
833
+ "Invalid use of literal_unroll, "
834
+ "argument should be a tuple or a list "
835
+ "of constant values. Failure reason: "
836
+ "found %s" % extra
837
+ )
838
+ raise errors.UnsupportedError(msg, loc)
839
+ return mutated
840
+
841
+
842
+ @register_pass(mutates_CFG=True, analysis_only=False)
843
+ class MixedContainerUnroller(FunctionPass):
844
+ _name = "mixed_container_unroller"
845
+
846
+ _DEBUG = False
847
+
848
+ _accepted_types = (types.BaseTuple, types.LiteralList)
849
+
850
+ def __init__(self):
851
+ FunctionPass.__init__(self)
852
+
853
+ def analyse_tuple(self, tup):
854
+ """
855
+ Returns a map of type->list(indexes) for a typed tuple
856
+ """
857
+ d = defaultdict(list)
858
+ for i, ty in enumerate(tup):
859
+ d[ty].append(i)
860
+ return d
861
+
862
+ def add_offset_to_labels_w_ignore(self, blocks, offset, ignore=None):
863
+ """add an offset to all block labels and jump/branch targets
864
+ don't add an offset to anything in the ignore list
865
+ """
866
+ if ignore is None:
867
+ ignore = set()
868
+
869
+ new_blocks = {}
870
+ for l, b in blocks.items():
871
+ # some parfor last blocks might be empty
872
+ term = None
873
+ if b.body:
874
+ term = b.body[-1]
875
+ if isinstance(term, ir.Jump):
876
+ if term.target not in ignore:
877
+ b.body[-1] = ir.Jump(term.target + offset, term.loc)
878
+ if isinstance(term, ir.Branch):
879
+ if term.truebr not in ignore:
880
+ new_true = term.truebr + offset
881
+ else:
882
+ new_true = term.truebr
883
+
884
+ if term.falsebr not in ignore:
885
+ new_false = term.falsebr + offset
886
+ else:
887
+ new_false = term.falsebr
888
+ b.body[-1] = ir.Branch(term.cond, new_true, new_false, term.loc)
889
+ new_blocks[l + offset] = b
890
+ return new_blocks
891
+
892
+ def inject_loop_body(
893
+ self, switch_ir, loop_ir, caller_max_label, dont_replace, switch_data
894
+ ):
895
+ """
896
+ Injects the "loop body" held in `loop_ir` into `switch_ir` where ever
897
+ there is a statement of the form `SENTINEL.<int> = RHS`. It also:
898
+ * Finds and then deliberately does not relabel non-local jumps so as to
899
+ make the switch table suitable for injection into the IR from which
900
+ the loop body was derived.
901
+ * Looks for `typed_getitem` and wires them up to loop body version
902
+ specific variables or, if possible, directly writes in their constant
903
+ value at their use site.
904
+
905
+ Args:
906
+ - switch_ir, the switch table with SENTINELS as generated by
907
+ self.gen_switch
908
+ - loop_ir, the IR of the loop blocks (derived from the original func_ir)
909
+ - caller_max_label, the maximum label in the func_ir caller
910
+ - dont_replace, variables that should not be renamed (to handle
911
+ references to variables that are incoming at the loop head/escaping at
912
+ the loop exit.
913
+ - switch_data, the switch table data used to generated the switch_ir,
914
+ can be generated by self.analyse_tuple.
915
+
916
+ Returns:
917
+ - A type specific switch table with each case containing a versioned
918
+ loop body suitable for injection as a replacement for the loop_ir.
919
+ """
920
+
921
+ # Switch IR came from code gen, immediately relabel to prevent
922
+ # collisions with IR derived from the user code (caller)
923
+ switch_ir.blocks = self.add_offset_to_labels_w_ignore(
924
+ switch_ir.blocks, caller_max_label + 1
925
+ )
926
+
927
+ # Find the sentinels and validate the form
928
+ sentinel_exits = set()
929
+ sentinel_blocks = []
930
+ for lbl, blk in switch_ir.blocks.items():
931
+ for i, stmt in enumerate(blk.body):
932
+ if isinstance(stmt, ir.Assign):
933
+ if "SENTINEL" in stmt.target.name:
934
+ sentinel_blocks.append(lbl)
935
+ sentinel_exits.add(blk.body[-1].target)
936
+ break
937
+
938
+ assert len(sentinel_exits) == 1 # should only be 1 exit
939
+ switch_ir.blocks.pop(sentinel_exits.pop()) # kill the exit, it's dead
940
+
941
+ # find jumps that are non-local, we won't relabel these
942
+ ignore_set = set()
943
+ local_lbl = [x for x in loop_ir.blocks.keys()]
944
+ for lbl, blk in loop_ir.blocks.items():
945
+ for i, stmt in enumerate(blk.body):
946
+ if isinstance(stmt, ir.Jump):
947
+ if stmt.target not in local_lbl:
948
+ ignore_set.add(stmt.target)
949
+ if isinstance(stmt, ir.Branch):
950
+ if stmt.truebr not in local_lbl:
951
+ ignore_set.add(stmt.truebr)
952
+ if stmt.falsebr not in local_lbl:
953
+ ignore_set.add(stmt.falsebr)
954
+
955
+ # make sure the generated switch table matches the switch data
956
+ assert len(sentinel_blocks) == len(switch_data)
957
+
958
+ # replace the sentinel_blocks with the loop body
959
+ for lbl, branch_ty in zip(sentinel_blocks, switch_data.keys()):
960
+ loop_blocks = deepcopy(loop_ir.blocks)
961
+ # relabel blocks WRT switch table, each block replacement will shift
962
+ # the maximum label
963
+ max_label = max(switch_ir.blocks.keys())
964
+ loop_blocks = self.add_offset_to_labels_w_ignore(
965
+ loop_blocks, max_label + 1, ignore_set
966
+ )
967
+
968
+ # start label
969
+ loop_start_lbl = min(loop_blocks.keys())
970
+
971
+ # fix the typed_getitem locations in the loop blocks
972
+ for blk in loop_blocks.values():
973
+ new_body = []
974
+ for stmt in blk.body:
975
+ if isinstance(stmt, ir.Assign):
976
+ if (
977
+ isinstance(stmt.value, ir.Expr)
978
+ and stmt.value.op == "typed_getitem"
979
+ ):
980
+ if isinstance(branch_ty, types.Literal):
981
+ scope = switch_ir.blocks[lbl].scope
982
+ new_const_name = scope.redefine(
983
+ "branch_const", stmt.loc
984
+ ).name
985
+ new_const_var = ir.Var(
986
+ blk.scope, new_const_name, stmt.loc
987
+ )
988
+ new_const_val = ir.Const(
989
+ branch_ty.literal_value, stmt.loc
990
+ )
991
+ const_assign = ir.Assign(
992
+ new_const_val, new_const_var, stmt.loc
993
+ )
994
+ new_assign = ir.Assign(
995
+ new_const_var, stmt.target, stmt.loc
996
+ )
997
+ new_body.append(const_assign)
998
+ new_body.append(new_assign)
999
+ dont_replace.append(new_const_name)
1000
+ else:
1001
+ orig = stmt.value
1002
+ new_typed_getitem = ir.Expr.typed_getitem(
1003
+ value=orig.value,
1004
+ dtype=branch_ty,
1005
+ index=orig.index,
1006
+ loc=orig.loc,
1007
+ )
1008
+ new_assign = ir.Assign(
1009
+ new_typed_getitem, stmt.target, stmt.loc
1010
+ )
1011
+ new_body.append(new_assign)
1012
+ else:
1013
+ new_body.append(stmt)
1014
+ else:
1015
+ new_body.append(stmt)
1016
+ blk.body = new_body
1017
+
1018
+ # rename
1019
+ var_table = get_name_var_table(loop_blocks)
1020
+ drop_keys = []
1021
+ for k, v in var_table.items():
1022
+ if v.name in dont_replace:
1023
+ drop_keys.append(k)
1024
+ for k in drop_keys:
1025
+ var_table.pop(k)
1026
+
1027
+ new_var_dict = {}
1028
+ for name, var in var_table.items():
1029
+ scope = switch_ir.blocks[lbl].scope
1030
+ try:
1031
+ scope.get_exact(name)
1032
+ except errors.NotDefinedError:
1033
+ # In case the scope doesn't have the variable, we need to
1034
+ # define it prior creating new copies of it! This is
1035
+ # because the scope of the function and the scope of the
1036
+ # loop are different and the variable needs to be redefined
1037
+ # within the scope of the loop.
1038
+ scope.define(name, var.loc)
1039
+ new_var_dict[name] = scope.redefine(name, var.loc).name
1040
+ replace_var_names(loop_blocks, new_var_dict)
1041
+
1042
+ # clobber the sentinel body and then stuff in the rest
1043
+ switch_ir.blocks[lbl] = deepcopy(loop_blocks[loop_start_lbl])
1044
+ remaining_keys = [y for y in loop_blocks.keys()]
1045
+ remaining_keys.remove(loop_start_lbl)
1046
+ for k in remaining_keys:
1047
+ switch_ir.blocks[k] = deepcopy(loop_blocks[k])
1048
+
1049
+ if self._DEBUG:
1050
+ print("-" * 80 + "EXIT STUFFER")
1051
+ switch_ir.dump()
1052
+ print("-" * 80)
1053
+
1054
+ return switch_ir
1055
+
1056
+ def gen_switch(self, data, index):
1057
+ """
1058
+ Generates a function with a switch table like
1059
+ def foo():
1060
+ if PLACEHOLDER_INDEX in (<integers>):
1061
+ SENTINEL = None
1062
+ elif PLACEHOLDER_INDEX in (<integers>):
1063
+ SENTINEL = None
1064
+ ...
1065
+ else:
1066
+ raise RuntimeError
1067
+
1068
+ The data is a map of (type : indexes) for example:
1069
+ (int64, int64, float64)
1070
+ might give:
1071
+ {int64: [0, 1], float64: [2]}
1072
+
1073
+ The index is the index variable for the driving range loop over the
1074
+ mixed tuple.
1075
+ """
1076
+ elif_tplt = "\n\telif PLACEHOLDER_INDEX in (%s,):\n\t\tSENTINEL = None"
1077
+
1078
+ # Note regarding the insertion of the garbage/defeat variables below:
1079
+ # These values have been designed and inserted to defeat a specific
1080
+ # behaviour of the cpython optimizer. The optimization was introduced
1081
+ # in Python 3.10.
1082
+
1083
+ # The URL for the BPO is:
1084
+ # https://bugs.python.org/issue44626
1085
+ # The code for the optimization can be found at:
1086
+ # https://github.com/python/cpython/blob/d41abe8/Python/compile.c#L7533-L7557
1087
+
1088
+ # Essentially the CPython optimizer will inline the exit block under
1089
+ # certain circumstances and thus replace the jump with a return if the
1090
+ # exit block is small enough. This is an issue for unroller, as it
1091
+ # looks for a jump, not a return, when it inserts the generated switch
1092
+ # table.
1093
+
1094
+ # Part of the condition for this optimization to be applied is that the
1095
+ # exit block not exceed a certain (4 at the time of writing) number of
1096
+ # bytecode instructions. We defeat the optimizer by inserting a
1097
+ # sufficient number of instructions so that the exit block is big
1098
+ # enough. We don't care about this garbage, because the generated exit
1099
+ # block is discarded anyway when we smash the switch table into the
1100
+ # original function and so all the inserted garbage is dropped again.
1101
+
1102
+ # The final lines of the stacktrace w/o this will look like:
1103
+ #
1104
+ # File "/numba/numba/core/untyped_passes.py", line 830, \
1105
+ # in inject_loop_body
1106
+ # sentinel_exits.add(blk.body[-1].target)
1107
+ # AttributeError: Failed in nopython mode pipeline \
1108
+ # (step: handles literal_unroll)
1109
+ # Failed in literal_unroll_subpipeline mode pipeline \
1110
+ # (step: performs mixed container unroll)
1111
+ # 'Return' object has no attribute 'target'
1112
+ #
1113
+ # Which indicates that a Return has been found instead of a Jump
1114
+
1115
+ b = (
1116
+ "def foo():\n\tif PLACEHOLDER_INDEX in (%s,):\n\t\t"
1117
+ "SENTINEL = None\n%s\n\telse:\n\t\t"
1118
+ 'raise RuntimeError("Unreachable")\n\t'
1119
+ "py310_defeat1 = 1\n\t"
1120
+ "py310_defeat2 = 2\n\t"
1121
+ "py310_defeat3 = 3\n\t"
1122
+ "py310_defeat4 = 4\n\t"
1123
+ )
1124
+ keys = [k for k in data.keys()]
1125
+
1126
+ elifs = []
1127
+ for i in range(1, len(keys)):
1128
+ elifs.append(elif_tplt % ",".join(map(str, data[keys[i]])))
1129
+ src = b % (",".join(map(str, data[keys[0]])), "".join(elifs))
1130
+ wstr = src
1131
+ l = {}
1132
+ exec(wstr, {}, l)
1133
+ bfunc = l["foo"]
1134
+ branches = compile_to_numba_ir(bfunc, {})
1135
+ for lbl, blk in branches.blocks.items():
1136
+ for stmt in blk.body:
1137
+ if isinstance(stmt, ir.Assign):
1138
+ if isinstance(stmt.value, ir.Global):
1139
+ if stmt.value.name == "PLACEHOLDER_INDEX":
1140
+ stmt.value = index
1141
+ return branches
1142
+
1143
+ def apply_transform(self, state):
1144
+ # compute new CFG
1145
+ func_ir = state.func_ir
1146
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
1147
+ # find loops
1148
+ loops = cfg.loops()
1149
+
1150
+ # 0. Find the loops containing literal_unroll and store this
1151
+ # information
1152
+ unroll_info = namedtuple(
1153
+ "unroll_info", ["loop", "call", "arg", "getitem"]
1154
+ )
1155
+
1156
+ def get_call_args(init_arg, want):
1157
+ # Chases the assignment of a called value back through a specific
1158
+ # call to a global function "want" and returns the arguments
1159
+ # supplied to that function's call
1160
+ some_call = get_definition(func_ir, init_arg)
1161
+ if not isinstance(some_call, ir.Expr):
1162
+ raise GuardException
1163
+ if not some_call.op == "call":
1164
+ raise GuardException
1165
+ the_global = get_definition(func_ir, some_call.func)
1166
+ if not isinstance(the_global, ir.Global):
1167
+ raise GuardException
1168
+ if the_global.value is not want:
1169
+ raise GuardException
1170
+ return some_call
1171
+
1172
+ def find_unroll_loops(loops):
1173
+ """This finds loops which are compliant with the form:
1174
+ for i in range(len(literal_unroll(<something>>)))"""
1175
+ unroll_loops = {}
1176
+ for header_lbl, loop in loops.items():
1177
+ # TODO: check the loop head has literal_unroll, if it does but
1178
+ # does not conform to the following then raise
1179
+
1180
+ # scan loop header
1181
+ iternexts = [
1182
+ _
1183
+ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1184
+ ]
1185
+ # needs to be an single iternext driven loop
1186
+ if len(iternexts) != 1:
1187
+ continue
1188
+ for iternext in iternexts:
1189
+ # Walk the canonicalised loop structure and check it
1190
+ # Check loop form range(literal_unroll(container)))
1191
+ phi = guard(get_definition, func_ir, iternext.value)
1192
+ if phi is None:
1193
+ continue
1194
+
1195
+ # check call global "range"
1196
+ range_call = guard(get_call_args, phi.value, range)
1197
+ if range_call is None:
1198
+ continue
1199
+ range_arg = range_call.args[0]
1200
+
1201
+ # check call global "len"
1202
+ len_call = guard(get_call_args, range_arg, len)
1203
+ if len_call is None:
1204
+ continue
1205
+ len_arg = len_call.args[0]
1206
+
1207
+ # check literal_unroll
1208
+ literal_unroll_call = guard(
1209
+ get_definition, func_ir, len_arg
1210
+ )
1211
+ if literal_unroll_call is None:
1212
+ continue
1213
+ if not isinstance(literal_unroll_call, ir.Expr):
1214
+ continue
1215
+ if literal_unroll_call.op != "call":
1216
+ continue
1217
+ literal_func = getattr(literal_unroll_call, "func", None)
1218
+ if not literal_func:
1219
+ continue
1220
+ call_func = guard(
1221
+ get_definition, func_ir, literal_unroll_call.func
1222
+ )
1223
+ if call_func is None:
1224
+ continue
1225
+ call_func_value = call_func.value
1226
+
1227
+ if call_func_value is literal_unroll:
1228
+ assert len(literal_unroll_call.args) == 1
1229
+ unroll_loops[loop] = literal_unroll_call
1230
+ return unroll_loops
1231
+
1232
+ def ensure_no_nested_unroll(unroll_loops):
1233
+ # Validate loop nests, nested literal_unroll loops are unsupported.
1234
+ # This doesn't check that there's a getitem or anything else
1235
+ # required for the transform to work, simply just that there's no
1236
+ # nesting.
1237
+ for test_loop in unroll_loops:
1238
+ for ref_loop in unroll_loops:
1239
+ if test_loop == ref_loop: # comparing to self! skip
1240
+ continue
1241
+ if test_loop.header in ref_loop.body:
1242
+ msg = "Nesting of literal_unroll is unsupported"
1243
+ loc = func_ir.blocks[test_loop.header].loc
1244
+ raise errors.UnsupportedError(msg, loc)
1245
+
1246
+ def collect_literal_unroll_info(literal_unroll_loops):
1247
+ """Finds the loops induced by `literal_unroll`, returns a list of
1248
+ unroll_info namedtuples for use in the transform pass.
1249
+ """
1250
+
1251
+ literal_unroll_info = []
1252
+ for loop, literal_unroll_call in literal_unroll_loops.items():
1253
+ arg = literal_unroll_call.args[0]
1254
+ typemap = state.typemap
1255
+ resolved_arg = guard(
1256
+ get_definition, func_ir, arg, lhs_only=True
1257
+ )
1258
+ ty = typemap[resolved_arg.name]
1259
+ assert isinstance(ty, self._accepted_types)
1260
+ # loop header is spelled ok, now make sure the body
1261
+ # actually contains a getitem
1262
+
1263
+ # find a "getitem"... only looks in the blocks that belong
1264
+ # _solely_ to this literal_unroll (there should not be nested
1265
+ # literal_unroll loops, this is unsupported).
1266
+ tuple_getitem = None
1267
+ for lbli in loop.body:
1268
+ blk = func_ir.blocks[lbli]
1269
+ for stmt in blk.body:
1270
+ if isinstance(stmt, ir.Assign):
1271
+ if (
1272
+ isinstance(stmt.value, ir.Expr)
1273
+ and stmt.value.op == "getitem"
1274
+ ):
1275
+ # check for something like a[i]
1276
+ if stmt.value.value != arg:
1277
+ # that failed, so check for the
1278
+ # definition
1279
+ dfn = guard(
1280
+ get_definition,
1281
+ func_ir,
1282
+ stmt.value.value,
1283
+ )
1284
+ if dfn is None:
1285
+ continue
1286
+ try:
1287
+ args = getattr(dfn, "args", False)
1288
+ except KeyError:
1289
+ continue
1290
+ if not args:
1291
+ continue
1292
+ if not args[0] == arg:
1293
+ continue
1294
+ target_ty = state.typemap[arg.name]
1295
+ if not isinstance(
1296
+ target_ty, self._accepted_types
1297
+ ):
1298
+ continue
1299
+ tuple_getitem = stmt
1300
+ break
1301
+ if tuple_getitem:
1302
+ break
1303
+ else:
1304
+ continue # no getitem in this loop
1305
+
1306
+ ui = unroll_info(loop, literal_unroll_call, arg, tuple_getitem)
1307
+ literal_unroll_info.append(ui)
1308
+ return literal_unroll_info
1309
+
1310
+ # 1. Collect info about the literal_unroll loops, ensure they are legal
1311
+ literal_unroll_loops = find_unroll_loops(loops)
1312
+ # validate
1313
+ ensure_no_nested_unroll(literal_unroll_loops)
1314
+ # assemble info
1315
+ literal_unroll_info = collect_literal_unroll_info(literal_unroll_loops)
1316
+ if not literal_unroll_info:
1317
+ return False
1318
+
1319
+ # 2. Do the unroll, get a loop and process it!
1320
+ info = literal_unroll_info[0]
1321
+ self.unroll_loop(state, info)
1322
+
1323
+ # 3. Rebuild the state, the IR has taken a hammering
1324
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1325
+ post_proc = postproc.PostProcessor(func_ir)
1326
+ post_proc.run()
1327
+ if self._DEBUG:
1328
+ print("-" * 80 + "END OF PASS, SIMPLIFY DONE")
1329
+ func_ir.dump()
1330
+ func_ir._definitions = build_definitions(func_ir.blocks)
1331
+ return True
1332
+
1333
+ def unroll_loop(self, state, loop_info):
1334
+ # The general idea here is to:
1335
+ # 1. Find *a* getitem that conforms to the literal_unroll semantic,
1336
+ # i.e. one that is targeting a tuple with a loop induced index
1337
+ # 2. Compute a structure from the tuple that describes which
1338
+ # iterations of a loop will have which type
1339
+ # 3. Generate a switch table in IR form for the structure in 2
1340
+ # 4. Switch out getitems for the tuple for a `typed_getitem`
1341
+ # 5. Inject switch table as replacement loop body
1342
+ # 6. Patch up
1343
+ func_ir = state.func_ir
1344
+ getitem_target = loop_info.arg
1345
+ target_ty = state.typemap[getitem_target.name]
1346
+ assert isinstance(target_ty, self._accepted_types)
1347
+
1348
+ # 1. find a "getitem" that conforms
1349
+ tuple_getitem = []
1350
+ for lbl in loop_info.loop.body:
1351
+ blk = func_ir.blocks[lbl]
1352
+ for stmt in blk.body:
1353
+ if isinstance(stmt, ir.Assign):
1354
+ if (
1355
+ isinstance(stmt.value, ir.Expr)
1356
+ and stmt.value.op == "getitem"
1357
+ ):
1358
+ # try a couple of spellings... a[i] and ref(a)[i]
1359
+ if stmt.value.value != getitem_target:
1360
+ dfn = func_ir.get_definition(stmt.value.value)
1361
+ try:
1362
+ args = getattr(dfn, "args", False)
1363
+ except KeyError:
1364
+ continue
1365
+ if not args:
1366
+ continue
1367
+ if not args[0] == getitem_target:
1368
+ continue
1369
+ target_ty = state.typemap[getitem_target.name]
1370
+ if not isinstance(target_ty, self._accepted_types):
1371
+ continue
1372
+ tuple_getitem.append(stmt)
1373
+
1374
+ if not tuple_getitem:
1375
+ msg = (
1376
+ "Loop unrolling analysis has failed, there's no getitem "
1377
+ "in loop body that conforms to literal_unroll "
1378
+ "requirements."
1379
+ )
1380
+ LOC = func_ir.blocks[loop_info.loop.header].loc
1381
+ raise errors.CompilerError(msg, LOC)
1382
+
1383
+ # 2. get switch data
1384
+ switch_data = self.analyse_tuple(target_ty)
1385
+
1386
+ # 3. generate switch IR
1387
+ index = func_ir._definitions[tuple_getitem[0].value.index.name][0]
1388
+ branches = self.gen_switch(switch_data, index)
1389
+
1390
+ # 4. swap getitems for a typed_getitem, these are actually just
1391
+ # placeholders at this point. When the loop is duplicated they can
1392
+ # be swapped for a typed_getitem of the correct type or if the item
1393
+ # is literal it can be shoved straight into the duplicated loop body
1394
+ for item in tuple_getitem:
1395
+ old = item.value
1396
+ new = ir.Expr.typed_getitem(
1397
+ old.value, types.void, old.index, old.loc
1398
+ )
1399
+ item.value = new
1400
+
1401
+ # 5. Inject switch table
1402
+
1403
+ # Find the actual loop without the header (that won't get replaced)
1404
+ # and derive some new IR for this set of blocks
1405
+ this_loop = loop_info.loop
1406
+ this_loop_body = this_loop.body - set([this_loop.header])
1407
+ loop_blocks = {x: func_ir.blocks[x] for x in this_loop_body}
1408
+ new_ir = func_ir.derive(loop_blocks)
1409
+
1410
+ # Work out what is live on entry and exit so as to prevent
1411
+ # replacement (defined vars can escape, used vars live at the header
1412
+ # need to remain as-is so their references are correct, they can
1413
+ # also escape).
1414
+
1415
+ usedefs = compute_use_defs(func_ir.blocks)
1416
+ idx = this_loop.header
1417
+ keep = set()
1418
+ keep |= usedefs.usemap[idx] | usedefs.defmap[idx]
1419
+ keep |= func_ir.variable_lifetime.livemap[idx]
1420
+ dont_replace = [x for x in (keep)]
1421
+
1422
+ # compute the unrolled body
1423
+ unrolled_body = self.inject_loop_body(
1424
+ branches,
1425
+ new_ir,
1426
+ max(func_ir.blocks.keys()) + 1,
1427
+ dont_replace,
1428
+ switch_data,
1429
+ )
1430
+
1431
+ # 6. Patch in the unrolled body and fix up
1432
+ blks = state.func_ir.blocks
1433
+ the_scope = next(iter(blks.values())).scope
1434
+ orig_lbl = tuple(this_loop_body)
1435
+
1436
+ replace, *delete = orig_lbl
1437
+ unroll, header_block = unrolled_body, this_loop.header
1438
+ unroll_lbl = [x for x in sorted(unroll.blocks.keys())]
1439
+ blks[replace] = transfer_scope(unroll.blocks[unroll_lbl[0]], the_scope)
1440
+ [blks.pop(d) for d in delete]
1441
+ for k in unroll_lbl[1:]:
1442
+ blks[k] = transfer_scope(unroll.blocks[k], the_scope)
1443
+ # stitch up the loop predicate true -> new loop body jump
1444
+ blks[header_block].body[-1].truebr = replace
1445
+
1446
+ def run_pass(self, state):
1447
+ mutated = False
1448
+ func_ir = state.func_ir
1449
+ # first limit the work by squashing the CFG if possible
1450
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1451
+
1452
+ if self._DEBUG:
1453
+ print("-" * 80 + "PASS ENTRY")
1454
+ func_ir.dump()
1455
+ print("-" * 80)
1456
+
1457
+ # limitations:
1458
+ # 1. No nested unrolls
1459
+ # 2. Opt in via `numba.literal_unroll`
1460
+ # 3. No multiple mix-tuple use
1461
+
1462
+ # keep running the transform loop until it reports no more changes
1463
+ while True:
1464
+ stat = self.apply_transform(state)
1465
+ mutated |= stat
1466
+ if not stat:
1467
+ break
1468
+
1469
+ # reset type inference now we are done with the partial results
1470
+ state.typemap = {}
1471
+ state.calltypes = None
1472
+
1473
+ return mutated
1474
+
1475
+
1476
+ @register_pass(mutates_CFG=True, analysis_only=False)
1477
+ class IterLoopCanonicalization(FunctionPass):
1478
+ """Transforms loops that are induced by `getiter` into range() driven loops
1479
+ If the typemap is available this will only impact Tuple and UniTuple, if it
1480
+ is not available it will impact all matching loops.
1481
+ """
1482
+
1483
+ _name = "iter_loop_canonicalisation"
1484
+
1485
+ _DEBUG = False
1486
+
1487
+ # if partial typing info is available it will only look at these types
1488
+ _accepted_types = (types.BaseTuple, types.LiteralList)
1489
+ _accepted_calls = (literal_unroll,)
1490
+
1491
+ def __init__(self):
1492
+ FunctionPass.__init__(self)
1493
+
1494
+ def assess_loop(self, loop, func_ir, partial_typemap=None):
1495
+ # it's a iter loop if:
1496
+ # - loop header is driven by an iternext
1497
+ # - the iternext value is a phi derived from getiter()
1498
+
1499
+ # check header
1500
+ iternexts = [
1501
+ _ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1502
+ ]
1503
+ if len(iternexts) != 1:
1504
+ return False
1505
+ for iternext in iternexts:
1506
+ phi = guard(get_definition, func_ir, iternext.value)
1507
+ if phi is None:
1508
+ return False
1509
+ if getattr(phi, "op", False) == "getiter":
1510
+ if partial_typemap:
1511
+ # check that the call site is accepted, until we're
1512
+ # confident that tuple unrolling is behaving require opt-in
1513
+ # guard of `literal_unroll`, remove this later!
1514
+ phi_val_defn = guard(get_definition, func_ir, phi.value)
1515
+ if not isinstance(phi_val_defn, ir.Expr):
1516
+ return False
1517
+ if not phi_val_defn.op == "call":
1518
+ return False
1519
+ call = guard(get_definition, func_ir, phi_val_defn)
1520
+ if call is None or len(call.args) != 1:
1521
+ return False
1522
+ func_var = guard(get_definition, func_ir, call.func)
1523
+ func = guard(get_definition, func_ir, func_var)
1524
+ if func is None or not isinstance(
1525
+ func, (ir.Global, ir.FreeVar)
1526
+ ):
1527
+ return False
1528
+ if (
1529
+ func.value is None
1530
+ or func.value not in self._accepted_calls
1531
+ ):
1532
+ return False
1533
+
1534
+ # now check the type is supported
1535
+ ty = partial_typemap.get(call.args[0].name, None)
1536
+ if ty and isinstance(ty, self._accepted_types):
1537
+ return len(loop.entries) == 1
1538
+ else:
1539
+ return len(loop.entries) == 1
1540
+
1541
+ def transform(self, loop, func_ir, cfg):
1542
+ def get_range(a):
1543
+ return range(len(a))
1544
+
1545
+ iternext = [
1546
+ _ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
1547
+ ][0]
1548
+ LOC = func_ir.blocks[loop.header].loc
1549
+ scope = func_ir.blocks[loop.header].scope
1550
+ get_range_var = scope.redefine("CANONICALISER_get_range_gbl", LOC)
1551
+ get_range_global = ir.Global("get_range", get_range, LOC)
1552
+ assgn = ir.Assign(get_range_global, get_range_var, LOC)
1553
+
1554
+ loop_entry = tuple(loop.entries)[0]
1555
+ entry_block = func_ir.blocks[loop_entry]
1556
+ entry_block.body.insert(0, assgn)
1557
+
1558
+ iterarg = guard(get_definition, func_ir, iternext.value)
1559
+ if iterarg is not None:
1560
+ iterarg = iterarg.value
1561
+
1562
+ # look for iternext
1563
+ idx = 0
1564
+ for stmt in entry_block.body:
1565
+ if isinstance(stmt, ir.Assign):
1566
+ if (
1567
+ isinstance(stmt.value, ir.Expr)
1568
+ and stmt.value.op == "getiter"
1569
+ ):
1570
+ break
1571
+ idx += 1
1572
+ else:
1573
+ raise ValueError("problem")
1574
+
1575
+ # create a range(len(tup)) and inject it
1576
+ call_get_range_var = scope.redefine("CANONICALISER_call_get_range", LOC)
1577
+ make_call = ir.Expr.call(get_range_var, (stmt.value.value,), (), LOC)
1578
+ assgn_call = ir.Assign(make_call, call_get_range_var, LOC)
1579
+ entry_block.body.insert(idx, assgn_call)
1580
+ entry_block.body[idx + 1].value.value = call_get_range_var
1581
+
1582
+ glbls = copy(func_ir.func_id.func.__globals__)
1583
+
1584
+ inline_closurecall.inline_closure_call(
1585
+ func_ir,
1586
+ glbls,
1587
+ entry_block,
1588
+ idx,
1589
+ get_range,
1590
+ )
1591
+ kill = entry_block.body.index(assgn)
1592
+ entry_block.body.pop(kill)
1593
+
1594
+ # find the induction variable + references in the loop header
1595
+ # fixed point iter to do this, it's a bit clunky
1596
+ induction_vars = set()
1597
+ header_block = func_ir.blocks[loop.header]
1598
+
1599
+ # find induction var
1600
+ ind = [x for x in header_block.find_exprs("pair_first")]
1601
+ for x in ind:
1602
+ induction_vars.add(func_ir.get_assignee(x, loop.header))
1603
+ # find aliases of the induction var
1604
+ tmp = set()
1605
+ for x in induction_vars:
1606
+ try: # there's not always an alias, e.g. loop from inlined closure
1607
+ tmp.add(func_ir.get_assignee(x, loop.header))
1608
+ except ValueError:
1609
+ pass
1610
+ induction_vars |= tmp
1611
+ induction_var_names = set([x.name for x in induction_vars])
1612
+
1613
+ # Find the downstream blocks that might reference the induction var
1614
+ succ = set()
1615
+ for lbl in loop.exits:
1616
+ succ |= set([x[0] for x in cfg.successors(lbl)])
1617
+ check_blocks = (loop.body | loop.exits | succ) ^ {loop.header}
1618
+
1619
+ # replace RHS use of induction var with getitem
1620
+ for lbl in check_blocks:
1621
+ for stmt in func_ir.blocks[lbl].body:
1622
+ if isinstance(stmt, ir.Assign):
1623
+ # check for aliases
1624
+ try:
1625
+ lookup = getattr(stmt.value, "name", None)
1626
+ except KeyError:
1627
+ continue
1628
+ if lookup and lookup in induction_var_names:
1629
+ stmt.value = ir.Expr.getitem(
1630
+ iterarg, stmt.value, stmt.loc
1631
+ )
1632
+
1633
+ post_proc = postproc.PostProcessor(func_ir)
1634
+ post_proc.run()
1635
+
1636
+ def run_pass(self, state):
1637
+ func_ir = state.func_ir
1638
+ cfg = compute_cfg_from_blocks(func_ir.blocks)
1639
+ loops = cfg.loops()
1640
+
1641
+ mutated = False
1642
+ for header, loop in loops.items():
1643
+ stat = self.assess_loop(loop, func_ir, state.typemap)
1644
+ if stat:
1645
+ if self._DEBUG:
1646
+ print("Canonicalising loop", loop)
1647
+ self.transform(loop, func_ir, cfg)
1648
+ mutated = True
1649
+ else:
1650
+ if self._DEBUG:
1651
+ print("NOT Canonicalising loop", loop)
1652
+
1653
+ func_ir.blocks = simplify_CFG(func_ir.blocks)
1654
+ return mutated
1655
+
1656
+
1657
+ @register_pass(mutates_CFG=False, analysis_only=False)
1658
+ class PropagateLiterals(FunctionPass):
1659
+ """Implement literal propagation based on partial type inference"""
1660
+
1661
+ _name = "PropagateLiterals"
1662
+
1663
+ def __init__(self):
1664
+ FunctionPass.__init__(self)
1665
+
1666
+ def get_analysis_usage(self, AU):
1667
+ AU.add_required(ReconstructSSA)
1668
+
1669
+ def run_pass(self, state):
1670
+ func_ir = state.func_ir
1671
+ typemap = state.typemap
1672
+ flags = state.flags
1673
+
1674
+ accepted_functions = ("isinstance", "hasattr")
1675
+
1676
+ if not hasattr(func_ir, "_definitions") and not flags.enable_ssa:
1677
+ func_ir._definitions = build_definitions(func_ir.blocks)
1678
+
1679
+ changed = False
1680
+
1681
+ for block in func_ir.blocks.values():
1682
+ for assign in block.find_insts(ir.Assign):
1683
+ value = assign.value
1684
+ if isinstance(value, (ir.Arg, ir.Const, ir.FreeVar, ir.Global)):
1685
+ continue
1686
+
1687
+ # 1) Don't change return stmt in the form
1688
+ # $return_xyz = cast(value=ABC)
1689
+ # 2) Don't propagate literal values that are not primitives
1690
+ if isinstance(value, ir.Expr) and value.op in (
1691
+ "cast",
1692
+ "build_map",
1693
+ "build_list",
1694
+ "build_tuple",
1695
+ "build_set",
1696
+ ):
1697
+ continue
1698
+
1699
+ target = assign.target
1700
+ if not flags.enable_ssa:
1701
+ # SSA is disabled when doing inlining
1702
+ if guard(get_definition, func_ir, target.name) is None: # noqa: E501
1703
+ continue
1704
+
1705
+ # Numba cannot safely determine if an isinstance call
1706
+ # with a PHI node is True/False. For instance, in
1707
+ # the case below, the partial type inference step can coerce
1708
+ # '$z' to float, so any call to 'isinstance(z, int)' would fail.
1709
+ #
1710
+ # def fn(x):
1711
+ # if x > 4:
1712
+ # z = 1
1713
+ # else:
1714
+ # z = 3.14
1715
+ # if isinstance(z, int):
1716
+ # print('int')
1717
+ # else:
1718
+ # print('float')
1719
+ #
1720
+ # At the moment, one avoid propagating the literal
1721
+ # value if the argument is a PHI node
1722
+
1723
+ if isinstance(value, ir.Expr) and value.op == "call":
1724
+ fn = guard(get_definition, func_ir, value.func.name)
1725
+ if fn is None:
1726
+ continue
1727
+
1728
+ if not (
1729
+ isinstance(fn, ir.Global)
1730
+ and fn.name in accepted_functions
1731
+ ):
1732
+ continue
1733
+
1734
+ for arg in value.args:
1735
+ # check if any of the args to isinstance is a PHI node
1736
+ iv = func_ir._definitions[arg.name]
1737
+ assert len(iv) == 1 # SSA!
1738
+ if isinstance(iv[0], ir.Expr) and iv[0].op == "phi":
1739
+ msg = (
1740
+ f"{fn.name}() cannot determine the "
1741
+ f'type of variable "{arg.unversioned_name}" '
1742
+ "due to a branch."
1743
+ )
1744
+ raise errors.NumbaTypeError(msg, loc=assign.loc)
1745
+
1746
+ # Only propagate a PHI node if all arguments are the same
1747
+ # constant
1748
+ if isinstance(value, ir.Expr) and value.op == "phi":
1749
+ # typemap will return None in case `inc.name` not in typemap
1750
+ v = [typemap.get(inc.name) for inc in value.incoming_values]
1751
+ # stop if the elements in `v` do not hold the same value
1752
+ if v[0] is not None and any([v[0] != vi for vi in v]):
1753
+ continue
1754
+
1755
+ lit = typemap.get(target.name, None)
1756
+ if lit and isinstance(lit, types.Literal):
1757
+ # replace assign instruction by ir.Const(lit) iff
1758
+ # lit is a literal value
1759
+ rhs = ir.Const(lit.literal_value, assign.loc)
1760
+ new_assign = ir.Assign(rhs, target, assign.loc)
1761
+
1762
+ # replace instruction
1763
+ block.insert_after(new_assign, assign)
1764
+ block.remove(assign)
1765
+
1766
+ changed = True
1767
+
1768
+ # reset type inference now we are done with the partial results
1769
+ state.typemap = None
1770
+ state.calltypes = None
1771
+
1772
+ if changed:
1773
+ # Rebuild definitions
1774
+ func_ir._definitions = build_definitions(func_ir.blocks)
1775
+
1776
+ return changed
1777
+
1778
+
1779
+ @register_pass(mutates_CFG=True, analysis_only=False)
1780
+ class LiteralPropagationSubPipelinePass(FunctionPass):
1781
+ """Implement literal propagation based on partial type inference"""
1782
+
1783
+ _name = "LiteralPropagation"
1784
+
1785
+ def __init__(self):
1786
+ FunctionPass.__init__(self)
1787
+
1788
+ def run_pass(self, state):
1789
+ # Determine whether to even attempt this pass... if there's no
1790
+ # `isinstance` as a global or as a freevar then just skip.
1791
+
1792
+ found = False
1793
+ func_ir = state.func_ir
1794
+ for blk in func_ir.blocks.values():
1795
+ for asgn in blk.find_insts(ir.Assign):
1796
+ if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
1797
+ value = asgn.value.value
1798
+ if value is isinstance or value is hasattr:
1799
+ found = True
1800
+ break
1801
+ if found:
1802
+ break
1803
+ if not found:
1804
+ return False
1805
+
1806
+ # run as subpipeline
1807
+ from numba.cuda.core.compiler_machinery import PassManager
1808
+ from numba.cuda.core.typed_passes import PartialTypeInference
1809
+
1810
+ pm = PassManager("literal_propagation_subpipeline")
1811
+
1812
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1813
+ pm.add_pass(PropagateLiterals, "performs propagation of literal values")
1814
+
1815
+ # rewrite consts / dead branch pruning
1816
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1817
+ pm.add_pass(DeadBranchPrune, "dead branch pruning")
1818
+
1819
+ pm.finalize()
1820
+ pm.run(state)
1821
+ return True
1822
+
1823
+ def get_analysis_usage(self, AU):
1824
+ AU.add_required(ReconstructSSA)
1825
+
1826
+
1827
+ @register_pass(mutates_CFG=True, analysis_only=False)
1828
+ class LiteralUnroll(FunctionPass):
1829
+ """Implement the literal_unroll semantics"""
1830
+
1831
+ _name = "literal_unroll"
1832
+
1833
+ def __init__(self):
1834
+ FunctionPass.__init__(self)
1835
+
1836
+ def run_pass(self, state):
1837
+ # Determine whether to even attempt this pass... if there's no
1838
+ # `literal_unroll` as a global or as a freevar then just skip.
1839
+ found = False
1840
+ func_ir = state.func_ir
1841
+ for blk in func_ir.blocks.values():
1842
+ for asgn in blk.find_insts(ir.Assign):
1843
+ if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
1844
+ if asgn.value.value is literal_unroll:
1845
+ found = True
1846
+ break
1847
+ if found:
1848
+ break
1849
+ if not found:
1850
+ return False
1851
+
1852
+ # run as subpipeline
1853
+ from numba.cuda.core.compiler_machinery import PassManager
1854
+ from numba.cuda.core.typed_passes import PartialTypeInference
1855
+
1856
+ pm = PassManager("literal_unroll_subpipeline")
1857
+ # get types where possible to help with list->tuple change
1858
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1859
+ # make const lists tuples
1860
+ pm.add_pass(
1861
+ TransformLiteralUnrollConstListToTuple,
1862
+ "switch const list for tuples",
1863
+ )
1864
+ # recompute partial typemap following IR change
1865
+ pm.add_pass(PartialTypeInference, "performs partial type inference")
1866
+ # canonicalise loops
1867
+ pm.add_pass(
1868
+ IterLoopCanonicalization, "switch iter loops for range driven loops"
1869
+ )
1870
+ # rewrite consts
1871
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1872
+ # do the unroll
1873
+ pm.add_pass(MixedContainerUnroller, "performs mixed container unroll")
1874
+ # rewrite dynamic getitem to static getitem as it's possible some more
1875
+ # getitems will now be statically resolvable
1876
+ pm.add_pass(GenericRewrites, "Generic Rewrites")
1877
+ pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
1878
+ pm.finalize()
1879
+ pm.run(state)
1880
+ return True
1881
+
1882
+
1883
+ @register_pass(mutates_CFG=True, analysis_only=False)
1884
+ class SimplifyCFG(FunctionPass):
1885
+ """Perform CFG simplification"""
1886
+
1887
+ _name = "simplify_cfg"
1888
+
1889
+ def __init__(self):
1890
+ FunctionPass.__init__(self)
1891
+
1892
+ def run_pass(self, state):
1893
+ blks = state.func_ir.blocks
1894
+ new_blks = simplify_CFG(blks)
1895
+ state.func_ir.blocks = new_blks
1896
+ mutated = blks != new_blks
1897
+ return mutated
1898
+
1899
+
1900
+ @register_pass(mutates_CFG=False, analysis_only=False)
1901
+ class ReconstructSSA(FunctionPass):
1902
+ """Perform SSA-reconstruction
1903
+
1904
+ Produces minimal SSA.
1905
+ """
1906
+
1907
+ _name = "reconstruct_ssa"
1908
+
1909
+ def __init__(self):
1910
+ FunctionPass.__init__(self)
1911
+
1912
+ def run_pass(self, state):
1913
+ state.func_ir = reconstruct_ssa(state.func_ir)
1914
+ self._patch_locals(state)
1915
+
1916
+ # Rebuild definitions
1917
+ state.func_ir._definitions = build_definitions(state.func_ir.blocks)
1918
+
1919
+ # Rerun postprocessor to update metadata
1920
+ # example generator_info
1921
+ post_proc = postproc.PostProcessor(state.func_ir)
1922
+ post_proc.run(emit_dels=False)
1923
+
1924
+ if config.DEBUG or config.DUMP_SSA:
1925
+ name = state.func_ir.func_id.func_qualname
1926
+ print(f"SSA IR DUMP: {name}".center(80, "-"))
1927
+ state.func_ir.dump()
1928
+
1929
+ return True # XXX detect if it actually got changed
1930
+
1931
+ def _patch_locals(self, state):
1932
+ # Fix dispatcher locals dictionary type annotation
1933
+ locals_dict = state.get("locals")
1934
+ if locals_dict is None:
1935
+ return
1936
+
1937
+ first_blk, *_ = state.func_ir.blocks.values()
1938
+ scope = first_blk.scope
1939
+ for parent, redefs in scope.var_redefinitions.items():
1940
+ if parent in locals_dict:
1941
+ typ = locals_dict[parent]
1942
+ for derived in redefs:
1943
+ locals_dict[derived] = typ
1944
+
1945
+
1946
+ @register_pass(mutates_CFG=False, analysis_only=False)
1947
+ class RewriteDynamicRaises(FunctionPass):
1948
+ """Replace existing raise statements by dynamic raises in Numba IR."""
1949
+
1950
+ _name = "Rewrite dynamic raises"
1951
+
1952
+ def __init__(self):
1953
+ FunctionPass.__init__(self)
1954
+
1955
+ def run_pass(self, state):
1956
+ func_ir = state.func_ir
1957
+ changed = False
1958
+
1959
+ for block in func_ir.blocks.values():
1960
+ for raise_ in block.find_insts((ir.Raise, ir.TryRaise)):
1961
+ call_inst = guard(get_definition, func_ir, raise_.exception)
1962
+ if call_inst is None:
1963
+ continue
1964
+ exc_type = func_ir.infer_constant(call_inst.func.name)
1965
+ exc_args = []
1966
+ for exc_arg in call_inst.args:
1967
+ try:
1968
+ const = func_ir.infer_constant(exc_arg)
1969
+ exc_args.append(const)
1970
+ except consts.ConstantInferenceError:
1971
+ exc_args.append(exc_arg)
1972
+ loc = raise_.loc
1973
+
1974
+ cls = {
1975
+ ir.TryRaise: ir.DynamicTryRaise,
1976
+ ir.Raise: ir.DynamicRaise,
1977
+ }[type(raise_)]
1978
+
1979
+ dyn_raise = cls(exc_type, tuple(exc_args), loc)
1980
+ block.insert_after(dyn_raise, raise_)
1981
+ block.remove(raise_)
1982
+ changed = True
1983
+ return changed