kirin-toolchain 0.13.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (225) hide show
  1. kirin/__init__.py +7 -0
  2. kirin/analysis/__init__.py +24 -0
  3. kirin/analysis/callgraph.py +61 -0
  4. kirin/analysis/cfg.py +112 -0
  5. kirin/analysis/const/__init__.py +20 -0
  6. kirin/analysis/const/_visitor.py +2 -0
  7. kirin/analysis/const/_visitor.pyi +8 -0
  8. kirin/analysis/const/lattice.py +219 -0
  9. kirin/analysis/const/prop.py +116 -0
  10. kirin/analysis/forward.py +100 -0
  11. kirin/analysis/typeinfer/__init__.py +5 -0
  12. kirin/analysis/typeinfer/analysis.py +90 -0
  13. kirin/analysis/typeinfer/solve.py +141 -0
  14. kirin/decl/__init__.py +108 -0
  15. kirin/decl/base.py +65 -0
  16. kirin/decl/camel2snake.py +2 -0
  17. kirin/decl/emit/__init__.py +0 -0
  18. kirin/decl/emit/_create_fn.py +29 -0
  19. kirin/decl/emit/_set_new_attribute.py +22 -0
  20. kirin/decl/emit/dialect.py +8 -0
  21. kirin/decl/emit/init.py +277 -0
  22. kirin/decl/emit/name.py +10 -0
  23. kirin/decl/emit/property.py +182 -0
  24. kirin/decl/emit/repr.py +31 -0
  25. kirin/decl/emit/traits.py +13 -0
  26. kirin/decl/emit/typecheck.py +77 -0
  27. kirin/decl/emit/verify.py +51 -0
  28. kirin/decl/info.py +346 -0
  29. kirin/decl/scan_fields.py +157 -0
  30. kirin/decl/verify.py +69 -0
  31. kirin/dialects/__init__.py +14 -0
  32. kirin/dialects/_pprint_helper.py +53 -0
  33. kirin/dialects/cf/__init__.py +20 -0
  34. kirin/dialects/cf/constprop.py +51 -0
  35. kirin/dialects/cf/dialect.py +3 -0
  36. kirin/dialects/cf/emit.py +58 -0
  37. kirin/dialects/cf/interp.py +24 -0
  38. kirin/dialects/cf/stmts.py +68 -0
  39. kirin/dialects/cf/typeinfer.py +27 -0
  40. kirin/dialects/eltype.py +23 -0
  41. kirin/dialects/func/__init__.py +20 -0
  42. kirin/dialects/func/attrs.py +39 -0
  43. kirin/dialects/func/constprop.py +138 -0
  44. kirin/dialects/func/dialect.py +3 -0
  45. kirin/dialects/func/emit.py +80 -0
  46. kirin/dialects/func/interp.py +68 -0
  47. kirin/dialects/func/stmts.py +233 -0
  48. kirin/dialects/func/typeinfer.py +124 -0
  49. kirin/dialects/ilist/__init__.py +33 -0
  50. kirin/dialects/ilist/_dialect.py +3 -0
  51. kirin/dialects/ilist/_wrapper.py +51 -0
  52. kirin/dialects/ilist/interp.py +85 -0
  53. kirin/dialects/ilist/lowering.py +25 -0
  54. kirin/dialects/ilist/passes.py +32 -0
  55. kirin/dialects/ilist/rewrite/__init__.py +3 -0
  56. kirin/dialects/ilist/rewrite/const.py +45 -0
  57. kirin/dialects/ilist/rewrite/list.py +38 -0
  58. kirin/dialects/ilist/rewrite/unroll.py +131 -0
  59. kirin/dialects/ilist/runtime.py +63 -0
  60. kirin/dialects/ilist/stmts.py +102 -0
  61. kirin/dialects/ilist/typeinfer.py +120 -0
  62. kirin/dialects/lowering/__init__.py +7 -0
  63. kirin/dialects/lowering/call.py +48 -0
  64. kirin/dialects/lowering/cf.py +206 -0
  65. kirin/dialects/lowering/func.py +134 -0
  66. kirin/dialects/math/__init__.py +41 -0
  67. kirin/dialects/math/_gen.py +176 -0
  68. kirin/dialects/math/dialect.py +3 -0
  69. kirin/dialects/math/interp.py +190 -0
  70. kirin/dialects/math/stmts.py +369 -0
  71. kirin/dialects/module.py +139 -0
  72. kirin/dialects/py/__init__.py +40 -0
  73. kirin/dialects/py/assertion.py +91 -0
  74. kirin/dialects/py/assign.py +103 -0
  75. kirin/dialects/py/attr.py +59 -0
  76. kirin/dialects/py/base.py +34 -0
  77. kirin/dialects/py/binop/__init__.py +23 -0
  78. kirin/dialects/py/binop/_dialect.py +3 -0
  79. kirin/dialects/py/binop/interp.py +60 -0
  80. kirin/dialects/py/binop/julia.py +33 -0
  81. kirin/dialects/py/binop/lowering.py +22 -0
  82. kirin/dialects/py/binop/stmts.py +79 -0
  83. kirin/dialects/py/binop/typeinfer.py +108 -0
  84. kirin/dialects/py/boolop.py +84 -0
  85. kirin/dialects/py/builtin.py +78 -0
  86. kirin/dialects/py/cmp/__init__.py +16 -0
  87. kirin/dialects/py/cmp/_dialect.py +3 -0
  88. kirin/dialects/py/cmp/interp.py +48 -0
  89. kirin/dialects/py/cmp/julia.py +33 -0
  90. kirin/dialects/py/cmp/lowering.py +45 -0
  91. kirin/dialects/py/cmp/stmts.py +62 -0
  92. kirin/dialects/py/constant.py +79 -0
  93. kirin/dialects/py/indexing.py +251 -0
  94. kirin/dialects/py/iterable.py +90 -0
  95. kirin/dialects/py/len.py +57 -0
  96. kirin/dialects/py/list/__init__.py +15 -0
  97. kirin/dialects/py/list/_dialect.py +3 -0
  98. kirin/dialects/py/list/interp.py +21 -0
  99. kirin/dialects/py/list/lowering.py +25 -0
  100. kirin/dialects/py/list/stmts.py +22 -0
  101. kirin/dialects/py/list/typeinfer.py +54 -0
  102. kirin/dialects/py/range.py +76 -0
  103. kirin/dialects/py/slice.py +120 -0
  104. kirin/dialects/py/tuple.py +109 -0
  105. kirin/dialects/py/unary/__init__.py +24 -0
  106. kirin/dialects/py/unary/_dialect.py +3 -0
  107. kirin/dialects/py/unary/constprop.py +20 -0
  108. kirin/dialects/py/unary/interp.py +24 -0
  109. kirin/dialects/py/unary/julia.py +21 -0
  110. kirin/dialects/py/unary/lowering.py +22 -0
  111. kirin/dialects/py/unary/stmts.py +33 -0
  112. kirin/dialects/py/unary/typeinfer.py +23 -0
  113. kirin/dialects/py/unpack.py +90 -0
  114. kirin/dialects/scf/__init__.py +23 -0
  115. kirin/dialects/scf/_dialect.py +3 -0
  116. kirin/dialects/scf/absint.py +64 -0
  117. kirin/dialects/scf/constprop.py +140 -0
  118. kirin/dialects/scf/interp.py +35 -0
  119. kirin/dialects/scf/lowering.py +123 -0
  120. kirin/dialects/scf/stmts.py +250 -0
  121. kirin/dialects/scf/trim.py +36 -0
  122. kirin/dialects/scf/typeinfer.py +58 -0
  123. kirin/dialects/scf/unroll.py +92 -0
  124. kirin/emit/__init__.py +3 -0
  125. kirin/emit/abc.py +89 -0
  126. kirin/emit/abc.pyi +38 -0
  127. kirin/emit/exceptions.py +5 -0
  128. kirin/emit/julia.py +63 -0
  129. kirin/emit/str.py +51 -0
  130. kirin/exceptions.py +59 -0
  131. kirin/graph.py +34 -0
  132. kirin/idtable.py +57 -0
  133. kirin/interp/__init__.py +39 -0
  134. kirin/interp/abstract.py +253 -0
  135. kirin/interp/base.py +438 -0
  136. kirin/interp/concrete.py +62 -0
  137. kirin/interp/exceptions.py +26 -0
  138. kirin/interp/frame.py +151 -0
  139. kirin/interp/impl.py +197 -0
  140. kirin/interp/result.py +93 -0
  141. kirin/interp/state.py +71 -0
  142. kirin/interp/table.py +40 -0
  143. kirin/interp/value.py +73 -0
  144. kirin/ir/__init__.py +46 -0
  145. kirin/ir/attrs/__init__.py +20 -0
  146. kirin/ir/attrs/_types.py +8 -0
  147. kirin/ir/attrs/_types.pyi +13 -0
  148. kirin/ir/attrs/abc.py +46 -0
  149. kirin/ir/attrs/py.py +45 -0
  150. kirin/ir/attrs/types.py +522 -0
  151. kirin/ir/dialect.py +125 -0
  152. kirin/ir/group.py +249 -0
  153. kirin/ir/method.py +118 -0
  154. kirin/ir/nodes/__init__.py +7 -0
  155. kirin/ir/nodes/base.py +149 -0
  156. kirin/ir/nodes/block.py +458 -0
  157. kirin/ir/nodes/region.py +337 -0
  158. kirin/ir/nodes/stmt.py +713 -0
  159. kirin/ir/nodes/view.py +142 -0
  160. kirin/ir/ssa.py +204 -0
  161. kirin/ir/traits/__init__.py +36 -0
  162. kirin/ir/traits/abc.py +42 -0
  163. kirin/ir/traits/basic.py +78 -0
  164. kirin/ir/traits/callable.py +51 -0
  165. kirin/ir/traits/lowering/__init__.py +2 -0
  166. kirin/ir/traits/lowering/call.py +37 -0
  167. kirin/ir/traits/lowering/context.py +120 -0
  168. kirin/ir/traits/region/__init__.py +2 -0
  169. kirin/ir/traits/region/ssacfg.py +22 -0
  170. kirin/ir/traits/symbol.py +57 -0
  171. kirin/ir/use.py +17 -0
  172. kirin/lattice/__init__.py +13 -0
  173. kirin/lattice/abc.py +128 -0
  174. kirin/lattice/empty.py +25 -0
  175. kirin/lattice/mixin.py +51 -0
  176. kirin/lowering/__init__.py +7 -0
  177. kirin/lowering/binding.py +65 -0
  178. kirin/lowering/core.py +72 -0
  179. kirin/lowering/dialect.py +35 -0
  180. kirin/lowering/dialect.pyi +183 -0
  181. kirin/lowering/frame.py +171 -0
  182. kirin/lowering/result.py +68 -0
  183. kirin/lowering/state.py +441 -0
  184. kirin/lowering/stream.py +53 -0
  185. kirin/passes/__init__.py +3 -0
  186. kirin/passes/abc.py +44 -0
  187. kirin/passes/aggressive/__init__.py +1 -0
  188. kirin/passes/aggressive/fold.py +43 -0
  189. kirin/passes/fold.py +45 -0
  190. kirin/passes/inline.py +25 -0
  191. kirin/passes/typeinfer.py +25 -0
  192. kirin/prelude.py +197 -0
  193. kirin/print/__init__.py +15 -0
  194. kirin/print/printable.py +141 -0
  195. kirin/print/printer.py +415 -0
  196. kirin/py.typed +0 -0
  197. kirin/registry.py +105 -0
  198. kirin/registry.pyi +52 -0
  199. kirin/rewrite/__init__.py +14 -0
  200. kirin/rewrite/abc.py +43 -0
  201. kirin/rewrite/aggressive/__init__.py +1 -0
  202. kirin/rewrite/aggressive/fold.py +43 -0
  203. kirin/rewrite/alias.py +16 -0
  204. kirin/rewrite/apply_type.py +47 -0
  205. kirin/rewrite/call2invoke.py +34 -0
  206. kirin/rewrite/chain.py +39 -0
  207. kirin/rewrite/compactify.py +288 -0
  208. kirin/rewrite/cse.py +48 -0
  209. kirin/rewrite/dce.py +19 -0
  210. kirin/rewrite/fixpoint.py +34 -0
  211. kirin/rewrite/fold.py +57 -0
  212. kirin/rewrite/getfield.py +21 -0
  213. kirin/rewrite/getitem.py +37 -0
  214. kirin/rewrite/inline.py +143 -0
  215. kirin/rewrite/result.py +15 -0
  216. kirin/rewrite/walk.py +83 -0
  217. kirin/rewrite/wrap_const.py +55 -0
  218. kirin/source.py +21 -0
  219. kirin/symbol_table.py +27 -0
  220. kirin/types.py +34 -0
  221. kirin/worklist.py +30 -0
  222. kirin_toolchain-0.13.0.dist-info/METADATA +42 -0
  223. kirin_toolchain-0.13.0.dist-info/RECORD +225 -0
  224. kirin_toolchain-0.13.0.dist-info/WHEEL +4 -0
  225. kirin_toolchain-0.13.0.dist-info/licenses/LICENSE +234 -0
@@ -0,0 +1,369 @@
1
+ # This file is generated by gen.py
2
+ from kirin import ir, types
3
+ from kirin.decl import info, statement
4
+ from kirin.dialects.math.dialect import dialect
5
+
6
+
7
+ @statement(dialect=dialect)
8
+ class acos(ir.Statement):
9
+ """acos statement, wrapping the math.acos function"""
10
+
11
+ name = "acos"
12
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
13
+ x: ir.SSAValue = info.argument(types.Float)
14
+ result: ir.ResultValue = info.result(types.Float)
15
+
16
+
17
+ @statement(dialect=dialect)
18
+ class asin(ir.Statement):
19
+ """asin statement, wrapping the math.asin function"""
20
+
21
+ name = "asin"
22
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
23
+ x: ir.SSAValue = info.argument(types.Float)
24
+ result: ir.ResultValue = info.result(types.Float)
25
+
26
+
27
+ @statement(dialect=dialect)
28
+ class asinh(ir.Statement):
29
+ """asinh statement, wrapping the math.asinh function"""
30
+
31
+ name = "asinh"
32
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
33
+ x: ir.SSAValue = info.argument(types.Float)
34
+ result: ir.ResultValue = info.result(types.Float)
35
+
36
+
37
+ @statement(dialect=dialect)
38
+ class atan(ir.Statement):
39
+ """atan statement, wrapping the math.atan function"""
40
+
41
+ name = "atan"
42
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
43
+ x: ir.SSAValue = info.argument(types.Float)
44
+ result: ir.ResultValue = info.result(types.Float)
45
+
46
+
47
+ @statement(dialect=dialect)
48
+ class atan2(ir.Statement):
49
+ """atan2 statement, wrapping the math.atan2 function"""
50
+
51
+ name = "atan2"
52
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
53
+ y: ir.SSAValue = info.argument(types.Float)
54
+ x: ir.SSAValue = info.argument(types.Float)
55
+ result: ir.ResultValue = info.result(types.Float)
56
+
57
+
58
+ @statement(dialect=dialect)
59
+ class atanh(ir.Statement):
60
+ """atanh statement, wrapping the math.atanh function"""
61
+
62
+ name = "atanh"
63
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
64
+ x: ir.SSAValue = info.argument(types.Float)
65
+ result: ir.ResultValue = info.result(types.Float)
66
+
67
+
68
+ @statement(dialect=dialect)
69
+ class ceil(ir.Statement):
70
+ """ceil statement, wrapping the math.ceil function"""
71
+
72
+ name = "ceil"
73
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
74
+ x: ir.SSAValue = info.argument(types.Float)
75
+ result: ir.ResultValue = info.result(types.Float)
76
+
77
+
78
+ @statement(dialect=dialect)
79
+ class copysign(ir.Statement):
80
+ """copysign statement, wrapping the math.copysign function"""
81
+
82
+ name = "copysign"
83
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
84
+ x: ir.SSAValue = info.argument(types.Float)
85
+ y: ir.SSAValue = info.argument(types.Float)
86
+ result: ir.ResultValue = info.result(types.Float)
87
+
88
+
89
+ @statement(dialect=dialect)
90
+ class cos(ir.Statement):
91
+ """cos statement, wrapping the math.cos function"""
92
+
93
+ name = "cos"
94
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
95
+ x: ir.SSAValue = info.argument(types.Float)
96
+ result: ir.ResultValue = info.result(types.Float)
97
+
98
+
99
+ @statement(dialect=dialect)
100
+ class cosh(ir.Statement):
101
+ """cosh statement, wrapping the math.cosh function"""
102
+
103
+ name = "cosh"
104
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
105
+ x: ir.SSAValue = info.argument(types.Float)
106
+ result: ir.ResultValue = info.result(types.Float)
107
+
108
+
109
+ @statement(dialect=dialect)
110
+ class degrees(ir.Statement):
111
+ """degrees statement, wrapping the math.degrees function"""
112
+
113
+ name = "degrees"
114
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
115
+ x: ir.SSAValue = info.argument(types.Float)
116
+ result: ir.ResultValue = info.result(types.Float)
117
+
118
+
119
+ @statement(dialect=dialect)
120
+ class erf(ir.Statement):
121
+ """erf statement, wrapping the math.erf function"""
122
+
123
+ name = "erf"
124
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
125
+ x: ir.SSAValue = info.argument(types.Float)
126
+ result: ir.ResultValue = info.result(types.Float)
127
+
128
+
129
+ @statement(dialect=dialect)
130
+ class erfc(ir.Statement):
131
+ """erfc statement, wrapping the math.erfc function"""
132
+
133
+ name = "erfc"
134
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
135
+ x: ir.SSAValue = info.argument(types.Float)
136
+ result: ir.ResultValue = info.result(types.Float)
137
+
138
+
139
+ @statement(dialect=dialect)
140
+ class exp(ir.Statement):
141
+ """exp statement, wrapping the math.exp function"""
142
+
143
+ name = "exp"
144
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
145
+ x: ir.SSAValue = info.argument(types.Float)
146
+ result: ir.ResultValue = info.result(types.Float)
147
+
148
+
149
+ @statement(dialect=dialect)
150
+ class expm1(ir.Statement):
151
+ """expm1 statement, wrapping the math.expm1 function"""
152
+
153
+ name = "expm1"
154
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
155
+ x: ir.SSAValue = info.argument(types.Float)
156
+ result: ir.ResultValue = info.result(types.Float)
157
+
158
+
159
+ @statement(dialect=dialect)
160
+ class fabs(ir.Statement):
161
+ """fabs statement, wrapping the math.fabs function"""
162
+
163
+ name = "fabs"
164
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
165
+ x: ir.SSAValue = info.argument(types.Float)
166
+ result: ir.ResultValue = info.result(types.Float)
167
+
168
+
169
+ @statement(dialect=dialect)
170
+ class floor(ir.Statement):
171
+ """floor statement, wrapping the math.floor function"""
172
+
173
+ name = "floor"
174
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
175
+ x: ir.SSAValue = info.argument(types.Float)
176
+ result: ir.ResultValue = info.result(types.Float)
177
+
178
+
179
+ @statement(dialect=dialect)
180
+ class fmod(ir.Statement):
181
+ """fmod statement, wrapping the math.fmod function"""
182
+
183
+ name = "fmod"
184
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
185
+ x: ir.SSAValue = info.argument(types.Float)
186
+ y: ir.SSAValue = info.argument(types.Float)
187
+ result: ir.ResultValue = info.result(types.Float)
188
+
189
+
190
+ @statement(dialect=dialect)
191
+ class gamma(ir.Statement):
192
+ """gamma statement, wrapping the math.gamma function"""
193
+
194
+ name = "gamma"
195
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
196
+ x: ir.SSAValue = info.argument(types.Float)
197
+ result: ir.ResultValue = info.result(types.Float)
198
+
199
+
200
+ @statement(dialect=dialect)
201
+ class isfinite(ir.Statement):
202
+ """isfinite statement, wrapping the math.isfinite function"""
203
+
204
+ name = "isfinite"
205
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
206
+ x: ir.SSAValue = info.argument(types.Float)
207
+ result: ir.ResultValue = info.result(types.Float)
208
+
209
+
210
+ @statement(dialect=dialect)
211
+ class isinf(ir.Statement):
212
+ """isinf statement, wrapping the math.isinf function"""
213
+
214
+ name = "isinf"
215
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
216
+ x: ir.SSAValue = info.argument(types.Float)
217
+ result: ir.ResultValue = info.result(types.Float)
218
+
219
+
220
+ @statement(dialect=dialect)
221
+ class isnan(ir.Statement):
222
+ """isnan statement, wrapping the math.isnan function"""
223
+
224
+ name = "isnan"
225
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
226
+ x: ir.SSAValue = info.argument(types.Float)
227
+ result: ir.ResultValue = info.result(types.Float)
228
+
229
+
230
+ @statement(dialect=dialect)
231
+ class lgamma(ir.Statement):
232
+ """lgamma statement, wrapping the math.lgamma function"""
233
+
234
+ name = "lgamma"
235
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
236
+ x: ir.SSAValue = info.argument(types.Float)
237
+ result: ir.ResultValue = info.result(types.Float)
238
+
239
+
240
+ @statement(dialect=dialect)
241
+ class log10(ir.Statement):
242
+ """log10 statement, wrapping the math.log10 function"""
243
+
244
+ name = "log10"
245
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
246
+ x: ir.SSAValue = info.argument(types.Float)
247
+ result: ir.ResultValue = info.result(types.Float)
248
+
249
+
250
+ @statement(dialect=dialect)
251
+ class log1p(ir.Statement):
252
+ """log1p statement, wrapping the math.log1p function"""
253
+
254
+ name = "log1p"
255
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
256
+ x: ir.SSAValue = info.argument(types.Float)
257
+ result: ir.ResultValue = info.result(types.Float)
258
+
259
+
260
+ @statement(dialect=dialect)
261
+ class log2(ir.Statement):
262
+ """log2 statement, wrapping the math.log2 function"""
263
+
264
+ name = "log2"
265
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
266
+ x: ir.SSAValue = info.argument(types.Float)
267
+ result: ir.ResultValue = info.result(types.Float)
268
+
269
+
270
+ @statement(dialect=dialect)
271
+ class pow(ir.Statement):
272
+ """pow statement, wrapping the math.pow function"""
273
+
274
+ name = "pow"
275
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
276
+ x: ir.SSAValue = info.argument(types.Float)
277
+ y: ir.SSAValue = info.argument(types.Float)
278
+ result: ir.ResultValue = info.result(types.Float)
279
+
280
+
281
+ @statement(dialect=dialect)
282
+ class radians(ir.Statement):
283
+ """radians statement, wrapping the math.radians function"""
284
+
285
+ name = "radians"
286
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
287
+ x: ir.SSAValue = info.argument(types.Float)
288
+ result: ir.ResultValue = info.result(types.Float)
289
+
290
+
291
+ @statement(dialect=dialect)
292
+ class remainder(ir.Statement):
293
+ """remainder statement, wrapping the math.remainder function"""
294
+
295
+ name = "remainder"
296
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
297
+ x: ir.SSAValue = info.argument(types.Float)
298
+ y: ir.SSAValue = info.argument(types.Float)
299
+ result: ir.ResultValue = info.result(types.Float)
300
+
301
+
302
+ @statement(dialect=dialect)
303
+ class sin(ir.Statement):
304
+ """sin statement, wrapping the math.sin function"""
305
+
306
+ name = "sin"
307
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
308
+ x: ir.SSAValue = info.argument(types.Float)
309
+ result: ir.ResultValue = info.result(types.Float)
310
+
311
+
312
+ @statement(dialect=dialect)
313
+ class sinh(ir.Statement):
314
+ """sinh statement, wrapping the math.sinh function"""
315
+
316
+ name = "sinh"
317
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
318
+ x: ir.SSAValue = info.argument(types.Float)
319
+ result: ir.ResultValue = info.result(types.Float)
320
+
321
+
322
+ @statement(dialect=dialect)
323
+ class sqrt(ir.Statement):
324
+ """sqrt statement, wrapping the math.sqrt function"""
325
+
326
+ name = "sqrt"
327
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
328
+ x: ir.SSAValue = info.argument(types.Float)
329
+ result: ir.ResultValue = info.result(types.Float)
330
+
331
+
332
+ @statement(dialect=dialect)
333
+ class tan(ir.Statement):
334
+ """tan statement, wrapping the math.tan function"""
335
+
336
+ name = "tan"
337
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
338
+ x: ir.SSAValue = info.argument(types.Float)
339
+ result: ir.ResultValue = info.result(types.Float)
340
+
341
+
342
+ @statement(dialect=dialect)
343
+ class tanh(ir.Statement):
344
+ """tanh statement, wrapping the math.tanh function"""
345
+
346
+ name = "tanh"
347
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
348
+ x: ir.SSAValue = info.argument(types.Float)
349
+ result: ir.ResultValue = info.result(types.Float)
350
+
351
+
352
+ @statement(dialect=dialect)
353
+ class trunc(ir.Statement):
354
+ """trunc statement, wrapping the math.trunc function"""
355
+
356
+ name = "trunc"
357
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
358
+ x: ir.SSAValue = info.argument(types.Float)
359
+ result: ir.ResultValue = info.result(types.Float)
360
+
361
+
362
+ @statement(dialect=dialect)
363
+ class ulp(ir.Statement):
364
+ """ulp statement, wrapping the math.ulp function"""
365
+
366
+ name = "ulp"
367
+ traits = frozenset({ir.Pure(), ir.FromPythonCall()})
368
+ x: ir.SSAValue = info.argument(types.Float)
369
+ result: ir.ResultValue = info.result(types.Float)
@@ -0,0 +1,139 @@
1
+ """Module dialect provides a simple module
2
+ that is roughly a list of function statements.
3
+
4
+ This dialect provides the dialect necessary for compiling a function into
5
+ lower-level IR with all its callee functions.
6
+ """
7
+
8
+ from kirin import ir, types, interp
9
+ from kirin.decl import info, statement
10
+ from kirin.print import Printer
11
+ from kirin.analysis import TypeInference
12
+ from kirin.exceptions import VerificationError
13
+
14
+ from ._pprint_helper import pprint_calllike
15
+
16
+ dialect = ir.Dialect("module")
17
+
18
+
19
+ @statement(dialect=dialect)
20
+ class Module(ir.Statement):
21
+ traits = frozenset(
22
+ {ir.IsolatedFromAbove(), ir.SymbolTable(), ir.SymbolOpInterface()}
23
+ )
24
+ sym_name: str = info.attribute()
25
+ entry: str = info.attribute()
26
+ body: ir.Region = info.region(multi=False)
27
+
28
+
29
+ @statement(dialect=dialect)
30
+ class Invoke(ir.Statement):
31
+ """A special statement that represents
32
+ a function calling functions by symbol name.
33
+
34
+ Note:
35
+ This statement is here for completeness, for interpretation,
36
+ it is recommended to rewrite this statement into a `func.Invoke`
37
+ after looking up the symbol table.
38
+ """
39
+
40
+ callee: str = info.attribute()
41
+ inputs: tuple[ir.SSAValue, ...] = info.argument()
42
+ kwargs: tuple[str, ...] = info.attribute()
43
+ result: ir.ResultValue = info.result()
44
+
45
+ def print_impl(self, printer: Printer) -> None:
46
+ pprint_calllike(self, self.callee, printer)
47
+
48
+ def verify(self) -> None:
49
+ if self.kwargs:
50
+ for name in self.kwargs:
51
+ if name not in self.callee:
52
+ raise VerificationError(
53
+ self,
54
+ f"method {self.callee} does not have argument {name}",
55
+ )
56
+ elif len(self.callee) - 1 != len(self.args):
57
+ raise VerificationError(
58
+ self,
59
+ f"expected {len(self.callee)} arguments, got {len(self.args)}",
60
+ )
61
+
62
+
63
+ @dialect.register
64
+ class Concrete(interp.MethodTable):
65
+
66
+ @interp.impl(Module)
67
+ def interp_Module(
68
+ self, interp: interp.Interpreter, frame: interp.Frame, stmt: Module
69
+ ):
70
+ for stmt_ in stmt.body.blocks[0].stmts:
71
+ if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
72
+ interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
73
+ return ()
74
+
75
+ @interp.impl(Invoke)
76
+ def interp_Invoke(
77
+ self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: Invoke
78
+ ):
79
+ callee = interpreter.symbol_table.get(stmt.callee)
80
+ if callee is None:
81
+ raise interp.InterpreterError(f"symbol {stmt.callee} not found")
82
+
83
+ trait = callee.get_trait(ir.CallableStmtInterface)
84
+ if trait is None:
85
+ raise interp.InterpreterError(
86
+ f"{stmt.callee} is not callable, got {callee.__class__.__name__}"
87
+ )
88
+
89
+ body = trait.get_callable_region(callee)
90
+ mt = ir.Method(
91
+ mod=None,
92
+ py_func=None,
93
+ sym_name=stmt.callee,
94
+ arg_names=[
95
+ arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
96
+ ],
97
+ dialects=interpreter.dialects,
98
+ code=stmt,
99
+ )
100
+ return interpreter.run_method(mt, frame.get_values(stmt.inputs))
101
+
102
+
103
+ @dialect.register(key="typeinfer")
104
+ class TypeInfer(interp.MethodTable):
105
+
106
+ @interp.impl(Module)
107
+ def typeinfer_Module(
108
+ self, interp: TypeInference, frame: interp.Frame, stmt: Module
109
+ ):
110
+ for stmt_ in stmt.body.blocks[0].stmts:
111
+ if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
112
+ interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
113
+ return ()
114
+
115
+ @interp.impl(Invoke)
116
+ def typeinfer_Invoke(
117
+ self, interp: TypeInference, frame: interp.Frame, stmt: Invoke
118
+ ):
119
+ callee = interp.symbol_table.get(stmt.callee)
120
+ if callee is None:
121
+ return (types.Bottom,)
122
+
123
+ trait = callee.get_trait(ir.CallableStmtInterface)
124
+ if trait is None:
125
+ return (types.Bottom,)
126
+
127
+ body = trait.get_callable_region(callee)
128
+ mt = ir.Method(
129
+ mod=None,
130
+ py_func=None,
131
+ sym_name=stmt.callee,
132
+ arg_names=[
133
+ arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
134
+ ],
135
+ dialects=interp.dialects,
136
+ code=stmt,
137
+ )
138
+ interp.run_method(mt, mt.arg_types)
139
+ return tuple(result.type for result in callee.results)
@@ -0,0 +1,40 @@
1
+ """Python dialects module.
2
+
3
+ This module contains a set of dialects that represent
4
+ different fractions of the Python language. The dialects
5
+ are designed to be used in a union to represent the
6
+ entire Python language.
7
+ """
8
+
9
+ from . import (
10
+ cmp as cmp,
11
+ len as len,
12
+ attr as attr,
13
+ base as base,
14
+ list as list,
15
+ binop as binop,
16
+ range as range,
17
+ slice as slice,
18
+ tuple as tuple,
19
+ unary as unary,
20
+ assign as assign,
21
+ boolop as boolop,
22
+ unpack as unpack,
23
+ builtin as builtin,
24
+ constant as constant,
25
+ indexing as indexing,
26
+ iterable as iterable,
27
+ )
28
+ from .len import Len as Len
29
+ from .attr import GetAttr as GetAttr
30
+ from .range import Range as Range
31
+ from .slice import Slice as Slice
32
+ from .assign import Alias as Alias, SetItem as SetItem
33
+ from .boolop import Or as Or, And as And
34
+ from .builtin import Abs as Abs, Sum as Sum
35
+ from .constant import Constant as Constant
36
+ from .indexing import GetItem as GetItem, PyGetItemLike as PyGetItemLike
37
+ from .cmp.stmts import * # noqa: F403
38
+ from .list.stmts import Append as Append
39
+ from .binop.stmts import * # noqa: F403
40
+ from .unary.stmts import * # noqa: F403
@@ -0,0 +1,91 @@
1
+ """Assertion dialect for Python.
2
+
3
+ This module contains the dialect for the Python `assert` statement, including:
4
+
5
+ - The `Assert` statement class.
6
+ - The lowering pass for the `assert` statement.
7
+ - The concrete implementation of the `assert` statement.
8
+ - The type inference implementation of the `assert` statement.
9
+ - The Julia emitter for the `assert` statement.
10
+
11
+ This dialect maps `ast.Assert` nodes to the `Assert` statement.
12
+ """
13
+
14
+ import ast
15
+
16
+ from kirin import ir, types, interp, lowering
17
+ from kirin.decl import info, statement
18
+ from kirin.emit import EmitStrFrame, julia
19
+ from kirin.print import Printer
20
+
21
+ dialect = ir.Dialect("py.assert")
22
+
23
+
24
+ @statement(dialect=dialect)
25
+ class Assert(ir.Statement):
26
+ condition: ir.SSAValue
27
+ message: ir.SSAValue = info.argument(types.String)
28
+
29
+ def print_impl(self, printer: Printer) -> None:
30
+ with printer.rich(style="keyword"):
31
+ printer.print_name(self)
32
+
33
+ printer.plain_print(" ")
34
+ printer.print(self.condition)
35
+
36
+ if self.message:
37
+ printer.plain_print(", ")
38
+ printer.print(self.message)
39
+
40
+
41
+ @dialect.register
42
+ class Lowering(lowering.FromPythonAST):
43
+
44
+ def lower_Assert(
45
+ self, state: lowering.LoweringState, node: ast.Assert
46
+ ) -> lowering.Result:
47
+ from kirin.dialects.py.constant import Constant
48
+
49
+ cond = state.visit(node.test).expect_one()
50
+ if node.msg:
51
+ message = state.visit(node.msg).expect_one()
52
+ state.append_stmt(Assert(condition=cond, message=message))
53
+ else:
54
+ message_stmt = state.append_stmt(Constant(""))
55
+ state.append_stmt(Assert(condition=cond, message=message_stmt.result))
56
+ return lowering.Result()
57
+
58
+
59
+ @dialect.register
60
+ class Concrete(interp.MethodTable):
61
+
62
+ @interp.impl(Assert)
63
+ def assert_stmt(
64
+ self, interp_: interp.Interpreter, frame: interp.Frame, stmt: Assert
65
+ ):
66
+ if frame.get(stmt.condition) is True:
67
+ return ()
68
+
69
+ if stmt.message:
70
+ raise interp.WrapException(AssertionError(frame.get(stmt.message)))
71
+ else:
72
+ raise interp.WrapException(AssertionError("Assertion failed"))
73
+
74
+
75
+ @dialect.register(key="typeinfer")
76
+ class TypeInfer(interp.MethodTable):
77
+
78
+ @interp.impl(Assert)
79
+ def assert_stmt(self, interp, frame, stmt: Assert):
80
+ return (types.Bottom,)
81
+
82
+
83
+ @dialect.register(key="emit.julia")
84
+ class EmitJulia(interp.MethodTable):
85
+
86
+ @interp.impl(Assert)
87
+ def emit_assert(self, interp: julia.EmitJulia, frame: EmitStrFrame, stmt: Assert):
88
+ interp.writeln(
89
+ frame, f"@assert {frame.get(stmt.condition)} {frame.get(stmt.message)}"
90
+ )
91
+ return ()