kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/group.py ADDED
@@ -0,0 +1,249 @@
1
+ import inspect
2
+ from types import ModuleType
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Union,
6
+ Generic,
7
+ TypeVar,
8
+ Callable,
9
+ ParamSpec,
10
+ Concatenate,
11
+ overload,
12
+ )
13
+ from functools import update_wrapper
14
+ from dataclasses import dataclass
15
+ from collections.abc import Iterable
16
+
17
+ from kirin.ir.method import Method
18
+ from kirin.exceptions import CompilerError
19
+
20
+ if TYPE_CHECKING:
21
+ from kirin.registry import Registry
22
+ from kirin.ir.dialect import Dialect
23
+
24
+ PassParams = ParamSpec("PassParams")
25
+ RunPass = Callable[Concatenate[Method, PassParams], None]
26
+ RunPassGen = Callable[["DialectGroup"], RunPass[PassParams]]
27
+
28
+
29
+ @dataclass(init=False)
30
+ class DialectGroup(Generic[PassParams]):
31
+ # method wrapper params
32
+ Param = ParamSpec("Param")
33
+ RetType = TypeVar("RetType")
34
+ MethodTransform = Callable[[Callable[Param, RetType]], Method[Param, RetType]]
35
+
36
+ data: frozenset["Dialect"]
37
+ """The set of dialects in the group."""
38
+ # NOTE: this is used to create new dialect groups from existing one
39
+ run_pass_gen: RunPassGen[PassParams] | None = None
40
+ """the function that generates the `run_pass` function.
41
+
42
+ This is used to create new dialect groups from existing ones, while
43
+ keeping the same `run_pass` function.
44
+ """
45
+ run_pass: RunPass[PassParams] | None = None
46
+ """the function that runs the passes on the method."""
47
+
48
+ def __init__(
49
+ self,
50
+ dialects: Iterable[Union["Dialect", ModuleType]],
51
+ run_pass: RunPassGen[PassParams] | None = None,
52
+ ):
53
+ def identity(code: Method):
54
+ pass
55
+
56
+ self.data = frozenset(self.map_module(dialect) for dialect in dialects)
57
+ if run_pass is None:
58
+ self.run_pass_gen = None
59
+ self.run_pass = None
60
+ else:
61
+ self.run_pass_gen = run_pass
62
+ self.run_pass = run_pass(self)
63
+
64
+ def __iter__(self):
65
+ return iter(self.data)
66
+
67
+ def __repr__(self) -> str:
68
+ names = ", ".join(each.name for each in self.data)
69
+ return f"DialectGroup([{names}])"
70
+
71
+ @staticmethod
72
+ def map_module(dialect: Union["Dialect", ModuleType]) -> "Dialect":
73
+ """map the module to the dialect if it is a module.
74
+ It assumes that the module has a `dialect` attribute
75
+ that is an instance of [`Dialect`][kirin.ir.Dialect].
76
+ """
77
+ if isinstance(dialect, ModuleType):
78
+ return getattr(dialect, "dialect")
79
+ return dialect
80
+
81
+ def add(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
82
+ """add a dialect to the group.
83
+
84
+ Args:
85
+ dialect (Union[Dialect, ModuleType]): the dialect to add
86
+
87
+ Returns:
88
+ DialectGroup: the new dialect group with the added
89
+ """
90
+ return self.union([dialect])
91
+
92
+ def union(self, dialect: Iterable[Union["Dialect", ModuleType]]) -> "DialectGroup":
93
+ """union a set of dialects to the group.
94
+
95
+ Args:
96
+ dialect (Iterable[Union[Dialect, ModuleType]]): the dialects to union
97
+
98
+ Returns:
99
+ DialectGroup: the new dialect group with the union.
100
+ """
101
+ return DialectGroup(
102
+ dialects=self.data.union(frozenset(self.map_module(d) for d in dialect)),
103
+ run_pass=self.run_pass_gen, # pass the run_pass_gen function
104
+ )
105
+
106
+ def discard(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
107
+ """discard a dialect from the group.
108
+
109
+ !!! note
110
+ This does not raise an error if the dialect is not in the group.
111
+
112
+ Args:
113
+ dialect (Union[Dialect, ModuleType]): the dialect to discard
114
+
115
+ Returns:
116
+ DialectGroup: the new dialect group with the discarded dialect.
117
+ """
118
+ dialect_ = self.map_module(dialect)
119
+ return DialectGroup(
120
+ dialects=frozenset(
121
+ each for each in self.data if each.name != dialect_.name
122
+ ),
123
+ run_pass=self.run_pass_gen, # pass the run_pass_gen function
124
+ )
125
+
126
+ @property
127
+ def registry(self) -> "Registry":
128
+ """return the registry for the dialect group. This
129
+ returns a proxy object that can be used to select
130
+ the lowering interpreters, interpreters, and codegen
131
+ for the dialects in the group.
132
+
133
+ Returns:
134
+ Registry: the registry object.
135
+ """
136
+ from kirin.registry import Registry
137
+
138
+ return Registry(self)
139
+
140
+ @overload
141
+ def __call__(
142
+ self,
143
+ py_func: Callable[Param, RetType],
144
+ *args: PassParams.args,
145
+ **options: PassParams.kwargs,
146
+ ) -> Method[Param, RetType]: ...
147
+
148
+ @overload
149
+ def __call__(
150
+ self,
151
+ py_func: None = None,
152
+ *args: PassParams.args,
153
+ **options: PassParams.kwargs,
154
+ ) -> MethodTransform[Param, RetType]: ...
155
+
156
+ def __call__(
157
+ self,
158
+ py_func: Callable[Param, RetType] | None = None,
159
+ *args: PassParams.args,
160
+ **options: PassParams.kwargs,
161
+ ) -> Method[Param, RetType] | MethodTransform[Param, RetType]:
162
+ """create a method from the python function.
163
+
164
+ Args:
165
+ py_func (Callable): the python function to create the method from.
166
+ args (PassParams.args): the arguments to pass to the run_pass function.
167
+ options (PassParams.kwargs): the keyword arguments to pass to the run_pass function.
168
+
169
+ Returns:
170
+ Method: the method created from the python function.
171
+ """
172
+ from kirin.lowering import Lowering
173
+
174
+ emit_ir = Lowering(self)
175
+
176
+ def wrapper(py_func: Callable) -> Method:
177
+ if py_func.__name__ == "<lambda>":
178
+ raise ValueError("Cannot compile lambda functions")
179
+
180
+ lineno_offset, file = 0, ""
181
+ frame = inspect.currentframe()
182
+ if frame and frame.f_back is not None and frame.f_back.f_back is not None:
183
+ call_site_frame = frame.f_back.f_back
184
+ if py_func.__name__ in call_site_frame.f_locals:
185
+ raise CompilerError(
186
+ f"overwriting function definition of `{py_func.__name__}`"
187
+ )
188
+
189
+ lineno_offset = call_site_frame.f_lineno - 1
190
+ file = call_site_frame.f_code.co_filename
191
+
192
+ code = emit_ir.run(py_func, lineno_offset=lineno_offset)
193
+ mt = Method(
194
+ mod=inspect.getmodule(py_func),
195
+ py_func=py_func,
196
+ sym_name=py_func.__name__,
197
+ arg_names=["#self#"] + inspect.getfullargspec(py_func).args,
198
+ dialects=self,
199
+ code=code,
200
+ file=file,
201
+ )
202
+ if doc := inspect.getdoc(py_func):
203
+ mt.__doc__ = doc
204
+
205
+ if self.run_pass is not None:
206
+ self.run_pass(mt, *args, **options)
207
+ return mt
208
+
209
+ if py_func is not None:
210
+ return wrapper(py_func)
211
+ return wrapper
212
+
213
+
214
+ def dialect_group(
215
+ dialects: Iterable[Union["Dialect", ModuleType]]
216
+ ) -> Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]:
217
+ """Create a dialect group from the given dialects based on the
218
+ definition of `run_pass` function.
219
+
220
+ Args:
221
+ dialects (Iterable[Union[Dialect, ModuleType]]): the dialects to include in the group.
222
+
223
+ Returns:
224
+ Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]: the dialect group.
225
+
226
+ Example:
227
+ ```python
228
+ from kirin.dialects import cf, fcf, func, math
229
+
230
+ @dialect_group([cf, fcf, func, math])
231
+ def basic_no_opt(self):
232
+ # initializations
233
+ def run_pass(mt: Method) -> None:
234
+ # how passes are applied to the method
235
+ pass
236
+
237
+ return run_pass
238
+ ```
239
+ """
240
+
241
+ # NOTE: do not alias the annotation below
242
+ def wrapper(
243
+ transform: RunPassGen[PassParams],
244
+ ) -> DialectGroup[PassParams]:
245
+ ret = DialectGroup(dialects, run_pass=transform)
246
+ update_wrapper(ret, transform)
247
+ return ret
248
+
249
+ return wrapper
kirin/ir/method.py ADDED
@@ -0,0 +1,118 @@
1
+ import typing
2
+ from types import ModuleType
3
+
4
+ # from typing import TYPE_CHECKING, Generic, TypeVar, Callable, ParamSpec
5
+ from dataclasses import field, dataclass
6
+
7
+ from kirin.ir.traits import HasSignature, CallableStmtInterface
8
+ from kirin.exceptions import VerificationError
9
+ from kirin.ir.nodes.stmt import Statement
10
+ from kirin.print.printer import Printer
11
+ from kirin.ir.attrs.types import Generic
12
+ from kirin.print.printable import Printable
13
+
14
+ if typing.TYPE_CHECKING:
15
+ from kirin.ir.group import DialectGroup
16
+
17
+ Param = typing.ParamSpec("Param")
18
+ RetType = typing.TypeVar("RetType")
19
+
20
+
21
+ @dataclass
22
+ class Method(Printable, typing.Generic[Param, RetType]):
23
+ mod: ModuleType | None # ref
24
+ py_func: typing.Callable[Param, RetType] | None # ref
25
+ sym_name: str
26
+ arg_names: list[str]
27
+ dialects: "DialectGroup" # own
28
+ code: Statement # own, the corresponding IR, a func.func usually
29
+ # values contained if closure
30
+ fields: tuple = field(default_factory=tuple) # own
31
+ file: str = ""
32
+ inferred: bool = False
33
+ """if typeinfer has been run on this method
34
+ """
35
+ verified: bool = False
36
+ """if `code.verify` has been run on this method
37
+ """
38
+
39
+ def __hash__(self) -> int:
40
+ return id(self)
41
+
42
+ def __call__(self, *args: Param.args, **kwargs: Param.kwargs) -> RetType:
43
+ from kirin.interp.concrete import Interpreter
44
+
45
+ if len(args) + len(kwargs) != len(self.arg_names) - 1:
46
+ raise ValueError("Incorrect number of arguments")
47
+ # NOTE: multi-return values will be wrapped in a tuple for Python
48
+ interp = Interpreter(self.dialects)
49
+ return interp.run(self, args=args, kwargs=kwargs).expect()
50
+
51
+ @property
52
+ def args(self):
53
+ """Return the arguments of the method. (excluding self)"""
54
+ return tuple(arg for arg in self.callable_region.blocks[0].args[1:])
55
+
56
+ @property
57
+ def arg_types(self):
58
+ """Return the types of the arguments of the method. (excluding self)"""
59
+ return tuple(arg.type for arg in self.args)
60
+
61
+ @property
62
+ def self_type(self):
63
+ """Return the type of the self argument of the method."""
64
+ trait = self.code.get_trait(HasSignature)
65
+ if trait is None:
66
+ raise ValueError("Method body must implement HasSignature")
67
+ signature = trait.get_signature(self.code)
68
+ return Generic(Method, Generic(tuple, *signature.inputs), signature.output)
69
+
70
+ @property
71
+ def callable_region(self):
72
+ trait = self.code.get_trait(CallableStmtInterface)
73
+ if trait is None:
74
+ raise ValueError("Method body must implement CallableStmtInterface")
75
+ return trait.get_callable_region(self.code)
76
+
77
+ @property
78
+ def return_type(self):
79
+ trait = self.code.get_trait(HasSignature)
80
+ if trait is None:
81
+ raise ValueError("Method body must implement HasSignature")
82
+ return trait.get_signature(self.code).output
83
+
84
+ def __repr__(self) -> str:
85
+ return f'Method("{self.sym_name}")'
86
+
87
+ def print_impl(self, printer: Printer) -> None:
88
+ return printer.print(self.code)
89
+
90
+ def verify(self) -> None:
91
+ """verify the method body."""
92
+ try:
93
+ self.code.verify()
94
+ except VerificationError as e:
95
+ msg = f'File "{self.file}"'
96
+ if isinstance(e.node, Statement):
97
+ if e.node.source:
98
+ msg += f", line {e.node.source.lineno}"
99
+ msg += f", in {e.node.name}"
100
+
101
+ msg += f":\n Verification failed for {self.sym_name}: {e.args[0]}"
102
+ raise Exception(msg) from e
103
+ self.verified = True
104
+ return
105
+
106
+ def similar(self, dialects: typing.Optional["DialectGroup"] = None):
107
+ return Method(
108
+ self.mod,
109
+ self.py_func,
110
+ self.sym_name,
111
+ self.arg_names,
112
+ dialects or self.dialects,
113
+ self.code.from_stmt(self.code, regions=[self.callable_region.clone()]),
114
+ self.fields,
115
+ self.file,
116
+ self.inferred,
117
+ self.verified,
118
+ )
@@ -0,0 +1,7 @@
1
+ """Definition of Kirin's Intermediate Representation (IR) nodes.
2
+ """
3
+
4
+ from kirin.ir.nodes.base import IRNode as IRNode
5
+ from kirin.ir.nodes.stmt import Statement as Statement
6
+ from kirin.ir.nodes.block import Block as Block
7
+ from kirin.ir.nodes.region import Region as Region
kirin/ir/nodes/base.py ADDED
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING, Generic, TypeVar, Iterator
5
+ from dataclasses import dataclass
6
+
7
+ from typing_extensions import Self
8
+
9
+ from kirin.print import Printer, Printable
10
+ from kirin.ir.ssa import SSAValue
11
+
12
+ if TYPE_CHECKING:
13
+ from kirin.ir.nodes.stmt import Statement
14
+
15
+
16
+ ParentType = TypeVar("ParentType", bound="IRNode")
17
+
18
+
19
+ @dataclass
20
+ class IRNode(Generic[ParentType], ABC, Printable):
21
+ """Base class for all IR nodes. All IR nodes are hashable and can be compared
22
+ for equality. The hash of an IR node is the same as the id of the object.
23
+
24
+ !!! note "Pretty Printing"
25
+ This object is pretty printable via
26
+ [`.print()`][kirin.print.printable.Printable.print] method.
27
+ """
28
+
29
+ def assert_parent(self, type_: type[IRNode], parent) -> None:
30
+ assert (
31
+ isinstance(parent, type_) or parent is None
32
+ ), f"Invalid parent, expect {type_} or None, got {type(parent)}"
33
+
34
+ @property
35
+ @abstractmethod
36
+ def parent_node(self) -> ParentType | None:
37
+ """Parent node of the current node."""
38
+ ...
39
+
40
+ @parent_node.setter
41
+ @abstractmethod
42
+ def parent_node(self, parent: ParentType | None) -> None: ...
43
+
44
+ def is_ancestor(self, op: IRNode) -> bool:
45
+ """Check if the given node is an ancestor of the current node."""
46
+ if op is self:
47
+ return True
48
+ if (parent := op.parent_node) is None:
49
+ return False
50
+ return self.is_ancestor(parent)
51
+
52
+ def get_root(self) -> IRNode:
53
+ """Get the root node of the current node."""
54
+ if (parent := self.parent_node) is None:
55
+ return self
56
+ return parent.get_root()
57
+
58
+ def is_equal(self, other: IRNode, context: dict = {}) -> bool:
59
+ """Check if the current node is equal to the other node.
60
+
61
+ Args:
62
+ other: The other node to compare.
63
+ context: The context to store the visited nodes. Defaults to {}.
64
+
65
+ Returns:
66
+ True if the nodes are equal, False otherwise.
67
+
68
+ !!! note
69
+ This method is not the same as the `==` operator. It checks for
70
+ structural equality rather than identity. To change the behavior
71
+ of structural equality, override the `is_structurally_equal` method.
72
+ """
73
+ if not isinstance(other, type(self)):
74
+ return False
75
+ return self.is_structurally_equal(other, context)
76
+
77
+ def attach(self, parent: ParentType) -> None:
78
+ """Attach the current node to the parent node."""
79
+ assert isinstance(parent, IRNode), f"Expected IRNode, got {type(parent)}"
80
+
81
+ if self.parent_node:
82
+ raise ValueError("Node already has a parent")
83
+ if self.is_ancestor(parent):
84
+ raise ValueError("Node is an ancestor of the parent")
85
+ self.parent_node = parent
86
+
87
+ @abstractmethod
88
+ def detach(self) -> None:
89
+ """Detach the current node from the parent node."""
90
+ ...
91
+
92
+ @abstractmethod
93
+ def drop_all_references(self) -> None:
94
+ """Drop all references to other nodes."""
95
+ ...
96
+
97
+ @abstractmethod
98
+ def delete(self, safe: bool = True) -> None:
99
+ """Delete the current node.
100
+
101
+ Args:
102
+ safe: If True, check if the node has any references before deleting.
103
+ """
104
+ ...
105
+
106
+ @abstractmethod
107
+ def is_structurally_equal(
108
+ self,
109
+ other: Self,
110
+ context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
111
+ ) -> bool:
112
+ """Check if the current node is structurally equal to the other node.
113
+
114
+ !!! note
115
+ This method is for tweaking the behavior of structural equality.
116
+ To check if two nodes are structurally equal, use the `is_equal` method.
117
+
118
+ Args:
119
+ other: The other node to compare.
120
+ context: The context to store the visited nodes.
121
+
122
+ Returns:
123
+ True if the nodes are structurally equal, False otherwise.
124
+ """
125
+ ...
126
+
127
+ def __eq__(self, other) -> bool:
128
+ return self is other
129
+
130
+ def __hash__(self) -> int:
131
+ return id(self)
132
+
133
+ @abstractmethod
134
+ def walk(
135
+ self, *, reverse: bool = False, region_first: bool = False
136
+ ) -> Iterator[Statement]: ...
137
+
138
+ @abstractmethod
139
+ def print_impl(self, printer: Printer) -> None: ...
140
+
141
+ @abstractmethod
142
+ def typecheck(self) -> None:
143
+ """check if types are correct."""
144
+ ...
145
+
146
+ @abstractmethod
147
+ def verify(self) -> None:
148
+ """run mandatory validation checks. This is not same as typecheck, which may be optional."""
149
+ ...