numba-cuda 0.19.1__py3-none-any.whl → 0.20.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of numba-cuda might be problematic. Click here for more details.
- numba_cuda/VERSION +1 -1
- numba_cuda/numba/cuda/__init__.py +1 -1
- numba_cuda/numba/cuda/_internal/cuda_bf16.py +12706 -1470
- numba_cuda/numba/cuda/_internal/cuda_fp16.py +2653 -8769
- numba_cuda/numba/cuda/api.py +6 -1
- numba_cuda/numba/cuda/bf16.py +285 -2
- numba_cuda/numba/cuda/cgutils.py +2 -2
- numba_cuda/numba/cuda/cloudpickle/__init__.py +21 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle.py +1598 -0
- numba_cuda/numba/cuda/cloudpickle/cloudpickle_fast.py +17 -0
- numba_cuda/numba/cuda/codegen.py +1 -1
- numba_cuda/numba/cuda/compiler.py +373 -30
- numba_cuda/numba/cuda/core/analysis.py +319 -0
- numba_cuda/numba/cuda/core/annotations/__init__.py +0 -0
- numba_cuda/numba/cuda/core/annotations/type_annotations.py +304 -0
- numba_cuda/numba/cuda/core/base.py +1289 -0
- numba_cuda/numba/cuda/core/bytecode.py +727 -0
- numba_cuda/numba/cuda/core/caching.py +2 -2
- numba_cuda/numba/cuda/core/compiler.py +6 -14
- numba_cuda/numba/cuda/core/compiler_machinery.py +497 -0
- numba_cuda/numba/cuda/core/config.py +747 -0
- numba_cuda/numba/cuda/core/consts.py +124 -0
- numba_cuda/numba/cuda/core/cpu.py +370 -0
- numba_cuda/numba/cuda/core/environment.py +68 -0
- numba_cuda/numba/cuda/core/event.py +511 -0
- numba_cuda/numba/cuda/core/funcdesc.py +330 -0
- numba_cuda/numba/cuda/core/inline_closurecall.py +1889 -0
- numba_cuda/numba/cuda/core/interpreter.py +48 -26
- numba_cuda/numba/cuda/core/ir_utils.py +15 -26
- numba_cuda/numba/cuda/core/options.py +262 -0
- numba_cuda/numba/cuda/core/postproc.py +249 -0
- numba_cuda/numba/cuda/core/pythonapi.py +1868 -0
- numba_cuda/numba/cuda/core/rewrites/__init__.py +26 -0
- numba_cuda/numba/cuda/core/rewrites/ir_print.py +90 -0
- numba_cuda/numba/cuda/core/rewrites/registry.py +104 -0
- numba_cuda/numba/cuda/core/rewrites/static_binop.py +40 -0
- numba_cuda/numba/cuda/core/rewrites/static_getitem.py +187 -0
- numba_cuda/numba/cuda/core/rewrites/static_raise.py +98 -0
- numba_cuda/numba/cuda/core/ssa.py +496 -0
- numba_cuda/numba/cuda/core/targetconfig.py +329 -0
- numba_cuda/numba/cuda/core/tracing.py +231 -0
- numba_cuda/numba/cuda/core/transforms.py +952 -0
- numba_cuda/numba/cuda/core/typed_passes.py +738 -7
- numba_cuda/numba/cuda/core/typeinfer.py +1948 -0
- numba_cuda/numba/cuda/core/unsafe/__init__.py +0 -0
- numba_cuda/numba/cuda/core/unsafe/bytes.py +67 -0
- numba_cuda/numba/cuda/core/unsafe/eh.py +66 -0
- numba_cuda/numba/cuda/core/unsafe/refcount.py +98 -0
- numba_cuda/numba/cuda/core/untyped_passes.py +1983 -0
- numba_cuda/numba/cuda/cpython/cmathimpl.py +560 -0
- numba_cuda/numba/cuda/cpython/mathimpl.py +499 -0
- numba_cuda/numba/cuda/cpython/numbers.py +1474 -0
- numba_cuda/numba/cuda/cuda_paths.py +422 -246
- numba_cuda/numba/cuda/cudadecl.py +1 -1
- numba_cuda/numba/cuda/cudadrv/__init__.py +1 -1
- numba_cuda/numba/cuda/cudadrv/devicearray.py +2 -1
- numba_cuda/numba/cuda/cudadrv/driver.py +11 -140
- numba_cuda/numba/cuda/cudadrv/dummyarray.py +111 -24
- numba_cuda/numba/cuda/cudadrv/libs.py +5 -5
- numba_cuda/numba/cuda/cudadrv/mappings.py +1 -1
- numba_cuda/numba/cuda/cudadrv/nvrtc.py +19 -8
- numba_cuda/numba/cuda/cudadrv/nvvm.py +1 -4
- numba_cuda/numba/cuda/cudadrv/runtime.py +1 -1
- numba_cuda/numba/cuda/cudaimpl.py +5 -1
- numba_cuda/numba/cuda/debuginfo.py +85 -2
- numba_cuda/numba/cuda/decorators.py +3 -3
- numba_cuda/numba/cuda/descriptor.py +3 -4
- numba_cuda/numba/cuda/deviceufunc.py +66 -2
- numba_cuda/numba/cuda/dispatcher.py +18 -39
- numba_cuda/numba/cuda/flags.py +141 -1
- numba_cuda/numba/cuda/fp16.py +0 -2
- numba_cuda/numba/cuda/include/13/cuda_bf16.h +5118 -0
- numba_cuda/numba/cuda/include/13/cuda_bf16.hpp +3865 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.h +5363 -0
- numba_cuda/numba/cuda/include/13/cuda_fp16.hpp +3483 -0
- numba_cuda/numba/cuda/lowering.py +7 -144
- numba_cuda/numba/cuda/mathimpl.py +2 -1
- numba_cuda/numba/cuda/memory_management/nrt.py +43 -17
- numba_cuda/numba/cuda/misc/findlib.py +75 -0
- numba_cuda/numba/cuda/models.py +9 -1
- numba_cuda/numba/cuda/np/npdatetime_helpers.py +217 -0
- numba_cuda/numba/cuda/np/npyfuncs.py +1807 -0
- numba_cuda/numba/cuda/np/numpy_support.py +553 -0
- numba_cuda/numba/cuda/np/ufunc/ufuncbuilder.py +59 -0
- numba_cuda/numba/cuda/nvvmutils.py +1 -1
- numba_cuda/numba/cuda/printimpl.py +12 -1
- numba_cuda/numba/cuda/random.py +1 -1
- numba_cuda/numba/cuda/serialize.py +1 -1
- numba_cuda/numba/cuda/simulator/__init__.py +1 -1
- numba_cuda/numba/cuda/simulator/api.py +1 -1
- numba_cuda/numba/cuda/simulator/compiler.py +4 -0
- numba_cuda/numba/cuda/simulator/cudadrv/devicearray.py +1 -1
- numba_cuda/numba/cuda/simulator/kernelapi.py +1 -1
- numba_cuda/numba/cuda/simulator/memory_management/nrt.py +14 -2
- numba_cuda/numba/cuda/target.py +35 -17
- numba_cuda/numba/cuda/testing.py +7 -19
- numba_cuda/numba/cuda/tests/__init__.py +1 -1
- numba_cuda/numba/cuda/tests/cloudpickle_main_class.py +9 -0
- numba_cuda/numba/cuda/tests/core/test_serialize.py +4 -4
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_devicerecord.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_cuda_libraries.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_deallocations.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_detect.py +6 -3
- numba_cuda/numba/cuda/tests/cudadrv/test_emm_plugins.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_linker.py +18 -2
- numba_cuda/numba/cuda/tests/cudadrv/test_module_callbacks.py +2 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_nvjitlink.py +1 -1
- numba_cuda/numba/cuda/tests/cudadrv/test_ptds.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/extensions_usecases.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_array.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_atomics.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16.py +539 -2
- numba_cuda/numba/cuda/tests/cudapy/test_bfloat16_bindings.py +81 -1
- numba_cuda/numba/cuda/tests/cudapy/test_caching.py +1 -3
- numba_cuda/numba/cuda/tests/cudapy/test_complex.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_constmem.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_cooperative_groups.py +2 -3
- numba_cuda/numba/cuda/tests/cudapy/test_copy_propagate.py +130 -0
- numba_cuda/numba/cuda/tests/cudapy/test_datetime.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debug.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo.py +293 -4
- numba_cuda/numba/cuda/tests/cudapy/test_debuginfo_types.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_dispatcher.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_errors.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_exception.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_extending.py +2 -1
- numba_cuda/numba/cuda/tests/cudapy/test_inline.py +18 -8
- numba_cuda/numba/cuda/tests/cudapy/test_intrinsics.py +23 -21
- numba_cuda/numba/cuda/tests/cudapy/test_ir_utils.py +10 -37
- numba_cuda/numba/cuda/tests/cudapy/test_laplace.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_math.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_matmul.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_operator.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_print.py +20 -0
- numba_cuda/numba/cuda/tests/cudapy/test_record_dtype.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_reduction.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_serialize.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_sm.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_ssa.py +453 -0
- numba_cuda/numba/cuda/tests/cudapy/test_sync.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_typeinfer.py +538 -0
- numba_cuda/numba/cuda/tests/cudapy/test_ufuncs.py +263 -2
- numba_cuda/numba/cuda/tests/cudapy/test_userexc.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vector_type.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_vectorize_decor.py +112 -6
- numba_cuda/numba/cuda/tests/cudapy/test_warning.py +1 -1
- numba_cuda/numba/cuda/tests/cudapy/test_warp_ops.py +1 -1
- numba_cuda/numba/cuda/tests/doc_examples/test_cg.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_ffi.py +3 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_laplace.py +0 -2
- numba_cuda/numba/cuda/tests/doc_examples/test_sessionize.py +0 -2
- numba_cuda/numba/cuda/tests/nocuda/test_import.py +3 -1
- numba_cuda/numba/cuda/tests/nocuda/test_library_lookup.py +24 -12
- numba_cuda/numba/cuda/tests/nrt/test_nrt.py +2 -1
- numba_cuda/numba/cuda/tests/support.py +55 -15
- numba_cuda/numba/cuda/tests/test_tracing.py +200 -0
- numba_cuda/numba/cuda/types.py +56 -0
- numba_cuda/numba/cuda/typing/__init__.py +9 -1
- numba_cuda/numba/cuda/typing/cffi_utils.py +55 -0
- numba_cuda/numba/cuda/typing/context.py +751 -0
- numba_cuda/numba/cuda/typing/enumdecl.py +74 -0
- numba_cuda/numba/cuda/typing/npydecl.py +658 -0
- numba_cuda/numba/cuda/typing/templates.py +7 -6
- numba_cuda/numba/cuda/ufuncs.py +3 -3
- numba_cuda/numba/cuda/utils.py +6 -112
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/METADATA +4 -3
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/RECORD +171 -116
- numba_cuda/numba/cuda/tests/cudadrv/test_mvc.py +0 -60
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/WHEEL +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/licenses/LICENSE.numba +0 -0
- {numba_cuda-0.19.1.dist-info → numba_cuda-0.20.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,952 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
2
|
+
# SPDX-License-Identifier: BSD-2-Clause
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Implement transformation on Numba IR
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from collections import namedtuple, defaultdict
|
|
9
|
+
import logging
|
|
10
|
+
import operator
|
|
11
|
+
|
|
12
|
+
from numba.core.analysis import compute_cfg_from_blocks, find_top_level_loops
|
|
13
|
+
from numba.core import ir, errors
|
|
14
|
+
from numba.cuda.core import ir_utils
|
|
15
|
+
from numba.core.analysis import compute_use_defs
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
_logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _extract_loop_lifting_candidates(cfg, blocks):
|
|
22
|
+
"""
|
|
23
|
+
Returns a list of loops that are candidate for loop lifting
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
# check well-formed-ness of the loop
|
|
27
|
+
def same_exit_point(loop):
|
|
28
|
+
"all exits must point to the same location"
|
|
29
|
+
outedges = set()
|
|
30
|
+
for k in loop.exits:
|
|
31
|
+
succs = set(x for x, _ in cfg.successors(k))
|
|
32
|
+
if not succs:
|
|
33
|
+
# If the exit point has no successor, it contains an return
|
|
34
|
+
# statement, which is not handled by the looplifting code.
|
|
35
|
+
# Thus, this loop is not a candidate.
|
|
36
|
+
_logger.debug("return-statement in loop.")
|
|
37
|
+
return False
|
|
38
|
+
outedges |= succs
|
|
39
|
+
ok = len(outedges) == 1
|
|
40
|
+
_logger.debug("same_exit_point=%s (%s)", ok, outedges)
|
|
41
|
+
return ok
|
|
42
|
+
|
|
43
|
+
def one_entry(loop):
|
|
44
|
+
"there is one entry"
|
|
45
|
+
ok = len(loop.entries) == 1
|
|
46
|
+
_logger.debug("one_entry=%s", ok)
|
|
47
|
+
return ok
|
|
48
|
+
|
|
49
|
+
def cannot_yield(loop):
|
|
50
|
+
"cannot have yield inside the loop"
|
|
51
|
+
insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
|
|
52
|
+
for blk in map(blocks.__getitem__, insiders):
|
|
53
|
+
for inst in blk.body:
|
|
54
|
+
if isinstance(inst, ir.Assign):
|
|
55
|
+
if isinstance(inst.value, ir.Yield):
|
|
56
|
+
_logger.debug("has yield")
|
|
57
|
+
return False
|
|
58
|
+
_logger.debug("no yield")
|
|
59
|
+
return True
|
|
60
|
+
|
|
61
|
+
_logger.info("finding looplift candidates")
|
|
62
|
+
# the check for cfg.entry_point in the loop.entries is to prevent a bad
|
|
63
|
+
# rewrite where a prelude for a lifted loop would get written into block -1
|
|
64
|
+
# if a loop entry were in block 0
|
|
65
|
+
candidates = []
|
|
66
|
+
for loop in find_top_level_loops(cfg):
|
|
67
|
+
_logger.debug("top-level loop: %s", loop)
|
|
68
|
+
if (
|
|
69
|
+
same_exit_point(loop)
|
|
70
|
+
and one_entry(loop)
|
|
71
|
+
and cannot_yield(loop)
|
|
72
|
+
and cfg.entry_point() not in loop.entries
|
|
73
|
+
):
|
|
74
|
+
candidates.append(loop)
|
|
75
|
+
_logger.debug("add candidate: %s", loop)
|
|
76
|
+
return candidates
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def find_region_inout_vars(blocks, livemap, callfrom, returnto, body_block_ids):
|
|
80
|
+
"""Find input and output variables to a block region."""
|
|
81
|
+
inputs = livemap[callfrom]
|
|
82
|
+
outputs = livemap[returnto]
|
|
83
|
+
|
|
84
|
+
# ensure live variables are actually used in the blocks, else remove,
|
|
85
|
+
# saves having to create something valid to run through postproc
|
|
86
|
+
# to achieve similar
|
|
87
|
+
loopblocks = {}
|
|
88
|
+
for k in body_block_ids:
|
|
89
|
+
loopblocks[k] = blocks[k]
|
|
90
|
+
|
|
91
|
+
used_vars = set()
|
|
92
|
+
def_vars = set()
|
|
93
|
+
defs = compute_use_defs(loopblocks)
|
|
94
|
+
for vs in defs.usemap.values():
|
|
95
|
+
used_vars |= vs
|
|
96
|
+
for vs in defs.defmap.values():
|
|
97
|
+
def_vars |= vs
|
|
98
|
+
used_or_defined = used_vars | def_vars
|
|
99
|
+
|
|
100
|
+
# note: sorted for stable ordering
|
|
101
|
+
inputs = sorted(set(inputs) & used_or_defined)
|
|
102
|
+
outputs = sorted(set(outputs) & used_or_defined & def_vars)
|
|
103
|
+
return inputs, outputs
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
_loop_lift_info = namedtuple(
|
|
107
|
+
"loop_lift_info", "loop,inputs,outputs,callfrom,returnto"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _loop_lift_get_candidate_infos(cfg, blocks, livemap):
|
|
112
|
+
"""
|
|
113
|
+
Returns information on looplifting candidates.
|
|
114
|
+
"""
|
|
115
|
+
loops = _extract_loop_lifting_candidates(cfg, blocks)
|
|
116
|
+
loopinfos = []
|
|
117
|
+
for loop in loops:
|
|
118
|
+
[callfrom] = loop.entries # requirement checked earlier
|
|
119
|
+
an_exit = next(iter(loop.exits)) # anyone of the exit block
|
|
120
|
+
if len(loop.exits) > 1:
|
|
121
|
+
# has multiple exits
|
|
122
|
+
[(returnto, _)] = cfg.successors(
|
|
123
|
+
an_exit
|
|
124
|
+
) # requirement checked earlier
|
|
125
|
+
else:
|
|
126
|
+
# does not have multiple exits
|
|
127
|
+
returnto = an_exit
|
|
128
|
+
|
|
129
|
+
local_block_ids = set(loop.body) | set(loop.entries) | set(loop.exits)
|
|
130
|
+
inputs, outputs = find_region_inout_vars(
|
|
131
|
+
blocks=blocks,
|
|
132
|
+
livemap=livemap,
|
|
133
|
+
callfrom=callfrom,
|
|
134
|
+
returnto=returnto,
|
|
135
|
+
body_block_ids=local_block_ids,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
lli = _loop_lift_info(
|
|
139
|
+
loop=loop,
|
|
140
|
+
inputs=inputs,
|
|
141
|
+
outputs=outputs,
|
|
142
|
+
callfrom=callfrom,
|
|
143
|
+
returnto=returnto,
|
|
144
|
+
)
|
|
145
|
+
loopinfos.append(lli)
|
|
146
|
+
|
|
147
|
+
return loopinfos
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _loop_lift_modify_call_block(liftedloop, block, inputs, outputs, returnto):
|
|
151
|
+
"""
|
|
152
|
+
Transform calling block from top-level function to call the lifted loop.
|
|
153
|
+
"""
|
|
154
|
+
scope = block.scope
|
|
155
|
+
loc = block.loc
|
|
156
|
+
blk = ir.Block(scope=scope, loc=loc)
|
|
157
|
+
|
|
158
|
+
ir_utils.fill_block_with_call(
|
|
159
|
+
newblock=blk,
|
|
160
|
+
callee=liftedloop,
|
|
161
|
+
label_next=returnto,
|
|
162
|
+
inputs=inputs,
|
|
163
|
+
outputs=outputs,
|
|
164
|
+
)
|
|
165
|
+
return blk
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _loop_lift_prepare_loop_func(loopinfo, blocks):
|
|
169
|
+
"""
|
|
170
|
+
Inplace transform loop blocks for use as lifted loop.
|
|
171
|
+
"""
|
|
172
|
+
entry_block = blocks[loopinfo.callfrom]
|
|
173
|
+
scope = entry_block.scope
|
|
174
|
+
loc = entry_block.loc
|
|
175
|
+
|
|
176
|
+
# Lowering assumes the first block to be the one with the smallest offset
|
|
177
|
+
firstblk = min(blocks) - 1
|
|
178
|
+
blocks[firstblk] = ir_utils.fill_callee_prologue(
|
|
179
|
+
block=ir.Block(scope=scope, loc=loc),
|
|
180
|
+
inputs=loopinfo.inputs,
|
|
181
|
+
label_next=loopinfo.callfrom,
|
|
182
|
+
)
|
|
183
|
+
blocks[loopinfo.returnto] = ir_utils.fill_callee_epilogue(
|
|
184
|
+
block=ir.Block(scope=scope, loc=loc),
|
|
185
|
+
outputs=loopinfo.outputs,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _loop_lift_modify_blocks(
|
|
190
|
+
func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
|
|
191
|
+
):
|
|
192
|
+
"""
|
|
193
|
+
Modify the block inplace to call to the lifted-loop.
|
|
194
|
+
Returns a dictionary of blocks of the lifted-loop.
|
|
195
|
+
"""
|
|
196
|
+
from numba.core.dispatcher import LiftedLoop
|
|
197
|
+
|
|
198
|
+
# Copy loop blocks
|
|
199
|
+
loop = loopinfo.loop
|
|
200
|
+
|
|
201
|
+
loopblockkeys = set(loop.body) | set(loop.entries)
|
|
202
|
+
if len(loop.exits) > 1:
|
|
203
|
+
# has multiple exits
|
|
204
|
+
loopblockkeys |= loop.exits
|
|
205
|
+
loopblocks = dict((k, blocks[k].copy()) for k in loopblockkeys)
|
|
206
|
+
# Modify the loop blocks
|
|
207
|
+
_loop_lift_prepare_loop_func(loopinfo, loopblocks)
|
|
208
|
+
# Since Python 3.13, [END_FOR, POP_TOP] sequence becomes the start of the
|
|
209
|
+
# block causing the block to have line number of the start of previous loop.
|
|
210
|
+
# Fix this using the loc of the first getiter.
|
|
211
|
+
getiter_exprs = []
|
|
212
|
+
for blk in loopblocks.values():
|
|
213
|
+
getiter_exprs.extend(blk.find_exprs(op="getiter"))
|
|
214
|
+
first_getiter = min(getiter_exprs, key=lambda x: x.loc.line)
|
|
215
|
+
loop_loc = first_getiter.loc
|
|
216
|
+
# Create a new IR for the lifted loop
|
|
217
|
+
lifted_ir = func_ir.derive(
|
|
218
|
+
blocks=loopblocks,
|
|
219
|
+
arg_names=tuple(loopinfo.inputs),
|
|
220
|
+
arg_count=len(loopinfo.inputs),
|
|
221
|
+
force_non_generator=True,
|
|
222
|
+
loc=loop_loc,
|
|
223
|
+
)
|
|
224
|
+
liftedloop = LiftedLoop(lifted_ir, typingctx, targetctx, flags, locals)
|
|
225
|
+
|
|
226
|
+
# modify for calling into liftedloop
|
|
227
|
+
callblock = _loop_lift_modify_call_block(
|
|
228
|
+
liftedloop,
|
|
229
|
+
blocks[loopinfo.callfrom],
|
|
230
|
+
loopinfo.inputs,
|
|
231
|
+
loopinfo.outputs,
|
|
232
|
+
loopinfo.returnto,
|
|
233
|
+
)
|
|
234
|
+
# remove blocks
|
|
235
|
+
for k in loopblockkeys:
|
|
236
|
+
del blocks[k]
|
|
237
|
+
# update main interpreter callsite into the liftedloop
|
|
238
|
+
blocks[loopinfo.callfrom] = callblock
|
|
239
|
+
return liftedloop
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
def _has_multiple_loop_exits(cfg, lpinfo):
|
|
243
|
+
"""Returns True if there is more than one exit in the loop.
|
|
244
|
+
|
|
245
|
+
NOTE: "common exits" refers to the situation where a loop exit has another
|
|
246
|
+
loop exit as its successor. In that case, we do not need to alter it.
|
|
247
|
+
"""
|
|
248
|
+
if len(lpinfo.exits) <= 1:
|
|
249
|
+
return False
|
|
250
|
+
exits = set(lpinfo.exits)
|
|
251
|
+
pdom = cfg.post_dominators()
|
|
252
|
+
|
|
253
|
+
# Eliminate blocks that have other blocks as post-dominators.
|
|
254
|
+
processed = set()
|
|
255
|
+
remain = set(exits) # create a copy to work on
|
|
256
|
+
while remain:
|
|
257
|
+
node = remain.pop()
|
|
258
|
+
processed.add(node)
|
|
259
|
+
exits -= pdom[node] - {node}
|
|
260
|
+
remain = exits - processed
|
|
261
|
+
|
|
262
|
+
return len(exits) > 1
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def _pre_looplift_transform(func_ir):
|
|
266
|
+
"""Canonicalize loops for looplifting."""
|
|
267
|
+
from numba.core.postproc import PostProcessor
|
|
268
|
+
|
|
269
|
+
cfg = compute_cfg_from_blocks(func_ir.blocks)
|
|
270
|
+
# For every loop that has multiple exits, combine the exits into one.
|
|
271
|
+
for loop_info in cfg.loops().values():
|
|
272
|
+
if _has_multiple_loop_exits(cfg, loop_info):
|
|
273
|
+
func_ir, _common_key = _fix_multi_exit_blocks(
|
|
274
|
+
func_ir, loop_info.exits
|
|
275
|
+
)
|
|
276
|
+
# Reset and reprocess the func_ir
|
|
277
|
+
func_ir._reset_analysis_variables()
|
|
278
|
+
PostProcessor(func_ir).run()
|
|
279
|
+
return func_ir
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def loop_lifting(func_ir, typingctx, targetctx, flags, locals):
|
|
283
|
+
"""
|
|
284
|
+
Loop lifting transformation.
|
|
285
|
+
|
|
286
|
+
Given a interpreter `func_ir` returns a 2 tuple of
|
|
287
|
+
`(toplevel_interp, [loop0_interp, loop1_interp, ....])`
|
|
288
|
+
"""
|
|
289
|
+
func_ir = _pre_looplift_transform(func_ir)
|
|
290
|
+
blocks = func_ir.blocks.copy()
|
|
291
|
+
cfg = compute_cfg_from_blocks(blocks)
|
|
292
|
+
loopinfos = _loop_lift_get_candidate_infos(
|
|
293
|
+
cfg, blocks, func_ir.variable_lifetime.livemap
|
|
294
|
+
)
|
|
295
|
+
loops = []
|
|
296
|
+
if loopinfos:
|
|
297
|
+
_logger.debug(
|
|
298
|
+
"loop lifting this IR with %d candidates:\n%s",
|
|
299
|
+
len(loopinfos),
|
|
300
|
+
func_ir.dump_to_string(),
|
|
301
|
+
)
|
|
302
|
+
for loopinfo in loopinfos:
|
|
303
|
+
lifted = _loop_lift_modify_blocks(
|
|
304
|
+
func_ir, loopinfo, blocks, typingctx, targetctx, flags, locals
|
|
305
|
+
)
|
|
306
|
+
loops.append(lifted)
|
|
307
|
+
|
|
308
|
+
# Make main IR
|
|
309
|
+
main = func_ir.derive(blocks=blocks)
|
|
310
|
+
|
|
311
|
+
return main, loops
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def canonicalize_cfg_single_backedge(blocks):
|
|
315
|
+
"""
|
|
316
|
+
Rewrite loops that have multiple backedges.
|
|
317
|
+
"""
|
|
318
|
+
cfg = compute_cfg_from_blocks(blocks)
|
|
319
|
+
newblocks = blocks.copy()
|
|
320
|
+
|
|
321
|
+
def new_block_id():
|
|
322
|
+
return max(newblocks.keys()) + 1
|
|
323
|
+
|
|
324
|
+
def has_multiple_backedges(loop):
|
|
325
|
+
count = 0
|
|
326
|
+
for k in loop.body:
|
|
327
|
+
blk = blocks[k]
|
|
328
|
+
edges = blk.terminator.get_targets()
|
|
329
|
+
# is a backedge?
|
|
330
|
+
if loop.header in edges:
|
|
331
|
+
count += 1
|
|
332
|
+
if count > 1:
|
|
333
|
+
# early exit
|
|
334
|
+
return True
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
def yield_loops_with_multiple_backedges():
|
|
338
|
+
for lp in cfg.loops().values():
|
|
339
|
+
if has_multiple_backedges(lp):
|
|
340
|
+
yield lp
|
|
341
|
+
|
|
342
|
+
def replace_target(term, src, dst):
|
|
343
|
+
def replace(target):
|
|
344
|
+
return dst if target == src else target
|
|
345
|
+
|
|
346
|
+
if isinstance(term, ir.Branch):
|
|
347
|
+
return ir.Branch(
|
|
348
|
+
cond=term.cond,
|
|
349
|
+
truebr=replace(term.truebr),
|
|
350
|
+
falsebr=replace(term.falsebr),
|
|
351
|
+
loc=term.loc,
|
|
352
|
+
)
|
|
353
|
+
elif isinstance(term, ir.Jump):
|
|
354
|
+
return ir.Jump(target=replace(term.target), loc=term.loc)
|
|
355
|
+
else:
|
|
356
|
+
assert not term.get_targets()
|
|
357
|
+
return term
|
|
358
|
+
|
|
359
|
+
def rewrite_single_backedge(loop):
|
|
360
|
+
"""
|
|
361
|
+
Add new tail block that gathers all the backedges
|
|
362
|
+
"""
|
|
363
|
+
header = loop.header
|
|
364
|
+
tailkey = new_block_id()
|
|
365
|
+
for blkkey in loop.body:
|
|
366
|
+
blk = newblocks[blkkey]
|
|
367
|
+
if header in blk.terminator.get_targets():
|
|
368
|
+
newblk = blk.copy()
|
|
369
|
+
# rewrite backedge into jumps to new tail block
|
|
370
|
+
newblk.body[-1] = replace_target(
|
|
371
|
+
blk.terminator, header, tailkey
|
|
372
|
+
)
|
|
373
|
+
newblocks[blkkey] = newblk
|
|
374
|
+
# create new tail block
|
|
375
|
+
entryblk = newblocks[header]
|
|
376
|
+
tailblk = ir.Block(scope=entryblk.scope, loc=entryblk.loc)
|
|
377
|
+
# add backedge
|
|
378
|
+
tailblk.append(ir.Jump(target=header, loc=tailblk.loc))
|
|
379
|
+
newblocks[tailkey] = tailblk
|
|
380
|
+
|
|
381
|
+
for loop in yield_loops_with_multiple_backedges():
|
|
382
|
+
rewrite_single_backedge(loop)
|
|
383
|
+
|
|
384
|
+
return newblocks
|
|
385
|
+
|
|
386
|
+
|
|
387
|
+
def canonicalize_cfg(blocks):
|
|
388
|
+
"""
|
|
389
|
+
Rewrite the given blocks to canonicalize the CFG.
|
|
390
|
+
Returns a new dictionary of blocks.
|
|
391
|
+
"""
|
|
392
|
+
return canonicalize_cfg_single_backedge(blocks)
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def with_lifting(func_ir, typingctx, targetctx, flags, locals):
|
|
396
|
+
"""With-lifting transformation
|
|
397
|
+
|
|
398
|
+
Rewrite the IR to extract all withs.
|
|
399
|
+
Only the top-level withs are extracted.
|
|
400
|
+
Returns the (the_new_ir, the_lifted_with_ir)
|
|
401
|
+
"""
|
|
402
|
+
from numba.core import postproc
|
|
403
|
+
|
|
404
|
+
def dispatcher_factory(func_ir, objectmode=False, **kwargs):
|
|
405
|
+
from numba.core.dispatcher import LiftedWith, ObjModeLiftedWith
|
|
406
|
+
|
|
407
|
+
myflags = flags.copy()
|
|
408
|
+
if objectmode:
|
|
409
|
+
# Lifted with-block cannot looplift
|
|
410
|
+
myflags.enable_looplift = False
|
|
411
|
+
# Lifted with-block uses object mode
|
|
412
|
+
myflags.enable_pyobject = True
|
|
413
|
+
myflags.force_pyobject = True
|
|
414
|
+
myflags.no_cpython_wrapper = False
|
|
415
|
+
cls = ObjModeLiftedWith
|
|
416
|
+
else:
|
|
417
|
+
cls = LiftedWith
|
|
418
|
+
return cls(func_ir, typingctx, targetctx, myflags, locals, **kwargs)
|
|
419
|
+
|
|
420
|
+
# find where with-contexts regions are
|
|
421
|
+
withs, func_ir = find_setupwiths(func_ir)
|
|
422
|
+
|
|
423
|
+
if not withs:
|
|
424
|
+
return func_ir, []
|
|
425
|
+
|
|
426
|
+
postproc.PostProcessor(func_ir).run() # ensure we have variable lifetime
|
|
427
|
+
assert func_ir.variable_lifetime
|
|
428
|
+
vlt = func_ir.variable_lifetime
|
|
429
|
+
blocks = func_ir.blocks.copy()
|
|
430
|
+
cfg = vlt.cfg
|
|
431
|
+
# For each with-regions, mutate them according to
|
|
432
|
+
# the kind of contextmanager
|
|
433
|
+
sub_irs = []
|
|
434
|
+
for blk_start, blk_end in withs:
|
|
435
|
+
body_blocks = []
|
|
436
|
+
for node in _cfg_nodes_in_region(cfg, blk_start, blk_end):
|
|
437
|
+
body_blocks.append(node)
|
|
438
|
+
_legalize_with_head(blocks[blk_start])
|
|
439
|
+
# Find the contextmanager
|
|
440
|
+
cmkind, extra = _get_with_contextmanager(func_ir, blocks, blk_start)
|
|
441
|
+
# Mutate the body and get new IR
|
|
442
|
+
sub = cmkind.mutate_with_body(
|
|
443
|
+
func_ir,
|
|
444
|
+
blocks,
|
|
445
|
+
blk_start,
|
|
446
|
+
blk_end,
|
|
447
|
+
body_blocks,
|
|
448
|
+
dispatcher_factory,
|
|
449
|
+
extra,
|
|
450
|
+
)
|
|
451
|
+
sub_irs.append(sub)
|
|
452
|
+
if not sub_irs:
|
|
453
|
+
# Unchanged
|
|
454
|
+
new_ir = func_ir
|
|
455
|
+
else:
|
|
456
|
+
new_ir = func_ir.derive(blocks)
|
|
457
|
+
return new_ir, sub_irs
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
def _get_with_contextmanager(func_ir, blocks, blk_start):
|
|
461
|
+
"""Get the global object used for the context manager"""
|
|
462
|
+
_illegal_cm_msg = "Illegal use of context-manager."
|
|
463
|
+
|
|
464
|
+
def get_var_dfn(var):
|
|
465
|
+
"""Get the definition given a variable"""
|
|
466
|
+
return func_ir.get_definition(var)
|
|
467
|
+
|
|
468
|
+
def get_ctxmgr_obj(var_ref):
|
|
469
|
+
"""Return the context-manager object and extra info.
|
|
470
|
+
|
|
471
|
+
The extra contains the arguments if the context-manager is used
|
|
472
|
+
as a call.
|
|
473
|
+
"""
|
|
474
|
+
# If the contextmanager used as a Call
|
|
475
|
+
dfn = func_ir.get_definition(var_ref)
|
|
476
|
+
if isinstance(dfn, ir.Expr) and dfn.op == "call":
|
|
477
|
+
args = [get_var_dfn(x) for x in dfn.args]
|
|
478
|
+
kws = {k: get_var_dfn(v) for k, v in dfn.kws}
|
|
479
|
+
extra = {"args": args, "kwargs": kws}
|
|
480
|
+
var_ref = dfn.func
|
|
481
|
+
else:
|
|
482
|
+
extra = None
|
|
483
|
+
|
|
484
|
+
ctxobj = ir_utils.guard(ir_utils.find_outer_value, func_ir, var_ref)
|
|
485
|
+
|
|
486
|
+
# check the contextmanager object
|
|
487
|
+
if ctxobj is ir.UNDEFINED:
|
|
488
|
+
raise errors.CompilerError(
|
|
489
|
+
"Undefined variable used as context manager",
|
|
490
|
+
loc=blocks[blk_start].loc,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
if ctxobj is None:
|
|
494
|
+
raise errors.CompilerError(_illegal_cm_msg, loc=dfn.loc)
|
|
495
|
+
|
|
496
|
+
return ctxobj, extra
|
|
497
|
+
|
|
498
|
+
# Scan the start of the with-region for the contextmanager
|
|
499
|
+
for stmt in blocks[blk_start].body:
|
|
500
|
+
if isinstance(stmt, ir.EnterWith):
|
|
501
|
+
var_ref = stmt.contextmanager
|
|
502
|
+
ctxobj, extra = get_ctxmgr_obj(var_ref)
|
|
503
|
+
if not hasattr(ctxobj, "mutate_with_body"):
|
|
504
|
+
raise errors.CompilerError(
|
|
505
|
+
"Unsupported context manager in use",
|
|
506
|
+
loc=blocks[blk_start].loc,
|
|
507
|
+
)
|
|
508
|
+
return ctxobj, extra
|
|
509
|
+
# No contextmanager found?
|
|
510
|
+
raise errors.CompilerError(
|
|
511
|
+
"malformed with-context usage",
|
|
512
|
+
loc=blocks[blk_start].loc,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def _legalize_with_head(blk):
|
|
517
|
+
"""Given *blk*, the head block of the with-context, check that it doesn't
|
|
518
|
+
do anything else.
|
|
519
|
+
"""
|
|
520
|
+
counters = defaultdict(int)
|
|
521
|
+
for stmt in blk.body:
|
|
522
|
+
counters[type(stmt)] += 1
|
|
523
|
+
if counters.pop(ir.EnterWith) != 1:
|
|
524
|
+
raise errors.CompilerError(
|
|
525
|
+
"with's head-block must have exactly 1 ENTER_WITH",
|
|
526
|
+
loc=blk.loc,
|
|
527
|
+
)
|
|
528
|
+
if counters.pop(ir.Jump, 0) != 1:
|
|
529
|
+
raise errors.CompilerError(
|
|
530
|
+
"with's head-block must have exactly 1 JUMP",
|
|
531
|
+
loc=blk.loc,
|
|
532
|
+
)
|
|
533
|
+
# Can have any number of del
|
|
534
|
+
counters.pop(ir.Del, None)
|
|
535
|
+
# There MUST NOT be any other statements
|
|
536
|
+
if counters:
|
|
537
|
+
raise errors.CompilerError(
|
|
538
|
+
"illegal statements in with's head-block",
|
|
539
|
+
loc=blk.loc,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
|
|
543
|
+
def _cfg_nodes_in_region(cfg, region_begin, region_end):
|
|
544
|
+
"""Find the set of CFG nodes that are in the given region"""
|
|
545
|
+
region_nodes = set()
|
|
546
|
+
stack = [region_begin]
|
|
547
|
+
while stack:
|
|
548
|
+
tos = stack.pop()
|
|
549
|
+
succlist = list(cfg.successors(tos))
|
|
550
|
+
# a single block function will have a empty successor list
|
|
551
|
+
if succlist:
|
|
552
|
+
succs, _ = zip(*succlist)
|
|
553
|
+
nodes = set(
|
|
554
|
+
[
|
|
555
|
+
node
|
|
556
|
+
for node in succs
|
|
557
|
+
if node not in region_nodes and node != region_end
|
|
558
|
+
]
|
|
559
|
+
)
|
|
560
|
+
stack.extend(nodes)
|
|
561
|
+
region_nodes |= nodes
|
|
562
|
+
|
|
563
|
+
return region_nodes
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def find_setupwiths(func_ir):
|
|
567
|
+
"""Find all top-level with.
|
|
568
|
+
|
|
569
|
+
Returns a list of ranges for the with-regions.
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
def find_ranges(blocks):
|
|
573
|
+
cfg = compute_cfg_from_blocks(blocks)
|
|
574
|
+
sus_setups, sus_pops = set(), set()
|
|
575
|
+
# traverse the cfg and collect all suspected SETUP_WITH and POP_BLOCK
|
|
576
|
+
# statements so that we can iterate over them
|
|
577
|
+
for label, block in blocks.items():
|
|
578
|
+
for stmt in block.body:
|
|
579
|
+
if ir_utils.is_setup_with(stmt):
|
|
580
|
+
sus_setups.add(label)
|
|
581
|
+
if ir_utils.is_pop_block(stmt):
|
|
582
|
+
sus_pops.add(label)
|
|
583
|
+
|
|
584
|
+
# now that we do have the statements, iterate through them in reverse
|
|
585
|
+
# topo order and from each start looking for pop_blocks
|
|
586
|
+
setup_with_to_pop_blocks_map = defaultdict(set)
|
|
587
|
+
for setup_block in cfg.topo_sort(sus_setups, reverse=True):
|
|
588
|
+
# begin pop_block, search
|
|
589
|
+
to_visit, seen = [], []
|
|
590
|
+
to_visit.append(setup_block)
|
|
591
|
+
while to_visit:
|
|
592
|
+
# get whatever is next and record that we have seen it
|
|
593
|
+
block = to_visit.pop()
|
|
594
|
+
seen.append(block)
|
|
595
|
+
# go through the body of the block, looking for statements
|
|
596
|
+
for stmt in blocks[block].body:
|
|
597
|
+
# raise detected before pop_block
|
|
598
|
+
if ir_utils.is_raise(stmt):
|
|
599
|
+
raise errors.CompilerError(
|
|
600
|
+
"unsupported control flow due to raise "
|
|
601
|
+
"statements inside with block"
|
|
602
|
+
)
|
|
603
|
+
# if a pop_block, process it
|
|
604
|
+
if ir_utils.is_pop_block(stmt) and block in sus_pops:
|
|
605
|
+
# record the jump target of this block belonging to this setup
|
|
606
|
+
setup_with_to_pop_blocks_map[setup_block].add(block)
|
|
607
|
+
# remove the block from blocks to be matched
|
|
608
|
+
sus_pops.remove(block)
|
|
609
|
+
# stop looking, we have reached the frontier
|
|
610
|
+
break
|
|
611
|
+
# if we are still here, by the block terminator,
|
|
612
|
+
# add all its targets to the to_visit stack, unless we
|
|
613
|
+
# have seen them already
|
|
614
|
+
if ir_utils.is_terminator(stmt):
|
|
615
|
+
for t in stmt.get_targets():
|
|
616
|
+
if t not in seen:
|
|
617
|
+
to_visit.append(t)
|
|
618
|
+
|
|
619
|
+
return setup_with_to_pop_blocks_map
|
|
620
|
+
|
|
621
|
+
blocks = func_ir.blocks
|
|
622
|
+
# initial find, will return a dictionary, mapping indices of blocks
|
|
623
|
+
# containing SETUP_WITH statements to a set of indices of blocks containing
|
|
624
|
+
# POP_BLOCK statements
|
|
625
|
+
with_ranges_dict = find_ranges(blocks)
|
|
626
|
+
# rewrite the CFG in case there are multiple POP_BLOCK statements for one
|
|
627
|
+
# with
|
|
628
|
+
func_ir = consolidate_multi_exit_withs(with_ranges_dict, blocks, func_ir)
|
|
629
|
+
# here we need to turn the withs back into a list of tuples so that the
|
|
630
|
+
# rest of the code can cope
|
|
631
|
+
with_ranges_tuple = [(s, list(p)[0]) for (s, p) in with_ranges_dict.items()]
|
|
632
|
+
|
|
633
|
+
# check for POP_BLOCKS with multiple outgoing edges and reject
|
|
634
|
+
for _, p in with_ranges_tuple:
|
|
635
|
+
targets = blocks[p].terminator.get_targets()
|
|
636
|
+
if len(targets) != 1:
|
|
637
|
+
raise errors.CompilerError(
|
|
638
|
+
"unsupported control flow: with-context contains branches "
|
|
639
|
+
"(i.e. break/return/raise) that can leave the block "
|
|
640
|
+
)
|
|
641
|
+
# now we check for returns inside with and reject them
|
|
642
|
+
for _, p in with_ranges_tuple:
|
|
643
|
+
target_block = blocks[p]
|
|
644
|
+
if ir_utils.is_return(
|
|
645
|
+
func_ir.blocks[target_block.terminator.get_targets()[0]].terminator
|
|
646
|
+
):
|
|
647
|
+
_rewrite_return(func_ir, p)
|
|
648
|
+
|
|
649
|
+
# now we need to rewrite the tuple such that we have SETUP_WITH matching the
|
|
650
|
+
# successor of the block that contains the POP_BLOCK.
|
|
651
|
+
with_ranges_tuple = [
|
|
652
|
+
(s, func_ir.blocks[p].terminator.get_targets()[0])
|
|
653
|
+
for (s, p) in with_ranges_tuple
|
|
654
|
+
]
|
|
655
|
+
|
|
656
|
+
# finally we check for nested with statements and reject them
|
|
657
|
+
with_ranges_tuple = _eliminate_nested_withs(with_ranges_tuple)
|
|
658
|
+
|
|
659
|
+
return with_ranges_tuple, func_ir
|
|
660
|
+
|
|
661
|
+
|
|
662
|
+
def _rewrite_return(func_ir, target_block_label):
|
|
663
|
+
"""Rewrite a return block inside a with statement.
|
|
664
|
+
|
|
665
|
+
Arguments
|
|
666
|
+
---------
|
|
667
|
+
|
|
668
|
+
func_ir: Function IR
|
|
669
|
+
the CFG to transform
|
|
670
|
+
target_block_label: int
|
|
671
|
+
the block index/label of the block containing the POP_BLOCK statement
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
This implements a CFG transformation to insert a block between two other
|
|
675
|
+
blocks.
|
|
676
|
+
|
|
677
|
+
The input situation is:
|
|
678
|
+
|
|
679
|
+
┌───────────────┐
|
|
680
|
+
│ top │
|
|
681
|
+
│ POP_BLOCK │
|
|
682
|
+
│ bottom │
|
|
683
|
+
└───────┬───────┘
|
|
684
|
+
│
|
|
685
|
+
┌───────▼───────┐
|
|
686
|
+
│ │
|
|
687
|
+
│ RETURN │
|
|
688
|
+
│ │
|
|
689
|
+
└───────────────┘
|
|
690
|
+
|
|
691
|
+
If such a pattern is detected in IR, it means there is a `return` statement
|
|
692
|
+
within a `with` context. The basic idea is to rewrite the CFG as follows:
|
|
693
|
+
|
|
694
|
+
┌───────────────┐
|
|
695
|
+
│ top │
|
|
696
|
+
│ POP_BLOCK │
|
|
697
|
+
│ │
|
|
698
|
+
└───────┬───────┘
|
|
699
|
+
│
|
|
700
|
+
┌───────▼───────┐
|
|
701
|
+
│ │
|
|
702
|
+
│ bottom │
|
|
703
|
+
│ │
|
|
704
|
+
└───────┬───────┘
|
|
705
|
+
│
|
|
706
|
+
┌───────▼───────┐
|
|
707
|
+
│ │
|
|
708
|
+
│ RETURN │
|
|
709
|
+
│ │
|
|
710
|
+
└───────────────┘
|
|
711
|
+
|
|
712
|
+
We split the block that contains the `POP_BLOCK` statement into two blocks.
|
|
713
|
+
Everything from the beginning of the block up to and including the
|
|
714
|
+
`POP_BLOCK` statement is considered the 'top' and everything below is
|
|
715
|
+
considered 'bottom'. Finally the jump statements are re-wired to make sure
|
|
716
|
+
the CFG remains valid.
|
|
717
|
+
|
|
718
|
+
"""
|
|
719
|
+
# the block itself from the index
|
|
720
|
+
target_block = func_ir.blocks[target_block_label]
|
|
721
|
+
# get the index of the block containing the return
|
|
722
|
+
target_block_successor_label = target_block.terminator.get_targets()[0]
|
|
723
|
+
# the return block
|
|
724
|
+
target_block_successor = func_ir.blocks[target_block_successor_label]
|
|
725
|
+
|
|
726
|
+
# create the new return block with an appropriate label
|
|
727
|
+
max_label = ir_utils.find_max_label(func_ir.blocks)
|
|
728
|
+
new_label = max_label + 1
|
|
729
|
+
# create the new return block
|
|
730
|
+
new_block_loc = target_block_successor.loc
|
|
731
|
+
new_block_scope = ir.Scope(None, loc=new_block_loc)
|
|
732
|
+
new_block = ir.Block(new_block_scope, loc=new_block_loc)
|
|
733
|
+
|
|
734
|
+
# Split the block containing the POP_BLOCK into top and bottom
|
|
735
|
+
# Block must be of the form:
|
|
736
|
+
# -----------------
|
|
737
|
+
# <some stmts>
|
|
738
|
+
# POP_BLOCK
|
|
739
|
+
# <some more stmts>
|
|
740
|
+
# JUMP
|
|
741
|
+
# -----------------
|
|
742
|
+
top_body, bottom_body = [], []
|
|
743
|
+
pop_blocks = [*target_block.find_insts(ir.PopBlock)]
|
|
744
|
+
assert len(pop_blocks) == 1
|
|
745
|
+
assert len([*target_block.find_insts(ir.Jump)]) == 1
|
|
746
|
+
assert isinstance(target_block.body[-1], ir.Jump)
|
|
747
|
+
pb_marker = pop_blocks[0]
|
|
748
|
+
pb_is = target_block.body.index(pb_marker)
|
|
749
|
+
top_body.extend(target_block.body[:pb_is])
|
|
750
|
+
top_body.append(ir.Jump(target_block_successor_label, target_block.loc))
|
|
751
|
+
bottom_body.extend(target_block.body[pb_is:-1])
|
|
752
|
+
bottom_body.append(ir.Jump(new_label, target_block.loc))
|
|
753
|
+
|
|
754
|
+
# get the contents of the return block
|
|
755
|
+
return_body = func_ir.blocks[target_block_successor_label].body
|
|
756
|
+
# finally, re-assign all blocks
|
|
757
|
+
new_block.body.extend(return_body)
|
|
758
|
+
target_block_successor.body.clear()
|
|
759
|
+
target_block_successor.body.extend(bottom_body)
|
|
760
|
+
target_block.body.clear()
|
|
761
|
+
target_block.body.extend(top_body)
|
|
762
|
+
|
|
763
|
+
# finally, append the new return block and rebuild the IR properties
|
|
764
|
+
func_ir.blocks[new_label] = new_block
|
|
765
|
+
func_ir._definitions = ir_utils.build_definitions(func_ir.blocks)
|
|
766
|
+
return func_ir
|
|
767
|
+
|
|
768
|
+
|
|
769
|
+
def _eliminate_nested_withs(with_ranges):
|
|
770
|
+
known_ranges = []
|
|
771
|
+
|
|
772
|
+
def within_known_range(start, end, known_ranges):
|
|
773
|
+
for a, b in known_ranges:
|
|
774
|
+
# FIXME: this should be a comparison in topological order, right
|
|
775
|
+
# now we are comparing the integers of the blocks, stuff probably
|
|
776
|
+
# works by accident.
|
|
777
|
+
if start > a and end < b:
|
|
778
|
+
return True
|
|
779
|
+
return False
|
|
780
|
+
|
|
781
|
+
for s, e in sorted(with_ranges):
|
|
782
|
+
if not within_known_range(s, e, known_ranges):
|
|
783
|
+
known_ranges.append((s, e))
|
|
784
|
+
|
|
785
|
+
return known_ranges
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
def consolidate_multi_exit_withs(withs: dict, blocks, func_ir):
|
|
789
|
+
"""Modify the FunctionIR to merge the exit blocks of with constructs."""
|
|
790
|
+
for k in withs:
|
|
791
|
+
vs: set = withs[k]
|
|
792
|
+
if len(vs) > 1:
|
|
793
|
+
func_ir, common = _fix_multi_exit_blocks(
|
|
794
|
+
func_ir,
|
|
795
|
+
vs,
|
|
796
|
+
split_condition=ir_utils.is_pop_block,
|
|
797
|
+
)
|
|
798
|
+
withs[k] = {common}
|
|
799
|
+
return func_ir
|
|
800
|
+
|
|
801
|
+
|
|
802
|
+
def _fix_multi_exit_blocks(func_ir, exit_nodes, *, split_condition=None):
|
|
803
|
+
"""Modify the FunctionIR to create a single common exit node given the
|
|
804
|
+
original exit nodes.
|
|
805
|
+
|
|
806
|
+
Parameters
|
|
807
|
+
----------
|
|
808
|
+
func_ir :
|
|
809
|
+
The FunctionIR. Mutated inplace.
|
|
810
|
+
exit_nodes :
|
|
811
|
+
The original exit nodes. A sequence of block keys.
|
|
812
|
+
split_condition : callable or None
|
|
813
|
+
If not None, it is a callable with the signature
|
|
814
|
+
`split_condition(statement)` that determines if the `statement` is the
|
|
815
|
+
splitting point (e.g. `POP_BLOCK`) in an exit node.
|
|
816
|
+
If it's None, the exit node is not split.
|
|
817
|
+
"""
|
|
818
|
+
|
|
819
|
+
# Convert the following:
|
|
820
|
+
#
|
|
821
|
+
# | |
|
|
822
|
+
# +-------+ +-------+
|
|
823
|
+
# | exit0 | | exit1 |
|
|
824
|
+
# +-------+ +-------+
|
|
825
|
+
# | |
|
|
826
|
+
# +-------+ +-------+
|
|
827
|
+
# | after0| | after1|
|
|
828
|
+
# +-------+ +-------+
|
|
829
|
+
# | |
|
|
830
|
+
#
|
|
831
|
+
# To roughly:
|
|
832
|
+
#
|
|
833
|
+
# | |
|
|
834
|
+
# +-------+ +-------+
|
|
835
|
+
# | exit0 | | exit1 |
|
|
836
|
+
# +-------+ +-------+
|
|
837
|
+
# | |
|
|
838
|
+
# +-----+-----+
|
|
839
|
+
# |
|
|
840
|
+
# +---------+
|
|
841
|
+
# | common |
|
|
842
|
+
# +---------+
|
|
843
|
+
# |
|
|
844
|
+
# +-------+
|
|
845
|
+
# | post |
|
|
846
|
+
# +-------+
|
|
847
|
+
# |
|
|
848
|
+
# +-----+-----+
|
|
849
|
+
# | |
|
|
850
|
+
# +-------+ +-------+
|
|
851
|
+
# | after0| | after1|
|
|
852
|
+
# +-------+ +-------+
|
|
853
|
+
|
|
854
|
+
blocks = func_ir.blocks
|
|
855
|
+
# Getting the scope
|
|
856
|
+
any_blk = min(func_ir.blocks.values())
|
|
857
|
+
scope = any_blk.scope
|
|
858
|
+
# Getting the maximum block label
|
|
859
|
+
max_label = max(func_ir.blocks) + 1
|
|
860
|
+
# Define the new common block for the new exit.
|
|
861
|
+
common_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
|
|
862
|
+
common_label = max_label
|
|
863
|
+
max_label += 1
|
|
864
|
+
blocks[common_label] = common_block
|
|
865
|
+
# Define the new block after the exit.
|
|
866
|
+
post_block = ir.Block(any_blk.scope, loc=ir.unknown_loc)
|
|
867
|
+
post_label = max_label
|
|
868
|
+
max_label += 1
|
|
869
|
+
blocks[post_label] = post_block
|
|
870
|
+
|
|
871
|
+
# Adjust each exit node
|
|
872
|
+
remainings = []
|
|
873
|
+
for i, k in enumerate(exit_nodes):
|
|
874
|
+
blk = blocks[k]
|
|
875
|
+
|
|
876
|
+
# split the block if needed
|
|
877
|
+
if split_condition is not None:
|
|
878
|
+
for pt, stmt in enumerate(blk.body):
|
|
879
|
+
if split_condition(stmt):
|
|
880
|
+
break
|
|
881
|
+
else:
|
|
882
|
+
# no splitting
|
|
883
|
+
pt = -1
|
|
884
|
+
|
|
885
|
+
before = blk.body[:pt]
|
|
886
|
+
after = blk.body[pt:]
|
|
887
|
+
remainings.append(after)
|
|
888
|
+
|
|
889
|
+
# Add control-point variable to mark which exit block this is.
|
|
890
|
+
blk.body = before
|
|
891
|
+
loc = blk.loc
|
|
892
|
+
blk.body.append(
|
|
893
|
+
ir.Assign(
|
|
894
|
+
value=ir.Const(i, loc=loc),
|
|
895
|
+
target=scope.get_or_define("$cp", loc=loc),
|
|
896
|
+
loc=loc,
|
|
897
|
+
)
|
|
898
|
+
)
|
|
899
|
+
# Replace terminator with a jump to the common block
|
|
900
|
+
assert not blk.is_terminated
|
|
901
|
+
blk.body.append(ir.Jump(common_label, loc=ir.unknown_loc))
|
|
902
|
+
|
|
903
|
+
if split_condition is not None:
|
|
904
|
+
# Move the splitting statement to the common block
|
|
905
|
+
common_block.body.append(remainings[0][0])
|
|
906
|
+
assert not common_block.is_terminated
|
|
907
|
+
# Append jump from common block to post block
|
|
908
|
+
common_block.body.append(ir.Jump(post_label, loc=loc))
|
|
909
|
+
|
|
910
|
+
# Make if-else tree to jump to target
|
|
911
|
+
remain_blocks = []
|
|
912
|
+
for remain in remainings:
|
|
913
|
+
remain_blocks.append(max_label)
|
|
914
|
+
max_label += 1
|
|
915
|
+
|
|
916
|
+
switch_block = post_block
|
|
917
|
+
loc = ir.unknown_loc
|
|
918
|
+
for i, remain in enumerate(remainings):
|
|
919
|
+
match_expr = scope.redefine("$cp_check", loc=loc)
|
|
920
|
+
match_rhs = scope.redefine("$cp_rhs", loc=loc)
|
|
921
|
+
|
|
922
|
+
# Do comparison to match control-point variable to the exit block
|
|
923
|
+
switch_block.body.append(
|
|
924
|
+
ir.Assign(value=ir.Const(i, loc=loc), target=match_rhs, loc=loc),
|
|
925
|
+
)
|
|
926
|
+
|
|
927
|
+
# Add assignment for the comparison
|
|
928
|
+
switch_block.body.append(
|
|
929
|
+
ir.Assign(
|
|
930
|
+
value=ir.Expr.binop(
|
|
931
|
+
fn=operator.eq,
|
|
932
|
+
lhs=scope.get("$cp"),
|
|
933
|
+
rhs=match_rhs,
|
|
934
|
+
loc=loc,
|
|
935
|
+
),
|
|
936
|
+
target=match_expr,
|
|
937
|
+
loc=loc,
|
|
938
|
+
),
|
|
939
|
+
)
|
|
940
|
+
|
|
941
|
+
# Insert jump to the next case
|
|
942
|
+
[jump_target] = remain[-1].get_targets()
|
|
943
|
+
switch_block.body.append(
|
|
944
|
+
ir.Branch(match_expr, jump_target, remain_blocks[i], loc=loc),
|
|
945
|
+
)
|
|
946
|
+
switch_block = ir.Block(scope=scope, loc=loc)
|
|
947
|
+
blocks[remain_blocks[i]] = switch_block
|
|
948
|
+
|
|
949
|
+
# Add the final jump
|
|
950
|
+
switch_block.body.append(ir.Jump(jump_target, loc=loc))
|
|
951
|
+
|
|
952
|
+
return func_ir, common_label
|