kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,120 @@
|
|
1
|
+
"""Traits for customizing lowering of Python `with` syntax to a statement.
|
2
|
+
"""
|
3
|
+
|
4
|
+
import ast
|
5
|
+
from typing import TYPE_CHECKING, TypeVar
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from kirin.exceptions import DialectLoweringError
|
9
|
+
|
10
|
+
from ..abc import PythonLoweringTrait
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from kirin.ir import Statement
|
14
|
+
from kirin.lowering import Result, LoweringState
|
15
|
+
|
16
|
+
StatementType = TypeVar("StatementType", bound="Statement")
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass(frozen=True)
|
20
|
+
class FromPythonWith(PythonLoweringTrait[StatementType, ast.With]):
|
21
|
+
"""Trait for customizing lowering of Python with statements to a statement.
|
22
|
+
|
23
|
+
Subclassing this trait allows for customizing the lowering of Python with
|
24
|
+
statements to the statement. The `lower` method should be implemented to parse
|
25
|
+
the arguments from the Python with statement and construct the statement instance.
|
26
|
+
"""
|
27
|
+
|
28
|
+
pass
|
29
|
+
|
30
|
+
|
31
|
+
@dataclass(frozen=True)
|
32
|
+
class FromPythonWithSingleItem(FromPythonWith[StatementType]):
|
33
|
+
"""Trait for customizing lowering of the following Python with syntax to a statement:
|
34
|
+
|
35
|
+
```python
|
36
|
+
with <stmt>[ as <name>]:
|
37
|
+
<body>
|
38
|
+
```
|
39
|
+
|
40
|
+
where `<stmt>` is the statement being lowered, `<name>` is an optional name for the result
|
41
|
+
of the statement, and `<body>` is the body of the with statement. The optional `as <name>`
|
42
|
+
is not valid when the statement has no results.
|
43
|
+
|
44
|
+
This syntax is slightly different from the standard Python `with` statement in that
|
45
|
+
`<name>` refers to the result of the statement, not the context manager. Thus typically
|
46
|
+
one sould access `<name>` in `<body>` to use the result of the statement.
|
47
|
+
|
48
|
+
In some cases, however, `<name>` may be used as a reference of a special value `self` that
|
49
|
+
is passed to the `<body>` of the statement. This is useful for statements that have a similar
|
50
|
+
behavior to a closure.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def lower(
|
54
|
+
self, stmt: type[StatementType], state: "LoweringState", node: ast.With
|
55
|
+
) -> "Result":
|
56
|
+
from kirin import ir, lowering
|
57
|
+
from kirin.decl import fields
|
58
|
+
from kirin.dialects import cf
|
59
|
+
|
60
|
+
fs = fields(stmt)
|
61
|
+
if len(fs.regions) != 1:
|
62
|
+
raise DialectLoweringError(
|
63
|
+
"Expected exactly one region in statement declaration"
|
64
|
+
)
|
65
|
+
|
66
|
+
if len(node.items) != 1:
|
67
|
+
raise DialectLoweringError("Expected exactly one item in statement")
|
68
|
+
|
69
|
+
item, body = node.items[0], node.body
|
70
|
+
if not isinstance(item.context_expr, ast.Call):
|
71
|
+
raise DialectLoweringError(
|
72
|
+
f"Expected context expression to be a call for with {stmt.name}"
|
73
|
+
)
|
74
|
+
|
75
|
+
body_frame = lowering.Frame.from_stmts(body, state, parent=state.current_frame)
|
76
|
+
state.push_frame(body_frame)
|
77
|
+
state.exhaust()
|
78
|
+
region_name, region_info = next(iter(fs.regions.items()))
|
79
|
+
if region_info.multi: # branch to exit block if not terminated
|
80
|
+
for block in body_frame.curr_region.blocks:
|
81
|
+
if block.last_stmt is None or not block.last_stmt.has_trait(
|
82
|
+
ir.IsTerminator
|
83
|
+
):
|
84
|
+
block.stmts.append(
|
85
|
+
cf.Branch(arguments=(), successor=body_frame.next_block)
|
86
|
+
)
|
87
|
+
state.pop_frame()
|
88
|
+
else:
|
89
|
+
if len(body_frame.curr_region.blocks) != 1:
|
90
|
+
raise DialectLoweringError(
|
91
|
+
f"Expected exactly one block in region {region_name}"
|
92
|
+
)
|
93
|
+
state.pop_frame(finalize_next=False)
|
94
|
+
|
95
|
+
args, kwargs = state.default_Call_inputs(stmt, item.context_expr)
|
96
|
+
kwargs[region_name] = body_frame.curr_region
|
97
|
+
results = state.append_stmt(stmt(*args.values(), **kwargs)).results
|
98
|
+
if len(results) == 0:
|
99
|
+
return lowering.Result()
|
100
|
+
elif len(results) > 1:
|
101
|
+
raise DialectLoweringError(
|
102
|
+
f"Expected exactly one result or no result from statement {stmt.name}"
|
103
|
+
)
|
104
|
+
|
105
|
+
result = results[0]
|
106
|
+
if item.optional_vars is not None and isinstance(item.optional_vars, ast.Name):
|
107
|
+
result.name = item.optional_vars.id
|
108
|
+
state.current_frame.defs[result.name] = result
|
109
|
+
return lowering.Result(result)
|
110
|
+
|
111
|
+
def verify(self, stmt: "Statement"):
|
112
|
+
assert (
|
113
|
+
len(stmt.regions) == 1
|
114
|
+
), "FromPythonWithSingleItem statements must have one region"
|
115
|
+
assert (
|
116
|
+
len(stmt.successors) == 0
|
117
|
+
), "FromPythonWithSingleItem statements cannot have successors"
|
118
|
+
assert (
|
119
|
+
len(stmt.results) <= 1
|
120
|
+
), "FromPythonWithSingleItem statements can have at most one result"
|
@@ -0,0 +1,22 @@
|
|
1
|
+
"""SSACFG region trait.
|
2
|
+
|
3
|
+
This module defines the SSACFGRegion trait, which is used to indicate that a
|
4
|
+
region has an SSACFG graph.
|
5
|
+
"""
|
6
|
+
|
7
|
+
from typing import TYPE_CHECKING
|
8
|
+
from dataclasses import dataclass
|
9
|
+
|
10
|
+
from kirin.ir.traits.abc import RegionTrait
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from kirin.ir import Region
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass(frozen=True)
|
17
|
+
class SSACFGRegion(RegionTrait):
|
18
|
+
|
19
|
+
def get_graph(self, region: "Region"):
|
20
|
+
from kirin.analysis.cfg import CFG
|
21
|
+
|
22
|
+
return CFG(region)
|
@@ -0,0 +1,57 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from kirin.exceptions import VerificationError
|
5
|
+
from kirin.ir.attrs.py import PyAttr
|
6
|
+
from kirin.ir.traits.abc import StmtTrait
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from kirin.ir import Statement
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass(frozen=True)
|
13
|
+
class SymbolOpInterface(StmtTrait):
|
14
|
+
"""A trait that indicates that a statement is a symbol operation.
|
15
|
+
|
16
|
+
A symbol operation is a statement that has a symbol name attribute.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def get_sym_name(self, stmt: "Statement") -> "PyAttr[str]":
|
20
|
+
sym_name: PyAttr[str] | None = stmt.get_attr_or_prop("sym_name") # type: ignore
|
21
|
+
# NOTE: unlike MLIR or xDSL we do not allow empty symbol names
|
22
|
+
if sym_name is None:
|
23
|
+
raise ValueError(f"Statement {stmt.name} does not have a symbol name")
|
24
|
+
return sym_name
|
25
|
+
|
26
|
+
def verify(self, stmt: "Statement"):
|
27
|
+
from kirin.types import String
|
28
|
+
|
29
|
+
sym_name = self.get_sym_name(stmt)
|
30
|
+
if not (isinstance(sym_name, PyAttr) and sym_name.type.is_subseteq(String)):
|
31
|
+
raise ValueError(f"Symbol name {sym_name} is not a string attribute")
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass(frozen=True)
|
35
|
+
class SymbolTable(StmtTrait):
|
36
|
+
"""
|
37
|
+
Statement with SymbolTable trait can only have one region with one block.
|
38
|
+
"""
|
39
|
+
|
40
|
+
@staticmethod
|
41
|
+
def walk(stmt: "Statement"):
|
42
|
+
return stmt.regions[0].blocks[0].stmts
|
43
|
+
|
44
|
+
def verify(self, stmt: "Statement"):
|
45
|
+
if len(stmt.regions) != 1:
|
46
|
+
raise VerificationError(
|
47
|
+
stmt,
|
48
|
+
f"Statement {stmt.name} with SymbolTable trait must have exactly one region",
|
49
|
+
)
|
50
|
+
|
51
|
+
if len(stmt.regions[0].blocks) != 1:
|
52
|
+
raise VerificationError(
|
53
|
+
stmt,
|
54
|
+
f"Statement {stmt.name} with SymbolTable trait must have exactly one block",
|
55
|
+
)
|
56
|
+
|
57
|
+
# TODO: check uniqueness of symbol names
|
kirin/ir/use.py
ADDED
@@ -0,0 +1,17 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from kirin.ir.nodes.stmt import Statement
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass(frozen=True)
|
11
|
+
class Use:
|
12
|
+
"""A use of an SSA value in a statement."""
|
13
|
+
|
14
|
+
stmt: Statement
|
15
|
+
"""The statement that uses the SSA value."""
|
16
|
+
index: int
|
17
|
+
"""The index of the use in the statement."""
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from kirin.lattice.abc import (
|
2
|
+
Lattice as Lattice,
|
3
|
+
UnionMeta as UnionMeta,
|
4
|
+
LatticeMeta as LatticeMeta,
|
5
|
+
SingletonMeta as SingletonMeta,
|
6
|
+
BoundedLattice as BoundedLattice,
|
7
|
+
)
|
8
|
+
from kirin.lattice.empty import EmptyLattice as EmptyLattice
|
9
|
+
from kirin.lattice.mixin import (
|
10
|
+
IsSubsetEqMixin as IsSubsetEqMixin,
|
11
|
+
SimpleJoinMixin as SimpleJoinMixin,
|
12
|
+
SimpleMeetMixin as SimpleMeetMixin,
|
13
|
+
)
|
kirin/lattice/abc.py
ADDED
@@ -0,0 +1,128 @@
|
|
1
|
+
from abc import ABC, ABCMeta, abstractmethod
|
2
|
+
from typing import Generic, TypeVar, Iterable
|
3
|
+
|
4
|
+
|
5
|
+
class LatticeMeta(ABCMeta):
|
6
|
+
pass
|
7
|
+
|
8
|
+
|
9
|
+
class SingletonMeta(LatticeMeta):
|
10
|
+
"""
|
11
|
+
Singleton metaclass for lattices. It ensures that only one instance of a lattice is created.
|
12
|
+
|
13
|
+
See https://stackoverflow.com/questions/674304/why-is-init-always-called-after-new/8665179#8665179
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(cls, name, bases, attrs):
|
17
|
+
super().__init__(name, bases, attrs)
|
18
|
+
cls._instance = None
|
19
|
+
|
20
|
+
def __call__(cls):
|
21
|
+
if cls._instance is None:
|
22
|
+
cls._instance = super().__call__()
|
23
|
+
return cls._instance
|
24
|
+
|
25
|
+
|
26
|
+
LatticeType = TypeVar("LatticeType", bound="Lattice")
|
27
|
+
|
28
|
+
|
29
|
+
class Lattice(ABC, Generic[LatticeType], metaclass=LatticeMeta):
|
30
|
+
"""ABC for lattices as Python class.
|
31
|
+
|
32
|
+
While `Lattice` is only an interface, `LatticeABC` is an abstract
|
33
|
+
class that can be inherited from. This provides a few default
|
34
|
+
implementations for the lattice operations.
|
35
|
+
"""
|
36
|
+
|
37
|
+
@abstractmethod
|
38
|
+
def join(self, other: LatticeType) -> LatticeType:
|
39
|
+
"""Join operation."""
|
40
|
+
...
|
41
|
+
|
42
|
+
@abstractmethod
|
43
|
+
def meet(self, other: LatticeType) -> LatticeType:
|
44
|
+
"""Meet operation."""
|
45
|
+
...
|
46
|
+
|
47
|
+
@abstractmethod
|
48
|
+
def is_subseteq(self, other: LatticeType) -> bool:
|
49
|
+
"""Subseteq operation."""
|
50
|
+
...
|
51
|
+
|
52
|
+
def is_equal(self, other: LatticeType) -> bool:
|
53
|
+
"""Check if two lattices are equal."""
|
54
|
+
if self is other:
|
55
|
+
return True
|
56
|
+
else:
|
57
|
+
return self.is_subseteq(other) and other.is_subseteq(self)
|
58
|
+
|
59
|
+
def is_subset(self, other: LatticeType) -> bool:
|
60
|
+
return self.is_subseteq(other) and not other.is_subseteq(self)
|
61
|
+
|
62
|
+
def __eq__(self, value: object) -> bool:
|
63
|
+
raise NotImplementedError(
|
64
|
+
"Equality is not implemented for lattices, use is_equal instead"
|
65
|
+
)
|
66
|
+
|
67
|
+
def __hash__(self) -> int:
|
68
|
+
raise NotImplementedError("Hash is not implemented for lattices")
|
69
|
+
|
70
|
+
|
71
|
+
BoundedLatticeType = TypeVar("BoundedLatticeType", bound="BoundedLattice")
|
72
|
+
|
73
|
+
|
74
|
+
class BoundedLattice(Lattice[BoundedLatticeType]):
|
75
|
+
"""ABC for bounded lattices as Python class.
|
76
|
+
|
77
|
+
`BoundedLattice` is an abstract class that can be inherited from.
|
78
|
+
It requires the implementation of the `bottom` and `top` methods.
|
79
|
+
"""
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
@abstractmethod
|
83
|
+
def bottom(cls) -> BoundedLatticeType: ...
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
@abstractmethod
|
87
|
+
def top(cls) -> BoundedLatticeType: ...
|
88
|
+
|
89
|
+
|
90
|
+
class UnionMeta(LatticeMeta):
|
91
|
+
"""Meta class for union types. It simplifies the union if possible."""
|
92
|
+
|
93
|
+
def __call__(
|
94
|
+
self,
|
95
|
+
typ: Iterable[LatticeType] | LatticeType,
|
96
|
+
*others: LatticeType,
|
97
|
+
):
|
98
|
+
from kirin.lattice.abc import Lattice
|
99
|
+
|
100
|
+
if isinstance(typ, Lattice):
|
101
|
+
typs: Iterable[LatticeType] = (typ, *others)
|
102
|
+
elif not others:
|
103
|
+
typs = typ
|
104
|
+
else:
|
105
|
+
raise ValueError(
|
106
|
+
"Expected an iterable of types or variadic arguments of types"
|
107
|
+
)
|
108
|
+
|
109
|
+
# try if the union can be simplified
|
110
|
+
params: list[LatticeType] = []
|
111
|
+
for typ in typs:
|
112
|
+
contains = False
|
113
|
+
for idx, other in enumerate(params):
|
114
|
+
if typ.is_subseteq(other):
|
115
|
+
contains = True
|
116
|
+
break
|
117
|
+
elif other.is_subseteq(typ):
|
118
|
+
params[idx] = typ
|
119
|
+
contains = True
|
120
|
+
break
|
121
|
+
|
122
|
+
if not contains:
|
123
|
+
params.append(typ)
|
124
|
+
|
125
|
+
if len(params) == 1:
|
126
|
+
return params[0]
|
127
|
+
|
128
|
+
return super(UnionMeta, self).__call__(*params)
|
kirin/lattice/empty.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
from kirin.lattice.abc import SingletonMeta, BoundedLattice
|
2
|
+
|
3
|
+
|
4
|
+
class EmptyLattice(BoundedLattice["EmptyLattice"], metaclass=SingletonMeta):
|
5
|
+
"""Empty lattice."""
|
6
|
+
|
7
|
+
def join(self, other: "EmptyLattice") -> "EmptyLattice":
|
8
|
+
return self
|
9
|
+
|
10
|
+
def meet(self, other: "EmptyLattice") -> "EmptyLattice":
|
11
|
+
return self
|
12
|
+
|
13
|
+
@classmethod
|
14
|
+
def bottom(cls):
|
15
|
+
return cls()
|
16
|
+
|
17
|
+
@classmethod
|
18
|
+
def top(cls):
|
19
|
+
return cls()
|
20
|
+
|
21
|
+
def __hash__(self) -> int:
|
22
|
+
return id(self)
|
23
|
+
|
24
|
+
def is_subseteq(self, other: "EmptyLattice") -> bool:
|
25
|
+
return True
|
kirin/lattice/mixin.py
ADDED
@@ -0,0 +1,51 @@
|
|
1
|
+
from typing import TypeVar
|
2
|
+
|
3
|
+
from .abc import BoundedLattice
|
4
|
+
|
5
|
+
BoundedLatticeType = TypeVar("BoundedLatticeType", bound="BoundedLattice")
|
6
|
+
|
7
|
+
|
8
|
+
class IsSubsetEqMixin(BoundedLattice[BoundedLatticeType]):
|
9
|
+
"""A special mixin for lattices that provides a default implementation for `is_subseteq`
|
10
|
+
by using the visitor pattern. This is useful if the lattice has a lot of different
|
11
|
+
subclasses that need to be compared.
|
12
|
+
|
13
|
+
Must be used before `BoundedLattice` in the inheritance chain.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def is_subseteq(self, other: BoundedLatticeType) -> bool:
|
17
|
+
if other is self.top():
|
18
|
+
return True
|
19
|
+
elif other is self.bottom():
|
20
|
+
return False
|
21
|
+
|
22
|
+
method = getattr(
|
23
|
+
self,
|
24
|
+
"is_subseteq_" + other.__class__.__name__,
|
25
|
+
getattr(self, "is_subseteq_fallback", None),
|
26
|
+
)
|
27
|
+
if method is not None:
|
28
|
+
return method(other)
|
29
|
+
return False
|
30
|
+
|
31
|
+
|
32
|
+
class SimpleJoinMixin(BoundedLattice[BoundedLatticeType]):
|
33
|
+
"""A mixin that provides a simple implementation for the join operation."""
|
34
|
+
|
35
|
+
def join(self, other: BoundedLatticeType) -> BoundedLatticeType:
|
36
|
+
if self.is_subseteq(other):
|
37
|
+
return other
|
38
|
+
elif other.is_subseteq(self):
|
39
|
+
return self # type: ignore
|
40
|
+
return self.top()
|
41
|
+
|
42
|
+
|
43
|
+
class SimpleMeetMixin(BoundedLattice[BoundedLatticeType]):
|
44
|
+
"""A mixin that provides a simple implementation for the meet operation."""
|
45
|
+
|
46
|
+
def meet(self, other: BoundedLatticeType) -> BoundedLatticeType:
|
47
|
+
if self.is_subseteq(other):
|
48
|
+
return self # type: ignore
|
49
|
+
elif other.is_subseteq(self):
|
50
|
+
return other
|
51
|
+
return self.bottom()
|
@@ -0,0 +1,7 @@
|
|
1
|
+
from kirin.lowering.core import Lowering as Lowering
|
2
|
+
from kirin.lowering.frame import Frame as Frame
|
3
|
+
from kirin.lowering.state import LoweringState as LoweringState
|
4
|
+
from kirin.lowering.result import Result as Result
|
5
|
+
from kirin.lowering.stream import StmtStream as StmtStream
|
6
|
+
from kirin.lowering.binding import wraps as wraps
|
7
|
+
from kirin.lowering.dialect import FromPythonAST as FromPythonAST
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Generic, TypeVar, Callable, ParamSpec
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from kirin.ir.nodes.stmt import Statement
|
6
|
+
|
7
|
+
Params = ParamSpec("Params")
|
8
|
+
RetType = TypeVar("RetType")
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass(frozen=True)
|
12
|
+
class Binding(Generic[Params, RetType]):
|
13
|
+
parent: type["Statement"]
|
14
|
+
|
15
|
+
def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> RetType:
|
16
|
+
raise NotImplementedError(
|
17
|
+
f"Binding of {self.parent.name} can \
|
18
|
+
only be called from a kernel"
|
19
|
+
)
|
20
|
+
|
21
|
+
|
22
|
+
def wraps(parent: type["Statement"]):
|
23
|
+
"""Wraps a [`Statement`][kirin.ir.nodes.stmt.Statement] to a `Binding` object
|
24
|
+
which will be special cased in the lowering process.
|
25
|
+
|
26
|
+
This is useful for providing type hints by faking the call signature of a
|
27
|
+
[`Statement`][kirin.ir.nodes.stmt.Statement].
|
28
|
+
|
29
|
+
## Example
|
30
|
+
|
31
|
+
Directly writing a function with the statement will let Python linter think
|
32
|
+
you intend to call the constructor of the statement class. However, given the
|
33
|
+
context of a kernel, our intention is to actually "call" the statement, e.g
|
34
|
+
the following will produce type errors with pyright or mypy:
|
35
|
+
|
36
|
+
```python
|
37
|
+
from kirin.dialects import math
|
38
|
+
from kirin.prelude import basic_no_opt
|
39
|
+
|
40
|
+
@basic_no_opt
|
41
|
+
def main(x: float):
|
42
|
+
return math.sin(x) # this is a statement, not a function
|
43
|
+
```
|
44
|
+
|
45
|
+
the `@lowering.wraps` decorator allows us to provide a type hint for the
|
46
|
+
statement, e.g:
|
47
|
+
|
48
|
+
```python
|
49
|
+
from kirin import lowering
|
50
|
+
|
51
|
+
@lowering.wraps(math.sin)
|
52
|
+
def sin(value: float) -> float: ...
|
53
|
+
|
54
|
+
@basic_no_opt
|
55
|
+
def main(x: float):
|
56
|
+
return sin(x) # linter now thinks this is a function
|
57
|
+
|
58
|
+
sin(1.0) # this will raise a NotImplementedError("Binding of sin can only be called from a kernel")
|
59
|
+
```
|
60
|
+
"""
|
61
|
+
|
62
|
+
def wrapper(func: Callable[Params, RetType]) -> Binding[Params, RetType]:
|
63
|
+
return Binding(parent)
|
64
|
+
|
65
|
+
return wrapper
|
kirin/lowering/core.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
import textwrap
|
4
|
+
from types import ModuleType
|
5
|
+
from typing import Any, Callable, Iterable
|
6
|
+
from dataclasses import dataclass
|
7
|
+
|
8
|
+
from kirin.ir import Dialect, DialectGroup
|
9
|
+
from kirin.exceptions import DialectLoweringError
|
10
|
+
from kirin.lowering.state import LoweringState
|
11
|
+
from kirin.lowering.dialect import FromPythonAST
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass(init=False)
|
15
|
+
class Lowering(ast.NodeVisitor):
|
16
|
+
dialects: DialectGroup
|
17
|
+
registry: dict[str, FromPythonAST]
|
18
|
+
state: LoweringState | None = None
|
19
|
+
|
20
|
+
# max lines to show in error hint
|
21
|
+
max_lines: int = 3
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
dialects: DialectGroup | Iterable[Dialect | ModuleType],
|
26
|
+
keys: list[str] | None = None,
|
27
|
+
max_lines: int = 3,
|
28
|
+
):
|
29
|
+
if isinstance(dialects, DialectGroup):
|
30
|
+
self.dialects = dialects
|
31
|
+
else:
|
32
|
+
self.dialects = DialectGroup(dialects)
|
33
|
+
|
34
|
+
self.max_lines = max_lines
|
35
|
+
self.registry: dict[str, FromPythonAST] = self.dialects.registry.ast(
|
36
|
+
keys=keys or ["main", "default"]
|
37
|
+
)
|
38
|
+
self.state = None
|
39
|
+
|
40
|
+
def run(
|
41
|
+
self,
|
42
|
+
stmt: ast.stmt | Callable,
|
43
|
+
source: str | None = None,
|
44
|
+
globals: dict[str, Any] | None = None,
|
45
|
+
lineno_offset: int = 0,
|
46
|
+
col_offset: int = 0,
|
47
|
+
compactify: bool = True,
|
48
|
+
):
|
49
|
+
if isinstance(stmt, Callable):
|
50
|
+
source = source or textwrap.dedent(inspect.getsource(stmt))
|
51
|
+
globals = globals or stmt.__globals__
|
52
|
+
try:
|
53
|
+
nonlocals = inspect.getclosurevars(stmt).nonlocals
|
54
|
+
except Exception:
|
55
|
+
nonlocals = {}
|
56
|
+
globals.update(nonlocals)
|
57
|
+
stmt = ast.parse(source).body[0]
|
58
|
+
|
59
|
+
state = LoweringState.from_stmt(
|
60
|
+
self, stmt, source, globals, self.max_lines, lineno_offset, col_offset
|
61
|
+
)
|
62
|
+
try:
|
63
|
+
state.visit(stmt)
|
64
|
+
except DialectLoweringError as e:
|
65
|
+
e.args = (f"{e.args[0]}\n\n{state.error_hint()}",) + e.args[1:]
|
66
|
+
raise e
|
67
|
+
|
68
|
+
if compactify:
|
69
|
+
from kirin.rewrite import Walk, CFGCompactify
|
70
|
+
|
71
|
+
Walk(CFGCompactify()).rewrite(state.code)
|
72
|
+
return state.code
|
@@ -0,0 +1,35 @@
|
|
1
|
+
# NOTE: this module is only interface, will be used inside
|
2
|
+
# the `ir` module try to minimize the dependencies as much
|
3
|
+
# as possible
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
|
7
|
+
import ast
|
8
|
+
from abc import ABC
|
9
|
+
from typing import TYPE_CHECKING
|
10
|
+
|
11
|
+
from kirin.exceptions import DialectLoweringError
|
12
|
+
from kirin.lowering.result import Result
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from kirin.lowering.state import LoweringState
|
16
|
+
|
17
|
+
|
18
|
+
class FromPythonAST(ABC):
|
19
|
+
|
20
|
+
@property
|
21
|
+
def names(self) -> list[str]: # show the name without lower_
|
22
|
+
return [name[6:] for name in dir(self) if name.startswith("lower_")]
|
23
|
+
|
24
|
+
def lower(self, state: LoweringState, node: ast.AST) -> Result:
|
25
|
+
"""Entry point of dialect specific lowering."""
|
26
|
+
return getattr(self, f"lower_{node.__class__.__name__}", self.unreachable)(
|
27
|
+
state, node
|
28
|
+
)
|
29
|
+
|
30
|
+
def unreachable(self, state: LoweringState, node: ast.AST) -> Result:
|
31
|
+
raise DialectLoweringError(f"unreachable reached for {node.__class__.__name__}")
|
32
|
+
|
33
|
+
|
34
|
+
class NoSpecialLowering(FromPythonAST):
|
35
|
+
pass
|