kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,250 @@
|
|
1
|
+
from kirin import ir, types
|
2
|
+
from kirin.decl import info, statement
|
3
|
+
from kirin.exceptions import VerificationError, DialectLoweringError
|
4
|
+
from kirin.print.printer import Printer
|
5
|
+
|
6
|
+
from ._dialect import dialect
|
7
|
+
|
8
|
+
|
9
|
+
@statement(dialect=dialect, init=False)
|
10
|
+
class IfElse(ir.Statement):
|
11
|
+
"""Python-like if-else statement.
|
12
|
+
|
13
|
+
This statement has a condition, then body, and else body.
|
14
|
+
|
15
|
+
Then body either terminates with a yield statement or `scf.return`.
|
16
|
+
"""
|
17
|
+
|
18
|
+
name = "if"
|
19
|
+
traits = frozenset({ir.MaybePure()})
|
20
|
+
purity: bool = info.attribute(default=False)
|
21
|
+
cond: ir.SSAValue = info.argument(types.Any)
|
22
|
+
# NOTE: we don't enforce the type here
|
23
|
+
# because anything implements __bool__ in Python
|
24
|
+
# can be used as a condition
|
25
|
+
then_body: ir.Region = info.region(multi=False)
|
26
|
+
else_body: ir.Region = info.region(multi=False, default_factory=ir.Region)
|
27
|
+
|
28
|
+
def __init__(
|
29
|
+
self,
|
30
|
+
cond: ir.SSAValue,
|
31
|
+
then_body: ir.Region | ir.Block,
|
32
|
+
else_body: ir.Region | ir.Block | None = None,
|
33
|
+
):
|
34
|
+
if isinstance(then_body, ir.Region):
|
35
|
+
if len(then_body.blocks) != 1:
|
36
|
+
raise DialectLoweringError(
|
37
|
+
"if-else statement must have a single block in the then region"
|
38
|
+
)
|
39
|
+
then_body_region = then_body
|
40
|
+
then_body = then_body_region.blocks[0]
|
41
|
+
elif isinstance(then_body, ir.Block):
|
42
|
+
then_body_region = ir.Region(then_body)
|
43
|
+
|
44
|
+
if isinstance(else_body, ir.Region):
|
45
|
+
if not else_body.blocks: # empty region
|
46
|
+
else_body_region = else_body
|
47
|
+
else_body = None
|
48
|
+
elif len(else_body.blocks) != 1:
|
49
|
+
raise DialectLoweringError(
|
50
|
+
"if-else statement must have a single block in the else region"
|
51
|
+
)
|
52
|
+
else:
|
53
|
+
else_body_region = else_body
|
54
|
+
else_body = else_body_region.blocks[0]
|
55
|
+
elif isinstance(else_body, ir.Block):
|
56
|
+
else_body_region = ir.Region(else_body)
|
57
|
+
else:
|
58
|
+
else_body_region = ir.Region()
|
59
|
+
|
60
|
+
# if either then or else body has yield, we generate results
|
61
|
+
# we assume if both have yields, they have the same number of results
|
62
|
+
then_yield = then_body.last_stmt
|
63
|
+
else_yield = else_body.last_stmt if else_body is not None else None
|
64
|
+
if then_yield is not None and isinstance(then_yield, Yield):
|
65
|
+
results = then_yield.values
|
66
|
+
elif else_yield is not None and isinstance(else_yield, Yield):
|
67
|
+
results = else_yield.values
|
68
|
+
else:
|
69
|
+
results = ()
|
70
|
+
|
71
|
+
result_types = tuple(value.type for value in results)
|
72
|
+
super().__init__(
|
73
|
+
args=(cond,),
|
74
|
+
regions=(then_body_region, else_body_region),
|
75
|
+
result_types=result_types,
|
76
|
+
args_slice={"cond": 0},
|
77
|
+
attributes={"purity": ir.PyAttr(False)},
|
78
|
+
)
|
79
|
+
|
80
|
+
def print_impl(self, printer: Printer) -> None:
|
81
|
+
printer.print_name(self)
|
82
|
+
printer.plain_print(" ")
|
83
|
+
printer.print(self.cond)
|
84
|
+
printer.plain_print(" ")
|
85
|
+
printer.print(self.then_body)
|
86
|
+
if self.else_body.blocks and not (
|
87
|
+
len(self.else_body.blocks[0].stmts) == 1
|
88
|
+
and isinstance(else_term := self.else_body.blocks[0].last_stmt, Yield)
|
89
|
+
and not else_term.values # empty yield
|
90
|
+
):
|
91
|
+
printer.plain_print(" else ", style="keyword")
|
92
|
+
printer.print(self.else_body)
|
93
|
+
|
94
|
+
with printer.rich(style="comment"):
|
95
|
+
printer.plain_print(f" -> purity={self.purity}")
|
96
|
+
|
97
|
+
def verify(self) -> None:
|
98
|
+
from kirin.dialects.func import Return
|
99
|
+
|
100
|
+
if len(self.then_body.blocks) != 1:
|
101
|
+
raise VerificationError(self, "then region must have a single block")
|
102
|
+
|
103
|
+
if len(self.else_body.blocks) != 1:
|
104
|
+
raise VerificationError(self, "else region must have a single block")
|
105
|
+
|
106
|
+
then_block = self.then_body.blocks[0]
|
107
|
+
else_block = self.else_body.blocks[0]
|
108
|
+
if len(then_block.args) != 1:
|
109
|
+
raise VerificationError(
|
110
|
+
self, "then block must have a single argument for condition"
|
111
|
+
)
|
112
|
+
|
113
|
+
if len(else_block.args) != 1:
|
114
|
+
raise VerificationError(
|
115
|
+
self, "else block must have a single argument for condition"
|
116
|
+
)
|
117
|
+
|
118
|
+
then_stmt = then_block.last_stmt
|
119
|
+
else_stmt = else_block.last_stmt
|
120
|
+
if then_stmt is None or not isinstance(then_stmt, (Yield, Return)):
|
121
|
+
raise VerificationError(
|
122
|
+
self, "then block must terminate with a yield or return"
|
123
|
+
)
|
124
|
+
|
125
|
+
if else_stmt is None or not isinstance(else_stmt, (Yield, Return)):
|
126
|
+
raise VerificationError(
|
127
|
+
self, "else block must terminate with a yield or return"
|
128
|
+
)
|
129
|
+
|
130
|
+
|
131
|
+
@statement(dialect=dialect, init=False)
|
132
|
+
class For(ir.Statement):
|
133
|
+
name = "for"
|
134
|
+
traits = frozenset({ir.MaybePure()})
|
135
|
+
purity: bool = info.attribute(default=False)
|
136
|
+
iterable: ir.SSAValue = info.argument(types.Any)
|
137
|
+
body: ir.Region = info.region(multi=False)
|
138
|
+
initializers: tuple[ir.SSAValue, ...] = info.argument(types.Any)
|
139
|
+
|
140
|
+
def __init__(
|
141
|
+
self,
|
142
|
+
iterable: ir.SSAValue,
|
143
|
+
body: ir.Region,
|
144
|
+
*initializers: ir.SSAValue,
|
145
|
+
):
|
146
|
+
stmt = body.blocks[0].last_stmt
|
147
|
+
if isinstance(stmt, Yield):
|
148
|
+
result_types = tuple(value.type for value in stmt.values)
|
149
|
+
else:
|
150
|
+
result_types = ()
|
151
|
+
super().__init__(
|
152
|
+
args=(iterable, *initializers),
|
153
|
+
regions=(body,),
|
154
|
+
result_types=result_types,
|
155
|
+
args_slice={"iterable": 0, "initializers": slice(1, None)},
|
156
|
+
attributes={"purity": ir.PyAttr(False)},
|
157
|
+
)
|
158
|
+
|
159
|
+
def verify(self) -> None:
|
160
|
+
from kirin.dialects.func import Return
|
161
|
+
|
162
|
+
if len(self.body.blocks) != 1:
|
163
|
+
raise VerificationError(self, "for loop body must have a single block")
|
164
|
+
|
165
|
+
if len(self.body.blocks[0].args) != len(self.initializers) + 1:
|
166
|
+
raise VerificationError(
|
167
|
+
self,
|
168
|
+
"for loop body must have arguments for all initializers and the loop variable",
|
169
|
+
)
|
170
|
+
|
171
|
+
stmt = self.body.blocks[0].last_stmt
|
172
|
+
if stmt is None or not isinstance(stmt, (Yield, Return)):
|
173
|
+
raise VerificationError(
|
174
|
+
self, "for loop body must terminate with a yield or return"
|
175
|
+
)
|
176
|
+
|
177
|
+
if isinstance(stmt, Return):
|
178
|
+
return
|
179
|
+
|
180
|
+
if len(stmt.values) != len(self.initializers):
|
181
|
+
raise VerificationError(
|
182
|
+
self,
|
183
|
+
"for loop body must have the same number of results as initializers",
|
184
|
+
)
|
185
|
+
if len(self.results) != len(stmt.values):
|
186
|
+
raise VerificationError(
|
187
|
+
self,
|
188
|
+
"for loop must have the same number of results as the yield in the body",
|
189
|
+
)
|
190
|
+
|
191
|
+
def print_impl(self, printer: Printer) -> None:
|
192
|
+
printer.print_name(self)
|
193
|
+
printer.plain_print(" ")
|
194
|
+
block = self.body.blocks[0]
|
195
|
+
printer.print(block.args[0])
|
196
|
+
printer.plain_print(" in ", style="keyword")
|
197
|
+
printer.print(self.iterable)
|
198
|
+
if self.results:
|
199
|
+
with printer.rich(style="comment"):
|
200
|
+
printer.plain_print(" -> ")
|
201
|
+
printer.print_seq(
|
202
|
+
tuple(result.type for result in self.results),
|
203
|
+
delim=", ",
|
204
|
+
style="comment",
|
205
|
+
)
|
206
|
+
|
207
|
+
with printer.indent():
|
208
|
+
if self.initializers:
|
209
|
+
printer.print_newline()
|
210
|
+
printer.plain_print("iter_args(")
|
211
|
+
for idx, (arg, val) in enumerate(
|
212
|
+
zip(block.args[1:], self.initializers)
|
213
|
+
):
|
214
|
+
printer.print(arg)
|
215
|
+
printer.plain_print(" = ")
|
216
|
+
printer.print(val)
|
217
|
+
if idx < len(self.initializers) - 1:
|
218
|
+
printer.plain_print(", ")
|
219
|
+
printer.plain_print(")")
|
220
|
+
|
221
|
+
printer.plain_print(" {")
|
222
|
+
if printer.analysis is not None:
|
223
|
+
with printer.rich(style="warning"):
|
224
|
+
for arg in block.args:
|
225
|
+
printer.print_newline()
|
226
|
+
printer.print_analysis(
|
227
|
+
arg, prefix=f"{printer.state.ssa_id[arg]} --> "
|
228
|
+
)
|
229
|
+
with printer.align(printer.result_width(block.stmts)):
|
230
|
+
for stmt in block.stmts:
|
231
|
+
printer.print_newline()
|
232
|
+
printer.print_stmt(stmt)
|
233
|
+
printer.print_newline()
|
234
|
+
printer.plain_print("}")
|
235
|
+
with printer.rich(style="comment"):
|
236
|
+
printer.plain_print(f" -> purity={self.purity}")
|
237
|
+
|
238
|
+
|
239
|
+
@statement(dialect=dialect)
|
240
|
+
class Yield(ir.Statement):
|
241
|
+
name = "yield"
|
242
|
+
traits = frozenset({ir.IsTerminator()})
|
243
|
+
values: tuple[ir.SSAValue, ...] = info.argument(types.Any)
|
244
|
+
|
245
|
+
def __init__(self, *values: ir.SSAValue):
|
246
|
+
super().__init__(args=values, args_slice={"values": slice(None)})
|
247
|
+
|
248
|
+
def print_impl(self, printer: Printer) -> None:
|
249
|
+
printer.print_name(self)
|
250
|
+
printer.print_seq(self.values, prefix=" ", delim=", ")
|
@@ -0,0 +1,36 @@
|
|
1
|
+
from kirin import ir
|
2
|
+
from kirin.rewrite.abc import RewriteRule
|
3
|
+
from kirin.rewrite.result import RewriteResult
|
4
|
+
|
5
|
+
from .stmts import For, Yield, IfElse
|
6
|
+
|
7
|
+
|
8
|
+
class UnusedYield(RewriteRule):
|
9
|
+
"""Trim unused results from `For` and `IfElse` statements."""
|
10
|
+
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
12
|
+
if not isinstance(node, (For, IfElse)):
|
13
|
+
return RewriteResult()
|
14
|
+
|
15
|
+
any_unused = False
|
16
|
+
uses: list[int] = []
|
17
|
+
results: list[ir.ResultValue] = []
|
18
|
+
for idx, result in enumerate(node.results):
|
19
|
+
if result.uses:
|
20
|
+
uses.append(idx)
|
21
|
+
results.append(result)
|
22
|
+
else:
|
23
|
+
any_unused = True
|
24
|
+
|
25
|
+
if not any_unused:
|
26
|
+
return RewriteResult()
|
27
|
+
|
28
|
+
node._results = results
|
29
|
+
for region in node.regions:
|
30
|
+
for block in region.blocks:
|
31
|
+
if not isinstance(block.last_stmt, Yield):
|
32
|
+
continue
|
33
|
+
|
34
|
+
block.last_stmt.args = [block.last_stmt.args[idx] for idx in uses]
|
35
|
+
|
36
|
+
return RewriteResult(has_done_something=True)
|
@@ -0,0 +1,58 @@
|
|
1
|
+
from kirin import ir, types, interp
|
2
|
+
from kirin.analysis import ForwardFrame, TypeInference
|
3
|
+
from kirin.dialects import func
|
4
|
+
from kirin.dialects.eltype import ElType
|
5
|
+
|
6
|
+
from . import absint
|
7
|
+
from .stmts import For, IfElse
|
8
|
+
from ._dialect import dialect
|
9
|
+
|
10
|
+
|
11
|
+
@dialect.register(key="typeinfer")
|
12
|
+
class TypeInfer(absint.Methods):
|
13
|
+
|
14
|
+
@interp.impl(IfElse)
|
15
|
+
def if_else_(
|
16
|
+
self,
|
17
|
+
interp_: TypeInference,
|
18
|
+
frame: ForwardFrame[types.TypeAttribute],
|
19
|
+
stmt: IfElse,
|
20
|
+
):
|
21
|
+
frame.set(
|
22
|
+
stmt.cond, frame.get(stmt.cond).meet(types.Bool)
|
23
|
+
) # set cond backwards
|
24
|
+
return super().if_else(self, interp_, frame, stmt)
|
25
|
+
|
26
|
+
@interp.impl(For)
|
27
|
+
def for_loop(
|
28
|
+
self,
|
29
|
+
interp_: TypeInference,
|
30
|
+
frame: ForwardFrame[types.TypeAttribute],
|
31
|
+
stmt: For,
|
32
|
+
):
|
33
|
+
iterable = frame.get(stmt.iterable)
|
34
|
+
loop_vars = frame.get_values(stmt.initializers)
|
35
|
+
body_block = stmt.body.blocks[0]
|
36
|
+
block_args = body_block.args
|
37
|
+
|
38
|
+
eltype = interp_.run_stmt(ElType(ir.TestValue()), (iterable,))
|
39
|
+
if not isinstance(eltype, tuple): # error
|
40
|
+
return
|
41
|
+
item = eltype[0]
|
42
|
+
frame.set_values(block_args, (item,) + loop_vars)
|
43
|
+
|
44
|
+
if isinstance(body_block.last_stmt, func.Return):
|
45
|
+
frame.worklist.append(interp.Successor(body_block, item, *loop_vars))
|
46
|
+
return # if terminate is Return, there is no result
|
47
|
+
|
48
|
+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
|
49
|
+
body_frame.entries.update(frame.entries)
|
50
|
+
loop_vars_ = interp_.run_ssacfg_region(body_frame, stmt.body)
|
51
|
+
|
52
|
+
frame.entries.update(body_frame.entries)
|
53
|
+
if isinstance(loop_vars_, interp.ReturnValue):
|
54
|
+
return loop_vars_
|
55
|
+
elif isinstance(loop_vars_, tuple):
|
56
|
+
return interp_.join_results(loop_vars, loop_vars_)
|
57
|
+
else: # None, loop has no result
|
58
|
+
return
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from kirin import ir
|
2
|
+
from kirin.analysis import const
|
3
|
+
from kirin.dialects import func
|
4
|
+
from kirin.rewrite.abc import RewriteRule
|
5
|
+
from kirin.rewrite.result import RewriteResult
|
6
|
+
from kirin.dialects.py.constant import Constant
|
7
|
+
|
8
|
+
from .stmts import For, Yield, IfElse
|
9
|
+
|
10
|
+
|
11
|
+
class PickIfElse(RewriteRule):
|
12
|
+
|
13
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
14
|
+
if not isinstance(node, IfElse):
|
15
|
+
return RewriteResult()
|
16
|
+
|
17
|
+
if not isinstance(hint := node.cond.hints.get("const"), const.Value):
|
18
|
+
return RewriteResult()
|
19
|
+
|
20
|
+
if hint.data:
|
21
|
+
return self.insert_body(node, node.then_body)
|
22
|
+
else:
|
23
|
+
return self.insert_body(node, node.else_body)
|
24
|
+
|
25
|
+
def insert_body(self, node: IfElse, body: ir.Region):
|
26
|
+
body_block = body.blocks[0]
|
27
|
+
body_block.args[0].replace_by(node.cond)
|
28
|
+
block_stmt = body_block.first_stmt
|
29
|
+
while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
|
30
|
+
block_stmt.detach()
|
31
|
+
block_stmt.insert_before(node)
|
32
|
+
block_stmt = body_block.first_stmt
|
33
|
+
|
34
|
+
terminator = body_block.last_stmt
|
35
|
+
if isinstance(terminator, Yield):
|
36
|
+
for result, output in zip(node.results, terminator.values):
|
37
|
+
result.replace_by(output)
|
38
|
+
node.delete()
|
39
|
+
return RewriteResult(has_done_something=True)
|
40
|
+
elif isinstance(terminator, func.Return):
|
41
|
+
block = node.parent
|
42
|
+
assert block is not None
|
43
|
+
stmt = block.last_stmt
|
44
|
+
while stmt is not None and stmt is not node: # remove the rest of the block
|
45
|
+
delete_stmt = stmt
|
46
|
+
stmt = stmt.prev_stmt
|
47
|
+
delete_stmt.delete()
|
48
|
+
|
49
|
+
terminator.detach()
|
50
|
+
node.replace_by(terminator)
|
51
|
+
return RewriteResult(has_done_something=True)
|
52
|
+
else:
|
53
|
+
return RewriteResult()
|
54
|
+
|
55
|
+
|
56
|
+
class ForLoop(RewriteRule):
|
57
|
+
|
58
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
59
|
+
if not isinstance(node, For):
|
60
|
+
return RewriteResult()
|
61
|
+
|
62
|
+
# TODO: support for PartialTuple and IList with known length
|
63
|
+
if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
|
64
|
+
return RewriteResult()
|
65
|
+
|
66
|
+
loop_vars = node.initializers
|
67
|
+
for item in hint.data:
|
68
|
+
body = node.body.clone()
|
69
|
+
block = body.blocks[0]
|
70
|
+
item_stmt = Constant(item)
|
71
|
+
item_stmt.insert_before(node)
|
72
|
+
block.args[0].replace_by(item_stmt.result)
|
73
|
+
for var, input in zip(block.args[1:], loop_vars):
|
74
|
+
var.replace_by(input)
|
75
|
+
|
76
|
+
block_stmt = block.first_stmt
|
77
|
+
while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
|
78
|
+
block_stmt.detach()
|
79
|
+
block_stmt.insert_before(node)
|
80
|
+
block_stmt = block.first_stmt
|
81
|
+
|
82
|
+
terminator = block.last_stmt
|
83
|
+
# we assume Yield has the same # of values as initializers
|
84
|
+
# TODO: check this in validation
|
85
|
+
if isinstance(terminator, Yield):
|
86
|
+
loop_vars = terminator.values
|
87
|
+
terminator.delete()
|
88
|
+
|
89
|
+
for result, output in zip(node.results, loop_vars):
|
90
|
+
result.replace_by(output)
|
91
|
+
node.delete()
|
92
|
+
return RewriteResult(has_done_something=True)
|
kirin/emit/__init__.py
ADDED
kirin/emit/abc.py
ADDED
@@ -0,0 +1,89 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from typing import TypeVar
|
3
|
+
from dataclasses import field, dataclass
|
4
|
+
|
5
|
+
from kirin import ir, interp
|
6
|
+
from kirin.worklist import WorkList
|
7
|
+
|
8
|
+
ValueType = TypeVar("ValueType")
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class EmitFrame(interp.Frame[ValueType]):
|
13
|
+
worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
|
14
|
+
block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)
|
15
|
+
|
16
|
+
|
17
|
+
FrameType = TypeVar("FrameType", bound=EmitFrame)
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class EmitABC(interp.BaseInterpreter[FrameType, ValueType], ABC):
|
22
|
+
|
23
|
+
def run_callable_region(
|
24
|
+
self, frame: FrameType, code: ir.Statement, region: ir.Region
|
25
|
+
) -> ValueType:
|
26
|
+
results = self.eval_stmt(frame, code)
|
27
|
+
if isinstance(results, tuple):
|
28
|
+
if len(results) == 0:
|
29
|
+
return self.void
|
30
|
+
elif len(results) == 1:
|
31
|
+
return results[0]
|
32
|
+
raise interp.InterpreterError(f"Unexpected results {results}")
|
33
|
+
|
34
|
+
def run_ssacfg_region(
|
35
|
+
self, frame: FrameType, region: ir.Region
|
36
|
+
) -> tuple[ValueType, ...]:
|
37
|
+
frame.worklist.append(
|
38
|
+
interp.Successor(region.blocks[0], frame.get_values(region.blocks[0].args))
|
39
|
+
)
|
40
|
+
while (succ := frame.worklist.pop()) is not None:
|
41
|
+
block_header = self.emit_block(frame, succ.block)
|
42
|
+
frame.block_ref[succ.block] = block_header
|
43
|
+
return ()
|
44
|
+
|
45
|
+
def emit_attribute(self, attr: ir.Attribute) -> ValueType:
|
46
|
+
return getattr(
|
47
|
+
self, f"emit_type_{type(attr).__name__}", self.emit_attribute_fallback
|
48
|
+
)(attr)
|
49
|
+
|
50
|
+
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType:
|
51
|
+
if (method := self.registry.attributes.get(type(attr))) is not None:
|
52
|
+
return method(self, attr)
|
53
|
+
raise NotImplementedError(f"Attribute {type(attr)} not implemented")
|
54
|
+
|
55
|
+
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None:
|
56
|
+
return
|
57
|
+
|
58
|
+
def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None:
|
59
|
+
return
|
60
|
+
|
61
|
+
def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None:
|
62
|
+
return
|
63
|
+
|
64
|
+
def emit_block_end(self, frame: FrameType, block: ir.Block) -> None:
|
65
|
+
return
|
66
|
+
|
67
|
+
def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType:
|
68
|
+
self.emit_block_begin(frame, block)
|
69
|
+
stmt = block.first_stmt
|
70
|
+
while stmt is not None:
|
71
|
+
if self.consume_fuel() == self.FuelResult.Stop:
|
72
|
+
raise interp.FuelExhaustedError("fuel exhausted")
|
73
|
+
|
74
|
+
self.emit_stmt_begin(frame, stmt)
|
75
|
+
stmt_results = self.eval_stmt(frame, stmt)
|
76
|
+
self.emit_stmt_end(frame, stmt)
|
77
|
+
|
78
|
+
match stmt_results:
|
79
|
+
case tuple(values):
|
80
|
+
frame.set_values(stmt._results, values)
|
81
|
+
case interp.ReturnValue(_) | interp.YieldValue(_):
|
82
|
+
pass
|
83
|
+
case _:
|
84
|
+
raise ValueError(f"Unexpected result {stmt_results}")
|
85
|
+
|
86
|
+
stmt = stmt.next_stmt
|
87
|
+
|
88
|
+
self.emit_block_end(frame, block)
|
89
|
+
return frame.block_ref[block]
|
kirin/emit/abc.pyi
ADDED
@@ -0,0 +1,38 @@
|
|
1
|
+
from typing import TypeVar
|
2
|
+
from dataclasses import field, dataclass
|
3
|
+
|
4
|
+
from kirin import ir, types, interp
|
5
|
+
from kirin.worklist import WorkList
|
6
|
+
|
7
|
+
ValueType = TypeVar("ValueType")
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class EmitFrame(interp.Frame[ValueType]):
|
11
|
+
worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
|
12
|
+
block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)
|
13
|
+
|
14
|
+
FrameType = TypeVar("FrameType", bound=EmitFrame)
|
15
|
+
|
16
|
+
class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
|
17
|
+
def run_callable_region(
|
18
|
+
self, frame: FrameType, code: ir.Statement, region: ir.Region
|
19
|
+
) -> ValueType: ...
|
20
|
+
def run_ssacfg_region(
|
21
|
+
self, frame: FrameType, region: ir.Region
|
22
|
+
) -> tuple[ValueType, ...]: ...
|
23
|
+
def emit_attribute(self, attr: ir.Attribute) -> ValueType: ...
|
24
|
+
def emit_type_Any(self, attr: types.AnyType) -> ValueType: ...
|
25
|
+
def emit_type_Bottom(self, attr: types.BottomType) -> ValueType: ...
|
26
|
+
def emit_type_Literal(self, attr: types.Literal) -> ValueType: ...
|
27
|
+
def emit_type_Union(self, attr: types.Union) -> ValueType: ...
|
28
|
+
def emit_type_TypeVar(self, attr: types.TypeVar) -> ValueType: ...
|
29
|
+
def emit_type_Vararg(self, attr: types.Vararg) -> ValueType: ...
|
30
|
+
def emit_type_Generic(self, attr: types.Generic) -> ValueType: ...
|
31
|
+
def emit_type_PyClass(self, attr: types.PyClass) -> ValueType: ...
|
32
|
+
def emit_type_PyAttr(self, attr: ir.PyAttr) -> ValueType: ...
|
33
|
+
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
|
34
|
+
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...
|
35
|
+
def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: ...
|
36
|
+
def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: ...
|
37
|
+
def emit_block_end(self, frame: FrameType, block: ir.Block) -> None: ...
|
38
|
+
def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType: ...
|
kirin/emit/exceptions.py
ADDED
kirin/emit/julia.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
from typing import IO, TypeVar
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.ir.attrs.types import PyClass
|
5
|
+
from kirin.ir.nodes.block import Block
|
6
|
+
|
7
|
+
from .str import EmitStr, EmitStrFrame
|
8
|
+
|
9
|
+
IO_t = TypeVar("IO_t", bound=IO)
|
10
|
+
|
11
|
+
|
12
|
+
class EmitJulia(EmitStr[IO_t]):
|
13
|
+
keys = ["emit.julia"]
|
14
|
+
|
15
|
+
PYTYPE_MAP = {
|
16
|
+
int: "Int",
|
17
|
+
float: "Real",
|
18
|
+
str: "String",
|
19
|
+
bool: "Bool",
|
20
|
+
type(None): "Nothing",
|
21
|
+
dict: "Dict",
|
22
|
+
list: "Vector",
|
23
|
+
tuple: "Tuple",
|
24
|
+
}
|
25
|
+
|
26
|
+
def emit_block_begin(self, frame: EmitStrFrame, block: Block) -> None:
|
27
|
+
block_id = self.block_id[block]
|
28
|
+
frame.block_ref[block] = block_id
|
29
|
+
self.newline(frame)
|
30
|
+
self.write(f"@label {block_id};")
|
31
|
+
|
32
|
+
def emit_type_PyClass(self, attr: PyClass) -> str:
|
33
|
+
return self.PYTYPE_MAP.get(attr.typ, "Any")
|
34
|
+
|
35
|
+
def write_assign(self, frame: EmitStrFrame, result: ir.SSAValue, *args):
|
36
|
+
result_sym = self.ssa_id[result]
|
37
|
+
frame.set(result, result_sym)
|
38
|
+
self.writeln(frame, result_sym, " = ", *args)
|
39
|
+
return result_sym
|
40
|
+
|
41
|
+
def emit_binaryop(
|
42
|
+
self,
|
43
|
+
frame: EmitStrFrame,
|
44
|
+
sym: str,
|
45
|
+
lhs: ir.SSAValue,
|
46
|
+
rhs: ir.SSAValue,
|
47
|
+
result: ir.ResultValue,
|
48
|
+
):
|
49
|
+
return (
|
50
|
+
self.write_assign(
|
51
|
+
frame,
|
52
|
+
result,
|
53
|
+
f"{frame.get(lhs)} {sym} {frame.get(rhs)}",
|
54
|
+
),
|
55
|
+
)
|
56
|
+
|
57
|
+
def emit_type_PyAttr(self, attr: ir.PyAttr) -> str:
|
58
|
+
if isinstance(attr.data, (int, float)):
|
59
|
+
return repr(attr.data)
|
60
|
+
elif isinstance(attr.data, str):
|
61
|
+
return f'"{attr.data}"'
|
62
|
+
else:
|
63
|
+
raise ValueError(f"unsupported type {type(attr.data)}")
|
kirin/emit/str.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from typing import IO, Generic, TypeVar
|
3
|
+
from dataclasses import field, dataclass
|
4
|
+
|
5
|
+
from kirin import ir, interp, idtable
|
6
|
+
from kirin.emit.abc import EmitABC, EmitFrame
|
7
|
+
|
8
|
+
IO_t = TypeVar("IO_t", bound=IO)
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class EmitStrFrame(EmitFrame[str]):
|
13
|
+
indent: int = 0
|
14
|
+
captured: dict[ir.SSAValue, tuple[str, ...]] = field(default_factory=dict)
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class EmitStr(EmitABC[EmitStrFrame, str], ABC, Generic[IO_t]):
|
19
|
+
void = ""
|
20
|
+
file: IO_t
|
21
|
+
prefix: str = field(default="", kw_only=True)
|
22
|
+
prefix_if_none: str = field(default="var_", kw_only=True)
|
23
|
+
|
24
|
+
def initialize(self):
|
25
|
+
super().initialize()
|
26
|
+
self.ssa_id = idtable.IdTable[ir.SSAValue](
|
27
|
+
prefix=self.prefix, prefix_if_none=self.prefix_if_none
|
28
|
+
)
|
29
|
+
self.block_id = idtable.IdTable[ir.Block](prefix=self.prefix + "block_")
|
30
|
+
return self
|
31
|
+
|
32
|
+
def new_frame(self, code: ir.Statement) -> EmitStrFrame:
|
33
|
+
return EmitStrFrame.from_func_like(code)
|
34
|
+
|
35
|
+
def run_method(
|
36
|
+
self, method: ir.Method, args: tuple[str, ...]
|
37
|
+
) -> tuple[EmitStrFrame, str]:
|
38
|
+
if len(self.state.frames) >= self.max_depth:
|
39
|
+
raise interp.InterpreterError("maximum recursion depth exceeded")
|
40
|
+
return self.run_callable(method.code, (method.sym_name,) + args)
|
41
|
+
|
42
|
+
def write(self, *args):
|
43
|
+
for arg in args:
|
44
|
+
self.file.write(arg)
|
45
|
+
|
46
|
+
def newline(self, frame: EmitStrFrame):
|
47
|
+
self.file.write("\n" + " " * frame.indent)
|
48
|
+
|
49
|
+
def writeln(self, frame: EmitStrFrame, *args):
|
50
|
+
self.newline(frame)
|
51
|
+
self.write(*args)
|