kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,103 @@
1
+ """Assignment dialect for Python.
2
+
3
+ This module contains the dialect for the Python assignment statement, including:
4
+
5
+ - Statements: `Alias`, `SetItem`.
6
+ - The lowering pass for the assignments.
7
+ - The concrete implementation of the assignment statements.
8
+
9
+ This dialects maps Python assignment syntax.
10
+ """
11
+
12
+ import ast
13
+
14
+ from kirin import ir, types, interp, lowering, exceptions
15
+ from kirin.decl import info, statement
16
+ from kirin.print import Printer
17
+
18
+ dialect = ir.Dialect("py.assign")
19
+
20
+ T = types.TypeVar("T")
21
+
22
+
23
+ @statement(dialect=dialect)
24
+ class Alias(ir.Statement):
25
+ name = "alias"
26
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
27
+ value: ir.SSAValue = info.argument(T)
28
+ target: ir.PyAttr[str] = info.attribute()
29
+ result: ir.ResultValue = info.result(T)
30
+
31
+ def print_impl(self, printer: Printer) -> None:
32
+ printer.print_name(self)
33
+ printer.plain_print(" ")
34
+ with printer.rich(style="symbol"):
35
+ printer.plain_print(self.target.data)
36
+
37
+ with printer.rich(style="keyword"):
38
+ printer.plain_print(" = ")
39
+
40
+ printer.print(self.value)
41
+
42
+
43
+ @statement(dialect=dialect)
44
+ class SetItem(ir.Statement):
45
+ name = "setitem"
46
+ traits = frozenset({ir.FromPythonCall()})
47
+ obj: ir.SSAValue = info.argument(print=False)
48
+ value: ir.SSAValue = info.argument(print=False)
49
+ index: ir.SSAValue = info.argument(print=False)
50
+
51
+
52
+ @dialect.register
53
+ class Concrete(interp.MethodTable):
54
+
55
+ @interp.impl(Alias)
56
+ def alias(self, interp, frame: interp.Frame, stmt: Alias):
57
+ return (frame.get(stmt.value),)
58
+
59
+ @interp.impl(SetItem)
60
+ def setindex(self, interp, frame: interp.Frame, stmt: SetItem):
61
+ frame.get(stmt.obj)[frame.get(stmt.index)] = frame.get(stmt.value)
62
+ return (None,)
63
+
64
+
65
+ @dialect.register
66
+ class Lowering(lowering.FromPythonAST):
67
+
68
+ def lower_Assign(
69
+ self, state: lowering.LoweringState, node: ast.Assign
70
+ ) -> lowering.Result:
71
+ results: lowering.Result = state.visit(node.value)
72
+ assert len(node.targets) == len(
73
+ results
74
+ ), "number of targets and results do not match"
75
+
76
+ current_frame = state.current_frame
77
+ match node:
78
+ case ast.Assign(
79
+ targets=[ast.Name(lhs_name, ast.Store())], value=ast.Name(_, ast.Load())
80
+ ):
81
+ stmt = Alias(
82
+ value=results[0], target=ir.PyAttr(lhs_name)
83
+ ) # NOTE: this is guaranteed to be one result
84
+ stmt.result.name = lhs_name
85
+ current_frame.defs[lhs_name] = state.append_stmt(stmt).result
86
+ case _:
87
+ for target, value in zip(node.targets, results.values):
88
+ match target:
89
+ # NOTE: if the name exists new ssa value will be
90
+ # used in the future to shadow the old one
91
+ case ast.Name(name, ast.Store()):
92
+ value.name = name
93
+ current_frame.defs[name] = value
94
+ case ast.Subscript(obj, slice):
95
+ obj = state.visit(obj).expect_one()
96
+ slice = state.visit(slice).expect_one()
97
+ stmt = SetItem(obj=obj, index=slice, value=value)
98
+ state.append_stmt(stmt)
99
+ case _:
100
+ raise exceptions.DialectLoweringError(
101
+ f"unsupported target {target}"
102
+ )
103
+ return lowering.Result() # python assign does not have value
@@ -0,0 +1,59 @@
1
+ """Attribute access dialect for Python.
2
+
3
+ This module contains the dialect for the Python attribute access statement, including:
4
+
5
+ - The `GetAttr` statement class.
6
+ - The lowering pass for the attribute access statement.
7
+ - The concrete implementation of the attribute access statement.
8
+
9
+ This dialect maps `ast.Attribute` nodes to the `GetAttr` statement.
10
+ """
11
+
12
+ import ast
13
+
14
+ from kirin import ir, interp, lowering, exceptions
15
+ from kirin.decl import info, statement
16
+
17
+ dialect = ir.Dialect("py.attr")
18
+
19
+
20
+ @statement(dialect=dialect)
21
+ class GetAttr(ir.Statement):
22
+ name = "getattr"
23
+ traits = frozenset({ir.FromPythonCall()})
24
+ obj: ir.SSAValue = info.argument(print=False)
25
+ attrname: str = info.attribute()
26
+ result: ir.ResultValue = info.result()
27
+
28
+
29
+ @dialect.register
30
+ class Concrete(interp.MethodTable):
31
+
32
+ @interp.impl(GetAttr)
33
+ def getattr(self, interp: interp.Interpreter, frame: interp.Frame, stmt: GetAttr):
34
+ return (getattr(frame.get(stmt.obj), stmt.attrname),)
35
+
36
+
37
+ @dialect.register
38
+ class Lowering(lowering.FromPythonAST):
39
+
40
+ def lower_Attribute(
41
+ self, state: lowering.LoweringState, node: ast.Attribute
42
+ ) -> lowering.Result:
43
+ from kirin.dialects.py import Constant
44
+
45
+ if not isinstance(node.ctx, ast.Load):
46
+ raise exceptions.DialectLoweringError(
47
+ f"unsupported attribute context {node.ctx}"
48
+ )
49
+
50
+ # NOTE: eagerly load global variables
51
+ value = state.get_global_nothrow(node)
52
+ if value is not None:
53
+ stmt = state.append_stmt(Constant(value.unwrap()))
54
+ return lowering.Result(stmt)
55
+
56
+ value = state.visit(node.value).expect_one()
57
+ stmt = GetAttr(obj=value, attrname=node.attr)
58
+ state.append_stmt(stmt)
59
+ return lowering.Result(stmt)
@@ -0,0 +1,34 @@
1
+ """Base dialect for Python.
2
+
3
+ This dialect does not contain statements. It only contains
4
+ lowering rules for `ast.Name` and `ast.Expr`.
5
+ """
6
+
7
+ import ast
8
+
9
+ from kirin import ir, lowering, exceptions
10
+
11
+ dialect = ir.Dialect("py.base")
12
+
13
+
14
+ @dialect.register
15
+ class PythonLowering(lowering.FromPythonAST):
16
+
17
+ def lower_Name(
18
+ self, state: lowering.LoweringState, node: ast.Name
19
+ ) -> lowering.Result:
20
+ name = node.id
21
+ if isinstance(node.ctx, ast.Load):
22
+ value = state.current_frame.get(name)
23
+ if value is None:
24
+ raise exceptions.DialectLoweringError(f"{name} is not defined")
25
+ return lowering.Result(value)
26
+ elif isinstance(node.ctx, ast.Store):
27
+ raise exceptions.DialectLoweringError("unhandled store operation")
28
+ else: # Del
29
+ raise exceptions.DialectLoweringError("unhandled del operation")
30
+
31
+ def lower_Expr(
32
+ self, state: lowering.LoweringState, node: ast.Expr
33
+ ) -> lowering.Result:
34
+ return state.visit(node.value)
@@ -0,0 +1,23 @@
1
+ """The binop dialect for Python.
2
+
3
+ This module contains the dialect for binary operation semantics in Python, including:
4
+
5
+ - The `Add`, `Sub`, `Mult`, `Div`, `FloorDiv`, `Mod`, `Pow`,
6
+ `LShift`, `RShift`, `BitOr`, `BitXor`, and `BitAnd` statement classes.
7
+ - The lowering pass for binary operations.
8
+ - The concrete implementation of binary operations.
9
+ - The type inference implementation of binary operations.
10
+ - The Julia emitter for binary operations.
11
+
12
+ This dialect maps `ast.BinOp` nodes to the `Add`, `Sub`, `Mult`, `Div`, `FloorDiv`,
13
+ `Mod`, `Pow`, `LShift`, `RShift`, `BitOr`, `BitXor`, and `BitAnd` statements.
14
+ """
15
+
16
+ from . import (
17
+ julia as julia,
18
+ interp as interp,
19
+ lowering as lowering,
20
+ typeinfer as typeinfer,
21
+ )
22
+ from .stmts import * # noqa: F403
23
+ from ._dialect import dialect as dialect
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("py.binop")
@@ -0,0 +1,60 @@
1
+ from kirin import interp
2
+
3
+ from . import stmts
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register
8
+ class PyMethodTable(interp.MethodTable):
9
+
10
+ @interp.impl(stmts.Add)
11
+ def add(self, interp, frame: interp.Frame, stmt: stmts.Add):
12
+ return (frame.get(stmt.lhs) + frame.get(stmt.rhs),)
13
+
14
+ @interp.impl(stmts.Sub)
15
+ def sub(self, interp, frame: interp.Frame, stmt: stmts.Sub):
16
+ return (frame.get(stmt.lhs) - frame.get(stmt.rhs),)
17
+
18
+ @interp.impl(stmts.Mult)
19
+ def mult(self, interp, frame: interp.Frame, stmt: stmts.Mult):
20
+ return (frame.get(stmt.lhs) * frame.get(stmt.rhs),)
21
+
22
+ @interp.impl(stmts.Div)
23
+ def div(self, interp, frame: interp.Frame, stmt: stmts.Div):
24
+ return (frame.get(stmt.lhs) / frame.get(stmt.rhs),)
25
+
26
+ @interp.impl(stmts.Mod)
27
+ def mod(self, interp, frame: interp.Frame, stmt: stmts.Mod):
28
+ return (frame.get(stmt.lhs) % frame.get(stmt.rhs),)
29
+
30
+ @interp.impl(stmts.BitAnd)
31
+ def bit_and(self, interp, frame: interp.Frame, stmt: stmts.BitAnd):
32
+ return (frame.get(stmt.lhs) & frame.get(stmt.rhs),)
33
+
34
+ @interp.impl(stmts.BitOr)
35
+ def bit_or(self, interp, frame: interp.Frame, stmt: stmts.BitOr):
36
+ return (frame.get(stmt.lhs) | frame.get(stmt.rhs),)
37
+
38
+ @interp.impl(stmts.BitXor)
39
+ def bit_xor(self, interp, frame: interp.Frame, stmt: stmts.BitXor):
40
+ return (frame.get(stmt.lhs) ^ frame.get(stmt.rhs),)
41
+
42
+ @interp.impl(stmts.LShift)
43
+ def lshift(self, interp, frame: interp.Frame, stmt: stmts.LShift):
44
+ return (frame.get(stmt.lhs) << frame.get(stmt.rhs),)
45
+
46
+ @interp.impl(stmts.RShift)
47
+ def rshift(self, interp, frame: interp.Frame, stmt: stmts.RShift):
48
+ return (frame.get(stmt.lhs) >> frame.get(stmt.rhs),)
49
+
50
+ @interp.impl(stmts.FloorDiv)
51
+ def floor_div(self, interp, frame: interp.Frame, stmt: stmts.FloorDiv):
52
+ return (frame.get(stmt.lhs) // frame.get(stmt.rhs),)
53
+
54
+ @interp.impl(stmts.Pow)
55
+ def pow(self, interp, frame: interp.Frame, stmt: stmts.Pow):
56
+ return (frame.get(stmt.lhs) ** frame.get(stmt.rhs),)
57
+
58
+ @interp.impl(stmts.MatMult)
59
+ def mat_mult(self, interp, frame: interp.Frame, stmt: stmts.MatMult):
60
+ return (frame.get(stmt.lhs) @ frame.get(stmt.rhs),)
@@ -0,0 +1,33 @@
1
+ from kirin import interp
2
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
3
+
4
+ from . import stmts
5
+ from ._dialect import dialect
6
+
7
+
8
+ @dialect.register(key="emit.julia")
9
+ class JuliaTable(interp.MethodTable):
10
+
11
+ @interp.impl(stmts.Add)
12
+ def emit_Add(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Add):
13
+ return emit.emit_binaryop(frame, "+", stmt.lhs, stmt.rhs, stmt.result)
14
+
15
+ @interp.impl(stmts.Sub)
16
+ def emit_Sub(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Sub):
17
+ return emit.emit_binaryop(frame, "-", stmt.lhs, stmt.rhs, stmt.result)
18
+
19
+ @interp.impl(stmts.Mult)
20
+ def emit_Mult(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Mult):
21
+ return emit.emit_binaryop(frame, "*", stmt.lhs, stmt.rhs, stmt.result)
22
+
23
+ @interp.impl(stmts.Div)
24
+ def emit_Div(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Div):
25
+ return emit.emit_binaryop(frame, "/", stmt.lhs, stmt.rhs, stmt.result)
26
+
27
+ @interp.impl(stmts.Mod)
28
+ def emit_Mod(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Mod):
29
+ return emit.emit_binaryop(frame, "%", stmt.lhs, stmt.rhs, stmt.result)
30
+
31
+ @interp.impl(stmts.Pow)
32
+ def emit_Pow(self, emit: EmitJulia, frame: EmitStrFrame, stmt: stmts.Pow):
33
+ return emit.emit_binaryop(frame, "^", stmt.lhs, stmt.rhs, stmt.result)
@@ -0,0 +1,22 @@
1
+ import ast
2
+
3
+ from kirin import lowering, exceptions
4
+
5
+ from . import stmts
6
+ from ._dialect import dialect
7
+
8
+
9
+ @dialect.register
10
+ class Lowering(lowering.FromPythonAST):
11
+
12
+ def lower_BinOp(
13
+ self, state: lowering.LoweringState, node: ast.BinOp
14
+ ) -> lowering.Result:
15
+ lhs = state.visit(node.left).expect_one()
16
+ rhs = state.visit(node.right).expect_one()
17
+
18
+ if op := getattr(stmts, node.op.__class__.__name__, None):
19
+ stmt = op(lhs=lhs, rhs=rhs)
20
+ else:
21
+ raise exceptions.DialectLoweringError(f"unsupported binop {node.op}")
22
+ return lowering.Result(state.append_stmt(stmt))
@@ -0,0 +1,79 @@
1
+ from kirin import ir, types
2
+ from kirin.decl import info, statement
3
+
4
+ from ._dialect import dialect
5
+
6
+ T = types.TypeVar("T")
7
+
8
+
9
+ @statement
10
+ class BinOp(ir.Statement):
11
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
12
+ lhs: ir.SSAValue = info.argument(T, print=False)
13
+ rhs: ir.SSAValue = info.argument(T, print=False)
14
+ result: ir.ResultValue = info.result(T)
15
+
16
+
17
+ @statement(dialect=dialect)
18
+ class Add(BinOp):
19
+ name = "add"
20
+
21
+
22
+ @statement(dialect=dialect)
23
+ class Sub(BinOp):
24
+ name = "sub"
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class Mult(BinOp):
29
+ name = "mult"
30
+
31
+
32
+ @statement(dialect=dialect)
33
+ class Div(BinOp):
34
+ name = "div"
35
+
36
+
37
+ @statement(dialect=dialect)
38
+ class Mod(BinOp):
39
+ name = "mod"
40
+
41
+
42
+ @statement(dialect=dialect)
43
+ class Pow(BinOp):
44
+ name = "pow"
45
+
46
+
47
+ @statement(dialect=dialect)
48
+ class LShift(BinOp):
49
+ name = "lshift"
50
+
51
+
52
+ @statement(dialect=dialect)
53
+ class RShift(BinOp):
54
+ name = "rshift"
55
+
56
+
57
+ @statement(dialect=dialect)
58
+ class BitAnd(BinOp):
59
+ name = "bitand"
60
+
61
+
62
+ @statement(dialect=dialect)
63
+ class BitOr(BinOp):
64
+ name = "bitor"
65
+
66
+
67
+ @statement(dialect=dialect)
68
+ class BitXor(BinOp):
69
+ name = "bitxor"
70
+
71
+
72
+ @statement(dialect=dialect)
73
+ class FloorDiv(BinOp):
74
+ name = "floordiv"
75
+
76
+
77
+ @statement(dialect=dialect)
78
+ class MatMult(BinOp):
79
+ name = "matmult"
@@ -0,0 +1,108 @@
1
+ from kirin import types, interp
2
+
3
+ from . import stmts
4
+ from ._dialect import dialect
5
+
6
+
7
+ @dialect.register(key="typeinfer")
8
+ class TypeInfer(interp.MethodTable):
9
+
10
+ @interp.impl(stmts.Add, types.Float, types.Float)
11
+ @interp.impl(stmts.Add, types.Float, types.Int)
12
+ @interp.impl(stmts.Add, types.Int, types.Float)
13
+ def addf(self, interp, frame, stmt):
14
+ return (types.Float,)
15
+
16
+ @interp.impl(stmts.Add, types.Int, types.Int)
17
+ def addi(self, interp, frame, stmt):
18
+ return (types.Int,)
19
+
20
+ @interp.impl(stmts.Sub, types.Float, types.Float)
21
+ @interp.impl(stmts.Sub, types.Float, types.Int)
22
+ @interp.impl(stmts.Sub, types.Int, types.Float)
23
+ def subf(self, *_):
24
+ return (types.Float,)
25
+
26
+ @interp.impl(stmts.Sub, types.Int, types.Int)
27
+ def subi(self, *_):
28
+ return (types.Int,)
29
+
30
+ @interp.impl(stmts.Mult, types.Float, types.Float)
31
+ @interp.impl(stmts.Mult, types.Float, types.Int)
32
+ @interp.impl(stmts.Mult, types.Int, types.Float)
33
+ def multf(self, *_):
34
+ return (types.Float,)
35
+
36
+ @interp.impl(stmts.Mult, types.Int, types.Int)
37
+ def multi(self, *_):
38
+ return (types.Int,)
39
+
40
+ @interp.impl(stmts.Div)
41
+ def divf(self, typeinfer_, frame, stmt):
42
+ return (types.Float,)
43
+
44
+ @interp.impl(stmts.Mod, types.Float, types.Float)
45
+ @interp.impl(stmts.Mod, types.Float, types.Int)
46
+ @interp.impl(stmts.Mod, types.Int, types.Float)
47
+ def modf(self, *_):
48
+ return (types.Float,)
49
+
50
+ @interp.impl(stmts.Mod, types.Int, types.Int)
51
+ def modi(self, *_):
52
+ return (types.Int,)
53
+
54
+ @interp.impl(stmts.BitAnd, types.Int, types.Int)
55
+ def bit_andi(self, interp, frame, stmt):
56
+ return (types.Int,)
57
+
58
+ @interp.impl(stmts.BitAnd, types.Bool, types.Bool)
59
+ def bit_andb(self, interp, frame, stmt):
60
+ return (types.Bool,)
61
+
62
+ @interp.impl(stmts.BitOr, types.Int, types.Int)
63
+ def bit_ori(self, interp, frame, stmt):
64
+ return (types.Int,)
65
+
66
+ @interp.impl(stmts.BitOr, types.Bool, types.Bool)
67
+ def bit_orb(self, interp, frame, stmt):
68
+ return (types.Bool,)
69
+
70
+ @interp.impl(stmts.BitXor, types.Int, types.Int)
71
+ def bit_xori(self, interp, frame, stmt):
72
+ return (types.Int,)
73
+
74
+ @interp.impl(stmts.BitXor, types.Bool, types.Bool)
75
+ def bit_xorb(self, interp, frame, stmt):
76
+ return (types.Bool,)
77
+
78
+ @interp.impl(stmts.LShift, types.Int)
79
+ def lshift(self, interp, frame, stmt):
80
+ return (types.Int,)
81
+
82
+ @interp.impl(stmts.RShift, types.Int)
83
+ def rshift(self, interp, frame, stmt):
84
+ return (types.Int,)
85
+
86
+ @interp.impl(stmts.FloorDiv, types.Float, types.Float)
87
+ @interp.impl(stmts.FloorDiv, types.Int, types.Float)
88
+ @interp.impl(stmts.FloorDiv, types.Float, types.Int)
89
+ def floor_divf(self, interp, frame, stmt):
90
+ return (types.Float,)
91
+
92
+ @interp.impl(stmts.FloorDiv, types.Int, types.Int)
93
+ def floor_divi(self, interp, frame, stmt):
94
+ return (types.Int,)
95
+
96
+ @interp.impl(stmts.Pow, types.Float, types.Float)
97
+ @interp.impl(stmts.Pow, types.Float, types.Int)
98
+ @interp.impl(stmts.Pow, types.Int, types.Float)
99
+ def powf(self, interp, frame, stmt):
100
+ return (types.Float,)
101
+
102
+ @interp.impl(stmts.Pow, types.Int, types.Int)
103
+ def powi(self, interp, frame, stmt):
104
+ return (types.Int,)
105
+
106
+ @interp.impl(stmts.MatMult)
107
+ def mat_mult(self, interp, frame, stmt):
108
+ raise NotImplementedError("np.array @ np.array not implemented")
@@ -0,0 +1,84 @@
1
+ """Boolean operators for Python dialect.
2
+
3
+ This module contains the dialect for the Python boolean operators, including:
4
+
5
+ - The `And` and `Or` statement classes.
6
+ - The lowering pass for the boolean operators.
7
+ - The concrete implementation of the boolean operators.
8
+ - The Julia emitter for the boolean operators.
9
+
10
+ This dialect maps `ast.BoolOp` nodes to the `And` and `Or` statements.
11
+ """
12
+
13
+ import ast
14
+
15
+ from kirin import ir, types, interp, lowering
16
+ from kirin.decl import info, statement
17
+ from kirin.emit.julia import EmitJulia, EmitStrFrame
18
+ from kirin.exceptions import DialectLoweringError
19
+
20
+ dialect = ir.Dialect("py.boolop")
21
+
22
+
23
+ @statement
24
+ class BoolOp(ir.Statement):
25
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
26
+ lhs: ir.SSAValue = info.argument(print=False)
27
+ rhs: ir.SSAValue = info.argument(print=False)
28
+ result: ir.ResultValue = info.result(types.Bool)
29
+
30
+
31
+ @statement(dialect=dialect)
32
+ class And(BoolOp):
33
+ name = "and"
34
+
35
+
36
+ @statement(dialect=dialect)
37
+ class Or(BoolOp):
38
+ name = "or"
39
+
40
+
41
+ @dialect.register
42
+ class PythonLowering(lowering.FromPythonAST):
43
+
44
+ def lower_BoolOp(
45
+ self, state: lowering.LoweringState, node: ast.BoolOp
46
+ ) -> lowering.Result:
47
+ lhs = state.visit(node.values[0]).expect_one()
48
+ match node.op:
49
+ case ast.And():
50
+ boolop = And
51
+ case ast.Or():
52
+ boolop = Or
53
+ case _:
54
+ raise DialectLoweringError(f"unsupported boolop {node.op}")
55
+
56
+ for value in node.values[1:]:
57
+ lhs = state.append_stmt(
58
+ boolop(lhs=lhs, rhs=state.visit(value).expect_one())
59
+ ).result
60
+ return lowering.Result(lhs)
61
+
62
+
63
+ @dialect.register
64
+ class BoolOpMethod(interp.MethodTable):
65
+
66
+ @interp.impl(And)
67
+ def and_(self, interp, frame: interp.Frame, stmt: And):
68
+ return (frame.get(stmt.lhs) and frame.get(stmt.rhs),)
69
+
70
+ @interp.impl(Or)
71
+ def or_(self, interp, frame: interp.Frame, stmt: Or):
72
+ return (frame.get(stmt.lhs) or frame.get(stmt.rhs),)
73
+
74
+
75
+ @dialect.register(key="emit.julia")
76
+ class JuliaTable(interp.MethodTable):
77
+
78
+ @interp.impl(And)
79
+ def emit_And(self, emit: EmitJulia, frame: EmitStrFrame, stmt: And):
80
+ return emit.emit_binaryop(frame, "&&", stmt.lhs, stmt.rhs, stmt.result)
81
+
82
+ @interp.impl(Or)
83
+ def emit_Or(self, emit: EmitJulia, frame: EmitStrFrame, stmt: Or):
84
+ return emit.emit_binaryop(frame, "||", stmt.lhs, stmt.rhs, stmt.result)
@@ -0,0 +1,78 @@
1
+ """builtin dialect for python builtins
2
+
3
+ This dialect provides implementations for builtin functions like abs and sum.
4
+
5
+ - Statements: `Abs`, `Sum`.
6
+ - The lowering pass for the builtin functions.
7
+ - The concrete implementation of the builtin functions.
8
+ - The type inference implementation of the builtin functions.
9
+
10
+ This dialect maps `ast.Call` nodes of builtin functions to the `Abs` and `Sum` statements.
11
+ """
12
+
13
+ from ast import Call
14
+
15
+ from kirin import ir, types, interp, lowering
16
+ from kirin.decl import info, statement
17
+
18
+ dialect = ir.Dialect("py.builtin")
19
+
20
+ T = types.TypeVar("T", bound=types.Int | types.Float)
21
+
22
+
23
+ @statement(dialect=dialect)
24
+ class Abs(ir.Statement):
25
+ name = "abs"
26
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
27
+ value: ir.SSAValue = info.argument(T, print=False)
28
+ result: ir.ResultValue = info.result(T)
29
+
30
+
31
+ @statement(dialect=dialect)
32
+ class Sum(ir.Statement):
33
+ name = "sum"
34
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
35
+ value: ir.SSAValue = info.argument(types.Any, print=False)
36
+ result: ir.ResultValue = info.result(types.Any)
37
+
38
+
39
+ @dialect.register
40
+ class Lowering(lowering.FromPythonAST):
41
+
42
+ def lower_Call_abs(
43
+ self, state: lowering.LoweringState, node: Call
44
+ ) -> lowering.Result:
45
+ return lowering.Result(
46
+ state.append_stmt(Abs(state.visit(node.args[0]).expect_one()))
47
+ )
48
+
49
+ def lower_Call_sum(
50
+ self, state: lowering.LoweringState, node: Call
51
+ ) -> lowering.Result:
52
+ return lowering.Result(
53
+ state.append_stmt(Sum(state.visit(node.args[0]).expect_one()))
54
+ )
55
+
56
+
57
+ @dialect.register
58
+ class Concrete(interp.MethodTable):
59
+
60
+ @interp.impl(Abs)
61
+ def abs(self, interp, frame: interp.Frame, stmt: Abs):
62
+ return (abs(frame.get(stmt.value)),)
63
+
64
+ @interp.impl(Sum)
65
+ def _sum(self, interp, frame: interp.Frame, stmt: Sum):
66
+ return (sum(frame.get(stmt.value)),)
67
+
68
+
69
+ @dialect.register(key="typeinfer")
70
+ class TypeInfer(interp.MethodTable):
71
+
72
+ @interp.impl(Abs, types.Int)
73
+ def absi(self, interp, frame, stmt):
74
+ return (types.Int,)
75
+
76
+ @interp.impl(Abs, types.Float)
77
+ def absf(self, interp, frame, stmt):
78
+ return (types.Float,)