kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/ir/group.py
ADDED
@@ -0,0 +1,249 @@
|
|
1
|
+
import inspect
|
2
|
+
from types import ModuleType
|
3
|
+
from typing import (
|
4
|
+
TYPE_CHECKING,
|
5
|
+
Union,
|
6
|
+
Generic,
|
7
|
+
TypeVar,
|
8
|
+
Callable,
|
9
|
+
ParamSpec,
|
10
|
+
Concatenate,
|
11
|
+
overload,
|
12
|
+
)
|
13
|
+
from functools import update_wrapper
|
14
|
+
from dataclasses import dataclass
|
15
|
+
from collections.abc import Iterable
|
16
|
+
|
17
|
+
from kirin.ir.method import Method
|
18
|
+
from kirin.exceptions import CompilerError
|
19
|
+
|
20
|
+
if TYPE_CHECKING:
|
21
|
+
from kirin.registry import Registry
|
22
|
+
from kirin.ir.dialect import Dialect
|
23
|
+
|
24
|
+
PassParams = ParamSpec("PassParams")
|
25
|
+
RunPass = Callable[Concatenate[Method, PassParams], None]
|
26
|
+
RunPassGen = Callable[["DialectGroup"], RunPass[PassParams]]
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass(init=False)
|
30
|
+
class DialectGroup(Generic[PassParams]):
|
31
|
+
# method wrapper params
|
32
|
+
Param = ParamSpec("Param")
|
33
|
+
RetType = TypeVar("RetType")
|
34
|
+
MethodTransform = Callable[[Callable[Param, RetType]], Method[Param, RetType]]
|
35
|
+
|
36
|
+
data: frozenset["Dialect"]
|
37
|
+
"""The set of dialects in the group."""
|
38
|
+
# NOTE: this is used to create new dialect groups from existing one
|
39
|
+
run_pass_gen: RunPassGen[PassParams] | None = None
|
40
|
+
"""the function that generates the `run_pass` function.
|
41
|
+
|
42
|
+
This is used to create new dialect groups from existing ones, while
|
43
|
+
keeping the same `run_pass` function.
|
44
|
+
"""
|
45
|
+
run_pass: RunPass[PassParams] | None = None
|
46
|
+
"""the function that runs the passes on the method."""
|
47
|
+
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
dialects: Iterable[Union["Dialect", ModuleType]],
|
51
|
+
run_pass: RunPassGen[PassParams] | None = None,
|
52
|
+
):
|
53
|
+
def identity(code: Method):
|
54
|
+
pass
|
55
|
+
|
56
|
+
self.data = frozenset(self.map_module(dialect) for dialect in dialects)
|
57
|
+
if run_pass is None:
|
58
|
+
self.run_pass_gen = None
|
59
|
+
self.run_pass = None
|
60
|
+
else:
|
61
|
+
self.run_pass_gen = run_pass
|
62
|
+
self.run_pass = run_pass(self)
|
63
|
+
|
64
|
+
def __iter__(self):
|
65
|
+
return iter(self.data)
|
66
|
+
|
67
|
+
def __repr__(self) -> str:
|
68
|
+
names = ", ".join(each.name for each in self.data)
|
69
|
+
return f"DialectGroup([{names}])"
|
70
|
+
|
71
|
+
@staticmethod
|
72
|
+
def map_module(dialect: Union["Dialect", ModuleType]) -> "Dialect":
|
73
|
+
"""map the module to the dialect if it is a module.
|
74
|
+
It assumes that the module has a `dialect` attribute
|
75
|
+
that is an instance of [`Dialect`][kirin.ir.Dialect].
|
76
|
+
"""
|
77
|
+
if isinstance(dialect, ModuleType):
|
78
|
+
return getattr(dialect, "dialect")
|
79
|
+
return dialect
|
80
|
+
|
81
|
+
def add(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
|
82
|
+
"""add a dialect to the group.
|
83
|
+
|
84
|
+
Args:
|
85
|
+
dialect (Union[Dialect, ModuleType]): the dialect to add
|
86
|
+
|
87
|
+
Returns:
|
88
|
+
DialectGroup: the new dialect group with the added
|
89
|
+
"""
|
90
|
+
return self.union([dialect])
|
91
|
+
|
92
|
+
def union(self, dialect: Iterable[Union["Dialect", ModuleType]]) -> "DialectGroup":
|
93
|
+
"""union a set of dialects to the group.
|
94
|
+
|
95
|
+
Args:
|
96
|
+
dialect (Iterable[Union[Dialect, ModuleType]]): the dialects to union
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
DialectGroup: the new dialect group with the union.
|
100
|
+
"""
|
101
|
+
return DialectGroup(
|
102
|
+
dialects=self.data.union(frozenset(self.map_module(d) for d in dialect)),
|
103
|
+
run_pass=self.run_pass_gen, # pass the run_pass_gen function
|
104
|
+
)
|
105
|
+
|
106
|
+
def discard(self, dialect: Union["Dialect", ModuleType]) -> "DialectGroup":
|
107
|
+
"""discard a dialect from the group.
|
108
|
+
|
109
|
+
!!! note
|
110
|
+
This does not raise an error if the dialect is not in the group.
|
111
|
+
|
112
|
+
Args:
|
113
|
+
dialect (Union[Dialect, ModuleType]): the dialect to discard
|
114
|
+
|
115
|
+
Returns:
|
116
|
+
DialectGroup: the new dialect group with the discarded dialect.
|
117
|
+
"""
|
118
|
+
dialect_ = self.map_module(dialect)
|
119
|
+
return DialectGroup(
|
120
|
+
dialects=frozenset(
|
121
|
+
each for each in self.data if each.name != dialect_.name
|
122
|
+
),
|
123
|
+
run_pass=self.run_pass_gen, # pass the run_pass_gen function
|
124
|
+
)
|
125
|
+
|
126
|
+
@property
|
127
|
+
def registry(self) -> "Registry":
|
128
|
+
"""return the registry for the dialect group. This
|
129
|
+
returns a proxy object that can be used to select
|
130
|
+
the lowering interpreters, interpreters, and codegen
|
131
|
+
for the dialects in the group.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
Registry: the registry object.
|
135
|
+
"""
|
136
|
+
from kirin.registry import Registry
|
137
|
+
|
138
|
+
return Registry(self)
|
139
|
+
|
140
|
+
@overload
|
141
|
+
def __call__(
|
142
|
+
self,
|
143
|
+
py_func: Callable[Param, RetType],
|
144
|
+
*args: PassParams.args,
|
145
|
+
**options: PassParams.kwargs,
|
146
|
+
) -> Method[Param, RetType]: ...
|
147
|
+
|
148
|
+
@overload
|
149
|
+
def __call__(
|
150
|
+
self,
|
151
|
+
py_func: None = None,
|
152
|
+
*args: PassParams.args,
|
153
|
+
**options: PassParams.kwargs,
|
154
|
+
) -> MethodTransform[Param, RetType]: ...
|
155
|
+
|
156
|
+
def __call__(
|
157
|
+
self,
|
158
|
+
py_func: Callable[Param, RetType] | None = None,
|
159
|
+
*args: PassParams.args,
|
160
|
+
**options: PassParams.kwargs,
|
161
|
+
) -> Method[Param, RetType] | MethodTransform[Param, RetType]:
|
162
|
+
"""create a method from the python function.
|
163
|
+
|
164
|
+
Args:
|
165
|
+
py_func (Callable): the python function to create the method from.
|
166
|
+
args (PassParams.args): the arguments to pass to the run_pass function.
|
167
|
+
options (PassParams.kwargs): the keyword arguments to pass to the run_pass function.
|
168
|
+
|
169
|
+
Returns:
|
170
|
+
Method: the method created from the python function.
|
171
|
+
"""
|
172
|
+
from kirin.lowering import Lowering
|
173
|
+
|
174
|
+
emit_ir = Lowering(self)
|
175
|
+
|
176
|
+
def wrapper(py_func: Callable) -> Method:
|
177
|
+
if py_func.__name__ == "<lambda>":
|
178
|
+
raise ValueError("Cannot compile lambda functions")
|
179
|
+
|
180
|
+
lineno_offset, file = 0, ""
|
181
|
+
frame = inspect.currentframe()
|
182
|
+
if frame and frame.f_back is not None and frame.f_back.f_back is not None:
|
183
|
+
call_site_frame = frame.f_back.f_back
|
184
|
+
if py_func.__name__ in call_site_frame.f_locals:
|
185
|
+
raise CompilerError(
|
186
|
+
f"overwriting function definition of `{py_func.__name__}`"
|
187
|
+
)
|
188
|
+
|
189
|
+
lineno_offset = call_site_frame.f_lineno - 1
|
190
|
+
file = call_site_frame.f_code.co_filename
|
191
|
+
|
192
|
+
code = emit_ir.run(py_func, lineno_offset=lineno_offset)
|
193
|
+
mt = Method(
|
194
|
+
mod=inspect.getmodule(py_func),
|
195
|
+
py_func=py_func,
|
196
|
+
sym_name=py_func.__name__,
|
197
|
+
arg_names=["#self#"] + inspect.getfullargspec(py_func).args,
|
198
|
+
dialects=self,
|
199
|
+
code=code,
|
200
|
+
file=file,
|
201
|
+
)
|
202
|
+
if doc := inspect.getdoc(py_func):
|
203
|
+
mt.__doc__ = doc
|
204
|
+
|
205
|
+
if self.run_pass is not None:
|
206
|
+
self.run_pass(mt, *args, **options)
|
207
|
+
return mt
|
208
|
+
|
209
|
+
if py_func is not None:
|
210
|
+
return wrapper(py_func)
|
211
|
+
return wrapper
|
212
|
+
|
213
|
+
|
214
|
+
def dialect_group(
|
215
|
+
dialects: Iterable[Union["Dialect", ModuleType]]
|
216
|
+
) -> Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]:
|
217
|
+
"""Create a dialect group from the given dialects based on the
|
218
|
+
definition of `run_pass` function.
|
219
|
+
|
220
|
+
Args:
|
221
|
+
dialects (Iterable[Union[Dialect, ModuleType]]): the dialects to include in the group.
|
222
|
+
|
223
|
+
Returns:
|
224
|
+
Callable[[RunPassGen[PassParams]], DialectGroup[PassParams]]: the dialect group.
|
225
|
+
|
226
|
+
Example:
|
227
|
+
```python
|
228
|
+
from kirin.dialects import cf, fcf, func, math
|
229
|
+
|
230
|
+
@dialect_group([cf, fcf, func, math])
|
231
|
+
def basic_no_opt(self):
|
232
|
+
# initializations
|
233
|
+
def run_pass(mt: Method) -> None:
|
234
|
+
# how passes are applied to the method
|
235
|
+
pass
|
236
|
+
|
237
|
+
return run_pass
|
238
|
+
```
|
239
|
+
"""
|
240
|
+
|
241
|
+
# NOTE: do not alias the annotation below
|
242
|
+
def wrapper(
|
243
|
+
transform: RunPassGen[PassParams],
|
244
|
+
) -> DialectGroup[PassParams]:
|
245
|
+
ret = DialectGroup(dialects, run_pass=transform)
|
246
|
+
update_wrapper(ret, transform)
|
247
|
+
return ret
|
248
|
+
|
249
|
+
return wrapper
|
kirin/ir/method.py
ADDED
@@ -0,0 +1,118 @@
|
|
1
|
+
import typing
|
2
|
+
from types import ModuleType
|
3
|
+
|
4
|
+
# from typing import TYPE_CHECKING, Generic, TypeVar, Callable, ParamSpec
|
5
|
+
from dataclasses import field, dataclass
|
6
|
+
|
7
|
+
from kirin.ir.traits import HasSignature, CallableStmtInterface
|
8
|
+
from kirin.exceptions import VerificationError
|
9
|
+
from kirin.ir.nodes.stmt import Statement
|
10
|
+
from kirin.print.printer import Printer
|
11
|
+
from kirin.ir.attrs.types import Generic
|
12
|
+
from kirin.print.printable import Printable
|
13
|
+
|
14
|
+
if typing.TYPE_CHECKING:
|
15
|
+
from kirin.ir.group import DialectGroup
|
16
|
+
|
17
|
+
Param = typing.ParamSpec("Param")
|
18
|
+
RetType = typing.TypeVar("RetType")
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class Method(Printable, typing.Generic[Param, RetType]):
|
23
|
+
mod: ModuleType | None # ref
|
24
|
+
py_func: typing.Callable[Param, RetType] | None # ref
|
25
|
+
sym_name: str
|
26
|
+
arg_names: list[str]
|
27
|
+
dialects: "DialectGroup" # own
|
28
|
+
code: Statement # own, the corresponding IR, a func.func usually
|
29
|
+
# values contained if closure
|
30
|
+
fields: tuple = field(default_factory=tuple) # own
|
31
|
+
file: str = ""
|
32
|
+
inferred: bool = False
|
33
|
+
"""if typeinfer has been run on this method
|
34
|
+
"""
|
35
|
+
verified: bool = False
|
36
|
+
"""if `code.verify` has been run on this method
|
37
|
+
"""
|
38
|
+
|
39
|
+
def __hash__(self) -> int:
|
40
|
+
return id(self)
|
41
|
+
|
42
|
+
def __call__(self, *args: Param.args, **kwargs: Param.kwargs) -> RetType:
|
43
|
+
from kirin.interp.concrete import Interpreter
|
44
|
+
|
45
|
+
if len(args) + len(kwargs) != len(self.arg_names) - 1:
|
46
|
+
raise ValueError("Incorrect number of arguments")
|
47
|
+
# NOTE: multi-return values will be wrapped in a tuple for Python
|
48
|
+
interp = Interpreter(self.dialects)
|
49
|
+
return interp.run(self, args=args, kwargs=kwargs).expect()
|
50
|
+
|
51
|
+
@property
|
52
|
+
def args(self):
|
53
|
+
"""Return the arguments of the method. (excluding self)"""
|
54
|
+
return tuple(arg for arg in self.callable_region.blocks[0].args[1:])
|
55
|
+
|
56
|
+
@property
|
57
|
+
def arg_types(self):
|
58
|
+
"""Return the types of the arguments of the method. (excluding self)"""
|
59
|
+
return tuple(arg.type for arg in self.args)
|
60
|
+
|
61
|
+
@property
|
62
|
+
def self_type(self):
|
63
|
+
"""Return the type of the self argument of the method."""
|
64
|
+
trait = self.code.get_trait(HasSignature)
|
65
|
+
if trait is None:
|
66
|
+
raise ValueError("Method body must implement HasSignature")
|
67
|
+
signature = trait.get_signature(self.code)
|
68
|
+
return Generic(Method, Generic(tuple, *signature.inputs), signature.output)
|
69
|
+
|
70
|
+
@property
|
71
|
+
def callable_region(self):
|
72
|
+
trait = self.code.get_trait(CallableStmtInterface)
|
73
|
+
if trait is None:
|
74
|
+
raise ValueError("Method body must implement CallableStmtInterface")
|
75
|
+
return trait.get_callable_region(self.code)
|
76
|
+
|
77
|
+
@property
|
78
|
+
def return_type(self):
|
79
|
+
trait = self.code.get_trait(HasSignature)
|
80
|
+
if trait is None:
|
81
|
+
raise ValueError("Method body must implement HasSignature")
|
82
|
+
return trait.get_signature(self.code).output
|
83
|
+
|
84
|
+
def __repr__(self) -> str:
|
85
|
+
return f'Method("{self.sym_name}")'
|
86
|
+
|
87
|
+
def print_impl(self, printer: Printer) -> None:
|
88
|
+
return printer.print(self.code)
|
89
|
+
|
90
|
+
def verify(self) -> None:
|
91
|
+
"""verify the method body."""
|
92
|
+
try:
|
93
|
+
self.code.verify()
|
94
|
+
except VerificationError as e:
|
95
|
+
msg = f'File "{self.file}"'
|
96
|
+
if isinstance(e.node, Statement):
|
97
|
+
if e.node.source:
|
98
|
+
msg += f", line {e.node.source.lineno}"
|
99
|
+
msg += f", in {e.node.name}"
|
100
|
+
|
101
|
+
msg += f":\n Verification failed for {self.sym_name}: {e.args[0]}"
|
102
|
+
raise Exception(msg) from e
|
103
|
+
self.verified = True
|
104
|
+
return
|
105
|
+
|
106
|
+
def similar(self, dialects: typing.Optional["DialectGroup"] = None):
|
107
|
+
return Method(
|
108
|
+
self.mod,
|
109
|
+
self.py_func,
|
110
|
+
self.sym_name,
|
111
|
+
self.arg_names,
|
112
|
+
dialects or self.dialects,
|
113
|
+
self.code.from_stmt(self.code, regions=[self.callable_region.clone()]),
|
114
|
+
self.fields,
|
115
|
+
self.file,
|
116
|
+
self.inferred,
|
117
|
+
self.verified,
|
118
|
+
)
|
@@ -0,0 +1,7 @@
|
|
1
|
+
"""Definition of Kirin's Intermediate Representation (IR) nodes.
|
2
|
+
"""
|
3
|
+
|
4
|
+
from kirin.ir.nodes.base import IRNode as IRNode
|
5
|
+
from kirin.ir.nodes.stmt import Statement as Statement
|
6
|
+
from kirin.ir.nodes.block import Block as Block
|
7
|
+
from kirin.ir.nodes.region import Region as Region
|
kirin/ir/nodes/base.py
ADDED
@@ -0,0 +1,149 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING, Generic, TypeVar, Iterator
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
from typing_extensions import Self
|
8
|
+
|
9
|
+
from kirin.print import Printer, Printable
|
10
|
+
from kirin.ir.ssa import SSAValue
|
11
|
+
|
12
|
+
if TYPE_CHECKING:
|
13
|
+
from kirin.ir.nodes.stmt import Statement
|
14
|
+
|
15
|
+
|
16
|
+
ParentType = TypeVar("ParentType", bound="IRNode")
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class IRNode(Generic[ParentType], ABC, Printable):
|
21
|
+
"""Base class for all IR nodes. All IR nodes are hashable and can be compared
|
22
|
+
for equality. The hash of an IR node is the same as the id of the object.
|
23
|
+
|
24
|
+
!!! note "Pretty Printing"
|
25
|
+
This object is pretty printable via
|
26
|
+
[`.print()`][kirin.print.printable.Printable.print] method.
|
27
|
+
"""
|
28
|
+
|
29
|
+
def assert_parent(self, type_: type[IRNode], parent) -> None:
|
30
|
+
assert (
|
31
|
+
isinstance(parent, type_) or parent is None
|
32
|
+
), f"Invalid parent, expect {type_} or None, got {type(parent)}"
|
33
|
+
|
34
|
+
@property
|
35
|
+
@abstractmethod
|
36
|
+
def parent_node(self) -> ParentType | None:
|
37
|
+
"""Parent node of the current node."""
|
38
|
+
...
|
39
|
+
|
40
|
+
@parent_node.setter
|
41
|
+
@abstractmethod
|
42
|
+
def parent_node(self, parent: ParentType | None) -> None: ...
|
43
|
+
|
44
|
+
def is_ancestor(self, op: IRNode) -> bool:
|
45
|
+
"""Check if the given node is an ancestor of the current node."""
|
46
|
+
if op is self:
|
47
|
+
return True
|
48
|
+
if (parent := op.parent_node) is None:
|
49
|
+
return False
|
50
|
+
return self.is_ancestor(parent)
|
51
|
+
|
52
|
+
def get_root(self) -> IRNode:
|
53
|
+
"""Get the root node of the current node."""
|
54
|
+
if (parent := self.parent_node) is None:
|
55
|
+
return self
|
56
|
+
return parent.get_root()
|
57
|
+
|
58
|
+
def is_equal(self, other: IRNode, context: dict = {}) -> bool:
|
59
|
+
"""Check if the current node is equal to the other node.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
other: The other node to compare.
|
63
|
+
context: The context to store the visited nodes. Defaults to {}.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
True if the nodes are equal, False otherwise.
|
67
|
+
|
68
|
+
!!! note
|
69
|
+
This method is not the same as the `==` operator. It checks for
|
70
|
+
structural equality rather than identity. To change the behavior
|
71
|
+
of structural equality, override the `is_structurally_equal` method.
|
72
|
+
"""
|
73
|
+
if not isinstance(other, type(self)):
|
74
|
+
return False
|
75
|
+
return self.is_structurally_equal(other, context)
|
76
|
+
|
77
|
+
def attach(self, parent: ParentType) -> None:
|
78
|
+
"""Attach the current node to the parent node."""
|
79
|
+
assert isinstance(parent, IRNode), f"Expected IRNode, got {type(parent)}"
|
80
|
+
|
81
|
+
if self.parent_node:
|
82
|
+
raise ValueError("Node already has a parent")
|
83
|
+
if self.is_ancestor(parent):
|
84
|
+
raise ValueError("Node is an ancestor of the parent")
|
85
|
+
self.parent_node = parent
|
86
|
+
|
87
|
+
@abstractmethod
|
88
|
+
def detach(self) -> None:
|
89
|
+
"""Detach the current node from the parent node."""
|
90
|
+
...
|
91
|
+
|
92
|
+
@abstractmethod
|
93
|
+
def drop_all_references(self) -> None:
|
94
|
+
"""Drop all references to other nodes."""
|
95
|
+
...
|
96
|
+
|
97
|
+
@abstractmethod
|
98
|
+
def delete(self, safe: bool = True) -> None:
|
99
|
+
"""Delete the current node.
|
100
|
+
|
101
|
+
Args:
|
102
|
+
safe: If True, check if the node has any references before deleting.
|
103
|
+
"""
|
104
|
+
...
|
105
|
+
|
106
|
+
@abstractmethod
|
107
|
+
def is_structurally_equal(
|
108
|
+
self,
|
109
|
+
other: Self,
|
110
|
+
context: dict[IRNode | SSAValue, IRNode | SSAValue] | None = None,
|
111
|
+
) -> bool:
|
112
|
+
"""Check if the current node is structurally equal to the other node.
|
113
|
+
|
114
|
+
!!! note
|
115
|
+
This method is for tweaking the behavior of structural equality.
|
116
|
+
To check if two nodes are structurally equal, use the `is_equal` method.
|
117
|
+
|
118
|
+
Args:
|
119
|
+
other: The other node to compare.
|
120
|
+
context: The context to store the visited nodes.
|
121
|
+
|
122
|
+
Returns:
|
123
|
+
True if the nodes are structurally equal, False otherwise.
|
124
|
+
"""
|
125
|
+
...
|
126
|
+
|
127
|
+
def __eq__(self, other) -> bool:
|
128
|
+
return self is other
|
129
|
+
|
130
|
+
def __hash__(self) -> int:
|
131
|
+
return id(self)
|
132
|
+
|
133
|
+
@abstractmethod
|
134
|
+
def walk(
|
135
|
+
self, *, reverse: bool = False, region_first: bool = False
|
136
|
+
) -> Iterator[Statement]: ...
|
137
|
+
|
138
|
+
@abstractmethod
|
139
|
+
def print_impl(self, printer: Printer) -> None: ...
|
140
|
+
|
141
|
+
@abstractmethod
|
142
|
+
def typecheck(self) -> None:
|
143
|
+
"""check if types are correct."""
|
144
|
+
...
|
145
|
+
|
146
|
+
@abstractmethod
|
147
|
+
def verify(self) -> None:
|
148
|
+
"""run mandatory validation checks. This is not same as typecheck, which may be optional."""
|
149
|
+
...
|