kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,51 @@
1
+ from kirin.interp import Successor, MethodTable, impl
2
+ from kirin.analysis import const
3
+ from kirin.dialects.cf.stmts import Branch, ConditionalBranch
4
+ from kirin.dialects.cf.dialect import dialect
5
+
6
+
7
+ @dialect.register(key="constprop")
8
+ class DialectConstProp(MethodTable):
9
+
10
+ @impl(Branch)
11
+ def branch(self, interp: const.Propagate, frame: const.Frame, stmt: Branch):
12
+ interp.state.current_frame().worklist.append(
13
+ Successor(stmt.successor, *frame.get_values(stmt.arguments))
14
+ )
15
+ return ()
16
+
17
+ @impl(ConditionalBranch)
18
+ def conditional_branch(
19
+ self,
20
+ interp: const.Propagate,
21
+ frame: const.Frame,
22
+ stmt: ConditionalBranch,
23
+ ):
24
+ frame = interp.state.current_frame()
25
+ cond = frame.get(stmt.cond)
26
+ if isinstance(cond, const.Value):
27
+ else_successor = Successor(
28
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
29
+ )
30
+ then_successor = Successor(
31
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
32
+ )
33
+ if cond.data:
34
+ frame.worklist.append(then_successor)
35
+ else:
36
+ frame.worklist.append(else_successor)
37
+ else:
38
+ frame.entries[stmt.cond] = const.Value(True)
39
+ then_successor = Successor(
40
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
41
+ )
42
+ frame.worklist.append(then_successor)
43
+
44
+ frame.entries[stmt.cond] = const.Value(False)
45
+ else_successor = Successor(
46
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
47
+ )
48
+ frame.worklist.append(else_successor)
49
+
50
+ frame.entries[stmt.cond] = cond
51
+ return ()
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("cf")
@@ -0,0 +1,58 @@
1
+ from typing import IO, TypeVar
2
+
3
+ from kirin import emit
4
+ from kirin.interp import Successor, MethodTable, impl
5
+ from kirin.emit.julia import EmitJulia
6
+
7
+ from .stmts import Branch, ConditionalBranch
8
+ from .dialect import dialect
9
+
10
+ IO_t = TypeVar("IO_t", bound=IO)
11
+
12
+
13
+ @dialect.register(key="emit.julia")
14
+ class JuliaMethodTable(MethodTable):
15
+
16
+ @impl(Branch)
17
+ def emit_branch(
18
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Branch
19
+ ):
20
+ interp.writeln(frame, f"@goto {interp.block_id[stmt.successor]};")
21
+ frame.worklist.append(
22
+ Successor(stmt.successor, frame.get_values(stmt.arguments))
23
+ )
24
+ return ()
25
+
26
+ @impl(ConditionalBranch)
27
+ def emit_cbr(
28
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: ConditionalBranch
29
+ ):
30
+ cond = frame.get(stmt.cond)
31
+ interp.writeln(frame, f"if {cond}")
32
+ frame.indent += 1
33
+ values = frame.get_values(stmt.then_arguments)
34
+ block_values = tuple(interp.ssa_id[x] for x in stmt.then_successor.args)
35
+ frame.set_values(stmt.then_successor.args, block_values)
36
+ for x, y in zip(block_values, values):
37
+ interp.writeln(frame, f"{x} = {y};")
38
+ interp.writeln(frame, f"@goto {interp.block_id[stmt.then_successor]};")
39
+ frame.indent -= 1
40
+ interp.writeln(frame, "else")
41
+ frame.indent += 1
42
+
43
+ values = frame.get_values(stmt.else_arguments)
44
+ block_values = tuple(interp.ssa_id[x] for x in stmt.else_successor.args)
45
+ frame.set_values(stmt.else_successor.args, block_values)
46
+ for x, y in zip(block_values, values):
47
+ interp.writeln(frame, f"{x} = {y};")
48
+ interp.writeln(frame, f"@goto {interp.block_id[stmt.else_successor]};")
49
+ frame.indent -= 1
50
+ interp.writeln(frame, "end")
51
+
52
+ frame.worklist.append(
53
+ Successor(stmt.then_successor, frame.get_values(stmt.then_arguments))
54
+ )
55
+ frame.worklist.append(
56
+ Successor(stmt.else_successor, frame.get_values(stmt.else_arguments))
57
+ )
58
+ return ()
@@ -0,0 +1,24 @@
1
+ from kirin.interp import Frame, Successor, Interpreter, MethodTable, impl
2
+ from kirin.dialects.cf.stmts import Branch, ConditionalBranch
3
+ from kirin.dialects.cf.dialect import dialect
4
+
5
+
6
+ @dialect.register
7
+ class CfInterpreter(MethodTable):
8
+
9
+ @impl(Branch)
10
+ def branch(self, interp: Interpreter, frame: Frame, stmt: Branch):
11
+ return Successor(stmt.successor, *frame.get_values(stmt.arguments))
12
+
13
+ @impl(ConditionalBranch)
14
+ def conditional_branch(
15
+ self, interp: Interpreter, frame: Frame, stmt: ConditionalBranch
16
+ ):
17
+ if frame.get(stmt.cond):
18
+ return Successor(
19
+ stmt.then_successor, *frame.get_values(stmt.then_arguments)
20
+ )
21
+ else:
22
+ return Successor(
23
+ stmt.else_successor, *frame.get_values(stmt.else_arguments)
24
+ )
@@ -0,0 +1,68 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+ from kirin.print.printer import Printer
4
+ from kirin.dialects.cf.dialect import dialect
5
+
6
+
7
+ @statement(dialect=dialect)
8
+ class Branch(ir.Statement):
9
+ name = "br"
10
+ traits = frozenset({ir.IsTerminator()})
11
+
12
+ arguments: tuple[ir.SSAValue, ...]
13
+ successor: ir.Block = info.block()
14
+
15
+ def verify(self) -> None:
16
+ return
17
+
18
+ def print_impl(self, printer: Printer) -> None:
19
+ with printer.rich(style="keyword"):
20
+ printer.print_name(self)
21
+
22
+ printer.plain_print(" ")
23
+ printer.plain_print(printer.state.block_id[self.successor])
24
+ printer.print_seq(
25
+ self.arguments,
26
+ delim=", ",
27
+ prefix="(",
28
+ suffix=")",
29
+ )
30
+
31
+
32
+ @statement(dialect=dialect)
33
+ class ConditionalBranch(ir.Statement):
34
+ name = "cond_br"
35
+ traits = frozenset({ir.IsTerminator()})
36
+
37
+ cond: ir.SSAValue = info.argument(types.Bool)
38
+ then_arguments: tuple[ir.SSAValue, ...]
39
+ else_arguments: tuple[ir.SSAValue, ...]
40
+
41
+ then_successor: ir.Block = info.block()
42
+ else_successor: ir.Block = info.block()
43
+
44
+ def print_impl(self, printer: Printer) -> None:
45
+ with printer.rich(style="keyword"):
46
+ printer.print_name(self)
47
+
48
+ printer.plain_print(" ")
49
+ printer.print(self.cond)
50
+
51
+ with printer.rich(style="keyword"):
52
+ printer.plain_print(" goto ")
53
+
54
+ printer.plain_print(printer.state.block_id[self.then_successor])
55
+ printer.plain_print("(")
56
+ printer.print_seq(self.then_arguments, delim=", ")
57
+ printer.plain_print(")")
58
+
59
+ with printer.rich(style="keyword"):
60
+ printer.plain_print(" else ")
61
+
62
+ printer.plain_print(printer.state.block_id[self.else_successor])
63
+ printer.plain_print("(")
64
+ printer.print_seq(self.else_arguments, delim=", ")
65
+ printer.plain_print(")")
66
+
67
+ def verify(self) -> None:
68
+ return
@@ -0,0 +1,27 @@
1
+ from kirin.interp import Successor, MethodTable, AbstractFrame, impl
2
+ from kirin.dialects.cf.stmts import Branch, ConditionalBranch
3
+ from kirin.analysis.typeinfer import TypeInference
4
+ from kirin.dialects.cf.dialect import dialect
5
+
6
+
7
+ @dialect.register(key="typeinfer")
8
+ class TypeInfer(MethodTable):
9
+
10
+ @impl(Branch)
11
+ def branch(self, interp: TypeInference, frame: AbstractFrame, stmt: Branch):
12
+ frame.worklist.append(
13
+ Successor(stmt.successor, *frame.get_values(stmt.arguments))
14
+ )
15
+ return ()
16
+
17
+ @impl(ConditionalBranch)
18
+ def conditional_branch(
19
+ self, interp: TypeInference, frame: AbstractFrame, stmt: ConditionalBranch
20
+ ):
21
+ frame.worklist.append(
22
+ Successor(stmt.else_successor, *frame.get_values(stmt.else_arguments))
23
+ )
24
+ frame.worklist.append(
25
+ Successor(stmt.then_successor, *frame.get_values(stmt.then_arguments))
26
+ )
27
+ return ()
@@ -0,0 +1,23 @@
1
+ """This dialect offers a statement `eltype` for other dialects'
2
+ type inference to query/implement the element type of a value.
3
+ For example, the `ilist` dialect implements the `eltype` statement
4
+ on the `ilist.IList` type to return the element type.
5
+ """
6
+
7
+ from kirin import ir, types
8
+ from kirin.decl import info, statement
9
+
10
+ dialect = ir.Dialect("eltype")
11
+
12
+
13
+ @statement(dialect=dialect)
14
+ class ElType(ir.Statement):
15
+ """Returns the element type of a value.
16
+
17
+ This statement is used by other dialects to query the element type of a value.
18
+ """
19
+
20
+ container: ir.SSAValue = info.argument(types.Any)
21
+ """The value to query the element type of."""
22
+ elem: ir.ResultValue = info.result(types.PyClass(types.TypeAttribute))
23
+ """The element type of the value."""
@@ -0,0 +1,20 @@
1
+ """A function dialect that is compatible with python semantics.
2
+ """
3
+
4
+ from kirin.dialects.func import (
5
+ emit as emit,
6
+ interp as interp,
7
+ constprop as constprop,
8
+ typeinfer as typeinfer,
9
+ )
10
+ from kirin.dialects.func.attrs import Signature as Signature, MethodType as MethodType
11
+ from kirin.dialects.func.stmts import (
12
+ Call as Call,
13
+ Invoke as Invoke,
14
+ Lambda as Lambda,
15
+ Return as Return,
16
+ Function as Function,
17
+ GetField as GetField,
18
+ ConstantNone as ConstantNone,
19
+ )
20
+ from kirin.dialects.func.dialect import dialect as dialect
@@ -0,0 +1,39 @@
1
+ from typing import Generic, TypeVar
2
+ from dataclasses import dataclass
3
+
4
+ from kirin import types
5
+ from kirin.ir import Method, Attribute
6
+ from kirin.print.printer import Printer
7
+ from kirin.dialects.func.dialect import dialect
8
+
9
+ TypeofMethodType = types.PyClass[Method]
10
+ MethodType = types.Generic(
11
+ Method, types.TypeVar("Params", types.Tuple), types.TypeVar("Ret")
12
+ )
13
+ TypeLatticeElem = TypeVar("TypeLatticeElem", bound="types.TypeAttribute")
14
+
15
+
16
+ @dialect.register
17
+ @dataclass
18
+ class Signature(Generic[TypeLatticeElem], Attribute):
19
+ """function body signature.
20
+
21
+ This is not a type attribute because it just stores
22
+ the signature of a function at its definition site.
23
+ We don't perform type inference on this directly.
24
+
25
+ The type of a function is the type of `inputs[0]`, which
26
+ typically is a `MethodType`.
27
+ """
28
+
29
+ name = "Signature"
30
+ inputs: tuple[TypeLatticeElem, ...]
31
+ output: TypeLatticeElem # multi-output must be tuple
32
+
33
+ def __hash__(self) -> int:
34
+ return hash((self.inputs, self.output))
35
+
36
+ def print_impl(self, printer: Printer) -> None:
37
+ printer.print_seq(self.inputs, delim=", ", prefix="(", suffix=")")
38
+ printer.plain_print(" -> ")
39
+ printer.print(self.output)
@@ -0,0 +1,138 @@
1
+ from kirin import ir, types
2
+ from kirin.interp import MethodTable, ReturnValue, StatementResult, impl
3
+ from kirin.analysis import const
4
+ from kirin.dialects.func.stmts import Call, Invoke, Lambda, Return, GetField
5
+ from kirin.dialects.func.dialect import dialect
6
+
7
+
8
+ @dialect.register(key="constprop")
9
+ class DialectConstProp(MethodTable):
10
+
11
+ @impl(Return)
12
+ def return_(
13
+ self, interp: const.Propagate, frame: const.Frame, stmt: Return
14
+ ) -> StatementResult[const.Result]:
15
+ return ReturnValue(frame.get(stmt.value))
16
+
17
+ @impl(Call)
18
+ def call(
19
+ self, interp: const.Propagate, frame: const.Frame, stmt: Call
20
+ ) -> StatementResult[const.Result]:
21
+ # give up on dynamic method calls
22
+ callee = frame.get(stmt.callee)
23
+ if isinstance(callee, const.PartialLambda):
24
+ call_frame, ret = self._call_lambda(
25
+ interp,
26
+ callee,
27
+ interp.permute_values(
28
+ callee.argnames, frame.get_values(stmt.inputs), stmt.kwargs
29
+ ),
30
+ )
31
+ if not call_frame.frame_is_not_pure:
32
+ frame.should_be_pure.add(stmt)
33
+ return (ret,)
34
+
35
+ if not (isinstance(callee, const.Value) and isinstance(callee.data, ir.Method)):
36
+ return (const.Result.bottom(),)
37
+
38
+ mt: ir.Method = callee.data
39
+ call_frame, ret = interp.run_method(
40
+ mt,
41
+ interp.permute_values(
42
+ mt.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
43
+ ),
44
+ )
45
+ if not call_frame.frame_is_not_pure:
46
+ frame.should_be_pure.add(stmt)
47
+ return (ret,)
48
+
49
+ def _call_lambda(
50
+ self,
51
+ interp: const.Propagate,
52
+ callee: const.PartialLambda,
53
+ args: tuple[const.Result, ...],
54
+ ):
55
+ # NOTE: we still use PartialLambda because
56
+ # we want to gurantee what we receive here in captured
57
+ # values are all lattice elements and not just obtain via
58
+ # Const(Method(...)) which is Any.
59
+ if (trait := callee.code.get_trait(ir.SymbolOpInterface)) is not None:
60
+ name = trait.get_sym_name(callee.code).data
61
+ else:
62
+ name = "lambda"
63
+
64
+ mt = ir.Method(
65
+ mod=None,
66
+ py_func=None,
67
+ sym_name=name,
68
+ arg_names=callee.argnames,
69
+ dialects=interp.dialects,
70
+ code=callee.code,
71
+ fields=callee.captured,
72
+ )
73
+ return interp.run_method(mt, args)
74
+
75
+ @impl(Invoke)
76
+ def invoke(
77
+ self,
78
+ interp: const.Propagate,
79
+ frame: const.Frame,
80
+ stmt: Invoke,
81
+ ) -> StatementResult[const.Result]:
82
+ invoke_frame, ret = interp.run_method(
83
+ stmt.callee,
84
+ interp.permute_values(
85
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
86
+ ),
87
+ )
88
+ if not invoke_frame.frame_is_not_pure:
89
+ frame.should_be_pure.add(stmt)
90
+ return (ret,)
91
+
92
+ @impl(Lambda)
93
+ def lambda_(
94
+ self, interp: const.Propagate, frame: const.Frame, stmt: Lambda
95
+ ) -> StatementResult[const.Result]:
96
+ captured = frame.get_values(stmt.captured)
97
+ arg_names = [
98
+ arg.name or str(idx) for idx, arg in enumerate(stmt.body.blocks[0].args)
99
+ ]
100
+ if stmt.body.blocks and types.is_tuple_of(captured, const.Value):
101
+ return (
102
+ const.Value(
103
+ ir.Method(
104
+ mod=None,
105
+ py_func=None,
106
+ sym_name=stmt.sym_name,
107
+ arg_names=arg_names,
108
+ dialects=interp.dialects,
109
+ code=stmt,
110
+ fields=tuple(each.data for each in captured),
111
+ )
112
+ ),
113
+ )
114
+
115
+ return (
116
+ const.PartialLambda(
117
+ arg_names,
118
+ stmt,
119
+ tuple(each for each in captured),
120
+ ),
121
+ )
122
+
123
+ @impl(GetField)
124
+ def getfield(
125
+ self,
126
+ interp: const.Propagate,
127
+ frame: const.Frame,
128
+ stmt: GetField,
129
+ ) -> StatementResult[const.Result]:
130
+ callee_self = frame.get(stmt.obj)
131
+ if isinstance(callee_self, const.Value) and isinstance(
132
+ callee_self.data, ir.Method
133
+ ):
134
+ mt: ir.Method = callee_self.data
135
+ return (const.Value(mt.fields[stmt.field]),)
136
+ elif isinstance(callee_self, const.PartialLambda):
137
+ return (callee_self.captured[stmt.field],)
138
+ return (const.Unknown(),)
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect(name="func")
@@ -0,0 +1,80 @@
1
+ from typing import IO, TypeVar
2
+
3
+ from kirin import emit
4
+ from kirin.interp import MethodTable, InterpreterError, impl
5
+ from kirin.emit.julia import EmitJulia
6
+
7
+ from .stmts import Call, Invoke, Lambda, Return, Function, GetField, ConstantNone
8
+ from .dialect import dialect
9
+
10
+ IO_t = TypeVar("IO_t", bound=IO)
11
+
12
+
13
+ @dialect.register(key="emit.julia")
14
+ class JuliaMethodTable(MethodTable):
15
+
16
+ @impl(Function)
17
+ def emit_function(
18
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Function
19
+ ):
20
+ fn_args = stmt.body.blocks[0].args[1:]
21
+ argnames = frame.get_values(fn_args)
22
+ argtypes = tuple(interp.emit_attribute(x.type) for x in fn_args)
23
+ args = [f"{name}::{type}" for name, type in zip(argnames, argtypes)]
24
+ interp.write(f"function {stmt.sym_name}({', '.join(args)})")
25
+ frame.indent += 1
26
+ interp.run_ssacfg_region(frame, stmt.body)
27
+ frame.indent -= 1
28
+ interp.writeln(frame, "end")
29
+ return ()
30
+
31
+ @impl(Return)
32
+ def emit_return(
33
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Return
34
+ ):
35
+ interp.writeln(frame, f"return {frame.get(stmt.value)}")
36
+ return ()
37
+
38
+ @impl(ConstantNone)
39
+ def emit_constant_none(
40
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: ConstantNone
41
+ ):
42
+ return ("nothing",)
43
+
44
+ @impl(Call)
45
+ def emit_call(self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Call):
46
+ if stmt.kwargs:
47
+ raise InterpreterError("cannot emit kwargs for dyanmic calls")
48
+ return (
49
+ f"{frame.get(stmt.callee)}({', '.join(frame.get_values(stmt.inputs))})",
50
+ )
51
+
52
+ @impl(Invoke)
53
+ def emit_invoke(
54
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Invoke
55
+ ):
56
+ args = interp.permute_values(
57
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
58
+ )
59
+ return (f"{stmt.callee.sym_name}({', '.join(args)})",)
60
+
61
+ @impl(Lambda)
62
+ def emit_lambda(
63
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: Lambda
64
+ ):
65
+ args = tuple(interp.ssa_id[x] for x in stmt.body.blocks[0].args[1:])
66
+ frame.set_values(stmt.body.blocks[0].args, args)
67
+ frame.set_values((stmt.body.blocks[0].args[0],), (stmt.sym_name,))
68
+ frame.captured[stmt.body.blocks[0].args[0]] = frame.get_values(stmt.captured)
69
+ interp.writeln(frame, f"function {stmt.sym_name}({', '.join(args[1:])})")
70
+ frame.indent += 1
71
+ interp.run_ssacfg_region(frame, stmt.body)
72
+ frame.indent -= 1
73
+ interp.writeln(frame, "end")
74
+ return (stmt.sym_name,)
75
+
76
+ @impl(GetField)
77
+ def emit_getfield(
78
+ self, interp: EmitJulia[IO_t], frame: emit.EmitStrFrame, stmt: GetField
79
+ ):
80
+ return (frame.captured[stmt.obj][stmt.field],)
@@ -0,0 +1,68 @@
1
+ from kirin.ir import Method
2
+ from kirin.interp import Frame, MethodTable, ReturnValue, impl, concrete
3
+ from kirin.dialects.func.stmts import (
4
+ Call,
5
+ Invoke,
6
+ Lambda,
7
+ Return,
8
+ GetField,
9
+ ConstantNone,
10
+ )
11
+ from kirin.dialects.func.dialect import dialect
12
+
13
+
14
+ @dialect.register
15
+ class Interpreter(MethodTable):
16
+
17
+ @impl(Call)
18
+ def call(self, interp: concrete.Interpreter, frame: Frame, stmt: Call):
19
+ mt: Method = frame.get(stmt.callee)
20
+ _, result = interp.run_method(
21
+ mt,
22
+ interp.permute_values(
23
+ mt.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
24
+ ),
25
+ )
26
+ return (result,)
27
+
28
+ @impl(Invoke)
29
+ def invoke(self, interp: concrete.Interpreter, frame: Frame, stmt: Invoke):
30
+ _, result = interp.run_method(
31
+ stmt.callee,
32
+ interp.permute_values(
33
+ stmt.callee.arg_names, frame.get_values(stmt.inputs), stmt.kwargs
34
+ ),
35
+ )
36
+ return (result,)
37
+
38
+ @impl(Return)
39
+ def return_(self, interp: concrete.Interpreter, frame: Frame, stmt: Return):
40
+ return ReturnValue(frame.get(stmt.value))
41
+
42
+ @impl(ConstantNone)
43
+ def const_none(
44
+ self, interp: concrete.Interpreter, frame: Frame, stmt: ConstantNone
45
+ ):
46
+ return (None,)
47
+
48
+ @impl(GetField)
49
+ def getfield(self, interp: concrete.Interpreter, frame: Frame, stmt: GetField):
50
+ mt: Method = frame.get(stmt.obj)
51
+ return (mt.fields[stmt.field],)
52
+
53
+ @impl(Lambda)
54
+ def lambda_(self, interp: concrete.Interpreter, frame: Frame, stmt: Lambda):
55
+ return (
56
+ Method(
57
+ mod=None,
58
+ py_func=None,
59
+ sym_name=stmt.name,
60
+ arg_names=[
61
+ arg.name or str(idx)
62
+ for idx, arg in enumerate(stmt.body.blocks[0].args)
63
+ ],
64
+ dialects=interp.dialects,
65
+ code=stmt,
66
+ fields=frame.get_values(stmt.captured),
67
+ ),
68
+ )