kirin-toolchain 0.13.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- kirin/__init__.py +7 -0
- kirin/analysis/__init__.py +24 -0
- kirin/analysis/callgraph.py +61 -0
- kirin/analysis/cfg.py +112 -0
- kirin/analysis/const/__init__.py +20 -0
- kirin/analysis/const/_visitor.py +2 -0
- kirin/analysis/const/_visitor.pyi +8 -0
- kirin/analysis/const/lattice.py +219 -0
- kirin/analysis/const/prop.py +116 -0
- kirin/analysis/forward.py +100 -0
- kirin/analysis/typeinfer/__init__.py +5 -0
- kirin/analysis/typeinfer/analysis.py +90 -0
- kirin/analysis/typeinfer/solve.py +141 -0
- kirin/decl/__init__.py +108 -0
- kirin/decl/base.py +65 -0
- kirin/decl/camel2snake.py +2 -0
- kirin/decl/emit/__init__.py +0 -0
- kirin/decl/emit/_create_fn.py +29 -0
- kirin/decl/emit/_set_new_attribute.py +22 -0
- kirin/decl/emit/dialect.py +8 -0
- kirin/decl/emit/init.py +277 -0
- kirin/decl/emit/name.py +10 -0
- kirin/decl/emit/property.py +182 -0
- kirin/decl/emit/repr.py +31 -0
- kirin/decl/emit/traits.py +13 -0
- kirin/decl/emit/typecheck.py +77 -0
- kirin/decl/emit/verify.py +51 -0
- kirin/decl/info.py +346 -0
- kirin/decl/scan_fields.py +157 -0
- kirin/decl/verify.py +69 -0
- kirin/dialects/__init__.py +14 -0
- kirin/dialects/_pprint_helper.py +53 -0
- kirin/dialects/cf/__init__.py +20 -0
- kirin/dialects/cf/constprop.py +51 -0
- kirin/dialects/cf/dialect.py +3 -0
- kirin/dialects/cf/emit.py +58 -0
- kirin/dialects/cf/interp.py +24 -0
- kirin/dialects/cf/stmts.py +68 -0
- kirin/dialects/cf/typeinfer.py +27 -0
- kirin/dialects/eltype.py +23 -0
- kirin/dialects/func/__init__.py +20 -0
- kirin/dialects/func/attrs.py +39 -0
- kirin/dialects/func/constprop.py +138 -0
- kirin/dialects/func/dialect.py +3 -0
- kirin/dialects/func/emit.py +80 -0
- kirin/dialects/func/interp.py +68 -0
- kirin/dialects/func/stmts.py +233 -0
- kirin/dialects/func/typeinfer.py +124 -0
- kirin/dialects/ilist/__init__.py +33 -0
- kirin/dialects/ilist/_dialect.py +3 -0
- kirin/dialects/ilist/_wrapper.py +51 -0
- kirin/dialects/ilist/interp.py +85 -0
- kirin/dialects/ilist/lowering.py +25 -0
- kirin/dialects/ilist/passes.py +32 -0
- kirin/dialects/ilist/rewrite/__init__.py +3 -0
- kirin/dialects/ilist/rewrite/const.py +45 -0
- kirin/dialects/ilist/rewrite/list.py +38 -0
- kirin/dialects/ilist/rewrite/unroll.py +131 -0
- kirin/dialects/ilist/runtime.py +63 -0
- kirin/dialects/ilist/stmts.py +102 -0
- kirin/dialects/ilist/typeinfer.py +120 -0
- kirin/dialects/lowering/__init__.py +7 -0
- kirin/dialects/lowering/call.py +48 -0
- kirin/dialects/lowering/cf.py +206 -0
- kirin/dialects/lowering/func.py +134 -0
- kirin/dialects/math/__init__.py +41 -0
- kirin/dialects/math/_gen.py +176 -0
- kirin/dialects/math/dialect.py +3 -0
- kirin/dialects/math/interp.py +190 -0
- kirin/dialects/math/stmts.py +369 -0
- kirin/dialects/module.py +139 -0
- kirin/dialects/py/__init__.py +40 -0
- kirin/dialects/py/assertion.py +91 -0
- kirin/dialects/py/assign.py +103 -0
- kirin/dialects/py/attr.py +59 -0
- kirin/dialects/py/base.py +34 -0
- kirin/dialects/py/binop/__init__.py +23 -0
- kirin/dialects/py/binop/_dialect.py +3 -0
- kirin/dialects/py/binop/interp.py +60 -0
- kirin/dialects/py/binop/julia.py +33 -0
- kirin/dialects/py/binop/lowering.py +22 -0
- kirin/dialects/py/binop/stmts.py +79 -0
- kirin/dialects/py/binop/typeinfer.py +108 -0
- kirin/dialects/py/boolop.py +84 -0
- kirin/dialects/py/builtin.py +78 -0
- kirin/dialects/py/cmp/__init__.py +16 -0
- kirin/dialects/py/cmp/_dialect.py +3 -0
- kirin/dialects/py/cmp/interp.py +48 -0
- kirin/dialects/py/cmp/julia.py +33 -0
- kirin/dialects/py/cmp/lowering.py +45 -0
- kirin/dialects/py/cmp/stmts.py +62 -0
- kirin/dialects/py/constant.py +79 -0
- kirin/dialects/py/indexing.py +251 -0
- kirin/dialects/py/iterable.py +90 -0
- kirin/dialects/py/len.py +57 -0
- kirin/dialects/py/list/__init__.py +15 -0
- kirin/dialects/py/list/_dialect.py +3 -0
- kirin/dialects/py/list/interp.py +21 -0
- kirin/dialects/py/list/lowering.py +25 -0
- kirin/dialects/py/list/stmts.py +22 -0
- kirin/dialects/py/list/typeinfer.py +54 -0
- kirin/dialects/py/range.py +76 -0
- kirin/dialects/py/slice.py +120 -0
- kirin/dialects/py/tuple.py +109 -0
- kirin/dialects/py/unary/__init__.py +24 -0
- kirin/dialects/py/unary/_dialect.py +3 -0
- kirin/dialects/py/unary/constprop.py +20 -0
- kirin/dialects/py/unary/interp.py +24 -0
- kirin/dialects/py/unary/julia.py +21 -0
- kirin/dialects/py/unary/lowering.py +22 -0
- kirin/dialects/py/unary/stmts.py +33 -0
- kirin/dialects/py/unary/typeinfer.py +23 -0
- kirin/dialects/py/unpack.py +90 -0
- kirin/dialects/scf/__init__.py +23 -0
- kirin/dialects/scf/_dialect.py +3 -0
- kirin/dialects/scf/absint.py +64 -0
- kirin/dialects/scf/constprop.py +140 -0
- kirin/dialects/scf/interp.py +35 -0
- kirin/dialects/scf/lowering.py +123 -0
- kirin/dialects/scf/stmts.py +250 -0
- kirin/dialects/scf/trim.py +36 -0
- kirin/dialects/scf/typeinfer.py +58 -0
- kirin/dialects/scf/unroll.py +92 -0
- kirin/emit/__init__.py +3 -0
- kirin/emit/abc.py +89 -0
- kirin/emit/abc.pyi +38 -0
- kirin/emit/exceptions.py +5 -0
- kirin/emit/julia.py +63 -0
- kirin/emit/str.py +51 -0
- kirin/exceptions.py +59 -0
- kirin/graph.py +34 -0
- kirin/idtable.py +57 -0
- kirin/interp/__init__.py +39 -0
- kirin/interp/abstract.py +253 -0
- kirin/interp/base.py +438 -0
- kirin/interp/concrete.py +62 -0
- kirin/interp/exceptions.py +26 -0
- kirin/interp/frame.py +151 -0
- kirin/interp/impl.py +197 -0
- kirin/interp/result.py +93 -0
- kirin/interp/state.py +71 -0
- kirin/interp/table.py +40 -0
- kirin/interp/value.py +73 -0
- kirin/ir/__init__.py +46 -0
- kirin/ir/attrs/__init__.py +20 -0
- kirin/ir/attrs/_types.py +8 -0
- kirin/ir/attrs/_types.pyi +13 -0
- kirin/ir/attrs/abc.py +46 -0
- kirin/ir/attrs/py.py +45 -0
- kirin/ir/attrs/types.py +522 -0
- kirin/ir/dialect.py +125 -0
- kirin/ir/group.py +249 -0
- kirin/ir/method.py +118 -0
- kirin/ir/nodes/__init__.py +7 -0
- kirin/ir/nodes/base.py +149 -0
- kirin/ir/nodes/block.py +458 -0
- kirin/ir/nodes/region.py +337 -0
- kirin/ir/nodes/stmt.py +713 -0
- kirin/ir/nodes/view.py +142 -0
- kirin/ir/ssa.py +204 -0
- kirin/ir/traits/__init__.py +36 -0
- kirin/ir/traits/abc.py +42 -0
- kirin/ir/traits/basic.py +78 -0
- kirin/ir/traits/callable.py +51 -0
- kirin/ir/traits/lowering/__init__.py +2 -0
- kirin/ir/traits/lowering/call.py +37 -0
- kirin/ir/traits/lowering/context.py +120 -0
- kirin/ir/traits/region/__init__.py +2 -0
- kirin/ir/traits/region/ssacfg.py +22 -0
- kirin/ir/traits/symbol.py +57 -0
- kirin/ir/use.py +17 -0
- kirin/lattice/__init__.py +13 -0
- kirin/lattice/abc.py +128 -0
- kirin/lattice/empty.py +25 -0
- kirin/lattice/mixin.py +51 -0
- kirin/lowering/__init__.py +7 -0
- kirin/lowering/binding.py +65 -0
- kirin/lowering/core.py +72 -0
- kirin/lowering/dialect.py +35 -0
- kirin/lowering/dialect.pyi +183 -0
- kirin/lowering/frame.py +171 -0
- kirin/lowering/result.py +68 -0
- kirin/lowering/state.py +441 -0
- kirin/lowering/stream.py +53 -0
- kirin/passes/__init__.py +3 -0
- kirin/passes/abc.py +44 -0
- kirin/passes/aggressive/__init__.py +1 -0
- kirin/passes/aggressive/fold.py +43 -0
- kirin/passes/fold.py +45 -0
- kirin/passes/inline.py +25 -0
- kirin/passes/typeinfer.py +25 -0
- kirin/prelude.py +197 -0
- kirin/print/__init__.py +15 -0
- kirin/print/printable.py +141 -0
- kirin/print/printer.py +415 -0
- kirin/py.typed +0 -0
- kirin/registry.py +105 -0
- kirin/registry.pyi +52 -0
- kirin/rewrite/__init__.py +14 -0
- kirin/rewrite/abc.py +43 -0
- kirin/rewrite/aggressive/__init__.py +1 -0
- kirin/rewrite/aggressive/fold.py +43 -0
- kirin/rewrite/alias.py +16 -0
- kirin/rewrite/apply_type.py +47 -0
- kirin/rewrite/call2invoke.py +34 -0
- kirin/rewrite/chain.py +39 -0
- kirin/rewrite/compactify.py +288 -0
- kirin/rewrite/cse.py +48 -0
- kirin/rewrite/dce.py +19 -0
- kirin/rewrite/fixpoint.py +34 -0
- kirin/rewrite/fold.py +57 -0
- kirin/rewrite/getfield.py +21 -0
- kirin/rewrite/getitem.py +37 -0
- kirin/rewrite/inline.py +143 -0
- kirin/rewrite/result.py +15 -0
- kirin/rewrite/walk.py +83 -0
- kirin/rewrite/wrap_const.py +55 -0
- kirin/source.py +21 -0
- kirin/symbol_table.py +27 -0
- kirin/types.py +34 -0
- kirin/worklist.py +30 -0
- kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
- kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
- kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
- kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,369 @@
|
|
1
|
+
# This file is generated by gen.py
|
2
|
+
from kirin import ir, types
|
3
|
+
from kirin.decl import info, statement
|
4
|
+
from kirin.dialects.math.dialect import dialect
|
5
|
+
|
6
|
+
|
7
|
+
@statement(dialect=dialect)
|
8
|
+
class acos(ir.Statement):
|
9
|
+
"""acos statement, wrapping the math.acos function"""
|
10
|
+
|
11
|
+
name = "acos"
|
12
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
13
|
+
x: ir.SSAValue = info.argument(types.Float)
|
14
|
+
result: ir.ResultValue = info.result(types.Float)
|
15
|
+
|
16
|
+
|
17
|
+
@statement(dialect=dialect)
|
18
|
+
class asin(ir.Statement):
|
19
|
+
"""asin statement, wrapping the math.asin function"""
|
20
|
+
|
21
|
+
name = "asin"
|
22
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
23
|
+
x: ir.SSAValue = info.argument(types.Float)
|
24
|
+
result: ir.ResultValue = info.result(types.Float)
|
25
|
+
|
26
|
+
|
27
|
+
@statement(dialect=dialect)
|
28
|
+
class asinh(ir.Statement):
|
29
|
+
"""asinh statement, wrapping the math.asinh function"""
|
30
|
+
|
31
|
+
name = "asinh"
|
32
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
33
|
+
x: ir.SSAValue = info.argument(types.Float)
|
34
|
+
result: ir.ResultValue = info.result(types.Float)
|
35
|
+
|
36
|
+
|
37
|
+
@statement(dialect=dialect)
|
38
|
+
class atan(ir.Statement):
|
39
|
+
"""atan statement, wrapping the math.atan function"""
|
40
|
+
|
41
|
+
name = "atan"
|
42
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
43
|
+
x: ir.SSAValue = info.argument(types.Float)
|
44
|
+
result: ir.ResultValue = info.result(types.Float)
|
45
|
+
|
46
|
+
|
47
|
+
@statement(dialect=dialect)
|
48
|
+
class atan2(ir.Statement):
|
49
|
+
"""atan2 statement, wrapping the math.atan2 function"""
|
50
|
+
|
51
|
+
name = "atan2"
|
52
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
53
|
+
y: ir.SSAValue = info.argument(types.Float)
|
54
|
+
x: ir.SSAValue = info.argument(types.Float)
|
55
|
+
result: ir.ResultValue = info.result(types.Float)
|
56
|
+
|
57
|
+
|
58
|
+
@statement(dialect=dialect)
|
59
|
+
class atanh(ir.Statement):
|
60
|
+
"""atanh statement, wrapping the math.atanh function"""
|
61
|
+
|
62
|
+
name = "atanh"
|
63
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
64
|
+
x: ir.SSAValue = info.argument(types.Float)
|
65
|
+
result: ir.ResultValue = info.result(types.Float)
|
66
|
+
|
67
|
+
|
68
|
+
@statement(dialect=dialect)
|
69
|
+
class ceil(ir.Statement):
|
70
|
+
"""ceil statement, wrapping the math.ceil function"""
|
71
|
+
|
72
|
+
name = "ceil"
|
73
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
74
|
+
x: ir.SSAValue = info.argument(types.Float)
|
75
|
+
result: ir.ResultValue = info.result(types.Float)
|
76
|
+
|
77
|
+
|
78
|
+
@statement(dialect=dialect)
|
79
|
+
class copysign(ir.Statement):
|
80
|
+
"""copysign statement, wrapping the math.copysign function"""
|
81
|
+
|
82
|
+
name = "copysign"
|
83
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
84
|
+
x: ir.SSAValue = info.argument(types.Float)
|
85
|
+
y: ir.SSAValue = info.argument(types.Float)
|
86
|
+
result: ir.ResultValue = info.result(types.Float)
|
87
|
+
|
88
|
+
|
89
|
+
@statement(dialect=dialect)
|
90
|
+
class cos(ir.Statement):
|
91
|
+
"""cos statement, wrapping the math.cos function"""
|
92
|
+
|
93
|
+
name = "cos"
|
94
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
95
|
+
x: ir.SSAValue = info.argument(types.Float)
|
96
|
+
result: ir.ResultValue = info.result(types.Float)
|
97
|
+
|
98
|
+
|
99
|
+
@statement(dialect=dialect)
|
100
|
+
class cosh(ir.Statement):
|
101
|
+
"""cosh statement, wrapping the math.cosh function"""
|
102
|
+
|
103
|
+
name = "cosh"
|
104
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
105
|
+
x: ir.SSAValue = info.argument(types.Float)
|
106
|
+
result: ir.ResultValue = info.result(types.Float)
|
107
|
+
|
108
|
+
|
109
|
+
@statement(dialect=dialect)
|
110
|
+
class degrees(ir.Statement):
|
111
|
+
"""degrees statement, wrapping the math.degrees function"""
|
112
|
+
|
113
|
+
name = "degrees"
|
114
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
115
|
+
x: ir.SSAValue = info.argument(types.Float)
|
116
|
+
result: ir.ResultValue = info.result(types.Float)
|
117
|
+
|
118
|
+
|
119
|
+
@statement(dialect=dialect)
|
120
|
+
class erf(ir.Statement):
|
121
|
+
"""erf statement, wrapping the math.erf function"""
|
122
|
+
|
123
|
+
name = "erf"
|
124
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
125
|
+
x: ir.SSAValue = info.argument(types.Float)
|
126
|
+
result: ir.ResultValue = info.result(types.Float)
|
127
|
+
|
128
|
+
|
129
|
+
@statement(dialect=dialect)
|
130
|
+
class erfc(ir.Statement):
|
131
|
+
"""erfc statement, wrapping the math.erfc function"""
|
132
|
+
|
133
|
+
name = "erfc"
|
134
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
135
|
+
x: ir.SSAValue = info.argument(types.Float)
|
136
|
+
result: ir.ResultValue = info.result(types.Float)
|
137
|
+
|
138
|
+
|
139
|
+
@statement(dialect=dialect)
|
140
|
+
class exp(ir.Statement):
|
141
|
+
"""exp statement, wrapping the math.exp function"""
|
142
|
+
|
143
|
+
name = "exp"
|
144
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
145
|
+
x: ir.SSAValue = info.argument(types.Float)
|
146
|
+
result: ir.ResultValue = info.result(types.Float)
|
147
|
+
|
148
|
+
|
149
|
+
@statement(dialect=dialect)
|
150
|
+
class expm1(ir.Statement):
|
151
|
+
"""expm1 statement, wrapping the math.expm1 function"""
|
152
|
+
|
153
|
+
name = "expm1"
|
154
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
155
|
+
x: ir.SSAValue = info.argument(types.Float)
|
156
|
+
result: ir.ResultValue = info.result(types.Float)
|
157
|
+
|
158
|
+
|
159
|
+
@statement(dialect=dialect)
|
160
|
+
class fabs(ir.Statement):
|
161
|
+
"""fabs statement, wrapping the math.fabs function"""
|
162
|
+
|
163
|
+
name = "fabs"
|
164
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
165
|
+
x: ir.SSAValue = info.argument(types.Float)
|
166
|
+
result: ir.ResultValue = info.result(types.Float)
|
167
|
+
|
168
|
+
|
169
|
+
@statement(dialect=dialect)
|
170
|
+
class floor(ir.Statement):
|
171
|
+
"""floor statement, wrapping the math.floor function"""
|
172
|
+
|
173
|
+
name = "floor"
|
174
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
175
|
+
x: ir.SSAValue = info.argument(types.Float)
|
176
|
+
result: ir.ResultValue = info.result(types.Float)
|
177
|
+
|
178
|
+
|
179
|
+
@statement(dialect=dialect)
|
180
|
+
class fmod(ir.Statement):
|
181
|
+
"""fmod statement, wrapping the math.fmod function"""
|
182
|
+
|
183
|
+
name = "fmod"
|
184
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
185
|
+
x: ir.SSAValue = info.argument(types.Float)
|
186
|
+
y: ir.SSAValue = info.argument(types.Float)
|
187
|
+
result: ir.ResultValue = info.result(types.Float)
|
188
|
+
|
189
|
+
|
190
|
+
@statement(dialect=dialect)
|
191
|
+
class gamma(ir.Statement):
|
192
|
+
"""gamma statement, wrapping the math.gamma function"""
|
193
|
+
|
194
|
+
name = "gamma"
|
195
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
196
|
+
x: ir.SSAValue = info.argument(types.Float)
|
197
|
+
result: ir.ResultValue = info.result(types.Float)
|
198
|
+
|
199
|
+
|
200
|
+
@statement(dialect=dialect)
|
201
|
+
class isfinite(ir.Statement):
|
202
|
+
"""isfinite statement, wrapping the math.isfinite function"""
|
203
|
+
|
204
|
+
name = "isfinite"
|
205
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
206
|
+
x: ir.SSAValue = info.argument(types.Float)
|
207
|
+
result: ir.ResultValue = info.result(types.Float)
|
208
|
+
|
209
|
+
|
210
|
+
@statement(dialect=dialect)
|
211
|
+
class isinf(ir.Statement):
|
212
|
+
"""isinf statement, wrapping the math.isinf function"""
|
213
|
+
|
214
|
+
name = "isinf"
|
215
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
216
|
+
x: ir.SSAValue = info.argument(types.Float)
|
217
|
+
result: ir.ResultValue = info.result(types.Float)
|
218
|
+
|
219
|
+
|
220
|
+
@statement(dialect=dialect)
|
221
|
+
class isnan(ir.Statement):
|
222
|
+
"""isnan statement, wrapping the math.isnan function"""
|
223
|
+
|
224
|
+
name = "isnan"
|
225
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
226
|
+
x: ir.SSAValue = info.argument(types.Float)
|
227
|
+
result: ir.ResultValue = info.result(types.Float)
|
228
|
+
|
229
|
+
|
230
|
+
@statement(dialect=dialect)
|
231
|
+
class lgamma(ir.Statement):
|
232
|
+
"""lgamma statement, wrapping the math.lgamma function"""
|
233
|
+
|
234
|
+
name = "lgamma"
|
235
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
236
|
+
x: ir.SSAValue = info.argument(types.Float)
|
237
|
+
result: ir.ResultValue = info.result(types.Float)
|
238
|
+
|
239
|
+
|
240
|
+
@statement(dialect=dialect)
|
241
|
+
class log10(ir.Statement):
|
242
|
+
"""log10 statement, wrapping the math.log10 function"""
|
243
|
+
|
244
|
+
name = "log10"
|
245
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
246
|
+
x: ir.SSAValue = info.argument(types.Float)
|
247
|
+
result: ir.ResultValue = info.result(types.Float)
|
248
|
+
|
249
|
+
|
250
|
+
@statement(dialect=dialect)
|
251
|
+
class log1p(ir.Statement):
|
252
|
+
"""log1p statement, wrapping the math.log1p function"""
|
253
|
+
|
254
|
+
name = "log1p"
|
255
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
256
|
+
x: ir.SSAValue = info.argument(types.Float)
|
257
|
+
result: ir.ResultValue = info.result(types.Float)
|
258
|
+
|
259
|
+
|
260
|
+
@statement(dialect=dialect)
|
261
|
+
class log2(ir.Statement):
|
262
|
+
"""log2 statement, wrapping the math.log2 function"""
|
263
|
+
|
264
|
+
name = "log2"
|
265
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
266
|
+
x: ir.SSAValue = info.argument(types.Float)
|
267
|
+
result: ir.ResultValue = info.result(types.Float)
|
268
|
+
|
269
|
+
|
270
|
+
@statement(dialect=dialect)
|
271
|
+
class pow(ir.Statement):
|
272
|
+
"""pow statement, wrapping the math.pow function"""
|
273
|
+
|
274
|
+
name = "pow"
|
275
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
276
|
+
x: ir.SSAValue = info.argument(types.Float)
|
277
|
+
y: ir.SSAValue = info.argument(types.Float)
|
278
|
+
result: ir.ResultValue = info.result(types.Float)
|
279
|
+
|
280
|
+
|
281
|
+
@statement(dialect=dialect)
|
282
|
+
class radians(ir.Statement):
|
283
|
+
"""radians statement, wrapping the math.radians function"""
|
284
|
+
|
285
|
+
name = "radians"
|
286
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
287
|
+
x: ir.SSAValue = info.argument(types.Float)
|
288
|
+
result: ir.ResultValue = info.result(types.Float)
|
289
|
+
|
290
|
+
|
291
|
+
@statement(dialect=dialect)
|
292
|
+
class remainder(ir.Statement):
|
293
|
+
"""remainder statement, wrapping the math.remainder function"""
|
294
|
+
|
295
|
+
name = "remainder"
|
296
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
297
|
+
x: ir.SSAValue = info.argument(types.Float)
|
298
|
+
y: ir.SSAValue = info.argument(types.Float)
|
299
|
+
result: ir.ResultValue = info.result(types.Float)
|
300
|
+
|
301
|
+
|
302
|
+
@statement(dialect=dialect)
|
303
|
+
class sin(ir.Statement):
|
304
|
+
"""sin statement, wrapping the math.sin function"""
|
305
|
+
|
306
|
+
name = "sin"
|
307
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
308
|
+
x: ir.SSAValue = info.argument(types.Float)
|
309
|
+
result: ir.ResultValue = info.result(types.Float)
|
310
|
+
|
311
|
+
|
312
|
+
@statement(dialect=dialect)
|
313
|
+
class sinh(ir.Statement):
|
314
|
+
"""sinh statement, wrapping the math.sinh function"""
|
315
|
+
|
316
|
+
name = "sinh"
|
317
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
318
|
+
x: ir.SSAValue = info.argument(types.Float)
|
319
|
+
result: ir.ResultValue = info.result(types.Float)
|
320
|
+
|
321
|
+
|
322
|
+
@statement(dialect=dialect)
|
323
|
+
class sqrt(ir.Statement):
|
324
|
+
"""sqrt statement, wrapping the math.sqrt function"""
|
325
|
+
|
326
|
+
name = "sqrt"
|
327
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
328
|
+
x: ir.SSAValue = info.argument(types.Float)
|
329
|
+
result: ir.ResultValue = info.result(types.Float)
|
330
|
+
|
331
|
+
|
332
|
+
@statement(dialect=dialect)
|
333
|
+
class tan(ir.Statement):
|
334
|
+
"""tan statement, wrapping the math.tan function"""
|
335
|
+
|
336
|
+
name = "tan"
|
337
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
338
|
+
x: ir.SSAValue = info.argument(types.Float)
|
339
|
+
result: ir.ResultValue = info.result(types.Float)
|
340
|
+
|
341
|
+
|
342
|
+
@statement(dialect=dialect)
|
343
|
+
class tanh(ir.Statement):
|
344
|
+
"""tanh statement, wrapping the math.tanh function"""
|
345
|
+
|
346
|
+
name = "tanh"
|
347
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
348
|
+
x: ir.SSAValue = info.argument(types.Float)
|
349
|
+
result: ir.ResultValue = info.result(types.Float)
|
350
|
+
|
351
|
+
|
352
|
+
@statement(dialect=dialect)
|
353
|
+
class trunc(ir.Statement):
|
354
|
+
"""trunc statement, wrapping the math.trunc function"""
|
355
|
+
|
356
|
+
name = "trunc"
|
357
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
358
|
+
x: ir.SSAValue = info.argument(types.Float)
|
359
|
+
result: ir.ResultValue = info.result(types.Float)
|
360
|
+
|
361
|
+
|
362
|
+
@statement(dialect=dialect)
|
363
|
+
class ulp(ir.Statement):
|
364
|
+
"""ulp statement, wrapping the math.ulp function"""
|
365
|
+
|
366
|
+
name = "ulp"
|
367
|
+
traits = frozenset({ir.Pure(), ir.FromPythonCall()})
|
368
|
+
x: ir.SSAValue = info.argument(types.Float)
|
369
|
+
result: ir.ResultValue = info.result(types.Float)
|
kirin/dialects/module.py
ADDED
@@ -0,0 +1,139 @@
|
|
1
|
+
"""Module dialect provides a simple module
|
2
|
+
that is roughly a list of function statements.
|
3
|
+
|
4
|
+
This dialect provides the dialect necessary for compiling a function into
|
5
|
+
lower-level IR with all its callee functions.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from kirin import ir, types, interp
|
9
|
+
from kirin.decl import info, statement
|
10
|
+
from kirin.print import Printer
|
11
|
+
from kirin.analysis import TypeInference
|
12
|
+
from kirin.exceptions import VerificationError
|
13
|
+
|
14
|
+
from ._pprint_helper import pprint_calllike
|
15
|
+
|
16
|
+
dialect = ir.Dialect("module")
|
17
|
+
|
18
|
+
|
19
|
+
@statement(dialect=dialect)
|
20
|
+
class Module(ir.Statement):
|
21
|
+
traits = frozenset(
|
22
|
+
{ir.IsolatedFromAbove(), ir.SymbolTable(), ir.SymbolOpInterface()}
|
23
|
+
)
|
24
|
+
sym_name: str = info.attribute()
|
25
|
+
entry: str = info.attribute()
|
26
|
+
body: ir.Region = info.region(multi=False)
|
27
|
+
|
28
|
+
|
29
|
+
@statement(dialect=dialect)
|
30
|
+
class Invoke(ir.Statement):
|
31
|
+
"""A special statement that represents
|
32
|
+
a function calling functions by symbol name.
|
33
|
+
|
34
|
+
Note:
|
35
|
+
This statement is here for completeness, for interpretation,
|
36
|
+
it is recommended to rewrite this statement into a `func.Invoke`
|
37
|
+
after looking up the symbol table.
|
38
|
+
"""
|
39
|
+
|
40
|
+
callee: str = info.attribute()
|
41
|
+
inputs: tuple[ir.SSAValue, ...] = info.argument()
|
42
|
+
kwargs: tuple[str, ...] = info.attribute()
|
43
|
+
result: ir.ResultValue = info.result()
|
44
|
+
|
45
|
+
def print_impl(self, printer: Printer) -> None:
|
46
|
+
pprint_calllike(self, self.callee, printer)
|
47
|
+
|
48
|
+
def verify(self) -> None:
|
49
|
+
if self.kwargs:
|
50
|
+
for name in self.kwargs:
|
51
|
+
if name not in self.callee:
|
52
|
+
raise VerificationError(
|
53
|
+
self,
|
54
|
+
f"method {self.callee} does not have argument {name}",
|
55
|
+
)
|
56
|
+
elif len(self.callee) - 1 != len(self.args):
|
57
|
+
raise VerificationError(
|
58
|
+
self,
|
59
|
+
f"expected {len(self.callee)} arguments, got {len(self.args)}",
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
@dialect.register
|
64
|
+
class Concrete(interp.MethodTable):
|
65
|
+
|
66
|
+
@interp.impl(Module)
|
67
|
+
def interp_Module(
|
68
|
+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: Module
|
69
|
+
):
|
70
|
+
for stmt_ in stmt.body.blocks[0].stmts:
|
71
|
+
if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
|
72
|
+
interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
|
73
|
+
return ()
|
74
|
+
|
75
|
+
@interp.impl(Invoke)
|
76
|
+
def interp_Invoke(
|
77
|
+
self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: Invoke
|
78
|
+
):
|
79
|
+
callee = interpreter.symbol_table.get(stmt.callee)
|
80
|
+
if callee is None:
|
81
|
+
raise interp.InterpreterError(f"symbol {stmt.callee} not found")
|
82
|
+
|
83
|
+
trait = callee.get_trait(ir.CallableStmtInterface)
|
84
|
+
if trait is None:
|
85
|
+
raise interp.InterpreterError(
|
86
|
+
f"{stmt.callee} is not callable, got {callee.__class__.__name__}"
|
87
|
+
)
|
88
|
+
|
89
|
+
body = trait.get_callable_region(callee)
|
90
|
+
mt = ir.Method(
|
91
|
+
mod=None,
|
92
|
+
py_func=None,
|
93
|
+
sym_name=stmt.callee,
|
94
|
+
arg_names=[
|
95
|
+
arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
|
96
|
+
],
|
97
|
+
dialects=interpreter.dialects,
|
98
|
+
code=stmt,
|
99
|
+
)
|
100
|
+
return interpreter.run_method(mt, frame.get_values(stmt.inputs))
|
101
|
+
|
102
|
+
|
103
|
+
@dialect.register(key="typeinfer")
|
104
|
+
class TypeInfer(interp.MethodTable):
|
105
|
+
|
106
|
+
@interp.impl(Module)
|
107
|
+
def typeinfer_Module(
|
108
|
+
self, interp: TypeInference, frame: interp.Frame, stmt: Module
|
109
|
+
):
|
110
|
+
for stmt_ in stmt.body.blocks[0].stmts:
|
111
|
+
if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
|
112
|
+
interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
|
113
|
+
return ()
|
114
|
+
|
115
|
+
@interp.impl(Invoke)
|
116
|
+
def typeinfer_Invoke(
|
117
|
+
self, interp: TypeInference, frame: interp.Frame, stmt: Invoke
|
118
|
+
):
|
119
|
+
callee = interp.symbol_table.get(stmt.callee)
|
120
|
+
if callee is None:
|
121
|
+
return (types.Bottom,)
|
122
|
+
|
123
|
+
trait = callee.get_trait(ir.CallableStmtInterface)
|
124
|
+
if trait is None:
|
125
|
+
return (types.Bottom,)
|
126
|
+
|
127
|
+
body = trait.get_callable_region(callee)
|
128
|
+
mt = ir.Method(
|
129
|
+
mod=None,
|
130
|
+
py_func=None,
|
131
|
+
sym_name=stmt.callee,
|
132
|
+
arg_names=[
|
133
|
+
arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
|
134
|
+
],
|
135
|
+
dialects=interp.dialects,
|
136
|
+
code=stmt,
|
137
|
+
)
|
138
|
+
interp.run_method(mt, mt.arg_types)
|
139
|
+
return tuple(result.type for result in callee.results)
|
@@ -0,0 +1,40 @@
|
|
1
|
+
"""Python dialects module.
|
2
|
+
|
3
|
+
This module contains a set of dialects that represent
|
4
|
+
different fractions of the Python language. The dialects
|
5
|
+
are designed to be used in a union to represent the
|
6
|
+
entire Python language.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from . import (
|
10
|
+
cmp as cmp,
|
11
|
+
len as len,
|
12
|
+
attr as attr,
|
13
|
+
base as base,
|
14
|
+
list as list,
|
15
|
+
binop as binop,
|
16
|
+
range as range,
|
17
|
+
slice as slice,
|
18
|
+
tuple as tuple,
|
19
|
+
unary as unary,
|
20
|
+
assign as assign,
|
21
|
+
boolop as boolop,
|
22
|
+
unpack as unpack,
|
23
|
+
builtin as builtin,
|
24
|
+
constant as constant,
|
25
|
+
indexing as indexing,
|
26
|
+
iterable as iterable,
|
27
|
+
)
|
28
|
+
from .len import Len as Len
|
29
|
+
from .attr import GetAttr as GetAttr
|
30
|
+
from .range import Range as Range
|
31
|
+
from .slice import Slice as Slice
|
32
|
+
from .assign import Alias as Alias, SetItem as SetItem
|
33
|
+
from .boolop import Or as Or, And as And
|
34
|
+
from .builtin import Abs as Abs, Sum as Sum
|
35
|
+
from .constant import Constant as Constant
|
36
|
+
from .indexing import GetItem as GetItem, PyGetItemLike as PyGetItemLike
|
37
|
+
from .cmp.stmts import * # noqa: F403
|
38
|
+
from .list.stmts import Append as Append
|
39
|
+
from .binop.stmts import * # noqa: F403
|
40
|
+
from .unary.stmts import * # noqa: F403
|
@@ -0,0 +1,91 @@
|
|
1
|
+
"""Assertion dialect for Python.
|
2
|
+
|
3
|
+
This module contains the dialect for the Python `assert` statement, including:
|
4
|
+
|
5
|
+
- The `Assert` statement class.
|
6
|
+
- The lowering pass for the `assert` statement.
|
7
|
+
- The concrete implementation of the `assert` statement.
|
8
|
+
- The type inference implementation of the `assert` statement.
|
9
|
+
- The Julia emitter for the `assert` statement.
|
10
|
+
|
11
|
+
This dialect maps `ast.Assert` nodes to the `Assert` statement.
|
12
|
+
"""
|
13
|
+
|
14
|
+
import ast
|
15
|
+
|
16
|
+
from kirin import ir, types, interp, lowering
|
17
|
+
from kirin.decl import info, statement
|
18
|
+
from kirin.emit import EmitStrFrame, julia
|
19
|
+
from kirin.print import Printer
|
20
|
+
|
21
|
+
dialect = ir.Dialect("py.assert")
|
22
|
+
|
23
|
+
|
24
|
+
@statement(dialect=dialect)
|
25
|
+
class Assert(ir.Statement):
|
26
|
+
condition: ir.SSAValue
|
27
|
+
message: ir.SSAValue = info.argument(types.String)
|
28
|
+
|
29
|
+
def print_impl(self, printer: Printer) -> None:
|
30
|
+
with printer.rich(style="keyword"):
|
31
|
+
printer.print_name(self)
|
32
|
+
|
33
|
+
printer.plain_print(" ")
|
34
|
+
printer.print(self.condition)
|
35
|
+
|
36
|
+
if self.message:
|
37
|
+
printer.plain_print(", ")
|
38
|
+
printer.print(self.message)
|
39
|
+
|
40
|
+
|
41
|
+
@dialect.register
|
42
|
+
class Lowering(lowering.FromPythonAST):
|
43
|
+
|
44
|
+
def lower_Assert(
|
45
|
+
self, state: lowering.LoweringState, node: ast.Assert
|
46
|
+
) -> lowering.Result:
|
47
|
+
from kirin.dialects.py.constant import Constant
|
48
|
+
|
49
|
+
cond = state.visit(node.test).expect_one()
|
50
|
+
if node.msg:
|
51
|
+
message = state.visit(node.msg).expect_one()
|
52
|
+
state.append_stmt(Assert(condition=cond, message=message))
|
53
|
+
else:
|
54
|
+
message_stmt = state.append_stmt(Constant(""))
|
55
|
+
state.append_stmt(Assert(condition=cond, message=message_stmt.result))
|
56
|
+
return lowering.Result()
|
57
|
+
|
58
|
+
|
59
|
+
@dialect.register
|
60
|
+
class Concrete(interp.MethodTable):
|
61
|
+
|
62
|
+
@interp.impl(Assert)
|
63
|
+
def assert_stmt(
|
64
|
+
self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Assert
|
65
|
+
):
|
66
|
+
if frame.get(stmt.condition) is True:
|
67
|
+
return ()
|
68
|
+
|
69
|
+
if stmt.message:
|
70
|
+
raise interp.WrapException(AssertionError(frame.get(stmt.message)))
|
71
|
+
else:
|
72
|
+
raise interp.WrapException(AssertionError("Assertion failed"))
|
73
|
+
|
74
|
+
|
75
|
+
@dialect.register(key="typeinfer")
|
76
|
+
class TypeInfer(interp.MethodTable):
|
77
|
+
|
78
|
+
@interp.impl(Assert)
|
79
|
+
def assert_stmt(self, interp, frame, stmt: Assert):
|
80
|
+
return (types.Bottom,)
|
81
|
+
|
82
|
+
|
83
|
+
@dialect.register(key="emit.julia")
|
84
|
+
class EmitJulia(interp.MethodTable):
|
85
|
+
|
86
|
+
@interp.impl(Assert)
|
87
|
+
def emit_assert(self, interp: julia.EmitJulia, frame: EmitStrFrame, stmt: Assert):
|
88
|
+
interp.writeln(
|
89
|
+
frame, f"@assert {frame.get(stmt.condition)} {frame.get(stmt.message)}"
|
90
|
+
)
|
91
|
+
return ()
|