kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/nodes/stmt.py
ADDED
@@ -0,0 +1,713 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Mapping, TypeVar, ClassVar, Iterator, Sequence
|
4
|
+
from dataclasses import field, dataclass
|
5
|
+
|
6
|
+
from typing_extensions import Self
|
7
|
+
|
8
|
+
from kirin.print import Printer, Printable
|
9
|
+
from kirin.ir.ssa import SSAValue, ResultValue
|
10
|
+
from kirin.ir.use import Use
|
11
|
+
from kirin.ir.traits import StmtTrait
|
12
|
+
from kirin.ir.attrs.abc import Attribute
|
13
|
+
from kirin.ir.nodes.base import IRNode
|
14
|
+
from kirin.ir.nodes.view import MutableSequenceView
|
15
|
+
from kirin.ir.nodes.block import Block
|
16
|
+
from kirin.ir.nodes.region import Region
|
17
|
+
|
18
|
+
if TYPE_CHECKING:
|
19
|
+
from kirin.source import SourceInfo
|
20
|
+
from kirin.ir.dialect import Dialect
|
21
|
+
from kirin.ir.attrs.types import TypeAttribute
|
22
|
+
from kirin.ir.nodes.block import Block
|
23
|
+
from kirin.ir.nodes.region import Region
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class ArgumentList(
|
28
|
+
MutableSequenceView[tuple[SSAValue, ...], "Statement", SSAValue], Printable
|
29
|
+
):
|
30
|
+
"""A View object that contains a list of Arguemnts of a Statement.
|
31
|
+
|
32
|
+
Description:
|
33
|
+
This is a proxy object that provide safe API to manipulate the arguemnts of a statement.
|
34
|
+
|
35
|
+
!!! note "Pretty Printing"
|
36
|
+
This object is pretty printable via
|
37
|
+
[`.print()`][kirin.print.printable.Printable.print] method.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __delitem__(self, idx: int) -> None:
|
41
|
+
arg = self.field[idx]
|
42
|
+
arg.remove_use(Use(self.node, idx))
|
43
|
+
new_args = (*self.field[:idx], *self.field[idx + 1 :])
|
44
|
+
self.node._args = new_args
|
45
|
+
self.field = new_args
|
46
|
+
|
47
|
+
def set_item(self, idx: int, value: SSAValue) -> None:
|
48
|
+
"""Set the argument SSAVAlue at the specified index.
|
49
|
+
|
50
|
+
Args:
|
51
|
+
idx (int): The index of the item to set.
|
52
|
+
value (SSAValue): The value to set.
|
53
|
+
"""
|
54
|
+
args = self.field
|
55
|
+
args[idx].remove_use(Use(self.node, idx))
|
56
|
+
value.add_use(Use(self.node, idx))
|
57
|
+
new_args = (*args[:idx], value, *args[idx + 1 :])
|
58
|
+
self.node._args = new_args
|
59
|
+
self.field = new_args
|
60
|
+
|
61
|
+
def insert(self, idx: int, value: SSAValue) -> None:
|
62
|
+
"""Insert the argument SSAValue at the specified index.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
idx (int): The index to insert the value.
|
66
|
+
value (SSAValue): The value to insert.
|
67
|
+
"""
|
68
|
+
args = self.field
|
69
|
+
value.add_use(Use(self.node, idx))
|
70
|
+
new_args = (*args[:idx], value, *args[idx:])
|
71
|
+
self.node._args = new_args
|
72
|
+
self.field = new_args
|
73
|
+
|
74
|
+
def get_slice(self, name: str) -> slice:
|
75
|
+
"""Get the slice of the arguments.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
name (str): The name of the slice.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
slice: The slice of the arguments.
|
82
|
+
"""
|
83
|
+
index = self.node._name_args_slice[name]
|
84
|
+
if isinstance(index, int):
|
85
|
+
return slice(index, index + 1)
|
86
|
+
return index
|
87
|
+
|
88
|
+
def print_impl(self, printer: Printer) -> None:
|
89
|
+
printer.print_seq(self.field, delim=", ", prefix="[", suffix="]")
|
90
|
+
|
91
|
+
|
92
|
+
@dataclass
|
93
|
+
class ResultList(MutableSequenceView[list[ResultValue], "Statement", ResultValue]):
|
94
|
+
"""A View object that contains a list of ResultValue of a Statement.
|
95
|
+
|
96
|
+
Description:
|
97
|
+
This is a proxy object that provide safe API to manipulate the result values of a statement
|
98
|
+
|
99
|
+
!!! note "Pretty Printing"
|
100
|
+
This object is pretty printable via
|
101
|
+
[`.print()`][kirin.print.printable.Printable.print] method.
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __setitem__(
|
105
|
+
self, idx: int | slice, value: ResultValue | Sequence[ResultValue]
|
106
|
+
) -> None:
|
107
|
+
raise NotImplementedError("Cannot set result value directly")
|
108
|
+
|
109
|
+
def __delitem__(self, idx: int) -> None:
|
110
|
+
result = self.field[idx]
|
111
|
+
del self.field[idx]
|
112
|
+
result.delete()
|
113
|
+
|
114
|
+
@property
|
115
|
+
def types(self) -> Sequence[TypeAttribute]:
|
116
|
+
"""Get the result types of the Statement.
|
117
|
+
|
118
|
+
Returns:
|
119
|
+
Sequence[TypeAttribute]: type of each result value.
|
120
|
+
"""
|
121
|
+
return [result.type for result in self.field]
|
122
|
+
|
123
|
+
|
124
|
+
@dataclass(repr=False)
|
125
|
+
class Statement(IRNode["Block"]):
|
126
|
+
"""The Statment is an instruction in the IR
|
127
|
+
|
128
|
+
!!! note "Pretty Printing"
|
129
|
+
This object is pretty printable via
|
130
|
+
[`.print()`][kirin.print.printable.Printable.print] method.
|
131
|
+
"""
|
132
|
+
|
133
|
+
name: ClassVar[str]
|
134
|
+
dialect: ClassVar[Dialect | None] = field(default=None, init=False, repr=False)
|
135
|
+
traits: ClassVar[frozenset[StmtTrait]]
|
136
|
+
_arg_groups: ClassVar[frozenset[str]] = frozenset()
|
137
|
+
|
138
|
+
_args: tuple[SSAValue, ...] = field(init=False)
|
139
|
+
_results: list[ResultValue] = field(init=False, default_factory=list)
|
140
|
+
successors: list[Block] = field(init=False)
|
141
|
+
_regions: list[Region] = field(init=False)
|
142
|
+
attributes: dict[str, Attribute] = field(init=False)
|
143
|
+
|
144
|
+
parent: Block | None = field(default=None, init=False, repr=False)
|
145
|
+
_next_stmt: Statement | None = field(default=None, init=False, repr=False)
|
146
|
+
_prev_stmt: Statement | None = field(default=None, init=False, repr=False)
|
147
|
+
|
148
|
+
# NOTE: This is only for syntax sugar to provide
|
149
|
+
# access to args via the properties
|
150
|
+
_name_args_slice: dict[str, int | slice] = field(
|
151
|
+
init=False, repr=False, default_factory=dict
|
152
|
+
)
|
153
|
+
source: SourceInfo | None = field(default=None, init=False, repr=False)
|
154
|
+
|
155
|
+
@property
|
156
|
+
def parent_stmt(self) -> Statement | None:
|
157
|
+
"""Get the parent statement.
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Statement | None: The parent statement.
|
161
|
+
"""
|
162
|
+
if not self.parent_node:
|
163
|
+
return None
|
164
|
+
return self.parent_node.parent_stmt
|
165
|
+
|
166
|
+
@property
|
167
|
+
def parent_node(self) -> Block | None:
|
168
|
+
"""Get the parent node.
|
169
|
+
|
170
|
+
Returns:
|
171
|
+
Block | None: The parent node.
|
172
|
+
"""
|
173
|
+
return self.parent
|
174
|
+
|
175
|
+
@parent_node.setter
|
176
|
+
def parent_node(self, parent: Block | None) -> None:
|
177
|
+
"""Set the parent Block."""
|
178
|
+
from kirin.ir.nodes.block import Block
|
179
|
+
|
180
|
+
self.assert_parent(Block, parent)
|
181
|
+
self.parent = parent
|
182
|
+
|
183
|
+
@property
|
184
|
+
def parent_region(self) -> Region | None:
|
185
|
+
"""Get the parent Region.
|
186
|
+
Returns:
|
187
|
+
Region | None: The parent Region.
|
188
|
+
"""
|
189
|
+
if (p := self.parent_node) is not None:
|
190
|
+
return p.parent_node
|
191
|
+
return None
|
192
|
+
|
193
|
+
@property
|
194
|
+
def parent_block(self) -> Block | None:
|
195
|
+
"""Get the parent Block.
|
196
|
+
|
197
|
+
Returns:
|
198
|
+
Block | None: The parent Block.
|
199
|
+
"""
|
200
|
+
return self.parent_node
|
201
|
+
|
202
|
+
@property
|
203
|
+
def next_stmt(self) -> Statement | None:
|
204
|
+
"""Get the next statement."""
|
205
|
+
return self._next_stmt
|
206
|
+
|
207
|
+
@next_stmt.setter
|
208
|
+
def next_stmt(self, stmt: Statement) -> None:
|
209
|
+
"""Set the next statement.
|
210
|
+
|
211
|
+
Note:
|
212
|
+
Do not directly call this API. use `stmt.insert_after(self)` instead.
|
213
|
+
|
214
|
+
"""
|
215
|
+
raise ValueError(
|
216
|
+
"Cannot set next_stmt directly, use stmt.insert_after(self) or stmt.insert_before(self)"
|
217
|
+
)
|
218
|
+
|
219
|
+
@property
|
220
|
+
def prev_stmt(self) -> Statement | None:
|
221
|
+
"""Get the previous statement."""
|
222
|
+
return self._prev_stmt
|
223
|
+
|
224
|
+
@prev_stmt.setter
|
225
|
+
def prev_stmt(self, stmt: Statement) -> None:
|
226
|
+
"""Set the previous statement.
|
227
|
+
|
228
|
+
Note:
|
229
|
+
Do not directly call this API. use `stmt.insert_before(self)` instead
|
230
|
+
|
231
|
+
"""
|
232
|
+
raise ValueError(
|
233
|
+
"Cannot set prev_stmt directly, use stmt.insert_after(self) or stmt.insert_before(self)"
|
234
|
+
)
|
235
|
+
|
236
|
+
def insert_after(self, stmt: Statement) -> None:
|
237
|
+
"""Insert the current Statement after the input Statement.
|
238
|
+
|
239
|
+
Args:
|
240
|
+
stmt (Statement): Input Statement.
|
241
|
+
|
242
|
+
Example:
|
243
|
+
The following example demonstrates how to insert a Statement after another Statement.
|
244
|
+
After `insert_after` is called, `stmt1` will be inserted after `stmt2`, which appears in IR in the order (stmt2 -> stmt1)
|
245
|
+
```python
|
246
|
+
stmt1 = Statement()
|
247
|
+
stmt2 = Statement()
|
248
|
+
stmt1.insert_after(stmt2)
|
249
|
+
```
|
250
|
+
"""
|
251
|
+
if self._next_stmt is not None and self._prev_stmt is not None:
|
252
|
+
raise ValueError(
|
253
|
+
f"Cannot insert before a statement that is already in a block: {self.name}"
|
254
|
+
)
|
255
|
+
|
256
|
+
if stmt._next_stmt is not None:
|
257
|
+
stmt._next_stmt._prev_stmt = self
|
258
|
+
|
259
|
+
self._prev_stmt = stmt
|
260
|
+
self._next_stmt = stmt._next_stmt
|
261
|
+
|
262
|
+
self.parent = stmt.parent
|
263
|
+
stmt._next_stmt = self
|
264
|
+
|
265
|
+
if self.parent:
|
266
|
+
self.parent._stmt_len += 1
|
267
|
+
|
268
|
+
if self._next_stmt is None:
|
269
|
+
self.parent._last_stmt = self
|
270
|
+
|
271
|
+
def insert_before(self, stmt: Statement) -> None:
|
272
|
+
"""Insert the current Statement before the input Statement.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
stmt (Statement): Input Statement.
|
276
|
+
|
277
|
+
Example:
|
278
|
+
The following example demonstrates how to insert a Statement before another Statement.
|
279
|
+
After `insert_before` is called, `stmt1` will be inserted before `stmt2`, which appears in IR in the order (stmt1 -> stmt2)
|
280
|
+
```python
|
281
|
+
stmt1 = Statement()
|
282
|
+
stmt2 = Statement()
|
283
|
+
stmt1.insert_before(stmt2)
|
284
|
+
```
|
285
|
+
"""
|
286
|
+
if self._next_stmt is not None and self._prev_stmt is not None:
|
287
|
+
raise ValueError(
|
288
|
+
f"Cannot insert before a statement that is already in a block: {self.name}"
|
289
|
+
)
|
290
|
+
|
291
|
+
if stmt._prev_stmt is not None:
|
292
|
+
stmt._prev_stmt._next_stmt = self
|
293
|
+
|
294
|
+
self._next_stmt = stmt
|
295
|
+
self._prev_stmt = stmt._prev_stmt
|
296
|
+
|
297
|
+
self.parent = stmt.parent
|
298
|
+
stmt._prev_stmt = self
|
299
|
+
|
300
|
+
if self.parent:
|
301
|
+
self.parent._stmt_len += 1
|
302
|
+
|
303
|
+
if self._prev_stmt is None:
|
304
|
+
self.parent._first_stmt = self
|
305
|
+
|
306
|
+
def replace_by(self, stmt: Statement) -> None:
|
307
|
+
"""Replace the current Statement by the input Statement.
|
308
|
+
|
309
|
+
Args:
|
310
|
+
stmt (Statement): Input Statement.
|
311
|
+
"""
|
312
|
+
stmt.insert_before(self)
|
313
|
+
for result, old_result in zip(stmt._results, self._results):
|
314
|
+
old_result.replace_by(result)
|
315
|
+
if old_result.name:
|
316
|
+
result.name = old_result.name
|
317
|
+
self.delete()
|
318
|
+
|
319
|
+
@property
|
320
|
+
def args(self) -> ArgumentList:
|
321
|
+
"""Get the arguments of the Statement.
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
ArgumentList: The arguments View of the Statement.
|
325
|
+
"""
|
326
|
+
return ArgumentList(self, self._args)
|
327
|
+
|
328
|
+
@args.setter
|
329
|
+
def args(self, args: Sequence[SSAValue]) -> None:
|
330
|
+
"""Set the arguments of the Statement.
|
331
|
+
|
332
|
+
Args:
|
333
|
+
args (Sequence[SSAValue]): The arguments to set.
|
334
|
+
"""
|
335
|
+
new = tuple(args)
|
336
|
+
for idx, arg in enumerate(self._args):
|
337
|
+
arg.remove_use(Use(self, idx))
|
338
|
+
for idx, arg in enumerate(new):
|
339
|
+
arg.add_use(Use(self, idx))
|
340
|
+
self._args = new
|
341
|
+
|
342
|
+
@property
|
343
|
+
def results(self) -> ResultList:
|
344
|
+
"""Get the result values of the Statement.
|
345
|
+
|
346
|
+
Returns:
|
347
|
+
ResultList: The result values View of the Statement.
|
348
|
+
"""
|
349
|
+
return ResultList(self, self._results)
|
350
|
+
|
351
|
+
@property
|
352
|
+
def regions(self) -> list[Region]:
|
353
|
+
"""Get a list of regions of the Statement.
|
354
|
+
|
355
|
+
Returns:
|
356
|
+
list[Region]: The list of regions of the Statement.
|
357
|
+
"""
|
358
|
+
return self._regions
|
359
|
+
|
360
|
+
@regions.setter
|
361
|
+
def regions(self, regions: list[Region]) -> None:
|
362
|
+
"""Set the regions of the Statement."""
|
363
|
+
for region in self._regions:
|
364
|
+
region._parent = None
|
365
|
+
for region in regions:
|
366
|
+
region._parent = self
|
367
|
+
self._regions = regions
|
368
|
+
|
369
|
+
def drop_all_references(self) -> None:
|
370
|
+
"""Remove all the dependency that reference/uses this Statement."""
|
371
|
+
self.parent = None
|
372
|
+
for idx, arg in enumerate(self._args):
|
373
|
+
arg.remove_use(Use(self, idx))
|
374
|
+
for region in self._regions:
|
375
|
+
region.drop_all_references()
|
376
|
+
|
377
|
+
def delete(self, safe: bool = True) -> None:
|
378
|
+
"""Delete the Statement completely from the IR graph.
|
379
|
+
|
380
|
+
Note:
|
381
|
+
This method will detach + remove references of the Statement.
|
382
|
+
|
383
|
+
Args:
|
384
|
+
safe (bool, optional): If True, raise error if there is anything that still reference components in the Statement. Defaults to True.
|
385
|
+
"""
|
386
|
+
self.detach()
|
387
|
+
self.drop_all_references()
|
388
|
+
for result in self._results:
|
389
|
+
result.delete(safe=safe)
|
390
|
+
|
391
|
+
def detach(self) -> None:
|
392
|
+
"""detach the statement from its parent block."""
|
393
|
+
if self.parent is None:
|
394
|
+
return
|
395
|
+
|
396
|
+
parent: Block = self.parent
|
397
|
+
prev_stmt = self.prev_stmt
|
398
|
+
next_stmt = self.next_stmt
|
399
|
+
|
400
|
+
if prev_stmt is not None:
|
401
|
+
prev_stmt._next_stmt = next_stmt
|
402
|
+
self._prev_stmt = None
|
403
|
+
else:
|
404
|
+
assert (
|
405
|
+
parent._first_stmt is self
|
406
|
+
), "Invalid statement, has no prev_stmt but not first_stmt"
|
407
|
+
parent._first_stmt = next_stmt
|
408
|
+
|
409
|
+
if next_stmt is not None:
|
410
|
+
next_stmt._prev_stmt = prev_stmt
|
411
|
+
self._next_stmt = None
|
412
|
+
else:
|
413
|
+
assert (
|
414
|
+
parent._last_stmt is self
|
415
|
+
), "Invalid statement, has no next_stmt but not last_stmt"
|
416
|
+
parent._last_stmt = prev_stmt
|
417
|
+
|
418
|
+
self.parent = None
|
419
|
+
parent._stmt_len -= 1
|
420
|
+
return
|
421
|
+
|
422
|
+
def __post_init__(self):
|
423
|
+
assert self.name != ""
|
424
|
+
assert isinstance(self.name, str)
|
425
|
+
|
426
|
+
def __init__(
|
427
|
+
self,
|
428
|
+
*,
|
429
|
+
args: Sequence[SSAValue] = (),
|
430
|
+
regions: Sequence[Region] = (),
|
431
|
+
successors: Sequence[Block] = (),
|
432
|
+
attributes: Mapping[str, Attribute] = {},
|
433
|
+
results: Sequence[ResultValue] = (),
|
434
|
+
result_types: Sequence[TypeAttribute] = (),
|
435
|
+
args_slice: Mapping[str, int | slice] = {},
|
436
|
+
source: SourceInfo | None = None,
|
437
|
+
) -> None:
|
438
|
+
super().__init__()
|
439
|
+
"""Initialize the Statement.
|
440
|
+
|
441
|
+
Args:
|
442
|
+
arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
|
443
|
+
regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
|
444
|
+
successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
|
445
|
+
attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
|
446
|
+
results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
|
447
|
+
result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
|
448
|
+
args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
|
449
|
+
source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.
|
450
|
+
|
451
|
+
"""
|
452
|
+
self._args = ()
|
453
|
+
self._regions = []
|
454
|
+
self._name_args_slice = dict(args_slice)
|
455
|
+
self.source = source
|
456
|
+
self.args = args
|
457
|
+
|
458
|
+
if results:
|
459
|
+
self._results = list(results)
|
460
|
+
assert (
|
461
|
+
len(result_types) == 0
|
462
|
+
), "expect either results or result_types specified, got both"
|
463
|
+
|
464
|
+
if result_types:
|
465
|
+
self._results = [
|
466
|
+
ResultValue(self, idx, type=type)
|
467
|
+
for idx, type in enumerate(result_types)
|
468
|
+
]
|
469
|
+
|
470
|
+
if not results and not result_types:
|
471
|
+
self._results = list(results)
|
472
|
+
|
473
|
+
self.successors = list(successors)
|
474
|
+
self.attributes = dict(attributes)
|
475
|
+
self.regions = list(regions)
|
476
|
+
|
477
|
+
self.parent = None
|
478
|
+
self._next_stmt = None
|
479
|
+
self._prev_stmt = None
|
480
|
+
self.__post_init__()
|
481
|
+
|
482
|
+
@classmethod
|
483
|
+
def from_stmt(
|
484
|
+
cls,
|
485
|
+
other: Statement,
|
486
|
+
args: Sequence[SSAValue] | None = None,
|
487
|
+
regions: list[Region] | None = None,
|
488
|
+
successors: list[Block] | None = None,
|
489
|
+
attributes: dict[str, Attribute] | None = None,
|
490
|
+
) -> Self:
|
491
|
+
"""Create a similar Statement with new `ResultValue` and without
|
492
|
+
attaching to any parent block. This still references to the old successor
|
493
|
+
and regions.
|
494
|
+
"""
|
495
|
+
obj = cls.__new__(cls)
|
496
|
+
Statement.__init__(
|
497
|
+
obj,
|
498
|
+
args=args or other._args,
|
499
|
+
regions=regions or other._regions,
|
500
|
+
successors=successors or other.successors,
|
501
|
+
attributes=attributes or other.attributes,
|
502
|
+
result_types=[result.type for result in other._results],
|
503
|
+
args_slice=other._name_args_slice,
|
504
|
+
)
|
505
|
+
return obj
|
506
|
+
|
507
|
+
def walk(
|
508
|
+
self,
|
509
|
+
*,
|
510
|
+
reverse: bool = False,
|
511
|
+
region_first: bool = False,
|
512
|
+
include_self: bool = True,
|
513
|
+
) -> Iterator[Statement]:
|
514
|
+
"""Traversal the Statements of Regions.
|
515
|
+
|
516
|
+
Args:
|
517
|
+
reverse (bool, optional): If walk in the reversed manner. Defaults to False.
|
518
|
+
region_first (bool, optional): If the walk should go through the Statement first or the Region of a Statement first. Defaults to False.
|
519
|
+
include_self (bool, optional): If the walk should include the Statement itself. Defaults to True.
|
520
|
+
|
521
|
+
Yields:
|
522
|
+
Iterator[Statement]: An iterator that yield Statements of Blocks in the Region, in the specified order.
|
523
|
+
"""
|
524
|
+
if include_self and not region_first:
|
525
|
+
yield self
|
526
|
+
|
527
|
+
for region in reversed(self.regions) if reverse else self.regions:
|
528
|
+
yield from region.walk(reverse=reverse, region_first=region_first)
|
529
|
+
|
530
|
+
if include_self and region_first:
|
531
|
+
yield self
|
532
|
+
|
533
|
+
def is_structurally_equal(
|
534
|
+
self,
|
535
|
+
other: Self,
|
536
|
+
context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
|
537
|
+
) -> bool:
|
538
|
+
"""Check if the Statement is structurally equal to another Statement.
|
539
|
+
|
540
|
+
Args:
|
541
|
+
other (Self): The other Statelemt to compare with.
|
542
|
+
context (dict[IRNode | SSAValue, IRNode | SSAValue] | None, optional): A map of IRNode/SSAValue to hint that they are equivalent so the check will treat them as equivalent. Defaults to None.
|
543
|
+
|
544
|
+
Returns:
|
545
|
+
bool: True if the IRNode is structurally equal to the other.
|
546
|
+
"""
|
547
|
+
if context is None:
|
548
|
+
context = {}
|
549
|
+
|
550
|
+
if self.name != other.name:
|
551
|
+
return False
|
552
|
+
|
553
|
+
if (
|
554
|
+
len(self.args) != len(other.args)
|
555
|
+
or len(self.regions) != len(other.regions)
|
556
|
+
or len(self.successors) != len(other.successors)
|
557
|
+
or self.attributes != other.attributes
|
558
|
+
):
|
559
|
+
return False
|
560
|
+
|
561
|
+
if (
|
562
|
+
self.parent is not None
|
563
|
+
and other.parent is not None
|
564
|
+
and context.get(self.parent) != other.parent
|
565
|
+
):
|
566
|
+
return False
|
567
|
+
|
568
|
+
if not all(
|
569
|
+
context.get(arg, arg) == other_arg
|
570
|
+
for arg, other_arg in zip(self.args, other.args)
|
571
|
+
):
|
572
|
+
return False
|
573
|
+
|
574
|
+
if not all(
|
575
|
+
context.get(successor, successor) == other_successor
|
576
|
+
for successor, other_successor in zip(self.successors, other.successors)
|
577
|
+
):
|
578
|
+
return False
|
579
|
+
|
580
|
+
if not all(
|
581
|
+
region.is_structurally_equal(other_region, context)
|
582
|
+
for region, other_region in zip(self.regions, other.regions)
|
583
|
+
):
|
584
|
+
return False
|
585
|
+
|
586
|
+
for result, other_result in zip(self._results, other._results):
|
587
|
+
context[result] = other_result
|
588
|
+
|
589
|
+
return True
|
590
|
+
|
591
|
+
def __hash__(self) -> int:
|
592
|
+
return id(self)
|
593
|
+
|
594
|
+
def print_impl(self, printer: Printer) -> None:
|
595
|
+
from kirin.decl import fields as stmt_fields
|
596
|
+
|
597
|
+
printer.print_name(self)
|
598
|
+
printer.plain_print("(")
|
599
|
+
for idx, (name, s) in enumerate(self._name_args_slice.items()):
|
600
|
+
values = self.args[s]
|
601
|
+
if (fields := stmt_fields(self)) and not fields.args[name].print:
|
602
|
+
pass
|
603
|
+
else:
|
604
|
+
with printer.rich(style="orange4"):
|
605
|
+
printer.plain_print(name, "=")
|
606
|
+
|
607
|
+
if isinstance(values, SSAValue):
|
608
|
+
printer.print(values)
|
609
|
+
else:
|
610
|
+
printer.print_seq(values, delim=", ")
|
611
|
+
|
612
|
+
if idx < len(self._name_args_slice) - 1:
|
613
|
+
printer.plain_print(", ")
|
614
|
+
|
615
|
+
# NOTE: args are specified manually without names
|
616
|
+
if not self._name_args_slice and self._args:
|
617
|
+
printer.print_seq(self._args, delim=", ")
|
618
|
+
|
619
|
+
printer.plain_print(")")
|
620
|
+
|
621
|
+
if self.successors:
|
622
|
+
printer.print_seq(
|
623
|
+
(printer.state.block_id[successor] for successor in self.successors),
|
624
|
+
emit=printer.plain_print,
|
625
|
+
delim=", ",
|
626
|
+
prefix="[",
|
627
|
+
suffix="]",
|
628
|
+
)
|
629
|
+
|
630
|
+
if self.regions:
|
631
|
+
printer.print_seq(
|
632
|
+
self.regions,
|
633
|
+
delim=" ",
|
634
|
+
prefix=" (",
|
635
|
+
suffix=")",
|
636
|
+
)
|
637
|
+
|
638
|
+
if self.attributes:
|
639
|
+
printer.plain_print("{")
|
640
|
+
with printer.rich(highlight=True):
|
641
|
+
printer.print_mapping(self.attributes, delim=", ")
|
642
|
+
printer.plain_print("}")
|
643
|
+
|
644
|
+
if self._results:
|
645
|
+
with printer.rich(style="black"):
|
646
|
+
printer.plain_print(" : ")
|
647
|
+
printer.print_seq(
|
648
|
+
[result.type for result in self._results],
|
649
|
+
delim=", ",
|
650
|
+
)
|
651
|
+
|
652
|
+
def get_attr_or_prop(self, key: str) -> Attribute | None:
|
653
|
+
"""Get the attribute or property of the Statement.
|
654
|
+
|
655
|
+
Args:
|
656
|
+
key (str): The key of the attribute or property.
|
657
|
+
|
658
|
+
Returns:
|
659
|
+
Attribute | None: The attribute or property of the Statement.
|
660
|
+
"""
|
661
|
+
return self.attributes.get(key)
|
662
|
+
|
663
|
+
@classmethod
|
664
|
+
def has_trait(cls, trait_type: type[StmtTrait]) -> bool:
|
665
|
+
"""Check if the Statement has a specific trait.
|
666
|
+
|
667
|
+
Args:
|
668
|
+
trait_type (type[StmtTrait]): The type of trait to check for.
|
669
|
+
|
670
|
+
Returns:
|
671
|
+
bool: True if the class has the specified trait, False otherwise.
|
672
|
+
"""
|
673
|
+
for trait in cls.traits:
|
674
|
+
if isinstance(trait, trait_type):
|
675
|
+
return True
|
676
|
+
return False
|
677
|
+
|
678
|
+
TraitType = TypeVar("TraitType", bound=StmtTrait)
|
679
|
+
|
680
|
+
@classmethod
|
681
|
+
def get_trait(cls, trait: type[TraitType]) -> TraitType | None:
|
682
|
+
"""Get the trait of the Statement."""
|
683
|
+
for t in cls.traits:
|
684
|
+
if isinstance(t, trait):
|
685
|
+
return t
|
686
|
+
return None
|
687
|
+
|
688
|
+
def expect_one_result(self) -> ResultValue:
|
689
|
+
"""Check if the statement contain only one result, and return it"""
|
690
|
+
if len(self._results) != 1:
|
691
|
+
raise ValueError(f"expected one result, got {len(self._results)}")
|
692
|
+
return self._results[0]
|
693
|
+
|
694
|
+
# NOTE: statement should implement typecheck
|
695
|
+
# this is done automatically via @statement, but
|
696
|
+
# in the case manualy implementation is needed,
|
697
|
+
# it should be implemented here.
|
698
|
+
# NOTE: not an @abstractmethod to make linter happy
|
699
|
+
def typecheck(self) -> None:
|
700
|
+
"""check the type of the statement.
|
701
|
+
|
702
|
+
Note:
|
703
|
+
1. Statement should implement typecheck.
|
704
|
+
this is done automatically via @statement, but
|
705
|
+
in the case manualy implementation is needed,
|
706
|
+
it should be implemented here.
|
707
|
+
2. This API should be called after all the types are figured out (by typeinfer)
|
708
|
+
|
709
|
+
"""
|
710
|
+
raise NotImplementedError
|
711
|
+
|
712
|
+
def verify(self) -> None:
|
713
|
+
return
|