kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,16 @@
1
+ """The cmp dialect for Python.
2
+
3
+ This module contains the dialect for comparison semantics in Python, including:
4
+
5
+ - The `Eq`, `NotEq`, `Lt`, `LtE`, `Gt`, `GtE`, `Is`, and `IsNot` statement classes.
6
+ - The lowering pass for comparison operations.
7
+ - The concrete implementation of comparison operations.
8
+ - The Julia emitter for comparison operations.
9
+
10
+ This dialect maps `ast.Compare` nodes to the `Eq`, `NotEq`, `Lt`, `LtE`,
11
+ `Gt`, `GtE`, `Is`, and `IsNot` statements.
12
+ """
13
+
14
+ from . import julia as julia, interp as interp, lowering as lowering
15
+ from .stmts import * # noqa: F403
16
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("py.cmp")
@@ -0,0 +1,48 @@
1
+ from kirin import interp
2
+
3
+ from . import stmts as cmp
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register
8
+ class CmpMethod(interp.MethodTable):
9
+
10
+ @interp.impl(cmp.Eq)
11
+ def eq(self, interp, frame: interp.Frame, stmt: cmp.Eq):
12
+ return (frame.get(stmt.lhs) == frame.get(stmt.rhs),)
13
+
14
+ @interp.impl(cmp.NotEq)
15
+ def not_eq(self, interp, frame: interp.Frame, stmt: cmp.NotEq):
16
+ return (frame.get(stmt.lhs) != frame.get(stmt.rhs),)
17
+
18
+ @interp.impl(cmp.Lt)
19
+ def lt(self, interp, frame: interp.Frame, stmt: cmp.Lt):
20
+ return (frame.get(stmt.lhs) < frame.get(stmt.rhs),)
21
+
22
+ @interp.impl(cmp.LtE)
23
+ def lt_eq(self, interp, frame: interp.Frame, stmt: cmp.LtE):
24
+ return (frame.get(stmt.lhs) <= frame.get(stmt.rhs),)
25
+
26
+ @interp.impl(cmp.Gt)
27
+ def gt(self, interp, frame: interp.Frame, stmt: cmp.Gt):
28
+ return (frame.get(stmt.lhs) > frame.get(stmt.rhs),)
29
+
30
+ @interp.impl(cmp.GtE)
31
+ def gt_eq(self, interp, frame: interp.Frame, stmt: cmp.GtE):
32
+ return (frame.get(stmt.lhs) >= frame.get(stmt.rhs),)
33
+
34
+ @interp.impl(cmp.In)
35
+ def in_(self, interp, frame: interp.Frame, stmt: cmp.In):
36
+ return (frame.get(stmt.lhs) in frame.get(stmt.rhs),)
37
+
38
+ @interp.impl(cmp.NotIn)
39
+ def not_in(self, interp, frame: interp.Frame, stmt: cmp.NotIn):
40
+ return (frame.get(stmt.lhs) not in frame.get(stmt.rhs),)
41
+
42
+ @interp.impl(cmp.Is)
43
+ def is_(self, interp, frame: interp.Frame, stmt: cmp.Is):
44
+ return (frame.get(stmt.lhs) is frame.get(stmt.rhs),)
45
+
46
+ @interp.impl(cmp.IsNot)
47
+ def is_not(self, interp, frame: interp.Frame, stmt: cmp.IsNot):
48
+ return (frame.get(stmt.lhs) is not frame.get(stmt.rhs),)
@@ -0,0 +1,33 @@
1
+ from kirin import interp
2
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
3
+
4
+ from . import stmts
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register(key="emit.julia")
9
+ class JuliaTable(interp.MethodTable):
10
+
11
+ @interp.impl(stmts.Eq)
12
+ def emit_Eq(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Eq):
13
+ return emit.emit_binaryop(frame, "==", stmt.lhs, stmt.rhs, stmt.result)
14
+
15
+ @interp.impl(stmts.GtE)
16
+ def emit_GtE(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.GtE):
17
+ return emit.emit_binaryop(frame, ">=", stmt.lhs, stmt.rhs, stmt.result)
18
+
19
+ @interp.impl(stmts.LtE)
20
+ def emit_LtE(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.LtE):
21
+ return emit.emit_binaryop(frame, "<=", stmt.lhs, stmt.rhs, stmt.result)
22
+
23
+ @interp.impl(stmts.NotEq)
24
+ def emit_NotEq(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.NotEq):
25
+ return emit.emit_binaryop(frame, "!=", stmt.lhs, stmt.rhs, stmt.result)
26
+
27
+ @interp.impl(stmts.Gt)
28
+ def emit_Gt(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Gt):
29
+ return emit.emit_binaryop(frame, ">", stmt.lhs, stmt.rhs, stmt.result)
30
+
31
+ @interp.impl(stmts.Lt)
32
+ def emit_Lt(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Lt):
33
+ return emit.emit_binaryop(frame, "<", stmt.lhs, stmt.rhs, stmt.result)
@@ -0,0 +1,45 @@
1
+ import ast
2
+
3
+ from kirin.ir import SSAValue
4
+ from kirin.lowering import Result, FromPythonAST, LoweringState
5
+ from kirin.exceptions import DialectLoweringError
6
+ from kirin.dialects.py import boolop
7
+
8
+ from . import stmts as cmp
9
+ from ._dialect import dialect
10
+
11
+
12
+ @dialect.register
13
+ class PythonLowering(FromPythonAST):
14
+
15
+ def lower_Compare(self, state: LoweringState, node: ast.Compare) -> Result:
16
+ # NOTE: a key difference here is we need to lower
17
+ # the multi-argument comparison operators into binary operators
18
+ # since low-level comparision operators are binary + we need a static
19
+ # number of arguments in each instruction
20
+ lhs = state.visit(node.left).expect_one()
21
+
22
+ comparators = [
23
+ state.visit(comparator).expect_one() for comparator in node.comparators
24
+ ]
25
+
26
+ cmp_results: list[SSAValue] = []
27
+ for op, rhs in zip(node.ops, comparators):
28
+ if op := getattr(cmp, op.__class__.__name__, None):
29
+ stmt = op(lhs=lhs, rhs=rhs)
30
+ else:
31
+ raise DialectLoweringError(f"unsupported compare operator {op}")
32
+ state.append_stmt(stmt)
33
+ cmp_results.append(Result(stmt).expect_one())
34
+ lhs = rhs
35
+
36
+ if len(cmp_results) == 1:
37
+ return Result(cmp_results)
38
+
39
+ lhs = cmp_results[0]
40
+ for op in cmp_results[1:]:
41
+ stmt = boolop.And(lhs=lhs, rhs=op)
42
+ state.append_stmt(stmt)
43
+ lhs = stmt.result
44
+
45
+ return Result(lhs)
@@ -0,0 +1,62 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+
4
+ from ._dialect import dialect
5
+
6
+
7
+ @statement
8
+ class Cmp(ir.Statement):
9
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
10
+ lhs: ir.SSAValue = info.argument()
11
+ rhs: ir.SSAValue = info.argument()
12
+ result: ir.ResultValue = info.result(types.Bool)
13
+
14
+
15
+ @statement(dialect=dialect)
16
+ class Eq(Cmp):
17
+ name = "eq"
18
+
19
+
20
+ @statement(dialect=dialect)
21
+ class NotEq(Cmp):
22
+ name = "ne"
23
+
24
+
25
+ @statement(dialect=dialect)
26
+ class Lt(Cmp):
27
+ name = "lt"
28
+
29
+
30
+ @statement(dialect=dialect)
31
+ class Gt(Cmp):
32
+ name = "gt"
33
+
34
+
35
+ @statement(dialect=dialect)
36
+ class LtE(Cmp):
37
+ name = "lte"
38
+
39
+
40
+ @statement(dialect=dialect)
41
+ class GtE(Cmp):
42
+ name = "gte"
43
+
44
+
45
+ @statement(dialect=dialect)
46
+ class Is(Cmp):
47
+ name = "is"
48
+
49
+
50
+ @statement(dialect=dialect)
51
+ class IsNot(Cmp):
52
+ name = "is_not"
53
+
54
+
55
+ @statement(dialect=dialect)
56
+ class In(Cmp):
57
+ name = "in"
58
+
59
+
60
+ @statement(dialect=dialect)
61
+ class NotIn(Cmp):
62
+ name = "not_in"
@@ -0,0 +1,79 @@
1
+ """Constant statement for Python dialect.
2
+
3
+ This module contains the dialect for the Python `constant` statement, including:
4
+
5
+ - The `Constant` statement class.
6
+ - The lowering pass for the `constant` statement.
7
+ - The concrete implementation of the `constant` statement.
8
+ - The Julia emitter for the `constant` statement.
9
+
10
+ This dialect maps `ast.Constant` nodes to the `Constant` statement.
11
+ """
12
+
13
+ import ast
14
+ from typing import Generic, TypeVar
15
+
16
+ from kirin import ir, types, interp, lowering, exceptions
17
+ from kirin.decl import info, statement
18
+ from kirin.print import Printer
19
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
20
+
21
+ dialect = ir.Dialect("py.constant")
22
+
23
+ T = TypeVar("T", covariant=True)
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class Constant(ir.Statement, Generic[T]):
28
+ name = "constant"
29
+ traits = frozenset({ir.Pure(), ir.ConstantLike(), ir.FromPythonCall()})
30
+ value: T = info.attribute()
31
+ result: ir.ResultValue = info.result()
32
+
33
+ # NOTE: we allow py.Constant take data.PyAttr too
34
+ def __init__(self, value: T | ir.PyAttr[T]) -> None:
35
+ if not isinstance(value, ir.PyAttr):
36
+ value = ir.PyAttr(value)
37
+ super().__init__(
38
+ attributes={"value": value},
39
+ result_types=(value.type,),
40
+ )
41
+
42
+ def print_impl(self, printer: Printer) -> None:
43
+ printer.print_name(self)
44
+ printer.plain_print(" ")
45
+ printer.plain_print(repr(self.value))
46
+ with printer.rich(style="comment"):
47
+ printer.plain_print(" : ")
48
+ printer.print(self.result.type)
49
+
50
+ def typecheck(self) -> None:
51
+ if not isinstance(self.result.type, types.TypeAttribute):
52
+ raise exceptions.VerificationError(
53
+ self, f"Expected result type to be PyType, got {self.result.type}"
54
+ )
55
+
56
+
57
+ @dialect.register
58
+ class Lowering(lowering.FromPythonAST):
59
+
60
+ def lower_Constant(
61
+ self, state: lowering.LoweringState, node: ast.Constant
62
+ ) -> lowering.Result:
63
+ return lowering.Result(state.append_stmt(Constant(node.value)))
64
+
65
+
66
+ @dialect.register
67
+ class Concrete(interp.MethodTable):
68
+
69
+ @interp.impl(Constant)
70
+ def constant(self, interp, frame: interp.Frame, stmt: Constant):
71
+ return (stmt.value,)
72
+
73
+
74
+ @dialect.register(key="emit.julia")
75
+ class JuliaTable(interp.MethodTable):
76
+
77
+ @interp.impl(Constant)
78
+ def emit_Constant(self, emit: EmitJulia, frame: EmitStrFrame, stmt: Constant):
79
+ return (emit.emit_attribute(ir.PyAttr(stmt.value)),)
@@ -0,0 +1,251 @@
1
+ """The indexing dialect for Python.
2
+
3
+ This module contains the dialect for the Python indexing syntax, including:
4
+
5
+ - The `GetItem` statement class.
6
+ - A base class `Subscript` for indexing statements.
7
+ - A trait `GetItemLike` for indexing statements.
8
+ - The lowering pass for the indexing statement.
9
+ - The concrete implementation of the indexing statement.
10
+ - The constant propagation implementation (special case) of the indexing statement.
11
+ - The type inference implementation of the indexing statement.
12
+ - A canonical rewrite rule for the rewriting of a given getitem-like
13
+ statement to another getitem-like statement.
14
+ """
15
+
16
+ import ast
17
+ from abc import abstractmethod
18
+ from typing import Generic, TypeVar
19
+ from dataclasses import dataclass
20
+
21
+ from kirin import ir, types, interp, lowering, exceptions
22
+ from kirin.decl import info, statement
23
+ from kirin.analysis import const
24
+ from kirin.rewrite.abc import RewriteRule, RewriteResult
25
+ from kirin.analysis.typeinfer import TypeInference
26
+
27
+ dialect = ir.Dialect("py.indexing")
28
+
29
+ GetItemLikeStmt = TypeVar("GetItemLikeStmt", bound=ir.Statement)
30
+
31
+
32
+ @dataclass(frozen=True, eq=False)
33
+ class GetItemLike(ir.StmtTrait, Generic[GetItemLikeStmt]):
34
+
35
+ @abstractmethod
36
+ def get_object(self, stmt: GetItemLikeStmt) -> ir.SSAValue: ...
37
+
38
+ @abstractmethod
39
+ def get_index(self, stmt: GetItemLikeStmt) -> ir.SSAValue: ...
40
+
41
+ @abstractmethod
42
+ def new(
43
+ self, stmt_type: type[GetItemLikeStmt], obj: ir.SSAValue, index: ir.SSAValue
44
+ ) -> GetItemLikeStmt: ...
45
+
46
+
47
+ PyGetItemLikeStmt = TypeVar("PyGetItemLikeStmt", bound="GetItem")
48
+
49
+
50
+ @dataclass(frozen=True, eq=False)
51
+ class PyGetItemLike(GetItemLike[PyGetItemLikeStmt]):
52
+
53
+ def get_object(self, stmt: PyGetItemLikeStmt) -> ir.SSAValue:
54
+ return stmt.obj
55
+
56
+ def get_index(self, stmt: PyGetItemLikeStmt) -> ir.SSAValue:
57
+ return stmt.index
58
+
59
+ def new(
60
+ self, stmt_type: type[PyGetItemLikeStmt], obj: ir.SSAValue, index: ir.SSAValue
61
+ ) -> PyGetItemLikeStmt:
62
+ return stmt_type(obj=obj, index=index)
63
+
64
+
65
+ # NOTE: in IR setindex is very different from getindex
66
+ # taking Julia's semantics as reference here
67
+ @statement
68
+ class Subscript(ir.Statement):
69
+ pass
70
+
71
+
72
+ @statement(dialect=dialect)
73
+ class GetItem(Subscript):
74
+ name = "getitem"
75
+ traits = frozenset({ir.Pure(), PyGetItemLike(), ir.FromPythonCall()})
76
+ obj: ir.SSAValue = info.argument(print=False)
77
+ index: ir.SSAValue = info.argument(print=False)
78
+ result: ir.ResultValue = info.result(types.Any)
79
+
80
+
81
+ @dialect.register
82
+ class Lowering(lowering.FromPythonAST):
83
+
84
+ def lower_Subscript(
85
+ self, state: lowering.LoweringState, node: ast.Subscript
86
+ ) -> lowering.Result:
87
+ value = state.visit(node.value).expect_one()
88
+ slice = state.visit(node.slice).expect_one()
89
+ if isinstance(node.ctx, ast.Load):
90
+ stmt = GetItem(obj=value, index=slice)
91
+ else:
92
+ raise exceptions.DialectLoweringError(
93
+ f"unsupported subscript context {node.ctx}"
94
+ )
95
+ state.append_stmt(stmt)
96
+ return lowering.Result(stmt)
97
+
98
+
99
+ @dialect.register
100
+ class Concrete(interp.MethodTable):
101
+
102
+ @interp.impl(GetItem)
103
+ def getindex(self, interp, frame: interp.Frame, stmt: GetItem):
104
+ return (frame.get(stmt.obj)[frame.get(stmt.index)],)
105
+
106
+
107
+ @dialect.register(key="typeinfer")
108
+ class TypeInfer(interp.MethodTable):
109
+
110
+ @interp.impl(GetItem)
111
+ def getitem(
112
+ self,
113
+ interp: TypeInference,
114
+ frame: interp.Frame[types.TypeAttribute],
115
+ stmt: GetItem,
116
+ ):
117
+ obj = frame.get(stmt.obj)
118
+ index: types.TypeAttribute = frame.get(stmt.index)
119
+ # TODO: replace this when we can multiple dispatch
120
+ if obj.is_subseteq(types.Tuple):
121
+ return self.getitem_tuple(interp, stmt, obj, index)
122
+ elif obj.is_subseteq(types.String):
123
+ return (types.String,)
124
+ else:
125
+ return (types.Any,)
126
+
127
+ def getitem_tuple(
128
+ self,
129
+ interp,
130
+ stmt: GetItem,
131
+ obj: types.TypeAttribute,
132
+ index: types.TypeAttribute,
133
+ ):
134
+ if isinstance(obj, types.Generic):
135
+ if index.is_subseteq(types.Int):
136
+ return self.getitem_tuple_index(interp, stmt, obj, index)
137
+ elif index.is_subseteq(types.Slice):
138
+ return self.getitem_tuple_slice(interp, stmt, obj, index)
139
+ else:
140
+ return (types.Bottom,)
141
+ elif isinstance(obj, types.PyClass):
142
+ return (types.Any,)
143
+ else:
144
+ return (types.Bottom,)
145
+
146
+ def getitem_tuple_index(
147
+ self,
148
+ interp: TypeInference,
149
+ stmt: GetItem,
150
+ obj: types.Generic,
151
+ index: types.TypeAttribute,
152
+ ):
153
+ if index_ := interp.maybe_const(stmt.index, int):
154
+ if obj.vararg and (index_ >= len(obj.vars) or -len(obj.vars) <= index_ < 0):
155
+ return (obj.vararg.typ,)
156
+ elif obj.vars and (
157
+ 0 <= index_ < len(obj.vars) or -len(obj.vars) <= index_ < 0
158
+ ):
159
+ return (obj.vars[index_],)
160
+ else:
161
+ return (types.Bottom,)
162
+ else:
163
+ return (self.getitem_tuple_union(obj),)
164
+
165
+ def getitem_tuple_slice(
166
+ self,
167
+ interp: TypeInference,
168
+ stmt: GetItem,
169
+ obj: types.Generic,
170
+ index: types.TypeAttribute,
171
+ ):
172
+ if index_ := interp.maybe_const(stmt.index, slice):
173
+ if obj.vararg and index_.stop >= len(obj.vars):
174
+ return (
175
+ types.Union(
176
+ *obj.vars[slice(index_.start, len(obj.vars), index_.step)],
177
+ obj.vararg.typ,
178
+ ),
179
+ )
180
+ elif index_.stop is None or index_.stop < len(obj.vars):
181
+ return (
182
+ types.Tuple.where(
183
+ obj.vars[slice(index_.start, index_.stop, index_.step)]
184
+ ),
185
+ )
186
+ else: # out of bounds
187
+ return (types.Bottom,)
188
+ else:
189
+ return (types.Tuple[types.Vararg(self.getitem_tuple_union(obj))],)
190
+
191
+ def getitem_tuple_union(self, obj: types.Generic):
192
+ if obj.vararg:
193
+ return types.Union(*obj.vars, obj.vararg.typ)
194
+ else:
195
+ return types.Union(*obj.vars)
196
+
197
+
198
+ @dialect.register(key="constprop")
199
+ class ConstProp(interp.MethodTable):
200
+
201
+ @interp.impl(GetItem)
202
+ def getitem(
203
+ self,
204
+ _: const.Propagate,
205
+ frame: const.Frame,
206
+ stmt: GetItem,
207
+ ) -> interp.StatementResult[const.Result]:
208
+ obj = frame.get(stmt.obj)
209
+ index = frame.get(stmt.index)
210
+ if not isinstance(index, const.Value):
211
+ return (const.Unknown(),)
212
+
213
+ if isinstance(obj, const.PartialTuple):
214
+ obj = obj.data
215
+ if isinstance(index.data, int) and 0 <= index.data < len(obj):
216
+ return (obj[index.data],)
217
+ elif isinstance(index.data, slice):
218
+ start, stop, step = index.data.indices(len(obj))
219
+ return (const.PartialTuple(obj[start:stop:step]),)
220
+ return (const.Unknown(),)
221
+
222
+
223
+ GetItemLikeStmt = TypeVar("GetItemLikeStmt", bound=ir.Statement)
224
+
225
+
226
+ @dataclass(init=False)
227
+ class RewriteGetItem(RewriteRule, Generic[GetItemLikeStmt]):
228
+ target_stmt_type: type[GetItemLikeStmt]
229
+ obj_type: types.TypeAttribute
230
+ getitem_like: GetItemLike[GetItemLikeStmt]
231
+
232
+ def __init__(self, stmt_type: type[GetItemLikeStmt], obj_type: types.TypeAttribute):
233
+ trait = stmt_type.get_trait(GetItemLike)
234
+ if trait is None:
235
+ raise ValueError(f"{stmt_type} does not have GetItemLike trait")
236
+
237
+ self.obj_type = obj_type
238
+ self.target_stmt_type = stmt_type
239
+ self.getitem_like = trait
240
+
241
+ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
242
+ if not isinstance(node, GetItem):
243
+ return RewriteResult()
244
+
245
+ if not node.obj.type.is_subseteq(self.obj_type):
246
+ return RewriteResult()
247
+
248
+ node.replace_by(
249
+ self.getitem_like.new(self.target_stmt_type, node.obj, node.index)
250
+ )
251
+ return RewriteResult(has_done_something=True)
@@ -0,0 +1,90 @@
1
+ """This module provides access to Python iterables.
2
+
3
+ This is used to lower Python loops into `cf` dialect.
4
+
5
+ This module contains the common methods for the Python iterable:
6
+
7
+ - The `Iter` statement class.
8
+ - The `Next` statement class.
9
+ - The lowering pass for the iterable.
10
+ - The concrete implementation of the iterable.
11
+
12
+ This dialect maps `iter()` and `next()` calls to the `Iter` and `Next` statements.
13
+ """
14
+
15
+ from ast import Call
16
+
17
+ from kirin import ir, types, interp, lowering
18
+ from kirin.decl import info, statement
19
+ from kirin.exceptions import DialectLoweringError
20
+
21
+ dialect = ir.Dialect("py.iterable")
22
+
23
+ PyRangeIterType = types.PyClass(type(iter(range(0))))
24
+
25
+
26
+ @statement(dialect=dialect)
27
+ class Iter(ir.Statement):
28
+ """This is equivalent to `iter(value)` in Python."""
29
+
30
+ traits = frozenset({ir.Pure()})
31
+ value: ir.SSAValue = info.argument(types.Any)
32
+ iter: ir.ResultValue = info.result(types.Any)
33
+
34
+
35
+ @statement(dialect=dialect)
36
+ class Next(ir.Statement):
37
+ """This is equivalent to `next(iterable, None)` in Python."""
38
+
39
+ iter: ir.SSAValue = info.argument(types.Any)
40
+ value: ir.ResultValue = info.result(types.Any)
41
+
42
+
43
+ @dialect.register
44
+ class Concrete(interp.MethodTable):
45
+
46
+ @interp.impl(Iter)
47
+ def iter_(self, interp, frame: interp.Frame, stmt: Iter):
48
+ return (iter(frame.get(stmt.value)),)
49
+
50
+ @interp.impl(Next)
51
+ def next_(self, interp, frame: interp.Frame, stmt: Next):
52
+ return (next(frame.get(stmt.iter), None),)
53
+
54
+
55
+ @dialect.register(key="typeinfer")
56
+ class TypeInfer(interp.MethodTable):
57
+
58
+ @interp.impl(Iter, types.PyClass(range))
59
+ def iter_(self, interp, frame: interp.Frame, stmt: Iter):
60
+ return (PyRangeIterType,)
61
+
62
+ @interp.impl(Next, PyRangeIterType)
63
+ def next_(self, interp, frame: interp.Frame, stmt: Next):
64
+ return (types.Int,)
65
+
66
+
67
+ @dialect.register
68
+ class Lowering(lowering.FromPythonAST):
69
+
70
+ def lower_Call_iter(
71
+ self, state: lowering.LoweringState, node: Call
72
+ ) -> lowering.Result:
73
+ if len(node.args) != 1:
74
+ raise DialectLoweringError("iter() takes exactly 1 argument")
75
+ return lowering.Result(
76
+ state.append_stmt(Iter(state.visit(node.args[0]).expect_one()))
77
+ )
78
+
79
+ def lower_Call_next(
80
+ self, state: lowering.LoweringState, node: Call
81
+ ) -> lowering.Result:
82
+ if len(node.args) == 2:
83
+ raise DialectLoweringError(
84
+ "next() does not throw StopIteration inside kernel"
85
+ )
86
+ if len(node.args) != 1:
87
+ raise DialectLoweringError("next() takes exactly 1 argument")
88
+ return lowering.Result(
89
+ state.append_stmt(Next(state.visit(node.args[0]).expect_one()))
90
+ )
@@ -0,0 +1,57 @@
1
+ """The `Len` dialect.
2
+
3
+ This dialect maps the `len()` call to the `Len` statement:
4
+
5
+ - The `Len` statement class.
6
+ - The lowering pass for the `len()` call.
7
+ - The concrete implementation of the `len()` call.
8
+ """
9
+
10
+ import ast
11
+
12
+ from kirin import ir, types, interp, lowering
13
+ from kirin.decl import info, statement
14
+ from kirin.analysis import const
15
+
16
+ dialect = ir.Dialect("py.len")
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class Len(ir.Statement):
21
+ name = "len"
22
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
23
+ value: ir.SSAValue = info.argument(types.Any)
24
+ result: ir.ResultValue = info.result(types.Int)
25
+
26
+
27
+ @dialect.register
28
+ class Concrete(interp.MethodTable):
29
+
30
+ @interp.impl(Len)
31
+ def len(self, interp, frame: interp.Frame, stmt: Len):
32
+ return (len(frame.get(stmt.value)),)
33
+
34
+
35
+ @dialect.register(key="constprop")
36
+ class ConstProp(interp.MethodTable):
37
+
38
+ @interp.impl(Len)
39
+ def len(self, interp, frame: interp.Frame, stmt: Len):
40
+ value = frame.get(stmt.value)
41
+ if isinstance(value, const.Value):
42
+ return (const.Value(len(value.data)),)
43
+ elif isinstance(value, const.PartialTuple):
44
+ return (const.Value(len(value.data)),)
45
+ else:
46
+ return (const.Result.top(),)
47
+
48
+
49
+ @dialect.register
50
+ class Lowering(lowering.FromPythonAST):
51
+
52
+ def lower_Call_len(
53
+ self, state: lowering.LoweringState, node: ast.Call
54
+ ) -> lowering.Result:
55
+ return lowering.Result(
56
+ state.append_stmt(Len(value=state.visit(node.args[0]).expect_one()))
57
+ )