kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,120 @@
1
+ """Traits for customizing lowering of Python `with` syntax to a statement.
2
+ """
3
+
4
+ import ast
5
+ from typing import TYPE_CHECKING, TypeVar
6
+ from dataclasses import dataclass
7
+
8
+ from kirin.exceptions import DialectLoweringError
9
+
10
+ from ..abc import PythonLoweringTrait
11
+
12
+ if TYPE_CHECKING:
13
+ from kirin.ir import Statement
14
+ from kirin.lowering import Result, LoweringState
15
+
16
+ StatementType = TypeVar("StatementType", bound="Statement")
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class FromPythonWith(PythonLoweringTrait[StatementType, ast.With]):
21
+ """Trait for customizing lowering of Python with statements to a statement.
22
+
23
+ Subclassing this trait allows for customizing the lowering of Python with
24
+ statements to the statement. The `lower` method should be implemented to parse
25
+ the arguments from the Python with statement and construct the statement instance.
26
+ """
27
+
28
+ pass
29
+
30
+
31
+ @dataclass(frozen=True)
32
+ class FromPythonWithSingleItem(FromPythonWith[StatementType]):
33
+ """Trait for customizing lowering of the following Python with syntax to a statement:
34
+
35
+ ```python
36
+ with <stmt>[ as <name>]:
37
+ <body>
38
+ ```
39
+
40
+ where `<stmt>` is the statement being lowered, `<name>` is an optional name for the result
41
+ of the statement, and `<body>` is the body of the with statement. The optional `as <name>`
42
+ is not valid when the statement has no results.
43
+
44
+ This syntax is slightly different from the standard Python `with` statement in that
45
+ `<name>` refers to the result of the statement, not the context manager. Thus typically
46
+ one sould access `<name>` in `<body>` to use the result of the statement.
47
+
48
+ In some cases, however, `<name>` may be used as a reference of a special value `self` that
49
+ is passed to the `<body>` of the statement. This is useful for statements that have a similar
50
+ behavior to a closure.
51
+ """
52
+
53
+ def lower(
54
+ self, stmt: type[StatementType], state: "LoweringState", node: ast.With
55
+ ) -> "Result":
56
+ from kirin import ir, lowering
57
+ from kirin.decl import fields
58
+ from kirin.dialects import cf
59
+
60
+ fs = fields(stmt)
61
+ if len(fs.regions) != 1:
62
+ raise DialectLoweringError(
63
+ "Expected exactly one region in statement declaration"
64
+ )
65
+
66
+ if len(node.items) != 1:
67
+ raise DialectLoweringError("Expected exactly one item in statement")
68
+
69
+ item, body = node.items[0], node.body
70
+ if not isinstance(item.context_expr, ast.Call):
71
+ raise DialectLoweringError(
72
+ f"Expected context expression to be a call for with {stmt.name}"
73
+ )
74
+
75
+ body_frame = lowering.Frame.from_stmts(body, state, parent=state.current_frame)
76
+ state.push_frame(body_frame)
77
+ state.exhaust()
78
+ region_name, region_info = next(iter(fs.regions.items()))
79
+ if region_info.multi: # branch to exit block if not terminated
80
+ for block in body_frame.curr_region.blocks:
81
+ if block.last_stmt is None or not block.last_stmt.has_trait(
82
+ ir.IsTerminator
83
+ ):
84
+ block.stmts.append(
85
+ cf.Branch(arguments=(), successor=body_frame.next_block)
86
+ )
87
+ state.pop_frame()
88
+ else:
89
+ if len(body_frame.curr_region.blocks) != 1:
90
+ raise DialectLoweringError(
91
+ f"Expected exactly one block in region {region_name}"
92
+ )
93
+ state.pop_frame(finalize_next=False)
94
+
95
+ args, kwargs = state.default_Call_inputs(stmt, item.context_expr)
96
+ kwargs[region_name] = body_frame.curr_region
97
+ results = state.append_stmt(stmt(*args.values(), **kwargs)).results
98
+ if len(results) == 0:
99
+ return lowering.Result()
100
+ elif len(results) > 1:
101
+ raise DialectLoweringError(
102
+ f"Expected exactly one result or no result from statement {stmt.name}"
103
+ )
104
+
105
+ result = results[0]
106
+ if item.optional_vars is not None and isinstance(item.optional_vars, ast.Name):
107
+ result.name = item.optional_vars.id
108
+ state.current_frame.defs[result.name] = result
109
+ return lowering.Result(result)
110
+
111
+ def verify(self, stmt: "Statement"):
112
+ assert (
113
+ len(stmt.regions) == 1
114
+ ), "FromPythonWithSingleItem statements must have one region"
115
+ assert (
116
+ len(stmt.successors) == 0
117
+ ), "FromPythonWithSingleItem statements cannot have successors"
118
+ assert (
119
+ len(stmt.results) <= 1
120
+ ), "FromPythonWithSingleItem statements can have at most one result"
@@ -0,0 +1,2 @@
1
+ """Builtin region traits.
2
+ """
@@ -0,0 +1,22 @@
1
+ """SSACFG region trait.
2
+
3
+ This module defines the SSACFGRegion trait, which is used to indicate that a
4
+ region has an SSACFG graph.
5
+ """
6
+
7
+ from typing import TYPE_CHECKING
8
+ from dataclasses import dataclass
9
+
10
+ from kirin.ir.traits.abc import RegionTrait
11
+
12
+ if TYPE_CHECKING:
13
+ from kirin.ir import Region
14
+
15
+
16
+ @dataclass(frozen=True)
17
+ class SSACFGRegion(RegionTrait):
18
+
19
+ def get_graph(self, region: "Region"):
20
+ from kirin.analysis.cfg import CFG
21
+
22
+ return CFG(region)
@@ -0,0 +1,57 @@
1
+ from typing import TYPE_CHECKING
2
+ from dataclasses import dataclass
3
+
4
+ from kirin.exceptions import VerificationError
5
+ from kirin.ir.attrs.py import PyAttr
6
+ from kirin.ir.traits.abc import StmtTrait
7
+
8
+ if TYPE_CHECKING:
9
+ from kirin.ir import Statement
10
+
11
+
12
+ @dataclass(frozen=True)
13
+ class SymbolOpInterface(StmtTrait):
14
+ """A trait that indicates that a statement is a symbol operation.
15
+
16
+ A symbol operation is a statement that has a symbol name attribute.
17
+ """
18
+
19
+ def get_sym_name(self, stmt: "Statement") -> "PyAttr[str]":
20
+ sym_name: PyAttr[str] | None = stmt.get_attr_or_prop("sym_name") # type: ignore
21
+ # NOTE: unlike MLIR or xDSL we do not allow empty symbol names
22
+ if sym_name is None:
23
+ raise ValueError(f"Statement {stmt.name} does not have a symbol name")
24
+ return sym_name
25
+
26
+ def verify(self, stmt: "Statement"):
27
+ from kirin.types import String
28
+
29
+ sym_name = self.get_sym_name(stmt)
30
+ if not (isinstance(sym_name, PyAttr) and sym_name.type.is_subseteq(String)):
31
+ raise ValueError(f"Symbol name {sym_name} is not a string attribute")
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class SymbolTable(StmtTrait):
36
+ """
37
+ Statement with SymbolTable trait can only have one region with one block.
38
+ """
39
+
40
+ @staticmethod
41
+ def walk(stmt: "Statement"):
42
+ return stmt.regions[0].blocks[0].stmts
43
+
44
+ def verify(self, stmt: "Statement"):
45
+ if len(stmt.regions) != 1:
46
+ raise VerificationError(
47
+ stmt,
48
+ f"Statement {stmt.name} with SymbolTable trait must have exactly one region",
49
+ )
50
+
51
+ if len(stmt.regions[0].blocks) != 1:
52
+ raise VerificationError(
53
+ stmt,
54
+ f"Statement {stmt.name} with SymbolTable trait must have exactly one block",
55
+ )
56
+
57
+ # TODO: check uniqueness of symbol names
kirin/ir/use.py ADDED
@@ -0,0 +1,17 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+ from dataclasses import dataclass
5
+
6
+ if TYPE_CHECKING:
7
+ from kirin.ir.nodes.stmt import Statement
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Use:
12
+ """A use of an SSA value in a statement."""
13
+
14
+ stmt: Statement
15
+ """The statement that uses the SSA value."""
16
+ index: int
17
+ """The index of the use in the statement."""
@@ -0,0 +1,13 @@
1
+ from kirin.lattice.abc import (
2
+ Lattice as Lattice,
3
+ UnionMeta as UnionMeta,
4
+ LatticeMeta as LatticeMeta,
5
+ SingletonMeta as SingletonMeta,
6
+ BoundedLattice as BoundedLattice,
7
+ )
8
+ from kirin.lattice.empty import EmptyLattice as EmptyLattice
9
+ from kirin.lattice.mixin import (
10
+ IsSubsetEqMixin as IsSubsetEqMixin,
11
+ SimpleJoinMixin as SimpleJoinMixin,
12
+ SimpleMeetMixin as SimpleMeetMixin,
13
+ )
kirin/lattice/abc.py ADDED
@@ -0,0 +1,128 @@
1
+ from abc import ABC, ABCMeta, abstractmethod
2
+ from typing import Generic, TypeVar, Iterable
3
+
4
+
5
+ class LatticeMeta(ABCMeta):
6
+ pass
7
+
8
+
9
+ class SingletonMeta(LatticeMeta):
10
+ """
11
+ Singleton metaclass for lattices. It ensures that only one instance of a lattice is created.
12
+
13
+ See https://stackoverflow.com/questions/674304/why-is-init-always-called-after-new/8665179#8665179
14
+ """
15
+
16
+ def __init__(cls, name, bases, attrs):
17
+ super().__init__(name, bases, attrs)
18
+ cls._instance = None
19
+
20
+ def __call__(cls):
21
+ if cls._instance is None:
22
+ cls._instance = super().__call__()
23
+ return cls._instance
24
+
25
+
26
+ LatticeType = TypeVar("LatticeType", bound="Lattice")
27
+
28
+
29
+ class Lattice(ABC, Generic[LatticeType], metaclass=LatticeMeta):
30
+ """ABC for lattices as Python class.
31
+
32
+ While `Lattice` is only an interface, `LatticeABC` is an abstract
33
+ class that can be inherited from. This provides a few default
34
+ implementations for the lattice operations.
35
+ """
36
+
37
+ @abstractmethod
38
+ def join(self, other: LatticeType) -> LatticeType:
39
+ """Join operation."""
40
+ ...
41
+
42
+ @abstractmethod
43
+ def meet(self, other: LatticeType) -> LatticeType:
44
+ """Meet operation."""
45
+ ...
46
+
47
+ @abstractmethod
48
+ def is_subseteq(self, other: LatticeType) -> bool:
49
+ """Subseteq operation."""
50
+ ...
51
+
52
+ def is_equal(self, other: LatticeType) -> bool:
53
+ """Check if two lattices are equal."""
54
+ if self is other:
55
+ return True
56
+ else:
57
+ return self.is_subseteq(other) and other.is_subseteq(self)
58
+
59
+ def is_subset(self, other: LatticeType) -> bool:
60
+ return self.is_subseteq(other) and not other.is_subseteq(self)
61
+
62
+ def __eq__(self, value: object) -> bool:
63
+ raise NotImplementedError(
64
+ "Equality is not implemented for lattices, use is_equal instead"
65
+ )
66
+
67
+ def __hash__(self) -> int:
68
+ raise NotImplementedError("Hash is not implemented for lattices")
69
+
70
+
71
+ BoundedLatticeType = TypeVar("BoundedLatticeType", bound="BoundedLattice")
72
+
73
+
74
+ class BoundedLattice(Lattice[BoundedLatticeType]):
75
+ """ABC for bounded lattices as Python class.
76
+
77
+ `BoundedLattice` is an abstract class that can be inherited from.
78
+ It requires the implementation of the `bottom` and `top` methods.
79
+ """
80
+
81
+ @classmethod
82
+ @abstractmethod
83
+ def bottom(cls) -> BoundedLatticeType: ...
84
+
85
+ @classmethod
86
+ @abstractmethod
87
+ def top(cls) -> BoundedLatticeType: ...
88
+
89
+
90
+ class UnionMeta(LatticeMeta):
91
+ """Meta class for union types. It simplifies the union if possible."""
92
+
93
+ def __call__(
94
+ self,
95
+ typ: Iterable[LatticeType] | LatticeType,
96
+ *others: LatticeType,
97
+ ):
98
+ from kirin.lattice.abc import Lattice
99
+
100
+ if isinstance(typ, Lattice):
101
+ typs: Iterable[LatticeType] = (typ, *others)
102
+ elif not others:
103
+ typs = typ
104
+ else:
105
+ raise ValueError(
106
+ "Expected an iterable of types or variadic arguments of types"
107
+ )
108
+
109
+ # try if the union can be simplified
110
+ params: list[LatticeType] = []
111
+ for typ in typs:
112
+ contains = False
113
+ for idx, other in enumerate(params):
114
+ if typ.is_subseteq(other):
115
+ contains = True
116
+ break
117
+ elif other.is_subseteq(typ):
118
+ params[idx] = typ
119
+ contains = True
120
+ break
121
+
122
+ if not contains:
123
+ params.append(typ)
124
+
125
+ if len(params) == 1:
126
+ return params[0]
127
+
128
+ return super(UnionMeta, self).__call__(*params)
kirin/lattice/empty.py ADDED
@@ -0,0 +1,25 @@
1
+ from kirin.lattice.abc import SingletonMeta, BoundedLattice
2
+
3
+
4
+ class EmptyLattice(BoundedLattice["EmptyLattice"], metaclass=SingletonMeta):
5
+ """Empty lattice."""
6
+
7
+ def join(self, other: "EmptyLattice") -> "EmptyLattice":
8
+ return self
9
+
10
+ def meet(self, other: "EmptyLattice") -> "EmptyLattice":
11
+ return self
12
+
13
+ @classmethod
14
+ def bottom(cls):
15
+ return cls()
16
+
17
+ @classmethod
18
+ def top(cls):
19
+ return cls()
20
+
21
+ def __hash__(self) -> int:
22
+ return id(self)
23
+
24
+ def is_subseteq(self, other: "EmptyLattice") -> bool:
25
+ return True
kirin/lattice/mixin.py ADDED
@@ -0,0 +1,51 @@
1
+ from typing import TypeVar
2
+
3
+ from .abc import BoundedLattice
4
+
5
+ BoundedLatticeType = TypeVar("BoundedLatticeType", bound="BoundedLattice")
6
+
7
+
8
+ class IsSubsetEqMixin(BoundedLattice[BoundedLatticeType]):
9
+ """A special mixin for lattices that provides a default implementation for `is_subseteq`
10
+ by using the visitor pattern. This is useful if the lattice has a lot of different
11
+ subclasses that need to be compared.
12
+
13
+ Must be used before `BoundedLattice` in the inheritance chain.
14
+ """
15
+
16
+ def is_subseteq(self, other: BoundedLatticeType) -> bool:
17
+ if other is self.top():
18
+ return True
19
+ elif other is self.bottom():
20
+ return False
21
+
22
+ method = getattr(
23
+ self,
24
+ "is_subseteq_" + other.__class__.__name__,
25
+ getattr(self, "is_subseteq_fallback", None),
26
+ )
27
+ if method is not None:
28
+ return method(other)
29
+ return False
30
+
31
+
32
+ class SimpleJoinMixin(BoundedLattice[BoundedLatticeType]):
33
+ """A mixin that provides a simple implementation for the join operation."""
34
+
35
+ def join(self, other: BoundedLatticeType) -> BoundedLatticeType:
36
+ if self.is_subseteq(other):
37
+ return other
38
+ elif other.is_subseteq(self):
39
+ return self # type: ignore
40
+ return self.top()
41
+
42
+
43
+ class SimpleMeetMixin(BoundedLattice[BoundedLatticeType]):
44
+ """A mixin that provides a simple implementation for the meet operation."""
45
+
46
+ def meet(self, other: BoundedLatticeType) -> BoundedLatticeType:
47
+ if self.is_subseteq(other):
48
+ return self # type: ignore
49
+ elif other.is_subseteq(self):
50
+ return other
51
+ return self.bottom()
@@ -0,0 +1,7 @@
1
+ from kirin.lowering.core import Lowering as Lowering
2
+ from kirin.lowering.frame import Frame as Frame
3
+ from kirin.lowering.state import LoweringState as LoweringState
4
+ from kirin.lowering.result import Result as Result
5
+ from kirin.lowering.stream import StmtStream as StmtStream
6
+ from kirin.lowering.binding import wraps as wraps
7
+ from kirin.lowering.dialect import FromPythonAST as FromPythonAST
@@ -0,0 +1,65 @@
1
+ from typing import TYPE_CHECKING, Generic, TypeVar, Callable, ParamSpec
2
+ from dataclasses import dataclass
3
+
4
+ if TYPE_CHECKING:
5
+ from kirin.ir.nodes.stmt import Statement
6
+
7
+ Params = ParamSpec("Params")
8
+ RetType = TypeVar("RetType")
9
+
10
+
11
+ @dataclass(frozen=True)
12
+ class Binding(Generic[Params, RetType]):
13
+ parent: type["Statement"]
14
+
15
+ def __call__(self, *args: Params.args, **kwargs: Params.kwargs) -> RetType:
16
+ raise NotImplementedError(
17
+ f"Binding of {self.parent.name} can \
18
+ only be called from a kernel"
19
+ )
20
+
21
+
22
+ def wraps(parent: type["Statement"]):
23
+ """Wraps a [`Statement`][kirin.ir.nodes.stmt.Statement] to a `Binding` object
24
+ which will be special cased in the lowering process.
25
+
26
+ This is useful for providing type hints by faking the call signature of a
27
+ [`Statement`][kirin.ir.nodes.stmt.Statement].
28
+
29
+ ## Example
30
+
31
+ Directly writing a function with the statement will let Python linter think
32
+ you intend to call the constructor of the statement class. However, given the
33
+ context of a kernel, our intention is to actually "call" the statement, e.g
34
+ the following will produce type errors with pyright or mypy:
35
+
36
+ ```python
37
+ from kirin.dialects import math
38
+ from kirin.prelude import basic_no_opt
39
+
40
+ @basic_no_opt
41
+ def main(x: float):
42
+ return math.sin(x) # this is a statement, not a function
43
+ ```
44
+
45
+ the `@lowering.wraps` decorator allows us to provide a type hint for the
46
+ statement, e.g:
47
+
48
+ ```python
49
+ from kirin import lowering
50
+
51
+ @lowering.wraps(math.sin)
52
+ def sin(value: float) -> float: ...
53
+
54
+ @basic_no_opt
55
+ def main(x: float):
56
+ return sin(x) # linter now thinks this is a function
57
+
58
+ sin(1.0) # this will raise a NotImplementedError("Binding of sin can only be called from a kernel")
59
+ ```
60
+ """
61
+
62
+ def wrapper(func: Callable[Params, RetType]) -> Binding[Params, RetType]:
63
+ return Binding(parent)
64
+
65
+ return wrapper
kirin/lowering/core.py ADDED
@@ -0,0 +1,72 @@
1
+ import ast
2
+ import inspect
3
+ import textwrap
4
+ from types import ModuleType
5
+ from typing import Any, Callable, Iterable
6
+ from dataclasses import dataclass
7
+
8
+ from kirin.ir import Dialect, DialectGroup
9
+ from kirin.exceptions import DialectLoweringError
10
+ from kirin.lowering.state import LoweringState
11
+ from kirin.lowering.dialect import FromPythonAST
12
+
13
+
14
+ @dataclass(init=False)
15
+ class Lowering(ast.NodeVisitor):
16
+ dialects: DialectGroup
17
+ registry: dict[str, FromPythonAST]
18
+ state: LoweringState | None = None
19
+
20
+ # max lines to show in error hint
21
+ max_lines: int = 3
22
+
23
+ def __init__(
24
+ self,
25
+ dialects: DialectGroup | Iterable[Dialect | ModuleType],
26
+ keys: list[str] | None = None,
27
+ max_lines: int = 3,
28
+ ):
29
+ if isinstance(dialects, DialectGroup):
30
+ self.dialects = dialects
31
+ else:
32
+ self.dialects = DialectGroup(dialects)
33
+
34
+ self.max_lines = max_lines
35
+ self.registry: dict[str, FromPythonAST] = self.dialects.registry.ast(
36
+ keys=keys or ["main", "default"]
37
+ )
38
+ self.state = None
39
+
40
+ def run(
41
+ self,
42
+ stmt: ast.stmt | Callable,
43
+ source: str | None = None,
44
+ globals: dict[str, Any] | None = None,
45
+ lineno_offset: int = 0,
46
+ col_offset: int = 0,
47
+ compactify: bool = True,
48
+ ):
49
+ if isinstance(stmt, Callable):
50
+ source = source or textwrap.dedent(inspect.getsource(stmt))
51
+ globals = globals or stmt.__globals__
52
+ try:
53
+ nonlocals = inspect.getclosurevars(stmt).nonlocals
54
+ except Exception:
55
+ nonlocals = {}
56
+ globals.update(nonlocals)
57
+ stmt = ast.parse(source).body[0]
58
+
59
+ state = LoweringState.from_stmt(
60
+ self, stmt, source, globals, self.max_lines, lineno_offset, col_offset
61
+ )
62
+ try:
63
+ state.visit(stmt)
64
+ except DialectLoweringError as e:
65
+ e.args = (f"{e.args[0]}\n\n{state.error_hint()}",) + e.args[1:]
66
+ raise e
67
+
68
+ if compactify:
69
+ from kirin.rewrite import Walk, CFGCompactify
70
+
71
+ Walk(CFGCompactify()).rewrite(state.code)
72
+ return state.code
@@ -0,0 +1,35 @@
1
+ # NOTE: this module is only interface, will be used inside
2
+ # the `ir` module try to minimize the dependencies as much
3
+ # as possible
4
+
5
+ from __future__ import annotations
6
+
7
+ import ast
8
+ from abc import ABC
9
+ from typing import TYPE_CHECKING
10
+
11
+ from kirin.exceptions import DialectLoweringError
12
+ from kirin.lowering.result import Result
13
+
14
+ if TYPE_CHECKING:
15
+ from kirin.lowering.state import LoweringState
16
+
17
+
18
+ class FromPythonAST(ABC):
19
+
20
+ @property
21
+ def names(self) -> list[str]: # show the name without lower_
22
+ return [name[6:] for name in dir(self) if name.startswith("lower_")]
23
+
24
+ def lower(self, state: LoweringState, node: ast.AST) -> Result:
25
+ """Entry point of dialect specific lowering."""
26
+ return getattr(self, f"lower_{node.__class__.__name__}", self.unreachable)(
27
+ state, node
28
+ )
29
+
30
+ def unreachable(self, state: LoweringState, node: ast.AST) -> Result:
31
+ raise DialectLoweringError(f"unreachable reached for {node.__class__.__name__}")
32
+
33
+
34
+ class NoSpecialLowering(FromPythonAST):
35
+ pass