kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,441 @@
1
+ import ast
2
+ import inspect
3
+ import builtins
4
+ from typing import TYPE_CHECKING, Any, TypeVar, get_origin
5
+ from dataclasses import dataclass
6
+
7
+ from kirin.ir import Method, SSAValue, Statement, DialectGroup, traits
8
+ from kirin.source import SourceInfo
9
+ from kirin.exceptions import DialectLoweringError
10
+ from kirin.lowering.frame import Frame
11
+ from kirin.lowering.result import Result
12
+ from kirin.lowering.binding import Binding
13
+ from kirin.lowering.dialect import FromPythonAST
14
+
15
+ if TYPE_CHECKING:
16
+ from kirin.lowering.core import Lowering
17
+
18
+
19
+ @dataclass
20
+ class LoweringState(ast.NodeVisitor):
21
+ # from parent
22
+ dialects: DialectGroup
23
+ registry: dict[str, FromPythonAST]
24
+
25
+ # debug info
26
+ lines: list[str]
27
+ lineno_offset: int
28
+ "lineno offset at the beginning of the source"
29
+ col_offset: int
30
+ "column offset at the beginning of the source"
31
+ source: SourceInfo
32
+ "source info of the current node"
33
+ # line_range: tuple[int, int] # current (<start>, <end>)
34
+ # col_range: tuple[int, int] # current (<start>, <end>)
35
+ max_lines: int = 3
36
+ _current_frame: Frame | None = None
37
+
38
+ @classmethod
39
+ def from_stmt(
40
+ cls,
41
+ lowering: "Lowering",
42
+ stmt: ast.stmt,
43
+ source: str | None = None,
44
+ globals: dict[str, Any] | None = None,
45
+ max_lines: int = 3,
46
+ lineno_offset: int = 0,
47
+ col_offset: int = 0,
48
+ ):
49
+ if not isinstance(stmt, ast.stmt):
50
+ raise ValueError(f"Expected ast.stmt, got {type(stmt)}")
51
+
52
+ if not source:
53
+ source = ast.unparse(stmt)
54
+
55
+ state = cls(
56
+ dialects=lowering.dialects,
57
+ registry=lowering.registry,
58
+ lines=source.splitlines(),
59
+ lineno_offset=lineno_offset,
60
+ col_offset=col_offset,
61
+ source=SourceInfo.from_ast(stmt, lineno_offset, col_offset),
62
+ max_lines=max_lines,
63
+ )
64
+
65
+ frame = Frame.from_stmts([stmt], state, globals=globals)
66
+ state.push_frame(frame)
67
+ return state
68
+
69
+ @property
70
+ def current_frame(self):
71
+ if self._current_frame is None:
72
+ raise ValueError("No frame")
73
+ return self._current_frame
74
+
75
+ @property
76
+ def code(self):
77
+ stmt = self.current_frame.curr_region.blocks[0].first_stmt
78
+ if stmt:
79
+ return stmt
80
+ raise ValueError("No code generated")
81
+
82
+ StmtType = TypeVar("StmtType", bound=Statement)
83
+
84
+ def append_stmt(self, stmt: StmtType) -> StmtType:
85
+ """Shorthand for appending a statement to the current block of current frame."""
86
+ return self.current_frame.append_stmt(stmt)
87
+
88
+ def push_frame(self, frame: Frame):
89
+ frame.parent = self._current_frame
90
+ self._current_frame = frame
91
+ return frame
92
+
93
+ def pop_frame(self, finalize_next: bool = True):
94
+ """Pop the current frame and return it.
95
+
96
+ Args:
97
+ finalize_next(bool): If True, append the next block of the current frame.
98
+
99
+ Returns:
100
+ Frame: The popped frame.
101
+ """
102
+ if self._current_frame is None:
103
+ raise ValueError("No frame to pop")
104
+ frame = self._current_frame
105
+
106
+ if finalize_next and frame.next_block.parent is None:
107
+ frame.append_block(frame.next_block)
108
+ self._current_frame = frame.parent
109
+ return frame
110
+
111
+ def update_lineno(self, node):
112
+ self.source = SourceInfo.from_ast(node, self.lineno_offset, self.col_offset)
113
+
114
+ def __repr__(self) -> str:
115
+ return f"LoweringState({self.current_frame})"
116
+
117
+ def visit(self, node: ast.AST) -> Result:
118
+ self.update_lineno(node)
119
+ name = node.__class__.__name__
120
+ if name in self.registry:
121
+ return self.registry[name].lower(self, node)
122
+ elif isinstance(node, ast.Call):
123
+ # NOTE: if lower_Call is implemented,
124
+ # it will be called first before __dispatch_Call
125
+ # because "Call" exists in self.registry
126
+ return self.__dispatch_Call(node)
127
+ elif isinstance(node, ast.With):
128
+ return self.__dispatch_With(node)
129
+ return super().visit(node)
130
+
131
+ def generic_visit(self, node: ast.AST):
132
+ raise DialectLoweringError(f"unsupported ast node {type(node)}:")
133
+
134
+ def __dispatch_With(self, node: ast.With):
135
+ if len(node.items) != 1:
136
+ raise DialectLoweringError("expected exactly one item in with statement")
137
+
138
+ item = node.items[0]
139
+ if not isinstance(item.context_expr, ast.Call):
140
+ raise DialectLoweringError("expected context expression to be a call")
141
+
142
+ global_callee_result = self.get_global_nothrow(item.context_expr.func)
143
+ if global_callee_result is None:
144
+ raise DialectLoweringError("cannot find call func in with context")
145
+
146
+ global_callee = global_callee_result.unwrap()
147
+ if not issubclass(global_callee, Statement):
148
+ raise DialectLoweringError("expected callee to be a statement")
149
+
150
+ if (
151
+ trait := global_callee.get_trait(traits.FromPythonWithSingleItem)
152
+ ) is not None:
153
+ return trait.lower(global_callee, self, node)
154
+
155
+ raise DialectLoweringError(
156
+ "unsupported callee, missing FromPythonWithSingleItem trait"
157
+ )
158
+
159
+ def __dispatch_Call(self, node: ast.Call):
160
+ # 1. try to lookup global statement object
161
+ # 2. lookup local values
162
+ global_callee_result = self.get_global_nothrow(node.func)
163
+ if global_callee_result is None: # not found in globals, has to be local
164
+ return self.__lower_Call_local(node)
165
+
166
+ global_callee = global_callee_result.unwrap()
167
+ if isinstance(global_callee, Binding):
168
+ global_callee = global_callee.parent
169
+
170
+ if isinstance(global_callee, Method):
171
+ if "Call_global_method" in self.registry:
172
+ return self.registry["Call_global_method"].lower_Call_global_method(
173
+ self, global_callee, node
174
+ )
175
+ raise DialectLoweringError("`lower_Call_global_method` not implemented")
176
+ elif inspect.isclass(global_callee):
177
+ if issubclass(global_callee, Statement):
178
+ if global_callee.dialect is None:
179
+ raise DialectLoweringError(
180
+ f"unsupported dialect `None` for {global_callee.name}"
181
+ )
182
+
183
+ if global_callee.dialect not in self.dialects.data:
184
+ raise DialectLoweringError(
185
+ f"unsupported dialect `{global_callee.dialect.name}`"
186
+ )
187
+
188
+ if (
189
+ trait := global_callee.get_trait(traits.FromPythonCall)
190
+ ) is not None:
191
+ return trait.lower(global_callee, self, node)
192
+
193
+ raise DialectLoweringError(
194
+ f"unsupported callee {global_callee.__name__}, "
195
+ "missing FromPythonAST lowering, or traits.FromPythonCall trait"
196
+ )
197
+ elif issubclass(global_callee, slice):
198
+ if "Call_slice" in self.registry:
199
+ return self.registry["Call_slice"].lower_Call_slice(self, node)
200
+ raise DialectLoweringError("`lower_Call_slice` not implemented")
201
+ elif issubclass(global_callee, range):
202
+ if "Call_range" in self.registry:
203
+ return self.registry["Call_range"].lower_Call_range(self, node)
204
+ raise DialectLoweringError("`lower_Call_range` not implemented")
205
+ elif inspect.isbuiltin(global_callee):
206
+ name = f"Call_{global_callee.__name__}"
207
+ if "Call_builtins" in self.registry:
208
+ dialect_lowering = self.registry["Call_builtins"]
209
+ return dialect_lowering.lower_Call_builtins(self, node)
210
+ elif name in self.registry:
211
+ dialect_lowering = self.registry[name]
212
+ return getattr(dialect_lowering, f"lower_{name}")(self, node)
213
+ else:
214
+ raise DialectLoweringError(
215
+ f"`lower_{name}` is not implemented for builtin function `{global_callee.__name__}`."
216
+ )
217
+
218
+ # symbol exist in global, but not ir.Statement, it may refer to a
219
+ # local value that shadows the global value
220
+ try:
221
+ return self.__lower_Call_local(node)
222
+ except DialectLoweringError:
223
+ # symbol exist in global, but not ir.Statement, not found in locals either
224
+ # this means the symbol is referring to an external uncallable object
225
+ if inspect.isfunction(global_callee):
226
+ raise DialectLoweringError(
227
+ f"unsupported callee: {repr(global_callee)}."
228
+ "Are you trying to call a python function? This is not supported."
229
+ )
230
+ else: # well not much we can do, can't hint
231
+ raise DialectLoweringError(
232
+ f"unsupported callee type: {repr(global_callee)}"
233
+ )
234
+
235
+ def __lower_Call_local(self, node: ast.Call) -> Result:
236
+ callee = self.visit(node.func).expect_one()
237
+ if "Call_local" in self.registry:
238
+ return self.registry["Call_local"].lower_Call_local(self, callee, node)
239
+ raise DialectLoweringError("`lower_Call_local` not implemented")
240
+
241
+ def default_Call_lower(self, stmt: type[Statement], node: ast.Call) -> Result:
242
+ """Default lowering for Python call to statement.
243
+
244
+ This method is intended to be used by traits like `FromPythonCall` to
245
+ provide a default lowering for Python calls to statements.
246
+
247
+ Args:
248
+ stmt(type[Statement]): Statement class to construct.
249
+ node(ast.Call): Python call node to lower.
250
+
251
+ Returns:
252
+ Result: Result of lowering the Python call to statement.
253
+ """
254
+ args, kwargs = self.default_Call_inputs(stmt, node)
255
+ return Result(self.append_stmt(stmt(*args.values(), **kwargs)))
256
+
257
+ def default_Call_inputs(
258
+ self, stmt: type[Statement], node: ast.Call
259
+ ) -> tuple[dict[str, SSAValue | tuple[SSAValue, ...]], dict[str, Any]]:
260
+ from kirin.decl import fields
261
+
262
+ fs = fields(stmt)
263
+ stmt_std_arg_names = fs.std_args.keys()
264
+ stmt_kw_args_name = fs.kw_args.keys()
265
+ stmt_attr_prop_names = fs.attr_or_props
266
+ stmt_required_names = fs.required_names
267
+ stmt_group_arg_names = fs.group_arg_names
268
+ args, kwargs = {}, {}
269
+ for name, value in zip(stmt_std_arg_names, node.args):
270
+ self._parse_arg(stmt_group_arg_names, args, name, value)
271
+ for kw in node.keywords:
272
+ if not isinstance(kw.arg, str):
273
+ raise DialectLoweringError("Expected string for keyword argument name")
274
+
275
+ arg: str = kw.arg
276
+ if arg in node.args:
277
+ raise DialectLoweringError(
278
+ f"Keyword argument {arg} is already present in positional arguments"
279
+ )
280
+ elif arg in stmt_std_arg_names or arg in stmt_kw_args_name:
281
+ self._parse_arg(stmt_group_arg_names, kwargs, kw.arg, kw.value)
282
+ elif arg in stmt_attr_prop_names:
283
+ if (
284
+ isinstance(kw.value, ast.Name)
285
+ and self.current_frame.get_local(kw.value.id) is not None
286
+ ):
287
+ raise DialectLoweringError(
288
+ f"Expected global/constant value for attribute or property {arg}"
289
+ )
290
+ global_value = self.get_global_nothrow(kw.value)
291
+ if global_value is None:
292
+ raise DialectLoweringError(
293
+ f"Expected global value for attribute or property {arg}"
294
+ )
295
+ if (decl := fs.attributes.get(arg)) is not None:
296
+ if decl.annotation is Any:
297
+ kwargs[arg] = global_value.unwrap()
298
+ else:
299
+ kwargs[arg] = global_value.expect(
300
+ get_origin(decl.annotation) or decl.annotation
301
+ )
302
+ else:
303
+ raise DialectLoweringError(
304
+ f"Unexpected attribute or property {arg}"
305
+ )
306
+ else:
307
+ raise DialectLoweringError(f"Unexpected keyword argument {arg}")
308
+
309
+ for name in stmt_required_names:
310
+ if name not in args and name not in kwargs:
311
+ raise DialectLoweringError(f"Missing required argument {name}")
312
+
313
+ return args, kwargs
314
+
315
+ def _parse_arg(
316
+ self,
317
+ group_names: set[str],
318
+ target: dict,
319
+ name: str,
320
+ value: ast.AST,
321
+ ):
322
+ if name in group_names:
323
+ if not isinstance(value, ast.Tuple):
324
+ raise DialectLoweringError(f"Expected tuple for group argument {name}")
325
+ target[name] = tuple(self.visit(elem).expect_one() for elem in value.elts)
326
+ else:
327
+ target[name] = self.visit(value).expect_one()
328
+
329
+ ValueT = TypeVar("ValueT", bound=SSAValue)
330
+
331
+ def exhaust(self, frame: Frame | None = None) -> Frame:
332
+ """Exhaust given frame's stream. If not given, exhaust current frame's stream."""
333
+ if not frame:
334
+ current_frame = self.current_frame
335
+ else:
336
+ current_frame = frame
337
+
338
+ stream = current_frame.stream
339
+ while stream.has_next():
340
+ stmt = stream.pop()
341
+ self.visit(stmt)
342
+ return current_frame
343
+
344
+ def error_hint(self) -> str:
345
+ begin = max(0, self.source.lineno - self.max_lines)
346
+ end = max(self.source.lineno + self.max_lines, self.source.end_lineno or 0)
347
+ end = min(len(self.lines), end) # make sure end is within bounds
348
+ lines = self.lines[begin:end]
349
+ code_indent = min(map(self.__get_indent, lines), default=0)
350
+ lines.append("") # in case the last line errors
351
+
352
+ snippet_lines = []
353
+ for lineno, line in enumerate(lines, begin):
354
+ if lineno == self.source.lineno:
355
+ snippet_lines.append(("-" * (self.source.col_offset)) + "^")
356
+
357
+ snippet_lines.append(line[code_indent:])
358
+
359
+ return "\n".join(snippet_lines)
360
+
361
+ @staticmethod
362
+ def __get_indent(line: str) -> int:
363
+ if len(line) == 0:
364
+ return int(1e9) # very large number
365
+ return len(line) - len(line.lstrip())
366
+
367
+ @dataclass
368
+ class GlobalRefResult:
369
+ data: Any
370
+
371
+ def unwrap(self):
372
+ return self.data
373
+
374
+ ExpectT = TypeVar("ExpectT")
375
+
376
+ def expect(self, typ: type[ExpectT]) -> ExpectT:
377
+ if not isinstance(self.data, typ):
378
+ raise DialectLoweringError(f"expected {typ}, got {type(self.data)}")
379
+ return self.data
380
+
381
+ def get_global_nothrow(self, node) -> GlobalRefResult | None:
382
+ try:
383
+ return self.get_global(node)
384
+ except DialectLoweringError:
385
+ return None
386
+
387
+ def get_global(self, node) -> GlobalRefResult:
388
+ return getattr(
389
+ self, f"get_global_{node.__class__.__name__}", self.get_global_fallback
390
+ )(node)
391
+
392
+ def get_global_fallback(self, node: ast.AST) -> GlobalRefResult:
393
+ raise DialectLoweringError(
394
+ f"unsupported global access get_global_{node.__class__.__name__}: {ast.unparse(node)}"
395
+ )
396
+
397
+ def get_global_Constant(self, node: ast.Constant) -> GlobalRefResult:
398
+ return self.GlobalRefResult(node.value)
399
+
400
+ def get_global_str(self, node: str) -> GlobalRefResult:
401
+ if node in (globals := self.current_frame.globals):
402
+ return self.GlobalRefResult(globals[node])
403
+
404
+ if hasattr(builtins, node):
405
+ return self.GlobalRefResult(getattr(builtins, node))
406
+
407
+ raise DialectLoweringError(f"global {node} not found")
408
+
409
+ def get_global_Name(self, node: ast.Name) -> GlobalRefResult:
410
+ return self.get_global_str(node.id)
411
+
412
+ def get_global_Attribute(self, node: ast.Attribute) -> GlobalRefResult:
413
+ if not isinstance(node.ctx, ast.Load):
414
+ raise DialectLoweringError("unsupported attribute access")
415
+
416
+ match node.value:
417
+ case ast.Name(id):
418
+ value = self.get_global_str(id).unwrap()
419
+ case ast.Attribute():
420
+ value = self.get_global(node.value).unwrap()
421
+ case _:
422
+ raise DialectLoweringError("unsupported attribute access")
423
+
424
+ if hasattr(value, node.attr):
425
+ return self.GlobalRefResult(getattr(value, node.attr))
426
+
427
+ raise DialectLoweringError(f"attribute {node.attr} not found in {value}")
428
+
429
+ def get_global_Subscript(self, node: ast.Subscript) -> GlobalRefResult:
430
+ value = self.get_global(node.value).unwrap()
431
+ if isinstance(node.slice, ast.Tuple):
432
+ index = tuple(self.get_global(elt).unwrap() for elt in node.slice.elts)
433
+ else:
434
+ index = self.get_global(node.slice).unwrap()
435
+ return self.GlobalRefResult(value[index])
436
+
437
+ def get_global_Call(self, node: ast.Call) -> GlobalRefResult:
438
+ func = self.get_global(node.func).unwrap()
439
+ args = [self.get_global(arg).unwrap() for arg in node.args]
440
+ kwargs = {kw.arg: self.get_global(kw.value).unwrap() for kw in node.keywords}
441
+ return self.GlobalRefResult(func(*args, **kwargs))
@@ -0,0 +1,53 @@
1
+ from typing import Generic, TypeVar, Sequence
2
+ from dataclasses import field, dataclass
3
+
4
+ Stmt = TypeVar("Stmt")
5
+
6
+
7
+ @dataclass
8
+ class StmtStream(Generic[Stmt]):
9
+ stmts: list[Stmt] = field(default_factory=list)
10
+ cursor: int = 0
11
+
12
+ def __init__(self, stmts: Sequence[Stmt], cursor: int = 0):
13
+ self.stmts = list(stmts)
14
+ self.cursor = cursor
15
+
16
+ def __iter__(self):
17
+ return self
18
+
19
+ def __next__(self):
20
+ if self.cursor < len(self.stmts):
21
+ stmt = self.stmts[self.cursor]
22
+ self.cursor += 1
23
+ return stmt
24
+ else:
25
+ raise StopIteration
26
+
27
+ def peek(self):
28
+ return self.stmts[self.cursor]
29
+
30
+ def has_next(self):
31
+ return self.cursor < len(self.stmts)
32
+
33
+ def split(self) -> "StmtStream":
34
+ cursor = self.cursor
35
+ self.cursor = len(self.stmts)
36
+ return StmtStream(self.stmts, cursor)
37
+
38
+ def __len__(self):
39
+ return len(self.stmts)
40
+
41
+ def __getitem__(self, key):
42
+ return self.stmts[key]
43
+
44
+ def __setitem__(self, key, value):
45
+ self.stmts[key] = value
46
+
47
+ def pop(self):
48
+ stmt = self.stmts[self.cursor]
49
+ self.cursor += 1
50
+ return stmt
51
+
52
+ def is_empty(self):
53
+ return self.cursor == len(self.stmts)
@@ -0,0 +1,3 @@
1
+ from kirin.passes.abc import Pass as Pass
2
+ from kirin.passes.fold import Fold as Fold
3
+ from kirin.passes.typeinfer import TypeInfer as TypeInfer
kirin/passes/abc.py ADDED
@@ -0,0 +1,44 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import ClassVar
3
+ from dataclasses import dataclass
4
+
5
+ from kirin.ir import Method, DialectGroup
6
+ from kirin.rewrite.abc import RewriteResult
7
+
8
+
9
+ @dataclass
10
+ class Pass(ABC):
11
+ """A pass is a transformation that is applied to a method. It wraps
12
+ the analysis and rewrites needed to transform the method as an independent
13
+ unit.
14
+
15
+ Unlike LLVM/MLIR passes, a pass in Kirin does not apply to a module,
16
+ this is because we focus on individual methods defined within
17
+ python modules. This is a design choice to allow seamless integration
18
+ within the Python interpreter.
19
+
20
+ A Kirin compile unit is a `ir.Method` object, which is always equivalent
21
+ to a LLVM/MLIR module if it were lowered to LLVM/MLIR just like other JIT
22
+ compilers.
23
+ """
24
+
25
+ name: ClassVar[str]
26
+ dialects: DialectGroup
27
+
28
+ def __call__(self, mt: Method) -> RewriteResult:
29
+ result = self.unsafe_run(mt)
30
+ mt.code.verify()
31
+ return result
32
+
33
+ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
34
+ result = RewriteResult()
35
+ for _ in range(max_iter):
36
+ result_ = self.unsafe_run(mt)
37
+ result = result_.join(result)
38
+ if not result.has_done_something:
39
+ break
40
+ mt.code.verify()
41
+ return result
42
+
43
+ @abstractmethod
44
+ def unsafe_run(self, mt: Method) -> RewriteResult: ...
@@ -0,0 +1 @@
1
+ from .fold import Fold as Fold
@@ -0,0 +1,43 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin.passes import Pass
4
+ from kirin.rewrite import (
5
+ Walk,
6
+ Chain,
7
+ Inline,
8
+ Fixpoint,
9
+ WrapConst,
10
+ Call2Invoke,
11
+ ConstantFold,
12
+ CFGCompactify,
13
+ InlineGetItem,
14
+ InlineGetField,
15
+ DeadCodeElimination,
16
+ )
17
+ from kirin.analysis import const
18
+ from kirin.ir.method import Method
19
+ from kirin.rewrite.abc import RewriteResult
20
+
21
+
22
+ @dataclass
23
+ class Fold(Pass):
24
+ constprop: const.Propagate = field(init=False)
25
+
26
+ def __post_init__(self):
27
+ self.constprop = const.Propagate(self.dialects)
28
+
29
+ def unsafe_run(self, mt: Method) -> RewriteResult:
30
+ result = RewriteResult()
31
+ frame, _ = self.constprop.run_analysis(mt)
32
+ result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
33
+ rule = Chain(
34
+ ConstantFold(),
35
+ Call2Invoke(),
36
+ InlineGetField(),
37
+ InlineGetItem(),
38
+ DeadCodeElimination(),
39
+ )
40
+ result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
41
+ result = Walk(Inline(lambda _: True)).rewrite(mt.code).join(result)
42
+ result = Fixpoint(CFGCompactify()).rewrite(mt.code).join(result)
43
+ return result
kirin/passes/fold.py ADDED
@@ -0,0 +1,45 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.ir import Method, SSACFGRegion
4
+ from kirin.rewrite import (
5
+ Walk,
6
+ Chain,
7
+ Fixpoint,
8
+ WrapConst,
9
+ Call2Invoke,
10
+ ConstantFold,
11
+ CFGCompactify,
12
+ InlineGetItem,
13
+ DeadCodeElimination,
14
+ )
15
+ from kirin.analysis import const
16
+ from kirin.passes.abc import Pass
17
+ from kirin.rewrite.abc import RewriteResult
18
+
19
+
20
+ @dataclass
21
+ class Fold(Pass):
22
+
23
+ def unsafe_run(self, mt: Method) -> RewriteResult:
24
+ constprop = const.Propagate(self.dialects)
25
+ frame, _ = constprop.run_analysis(mt)
26
+ result = Walk(WrapConst(frame)).rewrite(mt.code)
27
+ result = (
28
+ Fixpoint(
29
+ Walk(
30
+ Chain(
31
+ ConstantFold(),
32
+ InlineGetItem(),
33
+ Call2Invoke(),
34
+ DeadCodeElimination(),
35
+ )
36
+ )
37
+ )
38
+ .rewrite(mt.code)
39
+ .join(result)
40
+ )
41
+
42
+ if mt.code.has_trait(SSACFGRegion):
43
+ result = Walk(CFGCompactify()).rewrite(mt.code).join(result)
44
+
45
+ return Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
kirin/passes/inline.py ADDED
@@ -0,0 +1,25 @@
1
+ from typing import Callable
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.passes import Pass
6
+ from kirin.rewrite import Walk, Inline, Fixpoint, CFGCompactify, DeadCodeElimination
7
+ from kirin.rewrite.abc import RewriteResult
8
+
9
+
10
+ def aggresive(x: ir.IRNode) -> bool:
11
+ return True
12
+
13
+
14
+ @dataclass
15
+ class InlinePass(Pass):
16
+ herustic: Callable[[ir.IRNode], bool] = field(default=aggresive)
17
+
18
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
19
+
20
+ result = Walk(Inline(heuristic=self.herustic)).rewrite(mt.code)
21
+ result = Walk(CFGCompactify()).rewrite(mt.code).join(result)
22
+
23
+ # dce
24
+ dce = DeadCodeElimination()
25
+ return Fixpoint(Walk(dce)).rewrite(mt.code).join(result)
@@ -0,0 +1,25 @@
1
+ from dataclasses import dataclass
2
+
3
+ from kirin.ir import Method, HasSignature
4
+ from kirin.rewrite import Walk
5
+ from kirin.passes.abc import Pass
6
+ from kirin.rewrite.abc import RewriteResult
7
+ from kirin.dialects.func import Signature
8
+ from kirin.analysis.typeinfer import TypeInference
9
+ from kirin.rewrite.apply_type import ApplyType
10
+
11
+
12
+ @dataclass
13
+ class TypeInfer(Pass):
14
+
15
+ def __post_init__(self):
16
+ self.infer = TypeInference(self.dialects)
17
+
18
+ def unsafe_run(self, mt: Method) -> RewriteResult:
19
+ frame, return_type = self.infer.run_analysis(mt, mt.arg_types)
20
+ if trait := mt.code.get_trait(HasSignature):
21
+ trait.set_signature(mt.code, Signature(mt.arg_types, return_type))
22
+
23
+ result = Walk(ApplyType(frame.entries)).rewrite(mt.code)
24
+ mt.inferred = True
25
+ return result