kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/interp/base.py
ADDED
@@ -0,0 +1,438 @@
|
|
1
|
+
import sys
|
2
|
+
from abc import ABC, ABCMeta, abstractmethod
|
3
|
+
from enum import Enum
|
4
|
+
from typing import TYPE_CHECKING, Generic, TypeVar, ClassVar, Optional, Sequence
|
5
|
+
from dataclasses import field, dataclass
|
6
|
+
|
7
|
+
from typing_extensions import Self, deprecated
|
8
|
+
|
9
|
+
from kirin.ir import Block, Region, Statement, DialectGroup, traits
|
10
|
+
from kirin.ir.method import Method
|
11
|
+
|
12
|
+
from .impl import Signature
|
13
|
+
from .frame import FrameABC
|
14
|
+
from .state import InterpreterState
|
15
|
+
from .value import ReturnValue, SpecialValue, StatementResult
|
16
|
+
from .result import Ok, Err, Result
|
17
|
+
from .exceptions import InterpreterError
|
18
|
+
|
19
|
+
if TYPE_CHECKING:
|
20
|
+
from kirin.registry import StatementImpl, InterpreterRegistry
|
21
|
+
|
22
|
+
ValueType = TypeVar("ValueType")
|
23
|
+
FrameType = TypeVar("FrameType", bound=FrameABC)
|
24
|
+
|
25
|
+
|
26
|
+
class InterpreterMeta(ABCMeta):
|
27
|
+
"""A metaclass for interpreters."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
@dataclass
|
33
|
+
class BaseInterpreter(ABC, Generic[FrameType, ValueType], metaclass=InterpreterMeta):
|
34
|
+
"""A base class for interpreters.
|
35
|
+
|
36
|
+
This class defines the basic structure of an interpreter. It is
|
37
|
+
designed to be subclassed to provide the actual implementation of
|
38
|
+
the interpreter.
|
39
|
+
|
40
|
+
### Required Overrides
|
41
|
+
When subclassing, if the subclass does not contain `ABC`,
|
42
|
+
the subclass must define the following attributes:
|
43
|
+
|
44
|
+
- `keys`: a list of strings that defines the order of dialects to select from.
|
45
|
+
- `void`: the value to return when the interpreter evaluates nothing.
|
46
|
+
"""
|
47
|
+
|
48
|
+
keys: ClassVar[list[str]]
|
49
|
+
"""The name of the interpreter to select from dialects by order.
|
50
|
+
"""
|
51
|
+
void: ValueType = field(init=False)
|
52
|
+
"""What to return when the interpreter evaluates nothing.
|
53
|
+
"""
|
54
|
+
dialects: DialectGroup
|
55
|
+
"""The dialects to interpret.
|
56
|
+
"""
|
57
|
+
fuel: int | None = field(default=None, kw_only=True)
|
58
|
+
"""The fuel limit for the interpreter.
|
59
|
+
"""
|
60
|
+
debug: bool = field(default=False, kw_only=True)
|
61
|
+
"""Whether to enable debug mode.
|
62
|
+
"""
|
63
|
+
max_depth: int = field(default=128, kw_only=True)
|
64
|
+
"""The maximum depth of the interpreter stack.
|
65
|
+
"""
|
66
|
+
max_python_recursion_depth: int = field(default=8192, kw_only=True)
|
67
|
+
"""The maximum recursion depth of the Python interpreter.
|
68
|
+
"""
|
69
|
+
|
70
|
+
# global states
|
71
|
+
registry: "InterpreterRegistry" = field(init=False, compare=False)
|
72
|
+
"""The interpreter registry.
|
73
|
+
"""
|
74
|
+
symbol_table: dict[str, Statement] = field(init=False, compare=False)
|
75
|
+
"""The symbol table.
|
76
|
+
"""
|
77
|
+
state: InterpreterState[FrameType] = field(init=False, compare=False)
|
78
|
+
"""The interpreter state.
|
79
|
+
"""
|
80
|
+
|
81
|
+
# private
|
82
|
+
_eval_lock: bool = field(default=False, init=False, compare=False)
|
83
|
+
|
84
|
+
def __post_init__(self) -> None:
|
85
|
+
self.registry = self.dialects.registry.interpreter(keys=self.keys)
|
86
|
+
|
87
|
+
def initialize(self) -> Self:
|
88
|
+
"""Initialize the interpreter global states. This method is called right upon
|
89
|
+
calling [`run`][kirin.interp.base.BaseInterpreter.run] to initialize the
|
90
|
+
interpreter global states.
|
91
|
+
|
92
|
+
!!! note "Default Implementation"
|
93
|
+
This method provides default behavior but may be overridden by subclasses
|
94
|
+
to customize or extend functionality.
|
95
|
+
"""
|
96
|
+
self.symbol_table: dict[str, Statement] = {}
|
97
|
+
self.state: InterpreterState[FrameType] = InterpreterState()
|
98
|
+
return self
|
99
|
+
|
100
|
+
def __init_subclass__(cls) -> None:
|
101
|
+
super().__init_subclass__()
|
102
|
+
if ABC in cls.__bases__:
|
103
|
+
return
|
104
|
+
|
105
|
+
if not hasattr(cls, "keys"):
|
106
|
+
raise TypeError(f"keys is not defined for class {cls.__name__}")
|
107
|
+
if not hasattr(cls, "void"):
|
108
|
+
raise TypeError(f"void is not defined for class {cls.__name__}")
|
109
|
+
|
110
|
+
@deprecated("use run instead")
|
111
|
+
def eval(
|
112
|
+
self,
|
113
|
+
mt: Method,
|
114
|
+
args: tuple[ValueType, ...],
|
115
|
+
kwargs: dict[str, ValueType] | None = None,
|
116
|
+
) -> Result[ValueType]:
|
117
|
+
return self.run(mt, args, kwargs)
|
118
|
+
|
119
|
+
def run(
|
120
|
+
self,
|
121
|
+
mt: Method,
|
122
|
+
args: tuple[ValueType, ...],
|
123
|
+
kwargs: dict[str, ValueType] | None = None,
|
124
|
+
) -> Result[ValueType]:
|
125
|
+
"""Run a method. This is the main entry point of the interpreter.
|
126
|
+
|
127
|
+
Args:
|
128
|
+
mt (Method): the method to run.
|
129
|
+
args (tuple[ValueType]): the arguments to the method, does not include self.
|
130
|
+
kwargs (dict[str, ValueType], optional): the keyword arguments to the method.
|
131
|
+
|
132
|
+
Returns:
|
133
|
+
Result[ValueType]: the result of the method.
|
134
|
+
"""
|
135
|
+
if self._eval_lock:
|
136
|
+
raise InterpreterError(
|
137
|
+
"recursive eval is not allowed, use run_method instead"
|
138
|
+
)
|
139
|
+
|
140
|
+
self._eval_lock = True
|
141
|
+
self.initialize()
|
142
|
+
current_recursion_limit = sys.getrecursionlimit()
|
143
|
+
sys.setrecursionlimit(self.max_python_recursion_depth)
|
144
|
+
args = self.get_args(mt.arg_names[len(args) + 1 :], args, kwargs)
|
145
|
+
try:
|
146
|
+
_, results = self.run_method(mt, args)
|
147
|
+
except InterpreterError as e:
|
148
|
+
# NOTE: initialize will create new State
|
149
|
+
# so we don't need to copy the frames.
|
150
|
+
return Err(e, self.state.frames)
|
151
|
+
finally:
|
152
|
+
self._eval_lock = False
|
153
|
+
sys.setrecursionlimit(current_recursion_limit)
|
154
|
+
return Ok(results)
|
155
|
+
|
156
|
+
def run_stmt(
|
157
|
+
self, stmt: Statement, args: tuple[ValueType, ...]
|
158
|
+
) -> StatementResult[ValueType]:
|
159
|
+
"""execute a statement with arguments in a new frame.
|
160
|
+
|
161
|
+
Args:
|
162
|
+
stmt (Statement): the statement to run.
|
163
|
+
args (tuple[ValueType]): the arguments to the statement.
|
164
|
+
|
165
|
+
Returns:
|
166
|
+
StatementResult[ValueType]: the result of the statement.
|
167
|
+
"""
|
168
|
+
frame = self.new_frame(stmt)
|
169
|
+
self.state.push_frame(frame)
|
170
|
+
frame.set_values(stmt.args, args)
|
171
|
+
results = self.eval_stmt(frame, stmt)
|
172
|
+
self.state.pop_frame()
|
173
|
+
return results
|
174
|
+
|
175
|
+
@abstractmethod
|
176
|
+
def run_method(
|
177
|
+
self, method: Method, args: tuple[ValueType, ...]
|
178
|
+
) -> tuple[FrameType, ValueType]:
|
179
|
+
"""How to run a method.
|
180
|
+
|
181
|
+
This is defined by subclasses to describe what's the corresponding
|
182
|
+
value of a method during the interpretation. Usually, this method
|
183
|
+
just calls [`run_callable`][kirin.interp.base.BaseInterpreter.run_callable].
|
184
|
+
|
185
|
+
Args:
|
186
|
+
method (Method): the method to run.
|
187
|
+
args (tuple[ValueType]): the arguments to the method, does not include self.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
ValueType: the result of the method.
|
191
|
+
"""
|
192
|
+
...
|
193
|
+
|
194
|
+
def run_callable(
|
195
|
+
self, code: Statement, args: tuple[ValueType, ...]
|
196
|
+
) -> tuple[FrameType, ValueType]:
|
197
|
+
"""Run a callable statement.
|
198
|
+
|
199
|
+
Args:
|
200
|
+
code (Statement): the statement to run.
|
201
|
+
args (tuple[ValueType]): the arguments to the statement,
|
202
|
+
includes self if the corresponding callable region contains a self argument.
|
203
|
+
|
204
|
+
Returns:
|
205
|
+
ValueType: the result of the statement.
|
206
|
+
"""
|
207
|
+
if len(self.state.frames) >= self.max_depth:
|
208
|
+
return self.eval_recursion_limit(self.state.current_frame())
|
209
|
+
|
210
|
+
interface = code.get_trait(traits.CallableStmtInterface)
|
211
|
+
if interface is None:
|
212
|
+
raise InterpreterError(f"statement {code.name} is not callable")
|
213
|
+
|
214
|
+
frame = self.new_frame(code)
|
215
|
+
self.state.push_frame(frame)
|
216
|
+
body = interface.get_callable_region(code)
|
217
|
+
if not body.blocks:
|
218
|
+
return self.state.pop_frame(), self.void
|
219
|
+
frame.set_values(body.blocks[0].args, args)
|
220
|
+
results = self.run_callable_region(frame, code, body)
|
221
|
+
return self.state.pop_frame(), results
|
222
|
+
|
223
|
+
def run_callable_region(
|
224
|
+
self, frame: FrameType, code: Statement, region: Region
|
225
|
+
) -> ValueType:
|
226
|
+
"""A hook defines how to run the callable region given
|
227
|
+
the interpreter context. Frame should be pushed before calling
|
228
|
+
this method and popped after calling this method.
|
229
|
+
|
230
|
+
A callable region is a region that can be called as a function.
|
231
|
+
Unlike a general region (or the MLIR convention), it always return a value
|
232
|
+
to be compatible with the Python convention.
|
233
|
+
"""
|
234
|
+
results = self.run_ssacfg_region(frame, region)
|
235
|
+
if isinstance(results, ReturnValue):
|
236
|
+
return results.value
|
237
|
+
elif not results: # empty result or None
|
238
|
+
return self.void
|
239
|
+
raise InterpreterError(
|
240
|
+
f"callable region {code.name} does not return `ReturnValue`, got {results}"
|
241
|
+
)
|
242
|
+
|
243
|
+
def run_block(self, frame: FrameType, block: Block) -> SpecialValue[ValueType]:
|
244
|
+
"""Run a block within the current frame.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
frame: the current frame.
|
248
|
+
block: the block to run.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
SpecialValue: the result of running the block terminator.
|
252
|
+
"""
|
253
|
+
...
|
254
|
+
|
255
|
+
@abstractmethod
|
256
|
+
def new_frame(self, code: Statement) -> FrameType:
|
257
|
+
"""Create a new frame for the given method."""
|
258
|
+
...
|
259
|
+
|
260
|
+
@staticmethod
|
261
|
+
def get_args(
|
262
|
+
left_arg_names, args: tuple[ValueType, ...], kwargs: dict[str, ValueType] | None
|
263
|
+
) -> tuple[ValueType, ...]:
|
264
|
+
if kwargs:
|
265
|
+
# NOTE: #self# is not user input so it is not
|
266
|
+
# in the args, +1 is for self
|
267
|
+
for name in left_arg_names:
|
268
|
+
args += (kwargs[name],)
|
269
|
+
return args
|
270
|
+
|
271
|
+
@staticmethod
|
272
|
+
def permute_values(
|
273
|
+
arg_names: Sequence[str],
|
274
|
+
values: tuple[ValueType, ...],
|
275
|
+
kwarg_names: tuple[str, ...],
|
276
|
+
) -> tuple[ValueType, ...]:
|
277
|
+
"""Permute the arguments according to the method signature and
|
278
|
+
the given keyword arguments, where the keyword argument names
|
279
|
+
refer to the last n arguments in the values tuple.
|
280
|
+
|
281
|
+
Args:
|
282
|
+
arg_names: the argument names
|
283
|
+
values: the values tuple (should not contain method itself)
|
284
|
+
kwarg_names: the keyword argument names
|
285
|
+
"""
|
286
|
+
n_total = len(values)
|
287
|
+
if kwarg_names:
|
288
|
+
kwargs = dict(zip(kwarg_names, values[n_total - len(kwarg_names) :]))
|
289
|
+
else:
|
290
|
+
kwargs = None
|
291
|
+
|
292
|
+
positionals = values[: n_total - len(kwarg_names)]
|
293
|
+
args = BaseInterpreter.get_args(
|
294
|
+
arg_names[len(positionals) + 1 :], positionals, kwargs
|
295
|
+
)
|
296
|
+
return args
|
297
|
+
|
298
|
+
def eval_stmt(
|
299
|
+
self, frame: FrameType, stmt: Statement
|
300
|
+
) -> StatementResult[ValueType]:
|
301
|
+
"""Run a statement within the current frame. This is the entry
|
302
|
+
point of running a statement. It will look up the statement implementation
|
303
|
+
in the dialect registry, or optionally call a fallback implementation.
|
304
|
+
|
305
|
+
Args:
|
306
|
+
frame: the current frame
|
307
|
+
stmt: the statement to run
|
308
|
+
|
309
|
+
Returns:
|
310
|
+
StatementResult: the result of running the statement
|
311
|
+
|
312
|
+
Note:
|
313
|
+
Overload this method for the following reasons:
|
314
|
+
- to change the source tracking information
|
315
|
+
- to take control of how to run a statement
|
316
|
+
- to change the implementation lookup behavior that cannot acheive
|
317
|
+
by overloading [`lookup_registry`][kirin.interp.base.BaseInterpreter.lookup_registry]
|
318
|
+
|
319
|
+
Example:
|
320
|
+
* implement an interpreter that only handles MyStmt:
|
321
|
+
```python
|
322
|
+
class MyInterpreter(BaseInterpreter):
|
323
|
+
...
|
324
|
+
def eval_stmt(self, frame: FrameType, stmt: Statement) -> StatementResult[ValueType]:
|
325
|
+
if isinstance(stmt, MyStmt):
|
326
|
+
return self.run_my_stmt(frame, stmt)
|
327
|
+
else:
|
328
|
+
return ()
|
329
|
+
```
|
330
|
+
|
331
|
+
"""
|
332
|
+
# TODO: update tracking information
|
333
|
+
method = self.lookup_registry(frame, stmt)
|
334
|
+
if method is not None:
|
335
|
+
results = method(self, frame, stmt)
|
336
|
+
if self.debug and not isinstance(results, (tuple, SpecialValue)):
|
337
|
+
raise InterpreterError(
|
338
|
+
f"method must return tuple or SpecialResult, got {results}"
|
339
|
+
)
|
340
|
+
return results
|
341
|
+
|
342
|
+
return self.eval_stmt_fallback(frame, stmt)
|
343
|
+
|
344
|
+
@deprecated("use eval_stmt_fallback instead")
|
345
|
+
def run_stmt_fallback(
|
346
|
+
self, frame: FrameType, stmt: Statement
|
347
|
+
) -> StatementResult[ValueType]:
|
348
|
+
return self.eval_stmt_fallback(frame, stmt)
|
349
|
+
|
350
|
+
def eval_stmt_fallback(
|
351
|
+
self, frame: FrameType, stmt: Statement
|
352
|
+
) -> StatementResult[ValueType]:
|
353
|
+
"""The fallback implementation of statements.
|
354
|
+
|
355
|
+
This is called when no implementation is found for the statement.
|
356
|
+
|
357
|
+
Args:
|
358
|
+
frame: the current frame
|
359
|
+
stmt: the statement to run
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
StatementResult: the result of running the statement
|
363
|
+
|
364
|
+
Note:
|
365
|
+
Overload this method to provide a fallback implementation for statements.
|
366
|
+
"""
|
367
|
+
# NOTE: not using f-string here because 3.10 and 3.11 have
|
368
|
+
# parser bug that doesn't allow f-string in raise statement
|
369
|
+
raise InterpreterError(
|
370
|
+
"no implementation for stmt "
|
371
|
+
+ stmt.print_str(end="")
|
372
|
+
+ " from "
|
373
|
+
+ str(type(self))
|
374
|
+
)
|
375
|
+
|
376
|
+
def eval_recursion_limit(self, frame: FrameType) -> tuple[FrameType, ValueType]:
|
377
|
+
"""Return the value of recursion exception, e.g in concrete
|
378
|
+
interpreter, it will raise an exception if the limit is reached;
|
379
|
+
in type inference, it will return a special value.
|
380
|
+
"""
|
381
|
+
raise InterpreterError("maximum recursion depth exceeded")
|
382
|
+
|
383
|
+
def build_signature(self, frame: FrameType, stmt: Statement) -> "Signature":
|
384
|
+
"""build signature for querying the statement implementation."""
|
385
|
+
return Signature(stmt.__class__, tuple(arg.type for arg in stmt.args))
|
386
|
+
|
387
|
+
def lookup_registry(
|
388
|
+
self, frame: FrameType, stmt: Statement
|
389
|
+
) -> Optional["StatementImpl[Self, FrameType]"]:
|
390
|
+
"""Lookup the statement implementation in the registry.
|
391
|
+
|
392
|
+
Args:
|
393
|
+
frame: the current frame
|
394
|
+
stmt: the statement to run
|
395
|
+
|
396
|
+
Returns:
|
397
|
+
Optional[StatementImpl]: the statement implementation if found, None otherwise.
|
398
|
+
"""
|
399
|
+
sig = self.build_signature(frame, stmt)
|
400
|
+
if sig in self.registry.statements:
|
401
|
+
return self.registry.statements[sig]
|
402
|
+
elif (class_sig := Signature(stmt.__class__)) in self.registry.statements:
|
403
|
+
return self.registry.statements[class_sig]
|
404
|
+
return
|
405
|
+
|
406
|
+
@abstractmethod
|
407
|
+
def run_ssacfg_region(
|
408
|
+
self, frame: FrameType, region: Region
|
409
|
+
) -> tuple[ValueType, ...] | None | ReturnValue[ValueType]:
|
410
|
+
"""This implements how to run a region with MLIR SSA CFG convention.
|
411
|
+
|
412
|
+
Args:
|
413
|
+
frame: the current frame.
|
414
|
+
region: the region to run.
|
415
|
+
|
416
|
+
Returns:
|
417
|
+
tuple[ValueType, ...] | SpecialValue[ValueType]: the result of running the region.
|
418
|
+
|
419
|
+
when region returns `tuple[ValueType, ...]`, it means the region terminates normally
|
420
|
+
with `YieldValue`. When region returns `ReturnValue`, it means the region terminates
|
421
|
+
and needs to pop the frame. Region cannot return `Successor` because reference to
|
422
|
+
external region is not allowed.
|
423
|
+
"""
|
424
|
+
...
|
425
|
+
|
426
|
+
class FuelResult(Enum):
|
427
|
+
Stop = 0
|
428
|
+
Continue = 1
|
429
|
+
|
430
|
+
def consume_fuel(self) -> FuelResult:
|
431
|
+
if self.fuel is None: # no fuel limit
|
432
|
+
return self.FuelResult.Continue
|
433
|
+
|
434
|
+
if self.fuel == 0:
|
435
|
+
return self.FuelResult.Stop
|
436
|
+
else:
|
437
|
+
self.fuel -= 1
|
438
|
+
return self.FuelResult.Continue
|
kirin/interp/concrete.py
ADDED
@@ -0,0 +1,62 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from kirin.ir import Block, Region
|
4
|
+
from kirin.ir.method import Method
|
5
|
+
from kirin.ir.nodes.stmt import Statement
|
6
|
+
|
7
|
+
from .base import BaseInterpreter
|
8
|
+
from .frame import Frame
|
9
|
+
from .value import Successor, YieldValue, ReturnValue, SpecialValue
|
10
|
+
from .exceptions import FuelExhaustedError
|
11
|
+
|
12
|
+
|
13
|
+
class Interpreter(BaseInterpreter[Frame[Any], Any]):
|
14
|
+
"""Concrete interpreter for the IR.
|
15
|
+
|
16
|
+
This is a concrete interpreter for the IR. It evaluates the IR by
|
17
|
+
executing the statements in the IR using a simple stack-based
|
18
|
+
interpreter.
|
19
|
+
"""
|
20
|
+
|
21
|
+
keys = ["main"]
|
22
|
+
void = None
|
23
|
+
|
24
|
+
def new_frame(self, code: Statement) -> Frame[Any]:
|
25
|
+
return Frame.from_func_like(code)
|
26
|
+
|
27
|
+
def run_method(
|
28
|
+
self, method: Method, args: tuple[Any, ...]
|
29
|
+
) -> tuple[Frame[Any], Any]:
|
30
|
+
return self.run_callable(method.code, (method,) + args)
|
31
|
+
|
32
|
+
def run_ssacfg_region(
|
33
|
+
self, frame: Frame[Any], region: Region
|
34
|
+
) -> tuple[Any, ...] | None | ReturnValue[Any]:
|
35
|
+
block = region.blocks[0]
|
36
|
+
while block is not None:
|
37
|
+
results = self.run_block(frame, block)
|
38
|
+
if isinstance(results, Successor):
|
39
|
+
block = results.block
|
40
|
+
frame.set_values(block.args, results.block_args)
|
41
|
+
elif isinstance(results, ReturnValue):
|
42
|
+
return results
|
43
|
+
elif isinstance(results, YieldValue):
|
44
|
+
return results.values
|
45
|
+
else:
|
46
|
+
return results
|
47
|
+
return None # region without terminator returns empty tuple
|
48
|
+
|
49
|
+
def run_block(self, frame: Frame[Any], block: Block) -> SpecialValue[Any]:
|
50
|
+
for stmt in block.stmts:
|
51
|
+
if self.consume_fuel() == self.FuelResult.Stop:
|
52
|
+
raise FuelExhaustedError("fuel exhausted")
|
53
|
+
frame.stmt = stmt
|
54
|
+
frame.lino = stmt.source.lineno if stmt.source else 0
|
55
|
+
stmt_results = self.eval_stmt(frame, stmt)
|
56
|
+
if isinstance(stmt_results, tuple):
|
57
|
+
frame.set_values(stmt._results, stmt_results)
|
58
|
+
elif stmt_results is None:
|
59
|
+
continue # empty result
|
60
|
+
else: # terminator
|
61
|
+
return stmt_results
|
62
|
+
return None
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
|
4
|
+
# errors
|
5
|
+
class InterpreterError(Exception):
|
6
|
+
"""Generic interpreter error.
|
7
|
+
|
8
|
+
This is the base class for all interpreter errors. Interpreter
|
9
|
+
errors will be catched by the interpreter and handled appropriately
|
10
|
+
as an error with stack trace (of Kirin, not Python) from the interpreter.
|
11
|
+
"""
|
12
|
+
|
13
|
+
pass
|
14
|
+
|
15
|
+
|
16
|
+
@dataclass
|
17
|
+
class WrapException(InterpreterError):
|
18
|
+
"""A special interpreter error that wraps a Python exception."""
|
19
|
+
|
20
|
+
exception: Exception
|
21
|
+
|
22
|
+
|
23
|
+
class FuelExhaustedError(InterpreterError):
|
24
|
+
"""An error raised when the interpreter runs out of fuel."""
|
25
|
+
|
26
|
+
pass
|
kirin/interp/frame.py
ADDED
@@ -0,0 +1,151 @@
|
|
1
|
+
from abc import ABC, abstractmethod
|
2
|
+
from typing import Any, Generic, TypeVar, Iterable
|
3
|
+
from dataclasses import field, dataclass
|
4
|
+
|
5
|
+
from typing_extensions import Self
|
6
|
+
|
7
|
+
from kirin.ir import SSAValue, Statement
|
8
|
+
|
9
|
+
from .exceptions import InterpreterError
|
10
|
+
|
11
|
+
ValueType = TypeVar("ValueType")
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class FrameABC(ABC, Generic[ValueType]):
|
16
|
+
"""Abstract base class for interpreter frame."""
|
17
|
+
|
18
|
+
code: Statement
|
19
|
+
"""func statement being interpreted.
|
20
|
+
"""
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
@abstractmethod
|
24
|
+
def from_func_like(cls, code: Statement) -> Self:
|
25
|
+
"""Create a new frame for the given method."""
|
26
|
+
...
|
27
|
+
|
28
|
+
@abstractmethod
|
29
|
+
def get(self, key: SSAValue) -> ValueType:
|
30
|
+
"""Get the value for the given [`SSAValue`][kirin.ir.SSAValue] key.
|
31
|
+
See also [`get_values`][kirin.interp.frame.Frame.get_values].
|
32
|
+
|
33
|
+
Args:
|
34
|
+
key(SSAValue): The key to get the value for.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
ValueType: The value.
|
38
|
+
"""
|
39
|
+
...
|
40
|
+
|
41
|
+
@abstractmethod
|
42
|
+
def set(self, key: SSAValue, value: ValueType) -> None:
|
43
|
+
"""Set the value for the given [`SSAValue`][kirin.ir.SSAValue] key.
|
44
|
+
See also [`set_values`][kirin.interp.frame.Frame.set_values].
|
45
|
+
|
46
|
+
Args:
|
47
|
+
key(SSAValue): The key to set the value for.
|
48
|
+
value(ValueType): The value.
|
49
|
+
"""
|
50
|
+
...
|
51
|
+
|
52
|
+
def get_values(self, keys: Iterable[SSAValue]) -> tuple[ValueType, ...]:
|
53
|
+
"""Get the values of the given [`SSAValue`][kirin.ir.SSAValue] keys.
|
54
|
+
See also [`get`][kirin.interp.frame.Frame.get].
|
55
|
+
|
56
|
+
Args:
|
57
|
+
keys(Iterable[SSAValue]): The keys to get the values for.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
tuple[ValueType, ...]: The values.
|
61
|
+
"""
|
62
|
+
return tuple(self.get(key) for key in keys)
|
63
|
+
|
64
|
+
def set_values(self, keys: Iterable[SSAValue], values: Iterable[ValueType]) -> None:
|
65
|
+
"""Set the values of the given [`SSAValue`][kirin.ir.SSAValue] keys.
|
66
|
+
This is a convenience method to set multiple values at once.
|
67
|
+
|
68
|
+
Args:
|
69
|
+
keys(Iterable[SSAValue]): The keys to set the values for.
|
70
|
+
values(Iterable[ValueType]): The values.
|
71
|
+
"""
|
72
|
+
for key, value in zip(keys, values):
|
73
|
+
self.set(key, value)
|
74
|
+
|
75
|
+
@abstractmethod
|
76
|
+
def set_stmt(self, stmt: Statement) -> Self:
|
77
|
+
"""Set the current statement."""
|
78
|
+
...
|
79
|
+
|
80
|
+
|
81
|
+
@dataclass
|
82
|
+
class Frame(FrameABC[ValueType]):
|
83
|
+
"""Interpreter frame."""
|
84
|
+
|
85
|
+
lino: int = 0
|
86
|
+
stmt: Statement | None = None
|
87
|
+
"""statement being interpreted.
|
88
|
+
"""
|
89
|
+
|
90
|
+
globals: dict[str, Any] = field(default_factory=dict)
|
91
|
+
"""Global variables this frame has access to.
|
92
|
+
"""
|
93
|
+
|
94
|
+
# NOTE: we are sharing the same frame within blocks
|
95
|
+
# this is because we are validating e.g SSA value pointing
|
96
|
+
# to other blocks separately. This avoids the need
|
97
|
+
# to have a separate frame for each block.
|
98
|
+
entries: dict[SSAValue, ValueType] = field(default_factory=dict)
|
99
|
+
"""SSA values and their corresponding values.
|
100
|
+
"""
|
101
|
+
|
102
|
+
@classmethod
|
103
|
+
def from_func_like(cls, code: Statement) -> Self:
|
104
|
+
"""Create a new frame for the given statement."""
|
105
|
+
return cls(code=code)
|
106
|
+
|
107
|
+
def get(self, key: SSAValue) -> ValueType:
|
108
|
+
"""Get the value for the given [`SSAValue`][kirin.ir.SSAValue].
|
109
|
+
|
110
|
+
Args:
|
111
|
+
key(SSAValue): The key to get the value for.
|
112
|
+
|
113
|
+
Returns:
|
114
|
+
ValueType: The value.
|
115
|
+
|
116
|
+
Raises:
|
117
|
+
InterpreterError: If the value is not found. This will be catched by the interpreter.
|
118
|
+
"""
|
119
|
+
err = InterpreterError(f"SSAValue {key} not found")
|
120
|
+
value = self.entries.get(key, err)
|
121
|
+
if isinstance(value, InterpreterError):
|
122
|
+
raise err
|
123
|
+
else:
|
124
|
+
return value
|
125
|
+
|
126
|
+
ExpectedType = TypeVar("ExpectedType")
|
127
|
+
|
128
|
+
def get_typed(self, key: SSAValue, type_: type[ExpectedType]) -> ExpectedType:
|
129
|
+
"""Similar to [`get`][kirin.interp.frame.Frame.get] but also checks the type.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
key(SSAValue): The key to get the value for.
|
133
|
+
type_(type): The expected type.
|
134
|
+
|
135
|
+
Returns:
|
136
|
+
ExpectedType: The value.
|
137
|
+
|
138
|
+
Raises:
|
139
|
+
InterpreterError: If the value is not of the expected type.
|
140
|
+
"""
|
141
|
+
value = self.get(key)
|
142
|
+
if not isinstance(value, type_):
|
143
|
+
raise InterpreterError(f"expected {type_}, got {type(value)}")
|
144
|
+
return value
|
145
|
+
|
146
|
+
def set(self, key: SSAValue, value: ValueType) -> None:
|
147
|
+
self.entries[key] = value
|
148
|
+
|
149
|
+
def set_stmt(self, stmt: Statement) -> Self:
|
150
|
+
self.stmt = stmt
|
151
|
+
return self
|