kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/interp/base.py ADDED
@@ -0,0 +1,438 @@
1
+ import sys
2
+ from abc import ABC, ABCMeta, abstractmethod
3
+ from enum import Enum
4
+ from typing import TYPE_CHECKING, Generic, TypeVar, ClassVar, Optional, Sequence
5
+ from dataclasses import field, dataclass
6
+
7
+ from typing_extensions import Self, deprecated
8
+
9
+ from kirin.ir import Block, Region, Statement, DialectGroup, traits
10
+ from kirin.ir.method import Method
11
+
12
+ from .impl import Signature
13
+ from .frame import FrameABC
14
+ from .state import InterpreterState
15
+ from .value import ReturnValue, SpecialValue, StatementResult
16
+ from .result import Ok, Err, Result
17
+ from .exceptions import InterpreterError
18
+
19
+ if TYPE_CHECKING:
20
+ from kirin.registry import StatementImpl, InterpreterRegistry
21
+
22
+ ValueType = TypeVar("ValueType")
23
+ FrameType = TypeVar("FrameType", bound=FrameABC)
24
+
25
+
26
+ class InterpreterMeta(ABCMeta):
27
+ """A metaclass for interpreters."""
28
+
29
+ pass
30
+
31
+
32
+ @dataclass
33
+ class BaseInterpreter(ABC, Generic[FrameType, ValueType], metaclass=InterpreterMeta):
34
+ """A base class for interpreters.
35
+
36
+ This class defines the basic structure of an interpreter. It is
37
+ designed to be subclassed to provide the actual implementation of
38
+ the interpreter.
39
+
40
+ ### Required Overrides
41
+ When subclassing, if the subclass does not contain `ABC`,
42
+ the subclass must define the following attributes:
43
+
44
+ - `keys`: a list of strings that defines the order of dialects to select from.
45
+ - `void`: the value to return when the interpreter evaluates nothing.
46
+ """
47
+
48
+ keys: ClassVar[list[str]]
49
+ """The name of the interpreter to select from dialects by order.
50
+ """
51
+ void: ValueType = field(init=False)
52
+ """What to return when the interpreter evaluates nothing.
53
+ """
54
+ dialects: DialectGroup
55
+ """The dialects to interpret.
56
+ """
57
+ fuel: int | None = field(default=None, kw_only=True)
58
+ """The fuel limit for the interpreter.
59
+ """
60
+ debug: bool = field(default=False, kw_only=True)
61
+ """Whether to enable debug mode.
62
+ """
63
+ max_depth: int = field(default=128, kw_only=True)
64
+ """The maximum depth of the interpreter stack.
65
+ """
66
+ max_python_recursion_depth: int = field(default=8192, kw_only=True)
67
+ """The maximum recursion depth of the Python interpreter.
68
+ """
69
+
70
+ # global states
71
+ registry: "InterpreterRegistry" = field(init=False, compare=False)
72
+ """The interpreter registry.
73
+ """
74
+ symbol_table: dict[str, Statement] = field(init=False, compare=False)
75
+ """The symbol table.
76
+ """
77
+ state: InterpreterState[FrameType] = field(init=False, compare=False)
78
+ """The interpreter state.
79
+ """
80
+
81
+ # private
82
+ _eval_lock: bool = field(default=False, init=False, compare=False)
83
+
84
+ def __post_init__(self) -> None:
85
+ self.registry = self.dialects.registry.interpreter(keys=self.keys)
86
+
87
+ def initialize(self) -> Self:
88
+ """Initialize the interpreter global states. This method is called right upon
89
+ calling [`run`][kirin.interp.base.BaseInterpreter.run] to initialize the
90
+ interpreter global states.
91
+
92
+ !!! note "Default Implementation"
93
+ This method provides default behavior but may be overridden by subclasses
94
+ to customize or extend functionality.
95
+ """
96
+ self.symbol_table: dict[str, Statement] = {}
97
+ self.state: InterpreterState[FrameType] = InterpreterState()
98
+ return self
99
+
100
+ def __init_subclass__(cls) -> None:
101
+ super().__init_subclass__()
102
+ if ABC in cls.__bases__:
103
+ return
104
+
105
+ if not hasattr(cls, "keys"):
106
+ raise TypeError(f"keys is not defined for class {cls.__name__}")
107
+ if not hasattr(cls, "void"):
108
+ raise TypeError(f"void is not defined for class {cls.__name__}")
109
+
110
+ @deprecated("use run instead")
111
+ def eval(
112
+ self,
113
+ mt: Method,
114
+ args: tuple[ValueType, ...],
115
+ kwargs: dict[str, ValueType] | None = None,
116
+ ) -> Result[ValueType]:
117
+ return self.run(mt, args, kwargs)
118
+
119
+ def run(
120
+ self,
121
+ mt: Method,
122
+ args: tuple[ValueType, ...],
123
+ kwargs: dict[str, ValueType] | None = None,
124
+ ) -> Result[ValueType]:
125
+ """Run a method. This is the main entry point of the interpreter.
126
+
127
+ Args:
128
+ mt (Method): the method to run.
129
+ args (tuple[ValueType]): the arguments to the method, does not include self.
130
+ kwargs (dict[str, ValueType], optional): the keyword arguments to the method.
131
+
132
+ Returns:
133
+ Result[ValueType]: the result of the method.
134
+ """
135
+ if self._eval_lock:
136
+ raise InterpreterError(
137
+ "recursive eval is not allowed, use run_method instead"
138
+ )
139
+
140
+ self._eval_lock = True
141
+ self.initialize()
142
+ current_recursion_limit = sys.getrecursionlimit()
143
+ sys.setrecursionlimit(self.max_python_recursion_depth)
144
+ args = self.get_args(mt.arg_names[len(args) + 1 :], args, kwargs)
145
+ try:
146
+ _, results = self.run_method(mt, args)
147
+ except InterpreterError as e:
148
+ # NOTE: initialize will create new State
149
+ # so we don't need to copy the frames.
150
+ return Err(e, self.state.frames)
151
+ finally:
152
+ self._eval_lock = False
153
+ sys.setrecursionlimit(current_recursion_limit)
154
+ return Ok(results)
155
+
156
+ def run_stmt(
157
+ self, stmt: Statement, args: tuple[ValueType, ...]
158
+ ) -> StatementResult[ValueType]:
159
+ """execute a statement with arguments in a new frame.
160
+
161
+ Args:
162
+ stmt (Statement): the statement to run.
163
+ args (tuple[ValueType]): the arguments to the statement.
164
+
165
+ Returns:
166
+ StatementResult[ValueType]: the result of the statement.
167
+ """
168
+ frame = self.new_frame(stmt)
169
+ self.state.push_frame(frame)
170
+ frame.set_values(stmt.args, args)
171
+ results = self.eval_stmt(frame, stmt)
172
+ self.state.pop_frame()
173
+ return results
174
+
175
+ @abstractmethod
176
+ def run_method(
177
+ self, method: Method, args: tuple[ValueType, ...]
178
+ ) -> tuple[FrameType, ValueType]:
179
+ """How to run a method.
180
+
181
+ This is defined by subclasses to describe what's the corresponding
182
+ value of a method during the interpretation. Usually, this method
183
+ just calls [`run_callable`][kirin.interp.base.BaseInterpreter.run_callable].
184
+
185
+ Args:
186
+ method (Method): the method to run.
187
+ args (tuple[ValueType]): the arguments to the method, does not include self.
188
+
189
+ Returns:
190
+ ValueType: the result of the method.
191
+ """
192
+ ...
193
+
194
+ def run_callable(
195
+ self, code: Statement, args: tuple[ValueType, ...]
196
+ ) -> tuple[FrameType, ValueType]:
197
+ """Run a callable statement.
198
+
199
+ Args:
200
+ code (Statement): the statement to run.
201
+ args (tuple[ValueType]): the arguments to the statement,
202
+ includes self if the corresponding callable region contains a self argument.
203
+
204
+ Returns:
205
+ ValueType: the result of the statement.
206
+ """
207
+ if len(self.state.frames) >= self.max_depth:
208
+ return self.eval_recursion_limit(self.state.current_frame())
209
+
210
+ interface = code.get_trait(traits.CallableStmtInterface)
211
+ if interface is None:
212
+ raise InterpreterError(f"statement {code.name} is not callable")
213
+
214
+ frame = self.new_frame(code)
215
+ self.state.push_frame(frame)
216
+ body = interface.get_callable_region(code)
217
+ if not body.blocks:
218
+ return self.state.pop_frame(), self.void
219
+ frame.set_values(body.blocks[0].args, args)
220
+ results = self.run_callable_region(frame, code, body)
221
+ return self.state.pop_frame(), results
222
+
223
+ def run_callable_region(
224
+ self, frame: FrameType, code: Statement, region: Region
225
+ ) -> ValueType:
226
+ """A hook defines how to run the callable region given
227
+ the interpreter context. Frame should be pushed before calling
228
+ this method and popped after calling this method.
229
+
230
+ A callable region is a region that can be called as a function.
231
+ Unlike a general region (or the MLIR convention), it always return a value
232
+ to be compatible with the Python convention.
233
+ """
234
+ results = self.run_ssacfg_region(frame, region)
235
+ if isinstance(results, ReturnValue):
236
+ return results.value
237
+ elif not results: # empty result or None
238
+ return self.void
239
+ raise InterpreterError(
240
+ f"callable region {code.name} does not return `ReturnValue`, got {results}"
241
+ )
242
+
243
+ def run_block(self, frame: FrameType, block: Block) -> SpecialValue[ValueType]:
244
+ """Run a block within the current frame.
245
+
246
+ Args:
247
+ frame: the current frame.
248
+ block: the block to run.
249
+
250
+ Returns:
251
+ SpecialValue: the result of running the block terminator.
252
+ """
253
+ ...
254
+
255
+ @abstractmethod
256
+ def new_frame(self, code: Statement) -> FrameType:
257
+ """Create a new frame for the given method."""
258
+ ...
259
+
260
+ @staticmethod
261
+ def get_args(
262
+ left_arg_names, args: tuple[ValueType, ...], kwargs: dict[str, ValueType] | None
263
+ ) -> tuple[ValueType, ...]:
264
+ if kwargs:
265
+ # NOTE: #self# is not user input so it is not
266
+ # in the args, +1 is for self
267
+ for name in left_arg_names:
268
+ args += (kwargs[name],)
269
+ return args
270
+
271
+ @staticmethod
272
+ def permute_values(
273
+ arg_names: Sequence[str],
274
+ values: tuple[ValueType, ...],
275
+ kwarg_names: tuple[str, ...],
276
+ ) -> tuple[ValueType, ...]:
277
+ """Permute the arguments according to the method signature and
278
+ the given keyword arguments, where the keyword argument names
279
+ refer to the last n arguments in the values tuple.
280
+
281
+ Args:
282
+ arg_names: the argument names
283
+ values: the values tuple (should not contain method itself)
284
+ kwarg_names: the keyword argument names
285
+ """
286
+ n_total = len(values)
287
+ if kwarg_names:
288
+ kwargs = dict(zip(kwarg_names, values[n_total - len(kwarg_names) :]))
289
+ else:
290
+ kwargs = None
291
+
292
+ positionals = values[: n_total - len(kwarg_names)]
293
+ args = BaseInterpreter.get_args(
294
+ arg_names[len(positionals) + 1 :], positionals, kwargs
295
+ )
296
+ return args
297
+
298
+ def eval_stmt(
299
+ self, frame: FrameType, stmt: Statement
300
+ ) -> StatementResult[ValueType]:
301
+ """Run a statement within the current frame. This is the entry
302
+ point of running a statement. It will look up the statement implementation
303
+ in the dialect registry, or optionally call a fallback implementation.
304
+
305
+ Args:
306
+ frame: the current frame
307
+ stmt: the statement to run
308
+
309
+ Returns:
310
+ StatementResult: the result of running the statement
311
+
312
+ Note:
313
+ Overload this method for the following reasons:
314
+ - to change the source tracking information
315
+ - to take control of how to run a statement
316
+ - to change the implementation lookup behavior that cannot acheive
317
+ by overloading [`lookup_registry`][kirin.interp.base.BaseInterpreter.lookup_registry]
318
+
319
+ Example:
320
+ * implement an interpreter that only handles MyStmt:
321
+ ```python
322
+ class MyInterpreter(BaseInterpreter):
323
+ ...
324
+ def eval_stmt(self, frame: FrameType, stmt: Statement) -> StatementResult[ValueType]:
325
+ if isinstance(stmt, MyStmt):
326
+ return self.run_my_stmt(frame, stmt)
327
+ else:
328
+ return ()
329
+ ```
330
+
331
+ """
332
+ # TODO: update tracking information
333
+ method = self.lookup_registry(frame, stmt)
334
+ if method is not None:
335
+ results = method(self, frame, stmt)
336
+ if self.debug and not isinstance(results, (tuple, SpecialValue)):
337
+ raise InterpreterError(
338
+ f"method must return tuple or SpecialResult, got {results}"
339
+ )
340
+ return results
341
+
342
+ return self.eval_stmt_fallback(frame, stmt)
343
+
344
+ @deprecated("use eval_stmt_fallback instead")
345
+ def run_stmt_fallback(
346
+ self, frame: FrameType, stmt: Statement
347
+ ) -> StatementResult[ValueType]:
348
+ return self.eval_stmt_fallback(frame, stmt)
349
+
350
+ def eval_stmt_fallback(
351
+ self, frame: FrameType, stmt: Statement
352
+ ) -> StatementResult[ValueType]:
353
+ """The fallback implementation of statements.
354
+
355
+ This is called when no implementation is found for the statement.
356
+
357
+ Args:
358
+ frame: the current frame
359
+ stmt: the statement to run
360
+
361
+ Returns:
362
+ StatementResult: the result of running the statement
363
+
364
+ Note:
365
+ Overload this method to provide a fallback implementation for statements.
366
+ """
367
+ # NOTE: not using f-string here because 3.10 and 3.11 have
368
+ # parser bug that doesn't allow f-string in raise statement
369
+ raise InterpreterError(
370
+ "no implementation for stmt "
371
+ + stmt.print_str(end="")
372
+ + " from "
373
+ + str(type(self))
374
+ )
375
+
376
+ def eval_recursion_limit(self, frame: FrameType) -> tuple[FrameType, ValueType]:
377
+ """Return the value of recursion exception, e.g in concrete
378
+ interpreter, it will raise an exception if the limit is reached;
379
+ in type inference, it will return a special value.
380
+ """
381
+ raise InterpreterError("maximum recursion depth exceeded")
382
+
383
+ def build_signature(self, frame: FrameType, stmt: Statement) -> "Signature":
384
+ """build signature for querying the statement implementation."""
385
+ return Signature(stmt.__class__, tuple(arg.type for arg in stmt.args))
386
+
387
+ def lookup_registry(
388
+ self, frame: FrameType, stmt: Statement
389
+ ) -> Optional["StatementImpl[Self, FrameType]"]:
390
+ """Lookup the statement implementation in the registry.
391
+
392
+ Args:
393
+ frame: the current frame
394
+ stmt: the statement to run
395
+
396
+ Returns:
397
+ Optional[StatementImpl]: the statement implementation if found, None otherwise.
398
+ """
399
+ sig = self.build_signature(frame, stmt)
400
+ if sig in self.registry.statements:
401
+ return self.registry.statements[sig]
402
+ elif (class_sig := Signature(stmt.__class__)) in self.registry.statements:
403
+ return self.registry.statements[class_sig]
404
+ return
405
+
406
+ @abstractmethod
407
+ def run_ssacfg_region(
408
+ self, frame: FrameType, region: Region
409
+ ) -> tuple[ValueType, ...] | None | ReturnValue[ValueType]:
410
+ """This implements how to run a region with MLIR SSA CFG convention.
411
+
412
+ Args:
413
+ frame: the current frame.
414
+ region: the region to run.
415
+
416
+ Returns:
417
+ tuple[ValueType, ...] | SpecialValue[ValueType]: the result of running the region.
418
+
419
+ when region returns `tuple[ValueType, ...]`, it means the region terminates normally
420
+ with `YieldValue`. When region returns `ReturnValue`, it means the region terminates
421
+ and needs to pop the frame. Region cannot return `Successor` because reference to
422
+ external region is not allowed.
423
+ """
424
+ ...
425
+
426
+ class FuelResult(Enum):
427
+ Stop = 0
428
+ Continue = 1
429
+
430
+ def consume_fuel(self) -> FuelResult:
431
+ if self.fuel is None: # no fuel limit
432
+ return self.FuelResult.Continue
433
+
434
+ if self.fuel == 0:
435
+ return self.FuelResult.Stop
436
+ else:
437
+ self.fuel -= 1
438
+ return self.FuelResult.Continue
@@ -0,0 +1,62 @@
1
+ from typing import Any
2
+
3
+ from kirin.ir import Block, Region
4
+ from kirin.ir.method import Method
5
+ from kirin.ir.nodes.stmt import Statement
6
+
7
+ from .base import BaseInterpreter
8
+ from .frame import Frame
9
+ from .value import Successor, YieldValue, ReturnValue, SpecialValue
10
+ from .exceptions import FuelExhaustedError
11
+
12
+
13
+ class Interpreter(BaseInterpreter[Frame[Any], Any]):
14
+ """Concrete interpreter for the IR.
15
+
16
+ This is a concrete interpreter for the IR. It evaluates the IR by
17
+ executing the statements in the IR using a simple stack-based
18
+ interpreter.
19
+ """
20
+
21
+ keys = ["main"]
22
+ void = None
23
+
24
+ def new_frame(self, code: Statement) -> Frame[Any]:
25
+ return Frame.from_func_like(code)
26
+
27
+ def run_method(
28
+ self, method: Method, args: tuple[Any, ...]
29
+ ) -> tuple[Frame[Any], Any]:
30
+ return self.run_callable(method.code, (method,) + args)
31
+
32
+ def run_ssacfg_region(
33
+ self, frame: Frame[Any], region: Region
34
+ ) -> tuple[Any, ...] | None | ReturnValue[Any]:
35
+ block = region.blocks[0]
36
+ while block is not None:
37
+ results = self.run_block(frame, block)
38
+ if isinstance(results, Successor):
39
+ block = results.block
40
+ frame.set_values(block.args, results.block_args)
41
+ elif isinstance(results, ReturnValue):
42
+ return results
43
+ elif isinstance(results, YieldValue):
44
+ return results.values
45
+ else:
46
+ return results
47
+ return None # region without terminator returns empty tuple
48
+
49
+ def run_block(self, frame: Frame[Any], block: Block) -> SpecialValue[Any]:
50
+ for stmt in block.stmts:
51
+ if self.consume_fuel() == self.FuelResult.Stop:
52
+ raise FuelExhaustedError("fuel exhausted")
53
+ frame.stmt = stmt
54
+ frame.lino = stmt.source.lineno if stmt.source else 0
55
+ stmt_results = self.eval_stmt(frame, stmt)
56
+ if isinstance(stmt_results, tuple):
57
+ frame.set_values(stmt._results, stmt_results)
58
+ elif stmt_results is None:
59
+ continue # empty result
60
+ else: # terminator
61
+ return stmt_results
62
+ return None
@@ -0,0 +1,26 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ # errors
5
+ class InterpreterError(Exception):
6
+ """Generic interpreter error.
7
+
8
+ This is the base class for all interpreter errors. Interpreter
9
+ errors will be catched by the interpreter and handled appropriately
10
+ as an error with stack trace (of Kirin, not Python) from the interpreter.
11
+ """
12
+
13
+ pass
14
+
15
+
16
+ @dataclass
17
+ class WrapException(InterpreterError):
18
+ """A special interpreter error that wraps a Python exception."""
19
+
20
+ exception: Exception
21
+
22
+
23
+ class FuelExhaustedError(InterpreterError):
24
+ """An error raised when the interpreter runs out of fuel."""
25
+
26
+ pass
kirin/interp/frame.py ADDED
@@ -0,0 +1,151 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any, Generic, TypeVar, Iterable
3
+ from dataclasses import field, dataclass
4
+
5
+ from typing_extensions import Self
6
+
7
+ from kirin.ir import SSAValue, Statement
8
+
9
+ from .exceptions import InterpreterError
10
+
11
+ ValueType = TypeVar("ValueType")
12
+
13
+
14
+ @dataclass
15
+ class FrameABC(ABC, Generic[ValueType]):
16
+ """Abstract base class for interpreter frame."""
17
+
18
+ code: Statement
19
+ """func statement being interpreted.
20
+ """
21
+
22
+ @classmethod
23
+ @abstractmethod
24
+ def from_func_like(cls, code: Statement) -> Self:
25
+ """Create a new frame for the given method."""
26
+ ...
27
+
28
+ @abstractmethod
29
+ def get(self, key: SSAValue) -> ValueType:
30
+ """Get the value for the given [`SSAValue`][kirin.ir.SSAValue] key.
31
+ See also [`get_values`][kirin.interp.frame.Frame.get_values].
32
+
33
+ Args:
34
+ key(SSAValue): The key to get the value for.
35
+
36
+ Returns:
37
+ ValueType: The value.
38
+ """
39
+ ...
40
+
41
+ @abstractmethod
42
+ def set(self, key: SSAValue, value: ValueType) -> None:
43
+ """Set the value for the given [`SSAValue`][kirin.ir.SSAValue] key.
44
+ See also [`set_values`][kirin.interp.frame.Frame.set_values].
45
+
46
+ Args:
47
+ key(SSAValue): The key to set the value for.
48
+ value(ValueType): The value.
49
+ """
50
+ ...
51
+
52
+ def get_values(self, keys: Iterable[SSAValue]) -> tuple[ValueType, ...]:
53
+ """Get the values of the given [`SSAValue`][kirin.ir.SSAValue] keys.
54
+ See also [`get`][kirin.interp.frame.Frame.get].
55
+
56
+ Args:
57
+ keys(Iterable[SSAValue]): The keys to get the values for.
58
+
59
+ Returns:
60
+ tuple[ValueType, ...]: The values.
61
+ """
62
+ return tuple(self.get(key) for key in keys)
63
+
64
+ def set_values(self, keys: Iterable[SSAValue], values: Iterable[ValueType]) -> None:
65
+ """Set the values of the given [`SSAValue`][kirin.ir.SSAValue] keys.
66
+ This is a convenience method to set multiple values at once.
67
+
68
+ Args:
69
+ keys(Iterable[SSAValue]): The keys to set the values for.
70
+ values(Iterable[ValueType]): The values.
71
+ """
72
+ for key, value in zip(keys, values):
73
+ self.set(key, value)
74
+
75
+ @abstractmethod
76
+ def set_stmt(self, stmt: Statement) -> Self:
77
+ """Set the current statement."""
78
+ ...
79
+
80
+
81
+ @dataclass
82
+ class Frame(FrameABC[ValueType]):
83
+ """Interpreter frame."""
84
+
85
+ lino: int = 0
86
+ stmt: Statement | None = None
87
+ """statement being interpreted.
88
+ """
89
+
90
+ globals: dict[str, Any] = field(default_factory=dict)
91
+ """Global variables this frame has access to.
92
+ """
93
+
94
+ # NOTE: we are sharing the same frame within blocks
95
+ # this is because we are validating e.g SSA value pointing
96
+ # to other blocks separately. This avoids the need
97
+ # to have a separate frame for each block.
98
+ entries: dict[SSAValue, ValueType] = field(default_factory=dict)
99
+ """SSA values and their corresponding values.
100
+ """
101
+
102
+ @classmethod
103
+ def from_func_like(cls, code: Statement) -> Self:
104
+ """Create a new frame for the given statement."""
105
+ return cls(code=code)
106
+
107
+ def get(self, key: SSAValue) -> ValueType:
108
+ """Get the value for the given [`SSAValue`][kirin.ir.SSAValue].
109
+
110
+ Args:
111
+ key(SSAValue): The key to get the value for.
112
+
113
+ Returns:
114
+ ValueType: The value.
115
+
116
+ Raises:
117
+ InterpreterError: If the value is not found. This will be catched by the interpreter.
118
+ """
119
+ err = InterpreterError(f"SSAValue {key} not found")
120
+ value = self.entries.get(key, err)
121
+ if isinstance(value, InterpreterError):
122
+ raise err
123
+ else:
124
+ return value
125
+
126
+ ExpectedType = TypeVar("ExpectedType")
127
+
128
+ def get_typed(self, key: SSAValue, type_: type[ExpectedType]) -> ExpectedType:
129
+ """Similar to [`get`][kirin.interp.frame.Frame.get] but also checks the type.
130
+
131
+ Args:
132
+ key(SSAValue): The key to get the value for.
133
+ type_(type): The expected type.
134
+
135
+ Returns:
136
+ ExpectedType: The value.
137
+
138
+ Raises:
139
+ InterpreterError: If the value is not of the expected type.
140
+ """
141
+ value = self.get(key)
142
+ if not isinstance(value, type_):
143
+ raise InterpreterError(f"expected {type_}, got {type(value)}")
144
+ return value
145
+
146
+ def set(self, key: SSAValue, value: ValueType) -> None:
147
+ self.entries[key] = value
148
+
149
+ def set_stmt(self, stmt: Statement) -> Self:
150
+ self.stmt = stmt
151
+ return self