kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/nodes/block.py
ADDED
@@ -0,0 +1,458 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Iterable, Iterator
|
4
|
+
from dataclasses import field, dataclass
|
5
|
+
from collections.abc import Sequence
|
6
|
+
|
7
|
+
from typing_extensions import Self
|
8
|
+
|
9
|
+
from kirin.print import Printer
|
10
|
+
from kirin.ir.ssa import SSAValue, BlockArgument
|
11
|
+
from kirin.exceptions import VerificationError
|
12
|
+
from kirin.ir.nodes.base import IRNode
|
13
|
+
from kirin.ir.nodes.view import View, MutableSequenceView
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from kirin.ir.nodes.stmt import Statement
|
17
|
+
from kirin.ir.attrs.types import TypeAttribute
|
18
|
+
from kirin.ir.nodes.region import Region
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class BlockArguments(MutableSequenceView[tuple, "Block", BlockArgument]):
|
23
|
+
"""A View object that contains a list of BlockArgument.
|
24
|
+
|
25
|
+
Description:
|
26
|
+
This is a proxy object that provide safe API to manipulate the arguments of a Block.
|
27
|
+
|
28
|
+
|
29
|
+
"""
|
30
|
+
|
31
|
+
def append_from(self, typ: TypeAttribute, name: str | None = None) -> BlockArgument:
|
32
|
+
"""Append a new argument to the Block that this View reference to.
|
33
|
+
|
34
|
+
Description:
|
35
|
+
This method will create a new [`BlockArgument`][kirin.ir.BlockArgument] and append it to the argument list
|
36
|
+
of the reference `Block`.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
typ (TypeAttribute): The type of the argument.
|
40
|
+
name (str | None, optional): name of the argument. Defaults to `None`.
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
BlockArgument: The newly created [`BlockArgument`][kirin.ir.BlockArgument].
|
44
|
+
|
45
|
+
"""
|
46
|
+
new_arg = BlockArgument(self.node, len(self.node._args), typ)
|
47
|
+
if name:
|
48
|
+
new_arg.name = name
|
49
|
+
|
50
|
+
self.node._args += (new_arg,)
|
51
|
+
return new_arg
|
52
|
+
|
53
|
+
def insert_from(
|
54
|
+
self, idx: int, typ: TypeAttribute, name: str | None = None
|
55
|
+
) -> BlockArgument:
|
56
|
+
"""Insert a new argument to the Block that this View reference to.
|
57
|
+
|
58
|
+
Description:
|
59
|
+
This method will create a new `BlockArgument` and insert it to the argument list
|
60
|
+
of the reference Block at the specified index
|
61
|
+
|
62
|
+
Args:
|
63
|
+
idx (int): Insert location index.
|
64
|
+
typ (TypeAttribute): The type of the argument.
|
65
|
+
name (str | None, optional): Name of the argument. Defaults to `None`.
|
66
|
+
|
67
|
+
Returns:
|
68
|
+
BlockArgument: The newly created BlockArgument.
|
69
|
+
"""
|
70
|
+
if idx < 0 or idx > len(self.node._args):
|
71
|
+
raise ValueError("Invalid index")
|
72
|
+
|
73
|
+
new_arg = BlockArgument(self.node, idx, typ)
|
74
|
+
if name:
|
75
|
+
new_arg.name = name
|
76
|
+
|
77
|
+
for arg in self.node._args[idx:]:
|
78
|
+
arg.index += 1
|
79
|
+
self.node._args = self.node._args[:idx] + (new_arg,) + self.node._args[idx:]
|
80
|
+
return new_arg
|
81
|
+
|
82
|
+
def delete(self, arg: BlockArgument, safe: bool = True) -> None:
|
83
|
+
"""Delete a BlockArgument from the Block that this View reference to.
|
84
|
+
|
85
|
+
|
86
|
+
Args:
|
87
|
+
arg (BlockArgument): _description_
|
88
|
+
safe (bool, optional): If True, error will be raised if the BlockArgument has any Use by others. Defaults to True.
|
89
|
+
|
90
|
+
Raises:
|
91
|
+
ValueError: If the argument does not belong to the reference block.
|
92
|
+
"""
|
93
|
+
if arg.block is not self.node:
|
94
|
+
raise ValueError("Attempt to delete an argument that is not in the block")
|
95
|
+
|
96
|
+
for block_arg in self.field[arg.index + 1 :]:
|
97
|
+
block_arg.index -= 1
|
98
|
+
self.node._args = (*self.field[: arg.index], *self.field[arg.index + 1 :])
|
99
|
+
arg.delete(safe=safe)
|
100
|
+
|
101
|
+
def __delitem__(self, idx: int) -> None:
|
102
|
+
self.delete(self.field[idx])
|
103
|
+
|
104
|
+
|
105
|
+
@dataclass
|
106
|
+
class BlockStmtIterator:
|
107
|
+
"""Proxy object to iterate over the Statements in a Block."""
|
108
|
+
|
109
|
+
next_stmt: Statement | None
|
110
|
+
|
111
|
+
def __iter__(self) -> BlockStmtIterator:
|
112
|
+
return self
|
113
|
+
|
114
|
+
def __next__(self) -> Statement:
|
115
|
+
if self.next_stmt is None:
|
116
|
+
raise StopIteration
|
117
|
+
stmt = self.next_stmt
|
118
|
+
self.next_stmt = stmt.next_stmt
|
119
|
+
return stmt
|
120
|
+
|
121
|
+
|
122
|
+
@dataclass
|
123
|
+
class BlockStmtsReverseIterator:
|
124
|
+
"""Proxy object to iterate over the Statements in a Block in reverse order."""
|
125
|
+
|
126
|
+
next_stmt: Statement | None
|
127
|
+
|
128
|
+
def __iter__(self) -> BlockStmtsReverseIterator:
|
129
|
+
return self
|
130
|
+
|
131
|
+
def __next__(self) -> Statement:
|
132
|
+
if self.next_stmt is None:
|
133
|
+
raise StopIteration
|
134
|
+
stmt = self.next_stmt
|
135
|
+
self.next_stmt = stmt.prev_stmt
|
136
|
+
return stmt
|
137
|
+
|
138
|
+
|
139
|
+
@dataclass
|
140
|
+
class BlockStmts(View["Block", "Statement"]):
|
141
|
+
"""A View object that contains a list of Statements.
|
142
|
+
|
143
|
+
Description:
|
144
|
+
This is a proxy object that provide safe API to manipulate the statements of a Block.
|
145
|
+
"""
|
146
|
+
|
147
|
+
def __iter__(self) -> Iterator[Statement]:
|
148
|
+
return BlockStmtIterator(self.node.first_stmt)
|
149
|
+
|
150
|
+
def __len__(self) -> int:
|
151
|
+
return self.node._stmt_len
|
152
|
+
|
153
|
+
def __reversed__(self) -> Iterator[Statement]:
|
154
|
+
return BlockStmtsReverseIterator(self.node.last_stmt)
|
155
|
+
|
156
|
+
def __repr__(self) -> str:
|
157
|
+
return f"BlockStmts(len={len(self)})"
|
158
|
+
|
159
|
+
def __getitem__(self, index: int) -> Statement:
|
160
|
+
raise NotImplementedError("Use at() instead")
|
161
|
+
|
162
|
+
def at(self, index: int) -> Statement:
|
163
|
+
"""This is similar to __getitem__ but due to the nature of the linked list,
|
164
|
+
it is less efficient than __getitem__.
|
165
|
+
|
166
|
+
Args:
|
167
|
+
index (int): Index of the Statement.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
Statement: The Statement at the specified index.
|
171
|
+
"""
|
172
|
+
if index >= len(self):
|
173
|
+
raise IndexError("Index out of range")
|
174
|
+
|
175
|
+
if index < 0:
|
176
|
+
return self._at_reverse(-index - 1)
|
177
|
+
|
178
|
+
return self._at_forward(index)
|
179
|
+
|
180
|
+
def _at_forward(self, index: int) -> Statement:
|
181
|
+
if self.node.first_stmt is None:
|
182
|
+
raise IndexError("Index out of range")
|
183
|
+
|
184
|
+
stmt = self.node.first_stmt
|
185
|
+
for _ in range(index):
|
186
|
+
if stmt is None:
|
187
|
+
raise IndexError("Index out of range")
|
188
|
+
stmt = stmt.next_stmt
|
189
|
+
|
190
|
+
if stmt is None:
|
191
|
+
raise IndexError("Index out of range")
|
192
|
+
return stmt
|
193
|
+
|
194
|
+
def _at_reverse(self, index: int) -> Statement:
|
195
|
+
if self.node.last_stmt is None:
|
196
|
+
raise IndexError("Index out of range")
|
197
|
+
|
198
|
+
stmt = self.node.last_stmt
|
199
|
+
for _ in range(index):
|
200
|
+
if stmt is None:
|
201
|
+
raise IndexError("Index out of range")
|
202
|
+
stmt = stmt.prev_stmt
|
203
|
+
|
204
|
+
if stmt is None:
|
205
|
+
raise IndexError("Index out of range")
|
206
|
+
return stmt
|
207
|
+
|
208
|
+
def append(self, value: Statement) -> None:
|
209
|
+
"""Append a Statement to the reference Block.
|
210
|
+
|
211
|
+
Args:
|
212
|
+
value (Statement): A Statement to be appended.
|
213
|
+
"""
|
214
|
+
from kirin.ir.nodes.stmt import Statement
|
215
|
+
|
216
|
+
if not isinstance(value, Statement):
|
217
|
+
raise ValueError(f"Expected Statement, got {type(value).__name__}")
|
218
|
+
|
219
|
+
if self.node._stmt_len == 0: # empty block
|
220
|
+
value.attach(self.node)
|
221
|
+
self.node._first_stmt = value
|
222
|
+
self.node._last_stmt = value
|
223
|
+
self.node._stmt_len += 1
|
224
|
+
elif self.node._last_stmt:
|
225
|
+
value.insert_after(self.node._last_stmt)
|
226
|
+
else:
|
227
|
+
raise ValueError("Invalid block, last_stmt is None")
|
228
|
+
|
229
|
+
|
230
|
+
@dataclass
|
231
|
+
class Block(IRNode["Region"]):
|
232
|
+
"""
|
233
|
+
Block consist of a list of Statements and optionally input arguments.
|
234
|
+
|
235
|
+
!!! note "Pretty Printing"
|
236
|
+
This object is pretty printable via
|
237
|
+
[`.print()`][kirin.print.printable.Printable.print] method.
|
238
|
+
"""
|
239
|
+
|
240
|
+
_args: tuple[BlockArgument, ...]
|
241
|
+
|
242
|
+
# NOTE: we need linked list since stmts are inserted frequently
|
243
|
+
_first_stmt: Statement | None = field(repr=False)
|
244
|
+
_last_stmt: Statement | None = field(repr=False)
|
245
|
+
_stmt_len: int = field(default=0, repr=False)
|
246
|
+
|
247
|
+
parent: Region | None = field(default=None, repr=False)
|
248
|
+
"""Parent Region of the Block."""
|
249
|
+
|
250
|
+
def __init__(
|
251
|
+
self,
|
252
|
+
stmts: Sequence[Statement] = (),
|
253
|
+
argtypes: Iterable[TypeAttribute] = (),
|
254
|
+
):
|
255
|
+
"""
|
256
|
+
Args:
|
257
|
+
stmts (Sequence[Statement], optional): A list of statements. Defaults to ().
|
258
|
+
argtypes (Iterable[TypeAttribute], optional): The type of the block arguments. Defaults to ().
|
259
|
+
"""
|
260
|
+
super().__init__()
|
261
|
+
self._args = tuple(
|
262
|
+
BlockArgument(self, i, argtype) for i, argtype in enumerate(argtypes)
|
263
|
+
)
|
264
|
+
|
265
|
+
self._first_stmt = None
|
266
|
+
self._last_stmt = None
|
267
|
+
self._first_branch = None
|
268
|
+
self._last_branch = None
|
269
|
+
self._stmt_len = 0
|
270
|
+
self.stmts.extend(stmts)
|
271
|
+
|
272
|
+
@property
|
273
|
+
def parent_stmt(self) -> Statement | None:
|
274
|
+
"""parent statement of the Block."""
|
275
|
+
if self.parent is None:
|
276
|
+
return None
|
277
|
+
return self.parent.parent_node
|
278
|
+
|
279
|
+
@property
|
280
|
+
def parent_node(self) -> Region | None:
|
281
|
+
"""Get parent Region of the Block."""
|
282
|
+
return self.parent
|
283
|
+
|
284
|
+
@parent_node.setter
|
285
|
+
def parent_node(self, parent: Region | None) -> None:
|
286
|
+
"""Set the parent Region of the Block."""
|
287
|
+
from kirin.ir.nodes.region import Region
|
288
|
+
|
289
|
+
self.assert_parent(Region, parent)
|
290
|
+
self.parent = parent
|
291
|
+
|
292
|
+
@property
|
293
|
+
def args(self) -> BlockArguments:
|
294
|
+
"""Get the arguments of the Block.
|
295
|
+
|
296
|
+
Returns:
|
297
|
+
BlockArguments: The arguments view of the Block.
|
298
|
+
"""
|
299
|
+
return BlockArguments(self, self._args)
|
300
|
+
|
301
|
+
@property
|
302
|
+
def first_stmt(self) -> Statement | None:
|
303
|
+
"""Get the first Statement of the Block.
|
304
|
+
|
305
|
+
Returns:
|
306
|
+
Statement | None: The first Statement of the Block.
|
307
|
+
"""
|
308
|
+
return self._first_stmt
|
309
|
+
|
310
|
+
@property
|
311
|
+
def last_stmt(self) -> Statement | None:
|
312
|
+
"""Get the last Statement of the Block.
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
Statement | None: The last Statement of the Block.
|
316
|
+
|
317
|
+
"""
|
318
|
+
return self._last_stmt
|
319
|
+
|
320
|
+
@property
|
321
|
+
def stmts(self) -> BlockStmts:
|
322
|
+
"""Get the list of Statements of the Block.
|
323
|
+
|
324
|
+
Returns:
|
325
|
+
BlockStmts: The Statements of the Block.
|
326
|
+
"""
|
327
|
+
return BlockStmts(self)
|
328
|
+
|
329
|
+
def drop_all_references(self) -> None:
|
330
|
+
"""Remove all the dependency that reference/uses this Block."""
|
331
|
+
self.parent = None
|
332
|
+
for stmt in self.stmts:
|
333
|
+
stmt.drop_all_references()
|
334
|
+
|
335
|
+
def detach(self) -> None:
|
336
|
+
"""Detach this Block from the IR.
|
337
|
+
|
338
|
+
Note:
|
339
|
+
Detach only detach the Block from the IR graph. It does not remove uses that reference the Block.
|
340
|
+
"""
|
341
|
+
if self.parent is None:
|
342
|
+
return
|
343
|
+
|
344
|
+
idx = self.parent[self]
|
345
|
+
del self.parent._blocks[idx]
|
346
|
+
del self.parent._block_idx[self]
|
347
|
+
for block in self.parent._blocks[idx:]:
|
348
|
+
self.parent._block_idx[block] -= 1
|
349
|
+
self.parent = None
|
350
|
+
|
351
|
+
def delete(self, safe: bool = True) -> None:
|
352
|
+
"""Delete the Block completely from the IR.
|
353
|
+
|
354
|
+
Note:
|
355
|
+
This method will detach + remove references of the block.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
safe (bool, optional): If True, raise error if there is anything that still reference components in the block. Defaults to True.
|
359
|
+
"""
|
360
|
+
self.detach()
|
361
|
+
self.drop_all_references()
|
362
|
+
for stmt in self.stmts:
|
363
|
+
stmt.delete(safe=safe)
|
364
|
+
|
365
|
+
def is_structurally_equal(
|
366
|
+
self,
|
367
|
+
other: Self,
|
368
|
+
context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
|
369
|
+
) -> bool:
|
370
|
+
"""Check if the Block is structurally equal to another Block.
|
371
|
+
|
372
|
+
Args:
|
373
|
+
other (Self): The other Block to compare with.
|
374
|
+
context (dict[IRNode | SSAValue, IRNode | SSAValue] | None, optional): A map of IRNode/SSAValue to hint that they are equivalent so the check will treat them as equivalent. Defaults to None.
|
375
|
+
|
376
|
+
Returns:
|
377
|
+
bool: True if the Block is structurally equal to the other Block.
|
378
|
+
"""
|
379
|
+
if context is None:
|
380
|
+
context = {}
|
381
|
+
|
382
|
+
if len(self._args) != len(other._args) or len(self.stmts) != len(other.stmts):
|
383
|
+
return False
|
384
|
+
|
385
|
+
for arg, other_arg in zip(self._args, other._args):
|
386
|
+
if arg.type != other_arg.type:
|
387
|
+
return False
|
388
|
+
context[arg] = other_arg
|
389
|
+
|
390
|
+
context[self] = other
|
391
|
+
if not all(
|
392
|
+
stmt.is_structurally_equal(other_stmt, context)
|
393
|
+
for stmt, other_stmt in zip(self.stmts, other.stmts)
|
394
|
+
):
|
395
|
+
return False
|
396
|
+
|
397
|
+
return True
|
398
|
+
|
399
|
+
def __hash__(self) -> int:
|
400
|
+
return id(self)
|
401
|
+
|
402
|
+
def walk(
|
403
|
+
self, *, reverse: bool = False, region_first: bool = False
|
404
|
+
) -> Iterator[Statement]:
|
405
|
+
"""Traversal the Statements in a Block.
|
406
|
+
|
407
|
+
Args:
|
408
|
+
reverse (bool, optional): If walk in the reversed manner. Defaults to False.
|
409
|
+
region_first (bool, optional): If the walk should go through the Statement first or the Region of a Statement first. Defaults to False.
|
410
|
+
|
411
|
+
Yields:
|
412
|
+
Iterator[Statement]: An iterator that yield Statements in the Block in the specified order.
|
413
|
+
"""
|
414
|
+
for stmt in reversed(self.stmts) if reverse else self.stmts:
|
415
|
+
yield from stmt.walk(reverse=reverse, region_first=region_first)
|
416
|
+
|
417
|
+
def print_impl(self, printer: Printer) -> None:
|
418
|
+
printer.plain_print(printer.state.block_id[self])
|
419
|
+
printer.print_seq(
|
420
|
+
[printer.state.ssa_id[arg] for arg in self.args],
|
421
|
+
delim=", ",
|
422
|
+
prefix="(",
|
423
|
+
suffix="):",
|
424
|
+
emit=printer.plain_print,
|
425
|
+
)
|
426
|
+
|
427
|
+
if printer.analysis is not None:
|
428
|
+
with printer.indent(increase=4, mark=False):
|
429
|
+
for arg in self.args:
|
430
|
+
printer.print_newline()
|
431
|
+
with printer.rich(style="warning"):
|
432
|
+
printer.print_analysis(
|
433
|
+
arg, prefix=f"{printer.state.ssa_id[arg]} --> "
|
434
|
+
)
|
435
|
+
|
436
|
+
with printer.indent(increase=2, mark=False):
|
437
|
+
for stmt in self.stmts:
|
438
|
+
printer.print_newline()
|
439
|
+
printer.print_stmt(stmt)
|
440
|
+
|
441
|
+
def typecheck(self) -> None:
|
442
|
+
"""Checking the types of the Statments in the Block."""
|
443
|
+
for stmt in self.stmts:
|
444
|
+
stmt.typecheck()
|
445
|
+
|
446
|
+
def verify(self) -> None:
|
447
|
+
"""Verify the correctness of the Block.
|
448
|
+
|
449
|
+
Raises:
|
450
|
+
VerificationError: If the Block is not correct.
|
451
|
+
"""
|
452
|
+
from kirin.ir.nodes.stmt import Region
|
453
|
+
|
454
|
+
if not isinstance(self.parent, Region):
|
455
|
+
raise VerificationError(self, "Parent is not a region")
|
456
|
+
|
457
|
+
for stmt in self.stmts:
|
458
|
+
stmt.verify()
|