kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/print/printer.py ADDED
@@ -0,0 +1,415 @@
1
+ import io
2
+ from typing import (
3
+ IO,
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Union,
7
+ Literal,
8
+ TypeVar,
9
+ Callable,
10
+ Iterable,
11
+ Generator,
12
+ )
13
+ from contextlib import contextmanager
14
+ from dataclasses import field, dataclass
15
+
16
+ from rich.theme import Theme
17
+ from rich.console import Console
18
+
19
+ from kirin.idtable import IdTable
20
+ from kirin.print.printable import Printable
21
+
22
+ if TYPE_CHECKING:
23
+ from kirin import ir
24
+
25
+
26
+ DEFAULT_THEME = {
27
+ "dark": Theme(
28
+ {
29
+ "dialect": "dark_blue",
30
+ "type": "dark_blue",
31
+ "comment": "bright_black",
32
+ "keyword": "red",
33
+ "symbol": "cyan",
34
+ "warning": "yellow",
35
+ "string": "green",
36
+ "irrational": "default",
37
+ "number": "default",
38
+ }
39
+ ),
40
+ "light": Theme(
41
+ {
42
+ "dialect": "blue",
43
+ "type": "blue",
44
+ "comment": "bright_black",
45
+ "keyword": "red",
46
+ "symbol": "cyan",
47
+ "warning": "yellow",
48
+ "string": "green",
49
+ "irrational": "magenta",
50
+ }
51
+ ),
52
+ }
53
+
54
+
55
+ @dataclass
56
+ class PrintState:
57
+ ssa_id: IdTable["ir.SSAValue"] = field(default_factory=IdTable["ir.SSAValue"])
58
+ block_id: IdTable["ir.Block"] = field(
59
+ default_factory=lambda: IdTable["ir.Block"](prefix="^")
60
+ )
61
+ indent: int = 0
62
+ result_width: int = 0
63
+ indent_marks: list[int] = field(default_factory=list)
64
+ result_width: int = 0
65
+ "SSA-value column width in printing"
66
+ rich_style: str | None = None
67
+ rich_highlight: bool | None = False
68
+ messages: list[str] = field(default_factory=list)
69
+
70
+
71
+ IOType = TypeVar("IOType", bound=IO)
72
+
73
+
74
+ def _default_console():
75
+ return Console(force_jupyter=False)
76
+
77
+
78
+ @dataclass
79
+ class Printer:
80
+ """A IR pretty printer build on top of Rich console."""
81
+
82
+ console: Console = field(default_factory=_default_console)
83
+ """Rich console"""
84
+ analysis: dict["ir.SSAValue", Any] | None = None
85
+ """Analysis results"""
86
+ hint: str | None = None
87
+ """key of the SSAValue hint to print"""
88
+ state: PrintState = field(default_factory=PrintState)
89
+ """Printing state"""
90
+ show_indent_mark: bool = field(default=True, kw_only=True)
91
+ "Whether to show indent marks, e.g │"
92
+ theme: Theme | dict | Literal["dark", "light"] = field(default="dark", kw_only=True)
93
+ "Theme to use for printing"
94
+
95
+ def __post_init__(self):
96
+ if isinstance(self.theme, dict):
97
+ self.theme = Theme(self.theme)
98
+ elif isinstance(self.theme, str):
99
+ self.theme = DEFAULT_THEME[self.theme]
100
+ self.console.push_theme(self.theme)
101
+
102
+ def print(self, object):
103
+ """entry point for printing an object
104
+
105
+ Args:
106
+ object: object to print.
107
+ """
108
+ if isinstance(object, Printable):
109
+ object.print_impl(self)
110
+ else:
111
+ fn = getattr(self, f"print_{object.__class__.__name__}", None)
112
+ if fn is None:
113
+ raise NotImplementedError(
114
+ f"Printer for {object.__class__.__name__} not found"
115
+ )
116
+ fn(object)
117
+
118
+ def print_name(
119
+ self, node: Union["ir.Attribute", "ir.Statement"], prefix: str = ""
120
+ ) -> None:
121
+ """print the name of a node
122
+
123
+ Args:
124
+ node(ir.Attribute | ir.Statement): node to print
125
+ prefix(str): prefix to print before the name, default to ""
126
+ """
127
+ self.print_dialect_path(node, prefix=prefix)
128
+ if node.dialect:
129
+ self.plain_print(".")
130
+ self.plain_print(node.name)
131
+
132
+ def print_stmt(self, node: "ir.Statement"):
133
+ if node._results:
134
+ result_str = self.result_str(node._results)
135
+ self.plain_print(result_str.rjust(self.state.result_width), " = ")
136
+ elif self.state.result_width:
137
+ self.plain_print(" " * self.state.result_width, " ")
138
+ with self.indent(self.state.result_width + 3, mark=True):
139
+ self.print(node)
140
+ with self.rich(style="warning"):
141
+ self.print_analysis(*node._results, prefix=" # ---> ")
142
+ with self.rich(style="comment"):
143
+ self.print_hint(*node._results)
144
+
145
+ def print_hint(
146
+ self,
147
+ *values: "ir.SSAValue",
148
+ prefix: str = " // hint<",
149
+ suffix: str = ">",
150
+ ):
151
+ if not self.hint or not values:
152
+ return
153
+
154
+ self.plain_print(prefix)
155
+ self.plain_print(self.hint)
156
+ for idx, item in enumerate(values):
157
+ if idx > 0:
158
+ self.plain_print(", ")
159
+
160
+ self.plain_print("=")
161
+ if item.hints.get(self.hint) is not None:
162
+ self.plain_print(repr(item.hints.get(self.hint)))
163
+ else:
164
+ self.plain_print("missing")
165
+ self.plain_print(suffix)
166
+
167
+ def print_analysis(
168
+ self,
169
+ *values: "ir.SSAValue",
170
+ prefix: str = "",
171
+ ):
172
+ if self.analysis is None or not values:
173
+ return
174
+
175
+ self.plain_print(prefix)
176
+ for idx, value in enumerate(values):
177
+ if idx > 0:
178
+ self.plain_print(", ")
179
+
180
+ if result := self.analysis.get(value):
181
+ self.plain_print(repr(result))
182
+ else:
183
+ self.plain_print("missing")
184
+
185
+ def print_dialect_path(
186
+ self, node: Union["ir.Attribute", "ir.Statement"], prefix: str = ""
187
+ ) -> None:
188
+ """print the dialect path of a node.
189
+
190
+ Args:
191
+ node(ir.Attribute | ir.Statement): node to print
192
+ prefix(str): prefix to print before the dialect path, default to ""
193
+ """
194
+ if node.dialect: # not None
195
+ self.plain_print(prefix)
196
+ self.plain_print(node.dialect.name, style="dialect")
197
+ else:
198
+ self.plain_print(prefix)
199
+
200
+ def print_newline(self):
201
+ """print a newline character.
202
+
203
+ This method also prints any messages in the state for debugging.
204
+ """
205
+ self.plain_print("\n")
206
+
207
+ if self.state.messages:
208
+ for message in self.state.messages:
209
+ self.plain_print(message)
210
+ self.plain_print("\n")
211
+ self.state.messages.clear()
212
+ self.print_indent()
213
+
214
+ def print_indent(self):
215
+ """print the current indentation level optionally with indent marks."""
216
+ indent_str = ""
217
+ if self.show_indent_mark and self.state.indent_marks:
218
+ indent_str = "".join(
219
+ "│" if i in self.state.indent_marks else " "
220
+ for i in range(self.state.indent)
221
+ )
222
+ with self.rich(style="comment"):
223
+ self.plain_print(indent_str)
224
+ else:
225
+ indent_str = " " * self.state.indent
226
+ self.plain_print(indent_str)
227
+
228
+ def plain_print(self, *objects, sep="", end="", style=None, highlight=None):
229
+ """print objects without any formatting.
230
+
231
+ Args:
232
+ *objects: objects to print
233
+
234
+ Keyword Args:
235
+ sep(str): separator between objects, default to ""
236
+ end(str): end character, default to ""
237
+ style(str): style to use, default to None
238
+ highlight(bool): whether to highlight the text, default to None
239
+ """
240
+ self.console.out(
241
+ *objects,
242
+ sep=sep,
243
+ end=end,
244
+ style=style or self.state.rich_style,
245
+ highlight=highlight or self.state.rich_highlight,
246
+ )
247
+
248
+ ElemType = TypeVar("ElemType")
249
+
250
+ def print_seq(
251
+ self,
252
+ seq: Iterable[ElemType],
253
+ *,
254
+ emit: Callable[[ElemType], None] | None = None,
255
+ delim: str = ", ",
256
+ prefix: str = "",
257
+ suffix: str = "",
258
+ style=None,
259
+ highlight=None,
260
+ ) -> None:
261
+ """print a sequence of objects.
262
+
263
+ Args:
264
+ seq(Iterable[ElemType]): sequence of objects to print
265
+
266
+ Keyword Args:
267
+ emit(Callable[[ElemType], None]): function to print each element, default to None
268
+ delim(str): delimiter between elements, default to ", "
269
+ prefix(str): prefix to print before the sequence, default to ""
270
+ suffix(str): suffix to print after the sequence, default to ""
271
+ style(str): style to use, default to None
272
+ highlight(bool): whether to highlight the text, default to None
273
+ """
274
+ emit = emit or self.print
275
+ self.plain_print(prefix, style=style, highlight=highlight)
276
+ for idx, item in enumerate(seq):
277
+ if idx > 0:
278
+ self.plain_print(delim, style=style)
279
+ emit(item)
280
+ self.plain_print(suffix, style=style, highlight=highlight)
281
+
282
+ KeyType = TypeVar("KeyType")
283
+ ValueType = TypeVar("ValueType", bound=Printable)
284
+
285
+ def print_mapping(
286
+ self,
287
+ elems: dict[KeyType, ValueType],
288
+ *,
289
+ emit: Callable[[ValueType], None] | None = None,
290
+ delim: str = ", ",
291
+ ) -> None:
292
+ """print a mapping of key-value pairs.
293
+
294
+ Args:
295
+ elems(dict[KeyType, ValueType]): mapping to print
296
+
297
+ Keyword Args:
298
+ emit(Callable[[ValueType], None]): function to print each value, default to None
299
+ delim(str): delimiter between key-value pairs, default to ", "
300
+ """
301
+ emit = emit or self.print
302
+ for i, (key, value) in enumerate(elems.items()):
303
+ if i > 0:
304
+ self.plain_print(delim)
305
+ self.plain_print(f"{key}=")
306
+ emit(value)
307
+
308
+ def result_width(self, stmts: Iterable["ir.Statement"]) -> int:
309
+ """return the maximum width of the result column for a sequence of statements.
310
+
311
+ Args:
312
+ stmts(Iterable[ir.Statement]): sequence of statements
313
+
314
+ Returns:
315
+ int: maximum width of the result column
316
+ """
317
+ result_width = 0
318
+ for stmt in stmts:
319
+ result_width = max(result_width, len(self.result_str(stmt._results)))
320
+ return result_width
321
+
322
+ @contextmanager
323
+ def align(self, width: int) -> Generator[PrintState, Any, None]:
324
+ """align the result column width, and restore it after the context.
325
+
326
+ Args:
327
+ width(int): width of the column
328
+
329
+ Yields:
330
+ PrintState: the state with the new column width
331
+ """
332
+ old_width = self.state.result_width
333
+ self.state.result_width = width
334
+ try:
335
+ yield self.state
336
+ finally:
337
+ self.state.result_width = old_width
338
+
339
+ @contextmanager
340
+ def indent(
341
+ self, increase: int = 2, mark: bool | None = None
342
+ ) -> Generator[PrintState, Any, None]:
343
+ """increase the indentation level, and restore it after the context.
344
+
345
+ Args:
346
+ increase(int): amount to increase the indentation level, default to 2
347
+ mark(bool): whether to mark the indentation level, default to None
348
+
349
+ Yields:
350
+ PrintState: the state with the new indentation level.
351
+ """
352
+ mark = mark if mark is not None else self.show_indent_mark
353
+ self.state.indent += increase
354
+ if mark:
355
+ self.state.indent_marks.append(self.state.indent)
356
+ try:
357
+ yield self.state
358
+ finally:
359
+ self.state.indent -= increase
360
+ if mark:
361
+ self.state.indent_marks.pop()
362
+
363
+ @contextmanager
364
+ def rich(
365
+ self, style: str | None = None, highlight: bool = False
366
+ ) -> Generator[PrintState, Any, None]:
367
+ """set the rich style and highlight, and restore them after the context.
368
+
369
+ Args:
370
+ style(str | None): style to use, default to None
371
+ highlight(bool): whether to highlight the text, default to False
372
+
373
+ Yields:
374
+ PrintState: the state with the new style and highlight.
375
+ """
376
+ old_style = self.state.rich_style
377
+ old_highlight = self.state.rich_highlight
378
+ self.state.rich_style = style
379
+ self.state.rich_highlight = highlight
380
+ try:
381
+ yield self.state
382
+ finally:
383
+ self.state.rich_style = old_style
384
+ self.state.rich_highlight = old_highlight
385
+
386
+ @contextmanager
387
+ def string_io(self) -> Generator[io.StringIO, Any, None]:
388
+ """Temporary string IO for capturing output.
389
+
390
+ Yields:
391
+ io.StringIO: the string IO object.
392
+ """
393
+ stream = io.StringIO()
394
+ old_file = self.console.file
395
+ self.console.file = stream
396
+ try:
397
+ yield stream
398
+ finally:
399
+ self.console.file = old_file
400
+ stream.close()
401
+
402
+ def result_str(self, results: list["ir.ResultValue"]) -> str:
403
+ """return the string representation of a list of result values.
404
+
405
+ Args:
406
+ results(list[ir.ResultValue]): list of result values to print
407
+ """
408
+ with self.string_io() as stream:
409
+ self.print_seq(results, delim=", ")
410
+ result_str = stream.getvalue()
411
+ return result_str
412
+
413
+ def debug(self, message: str):
414
+ """Print a debug message."""
415
+ self.state.messages.append(f"DEBUG: {message}")
kirin/py.typed ADDED
File without changes
kirin/registry.py ADDED
@@ -0,0 +1,105 @@
1
+ from typing import TYPE_CHECKING, Callable, Iterable
2
+ from dataclasses import dataclass
3
+
4
+ if TYPE_CHECKING:
5
+ from kirin.ir import Attribute
6
+ from kirin.ir.group import DialectGroup
7
+ from kirin.ir.nodes import Statement
8
+ from kirin.lowering import FromPythonAST
9
+ from kirin.interp.impl import Signature
10
+ from kirin.interp.table import MethodTable
11
+
12
+
13
+ @dataclass
14
+ class StatementImpl:
15
+ parent: "MethodTable"
16
+ impl: Callable
17
+
18
+ def __call__(self, interp, frame, stmt: "Statement"):
19
+ return self.impl(self.parent, interp, frame, stmt)
20
+
21
+ def __repr__(self) -> str:
22
+ return f"method impl `{self.impl.__name__}` in {repr(self.parent.__class__)}"
23
+
24
+
25
+ @dataclass
26
+ class AttributeImpl:
27
+ parent: "MethodTable"
28
+ impl: Callable
29
+
30
+ def __call__(self, interp, attr: "Attribute"):
31
+ return self.impl(self.parent, interp, attr)
32
+
33
+ def __repr__(self) -> str:
34
+ return f"attribute impl `{self.impl.__name__}` in {repr(self.parent.__class__)}"
35
+
36
+
37
+ @dataclass
38
+ class InterpreterRegistry:
39
+ attributes: dict[type["Attribute"], "AttributeImpl"]
40
+ statements: dict["Signature", "StatementImpl"]
41
+
42
+
43
+ @dataclass
44
+ class Registry:
45
+ """Proxy class to build different registries from a dialect group."""
46
+
47
+ dialects: "DialectGroup"
48
+ """The dialect group to build the registry from."""
49
+
50
+ def ast(self, keys: Iterable[str]) -> dict[str, "FromPythonAST"]:
51
+ """select the dialect lowering interpreters for the given key.
52
+
53
+ Args:
54
+ keys (Iterable[str]): the keys to search for in the dialects
55
+
56
+ Returns:
57
+ a map of dialects to their lowering interpreters
58
+ """
59
+ ret: dict[str, "FromPythonAST"] = {}
60
+ from_ast = None
61
+ for dialect in self.dialects.data:
62
+ for key in keys:
63
+ if key in dialect.lowering:
64
+ from_ast = dialect.lowering[key]
65
+ break
66
+
67
+ if from_ast is None:
68
+ msg = ",".join(keys)
69
+ raise KeyError(f"Lowering not found for {msg}")
70
+
71
+ for name in from_ast.names:
72
+ if name in ret:
73
+ raise KeyError(f"Lowering {name} already exists")
74
+
75
+ ret[name] = from_ast
76
+ return ret
77
+
78
+ def interpreter(self, keys: Iterable[str]):
79
+ """select the dialect interpreter for the given key.
80
+
81
+ Args:
82
+ keys (Iterable[str]): the keys to search for in the dialects
83
+
84
+ Returns:
85
+ a map of statement signatures to their interpretation functions,
86
+ and a map of dialects to their fallback interpreters.
87
+ """
88
+ attributes: dict[type["Attribute"], "AttributeImpl"] = {}
89
+ table: dict["Signature", "StatementImpl"] = {}
90
+ for dialect in self.dialects.data:
91
+ dialect_table = None
92
+ for key in keys:
93
+ if key not in dialect.interps:
94
+ continue
95
+
96
+ dialect_table = dialect.interps[key]
97
+ for sig, func in dialect_table.attribute.items():
98
+ if sig not in attributes:
99
+ attributes[sig] = AttributeImpl(dialect_table, func)
100
+
101
+ for sig, func in dialect_table.table.items():
102
+ if sig not in table:
103
+ table[sig] = StatementImpl(dialect_table, func)
104
+
105
+ return InterpreterRegistry(attributes, table)
kirin/registry.pyi ADDED
@@ -0,0 +1,52 @@
1
+ from typing import Generic, TypeVar, Callable, Iterable, TypeAlias
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.ir.group import DialectGroup
5
+ from kirin.ir.nodes import Statement
6
+ from kirin.lowering import FromPythonAST
7
+ from kirin.interp.base import FrameABC, BaseInterpreter
8
+ from kirin.interp.impl import Signature
9
+ from kirin.interp.table import MethodTable
10
+ from kirin.interp.value import StatementResult
11
+ from kirin.ir.attrs.abc import Attribute
12
+
13
+ MethodTableSelf = TypeVar("MethodTableSelf", bound="MethodTable")
14
+ InterpreterType = TypeVar("InterpreterType", bound="BaseInterpreter")
15
+ FrameType = TypeVar("FrameType", bound="FrameABC")
16
+ StatementType = TypeVar("StatementType", bound="Statement")
17
+ MethodFunction: TypeAlias = Callable[
18
+ [MethodTableSelf, InterpreterType, FrameType, StatementType], StatementResult
19
+ ]
20
+
21
+ @dataclass
22
+ class StatementImpl(Generic[InterpreterType, FrameType]):
23
+ parent: "MethodTable"
24
+ impl: MethodFunction["MethodTable", InterpreterType, FrameType, "Statement"]
25
+
26
+ def __call__(
27
+ self, interp: InterpreterType, frame: FrameType, stmt: "Statement"
28
+ ) -> StatementResult: ...
29
+ def __repr__(self) -> str: ...
30
+
31
+ @dataclass
32
+ class AttributeImpl:
33
+ parent: "MethodTable"
34
+ impl: Callable
35
+
36
+ def __call__(self, interp, attr: "Attribute"): ...
37
+ def __repr__(self) -> str: ...
38
+
39
+ @dataclass
40
+ class InterpreterRegistry:
41
+ attributes: dict[type["Attribute"], "AttributeImpl"]
42
+ statements: dict["Signature", "StatementImpl"]
43
+
44
+ @dataclass
45
+ class Registry:
46
+ """Proxy class to build different registries from a dialect group."""
47
+
48
+ dialects: "DialectGroup"
49
+ """The dialect group to build the registry from."""
50
+
51
+ def ast(self, keys: Iterable[str]) -> dict[str, "FromPythonAST"]: ...
52
+ def interpreter(self, keys: Iterable[str]) -> InterpreterRegistry: ...
@@ -0,0 +1,14 @@
1
+ from .cse import CommonSubexpressionElimination as CommonSubexpressionElimination
2
+ from .dce import DeadCodeElimination as DeadCodeElimination
3
+ from .fold import ConstantFold as ConstantFold
4
+ from .walk import Walk as Walk
5
+ from .alias import InlineAlias as InlineAlias
6
+ from .chain import Chain as Chain
7
+ from .inline import Inline as Inline
8
+ from .getitem import InlineGetItem as InlineGetItem
9
+ from .fixpoint import Fixpoint as Fixpoint
10
+ from .getfield import InlineGetField as InlineGetField
11
+ from .apply_type import ApplyType as ApplyType
12
+ from .compactify import CFGCompactify as CFGCompactify
13
+ from .wrap_const import WrapConst as WrapConst
14
+ from .call2invoke import Call2Invoke as Call2Invoke
kirin/rewrite/abc.py ADDED
@@ -0,0 +1,43 @@
1
+ from abc import ABC
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.ir import Pure, Block, IRNode, Region, MaybePure, Statement
5
+ from kirin.rewrite.result import RewriteResult
6
+
7
+
8
+ @dataclass(repr=False)
9
+ class RewriteRule(ABC):
10
+ """A rewrite rule that matches and rewrites IR nodes.
11
+
12
+ The rewrite rule is applied to an IR node by calling the instance with the node as an argument.
13
+ The rewrite rule should mutate the node instead of returning a new node. A `RewriteResult` should
14
+ be returned to indicate whether the rewrite rule has done something, whether the rewrite rule
15
+ should terminate, and whether the rewrite rule has exceeded the maximum number of iterations.
16
+ """
17
+
18
+ def rewrite(self, node: IRNode) -> RewriteResult:
19
+ if isinstance(node, Region):
20
+ return self.rewrite_Region(node)
21
+ elif isinstance(node, Block):
22
+ return self.rewrite_Block(node)
23
+ elif isinstance(node, Statement):
24
+ return self.rewrite_Statement(node)
25
+ else:
26
+ return RewriteResult()
27
+
28
+ def rewrite_Region(self, node: Region) -> RewriteResult:
29
+ return RewriteResult()
30
+
31
+ def rewrite_Block(self, node: Block) -> RewriteResult:
32
+ return RewriteResult()
33
+
34
+ def rewrite_Statement(self, node: Statement) -> RewriteResult:
35
+ return RewriteResult()
36
+
37
+ def is_pure(self, node: Statement):
38
+ if node.has_trait(Pure):
39
+ return True
40
+
41
+ if (trait := node.get_trait(MaybePure)) and trait.is_pure(node):
42
+ return True
43
+ return False
@@ -0,0 +1 @@
1
+ from .fold import Fold as Fold
@@ -0,0 +1,43 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.rewrite import Walk, Chain, Fixpoint
4
+ from kirin.analysis import const
5
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
6
+ from kirin.rewrite.dce import DeadCodeElimination
7
+ from kirin.rewrite.fold import ConstantFold
8
+ from kirin.ir.nodes.base import IRNode
9
+ from kirin.rewrite.inline import Inline
10
+ from kirin.rewrite.getitem import InlineGetItem
11
+ from kirin.rewrite.getfield import InlineGetField
12
+ from kirin.rewrite.compactify import CFGCompactify
13
+ from kirin.rewrite.wrap_const import WrapConst
14
+ from kirin.rewrite.call2invoke import Call2Invoke
15
+
16
+
17
+ @dataclass
18
+ class Fold(RewriteRule):
19
+ rule: RewriteRule
20
+
21
+ def __init__(self, frame: const.Frame):
22
+ rule = Fixpoint(
23
+ Chain(
24
+ Walk(WrapConst(frame)),
25
+ Walk(Inline(lambda _: True)),
26
+ Walk(ConstantFold()),
27
+ Walk(Call2Invoke()),
28
+ Fixpoint(
29
+ Walk(
30
+ Chain(
31
+ InlineGetItem(),
32
+ InlineGetField(),
33
+ DeadCodeElimination(),
34
+ )
35
+ )
36
+ ),
37
+ Walk(CFGCompactify()),
38
+ )
39
+ )
40
+ self.rule = rule
41
+
42
+ def rewrite(self, node: IRNode) -> RewriteResult:
43
+ return self.rule.rewrite(node)
kirin/rewrite/alias.py ADDED
@@ -0,0 +1,16 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin import ir
4
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
5
+ from kirin.dialects.py.assign import Alias
6
+
7
+
8
+ @dataclass
9
+ class InlineAlias(RewriteRule):
10
+
11
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12
+ if not isinstance(node, Alias):
13
+ return RewriteResult()
14
+
15
+ node.result.replace_by(node.value)
16
+ return RewriteResult(has_done_something=True)