kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,47 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir, types
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
5
|
+
from kirin.dialects.func.attrs import Signature
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class ApplyType(RewriteRule):
|
10
|
+
results: dict[ir.SSAValue, types.TypeAttribute]
|
11
|
+
|
12
|
+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
|
13
|
+
has_done_something = False
|
14
|
+
for arg in node.args:
|
15
|
+
if arg in self.results:
|
16
|
+
arg.type = self.results[arg]
|
17
|
+
has_done_something = True
|
18
|
+
|
19
|
+
return RewriteResult(has_done_something=has_done_something)
|
20
|
+
|
21
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
22
|
+
has_done_something = False
|
23
|
+
for result in node._results:
|
24
|
+
if result in self.results:
|
25
|
+
result.type = self.results[result]
|
26
|
+
has_done_something = True
|
27
|
+
|
28
|
+
if (trait := node.get_trait(ir.HasSignature)) is not None and (
|
29
|
+
callable_trait := node.get_trait(ir.CallableStmtInterface)
|
30
|
+
) is not None:
|
31
|
+
callable_region = callable_trait.get_callable_region(node)
|
32
|
+
inputs = tuple(
|
33
|
+
self.results.get(arg, arg.type)
|
34
|
+
for arg in callable_region.blocks[0].args
|
35
|
+
)
|
36
|
+
|
37
|
+
if (
|
38
|
+
len(node._results) == 1
|
39
|
+
and isinstance(
|
40
|
+
output_ := self.results.get(node._results[0]), types.Generic
|
41
|
+
)
|
42
|
+
and output_.is_subseteq(types.MethodType)
|
43
|
+
):
|
44
|
+
output_ = output_.vars[1]
|
45
|
+
trait.set_signature(node, Signature(inputs, output_))
|
46
|
+
has_done_something = True
|
47
|
+
return RewriteResult(has_done_something=has_done_something)
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.analysis import const
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
6
|
+
from kirin.dialects.func import Call, Invoke
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Call2Invoke(RewriteRule):
|
11
|
+
"""Rewrite a `Call` statement to an `Invoke` statement."""
|
12
|
+
|
13
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
14
|
+
if not isinstance(node, Call):
|
15
|
+
return RewriteResult()
|
16
|
+
|
17
|
+
if (mt := node.callee.hints.get("const")) is None:
|
18
|
+
return RewriteResult()
|
19
|
+
|
20
|
+
if not isinstance(mt, const.Value):
|
21
|
+
return RewriteResult()
|
22
|
+
|
23
|
+
if not isinstance(mt.data, ir.Method):
|
24
|
+
return RewriteResult()
|
25
|
+
|
26
|
+
stmt = Invoke(inputs=node.inputs, callee=mt.data, kwargs=node.kwargs)
|
27
|
+
for result, new_result in zip(node.results, stmt.results):
|
28
|
+
new_result.name = result.name
|
29
|
+
new_result.type = result.type
|
30
|
+
if result_hint := result.hints.get("const"):
|
31
|
+
new_result.hints["const"] = result_hint
|
32
|
+
|
33
|
+
node.replace_by(stmt)
|
34
|
+
return RewriteResult(has_done_something=True)
|
kirin/rewrite/chain.py
ADDED
@@ -0,0 +1,39 @@
|
|
1
|
+
from typing import Iterable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from kirin.ir import IRNode
|
5
|
+
from kirin.rewrite.abc import RewriteRule
|
6
|
+
from kirin.rewrite.result import RewriteResult
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Chain(RewriteRule):
|
11
|
+
"""Chain multiple rewrites together.
|
12
|
+
|
13
|
+
The chain will apply each rewrite in order until one of the rewrites terminates.
|
14
|
+
"""
|
15
|
+
|
16
|
+
rules: list[RewriteRule]
|
17
|
+
|
18
|
+
def __init__(self, rule: RewriteRule | Iterable[RewriteRule], *others: RewriteRule):
|
19
|
+
if isinstance(rule, RewriteRule):
|
20
|
+
self.rules = [rule, *others]
|
21
|
+
else:
|
22
|
+
assert (
|
23
|
+
others == ()
|
24
|
+
), "Cannot pass multiple positional arguments if the first argument is an iterable"
|
25
|
+
self.rules = list(rule)
|
26
|
+
|
27
|
+
def rewrite(self, node: IRNode) -> RewriteResult:
|
28
|
+
has_done_something = False
|
29
|
+
for rule in self.rules:
|
30
|
+
result = rule.rewrite(node)
|
31
|
+
if result.terminated:
|
32
|
+
return result
|
33
|
+
|
34
|
+
if result.has_done_something:
|
35
|
+
has_done_something = True
|
36
|
+
return RewriteResult(has_done_something=has_done_something)
|
37
|
+
|
38
|
+
def __repr__(self):
|
39
|
+
return " -> ".join(map(str, self.rules))
|
@@ -0,0 +1,288 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.dialects import cf
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
6
|
+
from kirin.analysis.cfg import CFG
|
7
|
+
from kirin.rewrite.walk import Walk
|
8
|
+
from kirin.rewrite.chain import Chain
|
9
|
+
from kirin.rewrite.fixpoint import Fixpoint
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class DeadBlock(RewriteRule):
|
14
|
+
"""Compactify the CFG by removing dead blocks."""
|
15
|
+
|
16
|
+
cfg: CFG
|
17
|
+
|
18
|
+
def rewrite_Region(self, node: ir.Region) -> RewriteResult:
|
19
|
+
# remove non-entry blocks that are not reachable from the entry block
|
20
|
+
# TODO: check if this region is using SSACFG convention?
|
21
|
+
has_done_something = False
|
22
|
+
for block in node.blocks[1:]:
|
23
|
+
predecessors = self.cfg.predecessors.get(block)
|
24
|
+
if not predecessors: # empty predecessors
|
25
|
+
successors = self.cfg.successors.get(block, set())
|
26
|
+
for successor in successors:
|
27
|
+
self.cfg.predecessors[successor].discard(block)
|
28
|
+
self.cfg.successors.pop(block, None)
|
29
|
+
self.cfg.predecessors.pop(block, None)
|
30
|
+
block.delete()
|
31
|
+
has_done_something = True
|
32
|
+
return RewriteResult(has_done_something=has_done_something)
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class CFGEdge(RewriteRule):
|
37
|
+
"""Merge non-branching blocks on the edge of the CFG.
|
38
|
+
|
39
|
+
Example:
|
40
|
+
|
41
|
+
/---> [B] --> [D] --> [E]
|
42
|
+
[A]-----> [C] -------------^
|
43
|
+
|
44
|
+
[B] and [D] are non-branching blocks on the same edge. They can be merged into one block.
|
45
|
+
|
46
|
+
/---> [B,D] --> [E]
|
47
|
+
[A]-----> [C] -------^
|
48
|
+
"""
|
49
|
+
|
50
|
+
cfg: CFG
|
51
|
+
|
52
|
+
def rewrite_Region(self, node: ir.Region) -> RewriteResult:
|
53
|
+
result = RewriteResult()
|
54
|
+
for block in node.blocks:
|
55
|
+
result = self.rewrite_Block(block).join(result)
|
56
|
+
return result
|
57
|
+
|
58
|
+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
|
59
|
+
successors = self.cfg.successors.get(node, None)
|
60
|
+
if (
|
61
|
+
successors is None or len(successors) > 1 or len(successors) == 0
|
62
|
+
): # multiple outgoing edges
|
63
|
+
return RewriteResult()
|
64
|
+
|
65
|
+
successor = next(iter(successors))
|
66
|
+
if len(self.cfg.predecessors[successor]) > 1: # multiple incoming edges
|
67
|
+
return RewriteResult()
|
68
|
+
|
69
|
+
if not ((last_stmt := node.last_stmt) and isinstance(last_stmt, cf.Branch)):
|
70
|
+
return RewriteResult()
|
71
|
+
|
72
|
+
# merge the two blocks
|
73
|
+
for arg, input in zip(successor.args, last_stmt.arguments):
|
74
|
+
arg.replace_by(input)
|
75
|
+
last_stmt.delete()
|
76
|
+
for stmt in successor.stmts:
|
77
|
+
stmt.detach()
|
78
|
+
node.stmts.append(stmt)
|
79
|
+
successor.delete()
|
80
|
+
|
81
|
+
# update the CFG
|
82
|
+
new_successors = self.cfg.successors[successor]
|
83
|
+
self.cfg.successors[node] = new_successors
|
84
|
+
for new_successor in new_successors:
|
85
|
+
self.cfg.predecessors[new_successor].discard(successor)
|
86
|
+
self.cfg.predecessors[new_successor].add(node)
|
87
|
+
del self.cfg.successors[successor]
|
88
|
+
del self.cfg.predecessors[successor] # this is just [node]
|
89
|
+
return RewriteResult(has_done_something=True)
|
90
|
+
|
91
|
+
|
92
|
+
class DuplicatedBranch(RewriteRule):
|
93
|
+
"""Merge duplicated branches into a single branch.
|
94
|
+
|
95
|
+
Example:
|
96
|
+
|
97
|
+
[A]-->[B]
|
98
|
+
-----^
|
99
|
+
|
100
|
+
Merge the two branches into one without changing the CFG:
|
101
|
+
|
102
|
+
[A]-->[B]
|
103
|
+
"""
|
104
|
+
|
105
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
106
|
+
if (
|
107
|
+
not isinstance(node, cf.ConditionalBranch)
|
108
|
+
or node.then_successor is not node.else_successor
|
109
|
+
):
|
110
|
+
return RewriteResult()
|
111
|
+
|
112
|
+
for then_x, else_x in zip(node.then_arguments, node.else_arguments):
|
113
|
+
if then_x is not else_x:
|
114
|
+
return RewriteResult()
|
115
|
+
|
116
|
+
node.replace_by(
|
117
|
+
cf.Branch(arguments=node.then_arguments, successor=node.then_successor)
|
118
|
+
)
|
119
|
+
return RewriteResult(has_done_something=True)
|
120
|
+
|
121
|
+
|
122
|
+
@dataclass
|
123
|
+
class SkipBlock(RewriteRule):
|
124
|
+
"""Simplify a block that only contains a branch statement."""
|
125
|
+
|
126
|
+
cfg: CFG
|
127
|
+
|
128
|
+
def rewrite_Region(self, node: ir.Region) -> RewriteResult:
|
129
|
+
result = RewriteResult()
|
130
|
+
for block in node.blocks:
|
131
|
+
result = self.rewrite_Block(block).join(result)
|
132
|
+
return result
|
133
|
+
|
134
|
+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
|
135
|
+
if len(node.stmts) != 1:
|
136
|
+
return RewriteResult()
|
137
|
+
|
138
|
+
stmt = node.last_stmt
|
139
|
+
if not isinstance(stmt, cf.Branch):
|
140
|
+
return RewriteResult()
|
141
|
+
|
142
|
+
has_done_something = False
|
143
|
+
predecessors = self.cfg.predecessors.get(node, set())
|
144
|
+
# only if there is one predecessor and no uses of the arguments
|
145
|
+
if len(predecessors) == 1 and all(
|
146
|
+
self.can_skip(stmt, each) for each in node.args
|
147
|
+
):
|
148
|
+
has_done_something = self.rewrite_pred(node, stmt, next(iter(predecessors)))
|
149
|
+
return RewriteResult(has_done_something=has_done_something)
|
150
|
+
|
151
|
+
def can_skip(self, terminator: cf.Branch, value: ir.SSAValue) -> bool:
|
152
|
+
for use in value.uses:
|
153
|
+
if use.stmt is terminator:
|
154
|
+
continue
|
155
|
+
return False
|
156
|
+
return True
|
157
|
+
|
158
|
+
def rewrite_pred(
|
159
|
+
self, node: ir.Block, node_terminator: cf.Branch, predecessor: ir.Block
|
160
|
+
) -> bool:
|
161
|
+
terminator = predecessor.last_stmt
|
162
|
+
if isinstance(terminator, cf.Branch):
|
163
|
+
return self.rewrite_pred_Branch(
|
164
|
+
node, node_terminator, predecessor, terminator
|
165
|
+
)
|
166
|
+
elif isinstance(terminator, cf.ConditionalBranch):
|
167
|
+
return self.rewrite_pred_ConditionalBranch(
|
168
|
+
node, node_terminator, predecessor, terminator
|
169
|
+
)
|
170
|
+
return False
|
171
|
+
|
172
|
+
def rewrite_pred_Branch(
|
173
|
+
self,
|
174
|
+
node: ir.Block,
|
175
|
+
node_terminator: cf.Branch,
|
176
|
+
predecessor: ir.Block,
|
177
|
+
pred_terminator: cf.Branch,
|
178
|
+
) -> bool:
|
179
|
+
ssamap = self._block_inputs(node, pred_terminator.arguments)
|
180
|
+
pred_terminator.replace_by(
|
181
|
+
cf.Branch(
|
182
|
+
# NOTE: the argument can also be SSAs from previous blocks (non-phi)
|
183
|
+
arguments=tuple(
|
184
|
+
ssamap.get(arg, arg) for arg in node_terminator.arguments
|
185
|
+
),
|
186
|
+
successor=node_terminator.successor,
|
187
|
+
)
|
188
|
+
)
|
189
|
+
|
190
|
+
self.fix_cfg(predecessor, node, node_terminator.successor)
|
191
|
+
return True
|
192
|
+
|
193
|
+
def rewrite_pred_ConditionalBranch(
|
194
|
+
self,
|
195
|
+
node: ir.Block,
|
196
|
+
node_terminator: cf.Branch,
|
197
|
+
predecessor: ir.Block,
|
198
|
+
pred_terminator: cf.ConditionalBranch,
|
199
|
+
) -> bool:
|
200
|
+
then_arguments = pred_terminator.then_arguments
|
201
|
+
else_arguments = pred_terminator.else_arguments
|
202
|
+
then_successor = pred_terminator.then_successor
|
203
|
+
else_successor = pred_terminator.else_successor
|
204
|
+
|
205
|
+
has_done_something = False
|
206
|
+
if pred_terminator.then_successor is node:
|
207
|
+
ssamap = self._block_inputs(node, pred_terminator.then_arguments)
|
208
|
+
then_arguments = tuple(
|
209
|
+
ssamap.get(arg, arg) for arg in node_terminator.arguments
|
210
|
+
)
|
211
|
+
then_successor = node_terminator.successor
|
212
|
+
has_done_something = True
|
213
|
+
self.fix_cfg(predecessor, node, then_successor)
|
214
|
+
|
215
|
+
if pred_terminator.else_successor is node:
|
216
|
+
ssamap = self._block_inputs(node, pred_terminator.else_arguments)
|
217
|
+
else_arguments = tuple(
|
218
|
+
ssamap.get(arg, arg) for arg in node_terminator.arguments
|
219
|
+
)
|
220
|
+
else_successor = node_terminator.successor
|
221
|
+
has_done_something = True
|
222
|
+
self.fix_cfg(predecessor, node, else_successor)
|
223
|
+
|
224
|
+
pred_terminator.replace_by(
|
225
|
+
cf.ConditionalBranch(
|
226
|
+
cond=pred_terminator.cond,
|
227
|
+
then_arguments=then_arguments,
|
228
|
+
then_successor=then_successor,
|
229
|
+
else_arguments=else_arguments,
|
230
|
+
else_successor=else_successor,
|
231
|
+
)
|
232
|
+
)
|
233
|
+
return has_done_something
|
234
|
+
|
235
|
+
def fix_cfg(self, predecessor: ir.Block, node: ir.Block, successor: ir.Block):
|
236
|
+
node_pred_succ = self.cfg.successors.setdefault(predecessor, set())
|
237
|
+
node_pred_succ.discard(node)
|
238
|
+
node_pred_succ.add(successor)
|
239
|
+
|
240
|
+
node_succ_pred = self.cfg.predecessors.setdefault(successor, set())
|
241
|
+
node_succ_pred.add(predecessor)
|
242
|
+
|
243
|
+
node_pred = self.cfg.predecessors.setdefault(node, set())
|
244
|
+
node_pred.discard(predecessor)
|
245
|
+
|
246
|
+
def _block_inputs(
|
247
|
+
self, block: ir.Block, arguments: tuple[ir.SSAValue, ...]
|
248
|
+
) -> dict[ir.SSAValue, ir.SSAValue]:
|
249
|
+
return dict(zip(block.args, arguments))
|
250
|
+
|
251
|
+
|
252
|
+
@dataclass
|
253
|
+
class CompactifyRegion(RewriteRule):
|
254
|
+
"""Wrapper to share the CFG object with same CFG region."""
|
255
|
+
|
256
|
+
cfg: CFG
|
257
|
+
|
258
|
+
def __init__(self, cfg: CFG):
|
259
|
+
self.cfg = cfg
|
260
|
+
self.rule = Fixpoint(
|
261
|
+
Chain(
|
262
|
+
DeadBlock(cfg), Walk(DuplicatedBranch()), SkipBlock(cfg), CFGEdge(cfg)
|
263
|
+
)
|
264
|
+
)
|
265
|
+
|
266
|
+
def rewrite(self, node: ir.IRNode) -> RewriteResult:
|
267
|
+
return self.rule.rewrite(node)
|
268
|
+
|
269
|
+
|
270
|
+
@dataclass
|
271
|
+
class CFGCompactify(RewriteRule):
|
272
|
+
"""Compactify the CFG by removing dead blocks and merging blocks
|
273
|
+
if the statement uses the SSACFG convention. Do nothing if given
|
274
|
+
`ir.Region` or `ir.Block` due to no context of the region.
|
275
|
+
|
276
|
+
To compactify hierarchical CFG, combine this rule with `kirin.rewrite.Walk`
|
277
|
+
to recursively apply this rule to all statements.
|
278
|
+
"""
|
279
|
+
|
280
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
281
|
+
result = RewriteResult()
|
282
|
+
if not (trait := node.get_trait(ir.SSACFGRegion)):
|
283
|
+
return result
|
284
|
+
|
285
|
+
for region in node.regions:
|
286
|
+
cfg = trait.get_graph(region)
|
287
|
+
result = CompactifyRegion(cfg).rewrite(region).join(result)
|
288
|
+
return result
|
kirin/rewrite/cse.py
ADDED
@@ -0,0 +1,48 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin.ir import Pure, Block, Statement
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class CommonSubexpressionElimination(RewriteRule):
|
9
|
+
|
10
|
+
def rewrite_Block(self, node: Block) -> RewriteResult:
|
11
|
+
seen: dict[int, Statement] = {}
|
12
|
+
|
13
|
+
for stmt in node.stmts:
|
14
|
+
if not stmt.has_trait(Pure):
|
15
|
+
continue
|
16
|
+
|
17
|
+
if stmt.regions:
|
18
|
+
continue
|
19
|
+
|
20
|
+
hash_value = hash(
|
21
|
+
(type(stmt),)
|
22
|
+
+ tuple(stmt.args)
|
23
|
+
+ tuple(stmt.attributes.values())
|
24
|
+
+ tuple(stmt.successors)
|
25
|
+
+ tuple(stmt.regions)
|
26
|
+
)
|
27
|
+
if hash_value in seen:
|
28
|
+
old_stmt = seen[hash_value]
|
29
|
+
for result in stmt._results:
|
30
|
+
result.replace_by(old_stmt._results[0])
|
31
|
+
stmt.delete()
|
32
|
+
return RewriteResult(has_done_something=True)
|
33
|
+
else:
|
34
|
+
seen[hash_value] = stmt
|
35
|
+
return RewriteResult()
|
36
|
+
|
37
|
+
def rewrite_Statement(self, node: Statement) -> RewriteResult:
|
38
|
+
if not node.regions:
|
39
|
+
return RewriteResult()
|
40
|
+
|
41
|
+
has_done_something = False
|
42
|
+
for region in node.regions:
|
43
|
+
for block in region.blocks:
|
44
|
+
result = self.rewrite_Block(block)
|
45
|
+
if result.has_done_something:
|
46
|
+
has_done_something = True
|
47
|
+
|
48
|
+
return RewriteResult(has_done_something=has_done_something)
|
kirin/rewrite/dce.py
ADDED
@@ -0,0 +1,19 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class DeadCodeElimination(RewriteRule):
|
9
|
+
|
10
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
11
|
+
if self.is_pure(node):
|
12
|
+
for result in node._results:
|
13
|
+
if result.uses:
|
14
|
+
return RewriteResult()
|
15
|
+
|
16
|
+
node.delete()
|
17
|
+
return RewriteResult(has_done_something=True)
|
18
|
+
|
19
|
+
return RewriteResult()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin.ir import IRNode
|
4
|
+
from kirin.rewrite.abc import RewriteRule
|
5
|
+
from kirin.rewrite.result import RewriteResult
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class Fixpoint(RewriteRule):
|
10
|
+
"""Apply a rewrite rule until a fixpoint is reached.
|
11
|
+
|
12
|
+
The rewrite rule is applied to the node until the rewrite rule does not do anything.
|
13
|
+
|
14
|
+
### Parameters
|
15
|
+
- `map`: The rewrite rule to apply.
|
16
|
+
- `max_iter`: The maximum number of iterations to apply the rewrite rule. Default is 32.
|
17
|
+
"""
|
18
|
+
|
19
|
+
rule: RewriteRule
|
20
|
+
max_iter: int = 32
|
21
|
+
|
22
|
+
def rewrite(self, node: IRNode) -> RewriteResult:
|
23
|
+
has_done_something = False
|
24
|
+
for _ in range(self.max_iter):
|
25
|
+
result = self.rule.rewrite(node)
|
26
|
+
if result.terminated:
|
27
|
+
return result
|
28
|
+
|
29
|
+
if result.has_done_something:
|
30
|
+
has_done_something = True
|
31
|
+
else:
|
32
|
+
return RewriteResult(has_done_something=has_done_something)
|
33
|
+
|
34
|
+
return RewriteResult(exceeded_max_iter=True)
|
kirin/rewrite/fold.py
ADDED
@@ -0,0 +1,57 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.analysis import const
|
5
|
+
from kirin.dialects import cf
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
7
|
+
from kirin.dialects.py.constant import Constant
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass
|
11
|
+
class ConstantFold(RewriteRule):
|
12
|
+
|
13
|
+
def get_const(self, value: ir.SSAValue):
|
14
|
+
ret = value.hints.get("const")
|
15
|
+
|
16
|
+
if ret is not None and isinstance(ret, const.Value):
|
17
|
+
return ret
|
18
|
+
return None
|
19
|
+
|
20
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
21
|
+
if node.has_trait(ir.ConstantLike):
|
22
|
+
return RewriteResult()
|
23
|
+
elif isinstance(node, cf.ConditionalBranch):
|
24
|
+
return self.rewrite_cf_ConditionalBranch(node)
|
25
|
+
|
26
|
+
if not self.is_pure(node):
|
27
|
+
return RewriteResult()
|
28
|
+
|
29
|
+
has_done_something = False
|
30
|
+
for old_result in node.results:
|
31
|
+
if (value := self.get_const(old_result)) is not None:
|
32
|
+
stmt = Constant(value.data)
|
33
|
+
stmt.insert_before(node)
|
34
|
+
old_result.replace_by(stmt.result)
|
35
|
+
stmt.result.hints["const"] = value
|
36
|
+
if old_result.name:
|
37
|
+
stmt.result.name = old_result.name
|
38
|
+
has_done_something = True
|
39
|
+
return RewriteResult(has_done_something=has_done_something)
|
40
|
+
|
41
|
+
def rewrite_cf_ConditionalBranch(self, node: cf.ConditionalBranch):
|
42
|
+
if (value := self.get_const(node.cond)) is not None:
|
43
|
+
if value.data is True:
|
44
|
+
cf.Branch(
|
45
|
+
arguments=node.then_arguments,
|
46
|
+
successor=node.then_successor,
|
47
|
+
).insert_before(node)
|
48
|
+
elif value.data is False:
|
49
|
+
cf.Branch(
|
50
|
+
arguments=node.else_arguments,
|
51
|
+
successor=node.else_successor,
|
52
|
+
).insert_before(node)
|
53
|
+
else:
|
54
|
+
raise ValueError(f"Invalid constant value for branch: {value.data}")
|
55
|
+
node.delete()
|
56
|
+
return RewriteResult(has_done_something=True)
|
57
|
+
return RewriteResult()
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.dialects import func
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class InlineGetField(RewriteRule):
|
10
|
+
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
12
|
+
if not isinstance(node, func.GetField):
|
13
|
+
return RewriteResult()
|
14
|
+
|
15
|
+
if not isinstance(node.obj.owner, func.Lambda):
|
16
|
+
return RewriteResult()
|
17
|
+
|
18
|
+
original = node.obj.owner.captured[node.field]
|
19
|
+
node.result.replace_by(original)
|
20
|
+
node.delete()
|
21
|
+
return RewriteResult(has_done_something=True)
|
kirin/rewrite/getitem.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.analysis import const
|
5
|
+
from kirin.dialects import py
|
6
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class InlineGetItem(RewriteRule):
|
11
|
+
|
12
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
13
|
+
if not isinstance(node, py.indexing.GetItem):
|
14
|
+
return RewriteResult()
|
15
|
+
|
16
|
+
if not isinstance(node.obj.owner, py.tuple.New):
|
17
|
+
return RewriteResult()
|
18
|
+
|
19
|
+
if not isinstance(index_value := node.index.hints.get("const"), const.Value):
|
20
|
+
return RewriteResult()
|
21
|
+
|
22
|
+
stmt = node.obj.owner
|
23
|
+
index = index_value.data
|
24
|
+
if isinstance(index, int) and (
|
25
|
+
0 <= index < len(stmt.args) or -len(stmt.args) <= index < 0
|
26
|
+
):
|
27
|
+
node.result.replace_by(stmt.args[index])
|
28
|
+
return RewriteResult(has_done_something=True)
|
29
|
+
elif isinstance(index, slice):
|
30
|
+
start, stop, step = index.indices(len(stmt.args))
|
31
|
+
new_tuple = py.tuple.New(
|
32
|
+
tuple(stmt.args[start:stop:step]),
|
33
|
+
)
|
34
|
+
node.replace_by(new_tuple)
|
35
|
+
return RewriteResult(has_done_something=True)
|
36
|
+
else:
|
37
|
+
return RewriteResult()
|