kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/print/printer.py
ADDED
@@ -0,0 +1,415 @@
|
|
1
|
+
import io
|
2
|
+
from typing import (
|
3
|
+
IO,
|
4
|
+
TYPE_CHECKING,
|
5
|
+
Any,
|
6
|
+
Union,
|
7
|
+
Literal,
|
8
|
+
TypeVar,
|
9
|
+
Callable,
|
10
|
+
Iterable,
|
11
|
+
Generator,
|
12
|
+
)
|
13
|
+
from contextlib import contextmanager
|
14
|
+
from dataclasses import field, dataclass
|
15
|
+
|
16
|
+
from rich.theme import Theme
|
17
|
+
from rich.console import Console
|
18
|
+
|
19
|
+
from kirin.idtable import IdTable
|
20
|
+
from kirin.print.printable import Printable
|
21
|
+
|
22
|
+
if TYPE_CHECKING:
|
23
|
+
from kirin import ir
|
24
|
+
|
25
|
+
|
26
|
+
DEFAULT_THEME = {
|
27
|
+
"dark": Theme(
|
28
|
+
{
|
29
|
+
"dialect": "dark_blue",
|
30
|
+
"type": "dark_blue",
|
31
|
+
"comment": "bright_black",
|
32
|
+
"keyword": "red",
|
33
|
+
"symbol": "cyan",
|
34
|
+
"warning": "yellow",
|
35
|
+
"string": "green",
|
36
|
+
"irrational": "default",
|
37
|
+
"number": "default",
|
38
|
+
}
|
39
|
+
),
|
40
|
+
"light": Theme(
|
41
|
+
{
|
42
|
+
"dialect": "blue",
|
43
|
+
"type": "blue",
|
44
|
+
"comment": "bright_black",
|
45
|
+
"keyword": "red",
|
46
|
+
"symbol": "cyan",
|
47
|
+
"warning": "yellow",
|
48
|
+
"string": "green",
|
49
|
+
"irrational": "magenta",
|
50
|
+
}
|
51
|
+
),
|
52
|
+
}
|
53
|
+
|
54
|
+
|
55
|
+
@dataclass
|
56
|
+
class PrintState:
|
57
|
+
ssa_id: IdTable["ir.SSAValue"] = field(default_factory=IdTable["ir.SSAValue"])
|
58
|
+
block_id: IdTable["ir.Block"] = field(
|
59
|
+
default_factory=lambda: IdTable["ir.Block"](prefix="^")
|
60
|
+
)
|
61
|
+
indent: int = 0
|
62
|
+
result_width: int = 0
|
63
|
+
indent_marks: list[int] = field(default_factory=list)
|
64
|
+
result_width: int = 0
|
65
|
+
"SSA-value column width in printing"
|
66
|
+
rich_style: str | None = None
|
67
|
+
rich_highlight: bool | None = False
|
68
|
+
messages: list[str] = field(default_factory=list)
|
69
|
+
|
70
|
+
|
71
|
+
IOType = TypeVar("IOType", bound=IO)
|
72
|
+
|
73
|
+
|
74
|
+
def _default_console():
|
75
|
+
return Console(force_jupyter=False)
|
76
|
+
|
77
|
+
|
78
|
+
@dataclass
|
79
|
+
class Printer:
|
80
|
+
"""A IR pretty printer build on top of Rich console."""
|
81
|
+
|
82
|
+
console: Console = field(default_factory=_default_console)
|
83
|
+
"""Rich console"""
|
84
|
+
analysis: dict["ir.SSAValue", Any] | None = None
|
85
|
+
"""Analysis results"""
|
86
|
+
hint: str | None = None
|
87
|
+
"""key of the SSAValue hint to print"""
|
88
|
+
state: PrintState = field(default_factory=PrintState)
|
89
|
+
"""Printing state"""
|
90
|
+
show_indent_mark: bool = field(default=True, kw_only=True)
|
91
|
+
"Whether to show indent marks, e.g │"
|
92
|
+
theme: Theme | dict | Literal["dark", "light"] = field(default="dark", kw_only=True)
|
93
|
+
"Theme to use for printing"
|
94
|
+
|
95
|
+
def __post_init__(self):
|
96
|
+
if isinstance(self.theme, dict):
|
97
|
+
self.theme = Theme(self.theme)
|
98
|
+
elif isinstance(self.theme, str):
|
99
|
+
self.theme = DEFAULT_THEME[self.theme]
|
100
|
+
self.console.push_theme(self.theme)
|
101
|
+
|
102
|
+
def print(self, object):
|
103
|
+
"""entry point for printing an object
|
104
|
+
|
105
|
+
Args:
|
106
|
+
object: object to print.
|
107
|
+
"""
|
108
|
+
if isinstance(object, Printable):
|
109
|
+
object.print_impl(self)
|
110
|
+
else:
|
111
|
+
fn = getattr(self, f"print_{object.__class__.__name__}", None)
|
112
|
+
if fn is None:
|
113
|
+
raise NotImplementedError(
|
114
|
+
f"Printer for {object.__class__.__name__} not found"
|
115
|
+
)
|
116
|
+
fn(object)
|
117
|
+
|
118
|
+
def print_name(
|
119
|
+
self, node: Union["ir.Attribute", "ir.Statement"], prefix: str = ""
|
120
|
+
) -> None:
|
121
|
+
"""print the name of a node
|
122
|
+
|
123
|
+
Args:
|
124
|
+
node(ir.Attribute | ir.Statement): node to print
|
125
|
+
prefix(str): prefix to print before the name, default to ""
|
126
|
+
"""
|
127
|
+
self.print_dialect_path(node, prefix=prefix)
|
128
|
+
if node.dialect:
|
129
|
+
self.plain_print(".")
|
130
|
+
self.plain_print(node.name)
|
131
|
+
|
132
|
+
def print_stmt(self, node: "ir.Statement"):
|
133
|
+
if node._results:
|
134
|
+
result_str = self.result_str(node._results)
|
135
|
+
self.plain_print(result_str.rjust(self.state.result_width), " = ")
|
136
|
+
elif self.state.result_width:
|
137
|
+
self.plain_print(" " * self.state.result_width, " ")
|
138
|
+
with self.indent(self.state.result_width + 3, mark=True):
|
139
|
+
self.print(node)
|
140
|
+
with self.rich(style="warning"):
|
141
|
+
self.print_analysis(*node._results, prefix=" # ---> ")
|
142
|
+
with self.rich(style="comment"):
|
143
|
+
self.print_hint(*node._results)
|
144
|
+
|
145
|
+
def print_hint(
|
146
|
+
self,
|
147
|
+
*values: "ir.SSAValue",
|
148
|
+
prefix: str = " // hint<",
|
149
|
+
suffix: str = ">",
|
150
|
+
):
|
151
|
+
if not self.hint or not values:
|
152
|
+
return
|
153
|
+
|
154
|
+
self.plain_print(prefix)
|
155
|
+
self.plain_print(self.hint)
|
156
|
+
for idx, item in enumerate(values):
|
157
|
+
if idx > 0:
|
158
|
+
self.plain_print(", ")
|
159
|
+
|
160
|
+
self.plain_print("=")
|
161
|
+
if item.hints.get(self.hint) is not None:
|
162
|
+
self.plain_print(repr(item.hints.get(self.hint)))
|
163
|
+
else:
|
164
|
+
self.plain_print("missing")
|
165
|
+
self.plain_print(suffix)
|
166
|
+
|
167
|
+
def print_analysis(
|
168
|
+
self,
|
169
|
+
*values: "ir.SSAValue",
|
170
|
+
prefix: str = "",
|
171
|
+
):
|
172
|
+
if self.analysis is None or not values:
|
173
|
+
return
|
174
|
+
|
175
|
+
self.plain_print(prefix)
|
176
|
+
for idx, value in enumerate(values):
|
177
|
+
if idx > 0:
|
178
|
+
self.plain_print(", ")
|
179
|
+
|
180
|
+
if result := self.analysis.get(value):
|
181
|
+
self.plain_print(repr(result))
|
182
|
+
else:
|
183
|
+
self.plain_print("missing")
|
184
|
+
|
185
|
+
def print_dialect_path(
|
186
|
+
self, node: Union["ir.Attribute", "ir.Statement"], prefix: str = ""
|
187
|
+
) -> None:
|
188
|
+
"""print the dialect path of a node.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
node(ir.Attribute | ir.Statement): node to print
|
192
|
+
prefix(str): prefix to print before the dialect path, default to ""
|
193
|
+
"""
|
194
|
+
if node.dialect: # not None
|
195
|
+
self.plain_print(prefix)
|
196
|
+
self.plain_print(node.dialect.name, style="dialect")
|
197
|
+
else:
|
198
|
+
self.plain_print(prefix)
|
199
|
+
|
200
|
+
def print_newline(self):
|
201
|
+
"""print a newline character.
|
202
|
+
|
203
|
+
This method also prints any messages in the state for debugging.
|
204
|
+
"""
|
205
|
+
self.plain_print("\n")
|
206
|
+
|
207
|
+
if self.state.messages:
|
208
|
+
for message in self.state.messages:
|
209
|
+
self.plain_print(message)
|
210
|
+
self.plain_print("\n")
|
211
|
+
self.state.messages.clear()
|
212
|
+
self.print_indent()
|
213
|
+
|
214
|
+
def print_indent(self):
|
215
|
+
"""print the current indentation level optionally with indent marks."""
|
216
|
+
indent_str = ""
|
217
|
+
if self.show_indent_mark and self.state.indent_marks:
|
218
|
+
indent_str = "".join(
|
219
|
+
"│" if i in self.state.indent_marks else " "
|
220
|
+
for i in range(self.state.indent)
|
221
|
+
)
|
222
|
+
with self.rich(style="comment"):
|
223
|
+
self.plain_print(indent_str)
|
224
|
+
else:
|
225
|
+
indent_str = " " * self.state.indent
|
226
|
+
self.plain_print(indent_str)
|
227
|
+
|
228
|
+
def plain_print(self, *objects, sep="", end="", style=None, highlight=None):
|
229
|
+
"""print objects without any formatting.
|
230
|
+
|
231
|
+
Args:
|
232
|
+
*objects: objects to print
|
233
|
+
|
234
|
+
Keyword Args:
|
235
|
+
sep(str): separator between objects, default to ""
|
236
|
+
end(str): end character, default to ""
|
237
|
+
style(str): style to use, default to None
|
238
|
+
highlight(bool): whether to highlight the text, default to None
|
239
|
+
"""
|
240
|
+
self.console.out(
|
241
|
+
*objects,
|
242
|
+
sep=sep,
|
243
|
+
end=end,
|
244
|
+
style=style or self.state.rich_style,
|
245
|
+
highlight=highlight or self.state.rich_highlight,
|
246
|
+
)
|
247
|
+
|
248
|
+
ElemType = TypeVar("ElemType")
|
249
|
+
|
250
|
+
def print_seq(
|
251
|
+
self,
|
252
|
+
seq: Iterable[ElemType],
|
253
|
+
*,
|
254
|
+
emit: Callable[[ElemType], None] | None = None,
|
255
|
+
delim: str = ", ",
|
256
|
+
prefix: str = "",
|
257
|
+
suffix: str = "",
|
258
|
+
style=None,
|
259
|
+
highlight=None,
|
260
|
+
) -> None:
|
261
|
+
"""print a sequence of objects.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
seq(Iterable[ElemType]): sequence of objects to print
|
265
|
+
|
266
|
+
Keyword Args:
|
267
|
+
emit(Callable[[ElemType], None]): function to print each element, default to None
|
268
|
+
delim(str): delimiter between elements, default to ", "
|
269
|
+
prefix(str): prefix to print before the sequence, default to ""
|
270
|
+
suffix(str): suffix to print after the sequence, default to ""
|
271
|
+
style(str): style to use, default to None
|
272
|
+
highlight(bool): whether to highlight the text, default to None
|
273
|
+
"""
|
274
|
+
emit = emit or self.print
|
275
|
+
self.plain_print(prefix, style=style, highlight=highlight)
|
276
|
+
for idx, item in enumerate(seq):
|
277
|
+
if idx > 0:
|
278
|
+
self.plain_print(delim, style=style)
|
279
|
+
emit(item)
|
280
|
+
self.plain_print(suffix, style=style, highlight=highlight)
|
281
|
+
|
282
|
+
KeyType = TypeVar("KeyType")
|
283
|
+
ValueType = TypeVar("ValueType", bound=Printable)
|
284
|
+
|
285
|
+
def print_mapping(
|
286
|
+
self,
|
287
|
+
elems: dict[KeyType, ValueType],
|
288
|
+
*,
|
289
|
+
emit: Callable[[ValueType], None] | None = None,
|
290
|
+
delim: str = ", ",
|
291
|
+
) -> None:
|
292
|
+
"""print a mapping of key-value pairs.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
elems(dict[KeyType, ValueType]): mapping to print
|
296
|
+
|
297
|
+
Keyword Args:
|
298
|
+
emit(Callable[[ValueType], None]): function to print each value, default to None
|
299
|
+
delim(str): delimiter between key-value pairs, default to ", "
|
300
|
+
"""
|
301
|
+
emit = emit or self.print
|
302
|
+
for i, (key, value) in enumerate(elems.items()):
|
303
|
+
if i > 0:
|
304
|
+
self.plain_print(delim)
|
305
|
+
self.plain_print(f"{key}=")
|
306
|
+
emit(value)
|
307
|
+
|
308
|
+
def result_width(self, stmts: Iterable["ir.Statement"]) -> int:
|
309
|
+
"""return the maximum width of the result column for a sequence of statements.
|
310
|
+
|
311
|
+
Args:
|
312
|
+
stmts(Iterable[ir.Statement]): sequence of statements
|
313
|
+
|
314
|
+
Returns:
|
315
|
+
int: maximum width of the result column
|
316
|
+
"""
|
317
|
+
result_width = 0
|
318
|
+
for stmt in stmts:
|
319
|
+
result_width = max(result_width, len(self.result_str(stmt._results)))
|
320
|
+
return result_width
|
321
|
+
|
322
|
+
@contextmanager
|
323
|
+
def align(self, width: int) -> Generator[PrintState, Any, None]:
|
324
|
+
"""align the result column width, and restore it after the context.
|
325
|
+
|
326
|
+
Args:
|
327
|
+
width(int): width of the column
|
328
|
+
|
329
|
+
Yields:
|
330
|
+
PrintState: the state with the new column width
|
331
|
+
"""
|
332
|
+
old_width = self.state.result_width
|
333
|
+
self.state.result_width = width
|
334
|
+
try:
|
335
|
+
yield self.state
|
336
|
+
finally:
|
337
|
+
self.state.result_width = old_width
|
338
|
+
|
339
|
+
@contextmanager
|
340
|
+
def indent(
|
341
|
+
self, increase: int = 2, mark: bool | None = None
|
342
|
+
) -> Generator[PrintState, Any, None]:
|
343
|
+
"""increase the indentation level, and restore it after the context.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
increase(int): amount to increase the indentation level, default to 2
|
347
|
+
mark(bool): whether to mark the indentation level, default to None
|
348
|
+
|
349
|
+
Yields:
|
350
|
+
PrintState: the state with the new indentation level.
|
351
|
+
"""
|
352
|
+
mark = mark if mark is not None else self.show_indent_mark
|
353
|
+
self.state.indent += increase
|
354
|
+
if mark:
|
355
|
+
self.state.indent_marks.append(self.state.indent)
|
356
|
+
try:
|
357
|
+
yield self.state
|
358
|
+
finally:
|
359
|
+
self.state.indent -= increase
|
360
|
+
if mark:
|
361
|
+
self.state.indent_marks.pop()
|
362
|
+
|
363
|
+
@contextmanager
|
364
|
+
def rich(
|
365
|
+
self, style: str | None = None, highlight: bool = False
|
366
|
+
) -> Generator[PrintState, Any, None]:
|
367
|
+
"""set the rich style and highlight, and restore them after the context.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
style(str | None): style to use, default to None
|
371
|
+
highlight(bool): whether to highlight the text, default to False
|
372
|
+
|
373
|
+
Yields:
|
374
|
+
PrintState: the state with the new style and highlight.
|
375
|
+
"""
|
376
|
+
old_style = self.state.rich_style
|
377
|
+
old_highlight = self.state.rich_highlight
|
378
|
+
self.state.rich_style = style
|
379
|
+
self.state.rich_highlight = highlight
|
380
|
+
try:
|
381
|
+
yield self.state
|
382
|
+
finally:
|
383
|
+
self.state.rich_style = old_style
|
384
|
+
self.state.rich_highlight = old_highlight
|
385
|
+
|
386
|
+
@contextmanager
|
387
|
+
def string_io(self) -> Generator[io.StringIO, Any, None]:
|
388
|
+
"""Temporary string IO for capturing output.
|
389
|
+
|
390
|
+
Yields:
|
391
|
+
io.StringIO: the string IO object.
|
392
|
+
"""
|
393
|
+
stream = io.StringIO()
|
394
|
+
old_file = self.console.file
|
395
|
+
self.console.file = stream
|
396
|
+
try:
|
397
|
+
yield stream
|
398
|
+
finally:
|
399
|
+
self.console.file = old_file
|
400
|
+
stream.close()
|
401
|
+
|
402
|
+
def result_str(self, results: list["ir.ResultValue"]) -> str:
|
403
|
+
"""return the string representation of a list of result values.
|
404
|
+
|
405
|
+
Args:
|
406
|
+
results(list[ir.ResultValue]): list of result values to print
|
407
|
+
"""
|
408
|
+
with self.string_io() as stream:
|
409
|
+
self.print_seq(results, delim=", ")
|
410
|
+
result_str = stream.getvalue()
|
411
|
+
return result_str
|
412
|
+
|
413
|
+
def debug(self, message: str):
|
414
|
+
"""Print a debug message."""
|
415
|
+
self.state.messages.append(f"DEBUG: {message}")
|
kirin/py.typed
ADDED
File without changes
|
kirin/registry.py
ADDED
@@ -0,0 +1,105 @@
|
|
1
|
+
from typing import TYPE_CHECKING, Callable, Iterable
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
if TYPE_CHECKING:
|
5
|
+
from kirin.ir import Attribute
|
6
|
+
from kirin.ir.group import DialectGroup
|
7
|
+
from kirin.ir.nodes import Statement
|
8
|
+
from kirin.lowering import FromPythonAST
|
9
|
+
from kirin.interp.impl import Signature
|
10
|
+
from kirin.interp.table import MethodTable
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class StatementImpl:
|
15
|
+
parent: "MethodTable"
|
16
|
+
impl: Callable
|
17
|
+
|
18
|
+
def __call__(self, interp, frame, stmt: "Statement"):
|
19
|
+
return self.impl(self.parent, interp, frame, stmt)
|
20
|
+
|
21
|
+
def __repr__(self) -> str:
|
22
|
+
return f"method impl `{self.impl.__name__}` in {repr(self.parent.__class__)}"
|
23
|
+
|
24
|
+
|
25
|
+
@dataclass
|
26
|
+
class AttributeImpl:
|
27
|
+
parent: "MethodTable"
|
28
|
+
impl: Callable
|
29
|
+
|
30
|
+
def __call__(self, interp, attr: "Attribute"):
|
31
|
+
return self.impl(self.parent, interp, attr)
|
32
|
+
|
33
|
+
def __repr__(self) -> str:
|
34
|
+
return f"attribute impl `{self.impl.__name__}` in {repr(self.parent.__class__)}"
|
35
|
+
|
36
|
+
|
37
|
+
@dataclass
|
38
|
+
class InterpreterRegistry:
|
39
|
+
attributes: dict[type["Attribute"], "AttributeImpl"]
|
40
|
+
statements: dict["Signature", "StatementImpl"]
|
41
|
+
|
42
|
+
|
43
|
+
@dataclass
|
44
|
+
class Registry:
|
45
|
+
"""Proxy class to build different registries from a dialect group."""
|
46
|
+
|
47
|
+
dialects: "DialectGroup"
|
48
|
+
"""The dialect group to build the registry from."""
|
49
|
+
|
50
|
+
def ast(self, keys: Iterable[str]) -> dict[str, "FromPythonAST"]:
|
51
|
+
"""select the dialect lowering interpreters for the given key.
|
52
|
+
|
53
|
+
Args:
|
54
|
+
keys (Iterable[str]): the keys to search for in the dialects
|
55
|
+
|
56
|
+
Returns:
|
57
|
+
a map of dialects to their lowering interpreters
|
58
|
+
"""
|
59
|
+
ret: dict[str, "FromPythonAST"] = {}
|
60
|
+
from_ast = None
|
61
|
+
for dialect in self.dialects.data:
|
62
|
+
for key in keys:
|
63
|
+
if key in dialect.lowering:
|
64
|
+
from_ast = dialect.lowering[key]
|
65
|
+
break
|
66
|
+
|
67
|
+
if from_ast is None:
|
68
|
+
msg = ",".join(keys)
|
69
|
+
raise KeyError(f"Lowering not found for {msg}")
|
70
|
+
|
71
|
+
for name in from_ast.names:
|
72
|
+
if name in ret:
|
73
|
+
raise KeyError(f"Lowering {name} already exists")
|
74
|
+
|
75
|
+
ret[name] = from_ast
|
76
|
+
return ret
|
77
|
+
|
78
|
+
def interpreter(self, keys: Iterable[str]):
|
79
|
+
"""select the dialect interpreter for the given key.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
keys (Iterable[str]): the keys to search for in the dialects
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
a map of statement signatures to their interpretation functions,
|
86
|
+
and a map of dialects to their fallback interpreters.
|
87
|
+
"""
|
88
|
+
attributes: dict[type["Attribute"], "AttributeImpl"] = {}
|
89
|
+
table: dict["Signature", "StatementImpl"] = {}
|
90
|
+
for dialect in self.dialects.data:
|
91
|
+
dialect_table = None
|
92
|
+
for key in keys:
|
93
|
+
if key not in dialect.interps:
|
94
|
+
continue
|
95
|
+
|
96
|
+
dialect_table = dialect.interps[key]
|
97
|
+
for sig, func in dialect_table.attribute.items():
|
98
|
+
if sig not in attributes:
|
99
|
+
attributes[sig] = AttributeImpl(dialect_table, func)
|
100
|
+
|
101
|
+
for sig, func in dialect_table.table.items():
|
102
|
+
if sig not in table:
|
103
|
+
table[sig] = StatementImpl(dialect_table, func)
|
104
|
+
|
105
|
+
return InterpreterRegistry(attributes, table)
|
kirin/registry.pyi
ADDED
@@ -0,0 +1,52 @@
|
|
1
|
+
from typing import Generic, TypeVar, Callable, Iterable, TypeAlias
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from kirin.ir.group import DialectGroup
|
5
|
+
from kirin.ir.nodes import Statement
|
6
|
+
from kirin.lowering import FromPythonAST
|
7
|
+
from kirin.interp.base import FrameABC, BaseInterpreter
|
8
|
+
from kirin.interp.impl import Signature
|
9
|
+
from kirin.interp.table import MethodTable
|
10
|
+
from kirin.interp.value import StatementResult
|
11
|
+
from kirin.ir.attrs.abc import Attribute
|
12
|
+
|
13
|
+
MethodTableSelf = TypeVar("MethodTableSelf", bound="MethodTable")
|
14
|
+
InterpreterType = TypeVar("InterpreterType", bound="BaseInterpreter")
|
15
|
+
FrameType = TypeVar("FrameType", bound="FrameABC")
|
16
|
+
StatementType = TypeVar("StatementType", bound="Statement")
|
17
|
+
MethodFunction: TypeAlias = Callable[
|
18
|
+
[MethodTableSelf, InterpreterType, FrameType, StatementType], StatementResult
|
19
|
+
]
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class StatementImpl(Generic[InterpreterType, FrameType]):
|
23
|
+
parent: "MethodTable"
|
24
|
+
impl: MethodFunction["MethodTable", InterpreterType, FrameType, "Statement"]
|
25
|
+
|
26
|
+
def __call__(
|
27
|
+
self, interp: InterpreterType, frame: FrameType, stmt: "Statement"
|
28
|
+
) -> StatementResult: ...
|
29
|
+
def __repr__(self) -> str: ...
|
30
|
+
|
31
|
+
@dataclass
|
32
|
+
class AttributeImpl:
|
33
|
+
parent: "MethodTable"
|
34
|
+
impl: Callable
|
35
|
+
|
36
|
+
def __call__(self, interp, attr: "Attribute"): ...
|
37
|
+
def __repr__(self) -> str: ...
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class InterpreterRegistry:
|
41
|
+
attributes: dict[type["Attribute"], "AttributeImpl"]
|
42
|
+
statements: dict["Signature", "StatementImpl"]
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class Registry:
|
46
|
+
"""Proxy class to build different registries from a dialect group."""
|
47
|
+
|
48
|
+
dialects: "DialectGroup"
|
49
|
+
"""The dialect group to build the registry from."""
|
50
|
+
|
51
|
+
def ast(self, keys: Iterable[str]) -> dict[str, "FromPythonAST"]: ...
|
52
|
+
def interpreter(self, keys: Iterable[str]) -> InterpreterRegistry: ...
|
@@ -0,0 +1,14 @@
|
|
1
|
+
from .cse import CommonSubexpressionElimination as CommonSubexpressionElimination
|
2
|
+
from .dce import DeadCodeElimination as DeadCodeElimination
|
3
|
+
from .fold import ConstantFold as ConstantFold
|
4
|
+
from .walk import Walk as Walk
|
5
|
+
from .alias import InlineAlias as InlineAlias
|
6
|
+
from .chain import Chain as Chain
|
7
|
+
from .inline import Inline as Inline
|
8
|
+
from .getitem import InlineGetItem as InlineGetItem
|
9
|
+
from .fixpoint import Fixpoint as Fixpoint
|
10
|
+
from .getfield import InlineGetField as InlineGetField
|
11
|
+
from .apply_type import ApplyType as ApplyType
|
12
|
+
from .compactify import CFGCompactify as CFGCompactify
|
13
|
+
from .wrap_const import WrapConst as WrapConst
|
14
|
+
from .call2invoke import Call2Invoke as Call2Invoke
|
kirin/rewrite/abc.py
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
from abc import ABC
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
4
|
+
from kirin.ir import Pure, Block, IRNode, Region, MaybePure, Statement
|
5
|
+
from kirin.rewrite.result import RewriteResult
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass(repr=False)
|
9
|
+
class RewriteRule(ABC):
|
10
|
+
"""A rewrite rule that matches and rewrites IR nodes.
|
11
|
+
|
12
|
+
The rewrite rule is applied to an IR node by calling the instance with the node as an argument.
|
13
|
+
The rewrite rule should mutate the node instead of returning a new node. A `RewriteResult` should
|
14
|
+
be returned to indicate whether the rewrite rule has done something, whether the rewrite rule
|
15
|
+
should terminate, and whether the rewrite rule has exceeded the maximum number of iterations.
|
16
|
+
"""
|
17
|
+
|
18
|
+
def rewrite(self, node: IRNode) -> RewriteResult:
|
19
|
+
if isinstance(node, Region):
|
20
|
+
return self.rewrite_Region(node)
|
21
|
+
elif isinstance(node, Block):
|
22
|
+
return self.rewrite_Block(node)
|
23
|
+
elif isinstance(node, Statement):
|
24
|
+
return self.rewrite_Statement(node)
|
25
|
+
else:
|
26
|
+
return RewriteResult()
|
27
|
+
|
28
|
+
def rewrite_Region(self, node: Region) -> RewriteResult:
|
29
|
+
return RewriteResult()
|
30
|
+
|
31
|
+
def rewrite_Block(self, node: Block) -> RewriteResult:
|
32
|
+
return RewriteResult()
|
33
|
+
|
34
|
+
def rewrite_Statement(self, node: Statement) -> RewriteResult:
|
35
|
+
return RewriteResult()
|
36
|
+
|
37
|
+
def is_pure(self, node: Statement):
|
38
|
+
if node.has_trait(Pure):
|
39
|
+
return True
|
40
|
+
|
41
|
+
if (trait := node.get_trait(MaybePure)) and trait.is_pure(node):
|
42
|
+
return True
|
43
|
+
return False
|
@@ -0,0 +1 @@
|
|
1
|
+
from .fold import Fold as Fold
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin.rewrite import Walk, Chain, Fixpoint
|
4
|
+
from kirin.analysis import const
|
5
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
6
|
+
from kirin.rewrite.dce import DeadCodeElimination
|
7
|
+
from kirin.rewrite.fold import ConstantFold
|
8
|
+
from kirin.ir.nodes.base import IRNode
|
9
|
+
from kirin.rewrite.inline import Inline
|
10
|
+
from kirin.rewrite.getitem import InlineGetItem
|
11
|
+
from kirin.rewrite.getfield import InlineGetField
|
12
|
+
from kirin.rewrite.compactify import CFGCompactify
|
13
|
+
from kirin.rewrite.wrap_const import WrapConst
|
14
|
+
from kirin.rewrite.call2invoke import Call2Invoke
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class Fold(RewriteRule):
|
19
|
+
rule: RewriteRule
|
20
|
+
|
21
|
+
def __init__(self, frame: const.Frame):
|
22
|
+
rule = Fixpoint(
|
23
|
+
Chain(
|
24
|
+
Walk(WrapConst(frame)),
|
25
|
+
Walk(Inline(lambda _: True)),
|
26
|
+
Walk(ConstantFold()),
|
27
|
+
Walk(Call2Invoke()),
|
28
|
+
Fixpoint(
|
29
|
+
Walk(
|
30
|
+
Chain(
|
31
|
+
InlineGetItem(),
|
32
|
+
InlineGetField(),
|
33
|
+
DeadCodeElimination(),
|
34
|
+
)
|
35
|
+
)
|
36
|
+
),
|
37
|
+
Walk(CFGCompactify()),
|
38
|
+
)
|
39
|
+
)
|
40
|
+
self.rule = rule
|
41
|
+
|
42
|
+
def rewrite(self, node: IRNode) -> RewriteResult:
|
43
|
+
return self.rule.rewrite(node)
|
kirin/rewrite/alias.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin import ir
|
4
|
+
from kirin.rewrite.abc import RewriteRule, RewriteResult
|
5
|
+
from kirin.dialects.py.assign import Alias
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass
|
9
|
+
class InlineAlias(RewriteRule):
|
10
|
+
|
11
|
+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
|
12
|
+
if not isinstance(node, Alias):
|
13
|
+
return RewriteResult()
|
14
|
+
|
15
|
+
node.result.replace_by(node.value)
|
16
|
+
return RewriteResult(has_done_something=True)
|