kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,15 @@
|
|
1
|
+
"""The list dialect for Python.
|
2
|
+
|
3
|
+
This module contains the dialect for list semantics in Python, including:
|
4
|
+
|
5
|
+
- The `New` and `Append` statement classes.
|
6
|
+
- The lowering pass for list operations.
|
7
|
+
- The concrete implementation of list operations.
|
8
|
+
- The type inference implementation of list operations.
|
9
|
+
|
10
|
+
This dialect maps `list()`, `ast.List` and `append()` calls to the `New` and `Append` statements.
|
11
|
+
"""
|
12
|
+
|
13
|
+
from . import interp as interp, lowering as lowering, typeinfer as typeinfer
|
14
|
+
from .stmts import New as New, Append as Append
|
15
|
+
from ._dialect import dialect as dialect
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from kirin import types, interp
|
2
|
+
from kirin.dialects.py.binop import Add
|
3
|
+
|
4
|
+
from .stmts import New, Append
|
5
|
+
from ._dialect import dialect
|
6
|
+
|
7
|
+
|
8
|
+
@dialect.register
|
9
|
+
class ListMethods(interp.MethodTable):
|
10
|
+
|
11
|
+
@interp.impl(New)
|
12
|
+
def new(self, interp, frame: interp.Frame, stmt: New):
|
13
|
+
return (list(frame.get_values(stmt.values)),)
|
14
|
+
|
15
|
+
@interp.impl(Add, types.PyClass(list), types.PyClass(list))
|
16
|
+
def add(self, interp, frame: interp.Frame, stmt: Add):
|
17
|
+
return (frame.get(stmt.lhs) + frame.get(stmt.rhs),)
|
18
|
+
|
19
|
+
@interp.impl(Append)
|
20
|
+
def append(self, interp, frame: interp.Frame, stmt: Append):
|
21
|
+
return (frame.get(stmt.list_).append(frame.get(stmt.value)),)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
from kirin import types
|
4
|
+
from kirin.lowering import Result, FromPythonAST, LoweringState
|
5
|
+
|
6
|
+
from .stmts import New
|
7
|
+
from ._dialect import dialect
|
8
|
+
|
9
|
+
|
10
|
+
@dialect.register
|
11
|
+
class PythonLowering(FromPythonAST):
|
12
|
+
|
13
|
+
def lower_List(self, state: LoweringState, node: ast.List) -> Result:
|
14
|
+
elts = tuple(state.visit(each).expect_one() for each in node.elts)
|
15
|
+
|
16
|
+
if len(elts):
|
17
|
+
typ = elts[0].type
|
18
|
+
for each in elts:
|
19
|
+
typ = typ.join(each.type)
|
20
|
+
else:
|
21
|
+
typ = types.Any
|
22
|
+
|
23
|
+
stmt = New(values=tuple(elts))
|
24
|
+
state.append_stmt(stmt)
|
25
|
+
return Result(stmt)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from kirin import ir, types
|
2
|
+
from kirin.decl import info, statement
|
3
|
+
|
4
|
+
from ._dialect import dialect
|
5
|
+
|
6
|
+
T = types.TypeVar("T")
|
7
|
+
|
8
|
+
|
9
|
+
@statement(dialect=dialect)
|
10
|
+
class New(ir.Statement):
|
11
|
+
name = "list"
|
12
|
+
traits = frozenset({ir.FromPythonCall()})
|
13
|
+
values: tuple[ir.SSAValue, ...] = info.argument(T)
|
14
|
+
result: ir.ResultValue = info.result(types.List[T])
|
15
|
+
|
16
|
+
|
17
|
+
@statement(dialect=dialect)
|
18
|
+
class Append(ir.Statement):
|
19
|
+
name = "append"
|
20
|
+
traits = frozenset({ir.FromPythonCall()})
|
21
|
+
list_: ir.SSAValue = info.argument(types.List[T])
|
22
|
+
value: ir.SSAValue = info.argument(T)
|
@@ -0,0 +1,54 @@
|
|
1
|
+
from kirin import types, interp
|
2
|
+
from kirin.dialects.eltype import ElType
|
3
|
+
from kirin.dialects.py.binop import Add
|
4
|
+
from kirin.dialects.py.indexing import GetItem
|
5
|
+
|
6
|
+
from ._dialect import dialect
|
7
|
+
|
8
|
+
|
9
|
+
@dialect.register(key="typeinfer")
|
10
|
+
class TypeInfer(interp.MethodTable):
|
11
|
+
|
12
|
+
@interp.impl(ElType, types.PyClass(list))
|
13
|
+
def eltype_list(self, interp, frame: interp.Frame, stmt: ElType):
|
14
|
+
list_type = frame.get(stmt.container)
|
15
|
+
if isinstance(list_type, types.Generic):
|
16
|
+
return (list_type.vars[0],)
|
17
|
+
else:
|
18
|
+
return (types.Any,)
|
19
|
+
|
20
|
+
@interp.impl(Add, types.PyClass(list), types.PyClass(list))
|
21
|
+
def add(self, interp, frame: interp.Frame, stmt: Add):
|
22
|
+
lhs_type = frame.get(stmt.lhs)
|
23
|
+
rhs_type = frame.get(stmt.rhs)
|
24
|
+
if isinstance(lhs_type, types.Generic):
|
25
|
+
lhs_elem_type = lhs_type.vars[0]
|
26
|
+
else:
|
27
|
+
lhs_elem_type = types.Any
|
28
|
+
|
29
|
+
if isinstance(rhs_type, types.Generic):
|
30
|
+
rhs_elem_type = rhs_type.vars[0]
|
31
|
+
else:
|
32
|
+
rhs_elem_type = types.Any
|
33
|
+
|
34
|
+
return (types.List[lhs_elem_type.join(rhs_elem_type)],)
|
35
|
+
|
36
|
+
@interp.impl(GetItem, types.PyClass(list), types.Int)
|
37
|
+
def getitem_list_int(
|
38
|
+
self, interp, frame: interp.Frame[types.TypeAttribute], stmt: GetItem
|
39
|
+
):
|
40
|
+
obj_type = frame.get(stmt.obj)
|
41
|
+
if isinstance(obj_type, types.Generic):
|
42
|
+
return (obj_type.vars[0],)
|
43
|
+
else:
|
44
|
+
return (types.Any,)
|
45
|
+
|
46
|
+
@interp.impl(GetItem, types.PyClass(list), types.PyClass(slice))
|
47
|
+
def getitem_list_slice(
|
48
|
+
self, interp, frame: interp.Frame[types.TypeAttribute], stmt: GetItem
|
49
|
+
):
|
50
|
+
obj_type = frame.get(stmt.obj)
|
51
|
+
if isinstance(obj_type, types.Generic):
|
52
|
+
return (types.List[obj_type.vars[0]],)
|
53
|
+
else:
|
54
|
+
return (types.Any,)
|
@@ -0,0 +1,76 @@
|
|
1
|
+
"""The range dialect for Python.
|
2
|
+
|
3
|
+
This dialect models the builtin `range()` function in Python.
|
4
|
+
|
5
|
+
The dialect includes:
|
6
|
+
- The `Range` statement class.
|
7
|
+
- The lowering pass for the `range()` function.
|
8
|
+
|
9
|
+
This dialect does not include a concrete implementation or type inference
|
10
|
+
for the `range()` function. One needs to use other dialect for the concrete
|
11
|
+
implementation and type inference, e.g., `ilist` dialect.
|
12
|
+
"""
|
13
|
+
|
14
|
+
import ast
|
15
|
+
from dataclasses import dataclass
|
16
|
+
|
17
|
+
from kirin import ir, types, interp, lowering, exceptions
|
18
|
+
from kirin.decl import info, statement
|
19
|
+
from kirin.dialects import eltype
|
20
|
+
|
21
|
+
dialect = ir.Dialect("py.range")
|
22
|
+
|
23
|
+
|
24
|
+
@dataclass(frozen=True)
|
25
|
+
class RangeLowering(ir.FromPythonCall["Range"]):
|
26
|
+
|
27
|
+
def lower(
|
28
|
+
self, stmt: type["Range"], state: lowering.LoweringState, node: ast.Call
|
29
|
+
) -> lowering.Result:
|
30
|
+
return _lower_range(state, node)
|
31
|
+
|
32
|
+
|
33
|
+
@statement(dialect=dialect)
|
34
|
+
class Range(ir.Statement):
|
35
|
+
name = "range"
|
36
|
+
traits = frozenset({ir.Pure(), RangeLowering()})
|
37
|
+
start: ir.SSAValue = info.argument(types.Int)
|
38
|
+
stop: ir.SSAValue = info.argument(types.Int)
|
39
|
+
step: ir.SSAValue = info.argument(types.Int)
|
40
|
+
result: ir.ResultValue = info.result(types.PyClass(range))
|
41
|
+
|
42
|
+
|
43
|
+
@dialect.register
|
44
|
+
class Lowering(lowering.FromPythonAST):
|
45
|
+
|
46
|
+
def lower_Call_range(
|
47
|
+
self, state: lowering.LoweringState, node: ast.Call
|
48
|
+
) -> lowering.Result:
|
49
|
+
return _lower_range(state, node)
|
50
|
+
|
51
|
+
|
52
|
+
@dialect.register(key="typeinfer")
|
53
|
+
class TypeInfer(interp.MethodTable):
|
54
|
+
|
55
|
+
@interp.impl(eltype.ElType, types.PyClass(range))
|
56
|
+
def eltype_range(self, interp_, frame: interp.Frame, stmt: eltype.ElType):
|
57
|
+
return (types.Int,)
|
58
|
+
|
59
|
+
|
60
|
+
def _lower_range(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
|
61
|
+
if len(node.args) == 1:
|
62
|
+
start = state.visit(ast.Constant(0)).expect_one()
|
63
|
+
stop = state.visit(node.args[0]).expect_one()
|
64
|
+
step = state.visit(ast.Constant(1)).expect_one()
|
65
|
+
elif len(node.args) == 2:
|
66
|
+
start = state.visit(node.args[0]).expect_one()
|
67
|
+
stop = state.visit(node.args[1]).expect_one()
|
68
|
+
step = state.visit(ast.Constant(1)).expect_one()
|
69
|
+
elif len(node.args) == 3:
|
70
|
+
start = state.visit(node.args[0]).expect_one()
|
71
|
+
stop = state.visit(node.args[1]).expect_one()
|
72
|
+
step = state.visit(node.args[2]).expect_one()
|
73
|
+
else:
|
74
|
+
raise exceptions.DialectLoweringError("range() takes 1-3 arguments")
|
75
|
+
|
76
|
+
return lowering.Result(state.append_stmt(Range(start, stop, step)))
|
@@ -0,0 +1,120 @@
|
|
1
|
+
"""The slice dialect for Python.
|
2
|
+
|
3
|
+
This dialect provides a `Slice` statement that represents a slice object in Python:
|
4
|
+
|
5
|
+
- The `Slice` statement class.
|
6
|
+
- The lowering pass for the `slice` call.
|
7
|
+
- The concrete implementation of the `slice` call.
|
8
|
+
- The type inference implementation of the `slice` call.
|
9
|
+
"""
|
10
|
+
|
11
|
+
import ast
|
12
|
+
from dataclasses import dataclass
|
13
|
+
|
14
|
+
from kirin import ir, types, interp, lowering, exceptions
|
15
|
+
from kirin.decl import info, statement
|
16
|
+
from kirin.dialects.py.constant import Constant
|
17
|
+
|
18
|
+
dialect = ir.Dialect("py.slice")
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass(frozen=True)
|
22
|
+
class SliceLowering(ir.FromPythonCall["Slice"]):
|
23
|
+
|
24
|
+
def lower(
|
25
|
+
self, stmt: type["Slice"], state: lowering.LoweringState, node: ast.Call
|
26
|
+
) -> lowering.Result:
|
27
|
+
return _lower_slice(state, node)
|
28
|
+
|
29
|
+
|
30
|
+
T = types.TypeVar("T")
|
31
|
+
|
32
|
+
|
33
|
+
@statement(dialect=dialect, init=False)
|
34
|
+
class Slice(ir.Statement):
|
35
|
+
name = "slice"
|
36
|
+
traits = frozenset({ir.Pure(), SliceLowering()})
|
37
|
+
start: ir.SSAValue = info.argument(T | types.NoneType)
|
38
|
+
stop: ir.SSAValue = info.argument(T | types.NoneType)
|
39
|
+
step: ir.SSAValue = info.argument(T | types.NoneType)
|
40
|
+
result: ir.ResultValue = info.result(types.Slice[T])
|
41
|
+
|
42
|
+
def __init__(
|
43
|
+
self, start: ir.SSAValue, stop: ir.SSAValue, step: ir.SSAValue
|
44
|
+
) -> None:
|
45
|
+
if not (
|
46
|
+
isinstance(stop.type, types.TypeAttribute)
|
47
|
+
and isinstance(start.type, types.TypeAttribute)
|
48
|
+
):
|
49
|
+
result_type = types.Bottom
|
50
|
+
elif start.type.is_subseteq(types.NoneType):
|
51
|
+
if stop.type.is_subseteq(types.NoneType):
|
52
|
+
result_type = types.Bottom
|
53
|
+
else:
|
54
|
+
result_type = types.Slice[stop.type]
|
55
|
+
else:
|
56
|
+
result_type = types.Slice[start.type]
|
57
|
+
|
58
|
+
super().__init__(
|
59
|
+
args=(start, stop, step),
|
60
|
+
result_types=[result_type],
|
61
|
+
args_slice={"start": 0, "stop": 1, "step": 2},
|
62
|
+
)
|
63
|
+
|
64
|
+
|
65
|
+
@dialect.register
|
66
|
+
class Concrete(interp.MethodTable):
|
67
|
+
|
68
|
+
@interp.impl(Slice)
|
69
|
+
def _slice(self, interp, frame: interp.Frame, stmt: Slice):
|
70
|
+
start, stop, step = frame.get_values(stmt.args)
|
71
|
+
if start is None and step is None:
|
72
|
+
return (slice(stop),)
|
73
|
+
elif step is None:
|
74
|
+
return (slice(start, stop),)
|
75
|
+
else:
|
76
|
+
return (slice(start, stop, step),)
|
77
|
+
|
78
|
+
|
79
|
+
@dialect.register
|
80
|
+
class Lowering(lowering.FromPythonAST):
|
81
|
+
|
82
|
+
def lower_Slice(
|
83
|
+
self, state: lowering.LoweringState, node: ast.Slice
|
84
|
+
) -> lowering.Result:
|
85
|
+
def value_or_none(expr: ast.expr | None) -> ir.SSAValue:
|
86
|
+
if expr is not None:
|
87
|
+
return state.visit(expr).expect_one()
|
88
|
+
else:
|
89
|
+
return state.append_stmt(Constant(None)).result
|
90
|
+
|
91
|
+
lower = value_or_none(node.lower)
|
92
|
+
upper = value_or_none(node.upper)
|
93
|
+
step = value_or_none(node.step)
|
94
|
+
return lowering.Result(
|
95
|
+
state.append_stmt(Slice(start=lower, stop=upper, step=step))
|
96
|
+
)
|
97
|
+
|
98
|
+
def lower_Call_slice(
|
99
|
+
self, state: lowering.LoweringState, node: ast.Call
|
100
|
+
) -> lowering.Result:
|
101
|
+
return _lower_slice(state, node)
|
102
|
+
|
103
|
+
|
104
|
+
def _lower_slice(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
|
105
|
+
if len(node.args) == 1:
|
106
|
+
start = state.visit(ast.Constant(None)).expect_one()
|
107
|
+
stop = state.visit(node.args[0]).expect_one()
|
108
|
+
step = state.visit(ast.Constant(None)).expect_one()
|
109
|
+
elif len(node.args) == 2:
|
110
|
+
start = state.visit(node.args[0]).expect_one()
|
111
|
+
stop = state.visit(node.args[1]).expect_one()
|
112
|
+
step = state.visit(ast.Constant(None)).expect_one()
|
113
|
+
elif len(node.args) == 3:
|
114
|
+
start = state.visit(node.args[0]).expect_one()
|
115
|
+
stop = state.visit(node.args[1]).expect_one()
|
116
|
+
step = state.visit(node.args[2]).expect_one()
|
117
|
+
else:
|
118
|
+
raise exceptions.DialectLoweringError("slice() takes 1-3 arguments")
|
119
|
+
|
120
|
+
return lowering.Result(state.append_stmt(Slice(start, stop, step)))
|
@@ -0,0 +1,109 @@
|
|
1
|
+
"""The tuple dialect for Python.
|
2
|
+
|
3
|
+
This dialect provides a way to work with Python tuples in the IR, including:
|
4
|
+
|
5
|
+
- The `New` statement class.
|
6
|
+
- The lowering pass for the tuple statement.
|
7
|
+
- The concrete implementation of the tuple statement.
|
8
|
+
- The type inference implementation of the tuple addition with `py.binop.Add`.
|
9
|
+
- The constant propagation implementation of the tuple statement.
|
10
|
+
- The Julia emitter for the tuple statement.
|
11
|
+
|
12
|
+
This dialect maps `ast.Tuple` nodes to the `New` statement.
|
13
|
+
"""
|
14
|
+
|
15
|
+
import ast
|
16
|
+
|
17
|
+
from kirin import ir, types, interp, lowering
|
18
|
+
from kirin.decl import info, statement
|
19
|
+
from kirin.analysis import const
|
20
|
+
from kirin.emit.julia import EmitJulia, EmitStrFrame
|
21
|
+
from kirin.dialects.eltype import ElType
|
22
|
+
from kirin.dialects.py.binop import Add
|
23
|
+
|
24
|
+
dialect = ir.Dialect("py.tuple")
|
25
|
+
|
26
|
+
|
27
|
+
@statement(dialect=dialect)
|
28
|
+
class New(ir.Statement):
|
29
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
30
|
+
result: ir.ResultValue = info.result()
|
31
|
+
|
32
|
+
def __init__(self, values: tuple[ir.SSAValue, ...]) -> None:
|
33
|
+
result_type = types.Generic(tuple, *tuple(value.type for value in values))
|
34
|
+
super().__init__(
|
35
|
+
args=values,
|
36
|
+
result_types=[
|
37
|
+
result_type,
|
38
|
+
],
|
39
|
+
)
|
40
|
+
|
41
|
+
|
42
|
+
@dialect.register
|
43
|
+
class Concrete(interp.MethodTable):
|
44
|
+
|
45
|
+
@interp.impl(New)
|
46
|
+
def new(self, interp: interp.Interpreter, frame: interp.Frame, stmt: New):
|
47
|
+
return (frame.get_values(stmt.args),)
|
48
|
+
|
49
|
+
|
50
|
+
@dialect.register(key="typeinfer")
|
51
|
+
class TypeInfer(interp.MethodTable):
|
52
|
+
|
53
|
+
@interp.impl(ElType, types.PyClass(tuple))
|
54
|
+
def eltype_tuple(self, interp, frame: interp.Frame, stmt: ElType):
|
55
|
+
tuple_type = frame.get(stmt.container)
|
56
|
+
if isinstance(tuple_type, types.Generic):
|
57
|
+
ret = tuple_type.vars[0]
|
58
|
+
for var in tuple_type.vars[1:]:
|
59
|
+
ret = ret.join(var)
|
60
|
+
return (ret,)
|
61
|
+
else:
|
62
|
+
return (types.Any,)
|
63
|
+
|
64
|
+
@interp.impl(Add, types.PyClass(tuple), types.PyClass(tuple))
|
65
|
+
def add(self, interp, frame: interp.Frame[types.TypeAttribute], stmt):
|
66
|
+
lhs = frame.get(stmt.lhs)
|
67
|
+
rhs = frame.get(stmt.rhs)
|
68
|
+
if isinstance(lhs, types.Generic) and isinstance(rhs, types.Generic):
|
69
|
+
return (types.Generic(tuple, *(lhs.vars + rhs.vars)),)
|
70
|
+
else:
|
71
|
+
return (types.PyClass(tuple),) # no type param, so unknown
|
72
|
+
|
73
|
+
|
74
|
+
@dialect.register(key="constprop")
|
75
|
+
class ConstPropTable(interp.MethodTable):
|
76
|
+
|
77
|
+
@interp.impl(New)
|
78
|
+
def new_tuple(
|
79
|
+
self,
|
80
|
+
_: const.Propagate,
|
81
|
+
frame: const.Frame,
|
82
|
+
stmt: New,
|
83
|
+
) -> interp.StatementResult[const.Result]:
|
84
|
+
return (const.PartialTuple(tuple(x for x in frame.get_values(stmt.args))),)
|
85
|
+
|
86
|
+
|
87
|
+
@dialect.register
|
88
|
+
class Lowering(lowering.FromPythonAST):
|
89
|
+
|
90
|
+
def lower_Tuple(
|
91
|
+
self, state: lowering.LoweringState, node: ast.Tuple
|
92
|
+
) -> lowering.Result:
|
93
|
+
return lowering.Result(
|
94
|
+
state.append_stmt(
|
95
|
+
stmt=New(tuple(state.visit(elem).expect_one() for elem in node.elts))
|
96
|
+
)
|
97
|
+
)
|
98
|
+
|
99
|
+
|
100
|
+
@dialect.register(key="emit.julia")
|
101
|
+
class JuliaTable(interp.MethodTable):
|
102
|
+
|
103
|
+
@interp.impl(New)
|
104
|
+
def emit_NewTuple(self, emit: EmitJulia, frame: EmitStrFrame, stmt: New):
|
105
|
+
return (
|
106
|
+
emit.write_assign(
|
107
|
+
frame, stmt.result, "(" + ", ".join(frame.get_values(stmt.args)) + ")"
|
108
|
+
),
|
109
|
+
)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
"""The unary dialect for Python.
|
2
|
+
|
3
|
+
This module contains the dialect for unary semantics in Python, including:
|
4
|
+
|
5
|
+
- The `UnaryOp` base class for unary operations.
|
6
|
+
- The `UAdd`, `USub`, `Not`, and `Invert` statement classes.
|
7
|
+
- The lowering pass for unary operations.
|
8
|
+
- The concrete implementation of unary operations.
|
9
|
+
- The type inference implementation of unary operations.
|
10
|
+
- The constant propagation implementation of unary operations.
|
11
|
+
- The Julia emitter for unary operations.
|
12
|
+
|
13
|
+
This dialect maps `ast.UnaryOp` nodes to the `UAdd`, `USub`, `Not`, and `Invert` statements.
|
14
|
+
"""
|
15
|
+
|
16
|
+
from . import (
|
17
|
+
julia as julia,
|
18
|
+
interp as interp,
|
19
|
+
lowering as lowering,
|
20
|
+
constprop as constprop,
|
21
|
+
typeinfer as typeinfer,
|
22
|
+
)
|
23
|
+
from .stmts import * # noqa: F403
|
24
|
+
from ._dialect import dialect as dialect
|
@@ -0,0 +1,20 @@
|
|
1
|
+
from kirin import interp
|
2
|
+
from kirin.analysis import const
|
3
|
+
|
4
|
+
from . import stmts
|
5
|
+
from ._dialect import dialect
|
6
|
+
|
7
|
+
|
8
|
+
@dialect.register(key="constprop")
|
9
|
+
class ConstProp(interp.MethodTable):
|
10
|
+
|
11
|
+
@interp.impl(stmts.Not)
|
12
|
+
def not_(
|
13
|
+
self, _: const.Propagate, frame: const.Frame, stmt: stmts.Not
|
14
|
+
) -> interp.StatementResult[const.Result]:
|
15
|
+
hint = frame.get(stmt.value)
|
16
|
+
if isinstance(hint, (const.PartialTuple, const.Value)):
|
17
|
+
ret = const.Value(not hint.data)
|
18
|
+
else:
|
19
|
+
ret = const.Unknown()
|
20
|
+
return (ret,)
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from kirin import interp
|
2
|
+
|
3
|
+
from . import stmts
|
4
|
+
from ._dialect import dialect
|
5
|
+
|
6
|
+
|
7
|
+
@dialect.register
|
8
|
+
class Concrete(interp.MethodTable):
|
9
|
+
|
10
|
+
@interp.impl(stmts.UAdd)
|
11
|
+
def uadd(self, interp, frame: interp.Frame, stmt: stmts.UAdd):
|
12
|
+
return (+frame.get(stmt.value),)
|
13
|
+
|
14
|
+
@interp.impl(stmts.USub)
|
15
|
+
def usub(self, interp, frame: interp.Frame, stmt: stmts.USub):
|
16
|
+
return (-frame.get(stmt.value),)
|
17
|
+
|
18
|
+
@interp.impl(stmts.Not)
|
19
|
+
def not_(self, interp, frame: interp.Frame, stmt: stmts.Not):
|
20
|
+
return (not frame.get(stmt.value),)
|
21
|
+
|
22
|
+
@interp.impl(stmts.Invert)
|
23
|
+
def invert(self, interp, frame: interp.Frame, stmt: stmts.Invert):
|
24
|
+
return (~frame.get(stmt.value),)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from kirin import interp
|
2
|
+
from kirin.emit.julia import EmitJulia, EmitStrFrame
|
3
|
+
|
4
|
+
from . import stmts
|
5
|
+
from ._dialect import dialect
|
6
|
+
|
7
|
+
|
8
|
+
@dialect.register(key="emit.julia")
|
9
|
+
class JuliaTable(interp.MethodTable):
|
10
|
+
|
11
|
+
@interp.impl(stmts.Not)
|
12
|
+
def emit_Not(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Not):
|
13
|
+
return (emit.write_assign(frame, stmt.result, f"!{frame.get(stmt.value)}"),)
|
14
|
+
|
15
|
+
@interp.impl(stmts.USub)
|
16
|
+
def emit_USub(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.USub):
|
17
|
+
return (emit.write_assign(frame, stmt.result, f"-{frame.get(stmt.value)}"),)
|
18
|
+
|
19
|
+
@interp.impl(stmts.UAdd)
|
20
|
+
def emit_UAdd(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.UAdd):
|
21
|
+
return (emit.write_assign(frame, stmt.result, f"+{frame.get(stmt.value)}"),)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
from kirin import lowering, exceptions
|
4
|
+
|
5
|
+
from . import stmts
|
6
|
+
from ._dialect import dialect
|
7
|
+
|
8
|
+
|
9
|
+
@dialect.register
|
10
|
+
class Lowering(lowering.FromPythonAST):
|
11
|
+
|
12
|
+
def lower_UnaryOp(
|
13
|
+
self, state: lowering.LoweringState, node: ast.UnaryOp
|
14
|
+
) -> lowering.Result:
|
15
|
+
if op := getattr(stmts, node.op.__class__.__name__, None):
|
16
|
+
return lowering.Result(
|
17
|
+
state.append_stmt(op(state.visit(node.operand).expect_one()))
|
18
|
+
)
|
19
|
+
else:
|
20
|
+
raise exceptions.DialectLoweringError(
|
21
|
+
f"unsupported unary operator {node.op}"
|
22
|
+
)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from kirin import ir, types
|
2
|
+
from kirin.decl import info, statement
|
3
|
+
|
4
|
+
from ._dialect import dialect
|
5
|
+
|
6
|
+
T = types.TypeVar("T")
|
7
|
+
|
8
|
+
|
9
|
+
@statement
|
10
|
+
class UnaryOp(ir.Statement):
|
11
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
12
|
+
value: ir.SSAValue = info.argument(T, print=False)
|
13
|
+
result: ir.ResultValue = info.result(T)
|
14
|
+
|
15
|
+
|
16
|
+
@statement(dialect=dialect)
|
17
|
+
class UAdd(UnaryOp):
|
18
|
+
name = "uadd"
|
19
|
+
|
20
|
+
|
21
|
+
@statement(dialect=dialect)
|
22
|
+
class USub(UnaryOp):
|
23
|
+
name = "usub"
|
24
|
+
|
25
|
+
|
26
|
+
@statement(dialect=dialect)
|
27
|
+
class Not(UnaryOp):
|
28
|
+
name = "not"
|
29
|
+
|
30
|
+
|
31
|
+
@statement(dialect=dialect)
|
32
|
+
class Invert(UnaryOp):
|
33
|
+
name = "invert"
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from kirin import types, interp
|
2
|
+
|
3
|
+
from . import stmts
|
4
|
+
from ._dialect import dialect
|
5
|
+
|
6
|
+
|
7
|
+
@dialect.register(key="typeinfer")
|
8
|
+
class TypeInfer(interp.MethodTable):
|
9
|
+
|
10
|
+
@interp.impl(stmts.UAdd)
|
11
|
+
@interp.impl(stmts.USub)
|
12
|
+
def uadd(
|
13
|
+
self, interp, frame: interp.Frame[types.TypeAttribute], stmt: stmts.UnaryOp
|
14
|
+
):
|
15
|
+
return (frame.get(stmt.value),)
|
16
|
+
|
17
|
+
@interp.impl(stmts.Not)
|
18
|
+
def not_(self, interp, frame, stmt: stmts.Not):
|
19
|
+
return (types.Bool,)
|
20
|
+
|
21
|
+
@interp.impl(stmts.Invert, types.Int)
|
22
|
+
def invert(self, interp, frame, stmt):
|
23
|
+
return (types.Int,)
|