kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/lowering/state.py
ADDED
@@ -0,0 +1,441 @@
|
|
1
|
+
import ast
|
2
|
+
import inspect
|
3
|
+
import builtins
|
4
|
+
from typing import TYPE_CHECKING, Any, TypeVar, get_origin
|
5
|
+
from dataclasses import dataclass
|
6
|
+
|
7
|
+
from kirin.ir import Method, SSAValue, Statement, DialectGroup, traits
|
8
|
+
from kirin.source import SourceInfo
|
9
|
+
from kirin.exceptions import DialectLoweringError
|
10
|
+
from kirin.lowering.frame import Frame
|
11
|
+
from kirin.lowering.result import Result
|
12
|
+
from kirin.lowering.binding import Binding
|
13
|
+
from kirin.lowering.dialect import FromPythonAST
|
14
|
+
|
15
|
+
if TYPE_CHECKING:
|
16
|
+
from kirin.lowering.core import Lowering
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class LoweringState(ast.NodeVisitor):
|
21
|
+
# from parent
|
22
|
+
dialects: DialectGroup
|
23
|
+
registry: dict[str, FromPythonAST]
|
24
|
+
|
25
|
+
# debug info
|
26
|
+
lines: list[str]
|
27
|
+
lineno_offset: int
|
28
|
+
"lineno offset at the beginning of the source"
|
29
|
+
col_offset: int
|
30
|
+
"column offset at the beginning of the source"
|
31
|
+
source: SourceInfo
|
32
|
+
"source info of the current node"
|
33
|
+
# line_range: tuple[int, int] # current (<start>, <end>)
|
34
|
+
# col_range: tuple[int, int] # current (<start>, <end>)
|
35
|
+
max_lines: int = 3
|
36
|
+
_current_frame: Frame | None = None
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
def from_stmt(
|
40
|
+
cls,
|
41
|
+
lowering: "Lowering",
|
42
|
+
stmt: ast.stmt,
|
43
|
+
source: str | None = None,
|
44
|
+
globals: dict[str, Any] | None = None,
|
45
|
+
max_lines: int = 3,
|
46
|
+
lineno_offset: int = 0,
|
47
|
+
col_offset: int = 0,
|
48
|
+
):
|
49
|
+
if not isinstance(stmt, ast.stmt):
|
50
|
+
raise ValueError(f"Expected ast.stmt, got {type(stmt)}")
|
51
|
+
|
52
|
+
if not source:
|
53
|
+
source = ast.unparse(stmt)
|
54
|
+
|
55
|
+
state = cls(
|
56
|
+
dialects=lowering.dialects,
|
57
|
+
registry=lowering.registry,
|
58
|
+
lines=source.splitlines(),
|
59
|
+
lineno_offset=lineno_offset,
|
60
|
+
col_offset=col_offset,
|
61
|
+
source=SourceInfo.from_ast(stmt, lineno_offset, col_offset),
|
62
|
+
max_lines=max_lines,
|
63
|
+
)
|
64
|
+
|
65
|
+
frame = Frame.from_stmts([stmt], state, globals=globals)
|
66
|
+
state.push_frame(frame)
|
67
|
+
return state
|
68
|
+
|
69
|
+
@property
|
70
|
+
def current_frame(self):
|
71
|
+
if self._current_frame is None:
|
72
|
+
raise ValueError("No frame")
|
73
|
+
return self._current_frame
|
74
|
+
|
75
|
+
@property
|
76
|
+
def code(self):
|
77
|
+
stmt = self.current_frame.curr_region.blocks[0].first_stmt
|
78
|
+
if stmt:
|
79
|
+
return stmt
|
80
|
+
raise ValueError("No code generated")
|
81
|
+
|
82
|
+
StmtType = TypeVar("StmtType", bound=Statement)
|
83
|
+
|
84
|
+
def append_stmt(self, stmt: StmtType) -> StmtType:
|
85
|
+
"""Shorthand for appending a statement to the current block of current frame."""
|
86
|
+
return self.current_frame.append_stmt(stmt)
|
87
|
+
|
88
|
+
def push_frame(self, frame: Frame):
|
89
|
+
frame.parent = self._current_frame
|
90
|
+
self._current_frame = frame
|
91
|
+
return frame
|
92
|
+
|
93
|
+
def pop_frame(self, finalize_next: bool = True):
|
94
|
+
"""Pop the current frame and return it.
|
95
|
+
|
96
|
+
Args:
|
97
|
+
finalize_next(bool): If True, append the next block of the current frame.
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
Frame: The popped frame.
|
101
|
+
"""
|
102
|
+
if self._current_frame is None:
|
103
|
+
raise ValueError("No frame to pop")
|
104
|
+
frame = self._current_frame
|
105
|
+
|
106
|
+
if finalize_next and frame.next_block.parent is None:
|
107
|
+
frame.append_block(frame.next_block)
|
108
|
+
self._current_frame = frame.parent
|
109
|
+
return frame
|
110
|
+
|
111
|
+
def update_lineno(self, node):
|
112
|
+
self.source = SourceInfo.from_ast(node, self.lineno_offset, self.col_offset)
|
113
|
+
|
114
|
+
def __repr__(self) -> str:
|
115
|
+
return f"LoweringState({self.current_frame})"
|
116
|
+
|
117
|
+
def visit(self, node: ast.AST) -> Result:
|
118
|
+
self.update_lineno(node)
|
119
|
+
name = node.__class__.__name__
|
120
|
+
if name in self.registry:
|
121
|
+
return self.registry[name].lower(self, node)
|
122
|
+
elif isinstance(node, ast.Call):
|
123
|
+
# NOTE: if lower_Call is implemented,
|
124
|
+
# it will be called first before __dispatch_Call
|
125
|
+
# because "Call" exists in self.registry
|
126
|
+
return self.__dispatch_Call(node)
|
127
|
+
elif isinstance(node, ast.With):
|
128
|
+
return self.__dispatch_With(node)
|
129
|
+
return super().visit(node)
|
130
|
+
|
131
|
+
def generic_visit(self, node: ast.AST):
|
132
|
+
raise DialectLoweringError(f"unsupported ast node {type(node)}:")
|
133
|
+
|
134
|
+
def __dispatch_With(self, node: ast.With):
|
135
|
+
if len(node.items) != 1:
|
136
|
+
raise DialectLoweringError("expected exactly one item in with statement")
|
137
|
+
|
138
|
+
item = node.items[0]
|
139
|
+
if not isinstance(item.context_expr, ast.Call):
|
140
|
+
raise DialectLoweringError("expected context expression to be a call")
|
141
|
+
|
142
|
+
global_callee_result = self.get_global_nothrow(item.context_expr.func)
|
143
|
+
if global_callee_result is None:
|
144
|
+
raise DialectLoweringError("cannot find call func in with context")
|
145
|
+
|
146
|
+
global_callee = global_callee_result.unwrap()
|
147
|
+
if not issubclass(global_callee, Statement):
|
148
|
+
raise DialectLoweringError("expected callee to be a statement")
|
149
|
+
|
150
|
+
if (
|
151
|
+
trait := global_callee.get_trait(traits.FromPythonWithSingleItem)
|
152
|
+
) is not None:
|
153
|
+
return trait.lower(global_callee, self, node)
|
154
|
+
|
155
|
+
raise DialectLoweringError(
|
156
|
+
"unsupported callee, missing FromPythonWithSingleItem trait"
|
157
|
+
)
|
158
|
+
|
159
|
+
def __dispatch_Call(self, node: ast.Call):
|
160
|
+
# 1. try to lookup global statement object
|
161
|
+
# 2. lookup local values
|
162
|
+
global_callee_result = self.get_global_nothrow(node.func)
|
163
|
+
if global_callee_result is None: # not found in globals, has to be local
|
164
|
+
return self.__lower_Call_local(node)
|
165
|
+
|
166
|
+
global_callee = global_callee_result.unwrap()
|
167
|
+
if isinstance(global_callee, Binding):
|
168
|
+
global_callee = global_callee.parent
|
169
|
+
|
170
|
+
if isinstance(global_callee, Method):
|
171
|
+
if "Call_global_method" in self.registry:
|
172
|
+
return self.registry["Call_global_method"].lower_Call_global_method(
|
173
|
+
self, global_callee, node
|
174
|
+
)
|
175
|
+
raise DialectLoweringError("`lower_Call_global_method` not implemented")
|
176
|
+
elif inspect.isclass(global_callee):
|
177
|
+
if issubclass(global_callee, Statement):
|
178
|
+
if global_callee.dialect is None:
|
179
|
+
raise DialectLoweringError(
|
180
|
+
f"unsupported dialect `None` for {global_callee.name}"
|
181
|
+
)
|
182
|
+
|
183
|
+
if global_callee.dialect not in self.dialects.data:
|
184
|
+
raise DialectLoweringError(
|
185
|
+
f"unsupported dialect `{global_callee.dialect.name}`"
|
186
|
+
)
|
187
|
+
|
188
|
+
if (
|
189
|
+
trait := global_callee.get_trait(traits.FromPythonCall)
|
190
|
+
) is not None:
|
191
|
+
return trait.lower(global_callee, self, node)
|
192
|
+
|
193
|
+
raise DialectLoweringError(
|
194
|
+
f"unsupported callee {global_callee.__name__}, "
|
195
|
+
"missing FromPythonAST lowering, or traits.FromPythonCall trait"
|
196
|
+
)
|
197
|
+
elif issubclass(global_callee, slice):
|
198
|
+
if "Call_slice" in self.registry:
|
199
|
+
return self.registry["Call_slice"].lower_Call_slice(self, node)
|
200
|
+
raise DialectLoweringError("`lower_Call_slice` not implemented")
|
201
|
+
elif issubclass(global_callee, range):
|
202
|
+
if "Call_range" in self.registry:
|
203
|
+
return self.registry["Call_range"].lower_Call_range(self, node)
|
204
|
+
raise DialectLoweringError("`lower_Call_range` not implemented")
|
205
|
+
elif inspect.isbuiltin(global_callee):
|
206
|
+
name = f"Call_{global_callee.__name__}"
|
207
|
+
if "Call_builtins" in self.registry:
|
208
|
+
dialect_lowering = self.registry["Call_builtins"]
|
209
|
+
return dialect_lowering.lower_Call_builtins(self, node)
|
210
|
+
elif name in self.registry:
|
211
|
+
dialect_lowering = self.registry[name]
|
212
|
+
return getattr(dialect_lowering, f"lower_{name}")(self, node)
|
213
|
+
else:
|
214
|
+
raise DialectLoweringError(
|
215
|
+
f"`lower_{name}` is not implemented for builtin function `{global_callee.__name__}`."
|
216
|
+
)
|
217
|
+
|
218
|
+
# symbol exist in global, but not ir.Statement, it may refer to a
|
219
|
+
# local value that shadows the global value
|
220
|
+
try:
|
221
|
+
return self.__lower_Call_local(node)
|
222
|
+
except DialectLoweringError:
|
223
|
+
# symbol exist in global, but not ir.Statement, not found in locals either
|
224
|
+
# this means the symbol is referring to an external uncallable object
|
225
|
+
if inspect.isfunction(global_callee):
|
226
|
+
raise DialectLoweringError(
|
227
|
+
f"unsupported callee: {repr(global_callee)}."
|
228
|
+
"Are you trying to call a python function? This is not supported."
|
229
|
+
)
|
230
|
+
else: # well not much we can do, can't hint
|
231
|
+
raise DialectLoweringError(
|
232
|
+
f"unsupported callee type: {repr(global_callee)}"
|
233
|
+
)
|
234
|
+
|
235
|
+
def __lower_Call_local(self, node: ast.Call) -> Result:
|
236
|
+
callee = self.visit(node.func).expect_one()
|
237
|
+
if "Call_local" in self.registry:
|
238
|
+
return self.registry["Call_local"].lower_Call_local(self, callee, node)
|
239
|
+
raise DialectLoweringError("`lower_Call_local` not implemented")
|
240
|
+
|
241
|
+
def default_Call_lower(self, stmt: type[Statement], node: ast.Call) -> Result:
|
242
|
+
"""Default lowering for Python call to statement.
|
243
|
+
|
244
|
+
This method is intended to be used by traits like `FromPythonCall` to
|
245
|
+
provide a default lowering for Python calls to statements.
|
246
|
+
|
247
|
+
Args:
|
248
|
+
stmt(type[Statement]): Statement class to construct.
|
249
|
+
node(ast.Call): Python call node to lower.
|
250
|
+
|
251
|
+
Returns:
|
252
|
+
Result: Result of lowering the Python call to statement.
|
253
|
+
"""
|
254
|
+
args, kwargs = self.default_Call_inputs(stmt, node)
|
255
|
+
return Result(self.append_stmt(stmt(*args.values(), **kwargs)))
|
256
|
+
|
257
|
+
def default_Call_inputs(
|
258
|
+
self, stmt: type[Statement], node: ast.Call
|
259
|
+
) -> tuple[dict[str, SSAValue | tuple[SSAValue, ...]], dict[str, Any]]:
|
260
|
+
from kirin.decl import fields
|
261
|
+
|
262
|
+
fs = fields(stmt)
|
263
|
+
stmt_std_arg_names = fs.std_args.keys()
|
264
|
+
stmt_kw_args_name = fs.kw_args.keys()
|
265
|
+
stmt_attr_prop_names = fs.attr_or_props
|
266
|
+
stmt_required_names = fs.required_names
|
267
|
+
stmt_group_arg_names = fs.group_arg_names
|
268
|
+
args, kwargs = {}, {}
|
269
|
+
for name, value in zip(stmt_std_arg_names, node.args):
|
270
|
+
self._parse_arg(stmt_group_arg_names, args, name, value)
|
271
|
+
for kw in node.keywords:
|
272
|
+
if not isinstance(kw.arg, str):
|
273
|
+
raise DialectLoweringError("Expected string for keyword argument name")
|
274
|
+
|
275
|
+
arg: str = kw.arg
|
276
|
+
if arg in node.args:
|
277
|
+
raise DialectLoweringError(
|
278
|
+
f"Keyword argument {arg} is already present in positional arguments"
|
279
|
+
)
|
280
|
+
elif arg in stmt_std_arg_names or arg in stmt_kw_args_name:
|
281
|
+
self._parse_arg(stmt_group_arg_names, kwargs, kw.arg, kw.value)
|
282
|
+
elif arg in stmt_attr_prop_names:
|
283
|
+
if (
|
284
|
+
isinstance(kw.value, ast.Name)
|
285
|
+
and self.current_frame.get_local(kw.value.id) is not None
|
286
|
+
):
|
287
|
+
raise DialectLoweringError(
|
288
|
+
f"Expected global/constant value for attribute or property {arg}"
|
289
|
+
)
|
290
|
+
global_value = self.get_global_nothrow(kw.value)
|
291
|
+
if global_value is None:
|
292
|
+
raise DialectLoweringError(
|
293
|
+
f"Expected global value for attribute or property {arg}"
|
294
|
+
)
|
295
|
+
if (decl := fs.attributes.get(arg)) is not None:
|
296
|
+
if decl.annotation is Any:
|
297
|
+
kwargs[arg] = global_value.unwrap()
|
298
|
+
else:
|
299
|
+
kwargs[arg] = global_value.expect(
|
300
|
+
get_origin(decl.annotation) or decl.annotation
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
raise DialectLoweringError(
|
304
|
+
f"Unexpected attribute or property {arg}"
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
raise DialectLoweringError(f"Unexpected keyword argument {arg}")
|
308
|
+
|
309
|
+
for name in stmt_required_names:
|
310
|
+
if name not in args and name not in kwargs:
|
311
|
+
raise DialectLoweringError(f"Missing required argument {name}")
|
312
|
+
|
313
|
+
return args, kwargs
|
314
|
+
|
315
|
+
def _parse_arg(
|
316
|
+
self,
|
317
|
+
group_names: set[str],
|
318
|
+
target: dict,
|
319
|
+
name: str,
|
320
|
+
value: ast.AST,
|
321
|
+
):
|
322
|
+
if name in group_names:
|
323
|
+
if not isinstance(value, ast.Tuple):
|
324
|
+
raise DialectLoweringError(f"Expected tuple for group argument {name}")
|
325
|
+
target[name] = tuple(self.visit(elem).expect_one() for elem in value.elts)
|
326
|
+
else:
|
327
|
+
target[name] = self.visit(value).expect_one()
|
328
|
+
|
329
|
+
ValueT = TypeVar("ValueT", bound=SSAValue)
|
330
|
+
|
331
|
+
def exhaust(self, frame: Frame | None = None) -> Frame:
|
332
|
+
"""Exhaust given frame's stream. If not given, exhaust current frame's stream."""
|
333
|
+
if not frame:
|
334
|
+
current_frame = self.current_frame
|
335
|
+
else:
|
336
|
+
current_frame = frame
|
337
|
+
|
338
|
+
stream = current_frame.stream
|
339
|
+
while stream.has_next():
|
340
|
+
stmt = stream.pop()
|
341
|
+
self.visit(stmt)
|
342
|
+
return current_frame
|
343
|
+
|
344
|
+
def error_hint(self) -> str:
|
345
|
+
begin = max(0, self.source.lineno - self.max_lines)
|
346
|
+
end = max(self.source.lineno + self.max_lines, self.source.end_lineno or 0)
|
347
|
+
end = min(len(self.lines), end) # make sure end is within bounds
|
348
|
+
lines = self.lines[begin:end]
|
349
|
+
code_indent = min(map(self.__get_indent, lines), default=0)
|
350
|
+
lines.append("") # in case the last line errors
|
351
|
+
|
352
|
+
snippet_lines = []
|
353
|
+
for lineno, line in enumerate(lines, begin):
|
354
|
+
if lineno == self.source.lineno:
|
355
|
+
snippet_lines.append(("-" * (self.source.col_offset)) + "^")
|
356
|
+
|
357
|
+
snippet_lines.append(line[code_indent:])
|
358
|
+
|
359
|
+
return "\n".join(snippet_lines)
|
360
|
+
|
361
|
+
@staticmethod
|
362
|
+
def __get_indent(line: str) -> int:
|
363
|
+
if len(line) == 0:
|
364
|
+
return int(1e9) # very large number
|
365
|
+
return len(line) - len(line.lstrip())
|
366
|
+
|
367
|
+
@dataclass
|
368
|
+
class GlobalRefResult:
|
369
|
+
data: Any
|
370
|
+
|
371
|
+
def unwrap(self):
|
372
|
+
return self.data
|
373
|
+
|
374
|
+
ExpectT = TypeVar("ExpectT")
|
375
|
+
|
376
|
+
def expect(self, typ: type[ExpectT]) -> ExpectT:
|
377
|
+
if not isinstance(self.data, typ):
|
378
|
+
raise DialectLoweringError(f"expected {typ}, got {type(self.data)}")
|
379
|
+
return self.data
|
380
|
+
|
381
|
+
def get_global_nothrow(self, node) -> GlobalRefResult | None:
|
382
|
+
try:
|
383
|
+
return self.get_global(node)
|
384
|
+
except DialectLoweringError:
|
385
|
+
return None
|
386
|
+
|
387
|
+
def get_global(self, node) -> GlobalRefResult:
|
388
|
+
return getattr(
|
389
|
+
self, f"get_global_{node.__class__.__name__}", self.get_global_fallback
|
390
|
+
)(node)
|
391
|
+
|
392
|
+
def get_global_fallback(self, node: ast.AST) -> GlobalRefResult:
|
393
|
+
raise DialectLoweringError(
|
394
|
+
f"unsupported global access get_global_{node.__class__.__name__}: {ast.unparse(node)}"
|
395
|
+
)
|
396
|
+
|
397
|
+
def get_global_Constant(self, node: ast.Constant) -> GlobalRefResult:
|
398
|
+
return self.GlobalRefResult(node.value)
|
399
|
+
|
400
|
+
def get_global_str(self, node: str) -> GlobalRefResult:
|
401
|
+
if node in (globals := self.current_frame.globals):
|
402
|
+
return self.GlobalRefResult(globals[node])
|
403
|
+
|
404
|
+
if hasattr(builtins, node):
|
405
|
+
return self.GlobalRefResult(getattr(builtins, node))
|
406
|
+
|
407
|
+
raise DialectLoweringError(f"global {node} not found")
|
408
|
+
|
409
|
+
def get_global_Name(self, node: ast.Name) -> GlobalRefResult:
|
410
|
+
return self.get_global_str(node.id)
|
411
|
+
|
412
|
+
def get_global_Attribute(self, node: ast.Attribute) -> GlobalRefResult:
|
413
|
+
if not isinstance(node.ctx, ast.Load):
|
414
|
+
raise DialectLoweringError("unsupported attribute access")
|
415
|
+
|
416
|
+
match node.value:
|
417
|
+
case ast.Name(id):
|
418
|
+
value = self.get_global_str(id).unwrap()
|
419
|
+
case ast.Attribute():
|
420
|
+
value = self.get_global(node.value).unwrap()
|
421
|
+
case _:
|
422
|
+
raise DialectLoweringError("unsupported attribute access")
|
423
|
+
|
424
|
+
if hasattr(value, node.attr):
|
425
|
+
return self.GlobalRefResult(getattr(value, node.attr))
|
426
|
+
|
427
|
+
raise DialectLoweringError(f"attribute {node.attr} not found in {value}")
|
428
|
+
|
429
|
+
def get_global_Subscript(self, node: ast.Subscript) -> GlobalRefResult:
|
430
|
+
value = self.get_global(node.value).unwrap()
|
431
|
+
if isinstance(node.slice, ast.Tuple):
|
432
|
+
index = tuple(self.get_global(elt).unwrap() for elt in node.slice.elts)
|
433
|
+
else:
|
434
|
+
index = self.get_global(node.slice).unwrap()
|
435
|
+
return self.GlobalRefResult(value[index])
|
436
|
+
|
437
|
+
def get_global_Call(self, node: ast.Call) -> GlobalRefResult:
|
438
|
+
func = self.get_global(node.func).unwrap()
|
439
|
+
args = [self.get_global(arg).unwrap() for arg in node.args]
|
440
|
+
kwargs = {kw.arg: self.get_global(kw.value).unwrap() for kw in node.keywords}
|
441
|
+
return self.GlobalRefResult(func(*args, **kwargs))
|
kirin/lowering/stream.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
from typing import Generic, TypeVar, Sequence
|
2
|
+
from dataclasses import field, dataclass
|
3
|
+
|
4
|
+
Stmt = TypeVar("Stmt")
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class StmtStream(Generic[Stmt]):
|
9
|
+
stmts: list[Stmt] = field(default_factory=list)
|
10
|
+
cursor: int = 0
|
11
|
+
|
12
|
+
def __init__(self, stmts: Sequence[Stmt], cursor: int = 0):
|
13
|
+
self.stmts = list(stmts)
|
14
|
+
self.cursor = cursor
|
15
|
+
|
16
|
+
def __iter__(self):
|
17
|
+
return self
|
18
|
+
|
19
|
+
def __next__(self):
|
20
|
+
if self.cursor < len(self.stmts):
|
21
|
+
stmt = self.stmts[self.cursor]
|
22
|
+
self.cursor += 1
|
23
|
+
return stmt
|
24
|
+
else:
|
25
|
+
raise StopIteration
|
26
|
+
|
27
|
+
def peek(self):
|
28
|
+
return self.stmts[self.cursor]
|
29
|
+
|
30
|
+
def has_next(self):
|
31
|
+
return self.cursor < len(self.stmts)
|
32
|
+
|
33
|
+
def split(self) -> "StmtStream":
|
34
|
+
cursor = self.cursor
|
35
|
+
self.cursor = len(self.stmts)
|
36
|
+
return StmtStream(self.stmts, cursor)
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return len(self.stmts)
|
40
|
+
|
41
|
+
def __getitem__(self, key):
|
42
|
+
return self.stmts[key]
|
43
|
+
|
44
|
+
def __setitem__(self, key, value):
|
45
|
+
self.stmts[key] = value
|
46
|
+
|
47
|
+
def pop(self):
|
48
|
+
stmt = self.stmts[self.cursor]
|
49
|
+
self.cursor += 1
|
50
|
+
return stmt
|
51
|
+
|
52
|
+
def is_empty(self):
|
53
|
+
return self.cursor == len(self.stmts)
|
kirin/passes/__init__.py
ADDED
kirin/passes/abc.py
ADDED
@@ -0,0 +1,44 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import ClassVar
|
3
|
+
from dataclasses import dataclass
|
4
|
+
|
5
|
+
from kirin.ir import Method, DialectGroup
|
6
|
+
from kirin.rewrite.abc import RewriteResult
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class Pass(ABC):
|
11
|
+
"""A pass is a transformation that is applied to a method. It wraps
|
12
|
+
the analysis and rewrites needed to transform the method as an independent
|
13
|
+
unit.
|
14
|
+
|
15
|
+
Unlike LLVM/MLIR passes, a pass in Kirin does not apply to a module,
|
16
|
+
this is because we focus on individual methods defined within
|
17
|
+
python modules. This is a design choice to allow seamless integration
|
18
|
+
within the Python interpreter.
|
19
|
+
|
20
|
+
A Kirin compile unit is a `ir.Method` object, which is always equivalent
|
21
|
+
to a LLVM/MLIR module if it were lowered to LLVM/MLIR just like other JIT
|
22
|
+
compilers.
|
23
|
+
"""
|
24
|
+
|
25
|
+
name: ClassVar[str]
|
26
|
+
dialects: DialectGroup
|
27
|
+
|
28
|
+
def __call__(self, mt: Method) -> RewriteResult:
|
29
|
+
result = self.unsafe_run(mt)
|
30
|
+
mt.code.verify()
|
31
|
+
return result
|
32
|
+
|
33
|
+
def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
|
34
|
+
result = RewriteResult()
|
35
|
+
for _ in range(max_iter):
|
36
|
+
result_ = self.unsafe_run(mt)
|
37
|
+
result = result_.join(result)
|
38
|
+
if not result.has_done_something:
|
39
|
+
break
|
40
|
+
mt.code.verify()
|
41
|
+
return result
|
42
|
+
|
43
|
+
@abstractmethod
|
44
|
+
def unsafe_run(self, mt: Method) -> RewriteResult: ...
|
@@ -0,0 +1 @@
|
|
1
|
+
from .fold import Fold as Fold
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from dataclasses import field, dataclass
|
2
|
+
|
3
|
+
from kirin.passes import Pass
|
4
|
+
from kirin.rewrite import (
|
5
|
+
Walk,
|
6
|
+
Chain,
|
7
|
+
Inline,
|
8
|
+
Fixpoint,
|
9
|
+
WrapConst,
|
10
|
+
Call2Invoke,
|
11
|
+
ConstantFold,
|
12
|
+
CFGCompactify,
|
13
|
+
InlineGetItem,
|
14
|
+
InlineGetField,
|
15
|
+
DeadCodeElimination,
|
16
|
+
)
|
17
|
+
from kirin.analysis import const
|
18
|
+
from kirin.ir.method import Method
|
19
|
+
from kirin.rewrite.abc import RewriteResult
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class Fold(Pass):
|
24
|
+
constprop: const.Propagate = field(init=False)
|
25
|
+
|
26
|
+
def __post_init__(self):
|
27
|
+
self.constprop = const.Propagate(self.dialects)
|
28
|
+
|
29
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
30
|
+
result = RewriteResult()
|
31
|
+
frame, _ = self.constprop.run_analysis(mt)
|
32
|
+
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
|
33
|
+
rule = Chain(
|
34
|
+
ConstantFold(),
|
35
|
+
Call2Invoke(),
|
36
|
+
InlineGetField(),
|
37
|
+
InlineGetItem(),
|
38
|
+
DeadCodeElimination(),
|
39
|
+
)
|
40
|
+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
|
41
|
+
result = Walk(Inline(lambda _: True)).rewrite(mt.code).join(result)
|
42
|
+
result = Fixpoint(CFGCompactify()).rewrite(mt.code).join(result)
|
43
|
+
return result
|
kirin/passes/fold.py
ADDED
@@ -0,0 +1,45 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin.ir import Method, SSACFGRegion
|
4
|
+
from kirin.rewrite import (
|
5
|
+
Walk,
|
6
|
+
Chain,
|
7
|
+
Fixpoint,
|
8
|
+
WrapConst,
|
9
|
+
Call2Invoke,
|
10
|
+
ConstantFold,
|
11
|
+
CFGCompactify,
|
12
|
+
InlineGetItem,
|
13
|
+
DeadCodeElimination,
|
14
|
+
)
|
15
|
+
from kirin.analysis import const
|
16
|
+
from kirin.passes.abc import Pass
|
17
|
+
from kirin.rewrite.abc import RewriteResult
|
18
|
+
|
19
|
+
|
20
|
+
@dataclass
|
21
|
+
class Fold(Pass):
|
22
|
+
|
23
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
24
|
+
constprop = const.Propagate(self.dialects)
|
25
|
+
frame, _ = constprop.run_analysis(mt)
|
26
|
+
result = Walk(WrapConst(frame)).rewrite(mt.code)
|
27
|
+
result = (
|
28
|
+
Fixpoint(
|
29
|
+
Walk(
|
30
|
+
Chain(
|
31
|
+
ConstantFold(),
|
32
|
+
InlineGetItem(),
|
33
|
+
Call2Invoke(),
|
34
|
+
DeadCodeElimination(),
|
35
|
+
)
|
36
|
+
)
|
37
|
+
)
|
38
|
+
.rewrite(mt.code)
|
39
|
+
.join(result)
|
40
|
+
)
|
41
|
+
|
42
|
+
if mt.code.has_trait(SSACFGRegion):
|
43
|
+
result = Walk(CFGCompactify()).rewrite(mt.code).join(result)
|
44
|
+
|
45
|
+
return Fixpoint(Walk(DeadCodeElimination())).rewrite(mt.code).join(result)
|
kirin/passes/inline.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import Callable
|
2
|
+
from dataclasses import field, dataclass
|
3
|
+
|
4
|
+
from kirin import ir
|
5
|
+
from kirin.passes import Pass
|
6
|
+
from kirin.rewrite import Walk, Inline, Fixpoint, CFGCompactify, DeadCodeElimination
|
7
|
+
from kirin.rewrite.abc import RewriteResult
|
8
|
+
|
9
|
+
|
10
|
+
def aggresive(x: ir.IRNode) -> bool:
|
11
|
+
return True
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class InlinePass(Pass):
|
16
|
+
herustic: Callable[[ir.IRNode], bool] = field(default=aggresive)
|
17
|
+
|
18
|
+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
|
19
|
+
|
20
|
+
result = Walk(Inline(heuristic=self.herustic)).rewrite(mt.code)
|
21
|
+
result = Walk(CFGCompactify()).rewrite(mt.code).join(result)
|
22
|
+
|
23
|
+
# dce
|
24
|
+
dce = DeadCodeElimination()
|
25
|
+
return Fixpoint(Walk(dce)).rewrite(mt.code).join(result)
|
@@ -0,0 +1,25 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from kirin.ir import Method, HasSignature
|
4
|
+
from kirin.rewrite import Walk
|
5
|
+
from kirin.passes.abc import Pass
|
6
|
+
from kirin.rewrite.abc import RewriteResult
|
7
|
+
from kirin.dialects.func import Signature
|
8
|
+
from kirin.analysis.typeinfer import TypeInference
|
9
|
+
from kirin.rewrite.apply_type import ApplyType
|
10
|
+
|
11
|
+
|
12
|
+
@dataclass
|
13
|
+
class TypeInfer(Pass):
|
14
|
+
|
15
|
+
def __post_init__(self):
|
16
|
+
self.infer = TypeInference(self.dialects)
|
17
|
+
|
18
|
+
def unsafe_run(self, mt: Method) -> RewriteResult:
|
19
|
+
frame, return_type = self.infer.run_analysis(mt, mt.arg_types)
|
20
|
+
if trait := mt.code.get_trait(HasSignature):
|
21
|
+
trait.set_signature(mt.code, Signature(mt.arg_types, return_type))
|
22
|
+
|
23
|
+
result = Walk(ApplyType(frame.entries)).rewrite(mt.code)
|
24
|
+
mt.inferred = True
|
25
|
+
return result
|