kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,134 @@
1
+ import ast
2
+
3
+ from kirin import ir, types, lowering
4
+ from kirin.dialects import cf, func
5
+ from kirin.exceptions import DialectLoweringError
6
+
7
+ dialect = ir.Dialect("lowering.func")
8
+
9
+
10
+ @dialect.register
11
+ class Lowering(lowering.FromPythonAST):
12
+
13
+ def lower_Return(
14
+ self, state: lowering.LoweringState, node: ast.Return
15
+ ) -> lowering.Result:
16
+ if node.value is None:
17
+ stmt = func.Return(state.append_stmt(func.ConstantNone()).result)
18
+ state.append_stmt(stmt)
19
+ else:
20
+ result = state.visit(node.value).expect_one()
21
+ stmt = func.Return(result)
22
+ state.append_stmt(stmt)
23
+ return lowering.Result()
24
+
25
+ def lower_FunctionDef(
26
+ self, state: lowering.LoweringState, node: ast.FunctionDef
27
+ ) -> lowering.Result:
28
+ self.assert_simple_arguments(node.args)
29
+ signature = func.Signature(
30
+ inputs=tuple(
31
+ self.get_hint(state, arg.annotation) for arg in node.args.args
32
+ ),
33
+ output=self.get_hint(state, node.returns),
34
+ )
35
+ frame = state.current_frame
36
+
37
+ entries: dict[str, ir.SSAValue] = {}
38
+ entr_block = ir.Block()
39
+ fn_self = entr_block.args.append_from(
40
+ types.Generic(
41
+ ir.Method, types.Tuple.where(signature.inputs), signature.output
42
+ ),
43
+ node.name + "_self",
44
+ )
45
+ entries[node.name] = fn_self
46
+ for arg, type in zip(node.args.args, signature.inputs):
47
+ entries[arg.arg] = entr_block.args.append_from(type, arg.arg)
48
+
49
+ def callback(frame: lowering.Frame, value: ir.SSAValue):
50
+ first_stmt = entr_block.first_stmt
51
+ stmt = func.GetField(obj=fn_self, field=len(frame.captures) - 1)
52
+ if value.name:
53
+ stmt.result.name = value.name
54
+ stmt.result.type = value.type
55
+ stmt.source = state.source
56
+ if first_stmt:
57
+ stmt.insert_before(first_stmt)
58
+ else:
59
+ entr_block.stmts.append(stmt)
60
+ return stmt.result
61
+
62
+ func_frame = state.push_frame(
63
+ lowering.Frame.from_stmts(
64
+ node.body,
65
+ state,
66
+ entr_block=entr_block,
67
+ globals=frame.globals,
68
+ capture_callback=callback,
69
+ )
70
+ )
71
+ func_frame.defs.update(entries)
72
+ state.exhaust()
73
+
74
+ for block in func_frame.curr_region.blocks:
75
+ if not block.last_stmt or not block.last_stmt.has_trait(ir.IsTerminator):
76
+ block.stmts.append(
77
+ cf.Branch(arguments=(), successor=func_frame.next_block)
78
+ )
79
+
80
+ none_stmt = func.ConstantNone()
81
+ rtrn_stmt = func.Return(none_stmt.result)
82
+ func_frame.next_block.stmts.append(none_stmt)
83
+ func_frame.next_block.stmts.append(rtrn_stmt)
84
+ state.pop_frame()
85
+
86
+ if state.current_frame.parent is None: # toplevel function
87
+ stmt = frame.append_stmt(
88
+ func.Function(
89
+ sym_name=node.name,
90
+ signature=signature,
91
+ body=func_frame.curr_region,
92
+ )
93
+ )
94
+ return lowering.Result(stmt)
95
+
96
+ if node.decorator_list:
97
+ raise DialectLoweringError(
98
+ "decorators are not supported on nested functions"
99
+ )
100
+
101
+ # nested function, lookup unknown variables
102
+ first_stmt = func_frame.curr_region.blocks[0].first_stmt
103
+ if first_stmt is None:
104
+ raise DialectLoweringError("empty function body")
105
+
106
+ captured = [value for value in func_frame.captures.values()]
107
+ lambda_stmt = func.Lambda(
108
+ tuple(captured),
109
+ sym_name=node.name,
110
+ signature=signature,
111
+ body=func_frame.curr_region,
112
+ )
113
+ lambda_stmt.result.name = node.name
114
+ # NOTE: Python automatically assigns the lambda to the name
115
+ frame.defs[node.name] = frame.append_stmt(lambda_stmt).result
116
+ return lowering.Result(lambda_stmt)
117
+
118
+ def assert_simple_arguments(self, node: ast.arguments) -> None:
119
+ if node.kwonlyargs:
120
+ raise DialectLoweringError("keyword-only arguments are not supported")
121
+
122
+ if node.posonlyargs:
123
+ raise DialectLoweringError("positional-only arguments are not supported")
124
+
125
+ @staticmethod
126
+ def get_hint(state: lowering.LoweringState, node: ast.expr | None):
127
+ if node is None:
128
+ return types.Any
129
+
130
+ try:
131
+ t = state.get_global(node).unwrap()
132
+ return types.hint2type(t)
133
+ except: # noqa: E722
134
+ raise DialectLoweringError(f"expect a type hint, got {ast.unparse(node)}")
@@ -0,0 +1,41 @@
1
+ "math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py
2
+ from kirin.dialects.math.stmts import (
3
+ cos as cos,
4
+ erf as erf,
5
+ exp as exp,
6
+ pow as pow,
7
+ sin as sin,
8
+ tan as tan,
9
+ ulp as ulp,
10
+ acos as acos,
11
+ asin as asin,
12
+ atan as atan,
13
+ ceil as ceil,
14
+ cosh as cosh,
15
+ erfc as erfc,
16
+ fabs as fabs,
17
+ fmod as fmod,
18
+ log2 as log2,
19
+ sinh as sinh,
20
+ sqrt as sqrt,
21
+ tanh as tanh,
22
+ asinh as asinh,
23
+ atan2 as atan2,
24
+ atanh as atanh,
25
+ expm1 as expm1,
26
+ floor as floor,
27
+ gamma as gamma,
28
+ isinf as isinf,
29
+ isnan as isnan,
30
+ log1p as log1p,
31
+ log10 as log10,
32
+ trunc as trunc,
33
+ lgamma as lgamma,
34
+ degrees as degrees,
35
+ radians as radians,
36
+ copysign as copysign,
37
+ isfinite as isfinite,
38
+ remainder as remainder,
39
+ )
40
+ from kirin.dialects.math.interp import MathMethodTable as MathMethodTable
41
+ from kirin.dialects.math.dialect import dialect as dialect
@@ -0,0 +1,176 @@
1
+ import os
2
+ import math
3
+ import inspect
4
+ import textwrap
5
+ from pathlib import Path
6
+
7
+ import black
8
+
9
+ # NOTE: typeinfer and lowering should be the default, so we don't generate them.
10
+
11
+
12
+ def builtin_math_functions():
13
+ for name, obj in inspect.getmembers(math):
14
+ # skip some special cases for now
15
+ if name in (
16
+ "prod",
17
+ "perm",
18
+ "modf",
19
+ "ldexp",
20
+ "lcm",
21
+ "isqrt",
22
+ "isclose",
23
+ "gcd",
24
+ "fsum",
25
+ "frexp",
26
+ "factorial",
27
+ "acosh",
28
+ "comb",
29
+ "dist",
30
+ "sumprod",
31
+ "nextafter",
32
+ # 3.10 compat
33
+ "cbrt",
34
+ "exp2",
35
+ ):
36
+ continue
37
+
38
+ if inspect.isbuiltin(obj):
39
+ try:
40
+ sig = inspect.signature(obj)
41
+ yield name, obj, sig
42
+ except: # noqa: E722
43
+ continue
44
+
45
+
46
+ with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f:
47
+ f.write("# This file is generated by gen.py\n")
48
+ f.write("from kirin import ir, types\n")
49
+ f.write("from kirin.decl import statement, info\n")
50
+ f.write("from kirin.dialects.math.dialect import dialect\n")
51
+ f.write("\n")
52
+ for name, obj, sig in builtin_math_functions():
53
+ fields = "\n".join(
54
+ [
55
+ f" {arg} : ir.SSAValue = info.argument(types.Float)"
56
+ for arg in sig.parameters.keys()
57
+ ]
58
+ )
59
+ f.write(
60
+ textwrap.dedent(
61
+ f"""
62
+ @statement(dialect=dialect)
63
+ class {name}(ir.Statement):
64
+ \"\"\"{name} statement, wrapping the math.{name} function
65
+ \"\"\"
66
+ name = "{name}"
67
+ traits = frozenset({{ir.Pure(), ir.FromPythonCall()}})
68
+ {fields}
69
+ result: ir.ResultValue = info.result(types.Float)
70
+ """
71
+ )
72
+ )
73
+
74
+
75
+ with open(os.path.join(os.path.dirname(__file__), "interp.py"), "w") as f:
76
+ f.write("# This file is generated by gen.py\n")
77
+ f.write("import math\n")
78
+ f.write("from kirin.dialects.math.dialect import dialect\n")
79
+ f.write("from kirin.dialects.math import stmts\n")
80
+ f.write("from kirin.interp import MethodTable, Frame, impl\n")
81
+ f.write("\n")
82
+
83
+ implements = []
84
+ for name, obj, sig in builtin_math_functions():
85
+ fields = ", ".join(
86
+ [f"values[{idx}]" for idx, _ in enumerate(sig.parameters.keys())]
87
+ )
88
+ implements.append(
89
+ f"""
90
+ @impl(stmts.{name})
91
+ def {name}(self, interp, frame: Frame, stmt: stmts.{name}):
92
+ values = frame.get_values(stmt.args)
93
+ return (math.{name}({fields}),)"""
94
+ )
95
+
96
+ # Write the interpreter class
97
+ methods = "\n\n".join(implements)
98
+ f.write(
99
+ f"""
100
+ @dialect.register
101
+ class MathMethodTable(MethodTable):
102
+ {methods}
103
+ """
104
+ )
105
+
106
+ # __init__.py
107
+ with open(os.path.join(os.path.dirname(__file__), "__init__.py"), "w") as f:
108
+ f.write('"math dialect, modeling functions in python\'s `math` stdlib"')
109
+ f.write("# This file is generated by gen.py\n")
110
+ f.write("from kirin.dialects.math.dialect import dialect as dialect\n")
111
+ f.write("from kirin.dialects.math.stmts import (\n")
112
+ for name, obj, sig in builtin_math_functions():
113
+ f.write(f" {name} as {name},\n")
114
+ f.write(")\n")
115
+ f.write(
116
+ "from kirin.dialects.math.interp import MathMethodTable as MathMethodTable\n"
117
+ )
118
+ f.write("\n")
119
+
120
+ for file in ["__init__.py", "interp.py", "stmts.py"]:
121
+ # format the file in place + using the project config
122
+ black.format_file_in_place(
123
+ Path(os.path.join(os.path.dirname(__file__), file)),
124
+ fast=False,
125
+ mode=black.FileMode(),
126
+ )
127
+
128
+
129
+ # import math as pymath
130
+
131
+ # from kirin.compile import compile
132
+ # from kirin.dialects import math
133
+
134
+
135
+ # # print(math.sin(x=TestValue()))
136
+ # # print(inspect.getargspec(math.sin.__init__))
137
+ # # print(math.sin.__init__)
138
+ # @basic
139
+ # def complicated_math_expr(x):
140
+ # return math.sin(math.cos(x) + math.tan(0.5))
141
+
142
+
143
+ # def test_math():
144
+ # complicated_math_expr.code.print()
145
+ # complicated_math_expr.narrow_types()
146
+ # truth = pymath.sin(pymath.cos(1) + pymath.tan(0.5))
147
+ # assert (complicated_math_expr(1) - truth) / truth < 1e-6
148
+
149
+ # test_basic.py
150
+ project_dir = Path(__file__).parent.parent.parent.parent.parent
151
+ with open(project_dir.joinpath("test", "dialects", "math", "test_basic.py"), "w") as f:
152
+ f.write("# type: ignore\n")
153
+ f.write("# This file is generated by gen.py\n")
154
+ f.write("import math as pymath\n")
155
+ f.write("from kirin.prelude import basic\n")
156
+ f.write("from kirin.dialects import math\n")
157
+ f.write("\n")
158
+ f.write("\n")
159
+
160
+ for name, obj, sig in builtin_math_functions():
161
+ args = ", ".join(arg for arg in sig.parameters.keys())
162
+ inputs = ", ".join("0.42" for _ in sig.parameters.keys())
163
+
164
+ f.write(
165
+ textwrap.dedent(
166
+ f"""
167
+ @basic
168
+ def {name}_func({args}):
169
+ return math.{name}({args})
170
+
171
+ def test_{name}():
172
+ truth = pymath.{name}({inputs})
173
+ assert ({name}_func({inputs}) - truth) < 1e-6
174
+ """
175
+ )
176
+ )
@@ -0,0 +1,3 @@
1
+ from kirin import ir
2
+
3
+ dialect = ir.Dialect("math")
@@ -0,0 +1,190 @@
1
+ # This file is generated by gen.py
2
+ import math
3
+
4
+ from kirin.interp import Frame, MethodTable, impl
5
+ from kirin.dialects.math import stmts
6
+ from kirin.dialects.math.dialect import dialect
7
+
8
+
9
+ @dialect.register
10
+ class MathMethodTable(MethodTable):
11
+
12
+ @impl(stmts.acos)
13
+ def acos(self, interp, frame: Frame, stmt: stmts.acos):
14
+ values = frame.get_values(stmt.args)
15
+ return (math.acos(values[0]),)
16
+
17
+ @impl(stmts.asin)
18
+ def asin(self, interp, frame: Frame, stmt: stmts.asin):
19
+ values = frame.get_values(stmt.args)
20
+ return (math.asin(values[0]),)
21
+
22
+ @impl(stmts.asinh)
23
+ def asinh(self, interp, frame: Frame, stmt: stmts.asinh):
24
+ values = frame.get_values(stmt.args)
25
+ return (math.asinh(values[0]),)
26
+
27
+ @impl(stmts.atan)
28
+ def atan(self, interp, frame: Frame, stmt: stmts.atan):
29
+ values = frame.get_values(stmt.args)
30
+ return (math.atan(values[0]),)
31
+
32
+ @impl(stmts.atan2)
33
+ def atan2(self, interp, frame: Frame, stmt: stmts.atan2):
34
+ values = frame.get_values(stmt.args)
35
+ return (math.atan2(values[0], values[1]),)
36
+
37
+ @impl(stmts.atanh)
38
+ def atanh(self, interp, frame: Frame, stmt: stmts.atanh):
39
+ values = frame.get_values(stmt.args)
40
+ return (math.atanh(values[0]),)
41
+
42
+ @impl(stmts.ceil)
43
+ def ceil(self, interp, frame: Frame, stmt: stmts.ceil):
44
+ values = frame.get_values(stmt.args)
45
+ return (math.ceil(values[0]),)
46
+
47
+ @impl(stmts.copysign)
48
+ def copysign(self, interp, frame: Frame, stmt: stmts.copysign):
49
+ values = frame.get_values(stmt.args)
50
+ return (math.copysign(values[0], values[1]),)
51
+
52
+ @impl(stmts.cos)
53
+ def cos(self, interp, frame: Frame, stmt: stmts.cos):
54
+ values = frame.get_values(stmt.args)
55
+ return (math.cos(values[0]),)
56
+
57
+ @impl(stmts.cosh)
58
+ def cosh(self, interp, frame: Frame, stmt: stmts.cosh):
59
+ values = frame.get_values(stmt.args)
60
+ return (math.cosh(values[0]),)
61
+
62
+ @impl(stmts.degrees)
63
+ def degrees(self, interp, frame: Frame, stmt: stmts.degrees):
64
+ values = frame.get_values(stmt.args)
65
+ return (math.degrees(values[0]),)
66
+
67
+ @impl(stmts.erf)
68
+ def erf(self, interp, frame: Frame, stmt: stmts.erf):
69
+ values = frame.get_values(stmt.args)
70
+ return (math.erf(values[0]),)
71
+
72
+ @impl(stmts.erfc)
73
+ def erfc(self, interp, frame: Frame, stmt: stmts.erfc):
74
+ values = frame.get_values(stmt.args)
75
+ return (math.erfc(values[0]),)
76
+
77
+ @impl(stmts.exp)
78
+ def exp(self, interp, frame: Frame, stmt: stmts.exp):
79
+ values = frame.get_values(stmt.args)
80
+ return (math.exp(values[0]),)
81
+
82
+ @impl(stmts.expm1)
83
+ def expm1(self, interp, frame: Frame, stmt: stmts.expm1):
84
+ values = frame.get_values(stmt.args)
85
+ return (math.expm1(values[0]),)
86
+
87
+ @impl(stmts.fabs)
88
+ def fabs(self, interp, frame: Frame, stmt: stmts.fabs):
89
+ values = frame.get_values(stmt.args)
90
+ return (math.fabs(values[0]),)
91
+
92
+ @impl(stmts.floor)
93
+ def floor(self, interp, frame: Frame, stmt: stmts.floor):
94
+ values = frame.get_values(stmt.args)
95
+ return (math.floor(values[0]),)
96
+
97
+ @impl(stmts.fmod)
98
+ def fmod(self, interp, frame: Frame, stmt: stmts.fmod):
99
+ values = frame.get_values(stmt.args)
100
+ return (math.fmod(values[0], values[1]),)
101
+
102
+ @impl(stmts.gamma)
103
+ def gamma(self, interp, frame: Frame, stmt: stmts.gamma):
104
+ values = frame.get_values(stmt.args)
105
+ return (math.gamma(values[0]),)
106
+
107
+ @impl(stmts.isfinite)
108
+ def isfinite(self, interp, frame: Frame, stmt: stmts.isfinite):
109
+ values = frame.get_values(stmt.args)
110
+ return (math.isfinite(values[0]),)
111
+
112
+ @impl(stmts.isinf)
113
+ def isinf(self, interp, frame: Frame, stmt: stmts.isinf):
114
+ values = frame.get_values(stmt.args)
115
+ return (math.isinf(values[0]),)
116
+
117
+ @impl(stmts.isnan)
118
+ def isnan(self, interp, frame: Frame, stmt: stmts.isnan):
119
+ values = frame.get_values(stmt.args)
120
+ return (math.isnan(values[0]),)
121
+
122
+ @impl(stmts.lgamma)
123
+ def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma):
124
+ values = frame.get_values(stmt.args)
125
+ return (math.lgamma(values[0]),)
126
+
127
+ @impl(stmts.log10)
128
+ def log10(self, interp, frame: Frame, stmt: stmts.log10):
129
+ values = frame.get_values(stmt.args)
130
+ return (math.log10(values[0]),)
131
+
132
+ @impl(stmts.log1p)
133
+ def log1p(self, interp, frame: Frame, stmt: stmts.log1p):
134
+ values = frame.get_values(stmt.args)
135
+ return (math.log1p(values[0]),)
136
+
137
+ @impl(stmts.log2)
138
+ def log2(self, interp, frame: Frame, stmt: stmts.log2):
139
+ values = frame.get_values(stmt.args)
140
+ return (math.log2(values[0]),)
141
+
142
+ @impl(stmts.pow)
143
+ def pow(self, interp, frame: Frame, stmt: stmts.pow):
144
+ values = frame.get_values(stmt.args)
145
+ return (math.pow(values[0], values[1]),)
146
+
147
+ @impl(stmts.radians)
148
+ def radians(self, interp, frame: Frame, stmt: stmts.radians):
149
+ values = frame.get_values(stmt.args)
150
+ return (math.radians(values[0]),)
151
+
152
+ @impl(stmts.remainder)
153
+ def remainder(self, interp, frame: Frame, stmt: stmts.remainder):
154
+ values = frame.get_values(stmt.args)
155
+ return (math.remainder(values[0], values[1]),)
156
+
157
+ @impl(stmts.sin)
158
+ def sin(self, interp, frame: Frame, stmt: stmts.sin):
159
+ values = frame.get_values(stmt.args)
160
+ return (math.sin(values[0]),)
161
+
162
+ @impl(stmts.sinh)
163
+ def sinh(self, interp, frame: Frame, stmt: stmts.sinh):
164
+ values = frame.get_values(stmt.args)
165
+ return (math.sinh(values[0]),)
166
+
167
+ @impl(stmts.sqrt)
168
+ def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt):
169
+ values = frame.get_values(stmt.args)
170
+ return (math.sqrt(values[0]),)
171
+
172
+ @impl(stmts.tan)
173
+ def tan(self, interp, frame: Frame, stmt: stmts.tan):
174
+ values = frame.get_values(stmt.args)
175
+ return (math.tan(values[0]),)
176
+
177
+ @impl(stmts.tanh)
178
+ def tanh(self, interp, frame: Frame, stmt: stmts.tanh):
179
+ values = frame.get_values(stmt.args)
180
+ return (math.tanh(values[0]),)
181
+
182
+ @impl(stmts.trunc)
183
+ def trunc(self, interp, frame: Frame, stmt: stmts.trunc):
184
+ values = frame.get_values(stmt.args)
185
+ return (math.trunc(values[0]),)
186
+
187
+ @impl(stmts.ulp)
188
+ def ulp(self, interp, frame: Frame, stmt: stmts.ulp):
189
+ values = frame.get_values(stmt.args)
190
+ return (math.ulp(values[0]),)