kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,131 @@
1
+ from kirin import ir, types
2
+ from kirin.rewrite.abc import RewriteRule
3
+ from kirin.rewrite.result import RewriteResult
4
+ from kirin.dialects.py.tuple import New as TupleNew
5
+ from kirin.dialects.func.stmts import Call
6
+ from kirin.dialects.ilist.stmts import Map, New, Scan, Foldl, Foldr, ForEach, IListType
7
+ from kirin.dialects.py.constant import Constant
8
+ from kirin.dialects.py.indexing import GetItem
9
+
10
+
11
+ class Unroll(RewriteRule):
12
+
13
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14
+ return getattr(
15
+ self, f"rewrite_{node.__class__.__name__}", self.rewrite_fallback
16
+ )(node)
17
+
18
+ def rewrite_fallback(self, node: ir.Statement) -> RewriteResult:
19
+ return RewriteResult()
20
+
21
+ def _get_collection_len(self, collection: ir.SSAValue):
22
+ coll_type = collection.type
23
+ if not isinstance(coll_type, types.Generic):
24
+ return None
25
+
26
+ if not coll_type.is_subseteq(IListType):
27
+ return None
28
+
29
+ if not (
30
+ isinstance(coll_type.vars[1], types.Literal)
31
+ and isinstance(coll_type.vars[1].data, int)
32
+ ):
33
+ return None
34
+
35
+ return coll_type.vars[1].data
36
+
37
+ def rewrite_Map(self, node: Map) -> RewriteResult:
38
+ # NOTE: if node.collection is a constant, we can
39
+ # just leave it because Map is pure, and this will
40
+ # be folded.
41
+ if (coll_len := self._get_collection_len(node.collection)) is None:
42
+ return RewriteResult()
43
+
44
+ new_elems: list[ir.SSAValue] = []
45
+ for elt_idx in range(coll_len):
46
+ index = Constant(elt_idx)
47
+ index.insert_before(node)
48
+ elt = GetItem(node.collection, index.result)
49
+ elt.insert_before(node)
50
+ fn_call = Call(node.fn, (elt.result,))
51
+ fn_call.insert_before(node)
52
+ new_elems.append(fn_call.result)
53
+
54
+ node.replace_by(New(values=tuple(new_elems)))
55
+ return RewriteResult(has_done_something=True)
56
+
57
+ def rewrite_Scan(self, node: Scan) -> RewriteResult:
58
+ if (coll_len := self._get_collection_len(node.collection)) is None:
59
+ return RewriteResult()
60
+
61
+ index_0 = Constant(0)
62
+ index_1 = Constant(1)
63
+ # index_0.result.name = "idx0"
64
+ # index_1.result.name = "idx1"
65
+ index_0.insert_before(node)
66
+ index_1.insert_before(node)
67
+ carry = node.init
68
+ ys: list[ir.SSAValue] = []
69
+ for elem_idx in range(coll_len):
70
+ index = Constant(elem_idx)
71
+ # index.result.name = f"idx_{elem_idx}"
72
+ elt = GetItem(node.collection, index.result)
73
+ fn_call = Call(node.fn, (carry, elt.result))
74
+ carry_stmt = GetItem(fn_call.result, index_0.result)
75
+ y_stmt = GetItem(fn_call.result, index_1.result)
76
+ carry = carry_stmt.result
77
+ ys.append(y_stmt.result)
78
+
79
+ index.insert_before(node)
80
+ elt.insert_before(node)
81
+ fn_call.insert_before(node)
82
+ carry_stmt.insert_before(node)
83
+ y_stmt.insert_before(node)
84
+
85
+ ys_stmt = New(values=tuple(ys))
86
+ ys_stmt.insert_before(node)
87
+ ret = TupleNew(values=(carry, ys_stmt.result))
88
+ node.replace_by(ret)
89
+ return RewriteResult(has_done_something=True)
90
+
91
+ def rewrite_Foldr(self, node: Foldr) -> RewriteResult:
92
+ return self._rewrite_fold(node, True)
93
+
94
+ def rewrite_Foldl(self, node: Foldl) -> RewriteResult:
95
+ return self._rewrite_fold(node, False)
96
+
97
+ def _rewrite_fold(self, node: Foldr | Foldl, reversed: bool) -> RewriteResult:
98
+ if (coll_len := self._get_collection_len(node.collection)) is None:
99
+ return RewriteResult()
100
+
101
+ acc = node.init
102
+ for elem_idx in range(coll_len):
103
+ if reversed:
104
+ elem_idx = coll_len - elem_idx - 1
105
+ index = Constant(elem_idx)
106
+ index.insert_before(node)
107
+ elt = GetItem(node.collection, index.result)
108
+ elt.insert_before(node)
109
+
110
+ acc_stmt = Call(node.fn, (acc, elt.result))
111
+ acc_stmt.insert_before(node)
112
+ acc = acc_stmt.result
113
+
114
+ node.result.replace_by(acc)
115
+ node.delete()
116
+ return RewriteResult(has_done_something=True)
117
+
118
+ def rewrite_ForEach(self, node: ForEach) -> RewriteResult:
119
+ if (coll_len := self._get_collection_len(node.collection)) is None:
120
+ return RewriteResult()
121
+
122
+ for elem_idx in range(coll_len):
123
+ index = Constant(elem_idx)
124
+ index.insert_before(node)
125
+ elt = GetItem(node.collection, index.result)
126
+ elt.insert_before(node)
127
+ fn_call = Call(node.fn, (elt.result,))
128
+ fn_call.insert_before(node)
129
+
130
+ node.delete()
131
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,63 @@
1
+ # TODO: replace with something faster
2
+ from typing import Any, Generic, TypeVar, overload
3
+ from dataclasses import dataclass
4
+ from collections.abc import Sequence
5
+
6
+ T = TypeVar("T")
7
+ L = TypeVar("L")
8
+
9
+
10
+ @dataclass
11
+ class IList(Generic[T, L]):
12
+ """A simple immutable list."""
13
+
14
+ data: Sequence[T]
15
+
16
+ def __hash__(self) -> int:
17
+ return id(self) # do not hash the data
18
+
19
+ def __len__(self) -> int:
20
+ return len(self.data)
21
+
22
+ @overload
23
+ def __add__(self, other: "IList[T, Any]") -> "IList[T, Any]": ...
24
+
25
+ @overload
26
+ def __add__(self, other: list[T]) -> "IList[T, Any]": ...
27
+
28
+ def __add__(self, other):
29
+ return IList(self.data + other)
30
+
31
+ @overload
32
+ def __radd__(self, other: "IList[T, Any]") -> "IList[T, Any]": ...
33
+
34
+ @overload
35
+ def __radd__(self, other: list[T]) -> "IList[T, Any]": ...
36
+
37
+ def __radd__(self, other):
38
+ return IList(other + self.data)
39
+
40
+ def __repr__(self) -> str:
41
+ return f"IList({self.data})"
42
+
43
+ def __str__(self) -> str:
44
+ return f"IList({self.data})"
45
+
46
+ def __iter__(self):
47
+ return iter(self.data)
48
+
49
+ @overload
50
+ def __getitem__(self, index: slice) -> "IList[T, Any]": ...
51
+
52
+ @overload
53
+ def __getitem__(self, index: int) -> T: ...
54
+
55
+ def __getitem__(self, index: int | slice) -> T | "IList[T, Any]":
56
+ if isinstance(index, slice):
57
+ return IList(self.data[index])
58
+ return self.data[index]
59
+
60
+ def __eq__(self, value: object) -> bool:
61
+ if not isinstance(value, IList):
62
+ return False
63
+ return self.data == value.data
@@ -0,0 +1,102 @@
1
+ from typing import Sequence
2
+
3
+ from kirin import ir, types
4
+ from kirin.decl import info, statement
5
+
6
+ from .runtime import IList
7
+ from ._dialect import dialect
8
+
9
+ ElemT = types.TypeVar("ElemT")
10
+ ListLen = types.TypeVar("ListLen")
11
+ IListType = types.Generic(IList, ElemT, ListLen)
12
+
13
+
14
+ @statement(dialect=dialect, init=False)
15
+ class New(ir.Statement):
16
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
17
+ values: tuple[ir.SSAValue, ...] = info.argument(ElemT)
18
+ result: ir.ResultValue = info.result(IListType[ElemT])
19
+
20
+ def __init__(
21
+ self,
22
+ values: Sequence[ir.SSAValue],
23
+ ) -> None:
24
+ # get elem type
25
+ if not values:
26
+ elem_type = types.Any
27
+ else:
28
+ elem_type = values[0].type
29
+ for v in values:
30
+ elem_type = elem_type.join(v.type)
31
+
32
+ result_type = IListType[elem_type, types.Literal(len(values))]
33
+ super().__init__(
34
+ args=values,
35
+ result_types=(result_type,),
36
+ args_slice={"values": slice(0, len(values))},
37
+ )
38
+
39
+
40
+ @statement(dialect=dialect)
41
+ class Push(ir.Statement):
42
+ traits = frozenset({ir.FromPythonCall()})
43
+ lst: ir.SSAValue = info.argument(IListType[ElemT])
44
+ value: ir.SSAValue = info.argument(IListType[ElemT])
45
+ result: ir.ResultValue = info.result(IListType[ElemT])
46
+
47
+
48
+ OutElemT = types.TypeVar("OutElemT")
49
+
50
+
51
+ @statement(dialect=dialect)
52
+ class Map(ir.Statement):
53
+ traits = frozenset({ir.FromPythonCall()})
54
+ fn: ir.SSAValue = info.argument(types.MethodType[[ElemT], OutElemT])
55
+ collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
56
+ result: ir.ResultValue = info.result(IListType[OutElemT, ListLen])
57
+
58
+
59
+ @statement(dialect=dialect)
60
+ class Foldr(ir.Statement):
61
+ traits = frozenset({ir.FromPythonCall()})
62
+ fn: ir.SSAValue = info.argument(
63
+ types.Generic(ir.Method, [ElemT, OutElemT], OutElemT)
64
+ )
65
+ collection: ir.SSAValue = info.argument(IListType[ElemT])
66
+ init: ir.SSAValue = info.argument(OutElemT)
67
+ result: ir.ResultValue = info.result(OutElemT)
68
+
69
+
70
+ @statement(dialect=dialect)
71
+ class Foldl(ir.Statement):
72
+ traits = frozenset({ir.FromPythonCall()})
73
+ fn: ir.SSAValue = info.argument(
74
+ types.Generic(ir.Method, [OutElemT, ElemT], OutElemT)
75
+ )
76
+ collection: ir.SSAValue = info.argument(IListType[ElemT])
77
+ init: ir.SSAValue = info.argument(OutElemT)
78
+ result: ir.ResultValue = info.result(OutElemT)
79
+
80
+
81
+ CarryT = types.TypeVar("CarryT")
82
+ ResultT = types.TypeVar("ResultT")
83
+
84
+
85
+ @statement(dialect=dialect)
86
+ class Scan(ir.Statement):
87
+ traits = frozenset({ir.FromPythonCall()})
88
+ fn: ir.SSAValue = info.argument(
89
+ types.Generic(ir.Method, [OutElemT, ElemT], types.Tuple[OutElemT, ResultT])
90
+ )
91
+ collection: ir.SSAValue = info.argument(IListType[ElemT, ListLen])
92
+ init: ir.SSAValue = info.argument(OutElemT)
93
+ result: ir.ResultValue = info.result(
94
+ types.Tuple[OutElemT, IListType[ResultT, ListLen]]
95
+ )
96
+
97
+
98
+ @statement(dialect=dialect)
99
+ class ForEach(ir.Statement):
100
+ traits = frozenset({ir.FromPythonCall()})
101
+ fn: ir.SSAValue = info.argument(types.Generic(ir.Method, [ElemT], types.NoneType))
102
+ collection: ir.SSAValue = info.argument(IListType[ElemT])
@@ -0,0 +1,120 @@
1
+ from kirin import types
2
+ from kirin.interp import Frame, MethodTable, impl
3
+ from kirin.dialects.eltype import ElType
4
+ from kirin.dialects.py.binop import Add
5
+ from kirin.analysis.typeinfer import TypeInference
6
+ from kirin.dialects.py.indexing import GetItem
7
+
8
+ from .stmts import New, Push, IListType
9
+ from .runtime import IList
10
+ from ._dialect import dialect
11
+
12
+
13
+ @dialect.register(key="typeinfer")
14
+ class TypeInfer(MethodTable):
15
+
16
+ @staticmethod
17
+ def _get_list_len(typ: types.Generic):
18
+ if isinstance(typ.vars[1], types.Literal) and isinstance(typ.vars[1].data, int):
19
+ return typ.vars[1].data
20
+ else:
21
+ return types.Any
22
+
23
+ @impl(ElType, types.PyClass(IList))
24
+ def eltype_list(
25
+ self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: ElType
26
+ ):
27
+ list_type = frame.get(stmt.container)
28
+ if isinstance(list_type, types.Generic):
29
+ return (list_type.vars[0],)
30
+ else:
31
+ return (types.Any,)
32
+
33
+ @impl(New)
34
+ def new(self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: New):
35
+ values = frame.get_values(stmt.values)
36
+ if not values:
37
+ return (IListType[types.Any, types.Literal(0)],)
38
+
39
+ elem_type = values[0]
40
+ for v in values:
41
+ elem_type = elem_type.join(v)
42
+
43
+ return (IListType[elem_type, types.Literal(len(values))],)
44
+
45
+ @impl(Push)
46
+ def push(
47
+ self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: Push
48
+ ):
49
+ lst_type: types.Generic = frame.get(stmt.lst) # type: ignore
50
+ value_type = frame.get(stmt.value)
51
+ if not lst_type.is_subseteq(IListType):
52
+ return (types.Bottom,)
53
+
54
+ if not lst_type.vars[0].is_subseteq(value_type):
55
+ return (types.Bottom,)
56
+
57
+ lst_len = self._get_list_len(lst_type)
58
+ if not isinstance(lst_len, int):
59
+ return (IListType[lst_type.vars[0], types.Any],)
60
+
61
+ return (IListType[lst_type.vars[0], types.Literal(lst_len + 1)],)
62
+
63
+ @impl(Add, types.PyClass(IList), types.PyClass(IList))
64
+ def add(self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: Add):
65
+ lhs_type = frame.get(stmt.lhs)
66
+ rhs_type = frame.get(stmt.rhs)
67
+ if not lhs_type.is_subseteq(IListType) or not rhs_type.is_subseteq(IListType):
68
+ return (types.Bottom,)
69
+
70
+ if not isinstance(lhs_type, types.Generic): # just annotated with list
71
+ lhs_type = IListType[types.Any, types.Any]
72
+
73
+ if not isinstance(rhs_type, types.Generic):
74
+ rhs_type = IListType[types.Any, types.Any]
75
+
76
+ if len(lhs_type.vars) != 2 or len(rhs_type.vars) != 2:
77
+ raise TypeError("missing type argument for list")
78
+
79
+ elem_type = lhs_type.vars[0].join(rhs_type.vars[0])
80
+
81
+ lhs_len = self._get_list_len(lhs_type)
82
+ rhs_len = self._get_list_len(rhs_type)
83
+ if isinstance(lhs_len, int) and isinstance(rhs_len, int):
84
+ return (IListType[elem_type, types.Literal(lhs_len + rhs_len)],)
85
+ return (IListType[elem_type, types.Any],)
86
+
87
+ @impl(GetItem, types.PyClass(IList), types.PyClass(int))
88
+ def getitem(
89
+ self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: GetItem
90
+ ):
91
+ obj_type = frame.get(stmt.obj)
92
+ if not obj_type.is_subseteq(IListType):
93
+ raise TypeError(f"Expected list, got {obj_type}")
94
+
95
+ # just list type
96
+ if not isinstance(obj_type, types.Generic):
97
+ return (types.Any,)
98
+ else:
99
+ return (obj_type.vars[0],)
100
+
101
+ @impl(GetItem, types.PyClass(IList), types.PyClass(slice))
102
+ def getitem_slice(
103
+ self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: GetItem
104
+ ):
105
+ obj_type = frame.get(stmt.obj)
106
+ if not obj_type.is_subseteq(IListType):
107
+ raise TypeError(f"Expected list, got {obj_type}")
108
+
109
+ # just list type
110
+ if not isinstance(obj_type, types.Generic):
111
+ return (IListType[types.Any, types.Any],)
112
+ elif index_ := interp.maybe_const(stmt.index, slice):
113
+ # TODO: actually calculate the size
114
+ obj_len = obj_type.vars[1]
115
+ if not isinstance(obj_len, types.Literal):
116
+ return (IListType[obj_type.vars[0], types.Any],)
117
+ LenT = types.Literal(len(range(obj_len.data)[index_]))
118
+ return (IListType[obj_type.vars[0], LenT],)
119
+ else:
120
+ return (IListType[obj_type.vars[0], types.Any],)
@@ -0,0 +1,7 @@
1
+ """This module contains the dialects for choosing different lowering strategies.
2
+
3
+ The dialects defined inside this module do not provide any new statements, it only
4
+ provide different lowering strategies for existing statements.
5
+ """
6
+
7
+ from . import cf as cf, call as call, func as func
@@ -0,0 +1,48 @@
1
+ import ast
2
+
3
+ from kirin import ir, types, lowering
4
+ from kirin.dialects import func
5
+ from kirin.exceptions import DialectLoweringError
6
+
7
+ dialect = ir.Dialect("lowering.call")
8
+
9
+
10
+ @dialect.register
11
+ class Lowering(lowering.FromPythonAST):
12
+
13
+ def lower_Call_local(
14
+ self, state: lowering.LoweringState, callee: ir.SSAValue, node: ast.Call
15
+ ) -> lowering.Result:
16
+ args, keywords = self.__lower_Call_args_kwargs(state, node)
17
+ stmt = func.Call(callee, args, kwargs=keywords)
18
+ return lowering.Result(state.append_stmt(stmt))
19
+
20
+ def lower_Call_global_method(
21
+ self,
22
+ state: lowering.LoweringState,
23
+ method: ir.Method,
24
+ node: ast.Call,
25
+ ) -> lowering.Result:
26
+ args, keywords = self.__lower_Call_args_kwargs(state, node)
27
+ stmt = func.Invoke(args, callee=method, kwargs=keywords)
28
+ stmt.result.type = method.return_type or types.Any
29
+ return lowering.Result(state.append_stmt(stmt))
30
+
31
+ def __lower_Call_args_kwargs(
32
+ self,
33
+ state: lowering.LoweringState,
34
+ node: ast.Call,
35
+ ):
36
+ args: list[ir.SSAValue] = []
37
+ for arg in node.args:
38
+ if isinstance(arg, ast.Starred): # TODO: support *args
39
+ raise DialectLoweringError("starred arguments are not supported")
40
+ else:
41
+ args.append(state.visit(arg).expect_one())
42
+
43
+ keywords = []
44
+ for kw in node.keywords:
45
+ keywords.append(kw.arg)
46
+ args.append(state.visit(kw.value).expect_one())
47
+
48
+ return tuple(args), tuple(keywords)
@@ -0,0 +1,206 @@
1
+ """Lowering Python AST to cf dialect.
2
+ """
3
+
4
+ import ast
5
+
6
+ from kirin import ir, types
7
+ from kirin.dialects import cf, py
8
+ from kirin.lowering import Frame, Result, FromPythonAST, LoweringState
9
+ from kirin.exceptions import DialectLoweringError
10
+
11
+ dialect = ir.Dialect("lowering.cf")
12
+
13
+
14
+ @dialect.register
15
+ class CfLowering(FromPythonAST):
16
+
17
+ def lower_Pass(self, state: LoweringState, node: ast.Pass) -> Result:
18
+ state.append_stmt(
19
+ cf.Branch(arguments=(), successor=state.current_frame.next_block)
20
+ )
21
+ return Result()
22
+
23
+ def lower_For(self, state: LoweringState, node: ast.For) -> Result:
24
+ yields: list[str] = []
25
+
26
+ def new_block_arg_if_inside_loop(frame: Frame, capture: ir.SSAValue):
27
+ if not capture.name:
28
+ raise DialectLoweringError("unexpected loop variable captured")
29
+ yields.append(capture.name)
30
+ return frame.entr_block.args.append_from(capture.type, capture.name)
31
+
32
+ frame = state.current_frame
33
+ iterable = state.visit(node.iter).expect_one()
34
+ iter_stmt = frame.append_stmt(py.iterable.Iter(iterable))
35
+ none_stmt = frame.append_stmt(py.Constant(None))
36
+
37
+ body_frame = state.push_frame(
38
+ Frame.from_stmts(
39
+ node.body,
40
+ state,
41
+ region=state.current_frame.curr_region,
42
+ globals=state.current_frame.globals,
43
+ capture_callback=new_block_arg_if_inside_loop,
44
+ )
45
+ )
46
+ next_value = body_frame.entr_block.args.append_from(types.Any, "next_value")
47
+ py.unpack.unpacking(state, node.target, next_value)
48
+ state.exhaust(body_frame)
49
+ self.branch_next_if_not_terminated(body_frame)
50
+ yield_args = tuple(body_frame.get_scope(name) for name in yields)
51
+ next_stmt = py.iterable.Next(iter_stmt.iter)
52
+ cond_stmt = py.cmp.Is(next_stmt.value, none_stmt.result)
53
+ body_frame.next_block.stmts.append(next_stmt)
54
+ body_frame.next_block.stmts.append(cond_stmt)
55
+ body_frame.next_block.stmts.append(
56
+ cf.ConditionalBranch(
57
+ cond_stmt.result,
58
+ yield_args,
59
+ (next_stmt.value,) + yield_args,
60
+ then_successor=frame.next_block,
61
+ else_successor=body_frame.entr_block,
62
+ )
63
+ )
64
+ state.pop_frame()
65
+
66
+ # insert the branch to the entrance of the loop (the code block before loop)
67
+ next_stmt = frame.append_stmt(py.iterable.Next(iter_stmt.iter))
68
+ cond_stmt = frame.append_stmt(py.cmp.Is(next_stmt.value, none_stmt.result))
69
+ yield_args = tuple(frame.get_scope(name) for name in yields)
70
+ frame.append_stmt(
71
+ cf.ConditionalBranch(
72
+ cond_stmt.result,
73
+ yield_args,
74
+ (next_stmt.value,) + yield_args,
75
+ then_successor=frame.next_block, # empty iterator
76
+ else_successor=body_frame.entr_block,
77
+ )
78
+ )
79
+
80
+ frame.jump_next()
81
+ for name, arg in zip(yields, yield_args):
82
+ input = frame.curr_block.args.append_from(arg.type, name)
83
+ frame.defs[name] = input
84
+ return Result()
85
+
86
+ def lower_If(self, state: LoweringState, node: ast.If) -> Result:
87
+ cond = state.visit(node.test).expect_one()
88
+ frame = state.current_frame
89
+ before_block = frame.curr_block
90
+ if_frame = state.push_frame(
91
+ Frame.from_stmts(
92
+ node.body,
93
+ state,
94
+ region=frame.curr_region,
95
+ globals=frame.globals,
96
+ )
97
+ )
98
+ true_cond = if_frame.entr_block.args.append_from(types.Bool, cond.name)
99
+ if cond.name:
100
+ if_frame.defs[cond.name] = true_cond
101
+ state.exhaust()
102
+ self.branch_next_if_not_terminated(if_frame)
103
+ state.pop_frame()
104
+
105
+ else_frame = state.push_frame(
106
+ Frame.from_stmts(
107
+ node.orelse,
108
+ state,
109
+ region=frame.curr_region,
110
+ globals=frame.globals,
111
+ )
112
+ )
113
+ true_cond = else_frame.entr_block.args.append_from(types.Bool, cond.name)
114
+ if cond.name:
115
+ else_frame.defs[cond.name] = true_cond
116
+ state.exhaust()
117
+ self.branch_next_if_not_terminated(else_frame)
118
+ state.pop_frame()
119
+
120
+ after_frame = state.push_frame(
121
+ Frame.from_stmts(
122
+ frame.stream.split(),
123
+ state,
124
+ region=frame.curr_region,
125
+ globals=frame.globals,
126
+ )
127
+ )
128
+
129
+ after_frame.defs.update(frame.defs)
130
+ phi: set[str] = set()
131
+ for name in if_frame.defs.keys():
132
+ if frame.get(name):
133
+ phi.add(name)
134
+ elif name in else_frame.defs:
135
+ phi.add(name)
136
+
137
+ for name in else_frame.defs.keys():
138
+ if frame.get(name): # not defined in if_frame
139
+ phi.add(name)
140
+
141
+ for name in phi:
142
+ after_frame.defs[name] = after_frame.entr_block.args.append_from(
143
+ types.Any, name
144
+ )
145
+
146
+ state.exhaust()
147
+ self.branch_next_if_not_terminated(after_frame)
148
+ after_frame.next_block.stmts.append(
149
+ cf.Branch(arguments=(), successor=frame.next_block)
150
+ )
151
+ state.pop_frame()
152
+
153
+ if_args = []
154
+ for name in phi:
155
+ if value := if_frame.get(name):
156
+ if_args.append(value)
157
+ else:
158
+ raise DialectLoweringError(f"undefined variable {name} in if branch")
159
+
160
+ else_args = []
161
+ for name in phi:
162
+ if value := else_frame.get(name):
163
+ else_args.append(value)
164
+ else:
165
+ raise DialectLoweringError(f"undefined variable {name} in else branch")
166
+
167
+ if_frame.next_block.stmts.append(
168
+ cf.Branch(
169
+ arguments=tuple(if_args),
170
+ successor=after_frame.entr_block,
171
+ )
172
+ )
173
+ else_frame.next_block.stmts.append(
174
+ cf.Branch(
175
+ arguments=tuple(else_args),
176
+ successor=after_frame.entr_block,
177
+ )
178
+ )
179
+ before_block.stmts.append(
180
+ cf.ConditionalBranch(
181
+ cond=cond,
182
+ then_arguments=(cond,),
183
+ then_successor=if_frame.entr_block,
184
+ else_arguments=(cond,),
185
+ else_successor=else_frame.entr_block,
186
+ )
187
+ )
188
+ frame.jump_next()
189
+ return Result()
190
+
191
+ def branch_next_if_not_terminated(self, frame: Frame):
192
+ """Branch to the next block if the current block is not terminated.
193
+
194
+ This must be used after exhausting the current frame and before popping the frame.
195
+ """
196
+ if not frame.curr_block.last_stmt or not frame.curr_block.last_stmt.has_trait(
197
+ ir.IsTerminator
198
+ ):
199
+ frame.curr_block.stmts.append(
200
+ cf.Branch(arguments=(), successor=frame.next_block)
201
+ )
202
+
203
+ def current_block_terminated(self, frame: Frame):
204
+ return frame.curr_block.last_stmt and frame.curr_block.last_stmt.has_trait(
205
+ ir.IsTerminator
206
+ )