kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,233 @@
1
+ from kirin import types
2
+ from kirin.ir import (
3
+ Pure,
4
+ Method,
5
+ Region,
6
+ HasParent,
7
+ MaybePure,
8
+ Statement,
9
+ ResultValue,
10
+ ConstantLike,
11
+ HasSignature,
12
+ IsTerminator,
13
+ SSACFGRegion,
14
+ IsolatedFromAbove,
15
+ SymbolOpInterface,
16
+ CallableStmtInterface,
17
+ )
18
+ from kirin.decl import info, statement
19
+ from kirin.ir.ssa import SSAValue
20
+ from kirin.exceptions import VerificationError
21
+ from kirin.print.printer import Printer
22
+ from kirin.dialects.func.attrs import Signature, MethodType
23
+ from kirin.dialects.func.dialect import dialect
24
+
25
+ from .._pprint_helper import pprint_calllike
26
+
27
+
28
+ class FuncOpCallableInterface(CallableStmtInterface["Function"]):
29
+
30
+ @classmethod
31
+ def get_callable_region(cls, stmt: "Function") -> Region:
32
+ return stmt.body
33
+
34
+
35
+ @statement(dialect=dialect)
36
+ class Function(Statement):
37
+ name = "func"
38
+ traits = frozenset(
39
+ {
40
+ IsolatedFromAbove(),
41
+ SymbolOpInterface(),
42
+ HasSignature(),
43
+ FuncOpCallableInterface(),
44
+ SSACFGRegion(),
45
+ }
46
+ )
47
+ sym_name: str = info.attribute()
48
+ """The symbol name of the function."""
49
+ signature: Signature = info.attribute()
50
+ body: Region = info.region(multi=True)
51
+
52
+ def print_impl(self, printer: Printer) -> None:
53
+ with printer.rich(style="keyword"):
54
+ printer.print_name(self)
55
+ printer.plain_print(" ")
56
+
57
+ with printer.rich(style="symbol"):
58
+ printer.plain_print(self.sym_name)
59
+
60
+ printer.print_seq(self.signature.inputs, prefix="(", suffix=")", delim=", ")
61
+
62
+ with printer.rich(style="comment"):
63
+ printer.plain_print(" -> ")
64
+ printer.print(self.signature.output)
65
+ printer.plain_print(" ")
66
+
67
+ printer.print(self.body)
68
+
69
+ with printer.rich(style="comment"):
70
+ printer.plain_print(f" // func.func {self.sym_name}")
71
+
72
+
73
+ @statement(dialect=dialect)
74
+ class Call(Statement):
75
+ name = "call"
76
+ traits = frozenset({MaybePure()})
77
+ # not a fixed type here so just any
78
+ callee: SSAValue = info.argument()
79
+ inputs: tuple[SSAValue, ...] = info.argument()
80
+ kwargs: tuple[str, ...] = info.attribute(default_factory=lambda: ())
81
+ result: ResultValue = info.result()
82
+ purity: bool = info.attribute(default=False)
83
+
84
+ def print_impl(self, printer: Printer) -> None:
85
+ pprint_calllike(self, printer.state.ssa_id[self.callee], printer)
86
+
87
+
88
+ @statement(dialect=dialect)
89
+ class ConstantNone(Statement):
90
+ """A constant None value.
91
+
92
+ This is mainly used to represent the None return value of a function
93
+ to match Python semantics.
94
+ """
95
+
96
+ name = "const.none"
97
+ traits = frozenset({Pure(), ConstantLike()})
98
+ result: ResultValue = info.result(types.NoneType)
99
+
100
+
101
+ @statement(dialect=dialect, init=False)
102
+ class Return(Statement):
103
+ name = "return"
104
+ traits = frozenset({IsTerminator(), HasParent((Function,))})
105
+ value: SSAValue = info.argument()
106
+
107
+ def __init__(self, value_or_stmt: SSAValue | Statement | None = None) -> None:
108
+ if isinstance(value_or_stmt, SSAValue):
109
+ args = [value_or_stmt]
110
+ elif isinstance(value_or_stmt, Statement):
111
+ if len(value_or_stmt._results) == 1:
112
+ args = [value_or_stmt._results[0]]
113
+ else:
114
+ raise ValueError(
115
+ f"expected a single result, got {len(value_or_stmt._results)} results from {value_or_stmt.name}"
116
+ )
117
+ elif value_or_stmt is None:
118
+ args = []
119
+ else:
120
+ raise ValueError(f"expected SSAValue or Statement, got {value_or_stmt}")
121
+
122
+ super().__init__(args=args, args_slice={"value": 0})
123
+
124
+ def print_impl(self, printer: Printer) -> None:
125
+ with printer.rich(style="keyword"):
126
+ printer.print_name(self)
127
+
128
+ if self.args:
129
+ printer.plain_print(" ")
130
+ printer.print_seq(self.args, delim=", ")
131
+
132
+ def verify(self) -> None:
133
+ if not self.args:
134
+ raise VerificationError(
135
+ self, "return statement must have at least one value"
136
+ )
137
+
138
+ if len(self.args) > 1:
139
+ raise VerificationError(
140
+ self,
141
+ "return statement must have at most one value"
142
+ ", wrap multiple values in a tuple",
143
+ )
144
+
145
+
146
+ @statement(dialect=dialect)
147
+ class Lambda(Statement):
148
+ name = "lambda"
149
+ traits = frozenset(
150
+ {
151
+ Pure(),
152
+ HasSignature(),
153
+ SymbolOpInterface(),
154
+ FuncOpCallableInterface(),
155
+ SSACFGRegion(),
156
+ }
157
+ )
158
+ sym_name: str = info.attribute()
159
+ signature: Signature = info.attribute()
160
+ captured: tuple[SSAValue, ...] = info.argument()
161
+ body: Region = info.region(multi=True)
162
+ result: ResultValue = info.result(MethodType)
163
+
164
+ def verify(self) -> None:
165
+ if self.body.blocks.isempty():
166
+ raise VerificationError(self, "lambda body must not be empty")
167
+
168
+ def print_impl(self, printer: Printer) -> None:
169
+ with printer.rich(style="keyword"):
170
+ printer.print_name(self)
171
+ printer.plain_print(" ")
172
+
173
+ with printer.rich(style="symbol"):
174
+ printer.plain_print(self.sym_name)
175
+
176
+ printer.print_seq(self.captured, prefix="(", suffix=")", delim=", ")
177
+
178
+ with printer.rich(style="bright_black"):
179
+ printer.plain_print(" -> ")
180
+ printer.print(self.signature.output)
181
+
182
+ printer.plain_print(" ")
183
+ printer.print(self.body)
184
+
185
+ with printer.rich(style="black"):
186
+ printer.plain_print(f" // func.lambda {self.sym_name}")
187
+
188
+
189
+ @statement(dialect=dialect)
190
+ class GetField(Statement):
191
+ name = "getfield"
192
+ traits = frozenset({Pure()})
193
+ obj: SSAValue = info.argument(MethodType)
194
+ field: int = info.attribute()
195
+ # NOTE: mypy somehow doesn't understand default init=False
196
+ result: ResultValue = info.result(init=False)
197
+
198
+ def print_impl(self, printer: Printer) -> None:
199
+ printer.print_name(self)
200
+ printer.plain_print(
201
+ "(", printer.state.ssa_id[self.obj], ", ", str(self.field), ")"
202
+ )
203
+ with printer.rich(style="black"):
204
+ printer.plain_print(" : ")
205
+ printer.print(self.result.type)
206
+
207
+
208
+ @statement(dialect=dialect)
209
+ class Invoke(Statement):
210
+ name = "invoke"
211
+ traits = frozenset({MaybePure()})
212
+ callee: Method = info.attribute()
213
+ inputs: tuple[SSAValue, ...] = info.argument()
214
+ kwargs: tuple[str, ...] = info.attribute()
215
+ result: ResultValue = info.result()
216
+ purity: bool = info.attribute(default=False)
217
+
218
+ def print_impl(self, printer: Printer) -> None:
219
+ pprint_calllike(self, self.callee.sym_name, printer)
220
+
221
+ def verify(self) -> None:
222
+ if self.kwargs:
223
+ for name in self.kwargs:
224
+ if name not in self.callee.arg_names:
225
+ raise VerificationError(
226
+ self,
227
+ f"method {self.callee.sym_name} does not have argument {name}",
228
+ )
229
+ elif len(self.callee.arg_names) - 1 != len(self.args):
230
+ raise VerificationError(
231
+ self,
232
+ f"expected {len(self.callee.arg_names)} arguments, got {len(self.args)}",
233
+ )
@@ -0,0 +1,124 @@
1
+ from typing import Iterable, cast
2
+
3
+ from kirin import ir, types
4
+ from kirin.interp import Frame, MethodTable, ReturnValue, impl
5
+ from kirin.analysis import const
6
+ from kirin.analysis.typeinfer import TypeInference, TypeResolution
7
+ from kirin.dialects.func.stmts import (
8
+ Call,
9
+ Invoke,
10
+ Lambda,
11
+ Return,
12
+ GetField,
13
+ ConstantNone,
14
+ )
15
+ from kirin.dialects.func.dialect import dialect
16
+
17
+
18
+ # NOTE: a lot of the type infer rules are same as the builtin dialect
19
+ @dialect.register(key="typeinfer")
20
+ class TypeInfer(MethodTable):
21
+
22
+ @impl(ConstantNone)
23
+ def const_none(self, interp: TypeInference, frame: Frame, stmt: ConstantNone):
24
+ return (types.NoneType,)
25
+
26
+ @impl(Return)
27
+ def return_(self, interp: TypeInference, frame: Frame, stmt: Return) -> ReturnValue:
28
+ if (
29
+ isinstance(hint := stmt.value.hints.get("const"), const.Value)
30
+ and hint.data is not None
31
+ ):
32
+ return ReturnValue(types.Literal(hint.data))
33
+ return ReturnValue(frame.get(stmt.value))
34
+
35
+ @impl(Call)
36
+ def call(self, interp: TypeInference, frame: Frame, stmt: Call):
37
+ # give up on dynamic method calls
38
+ mt = interp.maybe_const(stmt.callee, ir.Method)
39
+ if mt is None:
40
+ return self._solve_method_type(interp, frame, stmt)
41
+ return self._invoke_method(
42
+ interp,
43
+ frame,
44
+ mt,
45
+ stmt.args[1:],
46
+ interp.permute_values(
47
+ mt.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
48
+ ),
49
+ )
50
+
51
+ def _solve_method_type(self, interp: TypeInference, frame: Frame, stmt: Call):
52
+ mt_inferred = frame.get(stmt.callee)
53
+ if not isinstance(mt_inferred, types.Generic):
54
+ return (types.Bottom,)
55
+
56
+ if len(mt_inferred.vars) != 2:
57
+ return (types.Bottom,)
58
+
59
+ args = mt_inferred.vars[0]
60
+ result = mt_inferred.vars[1]
61
+ if not args.is_subseteq(types.Tuple):
62
+ return (types.Bottom,)
63
+
64
+ resolve = TypeResolution()
65
+ args = cast(types.Generic, args)
66
+ for arg, value in zip(args.vars, frame.get_values(stmt.inputs)):
67
+ resolve.solve(arg, value)
68
+ return (resolve.substitute(result),)
69
+
70
+ @impl(Invoke)
71
+ def invoke(self, interp: TypeInference, frame: Frame, stmt: Invoke):
72
+ return self._invoke_method(
73
+ interp,
74
+ frame,
75
+ stmt.callee,
76
+ stmt.inputs,
77
+ interp.permute_values(
78
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
79
+ ),
80
+ )
81
+
82
+ def _invoke_method(
83
+ self,
84
+ interp: TypeInference,
85
+ frame: Frame,
86
+ mt: ir.Method,
87
+ args: Iterable[ir.SSAValue],
88
+ values: tuple,
89
+ ):
90
+ if mt.inferred: # so we don't end up in infinite loop
91
+ return (mt.return_type,)
92
+
93
+ # NOTE: narrowing the argument type based on method signature
94
+ inputs = tuple(
95
+ typ.meet(input_typ) for typ, input_typ in zip(mt.arg_types, values)
96
+ )
97
+
98
+ # NOTE: we use lower bound here because function call contains an
99
+ # implicit type check at call site. This will be validated either compile time
100
+ # or runtime.
101
+ # update the results with the narrowed types
102
+ frame.set_values(args, inputs)
103
+ _, ret = interp.run_method(mt, inputs)
104
+ return (ret,)
105
+
106
+ @impl(Lambda)
107
+ def lambda_(
108
+ self, interp_: TypeInference, frame: Frame[types.TypeAttribute], stmt: Lambda
109
+ ):
110
+ body_frame, ret = interp_.run_callable(
111
+ stmt,
112
+ (types.MethodType,)
113
+ + tuple(arg.type for arg in stmt.body.blocks[0].args[1:]),
114
+ )
115
+ argtypes = tuple(arg.type for arg in stmt.body.blocks[0].args[1:])
116
+ ret = types.MethodType[[*argtypes], ret]
117
+ frame.entries.update(body_frame.entries) # pass results back to upper frame
118
+ self_ = stmt.body.blocks[0].args[0]
119
+ frame.set(self_, ret)
120
+ return (ret,)
121
+
122
+ @impl(GetField)
123
+ def getfield(self, interp: TypeInference, frame, stmt: GetField):
124
+ return (stmt.result.type,)
@@ -0,0 +1,33 @@
1
+ """
2
+ Immutable list dialect for Python.
3
+
4
+ This dialect provides a simple, immutable list dialect similar
5
+ to Python's built-in list type.
6
+ """
7
+
8
+ from . import (
9
+ interp as interp,
10
+ rewrite as rewrite,
11
+ lowering as lowering,
12
+ typeinfer as typeinfer,
13
+ )
14
+ from .stmts import (
15
+ Map as Map,
16
+ New as New,
17
+ Push as Push,
18
+ Scan as Scan,
19
+ Foldl as Foldl,
20
+ Foldr as Foldr,
21
+ ForEach as ForEach,
22
+ IListType as IListType,
23
+ )
24
+ from .passes import IListDesugar as IListDesugar
25
+ from .runtime import IList as IList
26
+ from ._dialect import dialect as dialect
27
+ from ._wrapper import (
28
+ map as map,
29
+ scan as scan,
30
+ foldl as foldl,
31
+ foldr as foldr,
32
+ for_each as for_each,
33
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("py.ilist")
@@ -0,0 +1,51 @@
1
+ import typing
2
+
3
+ from kirin.lowering import wraps
4
+
5
+ from . import stmts
6
+ from .runtime import IList
7
+
8
+ ElemT = typing.TypeVar("ElemT")
9
+ OutElemT = typing.TypeVar("OutElemT")
10
+ LenT = typing.TypeVar("LenT")
11
+ ResultT = typing.TypeVar("ResultT")
12
+
13
+ # NOTE: we use Callable here to make nested function work.
14
+
15
+
16
+ @wraps(stmts.Map)
17
+ def map(
18
+ fn: typing.Callable[[ElemT], OutElemT],
19
+ collection: IList[ElemT, LenT] | list[ElemT],
20
+ ) -> IList[OutElemT, LenT]: ...
21
+
22
+
23
+ @wraps(stmts.Foldr)
24
+ def foldr(
25
+ fn: typing.Callable[[ElemT, OutElemT], OutElemT],
26
+ collection: IList[ElemT, LenT] | list[ElemT],
27
+ init: OutElemT,
28
+ ) -> OutElemT: ...
29
+
30
+
31
+ @wraps(stmts.Foldl)
32
+ def foldl(
33
+ fn: typing.Callable[[OutElemT, ElemT], OutElemT],
34
+ collection: IList[ElemT, LenT] | list[ElemT],
35
+ init: OutElemT,
36
+ ) -> OutElemT: ...
37
+
38
+
39
+ @wraps(stmts.Scan)
40
+ def scan(
41
+ fn: typing.Callable[[OutElemT, ElemT], tuple[OutElemT, ResultT]],
42
+ collection: IList[ElemT, LenT] | list[ElemT],
43
+ init: OutElemT,
44
+ ) -> tuple[OutElemT, IList[ResultT, LenT]]: ...
45
+
46
+
47
+ @wraps(stmts.ForEach)
48
+ def for_each(
49
+ fn: typing.Callable[[ElemT], typing.Any],
50
+ collection: IList[ElemT, LenT] | list[ElemT],
51
+ ) -> None: ...
@@ -0,0 +1,85 @@
1
+ from kirin import ir, types
2
+ from kirin.interp import Frame, Interpreter, MethodTable, impl
3
+ from kirin.dialects.py.len import Len
4
+ from kirin.dialects.py.binop import Add
5
+ from kirin.dialects.py.range import Range
6
+
7
+ from .stmts import Map, New, Push, Scan, Foldl, Foldr, ForEach
8
+ from .runtime import IList
9
+ from ._dialect import dialect
10
+
11
+
12
+ @dialect.register
13
+ class IListInterpreter(MethodTable):
14
+
15
+ @impl(Range)
16
+ def _range(self, interp, frame: Frame, stmt: Range):
17
+ return (IList(range(*frame.get_values(stmt.args))),)
18
+
19
+ @impl(New)
20
+ def new(self, interp, frame: Frame, stmt: New):
21
+ return (IList(list(frame.get_values(stmt.values))),)
22
+
23
+ @impl(Len, types.PyClass(IList))
24
+ def len(self, interp, frame: Frame, stmt: Len):
25
+ return (len(frame.get(stmt.value).data),)
26
+
27
+ @impl(Add, types.PyClass(IList), types.PyClass(IList))
28
+ def add(self, interp, frame: Frame, stmt: Add):
29
+ return (IList(frame.get(stmt.lhs).data + frame.get(stmt.rhs).data),)
30
+
31
+ @impl(Push)
32
+ def push(self, interp, frame: Frame, stmt: Push):
33
+ return (IList(frame.get(stmt.lst).data + [frame.get(stmt.value)]),)
34
+
35
+ @impl(Map)
36
+ def map(self, interp: Interpreter, frame: Frame, stmt: Map):
37
+ fn: ir.Method = frame.get(stmt.fn)
38
+ coll: IList = frame.get(stmt.collection)
39
+ ret = []
40
+ for elem in coll.data:
41
+ # NOTE: assume fn has been type checked
42
+ _, item = interp.run_method(fn, (elem,))
43
+ ret.append(item)
44
+ return (IList(ret),)
45
+
46
+ @impl(Scan)
47
+ def scan(self, interp: Interpreter, frame: Frame, stmt: Scan):
48
+ fn: ir.Method = frame.get(stmt.fn)
49
+ init = frame.get(stmt.init)
50
+ coll: IList = frame.get(stmt.collection)
51
+
52
+ carry = init
53
+ ys = []
54
+ for elem in coll.data:
55
+ # NOTE: assume fn has been type checked
56
+ _, (carry, y) = interp.run_method(fn, (carry, elem))
57
+ ys.append(y)
58
+ return ((carry, IList(ys)),)
59
+
60
+ @impl(Foldr)
61
+ def foldr(self, interp: Interpreter, frame: Frame, stmt: Foldr):
62
+ return self.fold(interp, frame, stmt, reversed(frame.get(stmt.collection).data))
63
+
64
+ @impl(Foldl)
65
+ def foldl(self, interp: Interpreter, frame: Frame, stmt: Foldl):
66
+ return self.fold(interp, frame, stmt, frame.get(stmt.collection).data)
67
+
68
+ def fold(self, interp: Interpreter, frame: Frame, stmt: Foldr | Foldl, coll):
69
+ fn: ir.Method = frame.get(stmt.fn)
70
+ init = frame.get(stmt.init)
71
+
72
+ acc = init
73
+ for elem in coll:
74
+ # NOTE: assume fn has been type checked
75
+ _, acc = interp.run_method(fn, (acc, elem))
76
+ return (acc,)
77
+
78
+ @impl(ForEach)
79
+ def for_each(self, interp: Interpreter, frame: Frame, stmt: ForEach):
80
+ fn: ir.Method = frame.get(stmt.fn)
81
+ coll: IList = frame.get(stmt.collection)
82
+ for elem in coll.data:
83
+ # NOTE: assume fn has been type checked
84
+ interp.run_method(fn, (elem,))
85
+ return (None,)
@@ -0,0 +1,25 @@
1
+ import ast
2
+
3
+ from kirin import types
4
+ from kirin.lowering import Result, FromPythonAST, LoweringState
5
+
6
+ from . import stmts as ilist
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register
11
+ class PythonLowering(FromPythonAST):
12
+
13
+ def lower_List(self, state: LoweringState, node: ast.List) -> Result:
14
+ elts = tuple(state.visit(each).expect_one() for each in node.elts)
15
+
16
+ if len(elts):
17
+ typ = elts[0].type
18
+ for each in elts:
19
+ typ = typ.join(each.type)
20
+ else:
21
+ typ = types.Any
22
+
23
+ stmt = ilist.New(values=tuple(elts))
24
+ state.append_stmt(stmt)
25
+ return Result(stmt)
@@ -0,0 +1,32 @@
1
+ from kirin import ir, types
2
+ from kirin.rewrite import Walk, Chain, Fixpoint
3
+ from kirin.passes.abc import Pass
4
+ from kirin.rewrite.result import RewriteResult
5
+ from kirin.dialects.ilist.rewrite import List2IList, ConstList2IList
6
+
7
+
8
+ class IListDesugar(Pass):
9
+ """This pass desugars the Python list dialect
10
+ to the immutable list dialect by rewriting all
11
+ constant `list` type into `IList` type.
12
+ """
13
+
14
+ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
15
+ for arg in mt.args:
16
+ _check_list(arg.type, arg.type)
17
+ return Fixpoint(Walk(Chain(ConstList2IList(), List2IList()))).rewrite(mt.code)
18
+
19
+
20
+ def _check_list(total: types.TypeAttribute, type_: types.TypeAttribute):
21
+ if isinstance(type_, types.Generic):
22
+ _check_list(total, type_.body)
23
+ for var in type_.vars:
24
+ _check_list(total, var)
25
+ if type_.vararg:
26
+ _check_list(total, type_.vararg.typ)
27
+ elif isinstance(type_, types.PyClass):
28
+ if issubclass(type_.typ, list):
29
+ raise TypeError(
30
+ f"Invalid type {total} for this kernel, use IList instead of {type_}."
31
+ )
32
+ return
@@ -0,0 +1,3 @@
1
+ from .list import List2IList as List2IList
2
+ from .const import ConstList2IList as ConstList2IList
3
+ from .unroll import Unroll as Unroll
@@ -0,0 +1,45 @@
1
+ from kirin import ir, types
2
+ from kirin.analysis import const
3
+ from kirin.rewrite.abc import RewriteRule
4
+ from kirin.rewrite.result import RewriteResult
5
+
6
+ from ..stmts import IListType
7
+ from ..runtime import IList
8
+
9
+
10
+ class ConstList2IList(RewriteRule):
11
+ """Rewrite type annotation for SSAValue with constant `IList`
12
+ in `Hinted` type. This should be run after constant folding and
13
+ `WrapConst` rule.
14
+ """
15
+
16
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
17
+ has_done_something = False
18
+ for result in node.results:
19
+ if not isinstance(hint := result.hints.get("const"), const.Value):
20
+ continue
21
+
22
+ typ = result.type
23
+ data = hint.data
24
+ if isinstance(typ, types.PyClass) and typ.is_subseteq(types.PyClass(IList)):
25
+ has_done_something = self._rewrite_IList_type(result, data)
26
+ elif isinstance(typ, types.Generic) and typ.body.is_subseteq(
27
+ types.PyClass(IList)
28
+ ):
29
+ has_done_something = self._rewrite_IList_type(result, data)
30
+ return RewriteResult(has_done_something=has_done_something)
31
+
32
+ def _rewrite_IList_type(self, result: ir.SSAValue, data):
33
+ if not isinstance(data, IList):
34
+ return False
35
+
36
+ if not data.data:
37
+ return False
38
+
39
+ elem_type = types.PyClass(type(data[0]))
40
+ for elem in data.data[1:]:
41
+ elem_type = elem_type.join(types.PyClass(type(elem)))
42
+
43
+ result.type = IListType[elem_type, types.Literal(len(data.data))]
44
+ result.hints["const"] = const.Value(data)
45
+ return True
@@ -0,0 +1,38 @@
1
+ from kirin import ir, types
2
+ from kirin.dialects.py import constant
3
+ from kirin.rewrite.abc import RewriteRule
4
+ from kirin.rewrite.result import RewriteResult
5
+ from kirin.dialects.ilist.stmts import IListType
6
+ from kirin.dialects.ilist.runtime import IList
7
+
8
+
9
+ class List2IList(RewriteRule):
10
+
11
+ def rewrite_Block(self, node: ir.Block) -> RewriteResult:
12
+ has_done_something = False
13
+ for arg in node.args:
14
+ has_done_something = self._rewrite_SSAValue_type(arg)
15
+ return RewriteResult(has_done_something=has_done_something)
16
+
17
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
18
+ has_done_something = False
19
+ for result in node.results:
20
+ has_done_something = self._rewrite_SSAValue_type(result)
21
+
22
+ if has_done_something and isinstance(node, constant.Constant):
23
+ node.replace_by(constant.Constant(value=IList(data=node.value)))
24
+
25
+ return RewriteResult(has_done_something=has_done_something)
26
+
27
+ def _rewrite_SSAValue_type(self, value: ir.SSAValue):
28
+ # NOTE: cannot use issubseteq here because type can be Bottom
29
+ if isinstance(value.type, types.Generic) and issubclass(
30
+ value.type.body.typ, list
31
+ ):
32
+ value.type = IListType[value.type.vars[0], types.Any]
33
+ return True
34
+
35
+ elif isinstance(value.type, types.PyClass) and issubclass(value.type.typ, list):
36
+ value.type = IListType[types.Any, types.Any]
37
+ return True
38
+ return False