kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/nodes/stmt.py ADDED
@@ -0,0 +1,713 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Mapping, TypeVar, ClassVar, Iterator, Sequence
4
+ from dataclasses import field, dataclass
5
+
6
+ from typing_extensions import Self
7
+
8
+ from kirin.print import Printer, Printable
9
+ from kirin.ir.ssa import SSAValue, ResultValue
10
+ from kirin.ir.use import Use
11
+ from kirin.ir.traits import StmtTrait
12
+ from kirin.ir.attrs.abc import Attribute
13
+ from kirin.ir.nodes.base import IRNode
14
+ from kirin.ir.nodes.view import MutableSequenceView
15
+ from kirin.ir.nodes.block import Block
16
+ from kirin.ir.nodes.region import Region
17
+
18
+ if TYPE_CHECKING:
19
+ from kirin.source import SourceInfo
20
+ from kirin.ir.dialect import Dialect
21
+ from kirin.ir.attrs.types import TypeAttribute
22
+ from kirin.ir.nodes.block import Block
23
+ from kirin.ir.nodes.region import Region
24
+
25
+
26
+ @dataclass
27
+ class ArgumentList(
28
+ MutableSequenceView[tuple[SSAValue, ...], "Statement", SSAValue], Printable
29
+ ):
30
+ """A View object that contains a list of Arguemnts of a Statement.
31
+
32
+ Description:
33
+ This is a proxy object that provide safe API to manipulate the arguemnts of a statement.
34
+
35
+ !!! note "Pretty Printing"
36
+ This object is pretty printable via
37
+ [`.print()`][kirin.print.printable.Printable.print] method.
38
+ """
39
+
40
+ def __delitem__(self, idx: int) -> None:
41
+ arg = self.field[idx]
42
+ arg.remove_use(Use(self.node, idx))
43
+ new_args = (*self.field[:idx], *self.field[idx + 1 :])
44
+ self.node._args = new_args
45
+ self.field = new_args
46
+
47
+ def set_item(self, idx: int, value: SSAValue) -> None:
48
+ """Set the argument SSAVAlue at the specified index.
49
+
50
+ Args:
51
+ idx (int): The index of the item to set.
52
+ value (SSAValue): The value to set.
53
+ """
54
+ args = self.field
55
+ args[idx].remove_use(Use(self.node, idx))
56
+ value.add_use(Use(self.node, idx))
57
+ new_args = (*args[:idx], value, *args[idx + 1 :])
58
+ self.node._args = new_args
59
+ self.field = new_args
60
+
61
+ def insert(self, idx: int, value: SSAValue) -> None:
62
+ """Insert the argument SSAValue at the specified index.
63
+
64
+ Args:
65
+ idx (int): The index to insert the value.
66
+ value (SSAValue): The value to insert.
67
+ """
68
+ args = self.field
69
+ value.add_use(Use(self.node, idx))
70
+ new_args = (*args[:idx], value, *args[idx:])
71
+ self.node._args = new_args
72
+ self.field = new_args
73
+
74
+ def get_slice(self, name: str) -> slice:
75
+ """Get the slice of the arguments.
76
+
77
+ Args:
78
+ name (str): The name of the slice.
79
+
80
+ Returns:
81
+ slice: The slice of the arguments.
82
+ """
83
+ index = self.node._name_args_slice[name]
84
+ if isinstance(index, int):
85
+ return slice(index, index + 1)
86
+ return index
87
+
88
+ def print_impl(self, printer: Printer) -> None:
89
+ printer.print_seq(self.field, delim=", ", prefix="[", suffix="]")
90
+
91
+
92
+ @dataclass
93
+ class ResultList(MutableSequenceView[list[ResultValue], "Statement", ResultValue]):
94
+ """A View object that contains a list of ResultValue of a Statement.
95
+
96
+ Description:
97
+ This is a proxy object that provide safe API to manipulate the result values of a statement
98
+
99
+ !!! note "Pretty Printing"
100
+ This object is pretty printable via
101
+ [`.print()`][kirin.print.printable.Printable.print] method.
102
+ """
103
+
104
+ def __setitem__(
105
+ self, idx: int | slice, value: ResultValue | Sequence[ResultValue]
106
+ ) -> None:
107
+ raise NotImplementedError("Cannot set result value directly")
108
+
109
+ def __delitem__(self, idx: int) -> None:
110
+ result = self.field[idx]
111
+ del self.field[idx]
112
+ result.delete()
113
+
114
+ @property
115
+ def types(self) -> Sequence[TypeAttribute]:
116
+ """Get the result types of the Statement.
117
+
118
+ Returns:
119
+ Sequence[TypeAttribute]: type of each result value.
120
+ """
121
+ return [result.type for result in self.field]
122
+
123
+
124
+ @dataclass(repr=False)
125
+ class Statement(IRNode["Block"]):
126
+ """The Statment is an instruction in the IR
127
+
128
+ !!! note "Pretty Printing"
129
+ This object is pretty printable via
130
+ [`.print()`][kirin.print.printable.Printable.print] method.
131
+ """
132
+
133
+ name: ClassVar[str]
134
+ dialect: ClassVar[Dialect | None] = field(default=None, init=False, repr=False)
135
+ traits: ClassVar[frozenset[StmtTrait]]
136
+ _arg_groups: ClassVar[frozenset[str]] = frozenset()
137
+
138
+ _args: tuple[SSAValue, ...] = field(init=False)
139
+ _results: list[ResultValue] = field(init=False, default_factory=list)
140
+ successors: list[Block] = field(init=False)
141
+ _regions: list[Region] = field(init=False)
142
+ attributes: dict[str, Attribute] = field(init=False)
143
+
144
+ parent: Block | None = field(default=None, init=False, repr=False)
145
+ _next_stmt: Statement | None = field(default=None, init=False, repr=False)
146
+ _prev_stmt: Statement | None = field(default=None, init=False, repr=False)
147
+
148
+ # NOTE: This is only for syntax sugar to provide
149
+ # access to args via the properties
150
+ _name_args_slice: dict[str, int | slice] = field(
151
+ init=False, repr=False, default_factory=dict
152
+ )
153
+ source: SourceInfo | None = field(default=None, init=False, repr=False)
154
+
155
+ @property
156
+ def parent_stmt(self) -> Statement | None:
157
+ """Get the parent statement.
158
+
159
+ Returns:
160
+ Statement | None: The parent statement.
161
+ """
162
+ if not self.parent_node:
163
+ return None
164
+ return self.parent_node.parent_stmt
165
+
166
+ @property
167
+ def parent_node(self) -> Block | None:
168
+ """Get the parent node.
169
+
170
+ Returns:
171
+ Block | None: The parent node.
172
+ """
173
+ return self.parent
174
+
175
+ @parent_node.setter
176
+ def parent_node(self, parent: Block | None) -> None:
177
+ """Set the parent Block."""
178
+ from kirin.ir.nodes.block import Block
179
+
180
+ self.assert_parent(Block, parent)
181
+ self.parent = parent
182
+
183
+ @property
184
+ def parent_region(self) -> Region | None:
185
+ """Get the parent Region.
186
+ Returns:
187
+ Region | None: The parent Region.
188
+ """
189
+ if (p := self.parent_node) is not None:
190
+ return p.parent_node
191
+ return None
192
+
193
+ @property
194
+ def parent_block(self) -> Block | None:
195
+ """Get the parent Block.
196
+
197
+ Returns:
198
+ Block | None: The parent Block.
199
+ """
200
+ return self.parent_node
201
+
202
+ @property
203
+ def next_stmt(self) -> Statement | None:
204
+ """Get the next statement."""
205
+ return self._next_stmt
206
+
207
+ @next_stmt.setter
208
+ def next_stmt(self, stmt: Statement) -> None:
209
+ """Set the next statement.
210
+
211
+ Note:
212
+ Do not directly call this API. use `stmt.insert_after(self)` instead.
213
+
214
+ """
215
+ raise ValueError(
216
+ "Cannot set next_stmt directly, use stmt.insert_after(self) or stmt.insert_before(self)"
217
+ )
218
+
219
+ @property
220
+ def prev_stmt(self) -> Statement | None:
221
+ """Get the previous statement."""
222
+ return self._prev_stmt
223
+
224
+ @prev_stmt.setter
225
+ def prev_stmt(self, stmt: Statement) -> None:
226
+ """Set the previous statement.
227
+
228
+ Note:
229
+ Do not directly call this API. use `stmt.insert_before(self)` instead
230
+
231
+ """
232
+ raise ValueError(
233
+ "Cannot set prev_stmt directly, use stmt.insert_after(self) or stmt.insert_before(self)"
234
+ )
235
+
236
+ def insert_after(self, stmt: Statement) -> None:
237
+ """Insert the current Statement after the input Statement.
238
+
239
+ Args:
240
+ stmt (Statement): Input Statement.
241
+
242
+ Example:
243
+ The following example demonstrates how to insert a Statement after another Statement.
244
+ After `insert_after` is called, `stmt1` will be inserted after `stmt2`, which appears in IR in the order (stmt2 -> stmt1)
245
+ ```python
246
+ stmt1 = Statement()
247
+ stmt2 = Statement()
248
+ stmt1.insert_after(stmt2)
249
+ ```
250
+ """
251
+ if self._next_stmt is not None and self._prev_stmt is not None:
252
+ raise ValueError(
253
+ f"Cannot insert before a statement that is already in a block: {self.name}"
254
+ )
255
+
256
+ if stmt._next_stmt is not None:
257
+ stmt._next_stmt._prev_stmt = self
258
+
259
+ self._prev_stmt = stmt
260
+ self._next_stmt = stmt._next_stmt
261
+
262
+ self.parent = stmt.parent
263
+ stmt._next_stmt = self
264
+
265
+ if self.parent:
266
+ self.parent._stmt_len += 1
267
+
268
+ if self._next_stmt is None:
269
+ self.parent._last_stmt = self
270
+
271
+ def insert_before(self, stmt: Statement) -> None:
272
+ """Insert the current Statement before the input Statement.
273
+
274
+ Args:
275
+ stmt (Statement): Input Statement.
276
+
277
+ Example:
278
+ The following example demonstrates how to insert a Statement before another Statement.
279
+ After `insert_before` is called, `stmt1` will be inserted before `stmt2`, which appears in IR in the order (stmt1 -> stmt2)
280
+ ```python
281
+ stmt1 = Statement()
282
+ stmt2 = Statement()
283
+ stmt1.insert_before(stmt2)
284
+ ```
285
+ """
286
+ if self._next_stmt is not None and self._prev_stmt is not None:
287
+ raise ValueError(
288
+ f"Cannot insert before a statement that is already in a block: {self.name}"
289
+ )
290
+
291
+ if stmt._prev_stmt is not None:
292
+ stmt._prev_stmt._next_stmt = self
293
+
294
+ self._next_stmt = stmt
295
+ self._prev_stmt = stmt._prev_stmt
296
+
297
+ self.parent = stmt.parent
298
+ stmt._prev_stmt = self
299
+
300
+ if self.parent:
301
+ self.parent._stmt_len += 1
302
+
303
+ if self._prev_stmt is None:
304
+ self.parent._first_stmt = self
305
+
306
+ def replace_by(self, stmt: Statement) -> None:
307
+ """Replace the current Statement by the input Statement.
308
+
309
+ Args:
310
+ stmt (Statement): Input Statement.
311
+ """
312
+ stmt.insert_before(self)
313
+ for result, old_result in zip(stmt._results, self._results):
314
+ old_result.replace_by(result)
315
+ if old_result.name:
316
+ result.name = old_result.name
317
+ self.delete()
318
+
319
+ @property
320
+ def args(self) -> ArgumentList:
321
+ """Get the arguments of the Statement.
322
+
323
+ Returns:
324
+ ArgumentList: The arguments View of the Statement.
325
+ """
326
+ return ArgumentList(self, self._args)
327
+
328
+ @args.setter
329
+ def args(self, args: Sequence[SSAValue]) -> None:
330
+ """Set the arguments of the Statement.
331
+
332
+ Args:
333
+ args (Sequence[SSAValue]): The arguments to set.
334
+ """
335
+ new = tuple(args)
336
+ for idx, arg in enumerate(self._args):
337
+ arg.remove_use(Use(self, idx))
338
+ for idx, arg in enumerate(new):
339
+ arg.add_use(Use(self, idx))
340
+ self._args = new
341
+
342
+ @property
343
+ def results(self) -> ResultList:
344
+ """Get the result values of the Statement.
345
+
346
+ Returns:
347
+ ResultList: The result values View of the Statement.
348
+ """
349
+ return ResultList(self, self._results)
350
+
351
+ @property
352
+ def regions(self) -> list[Region]:
353
+ """Get a list of regions of the Statement.
354
+
355
+ Returns:
356
+ list[Region]: The list of regions of the Statement.
357
+ """
358
+ return self._regions
359
+
360
+ @regions.setter
361
+ def regions(self, regions: list[Region]) -> None:
362
+ """Set the regions of the Statement."""
363
+ for region in self._regions:
364
+ region._parent = None
365
+ for region in regions:
366
+ region._parent = self
367
+ self._regions = regions
368
+
369
+ def drop_all_references(self) -> None:
370
+ """Remove all the dependency that reference/uses this Statement."""
371
+ self.parent = None
372
+ for idx, arg in enumerate(self._args):
373
+ arg.remove_use(Use(self, idx))
374
+ for region in self._regions:
375
+ region.drop_all_references()
376
+
377
+ def delete(self, safe: bool = True) -> None:
378
+ """Delete the Statement completely from the IR graph.
379
+
380
+ Note:
381
+ This method will detach + remove references of the Statement.
382
+
383
+ Args:
384
+ safe (bool, optional): If True, raise error if there is anything that still reference components in the Statement. Defaults to True.
385
+ """
386
+ self.detach()
387
+ self.drop_all_references()
388
+ for result in self._results:
389
+ result.delete(safe=safe)
390
+
391
+ def detach(self) -> None:
392
+ """detach the statement from its parent block."""
393
+ if self.parent is None:
394
+ return
395
+
396
+ parent: Block = self.parent
397
+ prev_stmt = self.prev_stmt
398
+ next_stmt = self.next_stmt
399
+
400
+ if prev_stmt is not None:
401
+ prev_stmt._next_stmt = next_stmt
402
+ self._prev_stmt = None
403
+ else:
404
+ assert (
405
+ parent._first_stmt is self
406
+ ), "Invalid statement, has no prev_stmt but not first_stmt"
407
+ parent._first_stmt = next_stmt
408
+
409
+ if next_stmt is not None:
410
+ next_stmt._prev_stmt = prev_stmt
411
+ self._next_stmt = None
412
+ else:
413
+ assert (
414
+ parent._last_stmt is self
415
+ ), "Invalid statement, has no next_stmt but not last_stmt"
416
+ parent._last_stmt = prev_stmt
417
+
418
+ self.parent = None
419
+ parent._stmt_len -= 1
420
+ return
421
+
422
+ def __post_init__(self):
423
+ assert self.name != ""
424
+ assert isinstance(self.name, str)
425
+
426
+ def __init__(
427
+ self,
428
+ *,
429
+ args: Sequence[SSAValue] = (),
430
+ regions: Sequence[Region] = (),
431
+ successors: Sequence[Block] = (),
432
+ attributes: Mapping[str, Attribute] = {},
433
+ results: Sequence[ResultValue] = (),
434
+ result_types: Sequence[TypeAttribute] = (),
435
+ args_slice: Mapping[str, int | slice] = {},
436
+ source: SourceInfo | None = None,
437
+ ) -> None:
438
+ super().__init__()
439
+ """Initialize the Statement.
440
+
441
+ Args:
442
+ arsg (Sequence[SSAValue], optional): The arguments of the Statement. Defaults to ().
443
+ regions (Sequence[Region], optional): The regions where the Statement belong to. Defaults to ().
444
+ successors (Sequence[Block], optional): The successors of the Statement. Defaults to ().
445
+ attributes (Mapping[str, Attribute], optional): The attributes of the Statement. Defaults to {}.
446
+ results (Sequence[ResultValue], optional): The result values of the Statement. Defaults to ().
447
+ result_types (Sequence[TypeAttribute], optional): The result types of the Statement. Defaults to ().
448
+ args_slice (Mapping[str, int | slice], optional): The arguments slice of the Statement. Defaults to {}.
449
+ source (SourceInfo | None, optional): The source information of the Statement for debugging/stacktracing. Defaults to None.
450
+
451
+ """
452
+ self._args = ()
453
+ self._regions = []
454
+ self._name_args_slice = dict(args_slice)
455
+ self.source = source
456
+ self.args = args
457
+
458
+ if results:
459
+ self._results = list(results)
460
+ assert (
461
+ len(result_types) == 0
462
+ ), "expect either results or result_types specified, got both"
463
+
464
+ if result_types:
465
+ self._results = [
466
+ ResultValue(self, idx, type=type)
467
+ for idx, type in enumerate(result_types)
468
+ ]
469
+
470
+ if not results and not result_types:
471
+ self._results = list(results)
472
+
473
+ self.successors = list(successors)
474
+ self.attributes = dict(attributes)
475
+ self.regions = list(regions)
476
+
477
+ self.parent = None
478
+ self._next_stmt = None
479
+ self._prev_stmt = None
480
+ self.__post_init__()
481
+
482
+ @classmethod
483
+ def from_stmt(
484
+ cls,
485
+ other: Statement,
486
+ args: Sequence[SSAValue] | None = None,
487
+ regions: list[Region] | None = None,
488
+ successors: list[Block] | None = None,
489
+ attributes: dict[str, Attribute] | None = None,
490
+ ) -> Self:
491
+ """Create a similar Statement with new `ResultValue` and without
492
+ attaching to any parent block. This still references to the old successor
493
+ and regions.
494
+ """
495
+ obj = cls.__new__(cls)
496
+ Statement.__init__(
497
+ obj,
498
+ args=args or other._args,
499
+ regions=regions or other._regions,
500
+ successors=successors or other.successors,
501
+ attributes=attributes or other.attributes,
502
+ result_types=[result.type for result in other._results],
503
+ args_slice=other._name_args_slice,
504
+ )
505
+ return obj
506
+
507
+ def walk(
508
+ self,
509
+ *,
510
+ reverse: bool = False,
511
+ region_first: bool = False,
512
+ include_self: bool = True,
513
+ ) -> Iterator[Statement]:
514
+ """Traversal the Statements of Regions.
515
+
516
+ Args:
517
+ reverse (bool, optional): If walk in the reversed manner. Defaults to False.
518
+ region_first (bool, optional): If the walk should go through the Statement first or the Region of a Statement first. Defaults to False.
519
+ include_self (bool, optional): If the walk should include the Statement itself. Defaults to True.
520
+
521
+ Yields:
522
+ Iterator[Statement]: An iterator that yield Statements of Blocks in the Region, in the specified order.
523
+ """
524
+ if include_self and not region_first:
525
+ yield self
526
+
527
+ for region in reversed(self.regions) if reverse else self.regions:
528
+ yield from region.walk(reverse=reverse, region_first=region_first)
529
+
530
+ if include_self and region_first:
531
+ yield self
532
+
533
+ def is_structurally_equal(
534
+ self,
535
+ other: Self,
536
+ context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
537
+ ) -> bool:
538
+ """Check if the Statement is structurally equal to another Statement.
539
+
540
+ Args:
541
+ other (Self): The other Statelemt to compare with.
542
+ context (dict[IRNode | SSAValue, IRNode | SSAValue] | None, optional): A map of IRNode/SSAValue to hint that they are equivalent so the check will treat them as equivalent. Defaults to None.
543
+
544
+ Returns:
545
+ bool: True if the IRNode is structurally equal to the other.
546
+ """
547
+ if context is None:
548
+ context = {}
549
+
550
+ if self.name != other.name:
551
+ return False
552
+
553
+ if (
554
+ len(self.args) != len(other.args)
555
+ or len(self.regions) != len(other.regions)
556
+ or len(self.successors) != len(other.successors)
557
+ or self.attributes != other.attributes
558
+ ):
559
+ return False
560
+
561
+ if (
562
+ self.parent is not None
563
+ and other.parent is not None
564
+ and context.get(self.parent) != other.parent
565
+ ):
566
+ return False
567
+
568
+ if not all(
569
+ context.get(arg, arg) == other_arg
570
+ for arg, other_arg in zip(self.args, other.args)
571
+ ):
572
+ return False
573
+
574
+ if not all(
575
+ context.get(successor, successor) == other_successor
576
+ for successor, other_successor in zip(self.successors, other.successors)
577
+ ):
578
+ return False
579
+
580
+ if not all(
581
+ region.is_structurally_equal(other_region, context)
582
+ for region, other_region in zip(self.regions, other.regions)
583
+ ):
584
+ return False
585
+
586
+ for result, other_result in zip(self._results, other._results):
587
+ context[result] = other_result
588
+
589
+ return True
590
+
591
+ def __hash__(self) -> int:
592
+ return id(self)
593
+
594
+ def print_impl(self, printer: Printer) -> None:
595
+ from kirin.decl import fields as stmt_fields
596
+
597
+ printer.print_name(self)
598
+ printer.plain_print("(")
599
+ for idx, (name, s) in enumerate(self._name_args_slice.items()):
600
+ values = self.args[s]
601
+ if (fields := stmt_fields(self)) and not fields.args[name].print:
602
+ pass
603
+ else:
604
+ with printer.rich(style="orange4"):
605
+ printer.plain_print(name, "=")
606
+
607
+ if isinstance(values, SSAValue):
608
+ printer.print(values)
609
+ else:
610
+ printer.print_seq(values, delim=", ")
611
+
612
+ if idx < len(self._name_args_slice) - 1:
613
+ printer.plain_print(", ")
614
+
615
+ # NOTE: args are specified manually without names
616
+ if not self._name_args_slice and self._args:
617
+ printer.print_seq(self._args, delim=", ")
618
+
619
+ printer.plain_print(")")
620
+
621
+ if self.successors:
622
+ printer.print_seq(
623
+ (printer.state.block_id[successor] for successor in self.successors),
624
+ emit=printer.plain_print,
625
+ delim=", ",
626
+ prefix="[",
627
+ suffix="]",
628
+ )
629
+
630
+ if self.regions:
631
+ printer.print_seq(
632
+ self.regions,
633
+ delim=" ",
634
+ prefix=" (",
635
+ suffix=")",
636
+ )
637
+
638
+ if self.attributes:
639
+ printer.plain_print("{")
640
+ with printer.rich(highlight=True):
641
+ printer.print_mapping(self.attributes, delim=", ")
642
+ printer.plain_print("}")
643
+
644
+ if self._results:
645
+ with printer.rich(style="black"):
646
+ printer.plain_print(" : ")
647
+ printer.print_seq(
648
+ [result.type for result in self._results],
649
+ delim=", ",
650
+ )
651
+
652
+ def get_attr_or_prop(self, key: str) -> Attribute | None:
653
+ """Get the attribute or property of the Statement.
654
+
655
+ Args:
656
+ key (str): The key of the attribute or property.
657
+
658
+ Returns:
659
+ Attribute | None: The attribute or property of the Statement.
660
+ """
661
+ return self.attributes.get(key)
662
+
663
+ @classmethod
664
+ def has_trait(cls, trait_type: type[StmtTrait]) -> bool:
665
+ """Check if the Statement has a specific trait.
666
+
667
+ Args:
668
+ trait_type (type[StmtTrait]): The type of trait to check for.
669
+
670
+ Returns:
671
+ bool: True if the class has the specified trait, False otherwise.
672
+ """
673
+ for trait in cls.traits:
674
+ if isinstance(trait, trait_type):
675
+ return True
676
+ return False
677
+
678
+ TraitType = TypeVar("TraitType", bound=StmtTrait)
679
+
680
+ @classmethod
681
+ def get_trait(cls, trait: type[TraitType]) -> TraitType | None:
682
+ """Get the trait of the Statement."""
683
+ for t in cls.traits:
684
+ if isinstance(t, trait):
685
+ return t
686
+ return None
687
+
688
+ def expect_one_result(self) -> ResultValue:
689
+ """Check if the statement contain only one result, and return it"""
690
+ if len(self._results) != 1:
691
+ raise ValueError(f"expected one result, got {len(self._results)}")
692
+ return self._results[0]
693
+
694
+ # NOTE: statement should implement typecheck
695
+ # this is done automatically via @statement, but
696
+ # in the case manualy implementation is needed,
697
+ # it should be implemented here.
698
+ # NOTE: not an @abstractmethod to make linter happy
699
+ def typecheck(self) -> None:
700
+ """check the type of the statement.
701
+
702
+ Note:
703
+ 1. Statement should implement typecheck.
704
+ this is done automatically via @statement, but
705
+ in the case manualy implementation is needed,
706
+ it should be implemented here.
707
+ 2. This API should be called after all the types are figured out (by typeinfer)
708
+
709
+ """
710
+ raise NotImplementedError
711
+
712
+ def verify(self) -> None:
713
+ return