numba-cuda 0.19.1__py3-none-any.whl → 0.20.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +4 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/METADATA +2 -1
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/RECORD +170 -115
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1983 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
from collections import defaultdict, namedtuple
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from copy import deepcopy, copy
|
|
7
|
+
import warnings
|
|
8
|
+
|
|
9
|
+
from numba.cuda.core.compiler_machinery import (
|
|
10
|
+
FunctionPass,
|
|
11
|
+
AnalysisPass,
|
|
12
|
+
SSACompliantMixin,
|
|
13
|
+
register_pass,
|
|
14
|
+
)
|
|
15
|
+
from numba.cuda.core import postproc, bytecode, transforms, inline_closurecall
|
|
16
|
+
from numba.core import (
|
|
17
|
+
errors,
|
|
18
|
+
types,
|
|
19
|
+
ir,
|
|
20
|
+
)
|
|
21
|
+
from numba.cuda.core import consts, rewrites, config
|
|
22
|
+
from numba.cuda.core.interpreter import Interpreter
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
from numba.misc.special import literal_unroll
|
|
26
|
+
from numba.cuda.core.analysis import dead_branch_prune
|
|
27
|
+
from numba.core.analysis import (
|
|
28
|
+
rewrite_semantic_constants,
|
|
29
|
+
find_literally_calls,
|
|
30
|
+
compute_cfg_from_blocks,
|
|
31
|
+
compute_use_defs,
|
|
32
|
+
)
|
|
33
|
+
from numba.cuda.core.ir_utils import (
|
|
34
|
+
guard,
|
|
35
|
+
resolve_func_from_module,
|
|
36
|
+
simplify_CFG,
|
|
37
|
+
GuardException,
|
|
38
|
+
convert_code_obj_to_function,
|
|
39
|
+
build_definitions,
|
|
40
|
+
replace_var_names,
|
|
41
|
+
get_name_var_table,
|
|
42
|
+
compile_to_numba_ir,
|
|
43
|
+
get_definition,
|
|
44
|
+
find_max_label,
|
|
45
|
+
rename_labels,
|
|
46
|
+
transfer_scope,
|
|
47
|
+
fixup_var_define_in_scope,
|
|
48
|
+
)
|
|
49
|
+
from numba.cuda.core.ssa import reconstruct_ssa
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@contextmanager
|
|
53
|
+
def fallback_context(state, msg):
|
|
54
|
+
"""
|
|
55
|
+
Wraps code that would signal a fallback to object mode
|
|
56
|
+
"""
|
|
57
|
+
try:
|
|
58
|
+
yield
|
|
59
|
+
except Exception as e:
|
|
60
|
+
if not state.status.can_fallback:
|
|
61
|
+
raise
|
|
62
|
+
else:
|
|
63
|
+
# Clear all references attached to the traceback
|
|
64
|
+
e = e.with_traceback(None)
|
|
65
|
+
# this emits a warning containing the error message body in the
|
|
66
|
+
# case of fallback from npm to objmode
|
|
67
|
+
loop_lift = "" if state.flags.enable_looplift else "OUT"
|
|
68
|
+
msg_rewrite = (
|
|
69
|
+
"\nCompilation is falling back to object mode "
|
|
70
|
+
"WITH%s looplifting enabled because %s" % (loop_lift, msg)
|
|
71
|
+
)
|
|
72
|
+
warnings.warn_explicit(
|
|
73
|
+
"%s due to: %s" % (msg_rewrite, e),
|
|
74
|
+
errors.NumbaWarning,
|
|
75
|
+
state.func_id.filename,
|
|
76
|
+
state.func_id.firstlineno,
|
|
77
|
+
)
|
|
78
|
+
raise
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
82
|
+
class ExtractByteCode(FunctionPass):
|
|
83
|
+
_name = "extract_bytecode"
|
|
84
|
+
|
|
85
|
+
def __init__(self):
|
|
86
|
+
FunctionPass.__init__(self)
|
|
87
|
+
|
|
88
|
+
def run_pass(self, state):
|
|
89
|
+
"""
|
|
90
|
+
Extract bytecode from function
|
|
91
|
+
"""
|
|
92
|
+
func_id = state["func_id"]
|
|
93
|
+
bc = bytecode.ByteCode(func_id)
|
|
94
|
+
if config.DUMP_BYTECODE:
|
|
95
|
+
print(bc.dump())
|
|
96
|
+
|
|
97
|
+
state["bc"] = bc
|
|
98
|
+
return True
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
102
|
+
class TranslateByteCode(FunctionPass):
|
|
103
|
+
_name = "translate_bytecode"
|
|
104
|
+
|
|
105
|
+
def __init__(self):
|
|
106
|
+
FunctionPass.__init__(self)
|
|
107
|
+
|
|
108
|
+
def run_pass(self, state):
|
|
109
|
+
"""
|
|
110
|
+
Analyze bytecode and translating to Numba IR
|
|
111
|
+
"""
|
|
112
|
+
func_id = state["func_id"]
|
|
113
|
+
bc = state["bc"]
|
|
114
|
+
interp = Interpreter(func_id)
|
|
115
|
+
func_ir = interp.interpret(bc)
|
|
116
|
+
state["func_ir"] = func_ir
|
|
117
|
+
return True
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
121
|
+
class FixupArgs(FunctionPass):
|
|
122
|
+
_name = "fixup_args"
|
|
123
|
+
|
|
124
|
+
def __init__(self):
|
|
125
|
+
FunctionPass.__init__(self)
|
|
126
|
+
|
|
127
|
+
def run_pass(self, state):
|
|
128
|
+
state["nargs"] = state["func_ir"].arg_count
|
|
129
|
+
if not state["args"] and state["flags"].force_pyobject:
|
|
130
|
+
# Allow an empty argument types specification when object mode
|
|
131
|
+
# is explicitly requested.
|
|
132
|
+
state["args"] = (types.pyobject,) * state["nargs"]
|
|
133
|
+
elif len(state["args"]) != state["nargs"]:
|
|
134
|
+
raise TypeError(
|
|
135
|
+
"Signature mismatch: %d argument types given, "
|
|
136
|
+
"but function takes %d arguments"
|
|
137
|
+
% (len(state["args"]), state["nargs"])
|
|
138
|
+
)
|
|
139
|
+
return True
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
143
|
+
class IRProcessing(FunctionPass):
|
|
144
|
+
_name = "ir_processing"
|
|
145
|
+
|
|
146
|
+
def __init__(self):
|
|
147
|
+
FunctionPass.__init__(self)
|
|
148
|
+
|
|
149
|
+
def run_pass(self, state):
|
|
150
|
+
func_ir = state["func_ir"]
|
|
151
|
+
post_proc = postproc.PostProcessor(func_ir)
|
|
152
|
+
post_proc.run()
|
|
153
|
+
|
|
154
|
+
if config.DEBUG or config.DUMP_IR:
|
|
155
|
+
name = func_ir.func_id.func_qualname
|
|
156
|
+
print(("IR DUMP: %s" % name).center(80, "-"))
|
|
157
|
+
func_ir.dump()
|
|
158
|
+
if func_ir.is_generator:
|
|
159
|
+
print(("GENERATOR INFO: %s" % name).center(80, "-"))
|
|
160
|
+
func_ir.dump_generator_info()
|
|
161
|
+
return True
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
165
|
+
class RewriteSemanticConstants(FunctionPass):
|
|
166
|
+
_name = "rewrite_semantic_constants"
|
|
167
|
+
|
|
168
|
+
def __init__(self):
|
|
169
|
+
FunctionPass.__init__(self)
|
|
170
|
+
|
|
171
|
+
def run_pass(self, state):
|
|
172
|
+
"""
|
|
173
|
+
This prunes dead branches, a dead branch is one which is derivable as
|
|
174
|
+
not taken at compile time purely based on const/literal evaluation.
|
|
175
|
+
"""
|
|
176
|
+
assert state.func_ir
|
|
177
|
+
msg = (
|
|
178
|
+
"Internal error in pre-inference dead branch pruning "
|
|
179
|
+
"pass encountered during compilation of "
|
|
180
|
+
'function "%s"' % (state.func_id.func_name,)
|
|
181
|
+
)
|
|
182
|
+
with fallback_context(state, msg):
|
|
183
|
+
rewrite_semantic_constants(state.func_ir, state.args)
|
|
184
|
+
|
|
185
|
+
return True
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
189
|
+
class DeadBranchPrune(SSACompliantMixin, FunctionPass):
|
|
190
|
+
_name = "dead_branch_prune"
|
|
191
|
+
|
|
192
|
+
def __init__(self):
|
|
193
|
+
FunctionPass.__init__(self)
|
|
194
|
+
|
|
195
|
+
def run_pass(self, state):
|
|
196
|
+
"""
|
|
197
|
+
This prunes dead branches, a dead branch is one which is derivable as
|
|
198
|
+
not taken at compile time purely based on const/literal evaluation.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
# purely for demonstration purposes, obtain the analysis from a pass
|
|
202
|
+
# declare as a required dependent
|
|
203
|
+
semantic_const_analysis = self.get_analysis(type(self)) # noqa
|
|
204
|
+
|
|
205
|
+
assert state.func_ir
|
|
206
|
+
msg = (
|
|
207
|
+
"Internal error in pre-inference dead branch pruning "
|
|
208
|
+
"pass encountered during compilation of "
|
|
209
|
+
'function "%s"' % (state.func_id.func_name,)
|
|
210
|
+
)
|
|
211
|
+
with fallback_context(state, msg):
|
|
212
|
+
dead_branch_prune(state.func_ir, state.args)
|
|
213
|
+
|
|
214
|
+
return True
|
|
215
|
+
|
|
216
|
+
def get_analysis_usage(self, AU):
|
|
217
|
+
AU.add_required(RewriteSemanticConstants)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
221
|
+
class InlineClosureLikes(FunctionPass):
|
|
222
|
+
_name = "inline_closure_likes"
|
|
223
|
+
|
|
224
|
+
def __init__(self):
|
|
225
|
+
FunctionPass.__init__(self)
|
|
226
|
+
|
|
227
|
+
def run_pass(self, state):
|
|
228
|
+
# Ensure we have an IR and type information.
|
|
229
|
+
assert state.func_ir
|
|
230
|
+
|
|
231
|
+
# if the return type is a pyobject, there's no type info available and
|
|
232
|
+
# no ability to resolve certain typed function calls in the array
|
|
233
|
+
# inlining code, use this variable to indicate
|
|
234
|
+
typed_pass = not isinstance(state.return_type, types.misc.PyObject)
|
|
235
|
+
|
|
236
|
+
inline_pass = inline_closurecall.InlineClosureCallPass(
|
|
237
|
+
state.func_ir,
|
|
238
|
+
state.flags.auto_parallel,
|
|
239
|
+
None,
|
|
240
|
+
typed_pass,
|
|
241
|
+
)
|
|
242
|
+
inline_pass.run()
|
|
243
|
+
|
|
244
|
+
# Remove all Dels, and re-run postproc
|
|
245
|
+
post_proc = postproc.PostProcessor(state.func_ir)
|
|
246
|
+
post_proc.run()
|
|
247
|
+
|
|
248
|
+
fixup_var_define_in_scope(state.func_ir.blocks)
|
|
249
|
+
|
|
250
|
+
return True
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
254
|
+
class GenericRewrites(FunctionPass):
|
|
255
|
+
_name = "generic_rewrites"
|
|
256
|
+
|
|
257
|
+
def __init__(self):
|
|
258
|
+
FunctionPass.__init__(self)
|
|
259
|
+
|
|
260
|
+
def run_pass(self, state):
|
|
261
|
+
"""
|
|
262
|
+
Perform any intermediate representation rewrites before type
|
|
263
|
+
inference.
|
|
264
|
+
"""
|
|
265
|
+
assert state.func_ir
|
|
266
|
+
msg = (
|
|
267
|
+
"Internal error in pre-inference rewriting "
|
|
268
|
+
"pass encountered during compilation of "
|
|
269
|
+
'function "%s"' % (state.func_id.func_name,)
|
|
270
|
+
)
|
|
271
|
+
with fallback_context(state, msg):
|
|
272
|
+
rewrites.rewrite_registry.apply("before-inference", state)
|
|
273
|
+
return True
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
277
|
+
class WithLifting(FunctionPass):
|
|
278
|
+
_name = "with_lifting"
|
|
279
|
+
|
|
280
|
+
def __init__(self):
|
|
281
|
+
FunctionPass.__init__(self)
|
|
282
|
+
|
|
283
|
+
def run_pass(self, state):
|
|
284
|
+
"""
|
|
285
|
+
Extract with-contexts
|
|
286
|
+
"""
|
|
287
|
+
main, withs = transforms.with_lifting(
|
|
288
|
+
func_ir=state.func_ir,
|
|
289
|
+
typingctx=state.typingctx,
|
|
290
|
+
targetctx=state.targetctx,
|
|
291
|
+
flags=state.flags,
|
|
292
|
+
locals=state.locals,
|
|
293
|
+
)
|
|
294
|
+
if withs:
|
|
295
|
+
from numba.cuda.compiler import compile_ir
|
|
296
|
+
from numba.cuda.core.compiler import _EarlyPipelineCompletion
|
|
297
|
+
|
|
298
|
+
cres = compile_ir(
|
|
299
|
+
state.typingctx,
|
|
300
|
+
state.targetctx,
|
|
301
|
+
main,
|
|
302
|
+
state.args,
|
|
303
|
+
state.return_type,
|
|
304
|
+
state.flags,
|
|
305
|
+
state.locals,
|
|
306
|
+
lifted=tuple(withs),
|
|
307
|
+
lifted_from=None,
|
|
308
|
+
pipeline_class=type(state.pipeline),
|
|
309
|
+
)
|
|
310
|
+
raise _EarlyPipelineCompletion(cres)
|
|
311
|
+
return True
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
315
|
+
class InlineInlinables(FunctionPass):
|
|
316
|
+
"""
|
|
317
|
+
This pass will inline a function wrapped by the numba.jit decorator directly
|
|
318
|
+
into the site of its call depending on the value set in the 'inline' kwarg
|
|
319
|
+
to the decorator.
|
|
320
|
+
|
|
321
|
+
This is an untyped pass. CFG simplification is performed at the end of the
|
|
322
|
+
pass but no block level clean up is performed on the mutated IR (typing
|
|
323
|
+
information is not available to do so).
|
|
324
|
+
"""
|
|
325
|
+
|
|
326
|
+
_name = "inline_inlinables"
|
|
327
|
+
_DEBUG = False
|
|
328
|
+
|
|
329
|
+
def __init__(self):
|
|
330
|
+
FunctionPass.__init__(self)
|
|
331
|
+
|
|
332
|
+
def run_pass(self, state):
|
|
333
|
+
"""Run inlining of inlinables"""
|
|
334
|
+
if self._DEBUG:
|
|
335
|
+
print("before inline".center(80, "-"))
|
|
336
|
+
print(state.func_ir.dump())
|
|
337
|
+
print("".center(80, "-"))
|
|
338
|
+
|
|
339
|
+
inline_worker = inline_closurecall.InlineWorker(
|
|
340
|
+
state.typingctx,
|
|
341
|
+
state.targetctx,
|
|
342
|
+
state.locals,
|
|
343
|
+
state.pipeline,
|
|
344
|
+
state.flags,
|
|
345
|
+
validator=inline_closurecall.callee_ir_validator,
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
modified = False
|
|
349
|
+
# use a work list, look for call sites via `ir.Expr.op == call` and
|
|
350
|
+
# then pass these to `self._do_work` to make decisions about inlining.
|
|
351
|
+
work_list = list(state.func_ir.blocks.items())
|
|
352
|
+
while work_list:
|
|
353
|
+
label, block = work_list.pop()
|
|
354
|
+
for i, instr in enumerate(block.body):
|
|
355
|
+
if isinstance(instr, ir.Assign):
|
|
356
|
+
expr = instr.value
|
|
357
|
+
if isinstance(expr, ir.Expr) and expr.op == "call":
|
|
358
|
+
if guard(
|
|
359
|
+
self._do_work,
|
|
360
|
+
state,
|
|
361
|
+
work_list,
|
|
362
|
+
block,
|
|
363
|
+
i,
|
|
364
|
+
expr,
|
|
365
|
+
inline_worker,
|
|
366
|
+
):
|
|
367
|
+
modified = True
|
|
368
|
+
break # because block structure changed
|
|
369
|
+
|
|
370
|
+
if modified:
|
|
371
|
+
# clean up unconditional branches that appear due to inlined
|
|
372
|
+
# functions introducing blocks
|
|
373
|
+
cfg = compute_cfg_from_blocks(state.func_ir.blocks)
|
|
374
|
+
for dead in cfg.dead_nodes():
|
|
375
|
+
del state.func_ir.blocks[dead]
|
|
376
|
+
post_proc = postproc.PostProcessor(state.func_ir)
|
|
377
|
+
post_proc.run()
|
|
378
|
+
state.func_ir.blocks = simplify_CFG(state.func_ir.blocks)
|
|
379
|
+
|
|
380
|
+
if self._DEBUG:
|
|
381
|
+
print("after inline".center(80, "-"))
|
|
382
|
+
print(state.func_ir.dump())
|
|
383
|
+
print("".center(80, "-"))
|
|
384
|
+
return True
|
|
385
|
+
|
|
386
|
+
def _do_work(self, state, work_list, block, i, expr, inline_worker):
|
|
387
|
+
from numba.cuda.compiler import run_frontend
|
|
388
|
+
from numba.cuda.core.options import InlineOptions
|
|
389
|
+
|
|
390
|
+
# try and get a definition for the call, this isn't always possible as
|
|
391
|
+
# it might be a eval(str)/part generated awaiting update etc. (parfors)
|
|
392
|
+
to_inline = None
|
|
393
|
+
try:
|
|
394
|
+
to_inline = state.func_ir.get_definition(expr.func)
|
|
395
|
+
except Exception:
|
|
396
|
+
if self._DEBUG:
|
|
397
|
+
print("Cannot find definition for %s" % expr.func)
|
|
398
|
+
return False
|
|
399
|
+
# do not handle closure inlining here, another pass deals with that.
|
|
400
|
+
if getattr(to_inline, "op", False) == "make_function":
|
|
401
|
+
return False
|
|
402
|
+
|
|
403
|
+
# see if the definition is a "getattr", in which case walk the IR to
|
|
404
|
+
# try and find the python function via the module from which it's
|
|
405
|
+
# imported, this should all be encoded in the IR.
|
|
406
|
+
if getattr(to_inline, "op", False) == "getattr":
|
|
407
|
+
val = resolve_func_from_module(state.func_ir, to_inline)
|
|
408
|
+
else:
|
|
409
|
+
# This is likely a freevar or global
|
|
410
|
+
#
|
|
411
|
+
# NOTE: getattr 'value' on a call may fail if it's an ir.Expr as
|
|
412
|
+
# getattr is overloaded to look in _kws.
|
|
413
|
+
try:
|
|
414
|
+
val = getattr(to_inline, "value", False)
|
|
415
|
+
except Exception:
|
|
416
|
+
raise GuardException
|
|
417
|
+
|
|
418
|
+
# if something was found...
|
|
419
|
+
if val:
|
|
420
|
+
# check it's dispatcher-like, the targetoptions attr holds the
|
|
421
|
+
# kwargs supplied in the jit decorator and is where 'inline' will
|
|
422
|
+
# be if it is present.
|
|
423
|
+
topt = getattr(val, "targetoptions", False)
|
|
424
|
+
if topt:
|
|
425
|
+
inline_type = topt.get("inline", None)
|
|
426
|
+
# has 'inline' been specified?
|
|
427
|
+
if inline_type is not None:
|
|
428
|
+
inline_opt = InlineOptions(inline_type)
|
|
429
|
+
# Could this be inlinable?
|
|
430
|
+
if not inline_opt.is_never_inline:
|
|
431
|
+
# yes, it could be inlinable
|
|
432
|
+
do_inline = True
|
|
433
|
+
pyfunc = val.py_func
|
|
434
|
+
# Has it got an associated cost model?
|
|
435
|
+
if inline_opt.has_cost_model:
|
|
436
|
+
# yes, it has a cost model, use it to determine
|
|
437
|
+
# whether to do the inline
|
|
438
|
+
py_func_ir = run_frontend(pyfunc)
|
|
439
|
+
do_inline = inline_type(
|
|
440
|
+
expr, state.func_ir, py_func_ir
|
|
441
|
+
)
|
|
442
|
+
# if do_inline is True then inline!
|
|
443
|
+
if do_inline:
|
|
444
|
+
_, _, _, new_blocks = inline_worker.inline_function(
|
|
445
|
+
state.func_ir,
|
|
446
|
+
block,
|
|
447
|
+
i,
|
|
448
|
+
pyfunc,
|
|
449
|
+
)
|
|
450
|
+
if work_list is not None:
|
|
451
|
+
for blk in new_blocks:
|
|
452
|
+
work_list.append(blk)
|
|
453
|
+
return True
|
|
454
|
+
return False
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
@register_pass(mutates_CFG=False, analysis_only=False)
|
|
458
|
+
class PreserveIR(AnalysisPass):
|
|
459
|
+
"""
|
|
460
|
+
Preserves the IR in the metadata
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
_name = "preserve_ir"
|
|
464
|
+
|
|
465
|
+
def __init__(self):
|
|
466
|
+
AnalysisPass.__init__(self)
|
|
467
|
+
|
|
468
|
+
def run_pass(self, state):
|
|
469
|
+
state.metadata["preserved_ir"] = state.func_ir.copy()
|
|
470
|
+
return False
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
@register_pass(mutates_CFG=False, analysis_only=True)
|
|
474
|
+
class FindLiterallyCalls(FunctionPass):
|
|
475
|
+
"""Find calls to `numba.literally()` and signal if its requirement is not
|
|
476
|
+
satisfied.
|
|
477
|
+
"""
|
|
478
|
+
|
|
479
|
+
_name = "find_literally"
|
|
480
|
+
|
|
481
|
+
def __init__(self):
|
|
482
|
+
FunctionPass.__init__(self)
|
|
483
|
+
|
|
484
|
+
def run_pass(self, state):
|
|
485
|
+
find_literally_calls(state.func_ir, state.args)
|
|
486
|
+
return False
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
490
|
+
class CanonicalizeLoopExit(FunctionPass):
|
|
491
|
+
"""A pass to canonicalize loop exit by splitting it from function exit."""
|
|
492
|
+
|
|
493
|
+
_name = "canonicalize_loop_exit"
|
|
494
|
+
|
|
495
|
+
def __init__(self):
|
|
496
|
+
FunctionPass.__init__(self)
|
|
497
|
+
|
|
498
|
+
def run_pass(self, state):
|
|
499
|
+
fir = state.func_ir
|
|
500
|
+
cfg = compute_cfg_from_blocks(fir.blocks)
|
|
501
|
+
status = False
|
|
502
|
+
for loop in cfg.loops().values():
|
|
503
|
+
for exit_label in loop.exits:
|
|
504
|
+
if exit_label in cfg.exit_points():
|
|
505
|
+
self._split_exit_block(fir, cfg, exit_label)
|
|
506
|
+
status = True
|
|
507
|
+
|
|
508
|
+
fir._reset_analysis_variables()
|
|
509
|
+
|
|
510
|
+
vlt = postproc.VariableLifetime(fir.blocks)
|
|
511
|
+
fir.variable_lifetime = vlt
|
|
512
|
+
return status
|
|
513
|
+
|
|
514
|
+
def _split_exit_block(self, fir, cfg, exit_label):
|
|
515
|
+
curblock = fir.blocks[exit_label]
|
|
516
|
+
newlabel = exit_label + 1
|
|
517
|
+
newlabel = find_max_label(fir.blocks) + 1
|
|
518
|
+
fir.blocks[newlabel] = curblock
|
|
519
|
+
newblock = ir.Block(scope=curblock.scope, loc=curblock.loc)
|
|
520
|
+
newblock.append(ir.Jump(newlabel, loc=curblock.loc))
|
|
521
|
+
fir.blocks[exit_label] = newblock
|
|
522
|
+
# Rename all labels
|
|
523
|
+
fir.blocks = rename_labels(fir.blocks)
|
|
524
|
+
|
|
525
|
+
|
|
526
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
527
|
+
class CanonicalizeLoopEntry(FunctionPass):
|
|
528
|
+
"""A pass to canonicalize loop header by splitting it from function entry.
|
|
529
|
+
|
|
530
|
+
This is needed for loop-lifting; esp in py3.8
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
_name = "canonicalize_loop_entry"
|
|
534
|
+
_supported_globals = {range, enumerate, zip}
|
|
535
|
+
|
|
536
|
+
def __init__(self):
|
|
537
|
+
FunctionPass.__init__(self)
|
|
538
|
+
|
|
539
|
+
def run_pass(self, state):
|
|
540
|
+
fir = state.func_ir
|
|
541
|
+
cfg = compute_cfg_from_blocks(fir.blocks)
|
|
542
|
+
status = False
|
|
543
|
+
for loop in cfg.loops().values():
|
|
544
|
+
if len(loop.entries) == 1:
|
|
545
|
+
[entry_label] = loop.entries
|
|
546
|
+
if entry_label == cfg.entry_point():
|
|
547
|
+
self._split_entry_block(fir, cfg, loop, entry_label)
|
|
548
|
+
status = True
|
|
549
|
+
fir._reset_analysis_variables()
|
|
550
|
+
|
|
551
|
+
vlt = postproc.VariableLifetime(fir.blocks)
|
|
552
|
+
fir.variable_lifetime = vlt
|
|
553
|
+
return status
|
|
554
|
+
|
|
555
|
+
def _split_entry_block(self, fir, cfg, loop, entry_label):
|
|
556
|
+
# Find iterator inputs into the for-loop header
|
|
557
|
+
header_block = fir.blocks[loop.header]
|
|
558
|
+
deps = set()
|
|
559
|
+
for expr in header_block.find_exprs(op="iternext"):
|
|
560
|
+
deps.add(expr.value)
|
|
561
|
+
# Find the getiter for each iterator
|
|
562
|
+
entry_block = fir.blocks[entry_label]
|
|
563
|
+
|
|
564
|
+
# Find the start of loop entry statement that needs to be included.
|
|
565
|
+
startpt = None
|
|
566
|
+
list_of_insts = list(entry_block.find_insts(ir.Assign))
|
|
567
|
+
for assign in reversed(list_of_insts):
|
|
568
|
+
if assign.target in deps:
|
|
569
|
+
rhs = assign.value
|
|
570
|
+
if isinstance(rhs, ir.Var):
|
|
571
|
+
if rhs.is_temp:
|
|
572
|
+
deps.add(rhs)
|
|
573
|
+
elif isinstance(rhs, ir.Expr):
|
|
574
|
+
expr = rhs
|
|
575
|
+
if expr.op == "getiter":
|
|
576
|
+
startpt = assign
|
|
577
|
+
if expr.value.is_temp:
|
|
578
|
+
deps.add(expr.value)
|
|
579
|
+
elif expr.op == "call":
|
|
580
|
+
defn = guard(get_definition, fir, expr.func)
|
|
581
|
+
if isinstance(defn, ir.Global):
|
|
582
|
+
if expr.func.is_temp:
|
|
583
|
+
deps.add(expr.func)
|
|
584
|
+
elif (
|
|
585
|
+
isinstance(rhs, ir.Global)
|
|
586
|
+
and rhs.value in self._supported_globals
|
|
587
|
+
):
|
|
588
|
+
startpt = assign
|
|
589
|
+
|
|
590
|
+
if startpt is None:
|
|
591
|
+
return
|
|
592
|
+
|
|
593
|
+
splitpt = entry_block.body.index(startpt)
|
|
594
|
+
new_block = entry_block.copy()
|
|
595
|
+
new_block.body = new_block.body[splitpt:]
|
|
596
|
+
new_block.loc = new_block.body[0].loc
|
|
597
|
+
new_label = find_max_label(fir.blocks) + 1
|
|
598
|
+
entry_block.body = entry_block.body[:splitpt]
|
|
599
|
+
entry_block.append(ir.Jump(new_label, loc=new_block.loc))
|
|
600
|
+
|
|
601
|
+
fir.blocks[new_label] = new_block
|
|
602
|
+
# Rename all labels
|
|
603
|
+
fir.blocks = rename_labels(fir.blocks)
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
@register_pass(mutates_CFG=False, analysis_only=True)
|
|
607
|
+
class PrintIRCFG(FunctionPass):
|
|
608
|
+
_name = "print_ir_cfg"
|
|
609
|
+
|
|
610
|
+
def __init__(self):
|
|
611
|
+
FunctionPass.__init__(self)
|
|
612
|
+
self._ver = 0
|
|
613
|
+
|
|
614
|
+
def run_pass(self, state):
|
|
615
|
+
fir = state.func_ir
|
|
616
|
+
self._ver += 1
|
|
617
|
+
fir.render_dot(filename_prefix="v{}".format(self._ver)).render()
|
|
618
|
+
return False
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
622
|
+
class MakeFunctionToJitFunction(FunctionPass):
|
|
623
|
+
"""
|
|
624
|
+
This swaps an ir.Expr.op == "make_function" i.e. a closure, for a compiled
|
|
625
|
+
function containing the closure body and puts it in ir.Global. It's a 1:1
|
|
626
|
+
statement value swap. `make_function` is already untyped
|
|
627
|
+
"""
|
|
628
|
+
|
|
629
|
+
_name = "make_function_op_code_to_jit_function"
|
|
630
|
+
|
|
631
|
+
def __init__(self):
|
|
632
|
+
FunctionPass.__init__(self)
|
|
633
|
+
|
|
634
|
+
def run_pass(self, state):
|
|
635
|
+
from numba import njit
|
|
636
|
+
|
|
637
|
+
func_ir = state.func_ir
|
|
638
|
+
mutated = False
|
|
639
|
+
for idx, blk in func_ir.blocks.items():
|
|
640
|
+
for stmt in blk.body:
|
|
641
|
+
if isinstance(stmt, ir.Assign):
|
|
642
|
+
if isinstance(stmt.value, ir.Expr):
|
|
643
|
+
if stmt.value.op == "make_function":
|
|
644
|
+
node = stmt.value
|
|
645
|
+
getdef = func_ir.get_definition
|
|
646
|
+
kw_default = getdef(node.defaults)
|
|
647
|
+
ok = False
|
|
648
|
+
if kw_default is None or isinstance(
|
|
649
|
+
kw_default, ir.Const
|
|
650
|
+
):
|
|
651
|
+
ok = True
|
|
652
|
+
elif isinstance(kw_default, tuple):
|
|
653
|
+
ok = all(
|
|
654
|
+
[
|
|
655
|
+
isinstance(getdef(x), ir.Const)
|
|
656
|
+
for x in kw_default
|
|
657
|
+
]
|
|
658
|
+
)
|
|
659
|
+
elif isinstance(kw_default, ir.Expr):
|
|
660
|
+
if kw_default.op != "build_tuple":
|
|
661
|
+
continue
|
|
662
|
+
ok = all(
|
|
663
|
+
[
|
|
664
|
+
isinstance(getdef(x), ir.Const)
|
|
665
|
+
for x in kw_default.items
|
|
666
|
+
]
|
|
667
|
+
)
|
|
668
|
+
if not ok:
|
|
669
|
+
continue
|
|
670
|
+
|
|
671
|
+
pyfunc = convert_code_obj_to_function(node, func_ir)
|
|
672
|
+
func = njit()(pyfunc)
|
|
673
|
+
new_node = ir.Global(
|
|
674
|
+
node.code.co_name, func, stmt.loc
|
|
675
|
+
)
|
|
676
|
+
stmt.value = new_node
|
|
677
|
+
mutated |= True
|
|
678
|
+
|
|
679
|
+
# if a change was made the del ordering is probably wrong, patch up
|
|
680
|
+
if mutated:
|
|
681
|
+
post_proc = postproc.PostProcessor(func_ir)
|
|
682
|
+
post_proc.run()
|
|
683
|
+
|
|
684
|
+
return mutated
|
|
685
|
+
|
|
686
|
+
|
|
687
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
688
|
+
class TransformLiteralUnrollConstListToTuple(FunctionPass):
|
|
689
|
+
"""This pass spots a `literal_unroll([<constant values>])` and rewrites it
|
|
690
|
+
as a `literal_unroll(tuple(<constant values>))`.
|
|
691
|
+
"""
|
|
692
|
+
|
|
693
|
+
_name = "transform_literal_unroll_const_list_to_tuple"
|
|
694
|
+
|
|
695
|
+
_accepted_types = (types.BaseTuple, types.LiteralList)
|
|
696
|
+
|
|
697
|
+
def __init__(self):
|
|
698
|
+
FunctionPass.__init__(self)
|
|
699
|
+
|
|
700
|
+
def run_pass(self, state):
|
|
701
|
+
mutated = False
|
|
702
|
+
func_ir = state.func_ir
|
|
703
|
+
for label, blk in func_ir.blocks.items():
|
|
704
|
+
calls = [_ for _ in blk.find_exprs("call")]
|
|
705
|
+
for call in calls:
|
|
706
|
+
glbl = guard(get_definition, func_ir, call.func)
|
|
707
|
+
if glbl and isinstance(glbl, (ir.Global, ir.FreeVar)):
|
|
708
|
+
# find a literal_unroll
|
|
709
|
+
if glbl.value is literal_unroll:
|
|
710
|
+
if len(call.args) > 1:
|
|
711
|
+
msg = "literal_unroll takes one argument, found %s"
|
|
712
|
+
raise errors.UnsupportedError(
|
|
713
|
+
msg % len(call.args), call.loc
|
|
714
|
+
)
|
|
715
|
+
# get the arg, make sure its a build_list
|
|
716
|
+
unroll_var = call.args[0]
|
|
717
|
+
to_unroll = guard(get_definition, func_ir, unroll_var)
|
|
718
|
+
if (
|
|
719
|
+
isinstance(to_unroll, ir.Expr)
|
|
720
|
+
and to_unroll.op == "build_list"
|
|
721
|
+
):
|
|
722
|
+
# make sure they are all const items in the list
|
|
723
|
+
for i, item in enumerate(to_unroll.items):
|
|
724
|
+
val = guard(get_definition, func_ir, item)
|
|
725
|
+
if not val:
|
|
726
|
+
msg = (
|
|
727
|
+
"multiple definitions for variable "
|
|
728
|
+
"%s, cannot resolve constant"
|
|
729
|
+
)
|
|
730
|
+
raise errors.UnsupportedError(
|
|
731
|
+
msg % item, to_unroll.loc
|
|
732
|
+
)
|
|
733
|
+
if not isinstance(val, ir.Const):
|
|
734
|
+
msg = (
|
|
735
|
+
"Found non-constant value at "
|
|
736
|
+
"position %s in a list argument to "
|
|
737
|
+
"literal_unroll" % i
|
|
738
|
+
)
|
|
739
|
+
raise errors.UnsupportedError(
|
|
740
|
+
msg, to_unroll.loc
|
|
741
|
+
)
|
|
742
|
+
# The above appears ok, now swap the build_list for
|
|
743
|
+
# a built tuple.
|
|
744
|
+
|
|
745
|
+
# find the assignment for the unroll target
|
|
746
|
+
to_unroll_lhs = guard(
|
|
747
|
+
get_definition,
|
|
748
|
+
func_ir,
|
|
749
|
+
unroll_var,
|
|
750
|
+
lhs_only=True,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
if to_unroll_lhs is None:
|
|
754
|
+
msg = (
|
|
755
|
+
"multiple definitions for variable "
|
|
756
|
+
"%s, cannot resolve constant"
|
|
757
|
+
)
|
|
758
|
+
raise errors.UnsupportedError(
|
|
759
|
+
msg % unroll_var, to_unroll.loc
|
|
760
|
+
)
|
|
761
|
+
# scan all blocks looking for the LHS
|
|
762
|
+
for b in func_ir.blocks.values():
|
|
763
|
+
asgn = b.find_variable_assignment(
|
|
764
|
+
to_unroll_lhs.name
|
|
765
|
+
)
|
|
766
|
+
if asgn is not None:
|
|
767
|
+
break
|
|
768
|
+
else:
|
|
769
|
+
msg = (
|
|
770
|
+
"Cannot find assignment for known "
|
|
771
|
+
"variable %s"
|
|
772
|
+
) % to_unroll_lhs.name
|
|
773
|
+
raise errors.CompilerError(msg, to_unroll.loc)
|
|
774
|
+
|
|
775
|
+
# Create a tuple with the list items as contents
|
|
776
|
+
tup = ir.Expr.build_tuple(
|
|
777
|
+
to_unroll.items, to_unroll.loc
|
|
778
|
+
)
|
|
779
|
+
|
|
780
|
+
# swap the list for the tuple
|
|
781
|
+
asgn.value = tup
|
|
782
|
+
mutated = True
|
|
783
|
+
elif (
|
|
784
|
+
isinstance(to_unroll, ir.Expr)
|
|
785
|
+
and to_unroll.op == "build_tuple"
|
|
786
|
+
):
|
|
787
|
+
# this is fine, do nothing
|
|
788
|
+
pass
|
|
789
|
+
elif isinstance(
|
|
790
|
+
to_unroll, (ir.Global, ir.FreeVar)
|
|
791
|
+
) and isinstance(to_unroll.value, tuple):
|
|
792
|
+
# this is fine, do nothing
|
|
793
|
+
pass
|
|
794
|
+
elif isinstance(to_unroll, ir.Arg):
|
|
795
|
+
# this is only fine if the arg is a tuple
|
|
796
|
+
ty = state.typemap[to_unroll.name]
|
|
797
|
+
if not isinstance(ty, self._accepted_types):
|
|
798
|
+
msg = (
|
|
799
|
+
"Invalid use of literal_unroll with a "
|
|
800
|
+
"function argument, only tuples are "
|
|
801
|
+
"supported as function arguments, found "
|
|
802
|
+
"%s"
|
|
803
|
+
) % ty
|
|
804
|
+
raise errors.UnsupportedError(
|
|
805
|
+
msg, to_unroll.loc
|
|
806
|
+
)
|
|
807
|
+
else:
|
|
808
|
+
extra = None
|
|
809
|
+
if isinstance(to_unroll, ir.Expr):
|
|
810
|
+
# probably a slice
|
|
811
|
+
if to_unroll.op == "getitem":
|
|
812
|
+
ty = state.typemap[to_unroll.value.name]
|
|
813
|
+
# check if this is a tuple slice
|
|
814
|
+
if not isinstance(ty, self._accepted_types):
|
|
815
|
+
extra = "operation %s" % to_unroll.op
|
|
816
|
+
loc = to_unroll.loc
|
|
817
|
+
elif isinstance(to_unroll, ir.Arg):
|
|
818
|
+
extra = "non-const argument %s" % to_unroll.name
|
|
819
|
+
loc = to_unroll.loc
|
|
820
|
+
else:
|
|
821
|
+
if to_unroll is None:
|
|
822
|
+
extra = (
|
|
823
|
+
"multiple definitions of "
|
|
824
|
+
'variable "%s".' % unroll_var.name
|
|
825
|
+
)
|
|
826
|
+
loc = unroll_var.loc
|
|
827
|
+
else:
|
|
828
|
+
loc = to_unroll.loc
|
|
829
|
+
extra = "unknown problem"
|
|
830
|
+
|
|
831
|
+
if extra:
|
|
832
|
+
msg = (
|
|
833
|
+
"Invalid use of literal_unroll, "
|
|
834
|
+
"argument should be a tuple or a list "
|
|
835
|
+
"of constant values. Failure reason: "
|
|
836
|
+
"found %s" % extra
|
|
837
|
+
)
|
|
838
|
+
raise errors.UnsupportedError(msg, loc)
|
|
839
|
+
return mutated
|
|
840
|
+
|
|
841
|
+
|
|
842
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
843
|
+
class MixedContainerUnroller(FunctionPass):
|
|
844
|
+
_name = "mixed_container_unroller"
|
|
845
|
+
|
|
846
|
+
_DEBUG = False
|
|
847
|
+
|
|
848
|
+
_accepted_types = (types.BaseTuple, types.LiteralList)
|
|
849
|
+
|
|
850
|
+
def __init__(self):
|
|
851
|
+
FunctionPass.__init__(self)
|
|
852
|
+
|
|
853
|
+
def analyse_tuple(self, tup):
|
|
854
|
+
"""
|
|
855
|
+
Returns a map of type->list(indexes) for a typed tuple
|
|
856
|
+
"""
|
|
857
|
+
d = defaultdict(list)
|
|
858
|
+
for i, ty in enumerate(tup):
|
|
859
|
+
d[ty].append(i)
|
|
860
|
+
return d
|
|
861
|
+
|
|
862
|
+
def add_offset_to_labels_w_ignore(self, blocks, offset, ignore=None):
|
|
863
|
+
"""add an offset to all block labels and jump/branch targets
|
|
864
|
+
don't add an offset to anything in the ignore list
|
|
865
|
+
"""
|
|
866
|
+
if ignore is None:
|
|
867
|
+
ignore = set()
|
|
868
|
+
|
|
869
|
+
new_blocks = {}
|
|
870
|
+
for l, b in blocks.items():
|
|
871
|
+
# some parfor last blocks might be empty
|
|
872
|
+
term = None
|
|
873
|
+
if b.body:
|
|
874
|
+
term = b.body[-1]
|
|
875
|
+
if isinstance(term, ir.Jump):
|
|
876
|
+
if term.target not in ignore:
|
|
877
|
+
b.body[-1] = ir.Jump(term.target + offset, term.loc)
|
|
878
|
+
if isinstance(term, ir.Branch):
|
|
879
|
+
if term.truebr not in ignore:
|
|
880
|
+
new_true = term.truebr + offset
|
|
881
|
+
else:
|
|
882
|
+
new_true = term.truebr
|
|
883
|
+
|
|
884
|
+
if term.falsebr not in ignore:
|
|
885
|
+
new_false = term.falsebr + offset
|
|
886
|
+
else:
|
|
887
|
+
new_false = term.falsebr
|
|
888
|
+
b.body[-1] = ir.Branch(term.cond, new_true, new_false, term.loc)
|
|
889
|
+
new_blocks[l + offset] = b
|
|
890
|
+
return new_blocks
|
|
891
|
+
|
|
892
|
+
def inject_loop_body(
|
|
893
|
+
self, switch_ir, loop_ir, caller_max_label, dont_replace, switch_data
|
|
894
|
+
):
|
|
895
|
+
"""
|
|
896
|
+
Injects the "loop body" held in `loop_ir` into `switch_ir` where ever
|
|
897
|
+
there is a statement of the form `SENTINEL.<int> = RHS`. It also:
|
|
898
|
+
* Finds and then deliberately does not relabel non-local jumps so as to
|
|
899
|
+
make the switch table suitable for injection into the IR from which
|
|
900
|
+
the loop body was derived.
|
|
901
|
+
* Looks for `typed_getitem` and wires them up to loop body version
|
|
902
|
+
specific variables or, if possible, directly writes in their constant
|
|
903
|
+
value at their use site.
|
|
904
|
+
|
|
905
|
+
Args:
|
|
906
|
+
- switch_ir, the switch table with SENTINELS as generated by
|
|
907
|
+
self.gen_switch
|
|
908
|
+
- loop_ir, the IR of the loop blocks (derived from the original func_ir)
|
|
909
|
+
- caller_max_label, the maximum label in the func_ir caller
|
|
910
|
+
- dont_replace, variables that should not be renamed (to handle
|
|
911
|
+
references to variables that are incoming at the loop head/escaping at
|
|
912
|
+
the loop exit.
|
|
913
|
+
- switch_data, the switch table data used to generated the switch_ir,
|
|
914
|
+
can be generated by self.analyse_tuple.
|
|
915
|
+
|
|
916
|
+
Returns:
|
|
917
|
+
- A type specific switch table with each case containing a versioned
|
|
918
|
+
loop body suitable for injection as a replacement for the loop_ir.
|
|
919
|
+
"""
|
|
920
|
+
|
|
921
|
+
# Switch IR came from code gen, immediately relabel to prevent
|
|
922
|
+
# collisions with IR derived from the user code (caller)
|
|
923
|
+
switch_ir.blocks = self.add_offset_to_labels_w_ignore(
|
|
924
|
+
switch_ir.blocks, caller_max_label + 1
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Find the sentinels and validate the form
|
|
928
|
+
sentinel_exits = set()
|
|
929
|
+
sentinel_blocks = []
|
|
930
|
+
for lbl, blk in switch_ir.blocks.items():
|
|
931
|
+
for i, stmt in enumerate(blk.body):
|
|
932
|
+
if isinstance(stmt, ir.Assign):
|
|
933
|
+
if "SENTINEL" in stmt.target.name:
|
|
934
|
+
sentinel_blocks.append(lbl)
|
|
935
|
+
sentinel_exits.add(blk.body[-1].target)
|
|
936
|
+
break
|
|
937
|
+
|
|
938
|
+
assert len(sentinel_exits) == 1 # should only be 1 exit
|
|
939
|
+
switch_ir.blocks.pop(sentinel_exits.pop()) # kill the exit, it's dead
|
|
940
|
+
|
|
941
|
+
# find jumps that are non-local, we won't relabel these
|
|
942
|
+
ignore_set = set()
|
|
943
|
+
local_lbl = [x for x in loop_ir.blocks.keys()]
|
|
944
|
+
for lbl, blk in loop_ir.blocks.items():
|
|
945
|
+
for i, stmt in enumerate(blk.body):
|
|
946
|
+
if isinstance(stmt, ir.Jump):
|
|
947
|
+
if stmt.target not in local_lbl:
|
|
948
|
+
ignore_set.add(stmt.target)
|
|
949
|
+
if isinstance(stmt, ir.Branch):
|
|
950
|
+
if stmt.truebr not in local_lbl:
|
|
951
|
+
ignore_set.add(stmt.truebr)
|
|
952
|
+
if stmt.falsebr not in local_lbl:
|
|
953
|
+
ignore_set.add(stmt.falsebr)
|
|
954
|
+
|
|
955
|
+
# make sure the generated switch table matches the switch data
|
|
956
|
+
assert len(sentinel_blocks) == len(switch_data)
|
|
957
|
+
|
|
958
|
+
# replace the sentinel_blocks with the loop body
|
|
959
|
+
for lbl, branch_ty in zip(sentinel_blocks, switch_data.keys()):
|
|
960
|
+
loop_blocks = deepcopy(loop_ir.blocks)
|
|
961
|
+
# relabel blocks WRT switch table, each block replacement will shift
|
|
962
|
+
# the maximum label
|
|
963
|
+
max_label = max(switch_ir.blocks.keys())
|
|
964
|
+
loop_blocks = self.add_offset_to_labels_w_ignore(
|
|
965
|
+
loop_blocks, max_label + 1, ignore_set
|
|
966
|
+
)
|
|
967
|
+
|
|
968
|
+
# start label
|
|
969
|
+
loop_start_lbl = min(loop_blocks.keys())
|
|
970
|
+
|
|
971
|
+
# fix the typed_getitem locations in the loop blocks
|
|
972
|
+
for blk in loop_blocks.values():
|
|
973
|
+
new_body = []
|
|
974
|
+
for stmt in blk.body:
|
|
975
|
+
if isinstance(stmt, ir.Assign):
|
|
976
|
+
if (
|
|
977
|
+
isinstance(stmt.value, ir.Expr)
|
|
978
|
+
and stmt.value.op == "typed_getitem"
|
|
979
|
+
):
|
|
980
|
+
if isinstance(branch_ty, types.Literal):
|
|
981
|
+
scope = switch_ir.blocks[lbl].scope
|
|
982
|
+
new_const_name = scope.redefine(
|
|
983
|
+
"branch_const", stmt.loc
|
|
984
|
+
).name
|
|
985
|
+
new_const_var = ir.Var(
|
|
986
|
+
blk.scope, new_const_name, stmt.loc
|
|
987
|
+
)
|
|
988
|
+
new_const_val = ir.Const(
|
|
989
|
+
branch_ty.literal_value, stmt.loc
|
|
990
|
+
)
|
|
991
|
+
const_assign = ir.Assign(
|
|
992
|
+
new_const_val, new_const_var, stmt.loc
|
|
993
|
+
)
|
|
994
|
+
new_assign = ir.Assign(
|
|
995
|
+
new_const_var, stmt.target, stmt.loc
|
|
996
|
+
)
|
|
997
|
+
new_body.append(const_assign)
|
|
998
|
+
new_body.append(new_assign)
|
|
999
|
+
dont_replace.append(new_const_name)
|
|
1000
|
+
else:
|
|
1001
|
+
orig = stmt.value
|
|
1002
|
+
new_typed_getitem = ir.Expr.typed_getitem(
|
|
1003
|
+
value=orig.value,
|
|
1004
|
+
dtype=branch_ty,
|
|
1005
|
+
index=orig.index,
|
|
1006
|
+
loc=orig.loc,
|
|
1007
|
+
)
|
|
1008
|
+
new_assign = ir.Assign(
|
|
1009
|
+
new_typed_getitem, stmt.target, stmt.loc
|
|
1010
|
+
)
|
|
1011
|
+
new_body.append(new_assign)
|
|
1012
|
+
else:
|
|
1013
|
+
new_body.append(stmt)
|
|
1014
|
+
else:
|
|
1015
|
+
new_body.append(stmt)
|
|
1016
|
+
blk.body = new_body
|
|
1017
|
+
|
|
1018
|
+
# rename
|
|
1019
|
+
var_table = get_name_var_table(loop_blocks)
|
|
1020
|
+
drop_keys = []
|
|
1021
|
+
for k, v in var_table.items():
|
|
1022
|
+
if v.name in dont_replace:
|
|
1023
|
+
drop_keys.append(k)
|
|
1024
|
+
for k in drop_keys:
|
|
1025
|
+
var_table.pop(k)
|
|
1026
|
+
|
|
1027
|
+
new_var_dict = {}
|
|
1028
|
+
for name, var in var_table.items():
|
|
1029
|
+
scope = switch_ir.blocks[lbl].scope
|
|
1030
|
+
try:
|
|
1031
|
+
scope.get_exact(name)
|
|
1032
|
+
except errors.NotDefinedError:
|
|
1033
|
+
# In case the scope doesn't have the variable, we need to
|
|
1034
|
+
# define it prior creating new copies of it! This is
|
|
1035
|
+
# because the scope of the function and the scope of the
|
|
1036
|
+
# loop are different and the variable needs to be redefined
|
|
1037
|
+
# within the scope of the loop.
|
|
1038
|
+
scope.define(name, var.loc)
|
|
1039
|
+
new_var_dict[name] = scope.redefine(name, var.loc).name
|
|
1040
|
+
replace_var_names(loop_blocks, new_var_dict)
|
|
1041
|
+
|
|
1042
|
+
# clobber the sentinel body and then stuff in the rest
|
|
1043
|
+
switch_ir.blocks[lbl] = deepcopy(loop_blocks[loop_start_lbl])
|
|
1044
|
+
remaining_keys = [y for y in loop_blocks.keys()]
|
|
1045
|
+
remaining_keys.remove(loop_start_lbl)
|
|
1046
|
+
for k in remaining_keys:
|
|
1047
|
+
switch_ir.blocks[k] = deepcopy(loop_blocks[k])
|
|
1048
|
+
|
|
1049
|
+
if self._DEBUG:
|
|
1050
|
+
print("-" * 80 + "EXIT STUFFER")
|
|
1051
|
+
switch_ir.dump()
|
|
1052
|
+
print("-" * 80)
|
|
1053
|
+
|
|
1054
|
+
return switch_ir
|
|
1055
|
+
|
|
1056
|
+
def gen_switch(self, data, index):
|
|
1057
|
+
"""
|
|
1058
|
+
Generates a function with a switch table like
|
|
1059
|
+
def foo():
|
|
1060
|
+
if PLACEHOLDER_INDEX in (<integers>):
|
|
1061
|
+
SENTINEL = None
|
|
1062
|
+
elif PLACEHOLDER_INDEX in (<integers>):
|
|
1063
|
+
SENTINEL = None
|
|
1064
|
+
...
|
|
1065
|
+
else:
|
|
1066
|
+
raise RuntimeError
|
|
1067
|
+
|
|
1068
|
+
The data is a map of (type : indexes) for example:
|
|
1069
|
+
(int64, int64, float64)
|
|
1070
|
+
might give:
|
|
1071
|
+
{int64: [0, 1], float64: [2]}
|
|
1072
|
+
|
|
1073
|
+
The index is the index variable for the driving range loop over the
|
|
1074
|
+
mixed tuple.
|
|
1075
|
+
"""
|
|
1076
|
+
elif_tplt = "\n\telif PLACEHOLDER_INDEX in (%s,):\n\t\tSENTINEL = None"
|
|
1077
|
+
|
|
1078
|
+
# Note regarding the insertion of the garbage/defeat variables below:
|
|
1079
|
+
# These values have been designed and inserted to defeat a specific
|
|
1080
|
+
# behaviour of the cpython optimizer. The optimization was introduced
|
|
1081
|
+
# in Python 3.10.
|
|
1082
|
+
|
|
1083
|
+
# The URL for the BPO is:
|
|
1084
|
+
# https://bugs.python.org/issue44626
|
|
1085
|
+
# The code for the optimization can be found at:
|
|
1086
|
+
# https://github.com/python/cpython/blob/d41abe8/Python/compile.c#L7533-L7557
|
|
1087
|
+
|
|
1088
|
+
# Essentially the CPython optimizer will inline the exit block under
|
|
1089
|
+
# certain circumstances and thus replace the jump with a return if the
|
|
1090
|
+
# exit block is small enough. This is an issue for unroller, as it
|
|
1091
|
+
# looks for a jump, not a return, when it inserts the generated switch
|
|
1092
|
+
# table.
|
|
1093
|
+
|
|
1094
|
+
# Part of the condition for this optimization to be applied is that the
|
|
1095
|
+
# exit block not exceed a certain (4 at the time of writing) number of
|
|
1096
|
+
# bytecode instructions. We defeat the optimizer by inserting a
|
|
1097
|
+
# sufficient number of instructions so that the exit block is big
|
|
1098
|
+
# enough. We don't care about this garbage, because the generated exit
|
|
1099
|
+
# block is discarded anyway when we smash the switch table into the
|
|
1100
|
+
# original function and so all the inserted garbage is dropped again.
|
|
1101
|
+
|
|
1102
|
+
# The final lines of the stacktrace w/o this will look like:
|
|
1103
|
+
#
|
|
1104
|
+
# File "/numba/numba/core/untyped_passes.py", line 830, \
|
|
1105
|
+
# in inject_loop_body
|
|
1106
|
+
# sentinel_exits.add(blk.body[-1].target)
|
|
1107
|
+
# AttributeError: Failed in nopython mode pipeline \
|
|
1108
|
+
# (step: handles literal_unroll)
|
|
1109
|
+
# Failed in literal_unroll_subpipeline mode pipeline \
|
|
1110
|
+
# (step: performs mixed container unroll)
|
|
1111
|
+
# 'Return' object has no attribute 'target'
|
|
1112
|
+
#
|
|
1113
|
+
# Which indicates that a Return has been found instead of a Jump
|
|
1114
|
+
|
|
1115
|
+
b = (
|
|
1116
|
+
"def foo():\n\tif PLACEHOLDER_INDEX in (%s,):\n\t\t"
|
|
1117
|
+
"SENTINEL = None\n%s\n\telse:\n\t\t"
|
|
1118
|
+
'raise RuntimeError("Unreachable")\n\t'
|
|
1119
|
+
"py310_defeat1 = 1\n\t"
|
|
1120
|
+
"py310_defeat2 = 2\n\t"
|
|
1121
|
+
"py310_defeat3 = 3\n\t"
|
|
1122
|
+
"py310_defeat4 = 4\n\t"
|
|
1123
|
+
)
|
|
1124
|
+
keys = [k for k in data.keys()]
|
|
1125
|
+
|
|
1126
|
+
elifs = []
|
|
1127
|
+
for i in range(1, len(keys)):
|
|
1128
|
+
elifs.append(elif_tplt % ",".join(map(str, data[keys[i]])))
|
|
1129
|
+
src = b % (",".join(map(str, data[keys[0]])), "".join(elifs))
|
|
1130
|
+
wstr = src
|
|
1131
|
+
l = {}
|
|
1132
|
+
exec(wstr, {}, l)
|
|
1133
|
+
bfunc = l["foo"]
|
|
1134
|
+
branches = compile_to_numba_ir(bfunc, {})
|
|
1135
|
+
for lbl, blk in branches.blocks.items():
|
|
1136
|
+
for stmt in blk.body:
|
|
1137
|
+
if isinstance(stmt, ir.Assign):
|
|
1138
|
+
if isinstance(stmt.value, ir.Global):
|
|
1139
|
+
if stmt.value.name == "PLACEHOLDER_INDEX":
|
|
1140
|
+
stmt.value = index
|
|
1141
|
+
return branches
|
|
1142
|
+
|
|
1143
|
+
def apply_transform(self, state):
|
|
1144
|
+
# compute new CFG
|
|
1145
|
+
func_ir = state.func_ir
|
|
1146
|
+
cfg = compute_cfg_from_blocks(func_ir.blocks)
|
|
1147
|
+
# find loops
|
|
1148
|
+
loops = cfg.loops()
|
|
1149
|
+
|
|
1150
|
+
# 0. Find the loops containing literal_unroll and store this
|
|
1151
|
+
# information
|
|
1152
|
+
unroll_info = namedtuple(
|
|
1153
|
+
"unroll_info", ["loop", "call", "arg", "getitem"]
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
def get_call_args(init_arg, want):
|
|
1157
|
+
# Chases the assignment of a called value back through a specific
|
|
1158
|
+
# call to a global function "want" and returns the arguments
|
|
1159
|
+
# supplied to that function's call
|
|
1160
|
+
some_call = get_definition(func_ir, init_arg)
|
|
1161
|
+
if not isinstance(some_call, ir.Expr):
|
|
1162
|
+
raise GuardException
|
|
1163
|
+
if not some_call.op == "call":
|
|
1164
|
+
raise GuardException
|
|
1165
|
+
the_global = get_definition(func_ir, some_call.func)
|
|
1166
|
+
if not isinstance(the_global, ir.Global):
|
|
1167
|
+
raise GuardException
|
|
1168
|
+
if the_global.value is not want:
|
|
1169
|
+
raise GuardException
|
|
1170
|
+
return some_call
|
|
1171
|
+
|
|
1172
|
+
def find_unroll_loops(loops):
|
|
1173
|
+
"""This finds loops which are compliant with the form:
|
|
1174
|
+
for i in range(len(literal_unroll(<something>>)))"""
|
|
1175
|
+
unroll_loops = {}
|
|
1176
|
+
for header_lbl, loop in loops.items():
|
|
1177
|
+
# TODO: check the loop head has literal_unroll, if it does but
|
|
1178
|
+
# does not conform to the following then raise
|
|
1179
|
+
|
|
1180
|
+
# scan loop header
|
|
1181
|
+
iternexts = [
|
|
1182
|
+
_
|
|
1183
|
+
for _ in func_ir.blocks[loop.header].find_exprs("iternext")
|
|
1184
|
+
]
|
|
1185
|
+
# needs to be an single iternext driven loop
|
|
1186
|
+
if len(iternexts) != 1:
|
|
1187
|
+
continue
|
|
1188
|
+
for iternext in iternexts:
|
|
1189
|
+
# Walk the canonicalised loop structure and check it
|
|
1190
|
+
# Check loop form range(literal_unroll(container)))
|
|
1191
|
+
phi = guard(get_definition, func_ir, iternext.value)
|
|
1192
|
+
if phi is None:
|
|
1193
|
+
continue
|
|
1194
|
+
|
|
1195
|
+
# check call global "range"
|
|
1196
|
+
range_call = guard(get_call_args, phi.value, range)
|
|
1197
|
+
if range_call is None:
|
|
1198
|
+
continue
|
|
1199
|
+
range_arg = range_call.args[0]
|
|
1200
|
+
|
|
1201
|
+
# check call global "len"
|
|
1202
|
+
len_call = guard(get_call_args, range_arg, len)
|
|
1203
|
+
if len_call is None:
|
|
1204
|
+
continue
|
|
1205
|
+
len_arg = len_call.args[0]
|
|
1206
|
+
|
|
1207
|
+
# check literal_unroll
|
|
1208
|
+
literal_unroll_call = guard(
|
|
1209
|
+
get_definition, func_ir, len_arg
|
|
1210
|
+
)
|
|
1211
|
+
if literal_unroll_call is None:
|
|
1212
|
+
continue
|
|
1213
|
+
if not isinstance(literal_unroll_call, ir.Expr):
|
|
1214
|
+
continue
|
|
1215
|
+
if literal_unroll_call.op != "call":
|
|
1216
|
+
continue
|
|
1217
|
+
literal_func = getattr(literal_unroll_call, "func", None)
|
|
1218
|
+
if not literal_func:
|
|
1219
|
+
continue
|
|
1220
|
+
call_func = guard(
|
|
1221
|
+
get_definition, func_ir, literal_unroll_call.func
|
|
1222
|
+
)
|
|
1223
|
+
if call_func is None:
|
|
1224
|
+
continue
|
|
1225
|
+
call_func_value = call_func.value
|
|
1226
|
+
|
|
1227
|
+
if call_func_value is literal_unroll:
|
|
1228
|
+
assert len(literal_unroll_call.args) == 1
|
|
1229
|
+
unroll_loops[loop] = literal_unroll_call
|
|
1230
|
+
return unroll_loops
|
|
1231
|
+
|
|
1232
|
+
def ensure_no_nested_unroll(unroll_loops):
|
|
1233
|
+
# Validate loop nests, nested literal_unroll loops are unsupported.
|
|
1234
|
+
# This doesn't check that there's a getitem or anything else
|
|
1235
|
+
# required for the transform to work, simply just that there's no
|
|
1236
|
+
# nesting.
|
|
1237
|
+
for test_loop in unroll_loops:
|
|
1238
|
+
for ref_loop in unroll_loops:
|
|
1239
|
+
if test_loop == ref_loop: # comparing to self! skip
|
|
1240
|
+
continue
|
|
1241
|
+
if test_loop.header in ref_loop.body:
|
|
1242
|
+
msg = "Nesting of literal_unroll is unsupported"
|
|
1243
|
+
loc = func_ir.blocks[test_loop.header].loc
|
|
1244
|
+
raise errors.UnsupportedError(msg, loc)
|
|
1245
|
+
|
|
1246
|
+
def collect_literal_unroll_info(literal_unroll_loops):
|
|
1247
|
+
"""Finds the loops induced by `literal_unroll`, returns a list of
|
|
1248
|
+
unroll_info namedtuples for use in the transform pass.
|
|
1249
|
+
"""
|
|
1250
|
+
|
|
1251
|
+
literal_unroll_info = []
|
|
1252
|
+
for loop, literal_unroll_call in literal_unroll_loops.items():
|
|
1253
|
+
arg = literal_unroll_call.args[0]
|
|
1254
|
+
typemap = state.typemap
|
|
1255
|
+
resolved_arg = guard(
|
|
1256
|
+
get_definition, func_ir, arg, lhs_only=True
|
|
1257
|
+
)
|
|
1258
|
+
ty = typemap[resolved_arg.name]
|
|
1259
|
+
assert isinstance(ty, self._accepted_types)
|
|
1260
|
+
# loop header is spelled ok, now make sure the body
|
|
1261
|
+
# actually contains a getitem
|
|
1262
|
+
|
|
1263
|
+
# find a "getitem"... only looks in the blocks that belong
|
|
1264
|
+
# _solely_ to this literal_unroll (there should not be nested
|
|
1265
|
+
# literal_unroll loops, this is unsupported).
|
|
1266
|
+
tuple_getitem = None
|
|
1267
|
+
for lbli in loop.body:
|
|
1268
|
+
blk = func_ir.blocks[lbli]
|
|
1269
|
+
for stmt in blk.body:
|
|
1270
|
+
if isinstance(stmt, ir.Assign):
|
|
1271
|
+
if (
|
|
1272
|
+
isinstance(stmt.value, ir.Expr)
|
|
1273
|
+
and stmt.value.op == "getitem"
|
|
1274
|
+
):
|
|
1275
|
+
# check for something like a[i]
|
|
1276
|
+
if stmt.value.value != arg:
|
|
1277
|
+
# that failed, so check for the
|
|
1278
|
+
# definition
|
|
1279
|
+
dfn = guard(
|
|
1280
|
+
get_definition,
|
|
1281
|
+
func_ir,
|
|
1282
|
+
stmt.value.value,
|
|
1283
|
+
)
|
|
1284
|
+
if dfn is None:
|
|
1285
|
+
continue
|
|
1286
|
+
try:
|
|
1287
|
+
args = getattr(dfn, "args", False)
|
|
1288
|
+
except KeyError:
|
|
1289
|
+
continue
|
|
1290
|
+
if not args:
|
|
1291
|
+
continue
|
|
1292
|
+
if not args[0] == arg:
|
|
1293
|
+
continue
|
|
1294
|
+
target_ty = state.typemap[arg.name]
|
|
1295
|
+
if not isinstance(
|
|
1296
|
+
target_ty, self._accepted_types
|
|
1297
|
+
):
|
|
1298
|
+
continue
|
|
1299
|
+
tuple_getitem = stmt
|
|
1300
|
+
break
|
|
1301
|
+
if tuple_getitem:
|
|
1302
|
+
break
|
|
1303
|
+
else:
|
|
1304
|
+
continue # no getitem in this loop
|
|
1305
|
+
|
|
1306
|
+
ui = unroll_info(loop, literal_unroll_call, arg, tuple_getitem)
|
|
1307
|
+
literal_unroll_info.append(ui)
|
|
1308
|
+
return literal_unroll_info
|
|
1309
|
+
|
|
1310
|
+
# 1. Collect info about the literal_unroll loops, ensure they are legal
|
|
1311
|
+
literal_unroll_loops = find_unroll_loops(loops)
|
|
1312
|
+
# validate
|
|
1313
|
+
ensure_no_nested_unroll(literal_unroll_loops)
|
|
1314
|
+
# assemble info
|
|
1315
|
+
literal_unroll_info = collect_literal_unroll_info(literal_unroll_loops)
|
|
1316
|
+
if not literal_unroll_info:
|
|
1317
|
+
return False
|
|
1318
|
+
|
|
1319
|
+
# 2. Do the unroll, get a loop and process it!
|
|
1320
|
+
info = literal_unroll_info[0]
|
|
1321
|
+
self.unroll_loop(state, info)
|
|
1322
|
+
|
|
1323
|
+
# 3. Rebuild the state, the IR has taken a hammering
|
|
1324
|
+
func_ir.blocks = simplify_CFG(func_ir.blocks)
|
|
1325
|
+
post_proc = postproc.PostProcessor(func_ir)
|
|
1326
|
+
post_proc.run()
|
|
1327
|
+
if self._DEBUG:
|
|
1328
|
+
print("-" * 80 + "END OF PASS, SIMPLIFY DONE")
|
|
1329
|
+
func_ir.dump()
|
|
1330
|
+
func_ir._definitions = build_definitions(func_ir.blocks)
|
|
1331
|
+
return True
|
|
1332
|
+
|
|
1333
|
+
def unroll_loop(self, state, loop_info):
|
|
1334
|
+
# The general idea here is to:
|
|
1335
|
+
# 1. Find *a* getitem that conforms to the literal_unroll semantic,
|
|
1336
|
+
# i.e. one that is targeting a tuple with a loop induced index
|
|
1337
|
+
# 2. Compute a structure from the tuple that describes which
|
|
1338
|
+
# iterations of a loop will have which type
|
|
1339
|
+
# 3. Generate a switch table in IR form for the structure in 2
|
|
1340
|
+
# 4. Switch out getitems for the tuple for a `typed_getitem`
|
|
1341
|
+
# 5. Inject switch table as replacement loop body
|
|
1342
|
+
# 6. Patch up
|
|
1343
|
+
func_ir = state.func_ir
|
|
1344
|
+
getitem_target = loop_info.arg
|
|
1345
|
+
target_ty = state.typemap[getitem_target.name]
|
|
1346
|
+
assert isinstance(target_ty, self._accepted_types)
|
|
1347
|
+
|
|
1348
|
+
# 1. find a "getitem" that conforms
|
|
1349
|
+
tuple_getitem = []
|
|
1350
|
+
for lbl in loop_info.loop.body:
|
|
1351
|
+
blk = func_ir.blocks[lbl]
|
|
1352
|
+
for stmt in blk.body:
|
|
1353
|
+
if isinstance(stmt, ir.Assign):
|
|
1354
|
+
if (
|
|
1355
|
+
isinstance(stmt.value, ir.Expr)
|
|
1356
|
+
and stmt.value.op == "getitem"
|
|
1357
|
+
):
|
|
1358
|
+
# try a couple of spellings... a[i] and ref(a)[i]
|
|
1359
|
+
if stmt.value.value != getitem_target:
|
|
1360
|
+
dfn = func_ir.get_definition(stmt.value.value)
|
|
1361
|
+
try:
|
|
1362
|
+
args = getattr(dfn, "args", False)
|
|
1363
|
+
except KeyError:
|
|
1364
|
+
continue
|
|
1365
|
+
if not args:
|
|
1366
|
+
continue
|
|
1367
|
+
if not args[0] == getitem_target:
|
|
1368
|
+
continue
|
|
1369
|
+
target_ty = state.typemap[getitem_target.name]
|
|
1370
|
+
if not isinstance(target_ty, self._accepted_types):
|
|
1371
|
+
continue
|
|
1372
|
+
tuple_getitem.append(stmt)
|
|
1373
|
+
|
|
1374
|
+
if not tuple_getitem:
|
|
1375
|
+
msg = (
|
|
1376
|
+
"Loop unrolling analysis has failed, there's no getitem "
|
|
1377
|
+
"in loop body that conforms to literal_unroll "
|
|
1378
|
+
"requirements."
|
|
1379
|
+
)
|
|
1380
|
+
LOC = func_ir.blocks[loop_info.loop.header].loc
|
|
1381
|
+
raise errors.CompilerError(msg, LOC)
|
|
1382
|
+
|
|
1383
|
+
# 2. get switch data
|
|
1384
|
+
switch_data = self.analyse_tuple(target_ty)
|
|
1385
|
+
|
|
1386
|
+
# 3. generate switch IR
|
|
1387
|
+
index = func_ir._definitions[tuple_getitem[0].value.index.name][0]
|
|
1388
|
+
branches = self.gen_switch(switch_data, index)
|
|
1389
|
+
|
|
1390
|
+
# 4. swap getitems for a typed_getitem, these are actually just
|
|
1391
|
+
# placeholders at this point. When the loop is duplicated they can
|
|
1392
|
+
# be swapped for a typed_getitem of the correct type or if the item
|
|
1393
|
+
# is literal it can be shoved straight into the duplicated loop body
|
|
1394
|
+
for item in tuple_getitem:
|
|
1395
|
+
old = item.value
|
|
1396
|
+
new = ir.Expr.typed_getitem(
|
|
1397
|
+
old.value, types.void, old.index, old.loc
|
|
1398
|
+
)
|
|
1399
|
+
item.value = new
|
|
1400
|
+
|
|
1401
|
+
# 5. Inject switch table
|
|
1402
|
+
|
|
1403
|
+
# Find the actual loop without the header (that won't get replaced)
|
|
1404
|
+
# and derive some new IR for this set of blocks
|
|
1405
|
+
this_loop = loop_info.loop
|
|
1406
|
+
this_loop_body = this_loop.body - set([this_loop.header])
|
|
1407
|
+
loop_blocks = {x: func_ir.blocks[x] for x in this_loop_body}
|
|
1408
|
+
new_ir = func_ir.derive(loop_blocks)
|
|
1409
|
+
|
|
1410
|
+
# Work out what is live on entry and exit so as to prevent
|
|
1411
|
+
# replacement (defined vars can escape, used vars live at the header
|
|
1412
|
+
# need to remain as-is so their references are correct, they can
|
|
1413
|
+
# also escape).
|
|
1414
|
+
|
|
1415
|
+
usedefs = compute_use_defs(func_ir.blocks)
|
|
1416
|
+
idx = this_loop.header
|
|
1417
|
+
keep = set()
|
|
1418
|
+
keep |= usedefs.usemap[idx] | usedefs.defmap[idx]
|
|
1419
|
+
keep |= func_ir.variable_lifetime.livemap[idx]
|
|
1420
|
+
dont_replace = [x for x in (keep)]
|
|
1421
|
+
|
|
1422
|
+
# compute the unrolled body
|
|
1423
|
+
unrolled_body = self.inject_loop_body(
|
|
1424
|
+
branches,
|
|
1425
|
+
new_ir,
|
|
1426
|
+
max(func_ir.blocks.keys()) + 1,
|
|
1427
|
+
dont_replace,
|
|
1428
|
+
switch_data,
|
|
1429
|
+
)
|
|
1430
|
+
|
|
1431
|
+
# 6. Patch in the unrolled body and fix up
|
|
1432
|
+
blks = state.func_ir.blocks
|
|
1433
|
+
the_scope = next(iter(blks.values())).scope
|
|
1434
|
+
orig_lbl = tuple(this_loop_body)
|
|
1435
|
+
|
|
1436
|
+
replace, *delete = orig_lbl
|
|
1437
|
+
unroll, header_block = unrolled_body, this_loop.header
|
|
1438
|
+
unroll_lbl = [x for x in sorted(unroll.blocks.keys())]
|
|
1439
|
+
blks[replace] = transfer_scope(unroll.blocks[unroll_lbl[0]], the_scope)
|
|
1440
|
+
[blks.pop(d) for d in delete]
|
|
1441
|
+
for k in unroll_lbl[1:]:
|
|
1442
|
+
blks[k] = transfer_scope(unroll.blocks[k], the_scope)
|
|
1443
|
+
# stitch up the loop predicate true -> new loop body jump
|
|
1444
|
+
blks[header_block].body[-1].truebr = replace
|
|
1445
|
+
|
|
1446
|
+
def run_pass(self, state):
|
|
1447
|
+
mutated = False
|
|
1448
|
+
func_ir = state.func_ir
|
|
1449
|
+
# first limit the work by squashing the CFG if possible
|
|
1450
|
+
func_ir.blocks = simplify_CFG(func_ir.blocks)
|
|
1451
|
+
|
|
1452
|
+
if self._DEBUG:
|
|
1453
|
+
print("-" * 80 + "PASS ENTRY")
|
|
1454
|
+
func_ir.dump()
|
|
1455
|
+
print("-" * 80)
|
|
1456
|
+
|
|
1457
|
+
# limitations:
|
|
1458
|
+
# 1. No nested unrolls
|
|
1459
|
+
# 2. Opt in via `numba.literal_unroll`
|
|
1460
|
+
# 3. No multiple mix-tuple use
|
|
1461
|
+
|
|
1462
|
+
# keep running the transform loop until it reports no more changes
|
|
1463
|
+
while True:
|
|
1464
|
+
stat = self.apply_transform(state)
|
|
1465
|
+
mutated |= stat
|
|
1466
|
+
if not stat:
|
|
1467
|
+
break
|
|
1468
|
+
|
|
1469
|
+
# reset type inference now we are done with the partial results
|
|
1470
|
+
state.typemap = {}
|
|
1471
|
+
state.calltypes = None
|
|
1472
|
+
|
|
1473
|
+
return mutated
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
1477
|
+
class IterLoopCanonicalization(FunctionPass):
|
|
1478
|
+
"""Transforms loops that are induced by `getiter` into range() driven loops
|
|
1479
|
+
If the typemap is available this will only impact Tuple and UniTuple, if it
|
|
1480
|
+
is not available it will impact all matching loops.
|
|
1481
|
+
"""
|
|
1482
|
+
|
|
1483
|
+
_name = "iter_loop_canonicalisation"
|
|
1484
|
+
|
|
1485
|
+
_DEBUG = False
|
|
1486
|
+
|
|
1487
|
+
# if partial typing info is available it will only look at these types
|
|
1488
|
+
_accepted_types = (types.BaseTuple, types.LiteralList)
|
|
1489
|
+
_accepted_calls = (literal_unroll,)
|
|
1490
|
+
|
|
1491
|
+
def __init__(self):
|
|
1492
|
+
FunctionPass.__init__(self)
|
|
1493
|
+
|
|
1494
|
+
def assess_loop(self, loop, func_ir, partial_typemap=None):
|
|
1495
|
+
# it's a iter loop if:
|
|
1496
|
+
# - loop header is driven by an iternext
|
|
1497
|
+
# - the iternext value is a phi derived from getiter()
|
|
1498
|
+
|
|
1499
|
+
# check header
|
|
1500
|
+
iternexts = [
|
|
1501
|
+
_ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
|
|
1502
|
+
]
|
|
1503
|
+
if len(iternexts) != 1:
|
|
1504
|
+
return False
|
|
1505
|
+
for iternext in iternexts:
|
|
1506
|
+
phi = guard(get_definition, func_ir, iternext.value)
|
|
1507
|
+
if phi is None:
|
|
1508
|
+
return False
|
|
1509
|
+
if getattr(phi, "op", False) == "getiter":
|
|
1510
|
+
if partial_typemap:
|
|
1511
|
+
# check that the call site is accepted, until we're
|
|
1512
|
+
# confident that tuple unrolling is behaving require opt-in
|
|
1513
|
+
# guard of `literal_unroll`, remove this later!
|
|
1514
|
+
phi_val_defn = guard(get_definition, func_ir, phi.value)
|
|
1515
|
+
if not isinstance(phi_val_defn, ir.Expr):
|
|
1516
|
+
return False
|
|
1517
|
+
if not phi_val_defn.op == "call":
|
|
1518
|
+
return False
|
|
1519
|
+
call = guard(get_definition, func_ir, phi_val_defn)
|
|
1520
|
+
if call is None or len(call.args) != 1:
|
|
1521
|
+
return False
|
|
1522
|
+
func_var = guard(get_definition, func_ir, call.func)
|
|
1523
|
+
func = guard(get_definition, func_ir, func_var)
|
|
1524
|
+
if func is None or not isinstance(
|
|
1525
|
+
func, (ir.Global, ir.FreeVar)
|
|
1526
|
+
):
|
|
1527
|
+
return False
|
|
1528
|
+
if (
|
|
1529
|
+
func.value is None
|
|
1530
|
+
or func.value not in self._accepted_calls
|
|
1531
|
+
):
|
|
1532
|
+
return False
|
|
1533
|
+
|
|
1534
|
+
# now check the type is supported
|
|
1535
|
+
ty = partial_typemap.get(call.args[0].name, None)
|
|
1536
|
+
if ty and isinstance(ty, self._accepted_types):
|
|
1537
|
+
return len(loop.entries) == 1
|
|
1538
|
+
else:
|
|
1539
|
+
return len(loop.entries) == 1
|
|
1540
|
+
|
|
1541
|
+
def transform(self, loop, func_ir, cfg):
|
|
1542
|
+
def get_range(a):
|
|
1543
|
+
return range(len(a))
|
|
1544
|
+
|
|
1545
|
+
iternext = [
|
|
1546
|
+
_ for _ in func_ir.blocks[loop.header].find_exprs("iternext")
|
|
1547
|
+
][0]
|
|
1548
|
+
LOC = func_ir.blocks[loop.header].loc
|
|
1549
|
+
scope = func_ir.blocks[loop.header].scope
|
|
1550
|
+
get_range_var = scope.redefine("CANONICALISER_get_range_gbl", LOC)
|
|
1551
|
+
get_range_global = ir.Global("get_range", get_range, LOC)
|
|
1552
|
+
assgn = ir.Assign(get_range_global, get_range_var, LOC)
|
|
1553
|
+
|
|
1554
|
+
loop_entry = tuple(loop.entries)[0]
|
|
1555
|
+
entry_block = func_ir.blocks[loop_entry]
|
|
1556
|
+
entry_block.body.insert(0, assgn)
|
|
1557
|
+
|
|
1558
|
+
iterarg = guard(get_definition, func_ir, iternext.value)
|
|
1559
|
+
if iterarg is not None:
|
|
1560
|
+
iterarg = iterarg.value
|
|
1561
|
+
|
|
1562
|
+
# look for iternext
|
|
1563
|
+
idx = 0
|
|
1564
|
+
for stmt in entry_block.body:
|
|
1565
|
+
if isinstance(stmt, ir.Assign):
|
|
1566
|
+
if (
|
|
1567
|
+
isinstance(stmt.value, ir.Expr)
|
|
1568
|
+
and stmt.value.op == "getiter"
|
|
1569
|
+
):
|
|
1570
|
+
break
|
|
1571
|
+
idx += 1
|
|
1572
|
+
else:
|
|
1573
|
+
raise ValueError("problem")
|
|
1574
|
+
|
|
1575
|
+
# create a range(len(tup)) and inject it
|
|
1576
|
+
call_get_range_var = scope.redefine("CANONICALISER_call_get_range", LOC)
|
|
1577
|
+
make_call = ir.Expr.call(get_range_var, (stmt.value.value,), (), LOC)
|
|
1578
|
+
assgn_call = ir.Assign(make_call, call_get_range_var, LOC)
|
|
1579
|
+
entry_block.body.insert(idx, assgn_call)
|
|
1580
|
+
entry_block.body[idx + 1].value.value = call_get_range_var
|
|
1581
|
+
|
|
1582
|
+
glbls = copy(func_ir.func_id.func.__globals__)
|
|
1583
|
+
|
|
1584
|
+
inline_closurecall.inline_closure_call(
|
|
1585
|
+
func_ir,
|
|
1586
|
+
glbls,
|
|
1587
|
+
entry_block,
|
|
1588
|
+
idx,
|
|
1589
|
+
get_range,
|
|
1590
|
+
)
|
|
1591
|
+
kill = entry_block.body.index(assgn)
|
|
1592
|
+
entry_block.body.pop(kill)
|
|
1593
|
+
|
|
1594
|
+
# find the induction variable + references in the loop header
|
|
1595
|
+
# fixed point iter to do this, it's a bit clunky
|
|
1596
|
+
induction_vars = set()
|
|
1597
|
+
header_block = func_ir.blocks[loop.header]
|
|
1598
|
+
|
|
1599
|
+
# find induction var
|
|
1600
|
+
ind = [x for x in header_block.find_exprs("pair_first")]
|
|
1601
|
+
for x in ind:
|
|
1602
|
+
induction_vars.add(func_ir.get_assignee(x, loop.header))
|
|
1603
|
+
# find aliases of the induction var
|
|
1604
|
+
tmp = set()
|
|
1605
|
+
for x in induction_vars:
|
|
1606
|
+
try: # there's not always an alias, e.g. loop from inlined closure
|
|
1607
|
+
tmp.add(func_ir.get_assignee(x, loop.header))
|
|
1608
|
+
except ValueError:
|
|
1609
|
+
pass
|
|
1610
|
+
induction_vars |= tmp
|
|
1611
|
+
induction_var_names = set([x.name for x in induction_vars])
|
|
1612
|
+
|
|
1613
|
+
# Find the downstream blocks that might reference the induction var
|
|
1614
|
+
succ = set()
|
|
1615
|
+
for lbl in loop.exits:
|
|
1616
|
+
succ |= set([x[0] for x in cfg.successors(lbl)])
|
|
1617
|
+
check_blocks = (loop.body | loop.exits | succ) ^ {loop.header}
|
|
1618
|
+
|
|
1619
|
+
# replace RHS use of induction var with getitem
|
|
1620
|
+
for lbl in check_blocks:
|
|
1621
|
+
for stmt in func_ir.blocks[lbl].body:
|
|
1622
|
+
if isinstance(stmt, ir.Assign):
|
|
1623
|
+
# check for aliases
|
|
1624
|
+
try:
|
|
1625
|
+
lookup = getattr(stmt.value, "name", None)
|
|
1626
|
+
except KeyError:
|
|
1627
|
+
continue
|
|
1628
|
+
if lookup and lookup in induction_var_names:
|
|
1629
|
+
stmt.value = ir.Expr.getitem(
|
|
1630
|
+
iterarg, stmt.value, stmt.loc
|
|
1631
|
+
)
|
|
1632
|
+
|
|
1633
|
+
post_proc = postproc.PostProcessor(func_ir)
|
|
1634
|
+
post_proc.run()
|
|
1635
|
+
|
|
1636
|
+
def run_pass(self, state):
|
|
1637
|
+
func_ir = state.func_ir
|
|
1638
|
+
cfg = compute_cfg_from_blocks(func_ir.blocks)
|
|
1639
|
+
loops = cfg.loops()
|
|
1640
|
+
|
|
1641
|
+
mutated = False
|
|
1642
|
+
for header, loop in loops.items():
|
|
1643
|
+
stat = self.assess_loop(loop, func_ir, state.typemap)
|
|
1644
|
+
if stat:
|
|
1645
|
+
if self._DEBUG:
|
|
1646
|
+
print("Canonicalising loop", loop)
|
|
1647
|
+
self.transform(loop, func_ir, cfg)
|
|
1648
|
+
mutated = True
|
|
1649
|
+
else:
|
|
1650
|
+
if self._DEBUG:
|
|
1651
|
+
print("NOT Canonicalising loop", loop)
|
|
1652
|
+
|
|
1653
|
+
func_ir.blocks = simplify_CFG(func_ir.blocks)
|
|
1654
|
+
return mutated
|
|
1655
|
+
|
|
1656
|
+
|
|
1657
|
+
@register_pass(mutates_CFG=False, analysis_only=False)
|
|
1658
|
+
class PropagateLiterals(FunctionPass):
|
|
1659
|
+
"""Implement literal propagation based on partial type inference"""
|
|
1660
|
+
|
|
1661
|
+
_name = "PropagateLiterals"
|
|
1662
|
+
|
|
1663
|
+
def __init__(self):
|
|
1664
|
+
FunctionPass.__init__(self)
|
|
1665
|
+
|
|
1666
|
+
def get_analysis_usage(self, AU):
|
|
1667
|
+
AU.add_required(ReconstructSSA)
|
|
1668
|
+
|
|
1669
|
+
def run_pass(self, state):
|
|
1670
|
+
func_ir = state.func_ir
|
|
1671
|
+
typemap = state.typemap
|
|
1672
|
+
flags = state.flags
|
|
1673
|
+
|
|
1674
|
+
accepted_functions = ("isinstance", "hasattr")
|
|
1675
|
+
|
|
1676
|
+
if not hasattr(func_ir, "_definitions") and not flags.enable_ssa:
|
|
1677
|
+
func_ir._definitions = build_definitions(func_ir.blocks)
|
|
1678
|
+
|
|
1679
|
+
changed = False
|
|
1680
|
+
|
|
1681
|
+
for block in func_ir.blocks.values():
|
|
1682
|
+
for assign in block.find_insts(ir.Assign):
|
|
1683
|
+
value = assign.value
|
|
1684
|
+
if isinstance(value, (ir.Arg, ir.Const, ir.FreeVar, ir.Global)):
|
|
1685
|
+
continue
|
|
1686
|
+
|
|
1687
|
+
# 1) Don't change return stmt in the form
|
|
1688
|
+
# $return_xyz = cast(value=ABC)
|
|
1689
|
+
# 2) Don't propagate literal values that are not primitives
|
|
1690
|
+
if isinstance(value, ir.Expr) and value.op in (
|
|
1691
|
+
"cast",
|
|
1692
|
+
"build_map",
|
|
1693
|
+
"build_list",
|
|
1694
|
+
"build_tuple",
|
|
1695
|
+
"build_set",
|
|
1696
|
+
):
|
|
1697
|
+
continue
|
|
1698
|
+
|
|
1699
|
+
target = assign.target
|
|
1700
|
+
if not flags.enable_ssa:
|
|
1701
|
+
# SSA is disabled when doing inlining
|
|
1702
|
+
if guard(get_definition, func_ir, target.name) is None: # noqa: E501
|
|
1703
|
+
continue
|
|
1704
|
+
|
|
1705
|
+
# Numba cannot safely determine if an isinstance call
|
|
1706
|
+
# with a PHI node is True/False. For instance, in
|
|
1707
|
+
# the case below, the partial type inference step can coerce
|
|
1708
|
+
# '$z' to float, so any call to 'isinstance(z, int)' would fail.
|
|
1709
|
+
#
|
|
1710
|
+
# def fn(x):
|
|
1711
|
+
# if x > 4:
|
|
1712
|
+
# z = 1
|
|
1713
|
+
# else:
|
|
1714
|
+
# z = 3.14
|
|
1715
|
+
# if isinstance(z, int):
|
|
1716
|
+
# print('int')
|
|
1717
|
+
# else:
|
|
1718
|
+
# print('float')
|
|
1719
|
+
#
|
|
1720
|
+
# At the moment, one avoid propagating the literal
|
|
1721
|
+
# value if the argument is a PHI node
|
|
1722
|
+
|
|
1723
|
+
if isinstance(value, ir.Expr) and value.op == "call":
|
|
1724
|
+
fn = guard(get_definition, func_ir, value.func.name)
|
|
1725
|
+
if fn is None:
|
|
1726
|
+
continue
|
|
1727
|
+
|
|
1728
|
+
if not (
|
|
1729
|
+
isinstance(fn, ir.Global)
|
|
1730
|
+
and fn.name in accepted_functions
|
|
1731
|
+
):
|
|
1732
|
+
continue
|
|
1733
|
+
|
|
1734
|
+
for arg in value.args:
|
|
1735
|
+
# check if any of the args to isinstance is a PHI node
|
|
1736
|
+
iv = func_ir._definitions[arg.name]
|
|
1737
|
+
assert len(iv) == 1 # SSA!
|
|
1738
|
+
if isinstance(iv[0], ir.Expr) and iv[0].op == "phi":
|
|
1739
|
+
msg = (
|
|
1740
|
+
f"{fn.name}() cannot determine the "
|
|
1741
|
+
f'type of variable "{arg.unversioned_name}" '
|
|
1742
|
+
"due to a branch."
|
|
1743
|
+
)
|
|
1744
|
+
raise errors.NumbaTypeError(msg, loc=assign.loc)
|
|
1745
|
+
|
|
1746
|
+
# Only propagate a PHI node if all arguments are the same
|
|
1747
|
+
# constant
|
|
1748
|
+
if isinstance(value, ir.Expr) and value.op == "phi":
|
|
1749
|
+
# typemap will return None in case `inc.name` not in typemap
|
|
1750
|
+
v = [typemap.get(inc.name) for inc in value.incoming_values]
|
|
1751
|
+
# stop if the elements in `v` do not hold the same value
|
|
1752
|
+
if v[0] is not None and any([v[0] != vi for vi in v]):
|
|
1753
|
+
continue
|
|
1754
|
+
|
|
1755
|
+
lit = typemap.get(target.name, None)
|
|
1756
|
+
if lit and isinstance(lit, types.Literal):
|
|
1757
|
+
# replace assign instruction by ir.Const(lit) iff
|
|
1758
|
+
# lit is a literal value
|
|
1759
|
+
rhs = ir.Const(lit.literal_value, assign.loc)
|
|
1760
|
+
new_assign = ir.Assign(rhs, target, assign.loc)
|
|
1761
|
+
|
|
1762
|
+
# replace instruction
|
|
1763
|
+
block.insert_after(new_assign, assign)
|
|
1764
|
+
block.remove(assign)
|
|
1765
|
+
|
|
1766
|
+
changed = True
|
|
1767
|
+
|
|
1768
|
+
# reset type inference now we are done with the partial results
|
|
1769
|
+
state.typemap = None
|
|
1770
|
+
state.calltypes = None
|
|
1771
|
+
|
|
1772
|
+
if changed:
|
|
1773
|
+
# Rebuild definitions
|
|
1774
|
+
func_ir._definitions = build_definitions(func_ir.blocks)
|
|
1775
|
+
|
|
1776
|
+
return changed
|
|
1777
|
+
|
|
1778
|
+
|
|
1779
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
1780
|
+
class LiteralPropagationSubPipelinePass(FunctionPass):
|
|
1781
|
+
"""Implement literal propagation based on partial type inference"""
|
|
1782
|
+
|
|
1783
|
+
_name = "LiteralPropagation"
|
|
1784
|
+
|
|
1785
|
+
def __init__(self):
|
|
1786
|
+
FunctionPass.__init__(self)
|
|
1787
|
+
|
|
1788
|
+
def run_pass(self, state):
|
|
1789
|
+
# Determine whether to even attempt this pass... if there's no
|
|
1790
|
+
# `isinstance` as a global or as a freevar then just skip.
|
|
1791
|
+
|
|
1792
|
+
found = False
|
|
1793
|
+
func_ir = state.func_ir
|
|
1794
|
+
for blk in func_ir.blocks.values():
|
|
1795
|
+
for asgn in blk.find_insts(ir.Assign):
|
|
1796
|
+
if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
|
|
1797
|
+
value = asgn.value.value
|
|
1798
|
+
if value is isinstance or value is hasattr:
|
|
1799
|
+
found = True
|
|
1800
|
+
break
|
|
1801
|
+
if found:
|
|
1802
|
+
break
|
|
1803
|
+
if not found:
|
|
1804
|
+
return False
|
|
1805
|
+
|
|
1806
|
+
# run as subpipeline
|
|
1807
|
+
from numba.cuda.core.compiler_machinery import PassManager
|
|
1808
|
+
from numba.cuda.core.typed_passes import PartialTypeInference
|
|
1809
|
+
|
|
1810
|
+
pm = PassManager("literal_propagation_subpipeline")
|
|
1811
|
+
|
|
1812
|
+
pm.add_pass(PartialTypeInference, "performs partial type inference")
|
|
1813
|
+
pm.add_pass(PropagateLiterals, "performs propagation of literal values")
|
|
1814
|
+
|
|
1815
|
+
# rewrite consts / dead branch pruning
|
|
1816
|
+
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
|
|
1817
|
+
pm.add_pass(DeadBranchPrune, "dead branch pruning")
|
|
1818
|
+
|
|
1819
|
+
pm.finalize()
|
|
1820
|
+
pm.run(state)
|
|
1821
|
+
return True
|
|
1822
|
+
|
|
1823
|
+
def get_analysis_usage(self, AU):
|
|
1824
|
+
AU.add_required(ReconstructSSA)
|
|
1825
|
+
|
|
1826
|
+
|
|
1827
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
1828
|
+
class LiteralUnroll(FunctionPass):
|
|
1829
|
+
"""Implement the literal_unroll semantics"""
|
|
1830
|
+
|
|
1831
|
+
_name = "literal_unroll"
|
|
1832
|
+
|
|
1833
|
+
def __init__(self):
|
|
1834
|
+
FunctionPass.__init__(self)
|
|
1835
|
+
|
|
1836
|
+
def run_pass(self, state):
|
|
1837
|
+
# Determine whether to even attempt this pass... if there's no
|
|
1838
|
+
# `literal_unroll` as a global or as a freevar then just skip.
|
|
1839
|
+
found = False
|
|
1840
|
+
func_ir = state.func_ir
|
|
1841
|
+
for blk in func_ir.blocks.values():
|
|
1842
|
+
for asgn in blk.find_insts(ir.Assign):
|
|
1843
|
+
if isinstance(asgn.value, (ir.Global, ir.FreeVar)):
|
|
1844
|
+
if asgn.value.value is literal_unroll:
|
|
1845
|
+
found = True
|
|
1846
|
+
break
|
|
1847
|
+
if found:
|
|
1848
|
+
break
|
|
1849
|
+
if not found:
|
|
1850
|
+
return False
|
|
1851
|
+
|
|
1852
|
+
# run as subpipeline
|
|
1853
|
+
from numba.cuda.core.compiler_machinery import PassManager
|
|
1854
|
+
from numba.cuda.core.typed_passes import PartialTypeInference
|
|
1855
|
+
|
|
1856
|
+
pm = PassManager("literal_unroll_subpipeline")
|
|
1857
|
+
# get types where possible to help with list->tuple change
|
|
1858
|
+
pm.add_pass(PartialTypeInference, "performs partial type inference")
|
|
1859
|
+
# make const lists tuples
|
|
1860
|
+
pm.add_pass(
|
|
1861
|
+
TransformLiteralUnrollConstListToTuple,
|
|
1862
|
+
"switch const list for tuples",
|
|
1863
|
+
)
|
|
1864
|
+
# recompute partial typemap following IR change
|
|
1865
|
+
pm.add_pass(PartialTypeInference, "performs partial type inference")
|
|
1866
|
+
# canonicalise loops
|
|
1867
|
+
pm.add_pass(
|
|
1868
|
+
IterLoopCanonicalization, "switch iter loops for range driven loops"
|
|
1869
|
+
)
|
|
1870
|
+
# rewrite consts
|
|
1871
|
+
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
|
|
1872
|
+
# do the unroll
|
|
1873
|
+
pm.add_pass(MixedContainerUnroller, "performs mixed container unroll")
|
|
1874
|
+
# rewrite dynamic getitem to static getitem as it's possible some more
|
|
1875
|
+
# getitems will now be statically resolvable
|
|
1876
|
+
pm.add_pass(GenericRewrites, "Generic Rewrites")
|
|
1877
|
+
pm.add_pass(RewriteSemanticConstants, "rewrite semantic constants")
|
|
1878
|
+
pm.finalize()
|
|
1879
|
+
pm.run(state)
|
|
1880
|
+
return True
|
|
1881
|
+
|
|
1882
|
+
|
|
1883
|
+
@register_pass(mutates_CFG=True, analysis_only=False)
|
|
1884
|
+
class SimplifyCFG(FunctionPass):
|
|
1885
|
+
"""Perform CFG simplification"""
|
|
1886
|
+
|
|
1887
|
+
_name = "simplify_cfg"
|
|
1888
|
+
|
|
1889
|
+
def __init__(self):
|
|
1890
|
+
FunctionPass.__init__(self)
|
|
1891
|
+
|
|
1892
|
+
def run_pass(self, state):
|
|
1893
|
+
blks = state.func_ir.blocks
|
|
1894
|
+
new_blks = simplify_CFG(blks)
|
|
1895
|
+
state.func_ir.blocks = new_blks
|
|
1896
|
+
mutated = blks != new_blks
|
|
1897
|
+
return mutated
|
|
1898
|
+
|
|
1899
|
+
|
|
1900
|
+
@register_pass(mutates_CFG=False, analysis_only=False)
|
|
1901
|
+
class ReconstructSSA(FunctionPass):
|
|
1902
|
+
"""Perform SSA-reconstruction
|
|
1903
|
+
|
|
1904
|
+
Produces minimal SSA.
|
|
1905
|
+
"""
|
|
1906
|
+
|
|
1907
|
+
_name = "reconstruct_ssa"
|
|
1908
|
+
|
|
1909
|
+
def __init__(self):
|
|
1910
|
+
FunctionPass.__init__(self)
|
|
1911
|
+
|
|
1912
|
+
def run_pass(self, state):
|
|
1913
|
+
state.func_ir = reconstruct_ssa(state.func_ir)
|
|
1914
|
+
self._patch_locals(state)
|
|
1915
|
+
|
|
1916
|
+
# Rebuild definitions
|
|
1917
|
+
state.func_ir._definitions = build_definitions(state.func_ir.blocks)
|
|
1918
|
+
|
|
1919
|
+
# Rerun postprocessor to update metadata
|
|
1920
|
+
# example generator_info
|
|
1921
|
+
post_proc = postproc.PostProcessor(state.func_ir)
|
|
1922
|
+
post_proc.run(emit_dels=False)
|
|
1923
|
+
|
|
1924
|
+
if config.DEBUG or config.DUMP_SSA:
|
|
1925
|
+
name = state.func_ir.func_id.func_qualname
|
|
1926
|
+
print(f"SSA IR DUMP: {name}".center(80, "-"))
|
|
1927
|
+
state.func_ir.dump()
|
|
1928
|
+
|
|
1929
|
+
return True # XXX detect if it actually got changed
|
|
1930
|
+
|
|
1931
|
+
def _patch_locals(self, state):
|
|
1932
|
+
# Fix dispatcher locals dictionary type annotation
|
|
1933
|
+
locals_dict = state.get("locals")
|
|
1934
|
+
if locals_dict is None:
|
|
1935
|
+
return
|
|
1936
|
+
|
|
1937
|
+
first_blk, *_ = state.func_ir.blocks.values()
|
|
1938
|
+
scope = first_blk.scope
|
|
1939
|
+
for parent, redefs in scope.var_redefinitions.items():
|
|
1940
|
+
if parent in locals_dict:
|
|
1941
|
+
typ = locals_dict[parent]
|
|
1942
|
+
for derived in redefs:
|
|
1943
|
+
locals_dict[derived] = typ
|
|
1944
|
+
|
|
1945
|
+
|
|
1946
|
+
@register_pass(mutates_CFG=False, analysis_only=False)
|
|
1947
|
+
class RewriteDynamicRaises(FunctionPass):
|
|
1948
|
+
"""Replace existing raise statements by dynamic raises in Numba IR."""
|
|
1949
|
+
|
|
1950
|
+
_name = "Rewrite dynamic raises"
|
|
1951
|
+
|
|
1952
|
+
def __init__(self):
|
|
1953
|
+
FunctionPass.__init__(self)
|
|
1954
|
+
|
|
1955
|
+
def run_pass(self, state):
|
|
1956
|
+
func_ir = state.func_ir
|
|
1957
|
+
changed = False
|
|
1958
|
+
|
|
1959
|
+
for block in func_ir.blocks.values():
|
|
1960
|
+
for raise_ in block.find_insts((ir.Raise, ir.TryRaise)):
|
|
1961
|
+
call_inst = guard(get_definition, func_ir, raise_.exception)
|
|
1962
|
+
if call_inst is None:
|
|
1963
|
+
continue
|
|
1964
|
+
exc_type = func_ir.infer_constant(call_inst.func.name)
|
|
1965
|
+
exc_args = []
|
|
1966
|
+
for exc_arg in call_inst.args:
|
|
1967
|
+
try:
|
|
1968
|
+
const = func_ir.infer_constant(exc_arg)
|
|
1969
|
+
exc_args.append(const)
|
|
1970
|
+
except consts.ConstantInferenceError:
|
|
1971
|
+
exc_args.append(exc_arg)
|
|
1972
|
+
loc = raise_.loc
|
|
1973
|
+
|
|
1974
|
+
cls = {
|
|
1975
|
+
ir.TryRaise: ir.DynamicTryRaise,
|
|
1976
|
+
ir.Raise: ir.DynamicRaise,
|
|
1977
|
+
}[type(raise_)]
|
|
1978
|
+
|
|
1979
|
+
dyn_raise = cls(exc_type, tuple(exc_args), loc)
|
|
1980
|
+
block.insert_after(dyn_raise, raise_)
|
|
1981
|
+
block.remove(raise_)
|
|
1982
|
+
changed = True
|
|
1983
|
+
return changed
|