kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
kirin/__init__.py ADDED
@@ -0,0 +1,7 @@
1
+ # re-exports the public API of the kirin package
2
+ from kirin import ir
3
+ from kirin.decl import info, statement
4
+
5
+ from . import types as types
6
+
7
+ __all__ = ["ir", "types", "statement", "info"]
@@ -0,0 +1,24 @@
1
+ """Analysis module for kirin.
2
+
3
+ This module contains the analysis framework for kirin. The analysis framework is
4
+ built on top of the interpreter framework. This module provides a set of base classes
5
+ and frameworks for implementing compiler analysis passes on the IR.
6
+
7
+ The analysis framework contains the following modules:
8
+
9
+ - [`cfg`][kirin.analysis.cfg]: Control flow graph for a given IR.
10
+ - [`forward`][kirin.analysis.forward]: Forward dataflow analysis.
11
+ - [`callgraph`][kirin.analysis.callgraph]: Call graph for a given IR.
12
+ - [`typeinfer`][kirin.analysis.typeinfer]: Type inference analysis.
13
+ - [`const`][kirin.analysis.const]: Constants used in the analysis framework.
14
+ """
15
+
16
+ from kirin.analysis import const as const
17
+ from kirin.analysis.cfg import CFG as CFG
18
+ from kirin.analysis.forward import (
19
+ Forward as Forward,
20
+ ForwardExtra as ForwardExtra,
21
+ ForwardFrame as ForwardFrame,
22
+ )
23
+ from kirin.analysis.callgraph import CallGraph as CallGraph
24
+ from kirin.analysis.typeinfer import TypeInference as TypeInference
@@ -0,0 +1,61 @@
1
+ from typing import Iterable
2
+ from dataclasses import field, dataclass
3
+
4
+ from kirin import ir
5
+ from kirin.print import Printable
6
+ from kirin.dialects import func
7
+ from kirin.print.printer import Printer
8
+
9
+
10
+ @dataclass
11
+ class CallGraph(Printable):
12
+ """Call graph for a given [`ir.Method`][kirin.ir.Method].
13
+
14
+ This class implements the [`kirin.graph.Graph`][kirin.graph.Graph] protocol.
15
+
16
+ !!! note "Pretty Printing"
17
+ This object is pretty printable via
18
+ [`.print()`][kirin.print.printable.Printable.print] method.
19
+ """
20
+
21
+ defs: dict[str, ir.Method] = field(default_factory=dict)
22
+ """Mapping from symbol names to methods."""
23
+ backedges: dict[str, set[str]] = field(default_factory=dict)
24
+ """Mapping from symbol names to backedges."""
25
+
26
+ def __init__(self, mt: ir.Method):
27
+ self.defs = {}
28
+ self.backedges = {}
29
+ self.__build(mt)
30
+
31
+ def __build(self, mt: ir.Method):
32
+ self.defs[mt.sym_name] = mt
33
+ for stmt in mt.callable_region.walk():
34
+ if isinstance(stmt, func.Invoke):
35
+ backedges = self.backedges.setdefault(stmt.callee.sym_name, set())
36
+ backedges.add(mt.sym_name)
37
+ self.__build(stmt.callee)
38
+
39
+ def get_neighbors(self, node: str) -> Iterable[str]:
40
+ """Get the neighbors of a node in the call graph."""
41
+ return self.backedges.get(node, ())
42
+
43
+ def get_edges(self) -> Iterable[tuple[str, str]]:
44
+ """Get the edges of the call graph."""
45
+ for node, neighbors in self.backedges.items():
46
+ for neighbor in neighbors:
47
+ yield node, neighbor
48
+
49
+ def get_nodes(self) -> Iterable[str]:
50
+ """Get the nodes of the call graph."""
51
+ return self.defs.keys()
52
+
53
+ def print_impl(self, printer: Printer) -> None:
54
+ for idx, (caller, callee) in enumerate(self.backedges.items()):
55
+ printer.plain_print(caller)
56
+ printer.plain_print(" -> ")
57
+ printer.print_seq(
58
+ callee, delim=", ", prefix="[", suffix="]", emit=printer.plain_print
59
+ )
60
+ if idx < len(self.backedges) - 1:
61
+ printer.print_newline()
kirin/analysis/cfg.py ADDED
@@ -0,0 +1,112 @@
1
+ from typing import Iterable
2
+ from functools import cached_property
3
+ from dataclasses import dataclass
4
+
5
+ from kirin import ir
6
+ from kirin.print import Printer, Printable
7
+ from kirin.worklist import WorkList
8
+
9
+
10
+ @dataclass
11
+ class CFG(Printable):
12
+ """Control Flow Graph of a given IR statement.
13
+
14
+ This class implements the [`kirin.graph.Graph`][kirin.graph.Graph] protocol.
15
+
16
+ !!! note "Pretty Printing"
17
+ This object is pretty printable via
18
+ [`.print()`][kirin.print.printable.Printable.print] method.
19
+ """
20
+
21
+ parent: ir.Region
22
+ """Parent IR statement.
23
+ """
24
+ entry: ir.Block | None = None
25
+ """Entry block of the CFG.
26
+ """
27
+
28
+ def __post_init__(self):
29
+ if self.parent.blocks.isempty():
30
+ self.entry = None
31
+ else:
32
+ self.entry = self.parent.blocks[0]
33
+
34
+ @cached_property
35
+ def predecessors(self):
36
+ """CFG data, mapping a block to its predecessors."""
37
+ graph: dict[ir.Block, set[ir.Block]] = {}
38
+ for block, neighbors in self.successors.items():
39
+ for neighbor in neighbors:
40
+ graph.setdefault(neighbor, set()).add(block)
41
+ return graph
42
+
43
+ @cached_property
44
+ def successors(self):
45
+ """CFG data, mapping a block to its neighbors."""
46
+ graph: dict[ir.Block, set[ir.Block]] = {}
47
+ visited: set[ir.Block] = set()
48
+ worklist: WorkList[ir.Block] = WorkList()
49
+ if self.parent.blocks.isempty():
50
+ return graph
51
+
52
+ block = self.entry
53
+ while block is not None:
54
+ neighbors = graph.setdefault(block, set())
55
+ if block.last_stmt is not None:
56
+ neighbors.update(block.last_stmt.successors)
57
+ worklist.extend(block.last_stmt.successors)
58
+ visited.add(block)
59
+
60
+ block = worklist.pop()
61
+ while block is not None and block in visited:
62
+ block = worklist.pop()
63
+ return graph
64
+
65
+ # graph interface
66
+ def get_neighbors(self, node: ir.Block) -> Iterable[ir.Block]:
67
+ return self.successors[node]
68
+
69
+ def get_edges(self) -> Iterable[tuple[ir.Block, ir.Block]]:
70
+ for block, neighbors in self.successors.items():
71
+ for neighbor in neighbors:
72
+ yield block, neighbor
73
+
74
+ def get_nodes(self) -> Iterable[ir.Block]:
75
+ return self.successors.keys()
76
+
77
+ # printable interface
78
+ def print_impl(self, printer: Printer) -> None:
79
+ # NOTE: this make sure we use the same name
80
+ # as the printing of CFG parent.
81
+ with printer.string_io():
82
+ self.parent.print(printer)
83
+
84
+ printer.plain_print("Successors:")
85
+ printer.print_newline()
86
+ for block, neighbors in self.successors.items():
87
+ printer.plain_print(f"{printer.state.block_id[block]} -> ", end="")
88
+ printer.print_seq(
89
+ neighbors,
90
+ delim=", ",
91
+ prefix="[",
92
+ suffix="]",
93
+ emit=lambda block: printer.plain_print(printer.state.block_id[block]),
94
+ )
95
+ printer.print_newline()
96
+
97
+ if self.predecessors:
98
+ printer.print_newline()
99
+ printer.plain_print("Predecessors:")
100
+ printer.print_newline()
101
+ for block, neighbors in self.predecessors.items():
102
+ printer.plain_print(f"{printer.state.block_id[block]} <- ", end="")
103
+ printer.print_seq(
104
+ neighbors,
105
+ delim=", ",
106
+ prefix="[",
107
+ suffix="]",
108
+ emit=lambda block: printer.plain_print(
109
+ printer.state.block_id[block]
110
+ ),
111
+ )
112
+ printer.print_newline()
@@ -0,0 +1,20 @@
1
+ """Const analysis module.
2
+
3
+ This module contains the constant analysis framework for kirin. The constant
4
+ analysis framework is built on top of the interpreter framework.
5
+
6
+ This module provides a lattice for constant propagation analysis and a
7
+ propagation algorithm for computing the constant values for each SSA value in
8
+ the IR.
9
+ """
10
+
11
+ from .prop import Frame as Frame, Propagate as Propagate
12
+ from .lattice import (
13
+ Value as Value,
14
+ Bottom as Bottom,
15
+ Result as Result,
16
+ Unknown as Unknown,
17
+ PartialConst as PartialConst,
18
+ PartialTuple as PartialTuple,
19
+ PartialLambda as PartialLambda,
20
+ )
@@ -0,0 +1,2 @@
1
+ class _ElemVisitor:
2
+ pass
@@ -0,0 +1,8 @@
1
+ from .lattice import Value, Bottom, Unknown, PartialTuple, PartialLambda
2
+
3
+ class _ElemVisitor:
4
+ def is_subseteq_Value(self, other: Value) -> bool: ...
5
+ def is_subseteq_NotConst(self, other: Unknown) -> bool: ...
6
+ def is_subseteq_Unknown(self, other: Bottom) -> bool: ...
7
+ def is_subseteq_PartialTuple(self, other: PartialTuple) -> bool: ...
8
+ def is_subseteq_PartialLambda(self, other: PartialLambda) -> bool: ...
@@ -0,0 +1,219 @@
1
+ """Lattice for constant analysis.
2
+ """
3
+
4
+ from typing import Any, final
5
+ from dataclasses import dataclass
6
+
7
+ from kirin import ir
8
+ from kirin.lattice import (
9
+ BoundedLattice,
10
+ IsSubsetEqMixin,
11
+ SimpleJoinMixin,
12
+ SimpleMeetMixin,
13
+ )
14
+ from kirin.ir.attrs.abc import LatticeAttributeMeta, SingletonLatticeAttributeMeta
15
+ from kirin.print.printer import Printer
16
+
17
+ from ._visitor import _ElemVisitor
18
+
19
+
20
+ @dataclass
21
+ class Result(
22
+ ir.Attribute,
23
+ IsSubsetEqMixin["Result"],
24
+ SimpleJoinMixin["Result"],
25
+ SimpleMeetMixin["Result"],
26
+ BoundedLattice["Result"],
27
+ _ElemVisitor,
28
+ metaclass=LatticeAttributeMeta,
29
+ ):
30
+ """Base class for constant analysis results."""
31
+
32
+ @classmethod
33
+ def top(cls) -> "Result":
34
+ return Unknown()
35
+
36
+ @classmethod
37
+ def bottom(cls) -> "Result":
38
+ return Bottom()
39
+
40
+ def print_impl(self, printer: Printer) -> None:
41
+ printer.plain_print(repr(self))
42
+
43
+
44
+ @final
45
+ @dataclass
46
+ class Unknown(Result, metaclass=SingletonLatticeAttributeMeta):
47
+ """Unknown constant value. This is the top element of the lattice."""
48
+
49
+ def is_subseteq(self, other: Result) -> bool:
50
+ return isinstance(other, Unknown)
51
+
52
+ def __hash__(self) -> int:
53
+ return id(self)
54
+
55
+
56
+ @final
57
+ @dataclass
58
+ class Bottom(Result, metaclass=SingletonLatticeAttributeMeta):
59
+ """Bottom element of the lattice."""
60
+
61
+ def is_subseteq(self, other: Result) -> bool:
62
+ return True
63
+
64
+ def __hash__(self) -> int:
65
+ return id(self)
66
+
67
+
68
+ @final
69
+ @dataclass
70
+ class Value(Result):
71
+ """Constant value. Wraps any hashable Python value."""
72
+
73
+ data: Any
74
+
75
+ def is_subseteq_Value(self, other: "Value") -> bool:
76
+ return self.data == other.data
77
+
78
+ def is_equal(self, other: Result) -> bool:
79
+ if isinstance(other, Value):
80
+ return self.data == other.data
81
+ return False
82
+
83
+ def __hash__(self) -> int:
84
+ # NOTE: we use id here because the data
85
+ # may not be hashable. This is fine because
86
+ # the data is guaranteed to be unique.
87
+ return id(self)
88
+
89
+
90
+ @dataclass
91
+ class PartialConst(Result):
92
+ """Base class for partial constant values."""
93
+
94
+ pass
95
+
96
+
97
+ @final
98
+ class PartialTupleMeta(LatticeAttributeMeta):
99
+ """Metaclass for PartialTuple.
100
+
101
+ This metaclass canonicalizes PartialTuple instances with all Value elements
102
+ into a single Value instance.
103
+ """
104
+
105
+ def __call__(cls, data: tuple[Result, ...]):
106
+ if all(isinstance(x, Value) for x in data):
107
+ return Value(tuple(x.data for x in data)) # type: ignore
108
+ return super().__call__(data)
109
+
110
+
111
+ @final
112
+ @dataclass
113
+ class PartialTuple(PartialConst, metaclass=PartialTupleMeta):
114
+ """Partial tuple constant value."""
115
+
116
+ data: tuple[Result, ...]
117
+
118
+ def join(self, other: Result) -> Result:
119
+ if other.is_subseteq(self):
120
+ return self
121
+ elif self.is_subseteq(other):
122
+ return other
123
+ elif isinstance(other, PartialTuple):
124
+ return PartialTuple(tuple(x.join(y) for x, y in zip(self.data, other.data)))
125
+ elif isinstance(other, Value) and isinstance(other.data, tuple):
126
+ return PartialTuple(
127
+ tuple(x.join(Value(y)) for x, y in zip(self.data, other.data))
128
+ )
129
+ return Unknown()
130
+
131
+ def meet(self, other: Result) -> Result:
132
+ if self.is_subseteq(other):
133
+ return self
134
+ elif other.is_subseteq(self):
135
+ return other
136
+ elif isinstance(other, PartialTuple):
137
+ return PartialTuple(tuple(x.meet(y) for x, y in zip(self.data, other.data)))
138
+ elif isinstance(other, Value) and isinstance(other.data, tuple):
139
+ return PartialTuple(
140
+ tuple(x.meet(Value(y)) for x, y in zip(self.data, other.data))
141
+ )
142
+ return self.bottom()
143
+
144
+ def is_equal(self, other: Result) -> bool:
145
+ if isinstance(other, PartialTuple):
146
+ return all(x.is_equal(y) for x, y in zip(self.data, other.data))
147
+ elif isinstance(other, Value) and isinstance(other.data, tuple):
148
+ return all(x.is_equal(Value(y)) for x, y in zip(self.data, other.data))
149
+ return False
150
+
151
+ def is_subseteq_PartialTuple(self, other: "PartialTuple") -> bool:
152
+ return all(x.is_subseteq(y) for x, y in zip(self.data, other.data))
153
+
154
+ def is_subseteq_Value(self, other: Value) -> bool:
155
+ if isinstance(other.data, tuple):
156
+ return all(x.is_subseteq(Value(y)) for x, y in zip(self.data, other.data))
157
+ return False
158
+
159
+ def __hash__(self) -> int:
160
+ return hash(self.data)
161
+
162
+
163
+ @final
164
+ @dataclass
165
+ class PartialLambda(PartialConst):
166
+ """Partial lambda constant value.
167
+
168
+ This represents a closure with captured variables.
169
+ """
170
+
171
+ argnames: list[str]
172
+ code: ir.Statement
173
+ captured: tuple[Result, ...]
174
+
175
+ def __hash__(self) -> int:
176
+ return hash((self.argnames, self.code, self.captured))
177
+
178
+ def is_subseteq_PartialLambda(self, other: "PartialLambda") -> bool:
179
+ if self.code is not other.code:
180
+ return False
181
+ if len(self.captured) != len(other.captured):
182
+ return False
183
+
184
+ return all(x.is_subseteq(y) for x, y in zip(self.captured, other.captured))
185
+
186
+ def join(self, other: Result) -> Result:
187
+ if other is other.bottom():
188
+ return self
189
+
190
+ if not isinstance(other, PartialLambda):
191
+ return Unknown().join(other) # widen self
192
+
193
+ if self.code is not other.code:
194
+ return Unknown() # lambda stmt is pure
195
+
196
+ if len(self.captured) != len(other.captured):
197
+ return self.bottom() # err
198
+
199
+ return PartialLambda(
200
+ self.argnames,
201
+ self.code,
202
+ tuple(x.join(y) for x, y in zip(self.captured, other.captured)),
203
+ )
204
+
205
+ def meet(self, other: Result) -> Result:
206
+ if not isinstance(other, PartialLambda):
207
+ return Unknown().meet(other)
208
+
209
+ if self.code is not other.code:
210
+ return self.bottom()
211
+
212
+ if len(self.captured) != len(other.captured):
213
+ return Unknown()
214
+
215
+ return PartialLambda(
216
+ self.argnames,
217
+ self.code,
218
+ tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
219
+ )
@@ -0,0 +1,116 @@
1
+ from dataclasses import field, dataclass
2
+
3
+ from kirin import ir, types, interp
4
+ from kirin.analysis.forward import ForwardExtra, ForwardFrame
5
+
6
+ from .lattice import Value, Result, Unknown
7
+
8
+
9
+ @dataclass
10
+ class Frame(ForwardFrame[Result]):
11
+ should_be_pure: set[ir.Statement] = field(default_factory=set)
12
+ """If any ir.MaybePure is actually pure."""
13
+ frame_is_not_pure: bool = False
14
+ """If we hit any non-pure statement."""
15
+
16
+
17
+ @dataclass
18
+ class Propagate(ForwardExtra[Frame, Result]):
19
+ """Forward dataflow analysis for constant propagation.
20
+
21
+ This analysis is a forward dataflow analysis that propagates constant values
22
+ through the program. It uses the `Result` lattice to track the constant
23
+ values and purity of the values.
24
+
25
+ The analysis is implemented as a forward dataflow analysis, where the
26
+ `eval_stmt` method is overridden to handle the different types of statements
27
+ in the IR. The analysis uses the `interp.Interpreter` to evaluate the
28
+ statements and propagate the constant values.
29
+
30
+ When a statement is registered under the "constprop" key in the method table,
31
+ the analysis will call the method to evaluate the statement instead of using
32
+ the interpreter. This allows for custom handling of statements.
33
+ """
34
+
35
+ keys = ["constprop"]
36
+ lattice = Result
37
+
38
+ _interp: interp.Interpreter = field(init=False)
39
+
40
+ def __post_init__(self) -> None:
41
+ super().__post_init__()
42
+ self._interp = interp.Interpreter(
43
+ self.dialects,
44
+ fuel=self.fuel,
45
+ debug=self.debug,
46
+ max_depth=self.max_depth,
47
+ max_python_recursion_depth=self.max_python_recursion_depth,
48
+ )
49
+
50
+ def initialize(self):
51
+ super().initialize()
52
+ self._interp.initialize()
53
+ return self
54
+
55
+ def new_frame(self, code: ir.Statement) -> Frame:
56
+ return Frame.from_func_like(code)
57
+
58
+ def _try_eval_const_pure(
59
+ self,
60
+ frame: Frame,
61
+ stmt: ir.Statement,
62
+ values: tuple[Value, ...],
63
+ ) -> interp.StatementResult[Result]:
64
+ _frame = self._interp.new_frame(frame.code)
65
+ _frame.set_values(stmt.args, tuple(x.data for x in values))
66
+ method = self._interp.lookup_registry(frame, stmt)
67
+ if method is not None:
68
+ value = method(self._interp, _frame, stmt)
69
+ else:
70
+ return (Unknown(),)
71
+ match value:
72
+ case tuple():
73
+ return tuple(Value(each) for each in value)
74
+ case interp.ReturnValue(ret):
75
+ return interp.ReturnValue(Value(ret))
76
+ case interp.YieldValue(yields):
77
+ return interp.YieldValue(tuple(Value(each) for each in yields))
78
+ case interp.Successor(block, args):
79
+ return interp.Successor(
80
+ block,
81
+ *tuple(Value(each) for each in args),
82
+ )
83
+
84
+ def eval_stmt(
85
+ self, frame: Frame, stmt: ir.Statement
86
+ ) -> interp.StatementResult[Result]:
87
+ if stmt.has_trait(ir.ConstantLike):
88
+ return self._try_eval_const_pure(frame, stmt, ())
89
+ elif stmt.has_trait(ir.Pure):
90
+ values = frame.get_values(stmt.args)
91
+ if types.is_tuple_of(values, Value):
92
+ return self._try_eval_const_pure(frame, stmt, values)
93
+
94
+ method = self.lookup_registry(frame, stmt)
95
+ if method is None:
96
+ if stmt.has_trait(ir.Pure):
97
+ return (Unknown(),) # no implementation but pure
98
+ # not pure, and no implementation, let's say it's not pure
99
+ frame.frame_is_not_pure = True
100
+ return (Unknown(),)
101
+
102
+ ret = method(self, frame, stmt)
103
+ if stmt.has_trait(ir.IsTerminator):
104
+ return ret
105
+ elif not stmt.has_trait(ir.MaybePure): # cannot be pure at all
106
+ frame.frame_is_not_pure = True
107
+ elif (
108
+ stmt not in frame.should_be_pure
109
+ ): # implementation cannot decide if it's pure
110
+ frame.frame_is_not_pure = True
111
+ return ret
112
+
113
+ def run_method(
114
+ self, method: ir.Method, args: tuple[Result, ...]
115
+ ) -> tuple[Frame, Result]:
116
+ return self.run_callable(method.code, (Value(method),) + args)
@@ -0,0 +1,100 @@
1
+ import sys
2
+ from abc import ABC
3
+ from typing import TypeVar, Iterable
4
+ from dataclasses import dataclass
5
+
6
+ from kirin import ir, interp
7
+ from kirin.interp import AbstractFrame, AbstractInterpreter
8
+ from kirin.lattice import BoundedLattice
9
+
10
+ ExtraType = TypeVar("ExtraType")
11
+ LatticeElemType = TypeVar("LatticeElemType", bound=BoundedLattice)
12
+
13
+
14
+ @dataclass
15
+ class ForwardFrame(AbstractFrame[LatticeElemType]):
16
+ pass
17
+
18
+
19
+ ForwardFrameType = TypeVar("ForwardFrameType", bound=ForwardFrame)
20
+
21
+
22
+ @dataclass
23
+ class ForwardExtra(
24
+ AbstractInterpreter[ForwardFrameType, LatticeElemType],
25
+ ABC,
26
+ ):
27
+ """Abstract interpreter but record results for each SSA value.
28
+
29
+ Params:
30
+ LatticeElemType: The lattice element type.
31
+ ExtraType: The type of extra information to be stored in the frame.
32
+ """
33
+
34
+ def run_analysis(
35
+ self,
36
+ method: ir.Method,
37
+ args: tuple[LatticeElemType, ...] | None = None,
38
+ ) -> tuple[ForwardFrameType, LatticeElemType]:
39
+ """Run the forward dataflow analysis.
40
+
41
+ Args:
42
+ method(ir.Method): The method to analyze.
43
+ args(tuple[LatticeElemType]): The arguments to the method. Defaults to tuple of top values.
44
+
45
+ Returns:
46
+ ForwardFrameType: The results of the analysis contained in the frame.
47
+ LatticeElemType: The result of the analysis for the method return value.
48
+ """
49
+ args = args or tuple(self.lattice.top() for _ in method.args)
50
+
51
+ if self._eval_lock:
52
+ raise interp.InterpreterError(
53
+ "recursive eval is not allowed, use run_method instead"
54
+ )
55
+
56
+ self._eval_lock = True
57
+ self.initialize()
58
+ current_recursion_limit = sys.getrecursionlimit()
59
+ sys.setrecursionlimit(self.max_python_recursion_depth)
60
+ try:
61
+ frame, ret = self.run_method(method, args)
62
+ except interp.InterpreterError:
63
+ # NOTE: initialize will create new State
64
+ # so we don't need to copy the frames.
65
+ return self.new_frame(method.code), self.lattice.bottom()
66
+ finally:
67
+ self._eval_lock = False
68
+ sys.setrecursionlimit(current_recursion_limit)
69
+ return frame, ret
70
+
71
+ def set_values(
72
+ self,
73
+ frame: AbstractFrame[LatticeElemType],
74
+ ssa: Iterable[ir.SSAValue],
75
+ results: Iterable[LatticeElemType],
76
+ ):
77
+ """Set the abstract values for the given SSA values in the frame.
78
+
79
+ This method is used to customize how the abstract values are set in
80
+ the frame. By default, the abstract values are set directly in the
81
+ frame. This method is overridden to join the results if the SSA value
82
+ already exists in the frame.
83
+ """
84
+ for ssa_value, result in zip(ssa, results):
85
+ if ssa_value in frame.entries:
86
+ frame.entries[ssa_value] = frame.entries[ssa_value].join(result)
87
+ else:
88
+ frame.entries[ssa_value] = result
89
+
90
+
91
+ class Forward(ForwardExtra[ForwardFrame[LatticeElemType], LatticeElemType], ABC):
92
+ """Forward dataflow analysis.
93
+
94
+ This is the base class for forward dataflow analysis. If your analysis
95
+ requires extra information per frame, you should subclass
96
+ [`ForwardExtra`][kirin.analysis.forward.ForwardExtra] instead.
97
+ """
98
+
99
+ def new_frame(self, code: ir.Statement) -> ForwardFrame[LatticeElemType]:
100
+ return ForwardFrame.from_func_like(code)
@@ -0,0 +1,5 @@
1
+ """Type inference analysis for kirin.
2
+ """
3
+
4
+ from .solve import TypeResolution as TypeResolution
5
+ from .analysis import TypeInference as TypeInference