kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,250 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+ from kirin.exceptions import VerificationError, DialectLoweringError
4
+ from kirin.print.printer import Printer
5
+
6
+ from ._dialect import dialect
7
+
8
+
9
+ @statement(dialect=dialect, init=False)
10
+ class IfElse(ir.Statement):
11
+ """Python-like if-else statement.
12
+
13
+ This statement has a condition, then body, and else body.
14
+
15
+ Then body either terminates with a yield statement or `scf.return`.
16
+ """
17
+
18
+ name = "if"
19
+ traits = frozenset({ir.MaybePure()})
20
+ purity: bool = info.attribute(default=False)
21
+ cond: ir.SSAValue = info.argument(types.Any)
22
+ # NOTE: we don't enforce the type here
23
+ # because anything implements __bool__ in Python
24
+ # can be used as a condition
25
+ then_body: ir.Region = info.region(multi=False)
26
+ else_body: ir.Region = info.region(multi=False, default_factory=ir.Region)
27
+
28
+ def __init__(
29
+ self,
30
+ cond: ir.SSAValue,
31
+ then_body: ir.Region | ir.Block,
32
+ else_body: ir.Region | ir.Block | None = None,
33
+ ):
34
+ if isinstance(then_body, ir.Region):
35
+ if len(then_body.blocks) != 1:
36
+ raise DialectLoweringError(
37
+ "if-else statement must have a single block in the then region"
38
+ )
39
+ then_body_region = then_body
40
+ then_body = then_body_region.blocks[0]
41
+ elif isinstance(then_body, ir.Block):
42
+ then_body_region = ir.Region(then_body)
43
+
44
+ if isinstance(else_body, ir.Region):
45
+ if not else_body.blocks: # empty region
46
+ else_body_region = else_body
47
+ else_body = None
48
+ elif len(else_body.blocks) != 1:
49
+ raise DialectLoweringError(
50
+ "if-else statement must have a single block in the else region"
51
+ )
52
+ else:
53
+ else_body_region = else_body
54
+ else_body = else_body_region.blocks[0]
55
+ elif isinstance(else_body, ir.Block):
56
+ else_body_region = ir.Region(else_body)
57
+ else:
58
+ else_body_region = ir.Region()
59
+
60
+ # if either then or else body has yield, we generate results
61
+ # we assume if both have yields, they have the same number of results
62
+ then_yield = then_body.last_stmt
63
+ else_yield = else_body.last_stmt if else_body is not None else None
64
+ if then_yield is not None and isinstance(then_yield, Yield):
65
+ results = then_yield.values
66
+ elif else_yield is not None and isinstance(else_yield, Yield):
67
+ results = else_yield.values
68
+ else:
69
+ results = ()
70
+
71
+ result_types = tuple(value.type for value in results)
72
+ super().__init__(
73
+ args=(cond,),
74
+ regions=(then_body_region, else_body_region),
75
+ result_types=result_types,
76
+ args_slice={"cond": 0},
77
+ attributes={"purity": ir.PyAttr(False)},
78
+ )
79
+
80
+ def print_impl(self, printer: Printer) -> None:
81
+ printer.print_name(self)
82
+ printer.plain_print(" ")
83
+ printer.print(self.cond)
84
+ printer.plain_print(" ")
85
+ printer.print(self.then_body)
86
+ if self.else_body.blocks and not (
87
+ len(self.else_body.blocks[0].stmts) == 1
88
+ and isinstance(else_term := self.else_body.blocks[0].last_stmt, Yield)
89
+ and not else_term.values # empty yield
90
+ ):
91
+ printer.plain_print(" else ", style="keyword")
92
+ printer.print(self.else_body)
93
+
94
+ with printer.rich(style="comment"):
95
+ printer.plain_print(f" -> purity={self.purity}")
96
+
97
+ def verify(self) -> None:
98
+ from kirin.dialects.func import Return
99
+
100
+ if len(self.then_body.blocks) != 1:
101
+ raise VerificationError(self, "then region must have a single block")
102
+
103
+ if len(self.else_body.blocks) != 1:
104
+ raise VerificationError(self, "else region must have a single block")
105
+
106
+ then_block = self.then_body.blocks[0]
107
+ else_block = self.else_body.blocks[0]
108
+ if len(then_block.args) != 1:
109
+ raise VerificationError(
110
+ self, "then block must have a single argument for condition"
111
+ )
112
+
113
+ if len(else_block.args) != 1:
114
+ raise VerificationError(
115
+ self, "else block must have a single argument for condition"
116
+ )
117
+
118
+ then_stmt = then_block.last_stmt
119
+ else_stmt = else_block.last_stmt
120
+ if then_stmt is None or not isinstance(then_stmt, (Yield, Return)):
121
+ raise VerificationError(
122
+ self, "then block must terminate with a yield or return"
123
+ )
124
+
125
+ if else_stmt is None or not isinstance(else_stmt, (Yield, Return)):
126
+ raise VerificationError(
127
+ self, "else block must terminate with a yield or return"
128
+ )
129
+
130
+
131
+ @statement(dialect=dialect, init=False)
132
+ class For(ir.Statement):
133
+ name = "for"
134
+ traits = frozenset({ir.MaybePure()})
135
+ purity: bool = info.attribute(default=False)
136
+ iterable: ir.SSAValue = info.argument(types.Any)
137
+ body: ir.Region = info.region(multi=False)
138
+ initializers: tuple[ir.SSAValue, ...] = info.argument(types.Any)
139
+
140
+ def __init__(
141
+ self,
142
+ iterable: ir.SSAValue,
143
+ body: ir.Region,
144
+ *initializers: ir.SSAValue,
145
+ ):
146
+ stmt = body.blocks[0].last_stmt
147
+ if isinstance(stmt, Yield):
148
+ result_types = tuple(value.type for value in stmt.values)
149
+ else:
150
+ result_types = ()
151
+ super().__init__(
152
+ args=(iterable, *initializers),
153
+ regions=(body,),
154
+ result_types=result_types,
155
+ args_slice={"iterable": 0, "initializers": slice(1, None)},
156
+ attributes={"purity": ir.PyAttr(False)},
157
+ )
158
+
159
+ def verify(self) -> None:
160
+ from kirin.dialects.func import Return
161
+
162
+ if len(self.body.blocks) != 1:
163
+ raise VerificationError(self, "for loop body must have a single block")
164
+
165
+ if len(self.body.blocks[0].args) != len(self.initializers) + 1:
166
+ raise VerificationError(
167
+ self,
168
+ "for loop body must have arguments for all initializers and the loop variable",
169
+ )
170
+
171
+ stmt = self.body.blocks[0].last_stmt
172
+ if stmt is None or not isinstance(stmt, (Yield, Return)):
173
+ raise VerificationError(
174
+ self, "for loop body must terminate with a yield or return"
175
+ )
176
+
177
+ if isinstance(stmt, Return):
178
+ return
179
+
180
+ if len(stmt.values) != len(self.initializers):
181
+ raise VerificationError(
182
+ self,
183
+ "for loop body must have the same number of results as initializers",
184
+ )
185
+ if len(self.results) != len(stmt.values):
186
+ raise VerificationError(
187
+ self,
188
+ "for loop must have the same number of results as the yield in the body",
189
+ )
190
+
191
+ def print_impl(self, printer: Printer) -> None:
192
+ printer.print_name(self)
193
+ printer.plain_print(" ")
194
+ block = self.body.blocks[0]
195
+ printer.print(block.args[0])
196
+ printer.plain_print(" in ", style="keyword")
197
+ printer.print(self.iterable)
198
+ if self.results:
199
+ with printer.rich(style="comment"):
200
+ printer.plain_print(" -> ")
201
+ printer.print_seq(
202
+ tuple(result.type for result in self.results),
203
+ delim=", ",
204
+ style="comment",
205
+ )
206
+
207
+ with printer.indent():
208
+ if self.initializers:
209
+ printer.print_newline()
210
+ printer.plain_print("iter_args(")
211
+ for idx, (arg, val) in enumerate(
212
+ zip(block.args[1:], self.initializers)
213
+ ):
214
+ printer.print(arg)
215
+ printer.plain_print(" = ")
216
+ printer.print(val)
217
+ if idx < len(self.initializers) - 1:
218
+ printer.plain_print(", ")
219
+ printer.plain_print(")")
220
+
221
+ printer.plain_print(" {")
222
+ if printer.analysis is not None:
223
+ with printer.rich(style="warning"):
224
+ for arg in block.args:
225
+ printer.print_newline()
226
+ printer.print_analysis(
227
+ arg, prefix=f"{printer.state.ssa_id[arg]} --> "
228
+ )
229
+ with printer.align(printer.result_width(block.stmts)):
230
+ for stmt in block.stmts:
231
+ printer.print_newline()
232
+ printer.print_stmt(stmt)
233
+ printer.print_newline()
234
+ printer.plain_print("}")
235
+ with printer.rich(style="comment"):
236
+ printer.plain_print(f" -> purity={self.purity}")
237
+
238
+
239
+ @statement(dialect=dialect)
240
+ class Yield(ir.Statement):
241
+ name = "yield"
242
+ traits = frozenset({ir.IsTerminator()})
243
+ values: tuple[ir.SSAValue, ...] = info.argument(types.Any)
244
+
245
+ def __init__(self, *values: ir.SSAValue):
246
+ super().__init__(args=values, args_slice={"values": slice(None)})
247
+
248
+ def print_impl(self, printer: Printer) -> None:
249
+ printer.print_name(self)
250
+ printer.print_seq(self.values, prefix=" ", delim=", ")
@@ -0,0 +1,36 @@
1
+ from kirin import ir
2
+ from kirin.rewrite.abc import RewriteRule
3
+ from kirin.rewrite.result import RewriteResult
4
+
5
+ from .stmts import For, Yield, IfElse
6
+
7
+
8
+ class UnusedYield(RewriteRule):
9
+ """Trim unused results from `For` and `IfElse` statements."""
10
+
11
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12
+ if not isinstance(node, (For, IfElse)):
13
+ return RewriteResult()
14
+
15
+ any_unused = False
16
+ uses: list[int] = []
17
+ results: list[ir.ResultValue] = []
18
+ for idx, result in enumerate(node.results):
19
+ if result.uses:
20
+ uses.append(idx)
21
+ results.append(result)
22
+ else:
23
+ any_unused = True
24
+
25
+ if not any_unused:
26
+ return RewriteResult()
27
+
28
+ node._results = results
29
+ for region in node.regions:
30
+ for block in region.blocks:
31
+ if not isinstance(block.last_stmt, Yield):
32
+ continue
33
+
34
+ block.last_stmt.args = [block.last_stmt.args[idx] for idx in uses]
35
+
36
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,58 @@
1
+ from kirin import ir, types, interp
2
+ from kirin.analysis import ForwardFrame, TypeInference
3
+ from kirin.dialects import func
4
+ from kirin.dialects.eltype import ElType
5
+
6
+ from . import absint
7
+ from .stmts import For, IfElse
8
+ from ._dialect import dialect
9
+
10
+
11
+ @dialect.register(key="typeinfer")
12
+ class TypeInfer(absint.Methods):
13
+
14
+ @interp.impl(IfElse)
15
+ def if_else_(
16
+ self,
17
+ interp_: TypeInference,
18
+ frame: ForwardFrame[types.TypeAttribute],
19
+ stmt: IfElse,
20
+ ):
21
+ frame.set(
22
+ stmt.cond, frame.get(stmt.cond).meet(types.Bool)
23
+ ) # set cond backwards
24
+ return super().if_else(self, interp_, frame, stmt)
25
+
26
+ @interp.impl(For)
27
+ def for_loop(
28
+ self,
29
+ interp_: TypeInference,
30
+ frame: ForwardFrame[types.TypeAttribute],
31
+ stmt: For,
32
+ ):
33
+ iterable = frame.get(stmt.iterable)
34
+ loop_vars = frame.get_values(stmt.initializers)
35
+ body_block = stmt.body.blocks[0]
36
+ block_args = body_block.args
37
+
38
+ eltype = interp_.run_stmt(ElType(ir.TestValue()), (iterable,))
39
+ if not isinstance(eltype, tuple): # error
40
+ return
41
+ item = eltype[0]
42
+ frame.set_values(block_args, (item,) + loop_vars)
43
+
44
+ if isinstance(body_block.last_stmt, func.Return):
45
+ frame.worklist.append(interp.Successor(body_block, item, *loop_vars))
46
+ return # if terminate is Return, there is no result
47
+
48
+ with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
49
+ body_frame.entries.update(frame.entries)
50
+ loop_vars_ = interp_.run_ssacfg_region(body_frame, stmt.body)
51
+
52
+ frame.entries.update(body_frame.entries)
53
+ if isinstance(loop_vars_, interp.ReturnValue):
54
+ return loop_vars_
55
+ elif isinstance(loop_vars_, tuple):
56
+ return interp_.join_results(loop_vars, loop_vars_)
57
+ else: # None, loop has no result
58
+ return
@@ -0,0 +1,92 @@
1
+ from kirin import ir
2
+ from kirin.analysis import const
3
+ from kirin.dialects import func
4
+ from kirin.rewrite.abc import RewriteRule
5
+ from kirin.rewrite.result import RewriteResult
6
+ from kirin.dialects.py.constant import Constant
7
+
8
+ from .stmts import For, Yield, IfElse
9
+
10
+
11
+ class PickIfElse(RewriteRule):
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ if not isinstance(node, IfElse):
15
+ return RewriteResult()
16
+
17
+ if not isinstance(hint := node.cond.hints.get("const"), const.Value):
18
+ return RewriteResult()
19
+
20
+ if hint.data:
21
+ return self.insert_body(node, node.then_body)
22
+ else:
23
+ return self.insert_body(node, node.else_body)
24
+
25
+ def insert_body(self, node: IfElse, body: ir.Region):
26
+ body_block = body.blocks[0]
27
+ body_block.args[0].replace_by(node.cond)
28
+ block_stmt = body_block.first_stmt
29
+ while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
30
+ block_stmt.detach()
31
+ block_stmt.insert_before(node)
32
+ block_stmt = body_block.first_stmt
33
+
34
+ terminator = body_block.last_stmt
35
+ if isinstance(terminator, Yield):
36
+ for result, output in zip(node.results, terminator.values):
37
+ result.replace_by(output)
38
+ node.delete()
39
+ return RewriteResult(has_done_something=True)
40
+ elif isinstance(terminator, func.Return):
41
+ block = node.parent
42
+ assert block is not None
43
+ stmt = block.last_stmt
44
+ while stmt is not None and stmt is not node: # remove the rest of the block
45
+ delete_stmt = stmt
46
+ stmt = stmt.prev_stmt
47
+ delete_stmt.delete()
48
+
49
+ terminator.detach()
50
+ node.replace_by(terminator)
51
+ return RewriteResult(has_done_something=True)
52
+ else:
53
+ return RewriteResult()
54
+
55
+
56
+ class ForLoop(RewriteRule):
57
+
58
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
59
+ if not isinstance(node, For):
60
+ return RewriteResult()
61
+
62
+ # TODO: support for PartialTuple and IList with known length
63
+ if not isinstance(hint := node.iterable.hints.get("const"), const.Value):
64
+ return RewriteResult()
65
+
66
+ loop_vars = node.initializers
67
+ for item in hint.data:
68
+ body = node.body.clone()
69
+ block = body.blocks[0]
70
+ item_stmt = Constant(item)
71
+ item_stmt.insert_before(node)
72
+ block.args[0].replace_by(item_stmt.result)
73
+ for var, input in zip(block.args[1:], loop_vars):
74
+ var.replace_by(input)
75
+
76
+ block_stmt = block.first_stmt
77
+ while block_stmt and not block_stmt.has_trait(ir.IsTerminator):
78
+ block_stmt.detach()
79
+ block_stmt.insert_before(node)
80
+ block_stmt = block.first_stmt
81
+
82
+ terminator = block.last_stmt
83
+ # we assume Yield has the same # of values as initializers
84
+ # TODO: check this in validation
85
+ if isinstance(terminator, Yield):
86
+ loop_vars = terminator.values
87
+ terminator.delete()
88
+
89
+ for result, output in zip(node.results, loop_vars):
90
+ result.replace_by(output)
91
+ node.delete()
92
+ return RewriteResult(has_done_something=True)
kirin/emit/__init__.py ADDED
@@ -0,0 +1,3 @@
1
+ from .abc import EmitABC as EmitABC, EmitFrame as EmitFrame
2
+ from .str import EmitStr as EmitStr, EmitStrFrame as EmitStrFrame
3
+ from .exceptions import EmitError as EmitError
kirin/emit/abc.py ADDED
@@ -0,0 +1,89 @@
1
+ from abc import ABC
2
+ from typing import TypeVar
3
+ from dataclasses import field, dataclass
4
+
5
+ from kirin import ir, interp
6
+ from kirin.worklist import WorkList
7
+
8
+ ValueType = TypeVar("ValueType")
9
+
10
+
11
+ @dataclass
12
+ class EmitFrame(interp.Frame[ValueType]):
13
+ worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
14
+ block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)
15
+
16
+
17
+ FrameType = TypeVar("FrameType", bound=EmitFrame)
18
+
19
+
20
+ @dataclass
21
+ class EmitABC(interp.BaseInterpreter[FrameType, ValueType], ABC):
22
+
23
+ def run_callable_region(
24
+ self, frame: FrameType, code: ir.Statement, region: ir.Region
25
+ ) -> ValueType:
26
+ results = self.eval_stmt(frame, code)
27
+ if isinstance(results, tuple):
28
+ if len(results) == 0:
29
+ return self.void
30
+ elif len(results) == 1:
31
+ return results[0]
32
+ raise interp.InterpreterError(f"Unexpected results {results}")
33
+
34
+ def run_ssacfg_region(
35
+ self, frame: FrameType, region: ir.Region
36
+ ) -> tuple[ValueType, ...]:
37
+ frame.worklist.append(
38
+ interp.Successor(region.blocks[0], frame.get_values(region.blocks[0].args))
39
+ )
40
+ while (succ := frame.worklist.pop()) is not None:
41
+ block_header = self.emit_block(frame, succ.block)
42
+ frame.block_ref[succ.block] = block_header
43
+ return ()
44
+
45
+ def emit_attribute(self, attr: ir.Attribute) -> ValueType:
46
+ return getattr(
47
+ self, f"emit_type_{type(attr).__name__}", self.emit_attribute_fallback
48
+ )(attr)
49
+
50
+ def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType:
51
+ if (method := self.registry.attributes.get(type(attr))) is not None:
52
+ return method(self, attr)
53
+ raise NotImplementedError(f"Attribute {type(attr)} not implemented")
54
+
55
+ def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None:
56
+ return
57
+
58
+ def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None:
59
+ return
60
+
61
+ def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None:
62
+ return
63
+
64
+ def emit_block_end(self, frame: FrameType, block: ir.Block) -> None:
65
+ return
66
+
67
+ def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType:
68
+ self.emit_block_begin(frame, block)
69
+ stmt = block.first_stmt
70
+ while stmt is not None:
71
+ if self.consume_fuel() == self.FuelResult.Stop:
72
+ raise interp.FuelExhaustedError("fuel exhausted")
73
+
74
+ self.emit_stmt_begin(frame, stmt)
75
+ stmt_results = self.eval_stmt(frame, stmt)
76
+ self.emit_stmt_end(frame, stmt)
77
+
78
+ match stmt_results:
79
+ case tuple(values):
80
+ frame.set_values(stmt._results, values)
81
+ case interp.ReturnValue(_) | interp.YieldValue(_):
82
+ pass
83
+ case _:
84
+ raise ValueError(f"Unexpected result {stmt_results}")
85
+
86
+ stmt = stmt.next_stmt
87
+
88
+ self.emit_block_end(frame, block)
89
+ return frame.block_ref[block]
kirin/emit/abc.pyi ADDED
@@ -0,0 +1,38 @@
1
+ from typing import TypeVar
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir, types, interp
5
+ from kirin.worklist import WorkList
6
+
7
+ ValueType = TypeVar("ValueType")
8
+
9
+ @dataclass
10
+ class EmitFrame(interp.Frame[ValueType]):
11
+ worklist: WorkList[interp.Successor] = field(default_factory=WorkList)
12
+ block_ref: dict[ir.Block, ValueType] = field(default_factory=dict)
13
+
14
+ FrameType = TypeVar("FrameType", bound=EmitFrame)
15
+
16
+ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
17
+ def run_callable_region(
18
+ self, frame: FrameType, code: ir.Statement, region: ir.Region
19
+ ) -> ValueType: ...
20
+ def run_ssacfg_region(
21
+ self, frame: FrameType, region: ir.Region
22
+ ) -> tuple[ValueType, ...]: ...
23
+ def emit_attribute(self, attr: ir.Attribute) -> ValueType: ...
24
+ def emit_type_Any(self, attr: types.AnyType) -> ValueType: ...
25
+ def emit_type_Bottom(self, attr: types.BottomType) -> ValueType: ...
26
+ def emit_type_Literal(self, attr: types.Literal) -> ValueType: ...
27
+ def emit_type_Union(self, attr: types.Union) -> ValueType: ...
28
+ def emit_type_TypeVar(self, attr: types.TypeVar) -> ValueType: ...
29
+ def emit_type_Vararg(self, attr: types.Vararg) -> ValueType: ...
30
+ def emit_type_Generic(self, attr: types.Generic) -> ValueType: ...
31
+ def emit_type_PyClass(self, attr: types.PyClass) -> ValueType: ...
32
+ def emit_type_PyAttr(self, attr: ir.PyAttr) -> ValueType: ...
33
+ def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
34
+ def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...
35
+ def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: ...
36
+ def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: ...
37
+ def emit_block_end(self, frame: FrameType, block: ir.Block) -> None: ...
38
+ def emit_block(self, frame: FrameType, block: ir.Block) -> ValueType: ...
@@ -0,0 +1,5 @@
1
+ from kirin.interp import InterpreterError
2
+
3
+
4
+ class EmitError(InterpreterError):
5
+ pass
kirin/emit/julia.py ADDED
@@ -0,0 +1,63 @@
1
+ from typing import IO, TypeVar
2
+
3
+ from kirin import ir
4
+ from kirin.ir.attrs.types import PyClass
5
+ from kirin.ir.nodes.block import Block
6
+
7
+ from .str import EmitStr, EmitStrFrame
8
+
9
+ IO_t = TypeVar("IO_t", bound=IO)
10
+
11
+
12
+ class EmitJulia(EmitStr[IO_t]):
13
+ keys = ["emit.julia"]
14
+
15
+ PYTYPE_MAP = {
16
+ int: "Int",
17
+ float: "Real",
18
+ str: "String",
19
+ bool: "Bool",
20
+ type(None): "Nothing",
21
+ dict: "Dict",
22
+ list: "Vector",
23
+ tuple: "Tuple",
24
+ }
25
+
26
+ def emit_block_begin(self, frame: EmitStrFrame, block: Block) -> None:
27
+ block_id = self.block_id[block]
28
+ frame.block_ref[block] = block_id
29
+ self.newline(frame)
30
+ self.write(f"@label {block_id};")
31
+
32
+ def emit_type_PyClass(self, attr: PyClass) -> str:
33
+ return self.PYTYPE_MAP.get(attr.typ, "Any")
34
+
35
+ def write_assign(self, frame: EmitStrFrame, result: ir.SSAValue, *args):
36
+ result_sym = self.ssa_id[result]
37
+ frame.set(result, result_sym)
38
+ self.writeln(frame, result_sym, " = ", *args)
39
+ return result_sym
40
+
41
+ def emit_binaryop(
42
+ self,
43
+ frame: EmitStrFrame,
44
+ sym: str,
45
+ lhs: ir.SSAValue,
46
+ rhs: ir.SSAValue,
47
+ result: ir.ResultValue,
48
+ ):
49
+ return (
50
+ self.write_assign(
51
+ frame,
52
+ result,
53
+ f"{frame.get(lhs)} {sym} {frame.get(rhs)}",
54
+ ),
55
+ )
56
+
57
+ def emit_type_PyAttr(self, attr: ir.PyAttr) -> str:
58
+ if isinstance(attr.data, (int, float)):
59
+ return repr(attr.data)
60
+ elif isinstance(attr.data, str):
61
+ return f'"{attr.data}"'
62
+ else:
63
+ raise ValueError(f"unsupported type {type(attr.data)}")
kirin/emit/str.py ADDED
@@ -0,0 +1,51 @@
1
+ from abc import ABC
2
+ from typing import IO, Generic, TypeVar
3
+ from dataclasses import field, dataclass
4
+
5
+ from kirin import ir, interp, idtable
6
+ from kirin.emit.abc import EmitABC, EmitFrame
7
+
8
+ IO_t = TypeVar("IO_t", bound=IO)
9
+
10
+
11
+ @dataclass
12
+ class EmitStrFrame(EmitFrame[str]):
13
+ indent: int = 0
14
+ captured: dict[ir.SSAValue, tuple[str, ...]] = field(default_factory=dict)
15
+
16
+
17
+ @dataclass
18
+ class EmitStr(EmitABC[EmitStrFrame, str], ABC, Generic[IO_t]):
19
+ void = ""
20
+ file: IO_t
21
+ prefix: str = field(default="", kw_only=True)
22
+ prefix_if_none: str = field(default="var_", kw_only=True)
23
+
24
+ def initialize(self):
25
+ super().initialize()
26
+ self.ssa_id = idtable.IdTable[ir.SSAValue](
27
+ prefix=self.prefix, prefix_if_none=self.prefix_if_none
28
+ )
29
+ self.block_id = idtable.IdTable[ir.Block](prefix=self.prefix + "block_")
30
+ return self
31
+
32
+ def new_frame(self, code: ir.Statement) -> EmitStrFrame:
33
+ return EmitStrFrame.from_func_like(code)
34
+
35
+ def run_method(
36
+ self, method: ir.Method, args: tuple[str, ...]
37
+ ) -> tuple[EmitStrFrame, str]:
38
+ if len(self.state.frames) >= self.max_depth:
39
+ raise interp.InterpreterError("maximum recursion depth exceeded")
40
+ return self.run_callable(method.code, (method.sym_name,) + args)
41
+
42
+ def write(self, *args):
43
+ for arg in args:
44
+ self.file.write(arg)
45
+
46
+ def newline(self, frame: EmitStrFrame):
47
+ self.file.write("\n" + " " * frame.indent)
48
+
49
+ def writeln(self, frame: EmitStrFrame, *args):
50
+ self.newline(frame)
51
+ self.write(*args)