kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,90 @@
1
+ """The unpack dialect for Python.
2
+
3
+ This module contains the dialect for the Python unpack semantics, including:
4
+
5
+ - The `Unpack` statement class.
6
+ - The lowering pass for the unpack statement.
7
+ - The concrete implementation of the unpack statement.
8
+ - The type inference implementation of the unpack statement.
9
+ - A helper function `unpacking` for unpacking Python AST nodes during lowering.
10
+ """
11
+
12
+ import ast
13
+
14
+ from kirin import ir, types, interp, lowering
15
+ from kirin.decl import info, statement
16
+ from kirin.print import Printer
17
+ from kirin.exceptions import DialectLoweringError
18
+
19
+ dialect = ir.Dialect("py.unpack")
20
+
21
+
22
+ @statement(dialect=dialect, init=False)
23
+ class Unpack(ir.Statement):
24
+ value: ir.SSAValue = info.argument(types.Any)
25
+ names: tuple[str | None, ...] = info.attribute()
26
+
27
+ def __init__(self, value: ir.SSAValue, names: tuple[str | None, ...]):
28
+ result_types = [types.Any] * len(names)
29
+ super().__init__(
30
+ args=(value,),
31
+ result_types=result_types,
32
+ args_slice={"value": 0},
33
+ attributes={"names": ir.PyAttr(names)},
34
+ )
35
+ for result, name in zip(self.results, names):
36
+ result.name = name
37
+
38
+ def print_impl(self, printer: Printer) -> None:
39
+ printer.print_name(self)
40
+ printer.plain_print(" ")
41
+ printer.print(self.value)
42
+
43
+
44
+ @dialect.register
45
+ class Concrete(interp.MethodTable):
46
+
47
+ @interp.impl(Unpack)
48
+ def unpack(self, interp: interp.Interpreter, frame: interp.Frame, stmt: Unpack):
49
+ return tuple(frame.get(stmt.value))
50
+
51
+
52
+ @dialect.register(key="typeinfer")
53
+ class TypeInfer(interp.MethodTable):
54
+
55
+ @interp.impl(Unpack)
56
+ def unpack(self, interp, frame: interp.Frame[types.TypeAttribute], stmt: Unpack):
57
+ value = frame.get(stmt.value)
58
+ if isinstance(value, types.Generic) and value.is_subseteq(types.Tuple):
59
+ if value.vararg:
60
+ rest = tuple(value.vararg.typ for _ in stmt.names[len(value.vars) :])
61
+ return tuple(value.vars) + rest
62
+ else:
63
+ return value.vars
64
+ # TODO: support unpacking other types
65
+ return tuple(types.Any for _ in stmt.names)
66
+
67
+
68
+ def unpacking(state: lowering.LoweringState, node: ast.expr, value: ir.SSAValue):
69
+ if isinstance(node, ast.Name):
70
+ state.current_frame.defs[node.id] = value
71
+ value.name = node.id
72
+ return
73
+ elif not isinstance(node, ast.Tuple):
74
+ raise DialectLoweringError(f"unsupported unpack node {node}")
75
+
76
+ names: list[str | None] = []
77
+ continue_unpack: list[int] = []
78
+ for idx, item in enumerate(node.elts):
79
+ if isinstance(item, ast.Name):
80
+ names.append(item.id)
81
+ else:
82
+ names.append(None)
83
+ continue_unpack.append(idx)
84
+ stmt = state.append_stmt(Unpack(value, tuple(names)))
85
+ for name, result in zip(names, stmt.results):
86
+ if name is not None:
87
+ state.current_frame.defs[name] = result
88
+
89
+ for idx in continue_unpack:
90
+ unpacking(state, node.elts[idx], stmt.results[idx])
@@ -0,0 +1,23 @@
1
+ """A Python-like structural Control Flow dialect.
2
+
3
+ This dialect provides constructs for expressing control flow in a structured
4
+ manner. The dialect provides constructs for expressing loops and conditionals.
5
+ Unlike MLIR SCF dialect, this dialect does not restrict the control flow to
6
+ statically analyzable forms. This dialect is designed to be compatible with
7
+ Python native control flow constructs.
8
+
9
+ This dialect depends on the following dialects:
10
+ - `eltype`: for obtaining the element type of a value.
11
+ """
12
+
13
+ from . import (
14
+ trim as trim,
15
+ absint as absint,
16
+ interp as interp,
17
+ unroll as unroll,
18
+ lowering as lowering,
19
+ constprop as constprop,
20
+ typeinfer as typeinfer,
21
+ )
22
+ from .stmts import For as For, Yield as Yield, IfElse as IfElse
23
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("scf")
@@ -0,0 +1,64 @@
1
+ from kirin import ir, interp
2
+ from kirin.analysis import const
3
+ from kirin.dialects import func
4
+
5
+ from .stmts import Yield, IfElse
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register(key="absint")
10
+ class Methods(interp.MethodTable):
11
+
12
+ @interp.impl(Yield)
13
+ def yield_stmt(
14
+ self,
15
+ interp_: interp.AbstractInterpreter,
16
+ frame: interp.AbstractFrame,
17
+ stmt: Yield,
18
+ ):
19
+ return interp.YieldValue(frame.get_values(stmt.values))
20
+
21
+ @interp.impl(IfElse)
22
+ def if_else(
23
+ self,
24
+ interp_: interp.AbstractInterpreter,
25
+ frame: interp.AbstractFrame,
26
+ stmt: IfElse,
27
+ ):
28
+ if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
29
+ if hint.data:
30
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
31
+ else:
32
+ return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
33
+ then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
34
+ else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
35
+
36
+ match (then_results, else_results):
37
+ case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
38
+ return interp.ReturnValue(then_value.join(else_value))
39
+ case (interp.ReturnValue(then_value), _):
40
+ return then_results
41
+ case (_, interp.ReturnValue(else_value)):
42
+ return else_results
43
+ case _:
44
+ return interp_.join_results(then_results, else_results)
45
+
46
+ def _infer_if_else_cond(
47
+ self,
48
+ interp_: interp.AbstractInterpreter,
49
+ frame: interp.AbstractFrame,
50
+ stmt: IfElse,
51
+ body: ir.Region,
52
+ ):
53
+ body_block = body.blocks[0]
54
+ body_term = body_block.last_stmt
55
+ if isinstance(body_term, func.Return):
56
+ frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond)))
57
+ return
58
+
59
+ with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
60
+ body_frame.entries.update(frame.entries)
61
+ body_frame.set(body_block.args[0], frame.get(stmt.cond))
62
+ ret = interp_.run_ssacfg_region(body_frame, body)
63
+ frame.entries.update(body_frame.entries)
64
+ return ret
@@ -0,0 +1,140 @@
1
+ from collections.abc import Iterable
2
+
3
+ from kirin import ir, interp
4
+ from kirin.analysis import const
5
+ from kirin.dialects import func
6
+
7
+ from .stmts import For, Yield, IfElse
8
+ from ._dialect import dialect
9
+
10
+ # NOTE: unlike concrete interpreter, we need to use a new frame
11
+ # for each iteration because otherwise join two constant values
12
+ # will result in bottom (error) element.
13
+
14
+
15
+ @dialect.register(key="constprop")
16
+ class DialectConstProp(interp.MethodTable):
17
+
18
+ @interp.impl(Yield)
19
+ def yield_stmt(
20
+ self,
21
+ interp_: const.Propagate,
22
+ frame: const.Frame,
23
+ stmt: Yield,
24
+ ):
25
+ return interp.YieldValue(frame.get_values(stmt.values))
26
+
27
+ @interp.impl(IfElse)
28
+ def if_else(
29
+ self,
30
+ interp_: const.Propagate,
31
+ frame: const.Frame,
32
+ stmt: IfElse,
33
+ ):
34
+ cond = frame.get(stmt.cond)
35
+ if isinstance(cond, const.Value):
36
+ if cond.data:
37
+ body = stmt.then_body
38
+ else:
39
+ body = stmt.else_body
40
+ body_frame, ret = self._prop_const_cond_ifelse(
41
+ interp_, frame, stmt, cond, body
42
+ )
43
+ frame.entries.update(body_frame.entries)
44
+ if not body_frame.frame_is_not_pure and not isinstance(
45
+ body.blocks[0].last_stmt, func.Return
46
+ ):
47
+ frame.should_be_pure.add(stmt)
48
+ return ret
49
+ else:
50
+ then_frame, then_results = self._prop_const_cond_ifelse(
51
+ interp_, frame, stmt, const.Value(True), stmt.then_body
52
+ )
53
+ else_frame, else_results = self._prop_const_cond_ifelse(
54
+ interp_, frame, stmt, const.Value(False), stmt.else_body
55
+ )
56
+ # NOTE: then_frame and else_frame do not change
57
+ # parent frame variables value except cond
58
+ frame.entries.update(then_frame.entries)
59
+ frame.entries.update(else_frame.entries)
60
+ # TODO: pick the non-return value
61
+ if isinstance(then_results, interp.ReturnValue) and isinstance(
62
+ else_results, interp.ReturnValue
63
+ ):
64
+ return interp.ReturnValue(then_results.value.join(else_results.value))
65
+ elif isinstance(then_results, interp.ReturnValue):
66
+ ret = else_results
67
+ elif isinstance(else_results, interp.ReturnValue):
68
+ ret = then_results
69
+ else:
70
+ if not (
71
+ then_frame.frame_is_not_pure is True
72
+ or else_frame.frame_is_not_pure is True
73
+ ):
74
+ frame.should_be_pure.add(stmt)
75
+ ret = interp_.join_results(then_results, else_results)
76
+ return ret
77
+
78
+ def _prop_const_cond_ifelse(
79
+ self,
80
+ interp_: const.Propagate,
81
+ frame: const.Frame,
82
+ stmt: IfElse,
83
+ cond: const.Value,
84
+ body: ir.Region,
85
+ ):
86
+ with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
87
+ body_frame.entries.update(frame.entries)
88
+ body_frame.set(body.blocks[0].args[0], cond)
89
+ results = interp_.run_ssacfg_region(body_frame, body)
90
+ return body_frame, results
91
+
92
+ @interp.impl(For)
93
+ def for_loop(
94
+ self,
95
+ interp_: const.Propagate,
96
+ frame: const.Frame,
97
+ stmt: For,
98
+ ):
99
+ iterable = frame.get(stmt.iterable)
100
+ if isinstance(iterable, const.Value):
101
+ return self._prop_const_iterable_forloop(interp_, frame, stmt, iterable)
102
+ else: # TODO: support other iteration
103
+ return tuple(interp_.lattice.top() for _ in stmt.results)
104
+
105
+ def _prop_const_iterable_forloop(
106
+ self,
107
+ interp_: const.Propagate,
108
+ frame: const.Frame,
109
+ stmt: For,
110
+ iterable: const.Value,
111
+ ):
112
+ frame_is_not_pure = False
113
+ if not isinstance(iterable.data, Iterable):
114
+ raise interp.InterpreterError(
115
+ f"Expected iterable, got {type(iterable.data)}"
116
+ )
117
+
118
+ loop_vars = frame.get_values(stmt.initializers)
119
+ body_block = stmt.body.blocks[0]
120
+ block_args = body_block.args
121
+
122
+ for value in iterable.data:
123
+ with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
124
+ body_frame.entries.update(frame.entries)
125
+ body_frame.set_values(
126
+ block_args,
127
+ (const.Value(value),) + loop_vars,
128
+ )
129
+ loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body)
130
+
131
+ if body_frame.frame_is_not_pure:
132
+ frame_is_not_pure = True
133
+ if loop_vars is None:
134
+ loop_vars = ()
135
+ elif isinstance(loop_vars, interp.ReturnValue):
136
+ return loop_vars
137
+
138
+ if not frame_is_not_pure:
139
+ frame.should_be_pure.add(stmt)
140
+ return loop_vars
@@ -0,0 +1,35 @@
1
+ from kirin import interp
2
+
3
+ from .stmts import For, Yield, IfElse
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register
8
+ class Concrete(interp.MethodTable):
9
+
10
+ @interp.impl(Yield)
11
+ def yield_stmt(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Yield):
12
+ return interp.YieldValue(frame.get_values(stmt.values))
13
+
14
+ @interp.impl(IfElse)
15
+ def if_else(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: IfElse):
16
+ cond = frame.get(stmt.cond)
17
+ if cond:
18
+ body = stmt.then_body
19
+ else:
20
+ body = stmt.else_body
21
+ return interp_.run_ssacfg_region(frame, body)
22
+
23
+ @interp.impl(For)
24
+ def for_loop(self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: For):
25
+ iterable = frame.get(stmt.iterable)
26
+ loop_vars = frame.get_values(stmt.initializers)
27
+ block_args = stmt.body.blocks[0].args
28
+ for value in iterable:
29
+ frame.set_values(block_args, (value,) + loop_vars)
30
+ loop_vars = interpreter.run_ssacfg_region(frame, stmt.body)
31
+ if isinstance(loop_vars, interp.ReturnValue):
32
+ return loop_vars
33
+ elif loop_vars is None:
34
+ loop_vars = ()
35
+ return loop_vars
@@ -0,0 +1,123 @@
1
+ import ast
2
+
3
+ from kirin import ir, types, lowering
4
+ from kirin.exceptions import DialectLoweringError
5
+ from kirin.dialects.py.unpack import unpacking
6
+
7
+ from .stmts import For, Yield, IfElse
8
+ from ._dialect import dialect
9
+
10
+
11
+ @dialect.register
12
+ class Lowering(lowering.FromPythonAST):
13
+
14
+ def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Result:
15
+ cond = state.visit(node.test).expect_one()
16
+ frame = state.current_frame
17
+ body_frame = lowering.Frame.from_stmts(node.body, state, globals=frame.globals)
18
+ then_cond = body_frame.curr_block.args.append_from(types.Bool, cond.name)
19
+ if cond.name:
20
+ body_frame.defs[cond.name] = then_cond
21
+ state.push_frame(body_frame)
22
+ state.exhaust(body_frame)
23
+ state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
24
+
25
+ else_frame = lowering.Frame.from_stmts(
26
+ node.orelse, state, globals=frame.globals
27
+ )
28
+ else_cond = else_frame.curr_block.args.append_from(types.Bool, cond.name)
29
+ if cond.name:
30
+ else_frame.defs[cond.name] = else_cond
31
+ state.push_frame(else_frame)
32
+ state.exhaust(else_frame)
33
+ state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
34
+
35
+ yield_names: list[str] = []
36
+ body_yields: list[ir.SSAValue] = []
37
+ else_yields: list[ir.SSAValue] = []
38
+ if node.orelse:
39
+ for name in body_frame.defs.keys():
40
+ if name in else_frame.defs:
41
+ yield_names.append(name)
42
+ body_yields.append(body_frame.get_scope(name))
43
+ else_yields.append(else_frame.get_scope(name))
44
+ else:
45
+ for name in body_frame.defs.keys():
46
+ if name in frame.defs:
47
+ yield_names.append(name)
48
+ body_yields.append(body_frame.get_scope(name))
49
+ value = frame.get(name)
50
+ if value is None:
51
+ raise DialectLoweringError(f"expected value for {name}")
52
+ else_yields.append(value)
53
+
54
+ if not (
55
+ body_frame.curr_block.last_stmt
56
+ and body_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
57
+ ):
58
+ body_frame.append_stmt(Yield(*body_yields))
59
+
60
+ if not (
61
+ else_frame.curr_block.last_stmt
62
+ and else_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
63
+ ):
64
+ else_frame.append_stmt(Yield(*else_yields))
65
+
66
+ stmt = IfElse(
67
+ cond,
68
+ then_body=body_frame.curr_region,
69
+ else_body=else_frame.curr_region,
70
+ )
71
+ for result, name, body, else_ in zip(
72
+ stmt.results, yield_names, body_yields, else_yields
73
+ ):
74
+ result.name = name
75
+ result.type = body.type.join(else_.type)
76
+ frame.defs[name] = result
77
+ state.append_stmt(stmt)
78
+ return lowering.Result()
79
+
80
+ def lower_For(
81
+ self, state: lowering.LoweringState, node: ast.For
82
+ ) -> lowering.Result:
83
+ iter_ = state.visit(node.iter).expect_one()
84
+
85
+ yields: list[str] = []
86
+
87
+ def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue):
88
+ if not capture.name:
89
+ raise DialectLoweringError("unexpected loop variable captured")
90
+ yields.append(capture.name)
91
+ return frame.curr_block.args.append_from(capture.type, capture.name)
92
+
93
+ body_frame = state.push_frame(
94
+ lowering.Frame.from_stmts(
95
+ node.body,
96
+ state,
97
+ globals=state.current_frame.globals,
98
+ capture_callback=new_block_arg_if_inside_loop,
99
+ )
100
+ )
101
+ loop_var = body_frame.curr_block.args.append_from(types.Any)
102
+ unpacking(state, node.target, loop_var)
103
+ state.exhaust(body_frame)
104
+ # NOTE: this frame won't have phi nodes
105
+ if yields and (
106
+ body_frame.curr_block.last_stmt is None
107
+ or not body_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
108
+ ):
109
+ body_frame.append_stmt(Yield(*[body_frame.defs[name] for name in yields])) # type: ignore
110
+ state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
111
+
112
+ initializers: list[ir.SSAValue] = []
113
+ for name in yields:
114
+ value = state.current_frame.get(name)
115
+ if value is None:
116
+ raise DialectLoweringError(f"expected value for {name}")
117
+ initializers.append(value)
118
+ stmt = For(iter_, body_frame.curr_region, *initializers)
119
+ for name, result in zip(yields, stmt.results):
120
+ state.current_frame.defs[name] = result
121
+ result.name = name
122
+ state.append_stmt(stmt)
123
+ return lowering.Result()