kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,15 @@
1
+ """The list dialect for Python.
2
+
3
+ This module contains the dialect for list semantics in Python, including:
4
+
5
+ - The `New` and `Append` statement classes.
6
+ - The lowering pass for list operations.
7
+ - The concrete implementation of list operations.
8
+ - The type inference implementation of list operations.
9
+
10
+ This dialect maps `list()`, `ast.List` and `append()` calls to the `New` and `Append` statements.
11
+ """
12
+
13
+ from . import interp as interp, lowering as lowering, typeinfer as typeinfer
14
+ from .stmts import New as New, Append as Append
15
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("py.list")
@@ -0,0 +1,21 @@
1
+ from kirin import types, interp
2
+ from kirin.dialects.py.binop import Add
3
+
4
+ from .stmts import New, Append
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register
9
+ class ListMethods(interp.MethodTable):
10
+
11
+ @interp.impl(New)
12
+ def new(self, interp, frame: interp.Frame, stmt: New):
13
+ return (list(frame.get_values(stmt.values)),)
14
+
15
+ @interp.impl(Add, types.PyClass(list), types.PyClass(list))
16
+ def add(self, interp, frame: interp.Frame, stmt: Add):
17
+ return (frame.get(stmt.lhs) + frame.get(stmt.rhs),)
18
+
19
+ @interp.impl(Append)
20
+ def append(self, interp, frame: interp.Frame, stmt: Append):
21
+ return (frame.get(stmt.list_).append(frame.get(stmt.value)),)
@@ -0,0 +1,25 @@
1
+ import ast
2
+
3
+ from kirin import types
4
+ from kirin.lowering import Result, FromPythonAST, LoweringState
5
+
6
+ from .stmts import New
7
+ from ._dialect import dialect
8
+
9
+
10
+ @dialect.register
11
+ class PythonLowering(FromPythonAST):
12
+
13
+ def lower_List(self, state: LoweringState, node: ast.List) -> Result:
14
+ elts = tuple(state.visit(each).expect_one() for each in node.elts)
15
+
16
+ if len(elts):
17
+ typ = elts[0].type
18
+ for each in elts:
19
+ typ = typ.join(each.type)
20
+ else:
21
+ typ = types.Any
22
+
23
+ stmt = New(values=tuple(elts))
24
+ state.append_stmt(stmt)
25
+ return Result(stmt)
@@ -0,0 +1,22 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+
4
+ from ._dialect import dialect
5
+
6
+ T = types.TypeVar("T")
7
+
8
+
9
+ @statement(dialect=dialect)
10
+ class New(ir.Statement):
11
+ name = "list"
12
+ traits = frozenset({ir.FromPythonCall()})
13
+ values: tuple[ir.SSAValue, ...] = info.argument(T)
14
+ result: ir.ResultValue = info.result(types.List[T])
15
+
16
+
17
+ @statement(dialect=dialect)
18
+ class Append(ir.Statement):
19
+ name = "append"
20
+ traits = frozenset({ir.FromPythonCall()})
21
+ list_: ir.SSAValue = info.argument(types.List[T])
22
+ value: ir.SSAValue = info.argument(T)
@@ -0,0 +1,54 @@
1
+ from kirin import types, interp
2
+ from kirin.dialects.eltype import ElType
3
+ from kirin.dialects.py.binop import Add
4
+ from kirin.dialects.py.indexing import GetItem
5
+
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register(key="typeinfer")
10
+ class TypeInfer(interp.MethodTable):
11
+
12
+ @interp.impl(ElType, types.PyClass(list))
13
+ def eltype_list(self, interp, frame: interp.Frame, stmt: ElType):
14
+ list_type = frame.get(stmt.container)
15
+ if isinstance(list_type, types.Generic):
16
+ return (list_type.vars[0],)
17
+ else:
18
+ return (types.Any,)
19
+
20
+ @interp.impl(Add, types.PyClass(list), types.PyClass(list))
21
+ def add(self, interp, frame: interp.Frame, stmt: Add):
22
+ lhs_type = frame.get(stmt.lhs)
23
+ rhs_type = frame.get(stmt.rhs)
24
+ if isinstance(lhs_type, types.Generic):
25
+ lhs_elem_type = lhs_type.vars[0]
26
+ else:
27
+ lhs_elem_type = types.Any
28
+
29
+ if isinstance(rhs_type, types.Generic):
30
+ rhs_elem_type = rhs_type.vars[0]
31
+ else:
32
+ rhs_elem_type = types.Any
33
+
34
+ return (types.List[lhs_elem_type.join(rhs_elem_type)],)
35
+
36
+ @interp.impl(GetItem, types.PyClass(list), types.Int)
37
+ def getitem_list_int(
38
+ self, interp, frame: interp.Frame[types.TypeAttribute], stmt: GetItem
39
+ ):
40
+ obj_type = frame.get(stmt.obj)
41
+ if isinstance(obj_type, types.Generic):
42
+ return (obj_type.vars[0],)
43
+ else:
44
+ return (types.Any,)
45
+
46
+ @interp.impl(GetItem, types.PyClass(list), types.PyClass(slice))
47
+ def getitem_list_slice(
48
+ self, interp, frame: interp.Frame[types.TypeAttribute], stmt: GetItem
49
+ ):
50
+ obj_type = frame.get(stmt.obj)
51
+ if isinstance(obj_type, types.Generic):
52
+ return (types.List[obj_type.vars[0]],)
53
+ else:
54
+ return (types.Any,)
@@ -0,0 +1,76 @@
1
+ """The range dialect for Python.
2
+
3
+ This dialect models the builtin `range()` function in Python.
4
+
5
+ The dialect includes:
6
+ - The `Range` statement class.
7
+ - The lowering pass for the `range()` function.
8
+
9
+ This dialect does not include a concrete implementation or type inference
10
+ for the `range()` function. One needs to use other dialect for the concrete
11
+ implementation and type inference, e.g., `ilist` dialect.
12
+ """
13
+
14
+ import ast
15
+ from dataclasses import dataclass
16
+
17
+ from kirin import ir, types, interp, lowering, exceptions
18
+ from kirin.decl import info, statement
19
+ from kirin.dialects import eltype
20
+
21
+ dialect = ir.Dialect("py.range")
22
+
23
+
24
+ @dataclass(frozen=True)
25
+ class RangeLowering(ir.FromPythonCall["Range"]):
26
+
27
+ def lower(
28
+ self, stmt: type["Range"], state: lowering.LoweringState, node: ast.Call
29
+ ) -> lowering.Result:
30
+ return _lower_range(state, node)
31
+
32
+
33
+ @statement(dialect=dialect)
34
+ class Range(ir.Statement):
35
+ name = "range"
36
+ traits = frozenset({ir.Pure(), RangeLowering()})
37
+ start: ir.SSAValue = info.argument(types.Int)
38
+ stop: ir.SSAValue = info.argument(types.Int)
39
+ step: ir.SSAValue = info.argument(types.Int)
40
+ result: ir.ResultValue = info.result(types.PyClass(range))
41
+
42
+
43
+ @dialect.register
44
+ class Lowering(lowering.FromPythonAST):
45
+
46
+ def lower_Call_range(
47
+ self, state: lowering.LoweringState, node: ast.Call
48
+ ) -> lowering.Result:
49
+ return _lower_range(state, node)
50
+
51
+
52
+ @dialect.register(key="typeinfer")
53
+ class TypeInfer(interp.MethodTable):
54
+
55
+ @interp.impl(eltype.ElType, types.PyClass(range))
56
+ def eltype_range(self, interp_, frame: interp.Frame, stmt: eltype.ElType):
57
+ return (types.Int,)
58
+
59
+
60
+ def _lower_range(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
61
+ if len(node.args) == 1:
62
+ start = state.visit(ast.Constant(0)).expect_one()
63
+ stop = state.visit(node.args[0]).expect_one()
64
+ step = state.visit(ast.Constant(1)).expect_one()
65
+ elif len(node.args) == 2:
66
+ start = state.visit(node.args[0]).expect_one()
67
+ stop = state.visit(node.args[1]).expect_one()
68
+ step = state.visit(ast.Constant(1)).expect_one()
69
+ elif len(node.args) == 3:
70
+ start = state.visit(node.args[0]).expect_one()
71
+ stop = state.visit(node.args[1]).expect_one()
72
+ step = state.visit(node.args[2]).expect_one()
73
+ else:
74
+ raise exceptions.DialectLoweringError("range() takes 1-3 arguments")
75
+
76
+ return lowering.Result(state.append_stmt(Range(start, stop, step)))
@@ -0,0 +1,120 @@
1
+ """The slice dialect for Python.
2
+
3
+ This dialect provides a `Slice` statement that represents a slice object in Python:
4
+
5
+ - The `Slice` statement class.
6
+ - The lowering pass for the `slice` call.
7
+ - The concrete implementation of the `slice` call.
8
+ - The type inference implementation of the `slice` call.
9
+ """
10
+
11
+ import ast
12
+ from dataclasses import dataclass
13
+
14
+ from kirin import ir, types, interp, lowering, exceptions
15
+ from kirin.decl import info, statement
16
+ from kirin.dialects.py.constant import Constant
17
+
18
+ dialect = ir.Dialect("py.slice")
19
+
20
+
21
+ @dataclass(frozen=True)
22
+ class SliceLowering(ir.FromPythonCall["Slice"]):
23
+
24
+ def lower(
25
+ self, stmt: type["Slice"], state: lowering.LoweringState, node: ast.Call
26
+ ) -> lowering.Result:
27
+ return _lower_slice(state, node)
28
+
29
+
30
+ T = types.TypeVar("T")
31
+
32
+
33
+ @statement(dialect=dialect, init=False)
34
+ class Slice(ir.Statement):
35
+ name = "slice"
36
+ traits = frozenset({ir.Pure(), SliceLowering()})
37
+ start: ir.SSAValue = info.argument(T | types.NoneType)
38
+ stop: ir.SSAValue = info.argument(T | types.NoneType)
39
+ step: ir.SSAValue = info.argument(T | types.NoneType)
40
+ result: ir.ResultValue = info.result(types.Slice[T])
41
+
42
+ def __init__(
43
+ self, start: ir.SSAValue, stop: ir.SSAValue, step: ir.SSAValue
44
+ ) -> None:
45
+ if not (
46
+ isinstance(stop.type, types.TypeAttribute)
47
+ and isinstance(start.type, types.TypeAttribute)
48
+ ):
49
+ result_type = types.Bottom
50
+ elif start.type.is_subseteq(types.NoneType):
51
+ if stop.type.is_subseteq(types.NoneType):
52
+ result_type = types.Bottom
53
+ else:
54
+ result_type = types.Slice[stop.type]
55
+ else:
56
+ result_type = types.Slice[start.type]
57
+
58
+ super().__init__(
59
+ args=(start, stop, step),
60
+ result_types=[result_type],
61
+ args_slice={"start": 0, "stop": 1, "step": 2},
62
+ )
63
+
64
+
65
+ @dialect.register
66
+ class Concrete(interp.MethodTable):
67
+
68
+ @interp.impl(Slice)
69
+ def _slice(self, interp, frame: interp.Frame, stmt: Slice):
70
+ start, stop, step = frame.get_values(stmt.args)
71
+ if start is None and step is None:
72
+ return (slice(stop),)
73
+ elif step is None:
74
+ return (slice(start, stop),)
75
+ else:
76
+ return (slice(start, stop, step),)
77
+
78
+
79
+ @dialect.register
80
+ class Lowering(lowering.FromPythonAST):
81
+
82
+ def lower_Slice(
83
+ self, state: lowering.LoweringState, node: ast.Slice
84
+ ) -> lowering.Result:
85
+ def value_or_none(expr: ast.expr | None) -> ir.SSAValue:
86
+ if expr is not None:
87
+ return state.visit(expr).expect_one()
88
+ else:
89
+ return state.append_stmt(Constant(None)).result
90
+
91
+ lower = value_or_none(node.lower)
92
+ upper = value_or_none(node.upper)
93
+ step = value_or_none(node.step)
94
+ return lowering.Result(
95
+ state.append_stmt(Slice(start=lower, stop=upper, step=step))
96
+ )
97
+
98
+ def lower_Call_slice(
99
+ self, state: lowering.LoweringState, node: ast.Call
100
+ ) -> lowering.Result:
101
+ return _lower_slice(state, node)
102
+
103
+
104
+ def _lower_slice(state: lowering.LoweringState, node: ast.Call) -> lowering.Result:
105
+ if len(node.args) == 1:
106
+ start = state.visit(ast.Constant(None)).expect_one()
107
+ stop = state.visit(node.args[0]).expect_one()
108
+ step = state.visit(ast.Constant(None)).expect_one()
109
+ elif len(node.args) == 2:
110
+ start = state.visit(node.args[0]).expect_one()
111
+ stop = state.visit(node.args[1]).expect_one()
112
+ step = state.visit(ast.Constant(None)).expect_one()
113
+ elif len(node.args) == 3:
114
+ start = state.visit(node.args[0]).expect_one()
115
+ stop = state.visit(node.args[1]).expect_one()
116
+ step = state.visit(node.args[2]).expect_one()
117
+ else:
118
+ raise exceptions.DialectLoweringError("slice() takes 1-3 arguments")
119
+
120
+ return lowering.Result(state.append_stmt(Slice(start, stop, step)))
@@ -0,0 +1,109 @@
1
+ """The tuple dialect for Python.
2
+
3
+ This dialect provides a way to work with Python tuples in the IR, including:
4
+
5
+ - The `New` statement class.
6
+ - The lowering pass for the tuple statement.
7
+ - The concrete implementation of the tuple statement.
8
+ - The type inference implementation of the tuple addition with `py.binop.Add`.
9
+ - The constant propagation implementation of the tuple statement.
10
+ - The Julia emitter for the tuple statement.
11
+
12
+ This dialect maps `ast.Tuple` nodes to the `New` statement.
13
+ """
14
+
15
+ import ast
16
+
17
+ from kirin import ir, types, interp, lowering
18
+ from kirin.decl import info, statement
19
+ from kirin.analysis import const
20
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
21
+ from kirin.dialects.eltype import ElType
22
+ from kirin.dialects.py.binop import Add
23
+
24
+ dialect = ir.Dialect("py.tuple")
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class New(ir.Statement):
29
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
30
+ result: ir.ResultValue = info.result()
31
+
32
+ def __init__(self, values: tuple[ir.SSAValue, ...]) -> None:
33
+ result_type = types.Generic(tuple, *tuple(value.type for value in values))
34
+ super().__init__(
35
+ args=values,
36
+ result_types=[
37
+ result_type,
38
+ ],
39
+ )
40
+
41
+
42
+ @dialect.register
43
+ class Concrete(interp.MethodTable):
44
+
45
+ @interp.impl(New)
46
+ def new(self, interp: interp.Interpreter, frame: interp.Frame, stmt: New):
47
+ return (frame.get_values(stmt.args),)
48
+
49
+
50
+ @dialect.register(key="typeinfer")
51
+ class TypeInfer(interp.MethodTable):
52
+
53
+ @interp.impl(ElType, types.PyClass(tuple))
54
+ def eltype_tuple(self, interp, frame: interp.Frame, stmt: ElType):
55
+ tuple_type = frame.get(stmt.container)
56
+ if isinstance(tuple_type, types.Generic):
57
+ ret = tuple_type.vars[0]
58
+ for var in tuple_type.vars[1:]:
59
+ ret = ret.join(var)
60
+ return (ret,)
61
+ else:
62
+ return (types.Any,)
63
+
64
+ @interp.impl(Add, types.PyClass(tuple), types.PyClass(tuple))
65
+ def add(self, interp, frame: interp.Frame[types.TypeAttribute], stmt):
66
+ lhs = frame.get(stmt.lhs)
67
+ rhs = frame.get(stmt.rhs)
68
+ if isinstance(lhs, types.Generic) and isinstance(rhs, types.Generic):
69
+ return (types.Generic(tuple, *(lhs.vars + rhs.vars)),)
70
+ else:
71
+ return (types.PyClass(tuple),) # no type param, so unknown
72
+
73
+
74
+ @dialect.register(key="constprop")
75
+ class ConstPropTable(interp.MethodTable):
76
+
77
+ @interp.impl(New)
78
+ def new_tuple(
79
+ self,
80
+ _: const.Propagate,
81
+ frame: const.Frame,
82
+ stmt: New,
83
+ ) -> interp.StatementResult[const.Result]:
84
+ return (const.PartialTuple(tuple(x for x in frame.get_values(stmt.args))),)
85
+
86
+
87
+ @dialect.register
88
+ class Lowering(lowering.FromPythonAST):
89
+
90
+ def lower_Tuple(
91
+ self, state: lowering.LoweringState, node: ast.Tuple
92
+ ) -> lowering.Result:
93
+ return lowering.Result(
94
+ state.append_stmt(
95
+ stmt=New(tuple(state.visit(elem).expect_one() for elem in node.elts))
96
+ )
97
+ )
98
+
99
+
100
+ @dialect.register(key="emit.julia")
101
+ class JuliaTable(interp.MethodTable):
102
+
103
+ @interp.impl(New)
104
+ def emit_NewTuple(self, emit: EmitJulia, frame: EmitStrFrame, stmt: New):
105
+ return (
106
+ emit.write_assign(
107
+ frame, stmt.result, "(" + ", ".join(frame.get_values(stmt.args)) + ")"
108
+ ),
109
+ )
@@ -0,0 +1,24 @@
1
+ """The unary dialect for Python.
2
+
3
+ This module contains the dialect for unary semantics in Python, including:
4
+
5
+ - The `UnaryOp` base class for unary operations.
6
+ - The `UAdd`, `USub`, `Not`, and `Invert` statement classes.
7
+ - The lowering pass for unary operations.
8
+ - The concrete implementation of unary operations.
9
+ - The type inference implementation of unary operations.
10
+ - The constant propagation implementation of unary operations.
11
+ - The Julia emitter for unary operations.
12
+
13
+ This dialect maps `ast.UnaryOp` nodes to the `UAdd`, `USub`, `Not`, and `Invert` statements.
14
+ """
15
+
16
+ from . import (
17
+ julia as julia,
18
+ interp as interp,
19
+ lowering as lowering,
20
+ constprop as constprop,
21
+ typeinfer as typeinfer,
22
+ )
23
+ from .stmts import * # noqa: F403
24
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("py.unary")
@@ -0,0 +1,20 @@
1
+ from kirin import interp
2
+ from kirin.analysis import const
3
+
4
+ from . import stmts
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register(key="constprop")
9
+ class ConstProp(interp.MethodTable):
10
+
11
+ @interp.impl(stmts.Not)
12
+ def not_(
13
+ self, _: const.Propagate, frame: const.Frame, stmt: stmts.Not
14
+ ) -> interp.StatementResult[const.Result]:
15
+ hint = frame.get(stmt.value)
16
+ if isinstance(hint, (const.PartialTuple, const.Value)):
17
+ ret = const.Value(not hint.data)
18
+ else:
19
+ ret = const.Unknown()
20
+ return (ret,)
@@ -0,0 +1,24 @@
1
+ from kirin import interp
2
+
3
+ from . import stmts
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register
8
+ class Concrete(interp.MethodTable):
9
+
10
+ @interp.impl(stmts.UAdd)
11
+ def uadd(self, interp, frame: interp.Frame, stmt: stmts.UAdd):
12
+ return (+frame.get(stmt.value),)
13
+
14
+ @interp.impl(stmts.USub)
15
+ def usub(self, interp, frame: interp.Frame, stmt: stmts.USub):
16
+ return (-frame.get(stmt.value),)
17
+
18
+ @interp.impl(stmts.Not)
19
+ def not_(self, interp, frame: interp.Frame, stmt: stmts.Not):
20
+ return (not frame.get(stmt.value),)
21
+
22
+ @interp.impl(stmts.Invert)
23
+ def invert(self, interp, frame: interp.Frame, stmt: stmts.Invert):
24
+ return (~frame.get(stmt.value),)
@@ -0,0 +1,21 @@
1
+ from kirin import interp
2
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
3
+
4
+ from . import stmts
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register(key="emit.julia")
9
+ class JuliaTable(interp.MethodTable):
10
+
11
+ @interp.impl(stmts.Not)
12
+ def emit_Not(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Not):
13
+ return (emit.write_assign(frame, stmt.result, f"!{frame.get(stmt.value)}"),)
14
+
15
+ @interp.impl(stmts.USub)
16
+ def emit_USub(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.USub):
17
+ return (emit.write_assign(frame, stmt.result, f"-{frame.get(stmt.value)}"),)
18
+
19
+ @interp.impl(stmts.UAdd)
20
+ def emit_UAdd(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.UAdd):
21
+ return (emit.write_assign(frame, stmt.result, f"+{frame.get(stmt.value)}"),)
@@ -0,0 +1,22 @@
1
+ import ast
2
+
3
+ from kirin import lowering, exceptions
4
+
5
+ from . import stmts
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register
10
+ class Lowering(lowering.FromPythonAST):
11
+
12
+ def lower_UnaryOp(
13
+ self, state: lowering.LoweringState, node: ast.UnaryOp
14
+ ) -> lowering.Result:
15
+ if op := getattr(stmts, node.op.__class__.__name__, None):
16
+ return lowering.Result(
17
+ state.append_stmt(op(state.visit(node.operand).expect_one()))
18
+ )
19
+ else:
20
+ raise exceptions.DialectLoweringError(
21
+ f"unsupported unary operator {node.op}"
22
+ )
@@ -0,0 +1,33 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+
4
+ from ._dialect import dialect
5
+
6
+ T = types.TypeVar("T")
7
+
8
+
9
+ @statement
10
+ class UnaryOp(ir.Statement):
11
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
12
+ value: ir.SSAValue = info.argument(T, print=False)
13
+ result: ir.ResultValue = info.result(T)
14
+
15
+
16
+ @statement(dialect=dialect)
17
+ class UAdd(UnaryOp):
18
+ name = "uadd"
19
+
20
+
21
+ @statement(dialect=dialect)
22
+ class USub(UnaryOp):
23
+ name = "usub"
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class Not(UnaryOp):
28
+ name = "not"
29
+
30
+
31
+ @statement(dialect=dialect)
32
+ class Invert(UnaryOp):
33
+ name = "invert"
@@ -0,0 +1,23 @@
1
+ from kirin import types, interp
2
+
3
+ from . import stmts
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register(key="typeinfer")
8
+ class TypeInfer(interp.MethodTable):
9
+
10
+ @interp.impl(stmts.UAdd)
11
+ @interp.impl(stmts.USub)
12
+ def uadd(
13
+ self, interp, frame: interp.Frame[types.TypeAttribute], stmt: stmts.UnaryOp
14
+ ):
15
+ return (frame.get(stmt.value),)
16
+
17
+ @interp.impl(stmts.Not)
18
+ def not_(self, interp, frame, stmt: stmts.Not):
19
+ return (types.Bool,)
20
+
21
+ @interp.impl(stmts.Invert, types.Int)
22
+ def invert(self, interp, frame, stmt):
23
+ return (types.Int,)