onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,491 @@
1
+ from __future__ import annotations
2
+
3
+ import sys
4
+
5
+ import sympy
6
+ from sympy.printing.precedence import PRECEDENCE, precedence
7
+ from sympy.printing.str import StrPrinter
8
+
9
+ INDEX_TYPE = "int64_t"
10
+ INDEX_TYPE_MAX = (1 << 63) - 1
11
+ INDEX_TYPE_MIN = -1 << 63
12
+
13
+
14
+ # This printer contains rules that are supposed to be generic for both C/C++ and
15
+ # Python
16
+ class ExprPrinter(StrPrinter):
17
+ # override this so that _print_FloorDiv is used
18
+ printmethod = "_torch_sympystr"
19
+
20
+ def _print_Mul(self, expr: sympy.Expr) -> str:
21
+ return self.stringify(expr.args, "*", precedence(expr))
22
+
23
+ def _print_Not(self, expr: sympy.Expr) -> str:
24
+ return f"not ({self._print(expr.args[0])})"
25
+
26
+ def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
27
+ return self.stringify(expr.args, " + ", precedence(expr))
28
+
29
+ def _print_Relational(self, expr: sympy.Expr) -> str:
30
+ return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
31
+
32
+ def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
33
+ return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
34
+
35
+ def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
36
+ return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
37
+
38
+ # NB: this is OK to put here, because Mod is only defined for positive
39
+ # numbers, and so across C/Python its behavior is consistent
40
+ def _print_Mod(self, expr: sympy.Expr) -> str:
41
+ return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
42
+
43
+ def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
44
+ s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
45
+ return f"({s})"
46
+
47
+ def _print_CleanDiv(self, expr: sympy.Expr) -> str:
48
+ return self._print_FloorDiv(expr)
49
+
50
+ def _print_Identity(self, expr: sympy.Expr) -> str:
51
+ return self._print(expr.args[0])
52
+
53
+ def _print_Float(self, expr: sympy.Expr) -> str:
54
+ if expr._prec == 53:
55
+ # IEEE-754 double precision have 53 bits. SymPy prints them with
56
+ # 15 digits, but we need 17 for round-trip correctness
57
+ return str(sympy.Float(expr, dps=17))
58
+ else:
59
+ # We don't use other precisions in pytorch
60
+ return str(expr)
61
+
62
+ # This must be implemented because sympy will collect x * x into Pow(x, 2), without
63
+ # any explicit intervention. We print it just like x * x, notably, we
64
+ # never generate sympy.Pow with floats.
65
+ #
66
+ # NB: this pow by natural, you should never have used builtin sympy.pow
67
+ # for FloatPow, and a symbolic exponent should be PowByNatural. These
68
+ # means exp is guaranteed to be integer.
69
+ def _print_Pow(self, expr: sympy.Expr) -> str:
70
+ base, exp = expr.args
71
+ assert exp == int(exp), exp
72
+ exp = int(exp)
73
+ assert exp >= 0
74
+ if exp > 0:
75
+ return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
76
+ return "1"
77
+
78
+ # Explicit NotImplemented functions are to prevent default sympy printing
79
+ # behavior, which will just barf out ToFloat(...) to your IR. The error
80
+ # message is better here because it tells you which printer class it needs
81
+ # to go in.
82
+
83
+ def _print_ToFloat(self, expr: sympy.Expr) -> str:
84
+ raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
85
+
86
+ def _print_Infinity(self, expr: sympy.Expr) -> str:
87
+ raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
88
+
89
+ def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
90
+ raise NotImplementedError(f"_print_NegativeInfinity not implemented for {type(self)}")
91
+
92
+ def _print_FloorDiv(self, expr: sympy.Expr) -> str:
93
+ raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
94
+
95
+ def _print_PythonMod(self, expr: sympy.Expr) -> str:
96
+ raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
97
+
98
+ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
99
+ raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
100
+
101
+ def _print_PowByNatural(self, expr: sympy.Expr) -> str:
102
+ raise NotImplementedError(f"_print_PowByNatural not implemented for {type(self)}")
103
+
104
+ def _print_FloatPow(self, expr: sympy.Expr) -> str:
105
+ raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
106
+
107
+ def _print_TruncToInt(self, expr: sympy.Expr) -> str:
108
+ raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
109
+
110
+ def _print_RoundToInt(self, expr: sympy.Expr) -> str:
111
+ raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
112
+
113
+ def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
114
+ raise NotImplementedError(f"_print_RoundDecimal not implemented for {type(self)}")
115
+
116
+ # NB: Some float operations are INTENTIONALLY not implemented for
117
+ # printers. You can implement them as a quick unblock, but it is better
118
+ # to ask yourself why we haven't done this computation in the Tensor
119
+ # universe instead
120
+
121
+ def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
122
+ raise NotImplementedError(f"_print_TruncToFloat not implemented for {type(self)}")
123
+
124
+
125
+ class PythonPrinter(ExprPrinter):
126
+ def _print_ToFloat(self, expr: sympy.Expr) -> str:
127
+ assert len(expr.args) == 1
128
+ # NB: We use sym_float here because the printer is used for cache
129
+ # serialization, and cache guards get evaluated with SymInt to
130
+ # propagate guards to the parent ShapeEnv. However, this comes at a
131
+ # runtime cost for guards involving float. If this is unacceptable
132
+ # overhead, what you want to do is have two separate printers for
133
+ # SymInt, one for when the inputs are guaranteed to be int, and
134
+ # another for when they could be SymInt.
135
+ #
136
+ # NB: sym_min/sym_max also have this problem, but I chose not to fix
137
+ # those.
138
+ #
139
+ # See https://github.com/pytorch/pytorch/issues/142507 for more
140
+ # context.
141
+ return f"torch.sym_float({self._print(expr.args[0])})"
142
+
143
+ def _print_And(self, expr: sympy.Expr) -> str:
144
+ return self.stringify(expr.args, " and ", precedence(expr))
145
+
146
+ def _print_Or(self, expr: sympy.Expr) -> str:
147
+ return self.stringify(expr.args, " or ", precedence(expr))
148
+
149
+ def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
150
+ x, div, mod = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
151
+ if div != "1":
152
+ x = f"({x} // {div})"
153
+ return f"({x} % {mod})"
154
+
155
+ def _print_Infinity(self, expr: sympy.Expr) -> str:
156
+ return "math.inf"
157
+
158
+ def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
159
+ return "-math.inf"
160
+
161
+ # WARNING: this is dangerous for Triton, which has C-style modulus
162
+ def _print_PythonMod(self, expr: sympy.Expr) -> str:
163
+ return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
164
+
165
+ # WARNING: this is dangerous for Triton, which has C-style modulus
166
+ def _print_FloorDiv(self, expr: sympy.Expr) -> str:
167
+ x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
168
+ return f"{x} // {div}"
169
+
170
+ # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
171
+ # does a special algorithm
172
+ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
173
+ return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
174
+
175
+ def _helper_sqrt(self, expr: sympy.Expr) -> str:
176
+ return f"math.sqrt({self._print(expr)})"
177
+
178
+ def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
179
+ return self._helper_sqrt(expr.args[0])
180
+
181
+ def _print_FloatPow(self, expr: sympy.Expr) -> str:
182
+ return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
183
+
184
+ # TODO: Not sure this works with Triton, even when base/exp are integral
185
+ def _print_PowByNatural(self, expr: sympy.Expr) -> str:
186
+ return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
187
+
188
+ def _print_floor(self, expr: sympy.Expr) -> str:
189
+ assert len(expr.args) == 1
190
+ return f"math.floor({self._print(expr.args[0])})"
191
+
192
+ def _print_FloorToInt(self, expr: sympy.Expr) -> str:
193
+ assert len(expr.args) == 1
194
+ return f"math.floor({self._print(expr.args[0])})"
195
+
196
+ def _print_TruncToInt(self, expr: sympy.Expr) -> str:
197
+ assert len(expr.args) == 1
198
+ # This also could have been int(), they'll do the same thing for float
199
+ return f"math.trunc({self._print(expr.args[0])})"
200
+
201
+ def _print_ceiling(self, expr: sympy.Expr) -> str:
202
+ assert len(expr.args) == 1
203
+ return f"math.ceil({self._print(expr.args[0])})"
204
+
205
+ def _print_CeilToInt(self, expr: sympy.Expr) -> str:
206
+ assert len(expr.args) == 1
207
+ return f"math.ceil({self._print(expr.args[0])})"
208
+
209
+ def _print_Abs(self, expr: sympy.Expr) -> str:
210
+ assert len(expr.args) == 1
211
+ return f"abs({self._print(expr.args[0])})"
212
+
213
+ # NB: It's expected that we've made explicit any promotion in the sympy
214
+ # expression, so it doesn't matter that Python max/min doesn't perform
215
+ # promotion
216
+ def _print_Max(self, expr: sympy.Expr) -> str:
217
+ assert len(expr.args) >= 2
218
+ return f"max({', '.join(map(self._print, expr.args))})"
219
+
220
+ def _print_Min(self, expr: sympy.Expr) -> str:
221
+ assert len(expr.args) >= 2
222
+ return f"min({', '.join(map(self._print, expr.args))})"
223
+
224
+ def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
225
+ assert len(expr.args) == 1
226
+ return f"math.cos({self._print(expr.args[0])})"
227
+
228
+ def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
229
+ assert len(expr.args) == 1
230
+ return f"math.cosh({self._print(expr.args[0])})"
231
+
232
+ def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
233
+ assert len(expr.args) == 1
234
+ return f"math.acos({self._print(expr.args[0])})"
235
+
236
+ def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
237
+ assert len(expr.args) == 1
238
+ return f"math.sin({self._print(expr.args[0])})"
239
+
240
+ def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
241
+ assert len(expr.args) == 1
242
+ return f"math.sinh({self._print(expr.args[0])})"
243
+
244
+ def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
245
+ assert len(expr.args) == 1
246
+ return f"math.asin({self._print(expr.args[0])})"
247
+
248
+ def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
249
+ assert len(expr.args) == 1
250
+ return f"math.tan({self._print(expr.args[0])})"
251
+
252
+ def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
253
+ assert len(expr.args) == 1
254
+ return f"math.tanh({self._print(expr.args[0])})"
255
+
256
+ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
257
+ assert len(expr.args) == 1
258
+ return f"math.atan({self._print(expr.args[0])})"
259
+
260
+ def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
261
+ assert len(expr.args) == 1
262
+ return f"math.log2({self._print(expr.args[0])})"
263
+
264
+ def _print_RoundToInt(self, expr: sympy.Expr) -> str:
265
+ assert len(expr.args) == 1
266
+ return f"round({self._print(expr.args[0])})"
267
+
268
+ def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
269
+ assert len(expr.args) == 2
270
+ number, ndigits = expr.args
271
+ assert isinstance(ndigits, sympy.Integer)
272
+ return f"round({self._print(number)}, {ndigits})"
273
+
274
+
275
+ class CppPrinter(ExprPrinter):
276
+ def _print_Integer(self, expr: sympy.Expr) -> str:
277
+ suffix = "LL" if sys.platform in ["darwin", "win32"] else "L"
278
+ i = int(expr)
279
+ if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN:
280
+ raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}")
281
+ elif i == INDEX_TYPE_MIN:
282
+ assert i == (-1) << 63
283
+ # Writing -9223372036854775808L makes the value overflow
284
+ # as it is parsed as -(9223372036854775808L) by the C/C++ compiler
285
+ return f"(-1{suffix} << 63)"
286
+ return f"{i}{suffix}"
287
+
288
+ def _print_Where(self, expr: sympy.Expr) -> str:
289
+ c, p, q = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
290
+ return f"{c} ? {p} : {q}"
291
+
292
+ def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
293
+ x, div, mod = expr.args
294
+ x = self.doprint(x)
295
+ if div != 1:
296
+ div = self.doprint(div)
297
+ if expr.is_integer:
298
+ x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
299
+ else:
300
+ x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
301
+ mod = self.doprint(mod)
302
+ return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
303
+
304
+ def _print_FloorDiv(self, expr: sympy.Expr) -> str:
305
+ x, div = expr.args
306
+ x = self.doprint(x)
307
+ div = self.doprint(div)
308
+ if expr.is_integer:
309
+ return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
310
+ return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
311
+
312
+ def _print_floor(self, expr: sympy.Expr) -> str:
313
+ assert len(expr.args) == 1
314
+ r = f"std::floor({self._print(expr.args[0])})"
315
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
316
+
317
+ def _print_FloorToInt(self, expr: sympy.Expr) -> str:
318
+ assert len(expr.args) == 1
319
+ r = f"std::floor({self._print(expr.args[0])})"
320
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
321
+
322
+ def _print_TruncToInt(self, expr: sympy.Expr) -> str:
323
+ assert len(expr.args) == 1
324
+ r = f"std::trunc({self._print(expr.args[0])})"
325
+ return f"static_cast<{INDEX_TYPE}>({r})"
326
+
327
+ def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
328
+ assert len(expr.args) == 1
329
+ return f"std::trunc({self._print(expr.args[0])})"
330
+
331
+ def _print_ToFloat(self, expr: sympy.Expr) -> str:
332
+ assert len(expr.args) == 1
333
+ return f"static_cast<double>({self._print(expr.args[0])})"
334
+
335
+ def _print_PythonMod(self, expr: sympy.Expr) -> str:
336
+ x, div = expr.args
337
+ x = self.doprint(x)
338
+ div = self.doprint(div)
339
+ return f"c10::div_mod({x}, {div})"
340
+
341
+ def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
342
+ lhs, rhs = expr.args
343
+ # TODO: This is only accurate up to 2**53
344
+ return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
345
+
346
+ # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
347
+ # use std::pow, that operates on floats
348
+ def _print_PowByNatural(self, expr: sympy.Expr) -> str:
349
+ # Implement the special-case of 2**x for now
350
+ base, exp = expr.args
351
+ if base == 2:
352
+ return f"(1 << ({self._print(exp)}))"
353
+ raise NotImplementedError(f"_print_PowByNatural not implemented for {type(self)}")
354
+
355
+ def _print_FloatPow(self, expr: sympy.Expr) -> str:
356
+ base, exp = expr.args
357
+ return f"std::pow({self._print(base)}, {self._print(exp)})"
358
+
359
+ def _print_Pow(self, expr: sympy.Expr) -> str:
360
+ # Uses float constants to perform FP div
361
+ base, exp = expr.args
362
+
363
+ if exp == 0.5 or exp == -0.5:
364
+ base = self._print(base)
365
+ return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
366
+ if exp.is_integer:
367
+ exp = int(exp)
368
+ if exp > 0:
369
+ r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
370
+ elif exp < -1:
371
+ r = "1.0/(" + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + ")"
372
+ elif exp == -1:
373
+ r = "1.0/" + self._print(base)
374
+ else: # exp == 0
375
+ r = "1.0"
376
+
377
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
378
+ else:
379
+ # TODO: float vs double
380
+ return f"std::pow({base}, {float(exp)})"
381
+
382
+ def _print_Rational(self, expr: sympy.Expr) -> str:
383
+ # Uses float constants to perform FP div
384
+ if expr.q == 1:
385
+ r = f"{expr.p}"
386
+ else:
387
+ r = f"{expr.p}.0/{expr.q}.0"
388
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
389
+
390
+ def _print_ceiling(self, expr: sympy.Expr) -> str:
391
+ assert len(expr.args) == 1
392
+ r = f"std::ceil({self._print(expr.args[0])})"
393
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
394
+
395
+ def _print_CeilToInt(self, expr: sympy.Expr) -> str:
396
+ assert len(expr.args) == 1
397
+ r = f"std::ceil({self._print(expr.args[0])})"
398
+ return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
399
+
400
+ def _print_Min(self, expr: sympy.Expr) -> str:
401
+ args = [self._print(a) for a in expr.args]
402
+ if len(args) == 2:
403
+ return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
404
+ else:
405
+ # Initializer list overload
406
+ il = "{" + ", ".join(args) + "}"
407
+ return f"std::min<{INDEX_TYPE}>({il})"
408
+
409
+ def _print_Max(self, expr: sympy.Expr) -> str:
410
+ args = [self._print(a) for a in expr.args]
411
+ if len(args) == 2:
412
+ return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
413
+ else:
414
+ # Initializer list overload
415
+ il = "{" + ", ".join(args) + "}"
416
+ return f"std::max<{INDEX_TYPE}>({il})"
417
+
418
+ def _print_Abs(self, expr: sympy.Expr) -> str:
419
+ assert len(expr.args) == 1
420
+ return f"std::abs({self._print(expr.args[0])})"
421
+
422
+ def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
423
+ assert len(expr.args) == 1
424
+ return f"std::cos({self._print(expr.args[0])})"
425
+
426
+ def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
427
+ assert len(expr.args) == 1
428
+ return f"std::cosh({self._print(expr.args[0])})"
429
+
430
+ def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
431
+ assert len(expr.args) == 1
432
+ return f"std::acos({self._print(expr.args[0])})"
433
+
434
+ def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
435
+ assert len(expr.args) == 1
436
+ return f"std::sin({self._print(expr.args[0])})"
437
+
438
+ def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
439
+ assert len(expr.args) == 1
440
+ return f"std::sinh({self._print(expr.args[0])})"
441
+
442
+ def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
443
+ assert len(expr.args) == 1
444
+ return f"std::asin({self._print(expr.args[0])})"
445
+
446
+ def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
447
+ assert len(expr.args) == 1
448
+ return f"std::tan({self._print(expr.args[0])})"
449
+
450
+ def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
451
+ assert len(expr.args) == 1
452
+ return f"std::tanh({self._print(expr.args[0])})"
453
+
454
+ def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
455
+ assert len(expr.args) == 1
456
+ return f"std::atan({self._print(expr.args[0])})"
457
+
458
+ def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
459
+ return f"std::sqrt({self._print(expr.args[0])})"
460
+
461
+ def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
462
+ return f"std::log2({self._print(expr.args[0])})"
463
+
464
+ def _print_RoundToInt(self, expr: sympy.Expr) -> str:
465
+ assert len(expr.args) == 1
466
+ # TODO: dispatch to llrint depending on index type
467
+ return f"std::lrint({self._print(expr.args[0])})"
468
+
469
+ def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
470
+ assert len(expr.args) == 2
471
+ number, ndigits = expr.args
472
+ if number.is_integer:
473
+ # ndigits < 0 should have been filtered by the sympy function
474
+ assert ndigits < 0
475
+ raise ValueError(
476
+ f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
477
+ )
478
+ number_str = self.parenthesize(number, PRECEDENCE["Mul"])
479
+ return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
480
+
481
+ def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
482
+ return "true"
483
+
484
+ def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
485
+ return "false"
486
+
487
+ def _print_Infinity(self, expr: sympy.Expr) -> str:
488
+ return "std::numeric_limits<double>::infinity()"
489
+
490
+ def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
491
+ return f"-{self._print_Infinity(expr)}"
@@ -0,0 +1,172 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+
5
+ import sympy
6
+
7
+ from onnxslim.third_party._sympy.functions import FloorDiv
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+ _MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
12
+ sympy.Eq: sympy.Eq,
13
+ sympy.Ne: sympy.Ne,
14
+ sympy.Ge: sympy.Le,
15
+ sympy.Gt: sympy.Lt,
16
+ sympy.Le: sympy.Ge,
17
+ sympy.Lt: sympy.Gt,
18
+ }
19
+
20
+ INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
21
+
22
+
23
+ def mirror_rel_op(type: type) -> type[sympy.Rel] | None:
24
+ return _MIRROR_REL_OP.get(type, None)
25
+
26
+
27
+ # Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
28
+ #
29
+ # Returns a tuple of:
30
+ # 1. The simplified expression
31
+ # 2. The expression on the right-hand side
32
+ #
33
+ # Returns 'None' if it can't reach a state where the only thing in the left
34
+ # hand side is 'thing'.
35
+ #
36
+ # 'trials': number of times 'try_solve' will try to isolate 'thing' to the
37
+ # left-hand side.
38
+ #
39
+ # 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
40
+ # inequalities.
41
+ def try_solve(
42
+ expr: sympy.Basic,
43
+ thing: sympy.Basic,
44
+ trials: int = 5,
45
+ floordiv_inequality: bool = True,
46
+ ) -> tuple[sympy.Rel, sympy.Expr] | None:
47
+ mirror = mirror_rel_op(type(expr))
48
+
49
+ # Ignore unsupported expressions:
50
+ # - Those that are not relational operations
51
+ # - Those that don't have a mirror (just avoiding unexpected classes)
52
+ if not isinstance(expr, sympy.Rel) or mirror is None:
53
+ log.debug("expression with unsupported type: %s", type(expr))
54
+ return None
55
+
56
+ lhs_has_thing = expr.lhs.has(thing)
57
+ rhs_has_thing = expr.rhs.has(thing)
58
+
59
+ # Give up when 'thing' appears on both sides of the relational expression.
60
+ # That is because, as is, we assume the thing we are trying to isolate is
61
+ # only on the right-hand side.
62
+ if lhs_has_thing and rhs_has_thing:
63
+ log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
64
+ return None
65
+
66
+ # Try considering both LHS and RHS by mirroring the original expression:
67
+ # a < b ==> b > a
68
+ expressions = []
69
+
70
+ # Add each version of 'expr' if 'thing' is in its left-hand side.
71
+ if lhs_has_thing:
72
+ expressions.append(expr)
73
+ if rhs_has_thing:
74
+ expressions.append(mirror(expr.rhs, expr.lhs))
75
+
76
+ for e in expressions:
77
+ if e is None:
78
+ continue
79
+
80
+ assert isinstance(e, sympy.Rel)
81
+
82
+ for _ in range(trials):
83
+ trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
84
+ # Stop if there was no change in this trial.
85
+ if trial == e:
86
+ break
87
+ e = trial # type: ignore[assignment]
88
+
89
+ # Return if we were able to isolate 'thing' on the left-hand side.
90
+ if isinstance(e, sympy.Rel) and e.lhs == thing:
91
+ log.debug("solved: %s ---> %s", expr, e)
92
+ return e, e.rhs
93
+
94
+ return None
95
+
96
+
97
+ def _try_isolate_lhs(e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool) -> sympy.Basic:
98
+ op = type(e)
99
+
100
+ if isinstance(e, sympy.Rel):
101
+ # Move any constants in the left-hand side to the right-hand side.
102
+ lhs_not_thing = sum(a for a in e.lhs.args if not a.has(thing)) if isinstance(e.lhs, sympy.Add) else 0
103
+ e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
104
+
105
+ # Divide both sides by the factors that don't contain thing.
106
+ if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
107
+ lhs, rhs = e.args
108
+ other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
109
+
110
+ # If we can't tell whether 'other' is negative or positive, we do nothing.
111
+ # That is because we don't know whether we have mirror the operation or not.
112
+ # We also divide only when we know 'rhs' is not zero.
113
+ if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not (
114
+ not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero
115
+ ):
116
+ # Divide both sides by 'other'.
117
+ lhs = lhs / other
118
+ rhs = rhs / other
119
+
120
+ # If 'e' is an inequality and 'other' is negative, we have to
121
+ # mirror the expression.
122
+ if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
123
+ op = mirror_rel_op(op) # type: ignore[assignment]
124
+
125
+ assert op is not None
126
+ e = op(lhs, rhs)
127
+
128
+ ################################################################################
129
+ # left-hand side is FloorDiv
130
+ ################################################################################
131
+ #
132
+ # Given the expression: a // b op c
133
+ # where 'op' is a relational operation, these rules only work if:
134
+ # - b > 0
135
+ # - c is an integer
136
+ if (
137
+ floordiv_inequality
138
+ and isinstance(e, sympy.Rel)
139
+ and isinstance(e.lhs, FloorDiv)
140
+ and e.lhs.divisor.is_positive
141
+ and e.rhs.is_integer
142
+ ):
143
+ # a // b == expr
144
+ # => a >= (b * expr) and a < (b * (expr + 1))
145
+ if isinstance(e, sympy.Eq):
146
+ numerator, denominator = e.lhs.args
147
+ return sympy.And(
148
+ sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
149
+ sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
150
+ )
151
+ # a // b != expr
152
+ # => a < (b * expr) or a >= (b * (expr + 1))
153
+ if isinstance(e, sympy.Ne):
154
+ numerator, denominator = e.lhs.args
155
+ return sympy.Or(
156
+ sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
157
+ sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
158
+ )
159
+ # The transformations below only work if b is positive.
160
+ # Note: we only have this information for constants.
161
+ # a // b > expr => a >= b * (expr + 1)
162
+ # a // b >= expr => a >= b * expr
163
+ if isinstance(e, (sympy.Gt, sympy.Ge)):
164
+ quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
165
+ return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
166
+ # a // b < expr => a < b * expr
167
+ # a // b <= expr => a < b * (expr + 1)
168
+ if isinstance(e, (sympy.Lt, sympy.Le)):
169
+ quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
170
+ return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
171
+
172
+ return e