kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/nodes/view.py
ADDED
@@ -0,0 +1,142 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Generic, TypeVar, Iterator, Sequence, overload
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
from typing_extensions import Self
|
6
|
+
|
7
|
+
from kirin.ir.ssa import SSAValue
|
8
|
+
from kirin.ir.nodes.base import IRNode
|
9
|
+
|
10
|
+
ElemType = TypeVar("ElemType", bound=IRNode | SSAValue)
|
11
|
+
FieldType = TypeVar("FieldType", bound=Sequence)
|
12
|
+
NodeType = TypeVar("NodeType", bound=IRNode | SSAValue)
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass
|
16
|
+
class View(ABC, Generic[NodeType, ElemType]):
|
17
|
+
node: NodeType
|
18
|
+
|
19
|
+
@abstractmethod
|
20
|
+
def __iter__(self) -> Iterator[ElemType]: ...
|
21
|
+
|
22
|
+
@abstractmethod
|
23
|
+
def __len__(self) -> int: ...
|
24
|
+
|
25
|
+
def __bool__(self) -> bool:
|
26
|
+
return bool(len(self))
|
27
|
+
|
28
|
+
def append(self, value: ElemType) -> None:
|
29
|
+
raise NotImplementedError
|
30
|
+
|
31
|
+
def extend(self, values: Sequence[ElemType]) -> None:
|
32
|
+
for value in values:
|
33
|
+
self.append(value)
|
34
|
+
|
35
|
+
def __reversed__(self) -> Iterator[ElemType]:
|
36
|
+
raise NotImplementedError
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class SequenceView(Generic[FieldType, NodeType, ElemType], View[NodeType, ElemType]):
|
41
|
+
field: FieldType
|
42
|
+
|
43
|
+
@classmethod
|
44
|
+
def similar(cls, node: NodeType, field: FieldType) -> Self:
|
45
|
+
return cls(node, field)
|
46
|
+
|
47
|
+
def __iter__(self) -> Iterator[ElemType]:
|
48
|
+
return iter(self.field)
|
49
|
+
|
50
|
+
def __len__(self) -> int:
|
51
|
+
return len(self.field)
|
52
|
+
|
53
|
+
def __reversed__(self) -> Iterator[ElemType]:
|
54
|
+
return reversed(self.field)
|
55
|
+
|
56
|
+
def isempty(self) -> bool:
|
57
|
+
return len(self) == 0
|
58
|
+
|
59
|
+
def __bool__(self) -> bool:
|
60
|
+
return not self.isempty()
|
61
|
+
|
62
|
+
# optional interface
|
63
|
+
@overload
|
64
|
+
def __getitem__(self, idx: int) -> ElemType: ...
|
65
|
+
|
66
|
+
@overload
|
67
|
+
def __getitem__(self, idx: slice) -> Self: ...
|
68
|
+
|
69
|
+
def __getitem__(self, idx: int | slice) -> ElemType | Self:
|
70
|
+
if isinstance(idx, slice):
|
71
|
+
x: FieldType = self.field[idx] # type: ignore
|
72
|
+
return self.similar(self.node, x)
|
73
|
+
else:
|
74
|
+
return self.field[idx]
|
75
|
+
|
76
|
+
|
77
|
+
@dataclass
|
78
|
+
class MutableSequenceView(SequenceView[FieldType, NodeType, ElemType]):
|
79
|
+
@overload
|
80
|
+
def __setitem__(self, idx: int, value: ElemType) -> None: ...
|
81
|
+
|
82
|
+
@overload
|
83
|
+
def __setitem__(self, idx: slice, value: Sequence[ElemType]) -> None: ...
|
84
|
+
|
85
|
+
def __setitem__(
|
86
|
+
self, idx: int | slice, value: ElemType | Sequence[ElemType]
|
87
|
+
) -> None:
|
88
|
+
if isinstance(idx, int) and not isinstance(value, Sequence):
|
89
|
+
return self.set_item(idx, value)
|
90
|
+
elif isinstance(idx, slice):
|
91
|
+
assert isinstance(value, Sequence), "Expected sequence of values"
|
92
|
+
if idx.step is not None: # no need to support step
|
93
|
+
raise ValueError("Slice step is not supported")
|
94
|
+
return self.set_item_slice(idx, value)
|
95
|
+
else:
|
96
|
+
raise TypeError("Expected int or slice")
|
97
|
+
|
98
|
+
def set_item(self, idx: int, value: ElemType) -> None:
|
99
|
+
raise NotImplementedError
|
100
|
+
|
101
|
+
def set_item_slice(self, s: slice, value: Sequence[ElemType]) -> None:
|
102
|
+
# replace the view of slice
|
103
|
+
for idx in range(s.start, s.stop):
|
104
|
+
if idx < len(value):
|
105
|
+
self.set_item(idx, value[idx])
|
106
|
+
else:
|
107
|
+
del self[idx]
|
108
|
+
|
109
|
+
# insert the rest of the values
|
110
|
+
for idx, v in enumerate(value[s.stop - s.start :]):
|
111
|
+
self.insert(idx + s.stop, v)
|
112
|
+
|
113
|
+
def __delitem__(self, idx: int) -> None:
|
114
|
+
raise NotImplementedError
|
115
|
+
|
116
|
+
def insert(self, idx: int, value: ElemType) -> None:
|
117
|
+
raise NotImplementedError
|
118
|
+
|
119
|
+
def pop(self, idx: int = -1) -> ElemType:
|
120
|
+
item = self.field[idx]
|
121
|
+
del self[idx]
|
122
|
+
return item
|
123
|
+
|
124
|
+
def poplast(self) -> ElemType | None:
|
125
|
+
"""Pop the last element from the view.
|
126
|
+
|
127
|
+
Returns:
|
128
|
+
The last element in the view.
|
129
|
+
"""
|
130
|
+
if self:
|
131
|
+
return self.pop(-1)
|
132
|
+
return None
|
133
|
+
|
134
|
+
def popfirst(self) -> ElemType | None:
|
135
|
+
"""Pop the first element from the view.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
The first element in the view.
|
139
|
+
"""
|
140
|
+
if self:
|
141
|
+
return self.pop(0)
|
142
|
+
return None
|
kirin/ir/ssa.py
ADDED
@@ -0,0 +1,204 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import re
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from typing import TYPE_CHECKING, ClassVar
|
6
|
+
from dataclasses import field, dataclass
|
7
|
+
|
8
|
+
from typing_extensions import Self
|
9
|
+
|
10
|
+
from kirin.print import Printer, Printable
|
11
|
+
from kirin.ir.attrs.abc import Attribute
|
12
|
+
from kirin.ir.attrs.types import AnyType, TypeAttribute
|
13
|
+
|
14
|
+
if TYPE_CHECKING:
|
15
|
+
from kirin.ir.use import Use
|
16
|
+
from kirin.ir.nodes.stmt import Statement
|
17
|
+
from kirin.ir.nodes.block import Block
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class SSAValue(ABC, Printable):
|
22
|
+
"""Base class for all SSA values in the IR."""
|
23
|
+
|
24
|
+
type: TypeAttribute = field(default_factory=AnyType, init=False, repr=True)
|
25
|
+
"""The type of this SSA value."""
|
26
|
+
hints: dict[str, Attribute] = field(default_factory=dict, init=False, repr=False)
|
27
|
+
"""Hints for this SSA value."""
|
28
|
+
uses: set[Use] = field(init=False, default_factory=set, repr=False)
|
29
|
+
"""The uses of this SSA value."""
|
30
|
+
_name: str | None = field(init=False, default=None, repr=True)
|
31
|
+
"""The name of this SSA value."""
|
32
|
+
name_pattern: ClassVar[re.Pattern[str]] = re.compile(r"([A-Za-z_$.-][\w$.-]*)")
|
33
|
+
"""The pattern that the name of this SSA value must match."""
|
34
|
+
|
35
|
+
@property
|
36
|
+
@abstractmethod
|
37
|
+
def owner(self) -> Statement | Block:
|
38
|
+
"""The object that owns this SSA value."""
|
39
|
+
...
|
40
|
+
|
41
|
+
@property
|
42
|
+
def name(self) -> str | None:
|
43
|
+
"""The name of this SSA value."""
|
44
|
+
return self._name
|
45
|
+
|
46
|
+
@name.setter
|
47
|
+
def name(self, name: str | None) -> None:
|
48
|
+
if name and not self.name_pattern.fullmatch(name):
|
49
|
+
raise ValueError(f"Invalid name: {name}")
|
50
|
+
self._name = name
|
51
|
+
|
52
|
+
def __repr__(self) -> str:
|
53
|
+
if self.name:
|
54
|
+
return f"{type(self).__name__}({self.name})"
|
55
|
+
return f"{type(self).__name__}({id(self)})"
|
56
|
+
|
57
|
+
def __hash__(self) -> int:
|
58
|
+
return id(self)
|
59
|
+
|
60
|
+
def add_use(self, use: Use) -> Self:
|
61
|
+
"""Add a use to this SSA value."""
|
62
|
+
self.uses.add(use)
|
63
|
+
return self
|
64
|
+
|
65
|
+
def remove_use(self, use: Use) -> Self:
|
66
|
+
"""Remove a use from this SSA value."""
|
67
|
+
# print(use)
|
68
|
+
# assert use in self.uses, "Use not found"
|
69
|
+
if use in self.uses:
|
70
|
+
self.uses.remove(use)
|
71
|
+
return self
|
72
|
+
|
73
|
+
def replace_by(self, other: SSAValue) -> None:
|
74
|
+
"""Replace this SSA value with another SSA value. Update all uses."""
|
75
|
+
for use in self.uses.copy():
|
76
|
+
use.stmt.args[use.index] = other
|
77
|
+
|
78
|
+
if other.name is None and self.name is not None:
|
79
|
+
other.name = self.name
|
80
|
+
|
81
|
+
assert len(self.uses) == 0, "Uses not empty"
|
82
|
+
|
83
|
+
# TODO: also delete BlockArgument from arglist
|
84
|
+
def delete(self, safe: bool = True) -> None:
|
85
|
+
"""Delete this SSA value. If `safe` is `True`, raise an error if there are uses."""
|
86
|
+
if safe and len(self.uses) > 0:
|
87
|
+
raise ValueError("Cannot delete SSA value with uses")
|
88
|
+
self.replace_by(DeletedSSAValue(self))
|
89
|
+
|
90
|
+
def print_impl(self, printer: Printer) -> None:
|
91
|
+
printer.plain_print(printer.state.ssa_id[self])
|
92
|
+
|
93
|
+
|
94
|
+
@dataclass
|
95
|
+
class ResultValue(SSAValue):
|
96
|
+
"""SSAValue that is a result of a [`Statement`][kirin.ir.nodes.stmt.Statement]."""
|
97
|
+
|
98
|
+
stmt: Statement = field(init=False)
|
99
|
+
"""The statement that this value is a result of."""
|
100
|
+
index: int = field(init=False)
|
101
|
+
"""The index of this value in the statement's result list."""
|
102
|
+
|
103
|
+
# NOTE: we will assign AnyType unless specified.
|
104
|
+
# when SSAValue is a ResultValue, the type is inferred
|
105
|
+
# later in the compilation process.
|
106
|
+
def __init__(
|
107
|
+
self, stmt: Statement, index: int, type: TypeAttribute | None = None
|
108
|
+
) -> None:
|
109
|
+
super().__init__()
|
110
|
+
self.type = type or AnyType()
|
111
|
+
self.stmt = stmt
|
112
|
+
self.index = index
|
113
|
+
|
114
|
+
@property
|
115
|
+
def owner(self) -> Statement:
|
116
|
+
return self.stmt
|
117
|
+
|
118
|
+
def __hash__(self) -> int:
|
119
|
+
return id(self)
|
120
|
+
|
121
|
+
def __repr__(self) -> str:
|
122
|
+
if self.type is self.type.top():
|
123
|
+
type_str = ""
|
124
|
+
else:
|
125
|
+
type_str = f"[{self.type}]"
|
126
|
+
|
127
|
+
if self.name:
|
128
|
+
return (
|
129
|
+
f"<{type(self).__name__}{type_str} {self.name}, uses: {len(self.uses)}>"
|
130
|
+
)
|
131
|
+
return f"<{type(self).__name__}{type_str} stmt: {self.stmt.name}, uses: {len(self.uses)}>"
|
132
|
+
|
133
|
+
|
134
|
+
@dataclass
|
135
|
+
class BlockArgument(SSAValue):
|
136
|
+
"""SSAValue that is an argument to a [`Block`][kirin.ir.Block]."""
|
137
|
+
|
138
|
+
block: Block = field(init=False)
|
139
|
+
"""The block that this argument belongs to."""
|
140
|
+
index: int = field(init=False)
|
141
|
+
"""The index of this argument in the block's argument list."""
|
142
|
+
|
143
|
+
def __init__(
|
144
|
+
self, block: Block, index: int, type: TypeAttribute = AnyType()
|
145
|
+
) -> None:
|
146
|
+
super().__init__()
|
147
|
+
self.type = type
|
148
|
+
self.block = block
|
149
|
+
self.index = index
|
150
|
+
|
151
|
+
@property
|
152
|
+
def owner(self) -> Block:
|
153
|
+
return self.block
|
154
|
+
|
155
|
+
def __hash__(self) -> int:
|
156
|
+
return id(self)
|
157
|
+
|
158
|
+
def __repr__(self) -> str:
|
159
|
+
if self.name:
|
160
|
+
return f"<{type(self).__name__}[{self.type}] {self.name}, uses: {len(self.uses)}>"
|
161
|
+
return f"<{type(self).__name__}[{self.type}] index: {self.index}, uses: {len(self.uses)}>"
|
162
|
+
|
163
|
+
def print_impl(self, printer: Printer) -> None:
|
164
|
+
super().print_impl(printer)
|
165
|
+
if not isinstance(self.type, AnyType):
|
166
|
+
with printer.rich(style="comment"):
|
167
|
+
printer.plain_print(" : ")
|
168
|
+
printer.print(self.type)
|
169
|
+
|
170
|
+
|
171
|
+
@dataclass
|
172
|
+
class DeletedSSAValue(SSAValue):
|
173
|
+
value: SSAValue = field(init=False)
|
174
|
+
|
175
|
+
def __init__(self, value: SSAValue) -> None:
|
176
|
+
super().__init__()
|
177
|
+
self.value = value
|
178
|
+
self.type = value.type
|
179
|
+
|
180
|
+
def __hash__(self) -> int:
|
181
|
+
return id(self)
|
182
|
+
|
183
|
+
def __repr__(self) -> str:
|
184
|
+
return f"<{type(self).__name__}[{self.type}] value: {self.value}, uses: {len(self.uses)}>"
|
185
|
+
|
186
|
+
@property
|
187
|
+
def owner(self) -> Statement | Block:
|
188
|
+
return self.value.owner
|
189
|
+
|
190
|
+
|
191
|
+
@dataclass
|
192
|
+
class TestValue(SSAValue):
|
193
|
+
"""Test SSAValue for testing IR construction."""
|
194
|
+
|
195
|
+
def __init__(self, type: TypeAttribute = AnyType()) -> None:
|
196
|
+
super().__init__()
|
197
|
+
self.type = type
|
198
|
+
|
199
|
+
def __hash__(self) -> int:
|
200
|
+
return id(self)
|
201
|
+
|
202
|
+
@property
|
203
|
+
def owner(self) -> Statement | Block:
|
204
|
+
raise NotImplementedError
|
@@ -0,0 +1,36 @@
|
|
1
|
+
"""Kirin IR Traits.
|
2
|
+
|
3
|
+
This module defines the traits that can be used to define the behavior of
|
4
|
+
Kirin IR nodes. The base trait is `StmtTrait`, which is a `dataclass` that
|
5
|
+
implements the `__hash__` and `__eq__` methods.
|
6
|
+
|
7
|
+
There are also some basic traits that are provided for convenience, such as
|
8
|
+
`Pure`, `HasParent`, `ConstantLike`, `IsTerminator`, `NoTerminator`, and
|
9
|
+
`IsolatedFromAbove`.
|
10
|
+
"""
|
11
|
+
|
12
|
+
from .abc import (
|
13
|
+
StmtTrait as StmtTrait,
|
14
|
+
RegionTrait as RegionTrait,
|
15
|
+
PythonLoweringTrait as PythonLoweringTrait,
|
16
|
+
)
|
17
|
+
from .basic import (
|
18
|
+
Pure as Pure,
|
19
|
+
HasParent as HasParent,
|
20
|
+
MaybePure as MaybePure,
|
21
|
+
ConstantLike as ConstantLike,
|
22
|
+
IsTerminator as IsTerminator,
|
23
|
+
NoTerminator as NoTerminator,
|
24
|
+
IsolatedFromAbove as IsolatedFromAbove,
|
25
|
+
)
|
26
|
+
from .symbol import SymbolTable as SymbolTable, SymbolOpInterface as SymbolOpInterface
|
27
|
+
from .callable import (
|
28
|
+
HasSignature as HasSignature,
|
29
|
+
CallableStmtInterface as CallableStmtInterface,
|
30
|
+
)
|
31
|
+
from .lowering.call import FromPythonCall as FromPythonCall
|
32
|
+
from .region.ssacfg import SSACFGRegion as SSACFGRegion
|
33
|
+
from .lowering.context import (
|
34
|
+
FromPythonWith as FromPythonWith,
|
35
|
+
FromPythonWithSingleItem as FromPythonWithSingleItem,
|
36
|
+
)
|
kirin/ir/traits/abc.py
ADDED
@@ -0,0 +1,42 @@
|
|
1
|
+
import ast
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import TYPE_CHECKING, Generic, TypeVar
|
4
|
+
from dataclasses import dataclass
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from kirin import lowering
|
8
|
+
from kirin.ir import Block, Region, Statement
|
9
|
+
from kirin.graph import Graph
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass(frozen=True)
|
13
|
+
class StmtTrait(ABC):
|
14
|
+
"""Base class for all statement traits."""
|
15
|
+
|
16
|
+
def verify(self, stmt: "Statement"):
|
17
|
+
pass
|
18
|
+
|
19
|
+
|
20
|
+
GraphType = TypeVar("GraphType", bound="Graph[Block]")
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass(frozen=True)
|
24
|
+
class RegionTrait(StmtTrait, Generic[GraphType]):
|
25
|
+
"""A trait that indicates the properties of the statement's region."""
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def get_graph(self, region: "Region") -> GraphType: ...
|
29
|
+
|
30
|
+
|
31
|
+
ASTNode = TypeVar("ASTNode", bound=ast.AST)
|
32
|
+
StatementType = TypeVar("StatementType", bound="Statement")
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass(frozen=True)
|
36
|
+
class PythonLoweringTrait(StmtTrait, Generic[StatementType, ASTNode]):
|
37
|
+
"""A trait that indicates that a statement can be lowered from Python AST."""
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def lower(
|
41
|
+
self, stmt: type[StatementType], state: "lowering.LoweringState", node: ASTNode
|
42
|
+
) -> "lowering.Result": ...
|
kirin/ir/traits/basic.py
ADDED
@@ -0,0 +1,78 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from .abc import StmtTrait
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from kirin.ir import Statement
|
8
|
+
|
9
|
+
|
10
|
+
@dataclass(frozen=True)
|
11
|
+
class Pure(StmtTrait):
|
12
|
+
"""A trait that indicates that a statement is pure, i.e., it has no side
|
13
|
+
effects.
|
14
|
+
"""
|
15
|
+
|
16
|
+
pass
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass(frozen=True)
|
20
|
+
class MaybePure(StmtTrait):
|
21
|
+
"""A trait that indicates the statement may be pure,
|
22
|
+
i.e., a call statement can be pure if the callee is pure.
|
23
|
+
"""
|
24
|
+
|
25
|
+
@classmethod
|
26
|
+
def is_pure(cls, stmt: "Statement") -> bool:
|
27
|
+
# TODO: simplify this after removing property
|
28
|
+
from kirin.ir.attrs.py import PyAttr
|
29
|
+
|
30
|
+
purity = stmt.attributes.get("purity")
|
31
|
+
if isinstance(purity, PyAttr) and purity.data:
|
32
|
+
return True
|
33
|
+
return False
|
34
|
+
|
35
|
+
@classmethod
|
36
|
+
def set_pure(cls, stmt: "Statement") -> None:
|
37
|
+
from kirin.ir.attrs.py import PyAttr
|
38
|
+
|
39
|
+
stmt.attributes["purity"] = PyAttr(True)
|
40
|
+
|
41
|
+
|
42
|
+
@dataclass(frozen=True)
|
43
|
+
class ConstantLike(StmtTrait):
|
44
|
+
"""A trait that indicates that a statement is constant-like, i.e., it
|
45
|
+
represents a constant value.
|
46
|
+
"""
|
47
|
+
|
48
|
+
pass
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass(frozen=True)
|
52
|
+
class IsTerminator(StmtTrait):
|
53
|
+
"""A trait that indicates that a statement is a terminator, i.e., it
|
54
|
+
terminates a block.
|
55
|
+
"""
|
56
|
+
|
57
|
+
pass
|
58
|
+
|
59
|
+
|
60
|
+
@dataclass(frozen=True)
|
61
|
+
class NoTerminator(StmtTrait):
|
62
|
+
"""A trait that indicates that the region of a statement has no terminator."""
|
63
|
+
|
64
|
+
pass
|
65
|
+
|
66
|
+
|
67
|
+
@dataclass(frozen=True)
|
68
|
+
class IsolatedFromAbove(StmtTrait):
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
@dataclass(frozen=True)
|
73
|
+
class HasParent(StmtTrait):
|
74
|
+
"""A trait that indicates that a statement has a parent
|
75
|
+
statement.
|
76
|
+
"""
|
77
|
+
|
78
|
+
parents: tuple[type["Statement"]]
|
@@ -0,0 +1,51 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import TYPE_CHECKING, Generic, TypeVar
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
from kirin.ir.traits.abc import StmtTrait
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from kirin.ir import Region, Statement
|
9
|
+
from kirin.dialects.func.attrs import Signature
|
10
|
+
|
11
|
+
StmtType = TypeVar("StmtType", bound="Statement")
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass(frozen=True)
|
15
|
+
class CallableStmtInterface(StmtTrait, Generic[StmtType]):
|
16
|
+
"""A trait that indicates that a statement is a callable statement.
|
17
|
+
|
18
|
+
A callable statement is a statement that can be called as a function.
|
19
|
+
"""
|
20
|
+
|
21
|
+
@classmethod
|
22
|
+
@abstractmethod
|
23
|
+
def get_callable_region(cls, stmt: "StmtType") -> "Region":
|
24
|
+
"""Returns the body of the callable region"""
|
25
|
+
...
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass(frozen=True)
|
29
|
+
class HasSignature(StmtTrait, ABC):
|
30
|
+
"""A trait that indicates that a statement has a function signature
|
31
|
+
attribute.
|
32
|
+
"""
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
def get_signature(cls, stmt: "Statement"):
|
36
|
+
signature: Signature | None = stmt.attributes.get("signature") # type: ignore
|
37
|
+
if signature is None:
|
38
|
+
raise ValueError(f"Statement {stmt.name} does not have a function type")
|
39
|
+
|
40
|
+
return signature
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def set_signature(cls, stmt: "Statement", signature: "Signature"):
|
44
|
+
stmt.attributes["signature"] = signature
|
45
|
+
|
46
|
+
def verify(self, stmt: "Statement"):
|
47
|
+
from kirin.dialects.func.attrs import Signature
|
48
|
+
|
49
|
+
signature = self.get_signature(stmt)
|
50
|
+
if not isinstance(signature, Signature):
|
51
|
+
raise ValueError(f"{signature} is not a Signature attribute")
|
@@ -0,0 +1,37 @@
|
|
1
|
+
import ast
|
2
|
+
from typing import TYPE_CHECKING, TypeVar
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
from ..abc import PythonLoweringTrait
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
from kirin.ir import Statement
|
9
|
+
from kirin.lowering import Result, LoweringState
|
10
|
+
|
11
|
+
|
12
|
+
StatementType = TypeVar("StatementType", bound="Statement")
|
13
|
+
|
14
|
+
|
15
|
+
@dataclass(frozen=True)
|
16
|
+
class FromPythonCall(PythonLoweringTrait[StatementType, ast.Call]):
|
17
|
+
"""Trait for customizing lowering of Python calls to a statement.
|
18
|
+
|
19
|
+
Declared in a statement definition to indicate that the statement can be
|
20
|
+
constructed from a Python call (i.e., a function call `ast.Call` in the
|
21
|
+
Python AST).
|
22
|
+
|
23
|
+
Subclassing this trait allows for customizing the lowering of Python calls
|
24
|
+
to the statement. The `lower` method should be implemented to parse the
|
25
|
+
arguments from the Python call and construct the statement instance.
|
26
|
+
"""
|
27
|
+
|
28
|
+
def lower(
|
29
|
+
self, stmt: type[StatementType], state: "LoweringState", node: ast.Call
|
30
|
+
) -> "Result":
|
31
|
+
return state.default_Call_lower(stmt, node)
|
32
|
+
|
33
|
+
def verify(self, stmt: "Statement"):
|
34
|
+
assert len(stmt.regions) == 0, "FromPythonCall statements cannot have regions"
|
35
|
+
assert (
|
36
|
+
len(stmt.successors) == 0
|
37
|
+
), "FromPythonCall statements cannot have successors"
|