kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,183 @@
1
+ import ast
2
+ import sys
3
+
4
+ from kirin.ir import Method, SSAValue
5
+ from kirin.lowering.state import LoweringState
6
+ from kirin.lowering.result import Result
7
+
8
+ class FromPythonAST:
9
+ @property
10
+ def names(self) -> list[str]: ...
11
+ def lower(self, state: LoweringState, node: ast.AST) -> Result: ...
12
+ def unreachable(self, state: LoweringState, node: ast.AST) -> Result: ...
13
+ def lower_Module(self, state: LoweringState, node: ast.Module) -> Result: ...
14
+ def lower_Interactive(
15
+ self, state: LoweringState, node: ast.Interactive
16
+ ) -> Result: ...
17
+ def lower_Expression(
18
+ self, state: LoweringState, node: ast.Expression
19
+ ) -> Result: ...
20
+ def lower_FunctionDef(
21
+ self, state: LoweringState, node: ast.FunctionDef
22
+ ) -> Result: ...
23
+ def lower_AsyncFunctionDef(
24
+ self, state: LoweringState, node: ast.AsyncFunctionDef
25
+ ) -> Result: ...
26
+ def lower_ClassDef(self, state: LoweringState, node: ast.ClassDef) -> Result: ...
27
+ def lower_Return(self, state: LoweringState, node: ast.Return) -> Result: ...
28
+ def lower_Delete(self, state: LoweringState, node: ast.Delete) -> Result: ...
29
+ def lower_Assign(self, state: LoweringState, node: ast.Assign) -> Result: ...
30
+ def lower_AugAssign(self, state: LoweringState, node: ast.AugAssign) -> Result: ...
31
+ def lower_AnnAssign(self, state: LoweringState, node: ast.AnnAssign) -> Result: ...
32
+ def lower_For(self, state: LoweringState, node: ast.For) -> Result: ...
33
+ def lower_AsyncFor(self, state: LoweringState, node: ast.AsyncFor) -> Result: ...
34
+ def lower_While(self, state: LoweringState, node: ast.While) -> Result: ...
35
+ def lower_If(self, state: LoweringState, node: ast.If) -> Result: ...
36
+ def lower_With(self, state: LoweringState, node: ast.With) -> Result: ...
37
+ def lower_AsyncWith(self, state: LoweringState, node: ast.AsyncWith) -> Result: ...
38
+ def lower_Raise(self, state: LoweringState, node: ast.Raise) -> Result: ...
39
+ def lower_Try(self, state: LoweringState, node: ast.Try) -> Result: ...
40
+ def lower_Assert(self, state: LoweringState, node: ast.Assert) -> Result: ...
41
+ def lower_Import(self, state: LoweringState, node: ast.Import) -> Result: ...
42
+ def lower_ImportFrom(
43
+ self, state: LoweringState, node: ast.ImportFrom
44
+ ) -> Result: ...
45
+ def lower_Global(self, state: LoweringState, node: ast.Global) -> Result: ...
46
+ def lower_Nonlocal(self, state: LoweringState, node: ast.Nonlocal) -> Result: ...
47
+ def lower_Expr(self, state: LoweringState, node: ast.Expr) -> Result: ...
48
+ def lower_Pass(self, state: LoweringState, node: ast.Pass) -> Result: ...
49
+ def lower_Break(self, state: LoweringState, node: ast.Break) -> Result: ...
50
+ def lower_Continue(self, state: LoweringState, node: ast.Continue) -> Result: ...
51
+ def lower_Slice(self, state: LoweringState, node: ast.Slice) -> Result: ...
52
+ def lower_BoolOp(self, state: LoweringState, node: ast.BoolOp) -> Result: ...
53
+ def lower_BinOp(self, state: LoweringState, node: ast.BinOp) -> Result: ...
54
+ def lower_UnaryOp(self, state: LoweringState, node: ast.UnaryOp) -> Result: ...
55
+ def lower_Lambda(self, state: LoweringState, node: ast.Lambda) -> Result: ...
56
+ def lower_IfExp(self, state: LoweringState, node: ast.IfExp) -> Result: ...
57
+ def lower_Dict(self, state: LoweringState, node: ast.Dict) -> Result: ...
58
+ def lower_Set(self, state: LoweringState, node: ast.Set) -> Result: ...
59
+ def lower_ListComp(self, state: LoweringState, node: ast.ListComp) -> Result: ...
60
+ def lower_SetComp(self, state: LoweringState, node: ast.SetComp) -> Result: ...
61
+ def lower_DictComp(self, state: LoweringState, node: ast.DictComp) -> Result: ...
62
+ def lower_GeneratorExp(
63
+ self, state: LoweringState, node: ast.GeneratorExp
64
+ ) -> Result: ...
65
+ def lower_Await(self, state: LoweringState, node: ast.Await) -> Result: ...
66
+ def lower_Yield(self, state: LoweringState, node: ast.Yield) -> Result: ...
67
+ def lower_YieldFrom(self, state: LoweringState, node: ast.YieldFrom) -> Result: ...
68
+ def lower_Compare(self, state: LoweringState, node: ast.Compare) -> Result: ...
69
+ def lower_Call(self, state: LoweringState, node: ast.Call) -> Result: ...
70
+ def lower_Call_builtins(self, state: LoweringState, node: ast.Call) -> Result: ...
71
+ def lower_Call_global_method(
72
+ self, state: LoweringState, method: Method, node: ast.Call
73
+ ) -> Result: ...
74
+ def lower_Call_statement(self, state: LoweringState, node: ast.Call) -> Result: ...
75
+ def lower_Call_slice(self, state: LoweringState, node: ast.Call) -> Result: ...
76
+ def lower_Call_range(self, state: LoweringState, node: ast.Call) -> Result: ...
77
+ def lower_Call_len(self, state: LoweringState, node: ast.Call) -> Result: ...
78
+ def lower_Call_iter(self, state: LoweringState, node: ast.Call) -> Result: ...
79
+ def lower_Call_next(self, state: LoweringState, node: ast.Call) -> Result: ...
80
+ def lower_Call_local(
81
+ self, state: LoweringState, callee: SSAValue, node: ast.Call
82
+ ) -> Result: ...
83
+ def lower_FormattedValue(
84
+ self, state: LoweringState, node: ast.FormattedValue
85
+ ) -> Result: ...
86
+ def lower_JoinedStr(self, state: LoweringState, node: ast.JoinedStr) -> Result: ...
87
+ def lower_Constant(self, state: LoweringState, node: ast.Constant) -> Result: ...
88
+ def lower_NamedExpr(self, state: LoweringState, node: ast.NamedExpr) -> Result: ...
89
+ def lower_TypeIgnore(
90
+ self, state: LoweringState, node: ast.TypeIgnore
91
+ ) -> Result: ...
92
+ def lower_Attribute(self, state: LoweringState, node: ast.Attribute) -> Result: ...
93
+ def lower_Subscript(self, state: LoweringState, node: ast.Subscript) -> Result: ...
94
+ def lower_Starred(self, state: LoweringState, node: ast.Starred) -> Result: ...
95
+ def lower_Name(self, state: LoweringState, node: ast.Name) -> Result: ...
96
+ def lower_List(self, state: LoweringState, node: ast.List) -> Result: ...
97
+ def lower_Tuple(self, state: LoweringState, node: ast.Tuple) -> Result: ...
98
+ def lower_Del(self, state: LoweringState, node: ast.Del) -> Result: ...
99
+ def lower_Load(self, state: LoweringState, node: ast.Load) -> Result: ...
100
+ def lower_Store(self, state: LoweringState, node: ast.Store) -> Result: ...
101
+ def lower_And(self, state: LoweringState, node: ast.And) -> Result: ...
102
+ def lower_Or(self, state: LoweringState, node: ast.Or) -> Result: ...
103
+ def lower_Add(self, state: LoweringState, node: ast.Add) -> Result: ...
104
+ def lower_BitAnd(self, state: LoweringState, node: ast.BitAnd) -> Result: ...
105
+ def lower_BitOr(self, state: LoweringState, node: ast.BitOr) -> Result: ...
106
+ def lower_BitXor(self, state: LoweringState, node: ast.BitXor) -> Result: ...
107
+ def lower_Div(self, state: LoweringState, node: ast.Div) -> Result: ...
108
+ def lower_FloorDiv(self, state: LoweringState, node: ast.FloorDiv) -> Result: ...
109
+ def lower_LShift(self, state: LoweringState, node: ast.LShift) -> Result: ...
110
+ def lower_Mod(self, state: LoweringState, node: ast.Mod) -> Result: ...
111
+ def lower_Mult(self, state: LoweringState, node: ast.Mult) -> Result: ...
112
+ def lower_MatMult(self, state: LoweringState, node: ast.MatMult) -> Result: ...
113
+ def lower_Pow(self, state: LoweringState, node: ast.Pow) -> Result: ...
114
+ def lower_RShift(self, state: LoweringState, node: ast.RShift) -> Result: ...
115
+ def lower_Sub(self, state: LoweringState, node: ast.Sub) -> Result: ...
116
+ def lower_Invert(self, state: LoweringState, node: ast.Invert) -> Result: ...
117
+ def lower_Not(self, state: LoweringState, node: ast.Not) -> Result: ...
118
+ def lower_UAdd(self, state: LoweringState, node: ast.UAdd) -> Result: ...
119
+ def lower_USub(self, state: LoweringState, node: ast.USub) -> Result: ...
120
+ def lower_Eq(self, state: LoweringState, node: ast.Eq) -> Result: ...
121
+ def lower_Gt(self, state: LoweringState, node: ast.Gt) -> Result: ...
122
+ def lower_GtE(self, state: LoweringState, node: ast.GtE) -> Result: ...
123
+ def lower_In(self, state: LoweringState, node: ast.In) -> Result: ...
124
+ def lower_Is(self, state: LoweringState, node: ast.Is) -> Result: ...
125
+ def lower_IsNot(self, state: LoweringState, node: ast.IsNot) -> Result: ...
126
+ def lower_Lt(self, state: LoweringState, node: ast.Lt) -> Result: ...
127
+ def lower_LtE(self, state: LoweringState, node: ast.LtE) -> Result: ...
128
+ def lower_NotEq(self, state: LoweringState, node: ast.NotEq) -> Result: ...
129
+ def lower_NotIn(self, state: LoweringState, node: ast.NotIn) -> Result: ...
130
+ def lower_comprehension(
131
+ self, state: LoweringState, node: ast.comprehension
132
+ ) -> Result: ...
133
+ def lower_ExceptHandler(
134
+ self, state: LoweringState, node: ast.ExceptHandler
135
+ ) -> Result: ...
136
+ def lower_arguments(self, state: LoweringState, node: ast.arguments) -> Result: ...
137
+ def lower_arg(self, state: LoweringState, node: ast.arg) -> Result: ...
138
+ def lower_keyword(self, state: LoweringState, node: ast.keyword) -> Result: ...
139
+ def lower_alias(self, state: LoweringState, node: ast.alias) -> Result: ...
140
+ def lower_withitem(self, state: LoweringState, node: ast.withitem) -> Result: ...
141
+ if sys.version_info >= (3, 10):
142
+ def lower_Match(self, state: LoweringState, node: ast.Match) -> Result: ...
143
+ def lower_match_case(
144
+ self, state: LoweringState, node: ast.match_case
145
+ ) -> Result: ...
146
+ def lower_MatchValue(
147
+ self, state: LoweringState, node: ast.MatchValue
148
+ ) -> Result: ...
149
+ def lower_MatchSequence(
150
+ self, state: LoweringState, node: ast.MatchSequence
151
+ ) -> Result: ...
152
+ def lower_MatchSingleton(
153
+ self, state: LoweringState, node: ast.MatchSingleton
154
+ ) -> Result: ...
155
+ def lower_MatchStar(
156
+ self, state: LoweringState, node: ast.MatchStar
157
+ ) -> Result: ...
158
+ def lower_MatchMapping(
159
+ self, state: LoweringState, node: ast.MatchMapping
160
+ ) -> Result: ...
161
+ def lower_MatchClass(
162
+ self, state: LoweringState, node: ast.MatchClass
163
+ ) -> Result: ...
164
+ def lower_MatchAs(self, state: LoweringState, node: ast.MatchAs) -> Result: ...
165
+ def lower_MatchOr(self, state: LoweringState, node: ast.MatchOr) -> Result: ...
166
+
167
+ if sys.version_info >= (3, 11):
168
+ def lower_TryStar(self, state: LoweringState, node: ast.TryStar) -> Result: ...
169
+
170
+ if sys.version_info >= (3, 12):
171
+ def lower_TypeVar(self, state: LoweringState, node: ast.TypeVar) -> Result: ...
172
+ def lower_ParamSpec(
173
+ self, state: LoweringState, node: ast.ParamSpec
174
+ ) -> Result: ...
175
+ def lower_TypeVarTuple(
176
+ self, state: LoweringState, node: ast.TypeVarTuple
177
+ ) -> Result: ...
178
+ def lower_TypeAlias(
179
+ self, state: LoweringState, node: ast.TypeAlias
180
+ ) -> Result: ...
181
+
182
+ class NoSpecialLowering(FromPythonAST):
183
+ pass
@@ -0,0 +1,171 @@
1
+ import ast
2
+ from typing import TYPE_CHECKING, Any, TypeVar, Callable, Optional, Sequence
3
+ from dataclasses import field, dataclass
4
+
5
+ from kirin.ir import Block, Region, SSAValue, Statement
6
+ from kirin.exceptions import DialectLoweringError
7
+ from kirin.lowering.stream import StmtStream
8
+
9
+ if TYPE_CHECKING:
10
+ from kirin.lowering.state import LoweringState
11
+
12
+
13
+ CallbackFn = Callable[["Frame", SSAValue], SSAValue]
14
+
15
+
16
+ @dataclass
17
+ class Frame:
18
+ state: "LoweringState"
19
+ """lowering state"""
20
+ parent: Optional["Frame"]
21
+ """parent frame, if any"""
22
+ stream: StmtStream[ast.stmt]
23
+ """stream of statements to be lowered"""
24
+
25
+ curr_region: Region
26
+ """current region being lowered"""
27
+ entr_block: Block
28
+ """entry block of the frame region"""
29
+ curr_block: Block
30
+ """current block being lowered"""
31
+ next_block: Block
32
+ """next block to be lowered, but not yet inserted in the region"""
33
+
34
+ # known variables, local SSA values or global values
35
+ defs: dict[str, SSAValue] = field(default_factory=dict)
36
+ """values defined in the current frame"""
37
+ globals: dict[str, Any] = field(default_factory=dict)
38
+ """global values known to the current frame"""
39
+ captures: dict[str, SSAValue] = field(default_factory=dict)
40
+ """values accessed from the parent frame"""
41
+ capture_callback: Optional[CallbackFn] = None
42
+ """callback function that creates a local SSAValue value when an captured value was used."""
43
+
44
+ @classmethod
45
+ def from_stmts(
46
+ cls,
47
+ stmts: Sequence[ast.stmt] | StmtStream[ast.stmt],
48
+ state: "LoweringState",
49
+ parent: Optional["Frame"] = None,
50
+ region: Optional[Region] = None,
51
+ entr_block: Optional[Block] = None,
52
+ next_block: Optional[Block] = None,
53
+ globals: dict[str, Any] | None = None,
54
+ capture_callback: Optional[CallbackFn] = None,
55
+ ):
56
+ """Create a new frame from a list of statements or a new `StmtStream`.
57
+
58
+ - `stmts`: list of statements or a `StmtStream` to be lowered.
59
+ - `region`: `Region` to append the new block to, `None` to create a new one, default `None`.
60
+ - `entr_block`: `Block` to append the new statements to,
61
+ `None` to create a new one and attached to the region, default `None`.
62
+ - `next_block`: `Block` to use if branching to a new block, if `None` to create
63
+ a new one without attaching to the region. (note: this should not attach to
64
+ the region at frame construction)
65
+ - `globals`: global variables, default `None`.
66
+ """
67
+ if not isinstance(stmts, StmtStream):
68
+ stmts = StmtStream(stmts)
69
+
70
+ region = region or Region()
71
+
72
+ entr_block = entr_block or Block()
73
+ region.blocks.append(entr_block)
74
+
75
+ return cls(
76
+ state=state,
77
+ parent=parent,
78
+ stream=stmts,
79
+ curr_region=region or Region(entr_block),
80
+ entr_block=entr_block,
81
+ curr_block=entr_block,
82
+ next_block=next_block or Block(),
83
+ globals=globals or {},
84
+ capture_callback=capture_callback,
85
+ )
86
+
87
+ def get(self, name: str) -> SSAValue | None:
88
+ value = self.get_local(name)
89
+ if value is not None:
90
+ return value
91
+
92
+ # NOTE: look up local first, then globals
93
+ if name in self.globals:
94
+ return self.state.visit(ast.Constant(self.globals[name])).expect_one()
95
+ return None
96
+
97
+ def get_local(self, name: str) -> SSAValue | None:
98
+ if name in self.defs:
99
+ return self.defs[name]
100
+
101
+ if self.parent is None:
102
+ return None # no parent frame, return None
103
+
104
+ value = self.parent.get_local(name)
105
+ if value is not None:
106
+ self.captures[name] = value
107
+ if self.capture_callback:
108
+ # whatever generates a local value gets defined
109
+ ret = self.capture_callback(self, value)
110
+ self.defs[name] = ret
111
+ return ret
112
+ return value
113
+ return None
114
+
115
+ def get_scope(self, name: str):
116
+ """Get a variable from current scope.
117
+
118
+ Args:
119
+ name(str): variable name
120
+
121
+ Returns:
122
+ SSAValue: the value of the variable
123
+
124
+ Raises:
125
+ DialectLoweringError: if the variable is not found in the scope,
126
+ or if the variable has multiple possible values.
127
+ """
128
+ value = self.defs.get(name)
129
+ if isinstance(value, SSAValue):
130
+ return value
131
+ else:
132
+ raise DialectLoweringError(f"Variable {name} not found in scope")
133
+
134
+ StmtType = TypeVar("StmtType", bound=Statement)
135
+
136
+ def append_stmt(self, stmt: StmtType) -> StmtType:
137
+ if not stmt.dialect:
138
+ raise DialectLoweringError(f"unexpected builtin statement {stmt.name}")
139
+ elif stmt.dialect not in self.state.dialects:
140
+ raise DialectLoweringError(
141
+ f"Unsupported dialect `{stmt.dialect.name}` in statement {stmt.name}"
142
+ )
143
+ self.curr_block.stmts.append(stmt)
144
+ stmt.source = self.state.source
145
+ return stmt
146
+
147
+ def jump_next(self):
148
+ """Jump to the next block and return it.
149
+ This appends the current `Frame.next_block` to the current region
150
+ and creates a new Block for `next_block`.
151
+
152
+ Returns:
153
+ Block: the next block
154
+ """
155
+ block = self.append_block(self.next_block)
156
+ self.next_block = Block()
157
+ return block
158
+
159
+ def append_block(self, block: Block | None = None):
160
+ """Append a block to the current region.
161
+
162
+ Args:
163
+ block(Block): block to append, default `None` to create a new block.
164
+ """
165
+ block = block or Block()
166
+ self.curr_region.blocks.append(block)
167
+ self.curr_block = block
168
+ return block
169
+
170
+ def __repr__(self):
171
+ return f"Frame({len(self.defs)} defs, {len(self.globals)} globals)"
@@ -0,0 +1,68 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Iterable, Sequence, overload
4
+ from dataclasses import field, dataclass
5
+
6
+ from kirin.ir import SSAValue, Statement
7
+ from kirin.exceptions import DialectLoweringError
8
+
9
+
10
+ @dataclass(init=False)
11
+ class Result(Sequence[SSAValue]):
12
+ values: Sequence[SSAValue] = field(default_factory=list)
13
+
14
+ @overload
15
+ def __init__(self, value: None = None) -> None: ...
16
+
17
+ @overload
18
+ def __init__(self, value: SSAValue, *values: SSAValue) -> None: ...
19
+
20
+ @overload
21
+ def __init__(self, value: Iterable[SSAValue]) -> None: ...
22
+
23
+ @overload
24
+ def __init__(self, value: Statement) -> None: ...
25
+
26
+ def __init__(
27
+ self,
28
+ value: SSAValue | Iterable[SSAValue] | Statement | None = None,
29
+ *values: SSAValue,
30
+ ) -> None:
31
+ if value is None:
32
+ assert not values, "unexpected values"
33
+ self.values = []
34
+ elif isinstance(value, SSAValue):
35
+ self.values = [value, *values]
36
+ elif isinstance(value, Statement):
37
+ assert not values, "unexpected values"
38
+ self.values = value._results
39
+ else:
40
+ assert not values, "unexpected values"
41
+ self.values = list(value)
42
+
43
+ def expect_one(self) -> SSAValue:
44
+ if len(self.values) != 1:
45
+ raise DialectLoweringError("expected one result")
46
+ return self.values[0]
47
+
48
+ # forward the sequence methods
49
+ def __len__(self):
50
+ return len(self.values)
51
+
52
+ def __getitem__(self, key):
53
+ return self.values[key]
54
+
55
+ def __iter__(self):
56
+ return iter(self.values)
57
+
58
+ def __contains__(self, value):
59
+ return value in self.values
60
+
61
+ def __reversed__(self):
62
+ return reversed(self.values)
63
+
64
+ def __eq__(self, other):
65
+ return self.values == other.values
66
+
67
+ def __ne__(self, other):
68
+ return self.values != other.values