kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,134 @@
|
|
1
|
+
import ast
|
2
|
+
|
3
|
+
from kirin import ir, types, lowering
|
4
|
+
from kirin.dialects import cf, func
|
5
|
+
from kirin.exceptions import DialectLoweringError
|
6
|
+
|
7
|
+
dialect = ir.Dialect("lowering.func")
|
8
|
+
|
9
|
+
|
10
|
+
@dialect.register
|
11
|
+
class Lowering(lowering.FromPythonAST):
|
12
|
+
|
13
|
+
def lower_Return(
|
14
|
+
self, state: lowering.LoweringState, node: ast.Return
|
15
|
+
) -> lowering.Result:
|
16
|
+
if node.value is None:
|
17
|
+
stmt = func.Return(state.append_stmt(func.ConstantNone()).result)
|
18
|
+
state.append_stmt(stmt)
|
19
|
+
else:
|
20
|
+
result = state.visit(node.value).expect_one()
|
21
|
+
stmt = func.Return(result)
|
22
|
+
state.append_stmt(stmt)
|
23
|
+
return lowering.Result()
|
24
|
+
|
25
|
+
def lower_FunctionDef(
|
26
|
+
self, state: lowering.LoweringState, node: ast.FunctionDef
|
27
|
+
) -> lowering.Result:
|
28
|
+
self.assert_simple_arguments(node.args)
|
29
|
+
signature = func.Signature(
|
30
|
+
inputs=tuple(
|
31
|
+
self.get_hint(state, arg.annotation) for arg in node.args.args
|
32
|
+
),
|
33
|
+
output=self.get_hint(state, node.returns),
|
34
|
+
)
|
35
|
+
frame = state.current_frame
|
36
|
+
|
37
|
+
entries: dict[str, ir.SSAValue] = {}
|
38
|
+
entr_block = ir.Block()
|
39
|
+
fn_self = entr_block.args.append_from(
|
40
|
+
types.Generic(
|
41
|
+
ir.Method, types.Tuple.where(signature.inputs), signature.output
|
42
|
+
),
|
43
|
+
node.name + "_self",
|
44
|
+
)
|
45
|
+
entries[node.name] = fn_self
|
46
|
+
for arg, type in zip(node.args.args, signature.inputs):
|
47
|
+
entries[arg.arg] = entr_block.args.append_from(type, arg.arg)
|
48
|
+
|
49
|
+
def callback(frame: lowering.Frame, value: ir.SSAValue):
|
50
|
+
first_stmt = entr_block.first_stmt
|
51
|
+
stmt = func.GetField(obj=fn_self, field=len(frame.captures) - 1)
|
52
|
+
if value.name:
|
53
|
+
stmt.result.name = value.name
|
54
|
+
stmt.result.type = value.type
|
55
|
+
stmt.source = state.source
|
56
|
+
if first_stmt:
|
57
|
+
stmt.insert_before(first_stmt)
|
58
|
+
else:
|
59
|
+
entr_block.stmts.append(stmt)
|
60
|
+
return stmt.result
|
61
|
+
|
62
|
+
func_frame = state.push_frame(
|
63
|
+
lowering.Frame.from_stmts(
|
64
|
+
node.body,
|
65
|
+
state,
|
66
|
+
entr_block=entr_block,
|
67
|
+
globals=frame.globals,
|
68
|
+
capture_callback=callback,
|
69
|
+
)
|
70
|
+
)
|
71
|
+
func_frame.defs.update(entries)
|
72
|
+
state.exhaust()
|
73
|
+
|
74
|
+
for block in func_frame.curr_region.blocks:
|
75
|
+
if not block.last_stmt or not block.last_stmt.has_trait(ir.IsTerminator):
|
76
|
+
block.stmts.append(
|
77
|
+
cf.Branch(arguments=(), successor=func_frame.next_block)
|
78
|
+
)
|
79
|
+
|
80
|
+
none_stmt = func.ConstantNone()
|
81
|
+
rtrn_stmt = func.Return(none_stmt.result)
|
82
|
+
func_frame.next_block.stmts.append(none_stmt)
|
83
|
+
func_frame.next_block.stmts.append(rtrn_stmt)
|
84
|
+
state.pop_frame()
|
85
|
+
|
86
|
+
if state.current_frame.parent is None: # toplevel function
|
87
|
+
stmt = frame.append_stmt(
|
88
|
+
func.Function(
|
89
|
+
sym_name=node.name,
|
90
|
+
signature=signature,
|
91
|
+
body=func_frame.curr_region,
|
92
|
+
)
|
93
|
+
)
|
94
|
+
return lowering.Result(stmt)
|
95
|
+
|
96
|
+
if node.decorator_list:
|
97
|
+
raise DialectLoweringError(
|
98
|
+
"decorators are not supported on nested functions"
|
99
|
+
)
|
100
|
+
|
101
|
+
# nested function, lookup unknown variables
|
102
|
+
first_stmt = func_frame.curr_region.blocks[0].first_stmt
|
103
|
+
if first_stmt is None:
|
104
|
+
raise DialectLoweringError("empty function body")
|
105
|
+
|
106
|
+
captured = [value for value in func_frame.captures.values()]
|
107
|
+
lambda_stmt = func.Lambda(
|
108
|
+
tuple(captured),
|
109
|
+
sym_name=node.name,
|
110
|
+
signature=signature,
|
111
|
+
body=func_frame.curr_region,
|
112
|
+
)
|
113
|
+
lambda_stmt.result.name = node.name
|
114
|
+
# NOTE: Python automatically assigns the lambda to the name
|
115
|
+
frame.defs[node.name] = frame.append_stmt(lambda_stmt).result
|
116
|
+
return lowering.Result(lambda_stmt)
|
117
|
+
|
118
|
+
def assert_simple_arguments(self, node: ast.arguments) -> None:
|
119
|
+
if node.kwonlyargs:
|
120
|
+
raise DialectLoweringError("keyword-only arguments are not supported")
|
121
|
+
|
122
|
+
if node.posonlyargs:
|
123
|
+
raise DialectLoweringError("positional-only arguments are not supported")
|
124
|
+
|
125
|
+
@staticmethod
|
126
|
+
def get_hint(state: lowering.LoweringState, node: ast.expr | None):
|
127
|
+
if node is None:
|
128
|
+
return types.Any
|
129
|
+
|
130
|
+
try:
|
131
|
+
t = state.get_global(node).unwrap()
|
132
|
+
return types.hint2type(t)
|
133
|
+
except: # noqa: E722
|
134
|
+
raise DialectLoweringError(f"expect a type hint, got {ast.unparse(node)}")
|
@@ -0,0 +1,41 @@
|
|
1
|
+
"math dialect, modeling functions in python's `math` stdlib" # This file is generated by gen.py
|
2
|
+
from kirin.dialects.math.stmts import (
|
3
|
+
cos as cos,
|
4
|
+
erf as erf,
|
5
|
+
exp as exp,
|
6
|
+
pow as pow,
|
7
|
+
sin as sin,
|
8
|
+
tan as tan,
|
9
|
+
ulp as ulp,
|
10
|
+
acos as acos,
|
11
|
+
asin as asin,
|
12
|
+
atan as atan,
|
13
|
+
ceil as ceil,
|
14
|
+
cosh as cosh,
|
15
|
+
erfc as erfc,
|
16
|
+
fabs as fabs,
|
17
|
+
fmod as fmod,
|
18
|
+
log2 as log2,
|
19
|
+
sinh as sinh,
|
20
|
+
sqrt as sqrt,
|
21
|
+
tanh as tanh,
|
22
|
+
asinh as asinh,
|
23
|
+
atan2 as atan2,
|
24
|
+
atanh as atanh,
|
25
|
+
expm1 as expm1,
|
26
|
+
floor as floor,
|
27
|
+
gamma as gamma,
|
28
|
+
isinf as isinf,
|
29
|
+
isnan as isnan,
|
30
|
+
log1p as log1p,
|
31
|
+
log10 as log10,
|
32
|
+
trunc as trunc,
|
33
|
+
lgamma as lgamma,
|
34
|
+
degrees as degrees,
|
35
|
+
radians as radians,
|
36
|
+
copysign as copysign,
|
37
|
+
isfinite as isfinite,
|
38
|
+
remainder as remainder,
|
39
|
+
)
|
40
|
+
from kirin.dialects.math.interp import MathMethodTable as MathMethodTable
|
41
|
+
from kirin.dialects.math.dialect import dialect as dialect
|
@@ -0,0 +1,176 @@
|
|
1
|
+
import os
|
2
|
+
import math
|
3
|
+
import inspect
|
4
|
+
import textwrap
|
5
|
+
from pathlib import Path
|
6
|
+
|
7
|
+
import black
|
8
|
+
|
9
|
+
# NOTE: typeinfer and lowering should be the default, so we don't generate them.
|
10
|
+
|
11
|
+
|
12
|
+
def builtin_math_functions():
|
13
|
+
for name, obj in inspect.getmembers(math):
|
14
|
+
# skip some special cases for now
|
15
|
+
if name in (
|
16
|
+
"prod",
|
17
|
+
"perm",
|
18
|
+
"modf",
|
19
|
+
"ldexp",
|
20
|
+
"lcm",
|
21
|
+
"isqrt",
|
22
|
+
"isclose",
|
23
|
+
"gcd",
|
24
|
+
"fsum",
|
25
|
+
"frexp",
|
26
|
+
"factorial",
|
27
|
+
"acosh",
|
28
|
+
"comb",
|
29
|
+
"dist",
|
30
|
+
"sumprod",
|
31
|
+
"nextafter",
|
32
|
+
# 3.10 compat
|
33
|
+
"cbrt",
|
34
|
+
"exp2",
|
35
|
+
):
|
36
|
+
continue
|
37
|
+
|
38
|
+
if inspect.isbuiltin(obj):
|
39
|
+
try:
|
40
|
+
sig = inspect.signature(obj)
|
41
|
+
yield name, obj, sig
|
42
|
+
except: # noqa: E722
|
43
|
+
continue
|
44
|
+
|
45
|
+
|
46
|
+
with open(os.path.join(os.path.dirname(__file__), "stmts.py"), "w") as f:
|
47
|
+
f.write("# This file is generated by gen.py\n")
|
48
|
+
f.write("from kirin import ir, types\n")
|
49
|
+
f.write("from kirin.decl import statement, info\n")
|
50
|
+
f.write("from kirin.dialects.math.dialect import dialect\n")
|
51
|
+
f.write("\n")
|
52
|
+
for name, obj, sig in builtin_math_functions():
|
53
|
+
fields = "\n".join(
|
54
|
+
[
|
55
|
+
f" {arg} : ir.SSAValue = info.argument(types.Float)"
|
56
|
+
for arg in sig.parameters.keys()
|
57
|
+
]
|
58
|
+
)
|
59
|
+
f.write(
|
60
|
+
textwrap.dedent(
|
61
|
+
f"""
|
62
|
+
@statement(dialect=dialect)
|
63
|
+
class {name}(ir.Statement):
|
64
|
+
\"\"\"{name} statement, wrapping the math.{name} function
|
65
|
+
\"\"\"
|
66
|
+
name = "{name}"
|
67
|
+
traits = frozenset({{ir.Pure(), ir.FromPythonCall()}})
|
68
|
+
{fields}
|
69
|
+
result: ir.ResultValue = info.result(types.Float)
|
70
|
+
"""
|
71
|
+
)
|
72
|
+
)
|
73
|
+
|
74
|
+
|
75
|
+
with open(os.path.join(os.path.dirname(__file__), "interp.py"), "w") as f:
|
76
|
+
f.write("# This file is generated by gen.py\n")
|
77
|
+
f.write("import math\n")
|
78
|
+
f.write("from kirin.dialects.math.dialect import dialect\n")
|
79
|
+
f.write("from kirin.dialects.math import stmts\n")
|
80
|
+
f.write("from kirin.interp import MethodTable, Frame, impl\n")
|
81
|
+
f.write("\n")
|
82
|
+
|
83
|
+
implements = []
|
84
|
+
for name, obj, sig in builtin_math_functions():
|
85
|
+
fields = ", ".join(
|
86
|
+
[f"values[{idx}]" for idx, _ in enumerate(sig.parameters.keys())]
|
87
|
+
)
|
88
|
+
implements.append(
|
89
|
+
f"""
|
90
|
+
@impl(stmts.{name})
|
91
|
+
def {name}(self, interp, frame: Frame, stmt: stmts.{name}):
|
92
|
+
values = frame.get_values(stmt.args)
|
93
|
+
return (math.{name}({fields}),)"""
|
94
|
+
)
|
95
|
+
|
96
|
+
# Write the interpreter class
|
97
|
+
methods = "\n\n".join(implements)
|
98
|
+
f.write(
|
99
|
+
f"""
|
100
|
+
@dialect.register
|
101
|
+
class MathMethodTable(MethodTable):
|
102
|
+
{methods}
|
103
|
+
"""
|
104
|
+
)
|
105
|
+
|
106
|
+
# __init__.py
|
107
|
+
with open(os.path.join(os.path.dirname(__file__), "__init__.py"), "w") as f:
|
108
|
+
f.write('"math dialect, modeling functions in python\'s `math` stdlib"')
|
109
|
+
f.write("# This file is generated by gen.py\n")
|
110
|
+
f.write("from kirin.dialects.math.dialect import dialect as dialect\n")
|
111
|
+
f.write("from kirin.dialects.math.stmts import (\n")
|
112
|
+
for name, obj, sig in builtin_math_functions():
|
113
|
+
f.write(f" {name} as {name},\n")
|
114
|
+
f.write(")\n")
|
115
|
+
f.write(
|
116
|
+
"from kirin.dialects.math.interp import MathMethodTable as MathMethodTable\n"
|
117
|
+
)
|
118
|
+
f.write("\n")
|
119
|
+
|
120
|
+
for file in ["__init__.py", "interp.py", "stmts.py"]:
|
121
|
+
# format the file in place + using the project config
|
122
|
+
black.format_file_in_place(
|
123
|
+
Path(os.path.join(os.path.dirname(__file__), file)),
|
124
|
+
fast=False,
|
125
|
+
mode=black.FileMode(),
|
126
|
+
)
|
127
|
+
|
128
|
+
|
129
|
+
# import math as pymath
|
130
|
+
|
131
|
+
# from kirin.compile import compile
|
132
|
+
# from kirin.dialects import math
|
133
|
+
|
134
|
+
|
135
|
+
# # print(math.sin(x=TestValue()))
|
136
|
+
# # print(inspect.getargspec(math.sin.__init__))
|
137
|
+
# # print(math.sin.__init__)
|
138
|
+
# @basic
|
139
|
+
# def complicated_math_expr(x):
|
140
|
+
# return math.sin(math.cos(x) + math.tan(0.5))
|
141
|
+
|
142
|
+
|
143
|
+
# def test_math():
|
144
|
+
# complicated_math_expr.code.print()
|
145
|
+
# complicated_math_expr.narrow_types()
|
146
|
+
# truth = pymath.sin(pymath.cos(1) + pymath.tan(0.5))
|
147
|
+
# assert (complicated_math_expr(1) - truth) / truth < 1e-6
|
148
|
+
|
149
|
+
# test_basic.py
|
150
|
+
project_dir = Path(__file__).parent.parent.parent.parent.parent
|
151
|
+
with open(project_dir.joinpath("test", "dialects", "math", "test_basic.py"), "w") as f:
|
152
|
+
f.write("# type: ignore\n")
|
153
|
+
f.write("# This file is generated by gen.py\n")
|
154
|
+
f.write("import math as pymath\n")
|
155
|
+
f.write("from kirin.prelude import basic\n")
|
156
|
+
f.write("from kirin.dialects import math\n")
|
157
|
+
f.write("\n")
|
158
|
+
f.write("\n")
|
159
|
+
|
160
|
+
for name, obj, sig in builtin_math_functions():
|
161
|
+
args = ", ".join(arg for arg in sig.parameters.keys())
|
162
|
+
inputs = ", ".join("0.42" for _ in sig.parameters.keys())
|
163
|
+
|
164
|
+
f.write(
|
165
|
+
textwrap.dedent(
|
166
|
+
f"""
|
167
|
+
@basic
|
168
|
+
def {name}_func({args}):
|
169
|
+
return math.{name}({args})
|
170
|
+
|
171
|
+
def test_{name}():
|
172
|
+
truth = pymath.{name}({inputs})
|
173
|
+
assert ({name}_func({inputs}) - truth) < 1e-6
|
174
|
+
"""
|
175
|
+
)
|
176
|
+
)
|
@@ -0,0 +1,190 @@
|
|
1
|
+
# This file is generated by gen.py
|
2
|
+
import math
|
3
|
+
|
4
|
+
from kirin.interp import Frame, MethodTable, impl
|
5
|
+
from kirin.dialects.math import stmts
|
6
|
+
from kirin.dialects.math.dialect import dialect
|
7
|
+
|
8
|
+
|
9
|
+
@dialect.register
|
10
|
+
class MathMethodTable(MethodTable):
|
11
|
+
|
12
|
+
@impl(stmts.acos)
|
13
|
+
def acos(self, interp, frame: Frame, stmt: stmts.acos):
|
14
|
+
values = frame.get_values(stmt.args)
|
15
|
+
return (math.acos(values[0]),)
|
16
|
+
|
17
|
+
@impl(stmts.asin)
|
18
|
+
def asin(self, interp, frame: Frame, stmt: stmts.asin):
|
19
|
+
values = frame.get_values(stmt.args)
|
20
|
+
return (math.asin(values[0]),)
|
21
|
+
|
22
|
+
@impl(stmts.asinh)
|
23
|
+
def asinh(self, interp, frame: Frame, stmt: stmts.asinh):
|
24
|
+
values = frame.get_values(stmt.args)
|
25
|
+
return (math.asinh(values[0]),)
|
26
|
+
|
27
|
+
@impl(stmts.atan)
|
28
|
+
def atan(self, interp, frame: Frame, stmt: stmts.atan):
|
29
|
+
values = frame.get_values(stmt.args)
|
30
|
+
return (math.atan(values[0]),)
|
31
|
+
|
32
|
+
@impl(stmts.atan2)
|
33
|
+
def atan2(self, interp, frame: Frame, stmt: stmts.atan2):
|
34
|
+
values = frame.get_values(stmt.args)
|
35
|
+
return (math.atan2(values[0], values[1]),)
|
36
|
+
|
37
|
+
@impl(stmts.atanh)
|
38
|
+
def atanh(self, interp, frame: Frame, stmt: stmts.atanh):
|
39
|
+
values = frame.get_values(stmt.args)
|
40
|
+
return (math.atanh(values[0]),)
|
41
|
+
|
42
|
+
@impl(stmts.ceil)
|
43
|
+
def ceil(self, interp, frame: Frame, stmt: stmts.ceil):
|
44
|
+
values = frame.get_values(stmt.args)
|
45
|
+
return (math.ceil(values[0]),)
|
46
|
+
|
47
|
+
@impl(stmts.copysign)
|
48
|
+
def copysign(self, interp, frame: Frame, stmt: stmts.copysign):
|
49
|
+
values = frame.get_values(stmt.args)
|
50
|
+
return (math.copysign(values[0], values[1]),)
|
51
|
+
|
52
|
+
@impl(stmts.cos)
|
53
|
+
def cos(self, interp, frame: Frame, stmt: stmts.cos):
|
54
|
+
values = frame.get_values(stmt.args)
|
55
|
+
return (math.cos(values[0]),)
|
56
|
+
|
57
|
+
@impl(stmts.cosh)
|
58
|
+
def cosh(self, interp, frame: Frame, stmt: stmts.cosh):
|
59
|
+
values = frame.get_values(stmt.args)
|
60
|
+
return (math.cosh(values[0]),)
|
61
|
+
|
62
|
+
@impl(stmts.degrees)
|
63
|
+
def degrees(self, interp, frame: Frame, stmt: stmts.degrees):
|
64
|
+
values = frame.get_values(stmt.args)
|
65
|
+
return (math.degrees(values[0]),)
|
66
|
+
|
67
|
+
@impl(stmts.erf)
|
68
|
+
def erf(self, interp, frame: Frame, stmt: stmts.erf):
|
69
|
+
values = frame.get_values(stmt.args)
|
70
|
+
return (math.erf(values[0]),)
|
71
|
+
|
72
|
+
@impl(stmts.erfc)
|
73
|
+
def erfc(self, interp, frame: Frame, stmt: stmts.erfc):
|
74
|
+
values = frame.get_values(stmt.args)
|
75
|
+
return (math.erfc(values[0]),)
|
76
|
+
|
77
|
+
@impl(stmts.exp)
|
78
|
+
def exp(self, interp, frame: Frame, stmt: stmts.exp):
|
79
|
+
values = frame.get_values(stmt.args)
|
80
|
+
return (math.exp(values[0]),)
|
81
|
+
|
82
|
+
@impl(stmts.expm1)
|
83
|
+
def expm1(self, interp, frame: Frame, stmt: stmts.expm1):
|
84
|
+
values = frame.get_values(stmt.args)
|
85
|
+
return (math.expm1(values[0]),)
|
86
|
+
|
87
|
+
@impl(stmts.fabs)
|
88
|
+
def fabs(self, interp, frame: Frame, stmt: stmts.fabs):
|
89
|
+
values = frame.get_values(stmt.args)
|
90
|
+
return (math.fabs(values[0]),)
|
91
|
+
|
92
|
+
@impl(stmts.floor)
|
93
|
+
def floor(self, interp, frame: Frame, stmt: stmts.floor):
|
94
|
+
values = frame.get_values(stmt.args)
|
95
|
+
return (math.floor(values[0]),)
|
96
|
+
|
97
|
+
@impl(stmts.fmod)
|
98
|
+
def fmod(self, interp, frame: Frame, stmt: stmts.fmod):
|
99
|
+
values = frame.get_values(stmt.args)
|
100
|
+
return (math.fmod(values[0], values[1]),)
|
101
|
+
|
102
|
+
@impl(stmts.gamma)
|
103
|
+
def gamma(self, interp, frame: Frame, stmt: stmts.gamma):
|
104
|
+
values = frame.get_values(stmt.args)
|
105
|
+
return (math.gamma(values[0]),)
|
106
|
+
|
107
|
+
@impl(stmts.isfinite)
|
108
|
+
def isfinite(self, interp, frame: Frame, stmt: stmts.isfinite):
|
109
|
+
values = frame.get_values(stmt.args)
|
110
|
+
return (math.isfinite(values[0]),)
|
111
|
+
|
112
|
+
@impl(stmts.isinf)
|
113
|
+
def isinf(self, interp, frame: Frame, stmt: stmts.isinf):
|
114
|
+
values = frame.get_values(stmt.args)
|
115
|
+
return (math.isinf(values[0]),)
|
116
|
+
|
117
|
+
@impl(stmts.isnan)
|
118
|
+
def isnan(self, interp, frame: Frame, stmt: stmts.isnan):
|
119
|
+
values = frame.get_values(stmt.args)
|
120
|
+
return (math.isnan(values[0]),)
|
121
|
+
|
122
|
+
@impl(stmts.lgamma)
|
123
|
+
def lgamma(self, interp, frame: Frame, stmt: stmts.lgamma):
|
124
|
+
values = frame.get_values(stmt.args)
|
125
|
+
return (math.lgamma(values[0]),)
|
126
|
+
|
127
|
+
@impl(stmts.log10)
|
128
|
+
def log10(self, interp, frame: Frame, stmt: stmts.log10):
|
129
|
+
values = frame.get_values(stmt.args)
|
130
|
+
return (math.log10(values[0]),)
|
131
|
+
|
132
|
+
@impl(stmts.log1p)
|
133
|
+
def log1p(self, interp, frame: Frame, stmt: stmts.log1p):
|
134
|
+
values = frame.get_values(stmt.args)
|
135
|
+
return (math.log1p(values[0]),)
|
136
|
+
|
137
|
+
@impl(stmts.log2)
|
138
|
+
def log2(self, interp, frame: Frame, stmt: stmts.log2):
|
139
|
+
values = frame.get_values(stmt.args)
|
140
|
+
return (math.log2(values[0]),)
|
141
|
+
|
142
|
+
@impl(stmts.pow)
|
143
|
+
def pow(self, interp, frame: Frame, stmt: stmts.pow):
|
144
|
+
values = frame.get_values(stmt.args)
|
145
|
+
return (math.pow(values[0], values[1]),)
|
146
|
+
|
147
|
+
@impl(stmts.radians)
|
148
|
+
def radians(self, interp, frame: Frame, stmt: stmts.radians):
|
149
|
+
values = frame.get_values(stmt.args)
|
150
|
+
return (math.radians(values[0]),)
|
151
|
+
|
152
|
+
@impl(stmts.remainder)
|
153
|
+
def remainder(self, interp, frame: Frame, stmt: stmts.remainder):
|
154
|
+
values = frame.get_values(stmt.args)
|
155
|
+
return (math.remainder(values[0], values[1]),)
|
156
|
+
|
157
|
+
@impl(stmts.sin)
|
158
|
+
def sin(self, interp, frame: Frame, stmt: stmts.sin):
|
159
|
+
values = frame.get_values(stmt.args)
|
160
|
+
return (math.sin(values[0]),)
|
161
|
+
|
162
|
+
@impl(stmts.sinh)
|
163
|
+
def sinh(self, interp, frame: Frame, stmt: stmts.sinh):
|
164
|
+
values = frame.get_values(stmt.args)
|
165
|
+
return (math.sinh(values[0]),)
|
166
|
+
|
167
|
+
@impl(stmts.sqrt)
|
168
|
+
def sqrt(self, interp, frame: Frame, stmt: stmts.sqrt):
|
169
|
+
values = frame.get_values(stmt.args)
|
170
|
+
return (math.sqrt(values[0]),)
|
171
|
+
|
172
|
+
@impl(stmts.tan)
|
173
|
+
def tan(self, interp, frame: Frame, stmt: stmts.tan):
|
174
|
+
values = frame.get_values(stmt.args)
|
175
|
+
return (math.tan(values[0]),)
|
176
|
+
|
177
|
+
@impl(stmts.tanh)
|
178
|
+
def tanh(self, interp, frame: Frame, stmt: stmts.tanh):
|
179
|
+
values = frame.get_values(stmt.args)
|
180
|
+
return (math.tanh(values[0]),)
|
181
|
+
|
182
|
+
@impl(stmts.trunc)
|
183
|
+
def trunc(self, interp, frame: Frame, stmt: stmts.trunc):
|
184
|
+
values = frame.get_values(stmt.args)
|
185
|
+
return (math.trunc(values[0]),)
|
186
|
+
|
187
|
+
@impl(stmts.ulp)
|
188
|
+
def ulp(self, interp, frame: Frame, stmt: stmts.ulp):
|
189
|
+
values = frame.get_values(stmt.args)
|
190
|
+
return (math.ulp(values[0]),)
|