kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,90 @@
|
|
1
|
+
"""The unpack dialect for Python.
|
2
|
+
|
3
|
+
This module contains the dialect for the Python unpack semantics, including:
|
4
|
+
|
5
|
+
- The `Unpack` statement class.
|
6
|
+
- The lowering pass for the unpack statement.
|
7
|
+
- The concrete implementation of the unpack statement.
|
8
|
+
- The type inference implementation of the unpack statement.
|
9
|
+
- A helper function `unpacking` for unpacking Python AST nodes during lowering.
|
10
|
+
"""
|
11
|
+
|
12
|
+
import ast
|
13
|
+
|
14
|
+
from kirin import ir, types, interp, lowering
|
15
|
+
from kirin.decl import info, statement
|
16
|
+
from kirin.print import Printer
|
17
|
+
from kirin.exceptions import DialectLoweringError
|
18
|
+
|
19
|
+
dialect = ir.Dialect("py.unpack")
|
20
|
+
|
21
|
+
|
22
|
+
@statement(dialect=dialect, init=False)
|
23
|
+
class Unpack(ir.Statement):
|
24
|
+
value: ir.SSAValue = info.argument(types.Any)
|
25
|
+
names: tuple[str | None, ...] = info.attribute()
|
26
|
+
|
27
|
+
def __init__(self, value: ir.SSAValue, names: tuple[str | None, ...]):
|
28
|
+
result_types = [types.Any] * len(names)
|
29
|
+
super().__init__(
|
30
|
+
args=(value,),
|
31
|
+
result_types=result_types,
|
32
|
+
args_slice={"value": 0},
|
33
|
+
attributes={"names": ir.PyAttr(names)},
|
34
|
+
)
|
35
|
+
for result, name in zip(self.results, names):
|
36
|
+
result.name = name
|
37
|
+
|
38
|
+
def print_impl(self, printer: Printer) -> None:
|
39
|
+
printer.print_name(self)
|
40
|
+
printer.plain_print(" ")
|
41
|
+
printer.print(self.value)
|
42
|
+
|
43
|
+
|
44
|
+
@dialect.register
|
45
|
+
class Concrete(interp.MethodTable):
|
46
|
+
|
47
|
+
@interp.impl(Unpack)
|
48
|
+
def unpack(self, interp: interp.Interpreter, frame: interp.Frame, stmt: Unpack):
|
49
|
+
return tuple(frame.get(stmt.value))
|
50
|
+
|
51
|
+
|
52
|
+
@dialect.register(key="typeinfer")
|
53
|
+
class TypeInfer(interp.MethodTable):
|
54
|
+
|
55
|
+
@interp.impl(Unpack)
|
56
|
+
def unpack(self, interp, frame: interp.Frame[types.TypeAttribute], stmt: Unpack):
|
57
|
+
value = frame.get(stmt.value)
|
58
|
+
if isinstance(value, types.Generic) and value.is_subseteq(types.Tuple):
|
59
|
+
if value.vararg:
|
60
|
+
rest = tuple(value.vararg.typ for _ in stmt.names[len(value.vars) :])
|
61
|
+
return tuple(value.vars) + rest
|
62
|
+
else:
|
63
|
+
return value.vars
|
64
|
+
# TODO: support unpacking other types
|
65
|
+
return tuple(types.Any for _ in stmt.names)
|
66
|
+
|
67
|
+
|
68
|
+
def unpacking(state: lowering.LoweringState, node: ast.expr, value: ir.SSAValue):
|
69
|
+
if isinstance(node, ast.Name):
|
70
|
+
state.current_frame.defs[node.id] = value
|
71
|
+
value.name = node.id
|
72
|
+
return
|
73
|
+
elif not isinstance(node, ast.Tuple):
|
74
|
+
raise DialectLoweringError(f"unsupported unpack node {node}")
|
75
|
+
|
76
|
+
names: list[str | None] = []
|
77
|
+
continue_unpack: list[int] = []
|
78
|
+
for idx, item in enumerate(node.elts):
|
79
|
+
if isinstance(item, ast.Name):
|
80
|
+
names.append(item.id)
|
81
|
+
else:
|
82
|
+
names.append(None)
|
83
|
+
continue_unpack.append(idx)
|
84
|
+
stmt = state.append_stmt(Unpack(value, tuple(names)))
|
85
|
+
for name, result in zip(names, stmt.results):
|
86
|
+
if name is not None:
|
87
|
+
state.current_frame.defs[name] = result
|
88
|
+
|
89
|
+
for idx in continue_unpack:
|
90
|
+
unpacking(state, node.elts[idx], stmt.results[idx])
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""A Python-like structural Control Flow dialect.
|
2
|
+
|
3
|
+
This dialect provides constructs for expressing control flow in a structured
|
4
|
+
manner. The dialect provides constructs for expressing loops and conditionals.
|
5
|
+
Unlike MLIR SCF dialect, this dialect does not restrict the control flow to
|
6
|
+
statically analyzable forms. This dialect is designed to be compatible with
|
7
|
+
Python native control flow constructs.
|
8
|
+
|
9
|
+
This dialect depends on the following dialects:
|
10
|
+
- `eltype`: for obtaining the element type of a value.
|
11
|
+
"""
|
12
|
+
|
13
|
+
from . import (
|
14
|
+
trim as trim,
|
15
|
+
absint as absint,
|
16
|
+
interp as interp,
|
17
|
+
unroll as unroll,
|
18
|
+
lowering as lowering,
|
19
|
+
constprop as constprop,
|
20
|
+
typeinfer as typeinfer,
|
21
|
+
)
|
22
|
+
from .stmts import For as For, Yield as Yield, IfElse as IfElse
|
23
|
+
from ._dialect import dialect as dialect
|
@@ -0,0 +1,64 @@
|
|
1
|
+
from kirin import ir, interp
|
2
|
+
from kirin.analysis import const
|
3
|
+
from kirin.dialects import func
|
4
|
+
|
5
|
+
from .stmts import Yield, IfElse
|
6
|
+
from ._dialect import dialect
|
7
|
+
|
8
|
+
|
9
|
+
@dialect.register(key="absint")
|
10
|
+
class Methods(interp.MethodTable):
|
11
|
+
|
12
|
+
@interp.impl(Yield)
|
13
|
+
def yield_stmt(
|
14
|
+
self,
|
15
|
+
interp_: interp.AbstractInterpreter,
|
16
|
+
frame: interp.AbstractFrame,
|
17
|
+
stmt: Yield,
|
18
|
+
):
|
19
|
+
return interp.YieldValue(frame.get_values(stmt.values))
|
20
|
+
|
21
|
+
@interp.impl(IfElse)
|
22
|
+
def if_else(
|
23
|
+
self,
|
24
|
+
interp_: interp.AbstractInterpreter,
|
25
|
+
frame: interp.AbstractFrame,
|
26
|
+
stmt: IfElse,
|
27
|
+
):
|
28
|
+
if isinstance(hint := stmt.cond.hints.get("const"), const.Value):
|
29
|
+
if hint.data:
|
30
|
+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
|
31
|
+
else:
|
32
|
+
return self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
|
33
|
+
then_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.then_body)
|
34
|
+
else_results = self._infer_if_else_cond(interp_, frame, stmt, stmt.else_body)
|
35
|
+
|
36
|
+
match (then_results, else_results):
|
37
|
+
case (interp.ReturnValue(then_value), interp.ReturnValue(else_value)):
|
38
|
+
return interp.ReturnValue(then_value.join(else_value))
|
39
|
+
case (interp.ReturnValue(then_value), _):
|
40
|
+
return then_results
|
41
|
+
case (_, interp.ReturnValue(else_value)):
|
42
|
+
return else_results
|
43
|
+
case _:
|
44
|
+
return interp_.join_results(then_results, else_results)
|
45
|
+
|
46
|
+
def _infer_if_else_cond(
|
47
|
+
self,
|
48
|
+
interp_: interp.AbstractInterpreter,
|
49
|
+
frame: interp.AbstractFrame,
|
50
|
+
stmt: IfElse,
|
51
|
+
body: ir.Region,
|
52
|
+
):
|
53
|
+
body_block = body.blocks[0]
|
54
|
+
body_term = body_block.last_stmt
|
55
|
+
if isinstance(body_term, func.Return):
|
56
|
+
frame.worklist.append(interp.Successor(body_block, frame.get(stmt.cond)))
|
57
|
+
return
|
58
|
+
|
59
|
+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
|
60
|
+
body_frame.entries.update(frame.entries)
|
61
|
+
body_frame.set(body_block.args[0], frame.get(stmt.cond))
|
62
|
+
ret = interp_.run_ssacfg_region(body_frame, body)
|
63
|
+
frame.entries.update(body_frame.entries)
|
64
|
+
return ret
|
@@ -0,0 +1,140 @@
|
|
1
|
+
from collections.abc import Iterable
|
2
|
+
|
3
|
+
from kirin import ir, interp
|
4
|
+
from kirin.analysis import const
|
5
|
+
from kirin.dialects import func
|
6
|
+
|
7
|
+
from .stmts import For, Yield, IfElse
|
8
|
+
from ._dialect import dialect
|
9
|
+
|
10
|
+
# NOTE: unlike concrete interpreter, we need to use a new frame
|
11
|
+
# for each iteration because otherwise join two constant values
|
12
|
+
# will result in bottom (error) element.
|
13
|
+
|
14
|
+
|
15
|
+
@dialect.register(key="constprop")
|
16
|
+
class DialectConstProp(interp.MethodTable):
|
17
|
+
|
18
|
+
@interp.impl(Yield)
|
19
|
+
def yield_stmt(
|
20
|
+
self,
|
21
|
+
interp_: const.Propagate,
|
22
|
+
frame: const.Frame,
|
23
|
+
stmt: Yield,
|
24
|
+
):
|
25
|
+
return interp.YieldValue(frame.get_values(stmt.values))
|
26
|
+
|
27
|
+
@interp.impl(IfElse)
|
28
|
+
def if_else(
|
29
|
+
self,
|
30
|
+
interp_: const.Propagate,
|
31
|
+
frame: const.Frame,
|
32
|
+
stmt: IfElse,
|
33
|
+
):
|
34
|
+
cond = frame.get(stmt.cond)
|
35
|
+
if isinstance(cond, const.Value):
|
36
|
+
if cond.data:
|
37
|
+
body = stmt.then_body
|
38
|
+
else:
|
39
|
+
body = stmt.else_body
|
40
|
+
body_frame, ret = self._prop_const_cond_ifelse(
|
41
|
+
interp_, frame, stmt, cond, body
|
42
|
+
)
|
43
|
+
frame.entries.update(body_frame.entries)
|
44
|
+
if not body_frame.frame_is_not_pure and not isinstance(
|
45
|
+
body.blocks[0].last_stmt, func.Return
|
46
|
+
):
|
47
|
+
frame.should_be_pure.add(stmt)
|
48
|
+
return ret
|
49
|
+
else:
|
50
|
+
then_frame, then_results = self._prop_const_cond_ifelse(
|
51
|
+
interp_, frame, stmt, const.Value(True), stmt.then_body
|
52
|
+
)
|
53
|
+
else_frame, else_results = self._prop_const_cond_ifelse(
|
54
|
+
interp_, frame, stmt, const.Value(False), stmt.else_body
|
55
|
+
)
|
56
|
+
# NOTE: then_frame and else_frame do not change
|
57
|
+
# parent frame variables value except cond
|
58
|
+
frame.entries.update(then_frame.entries)
|
59
|
+
frame.entries.update(else_frame.entries)
|
60
|
+
# TODO: pick the non-return value
|
61
|
+
if isinstance(then_results, interp.ReturnValue) and isinstance(
|
62
|
+
else_results, interp.ReturnValue
|
63
|
+
):
|
64
|
+
return interp.ReturnValue(then_results.value.join(else_results.value))
|
65
|
+
elif isinstance(then_results, interp.ReturnValue):
|
66
|
+
ret = else_results
|
67
|
+
elif isinstance(else_results, interp.ReturnValue):
|
68
|
+
ret = then_results
|
69
|
+
else:
|
70
|
+
if not (
|
71
|
+
then_frame.frame_is_not_pure is True
|
72
|
+
or else_frame.frame_is_not_pure is True
|
73
|
+
):
|
74
|
+
frame.should_be_pure.add(stmt)
|
75
|
+
ret = interp_.join_results(then_results, else_results)
|
76
|
+
return ret
|
77
|
+
|
78
|
+
def _prop_const_cond_ifelse(
|
79
|
+
self,
|
80
|
+
interp_: const.Propagate,
|
81
|
+
frame: const.Frame,
|
82
|
+
stmt: IfElse,
|
83
|
+
cond: const.Value,
|
84
|
+
body: ir.Region,
|
85
|
+
):
|
86
|
+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
|
87
|
+
body_frame.entries.update(frame.entries)
|
88
|
+
body_frame.set(body.blocks[0].args[0], cond)
|
89
|
+
results = interp_.run_ssacfg_region(body_frame, body)
|
90
|
+
return body_frame, results
|
91
|
+
|
92
|
+
@interp.impl(For)
|
93
|
+
def for_loop(
|
94
|
+
self,
|
95
|
+
interp_: const.Propagate,
|
96
|
+
frame: const.Frame,
|
97
|
+
stmt: For,
|
98
|
+
):
|
99
|
+
iterable = frame.get(stmt.iterable)
|
100
|
+
if isinstance(iterable, const.Value):
|
101
|
+
return self._prop_const_iterable_forloop(interp_, frame, stmt, iterable)
|
102
|
+
else: # TODO: support other iteration
|
103
|
+
return tuple(interp_.lattice.top() for _ in stmt.results)
|
104
|
+
|
105
|
+
def _prop_const_iterable_forloop(
|
106
|
+
self,
|
107
|
+
interp_: const.Propagate,
|
108
|
+
frame: const.Frame,
|
109
|
+
stmt: For,
|
110
|
+
iterable: const.Value,
|
111
|
+
):
|
112
|
+
frame_is_not_pure = False
|
113
|
+
if not isinstance(iterable.data, Iterable):
|
114
|
+
raise interp.InterpreterError(
|
115
|
+
f"Expected iterable, got {type(iterable.data)}"
|
116
|
+
)
|
117
|
+
|
118
|
+
loop_vars = frame.get_values(stmt.initializers)
|
119
|
+
body_block = stmt.body.blocks[0]
|
120
|
+
block_args = body_block.args
|
121
|
+
|
122
|
+
for value in iterable.data:
|
123
|
+
with interp_.state.new_frame(interp_.new_frame(stmt)) as body_frame:
|
124
|
+
body_frame.entries.update(frame.entries)
|
125
|
+
body_frame.set_values(
|
126
|
+
block_args,
|
127
|
+
(const.Value(value),) + loop_vars,
|
128
|
+
)
|
129
|
+
loop_vars = interp_.run_ssacfg_region(body_frame, stmt.body)
|
130
|
+
|
131
|
+
if body_frame.frame_is_not_pure:
|
132
|
+
frame_is_not_pure = True
|
133
|
+
if loop_vars is None:
|
134
|
+
loop_vars = ()
|
135
|
+
elif isinstance(loop_vars, interp.ReturnValue):
|
136
|
+
return loop_vars
|
137
|
+
|
138
|
+
if not frame_is_not_pure:
|
139
|
+
frame.should_be_pure.add(stmt)
|
140
|
+
return loop_vars
|
@@ -0,0 +1,35 @@
|
|
1
|
+
from kirin import interp
|
2
|
+
|
3
|
+
from .stmts import For, Yield, IfElse
|
4
|
+
from ._dialect import dialect
|
5
|
+
|
6
|
+
|
7
|
+
@dialect.register
|
8
|
+
class Concrete(interp.MethodTable):
|
9
|
+
|
10
|
+
@interp.impl(Yield)
|
11
|
+
def yield_stmt(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Yield):
|
12
|
+
return interp.YieldValue(frame.get_values(stmt.values))
|
13
|
+
|
14
|
+
@interp.impl(IfElse)
|
15
|
+
def if_else(self, interp_: interp.Interpreter, frame: interp.Frame, stmt: IfElse):
|
16
|
+
cond = frame.get(stmt.cond)
|
17
|
+
if cond:
|
18
|
+
body = stmt.then_body
|
19
|
+
else:
|
20
|
+
body = stmt.else_body
|
21
|
+
return interp_.run_ssacfg_region(frame, body)
|
22
|
+
|
23
|
+
@interp.impl(For)
|
24
|
+
def for_loop(self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: For):
|
25
|
+
iterable = frame.get(stmt.iterable)
|
26
|
+
loop_vars = frame.get_values(stmt.initializers)
|
27
|
+
block_args = stmt.body.blocks[0].args
|
28
|
+
for value in iterable:
|
29
|
+
frame.set_values(block_args, (value,) + loop_vars)
|
30
|
+
loop_vars = interpreter.run_ssacfg_region(frame, stmt.body)
|
31
|
+
if isinstance(loop_vars, interp.ReturnValue):
|
32
|
+
return loop_vars
|
33
|
+
elif loop_vars is None:
|
34
|
+
loop_vars = ()
|
35
|
+
return loop_vars
|
@@ -0,0 +1,123 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
from kirin import ir, types, lowering
|
4
|
+
from kirin.exceptions import DialectLoweringError
|
5
|
+
from kirin.dialects.py.unpack import unpacking
|
6
|
+
|
7
|
+
from .stmts import For, Yield, IfElse
|
8
|
+
from ._dialect import dialect
|
9
|
+
|
10
|
+
|
11
|
+
@dialect.register
|
12
|
+
class Lowering(lowering.FromPythonAST):
|
13
|
+
|
14
|
+
def lower_If(self, state: lowering.LoweringState, node: ast.If) -> lowering.Result:
|
15
|
+
cond = state.visit(node.test).expect_one()
|
16
|
+
frame = state.current_frame
|
17
|
+
body_frame = lowering.Frame.from_stmts(node.body, state, globals=frame.globals)
|
18
|
+
then_cond = body_frame.curr_block.args.append_from(types.Bool, cond.name)
|
19
|
+
if cond.name:
|
20
|
+
body_frame.defs[cond.name] = then_cond
|
21
|
+
state.push_frame(body_frame)
|
22
|
+
state.exhaust(body_frame)
|
23
|
+
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
|
24
|
+
|
25
|
+
else_frame = lowering.Frame.from_stmts(
|
26
|
+
node.orelse, state, globals=frame.globals
|
27
|
+
)
|
28
|
+
else_cond = else_frame.curr_block.args.append_from(types.Bool, cond.name)
|
29
|
+
if cond.name:
|
30
|
+
else_frame.defs[cond.name] = else_cond
|
31
|
+
state.push_frame(else_frame)
|
32
|
+
state.exhaust(else_frame)
|
33
|
+
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
|
34
|
+
|
35
|
+
yield_names: list[str] = []
|
36
|
+
body_yields: list[ir.SSAValue] = []
|
37
|
+
else_yields: list[ir.SSAValue] = []
|
38
|
+
if node.orelse:
|
39
|
+
for name in body_frame.defs.keys():
|
40
|
+
if name in else_frame.defs:
|
41
|
+
yield_names.append(name)
|
42
|
+
body_yields.append(body_frame.get_scope(name))
|
43
|
+
else_yields.append(else_frame.get_scope(name))
|
44
|
+
else:
|
45
|
+
for name in body_frame.defs.keys():
|
46
|
+
if name in frame.defs:
|
47
|
+
yield_names.append(name)
|
48
|
+
body_yields.append(body_frame.get_scope(name))
|
49
|
+
value = frame.get(name)
|
50
|
+
if value is None:
|
51
|
+
raise DialectLoweringError(f"expected value for {name}")
|
52
|
+
else_yields.append(value)
|
53
|
+
|
54
|
+
if not (
|
55
|
+
body_frame.curr_block.last_stmt
|
56
|
+
and body_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
|
57
|
+
):
|
58
|
+
body_frame.append_stmt(Yield(*body_yields))
|
59
|
+
|
60
|
+
if not (
|
61
|
+
else_frame.curr_block.last_stmt
|
62
|
+
and else_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
|
63
|
+
):
|
64
|
+
else_frame.append_stmt(Yield(*else_yields))
|
65
|
+
|
66
|
+
stmt = IfElse(
|
67
|
+
cond,
|
68
|
+
then_body=body_frame.curr_region,
|
69
|
+
else_body=else_frame.curr_region,
|
70
|
+
)
|
71
|
+
for result, name, body, else_ in zip(
|
72
|
+
stmt.results, yield_names, body_yields, else_yields
|
73
|
+
):
|
74
|
+
result.name = name
|
75
|
+
result.type = body.type.join(else_.type)
|
76
|
+
frame.defs[name] = result
|
77
|
+
state.append_stmt(stmt)
|
78
|
+
return lowering.Result()
|
79
|
+
|
80
|
+
def lower_For(
|
81
|
+
self, state: lowering.LoweringState, node: ast.For
|
82
|
+
) -> lowering.Result:
|
83
|
+
iter_ = state.visit(node.iter).expect_one()
|
84
|
+
|
85
|
+
yields: list[str] = []
|
86
|
+
|
87
|
+
def new_block_arg_if_inside_loop(frame: lowering.Frame, capture: ir.SSAValue):
|
88
|
+
if not capture.name:
|
89
|
+
raise DialectLoweringError("unexpected loop variable captured")
|
90
|
+
yields.append(capture.name)
|
91
|
+
return frame.curr_block.args.append_from(capture.type, capture.name)
|
92
|
+
|
93
|
+
body_frame = state.push_frame(
|
94
|
+
lowering.Frame.from_stmts(
|
95
|
+
node.body,
|
96
|
+
state,
|
97
|
+
globals=state.current_frame.globals,
|
98
|
+
capture_callback=new_block_arg_if_inside_loop,
|
99
|
+
)
|
100
|
+
)
|
101
|
+
loop_var = body_frame.curr_block.args.append_from(types.Any)
|
102
|
+
unpacking(state, node.target, loop_var)
|
103
|
+
state.exhaust(body_frame)
|
104
|
+
# NOTE: this frame won't have phi nodes
|
105
|
+
if yields and (
|
106
|
+
body_frame.curr_block.last_stmt is None
|
107
|
+
or not body_frame.curr_block.last_stmt.has_trait(ir.IsTerminator)
|
108
|
+
):
|
109
|
+
body_frame.append_stmt(Yield(*[body_frame.defs[name] for name in yields])) # type: ignore
|
110
|
+
state.pop_frame(finalize_next=False) # NOTE: scf does not have multiple blocks
|
111
|
+
|
112
|
+
initializers: list[ir.SSAValue] = []
|
113
|
+
for name in yields:
|
114
|
+
value = state.current_frame.get(name)
|
115
|
+
if value is None:
|
116
|
+
raise DialectLoweringError(f"expected value for {name}")
|
117
|
+
initializers.append(value)
|
118
|
+
stmt = For(iter_, body_frame.curr_region, *initializers)
|
119
|
+
for name, result in zip(yields, stmt.results):
|
120
|
+
state.current_frame.defs[name] = result
|
121
|
+
result.name = name
|
122
|
+
state.append_stmt(stmt)
|
123
|
+
return lowering.Result()
|