kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,47 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir, types
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+ from kirin.dialects.func.attrs import Signature
6
+
7
+
8
+ @dataclass
9
+ class ApplyType(RewriteRule):
10
+ results: dict[ir.SSAValue, types.TypeAttribute]
11
+
12
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
13
+ has_done_something = False
14
+ for arg in node.args:
15
+ if arg in self.results:
16
+ arg.type = self.results[arg]
17
+ has_done_something = True
18
+
19
+ return RewriteResult(has_done_something=has_done_something)
20
+
21
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
22
+ has_done_something = False
23
+ for result in node._results:
24
+ if result in self.results:
25
+ result.type = self.results[result]
26
+ has_done_something = True
27
+
28
+ if (trait := node.get_trait(ir.HasSignature)) is not None and (
29
+ callable_trait := node.get_trait(ir.CallableStmtInterface)
30
+ ) is not None:
31
+ callable_region = callable_trait.get_callable_region(node)
32
+ inputs = tuple(
33
+ self.results.get(arg, arg.type)
34
+ for arg in callable_region.blocks[0].args
35
+ )
36
+
37
+ if (
38
+ len(node._results) == 1
39
+ and isinstance(
40
+ output_ := self.results.get(node._results[0]), types.Generic
41
+ )
42
+ and output_.is_subseteq(types.MethodType)
43
+ ):
44
+ output_ = output_.vars[1]
45
+ trait.set_signature(node, Signature(inputs, output_))
46
+ has_done_something = True
47
+ return RewriteResult(has_done_something=has_done_something)
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import const
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.dialects.func import Call, Invoke
7
+
8
+
9
+ @dataclass
10
+ class Call2Invoke(RewriteRule):
11
+ """Rewrite a `Call` statement to an `Invoke` statement."""
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ if not isinstance(node, Call):
15
+ return RewriteResult()
16
+
17
+ if (mt := node.callee.hints.get("const")) is None:
18
+ return RewriteResult()
19
+
20
+ if not isinstance(mt, const.Value):
21
+ return RewriteResult()
22
+
23
+ if not isinstance(mt.data, ir.Method):
24
+ return RewriteResult()
25
+
26
+ stmt = Invoke(inputs=node.inputs, callee=mt.data, kwargs=node.kwargs)
27
+ for result, new_result in zip(node.results, stmt.results):
28
+ new_result.name = result.name
29
+ new_result.type = result.type
30
+ if result_hint := result.hints.get("const"):
31
+ new_result.hints["const"] = result_hint
32
+
33
+ node.replace_by(stmt)
34
+ return RewriteResult(has_done_something=True)
kirin/rewrite/chain.py ADDED
@@ -0,0 +1,39 @@
1
+ from typing import Iterable
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.ir import IRNode
5
+ from kirin.rewrite.abc import RewriteRule
6
+ from kirin.rewrite.result import RewriteResult
7
+
8
+
9
+ @dataclass
10
+ class Chain(RewriteRule):
11
+ """Chain multiple rewrites together.
12
+
13
+ The chain will apply each rewrite in order until one of the rewrites terminates.
14
+ """
15
+
16
+ rules: list[RewriteRule]
17
+
18
+ def __init__(self, rule: RewriteRule | Iterable[RewriteRule], *others: RewriteRule):
19
+ if isinstance(rule, RewriteRule):
20
+ self.rules = [rule, *others]
21
+ else:
22
+ assert (
23
+ others == ()
24
+ ), "Cannot pass multiple positional arguments if the first argument is an iterable"
25
+ self.rules = list(rule)
26
+
27
+ def rewrite(self, node: IRNode) -> RewriteResult:
28
+ has_done_something = False
29
+ for rule in self.rules:
30
+ result = rule.rewrite(node)
31
+ if result.terminated:
32
+ return result
33
+
34
+ if result.has_done_something:
35
+ has_done_something = True
36
+ return RewriteResult(has_done_something=has_done_something)
37
+
38
+ def __repr__(self):
39
+ return " -> ".join(map(str, self.rules))
@@ -0,0 +1,288 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.dialects import cf
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.analysis.cfg import CFG
7
+ from kirin.rewrite.walk import Walk
8
+ from kirin.rewrite.chain import Chain
9
+ from kirin.rewrite.fixpoint import Fixpoint
10
+
11
+
12
+ @dataclass
13
+ class DeadBlock(RewriteRule):
14
+ """Compactify the CFG by removing dead blocks."""
15
+
16
+ cfg: CFG
17
+
18
+ def rewrite_Region(self, node: ir.Region) -> RewriteResult:
19
+ # remove non-entry blocks that are not reachable from the entry block
20
+ # TODO: check if this region is using SSACFG convention?
21
+ has_done_something = False
22
+ for block in node.blocks[1:]:
23
+ predecessors = self.cfg.predecessors.get(block)
24
+ if not predecessors: # empty predecessors
25
+ successors = self.cfg.successors.get(block, set())
26
+ for successor in successors:
27
+ self.cfg.predecessors[successor].discard(block)
28
+ self.cfg.successors.pop(block, None)
29
+ self.cfg.predecessors.pop(block, None)
30
+ block.delete()
31
+ has_done_something = True
32
+ return RewriteResult(has_done_something=has_done_something)
33
+
34
+
35
+ @dataclass
36
+ class CFGEdge(RewriteRule):
37
+ """Merge non-branching blocks on the edge of the CFG.
38
+
39
+ Example:
40
+
41
+ /---> [B] --> [D] --> [E]
42
+ [A]-----> [C] -------------^
43
+
44
+ [B] and [D] are non-branching blocks on the same edge. They can be merged into one block.
45
+
46
+ /---> [B,D] --> [E]
47
+ [A]-----> [C] -------^
48
+ """
49
+
50
+ cfg: CFG
51
+
52
+ def rewrite_Region(self, node: ir.Region) -> RewriteResult:
53
+ result = RewriteResult()
54
+ for block in node.blocks:
55
+ result = self.rewrite_Block(block).join(result)
56
+ return result
57
+
58
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
59
+ successors = self.cfg.successors.get(node, None)
60
+ if (
61
+ successors is None or len(successors) > 1 or len(successors) == 0
62
+ ): # multiple outgoing edges
63
+ return RewriteResult()
64
+
65
+ successor = next(iter(successors))
66
+ if len(self.cfg.predecessors[successor]) > 1: # multiple incoming edges
67
+ return RewriteResult()
68
+
69
+ if not ((last_stmt := node.last_stmt) and isinstance(last_stmt, cf.Branch)):
70
+ return RewriteResult()
71
+
72
+ # merge the two blocks
73
+ for arg, input in zip(successor.args, last_stmt.arguments):
74
+ arg.replace_by(input)
75
+ last_stmt.delete()
76
+ for stmt in successor.stmts:
77
+ stmt.detach()
78
+ node.stmts.append(stmt)
79
+ successor.delete()
80
+
81
+ # update the CFG
82
+ new_successors = self.cfg.successors[successor]
83
+ self.cfg.successors[node] = new_successors
84
+ for new_successor in new_successors:
85
+ self.cfg.predecessors[new_successor].discard(successor)
86
+ self.cfg.predecessors[new_successor].add(node)
87
+ del self.cfg.successors[successor]
88
+ del self.cfg.predecessors[successor] # this is just [node]
89
+ return RewriteResult(has_done_something=True)
90
+
91
+
92
+ class DuplicatedBranch(RewriteRule):
93
+ """Merge duplicated branches into a single branch.
94
+
95
+ Example:
96
+
97
+ [A]-->[B]
98
+ -----^
99
+
100
+ Merge the two branches into one without changing the CFG:
101
+
102
+ [A]-->[B]
103
+ """
104
+
105
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
106
+ if (
107
+ not isinstance(node, cf.ConditionalBranch)
108
+ or node.then_successor is not node.else_successor
109
+ ):
110
+ return RewriteResult()
111
+
112
+ for then_x, else_x in zip(node.then_arguments, node.else_arguments):
113
+ if then_x is not else_x:
114
+ return RewriteResult()
115
+
116
+ node.replace_by(
117
+ cf.Branch(arguments=node.then_arguments, successor=node.then_successor)
118
+ )
119
+ return RewriteResult(has_done_something=True)
120
+
121
+
122
+ @dataclass
123
+ class SkipBlock(RewriteRule):
124
+ """Simplify a block that only contains a branch statement."""
125
+
126
+ cfg: CFG
127
+
128
+ def rewrite_Region(self, node: ir.Region) -> RewriteResult:
129
+ result = RewriteResult()
130
+ for block in node.blocks:
131
+ result = self.rewrite_Block(block).join(result)
132
+ return result
133
+
134
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
135
+ if len(node.stmts) != 1:
136
+ return RewriteResult()
137
+
138
+ stmt = node.last_stmt
139
+ if not isinstance(stmt, cf.Branch):
140
+ return RewriteResult()
141
+
142
+ has_done_something = False
143
+ predecessors = self.cfg.predecessors.get(node, set())
144
+ # only if there is one predecessor and no uses of the arguments
145
+ if len(predecessors) == 1 and all(
146
+ self.can_skip(stmt, each) for each in node.args
147
+ ):
148
+ has_done_something = self.rewrite_pred(node, stmt, next(iter(predecessors)))
149
+ return RewriteResult(has_done_something=has_done_something)
150
+
151
+ def can_skip(self, terminator: cf.Branch, value: ir.SSAValue) -> bool:
152
+ for use in value.uses:
153
+ if use.stmt is terminator:
154
+ continue
155
+ return False
156
+ return True
157
+
158
+ def rewrite_pred(
159
+ self, node: ir.Block, node_terminator: cf.Branch, predecessor: ir.Block
160
+ ) -> bool:
161
+ terminator = predecessor.last_stmt
162
+ if isinstance(terminator, cf.Branch):
163
+ return self.rewrite_pred_Branch(
164
+ node, node_terminator, predecessor, terminator
165
+ )
166
+ elif isinstance(terminator, cf.ConditionalBranch):
167
+ return self.rewrite_pred_ConditionalBranch(
168
+ node, node_terminator, predecessor, terminator
169
+ )
170
+ return False
171
+
172
+ def rewrite_pred_Branch(
173
+ self,
174
+ node: ir.Block,
175
+ node_terminator: cf.Branch,
176
+ predecessor: ir.Block,
177
+ pred_terminator: cf.Branch,
178
+ ) -> bool:
179
+ ssamap = self._block_inputs(node, pred_terminator.arguments)
180
+ pred_terminator.replace_by(
181
+ cf.Branch(
182
+ # NOTE: the argument can also be SSAs from previous blocks (non-phi)
183
+ arguments=tuple(
184
+ ssamap.get(arg, arg) for arg in node_terminator.arguments
185
+ ),
186
+ successor=node_terminator.successor,
187
+ )
188
+ )
189
+
190
+ self.fix_cfg(predecessor, node, node_terminator.successor)
191
+ return True
192
+
193
+ def rewrite_pred_ConditionalBranch(
194
+ self,
195
+ node: ir.Block,
196
+ node_terminator: cf.Branch,
197
+ predecessor: ir.Block,
198
+ pred_terminator: cf.ConditionalBranch,
199
+ ) -> bool:
200
+ then_arguments = pred_terminator.then_arguments
201
+ else_arguments = pred_terminator.else_arguments
202
+ then_successor = pred_terminator.then_successor
203
+ else_successor = pred_terminator.else_successor
204
+
205
+ has_done_something = False
206
+ if pred_terminator.then_successor is node:
207
+ ssamap = self._block_inputs(node, pred_terminator.then_arguments)
208
+ then_arguments = tuple(
209
+ ssamap.get(arg, arg) for arg in node_terminator.arguments
210
+ )
211
+ then_successor = node_terminator.successor
212
+ has_done_something = True
213
+ self.fix_cfg(predecessor, node, then_successor)
214
+
215
+ if pred_terminator.else_successor is node:
216
+ ssamap = self._block_inputs(node, pred_terminator.else_arguments)
217
+ else_arguments = tuple(
218
+ ssamap.get(arg, arg) for arg in node_terminator.arguments
219
+ )
220
+ else_successor = node_terminator.successor
221
+ has_done_something = True
222
+ self.fix_cfg(predecessor, node, else_successor)
223
+
224
+ pred_terminator.replace_by(
225
+ cf.ConditionalBranch(
226
+ cond=pred_terminator.cond,
227
+ then_arguments=then_arguments,
228
+ then_successor=then_successor,
229
+ else_arguments=else_arguments,
230
+ else_successor=else_successor,
231
+ )
232
+ )
233
+ return has_done_something
234
+
235
+ def fix_cfg(self, predecessor: ir.Block, node: ir.Block, successor: ir.Block):
236
+ node_pred_succ = self.cfg.successors.setdefault(predecessor, set())
237
+ node_pred_succ.discard(node)
238
+ node_pred_succ.add(successor)
239
+
240
+ node_succ_pred = self.cfg.predecessors.setdefault(successor, set())
241
+ node_succ_pred.add(predecessor)
242
+
243
+ node_pred = self.cfg.predecessors.setdefault(node, set())
244
+ node_pred.discard(predecessor)
245
+
246
+ def _block_inputs(
247
+ self, block: ir.Block, arguments: tuple[ir.SSAValue, ...]
248
+ ) -> dict[ir.SSAValue, ir.SSAValue]:
249
+ return dict(zip(block.args, arguments))
250
+
251
+
252
+ @dataclass
253
+ class CompactifyRegion(RewriteRule):
254
+ """Wrapper to share the CFG object with same CFG region."""
255
+
256
+ cfg: CFG
257
+
258
+ def __init__(self, cfg: CFG):
259
+ self.cfg = cfg
260
+ self.rule = Fixpoint(
261
+ Chain(
262
+ DeadBlock(cfg), Walk(DuplicatedBranch()), SkipBlock(cfg), CFGEdge(cfg)
263
+ )
264
+ )
265
+
266
+ def rewrite(self, node: ir.IRNode) -> RewriteResult:
267
+ return self.rule.rewrite(node)
268
+
269
+
270
+ @dataclass
271
+ class CFGCompactify(RewriteRule):
272
+ """Compactify the CFG by removing dead blocks and merging blocks
273
+ if the statement uses the SSACFG convention. Do nothing if given
274
+ `ir.Region` or `ir.Block` due to no context of the region.
275
+
276
+ To compactify hierarchical CFG, combine this rule with `kirin.rewrite.Walk`
277
+ to recursively apply this rule to all statements.
278
+ """
279
+
280
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
281
+ result = RewriteResult()
282
+ if not (trait := node.get_trait(ir.SSACFGRegion)):
283
+ return result
284
+
285
+ for region in node.regions:
286
+ cfg = trait.get_graph(region)
287
+ result = CompactifyRegion(cfg).rewrite(region).join(result)
288
+ return result
kirin/rewrite/cse.py ADDED
@@ -0,0 +1,48 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.ir import Pure, Block, Statement
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+
7
+ @dataclass
8
+ class CommonSubexpressionElimination(RewriteRule):
9
+
10
+ def rewrite_Block(self, node: Block) -> RewriteResult:
11
+ seen: dict[int, Statement] = {}
12
+
13
+ for stmt in node.stmts:
14
+ if not stmt.has_trait(Pure):
15
+ continue
16
+
17
+ if stmt.regions:
18
+ continue
19
+
20
+ hash_value = hash(
21
+ (type(stmt),)
22
+ + tuple(stmt.args)
23
+ + tuple(stmt.attributes.values())
24
+ + tuple(stmt.successors)
25
+ + tuple(stmt.regions)
26
+ )
27
+ if hash_value in seen:
28
+ old_stmt = seen[hash_value]
29
+ for result in stmt._results:
30
+ result.replace_by(old_stmt._results[0])
31
+ stmt.delete()
32
+ return RewriteResult(has_done_something=True)
33
+ else:
34
+ seen[hash_value] = stmt
35
+ return RewriteResult()
36
+
37
+ def rewrite_Statement(self, node: Statement) -> RewriteResult:
38
+ if not node.regions:
39
+ return RewriteResult()
40
+
41
+ has_done_something = False
42
+ for region in node.regions:
43
+ for block in region.blocks:
44
+ result = self.rewrite_Block(block)
45
+ if result.has_done_something:
46
+ has_done_something = True
47
+
48
+ return RewriteResult(has_done_something=has_done_something)
kirin/rewrite/dce.py ADDED
@@ -0,0 +1,19 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+
6
+
7
+ @dataclass
8
+ class DeadCodeElimination(RewriteRule):
9
+
10
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
11
+ if self.is_pure(node):
12
+ for result in node._results:
13
+ if result.uses:
14
+ return RewriteResult()
15
+
16
+ node.delete()
17
+ return RewriteResult(has_done_something=True)
18
+
19
+ return RewriteResult()
@@ -0,0 +1,34 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.ir import IRNode
4
+ from kirin.rewrite.abc import RewriteRule
5
+ from kirin.rewrite.result import RewriteResult
6
+
7
+
8
+ @dataclass
9
+ class Fixpoint(RewriteRule):
10
+ """Apply a rewrite rule until a fixpoint is reached.
11
+
12
+ The rewrite rule is applied to the node until the rewrite rule does not do anything.
13
+
14
+ ### Parameters
15
+ - `map`: The rewrite rule to apply.
16
+ - `max_iter`: The maximum number of iterations to apply the rewrite rule. Default is 32.
17
+ """
18
+
19
+ rule: RewriteRule
20
+ max_iter: int = 32
21
+
22
+ def rewrite(self, node: IRNode) -> RewriteResult:
23
+ has_done_something = False
24
+ for _ in range(self.max_iter):
25
+ result = self.rule.rewrite(node)
26
+ if result.terminated:
27
+ return result
28
+
29
+ if result.has_done_something:
30
+ has_done_something = True
31
+ else:
32
+ return RewriteResult(has_done_something=has_done_something)
33
+
34
+ return RewriteResult(exceeded_max_iter=True)
kirin/rewrite/fold.py ADDED
@@ -0,0 +1,57 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import const
5
+ from kirin.dialects import cf
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+ from kirin.dialects.py.constant import Constant
8
+
9
+
10
+ @dataclass
11
+ class ConstantFold(RewriteRule):
12
+
13
+ def get_const(self, value: ir.SSAValue):
14
+ ret = value.hints.get("const")
15
+
16
+ if ret is not None and isinstance(ret, const.Value):
17
+ return ret
18
+ return None
19
+
20
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
21
+ if node.has_trait(ir.ConstantLike):
22
+ return RewriteResult()
23
+ elif isinstance(node, cf.ConditionalBranch):
24
+ return self.rewrite_cf_ConditionalBranch(node)
25
+
26
+ if not self.is_pure(node):
27
+ return RewriteResult()
28
+
29
+ has_done_something = False
30
+ for old_result in node.results:
31
+ if (value := self.get_const(old_result)) is not None:
32
+ stmt = Constant(value.data)
33
+ stmt.insert_before(node)
34
+ old_result.replace_by(stmt.result)
35
+ stmt.result.hints["const"] = value
36
+ if old_result.name:
37
+ stmt.result.name = old_result.name
38
+ has_done_something = True
39
+ return RewriteResult(has_done_something=has_done_something)
40
+
41
+ def rewrite_cf_ConditionalBranch(self, node: cf.ConditionalBranch):
42
+ if (value := self.get_const(node.cond)) is not None:
43
+ if value.data is True:
44
+ cf.Branch(
45
+ arguments=node.then_arguments,
46
+ successor=node.then_successor,
47
+ ).insert_before(node)
48
+ elif value.data is False:
49
+ cf.Branch(
50
+ arguments=node.else_arguments,
51
+ successor=node.else_successor,
52
+ ).insert_before(node)
53
+ else:
54
+ raise ValueError(f"Invalid constant value for branch: {value.data}")
55
+ node.delete()
56
+ return RewriteResult(has_done_something=True)
57
+ return RewriteResult()
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.dialects import func
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+
7
+
8
+ @dataclass
9
+ class InlineGetField(RewriteRule):
10
+
11
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12
+ if not isinstance(node, func.GetField):
13
+ return RewriteResult()
14
+
15
+ if not isinstance(node.obj.owner, func.Lambda):
16
+ return RewriteResult()
17
+
18
+ original = node.obj.owner.captured[node.field]
19
+ node.result.replace_by(original)
20
+ node.delete()
21
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,37 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.analysis import const
5
+ from kirin.dialects import py
6
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
7
+
8
+
9
+ @dataclass
10
+ class InlineGetItem(RewriteRule):
11
+
12
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
13
+ if not isinstance(node, py.indexing.GetItem):
14
+ return RewriteResult()
15
+
16
+ if not isinstance(node.obj.owner, py.tuple.New):
17
+ return RewriteResult()
18
+
19
+ if not isinstance(index_value := node.index.hints.get("const"), const.Value):
20
+ return RewriteResult()
21
+
22
+ stmt = node.obj.owner
23
+ index = index_value.data
24
+ if isinstance(index, int) and (
25
+ 0 <= index < len(stmt.args) or -len(stmt.args) <= index < 0
26
+ ):
27
+ node.result.replace_by(stmt.args[index])
28
+ return RewriteResult(has_done_something=True)
29
+ elif isinstance(index, slice):
30
+ start, stop, step = index.indices(len(stmt.args))
31
+ new_tuple = py.tuple.New(
32
+ tuple(stmt.args[start:stop:step]),
33
+ )
34
+ node.replace_by(new_tuple)
35
+ return RewriteResult(has_done_something=True)
36
+ else:
37
+ return RewriteResult()