onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,491 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
import sympy
|
|
6
|
+
from sympy.printing.precedence import PRECEDENCE, precedence
|
|
7
|
+
from sympy.printing.str import StrPrinter
|
|
8
|
+
|
|
9
|
+
INDEX_TYPE = "int64_t"
|
|
10
|
+
INDEX_TYPE_MAX = (1 << 63) - 1
|
|
11
|
+
INDEX_TYPE_MIN = -1 << 63
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
# This printer contains rules that are supposed to be generic for both C/C++ and
|
|
15
|
+
# Python
|
|
16
|
+
class ExprPrinter(StrPrinter):
|
|
17
|
+
# override this so that _print_FloorDiv is used
|
|
18
|
+
printmethod = "_torch_sympystr"
|
|
19
|
+
|
|
20
|
+
def _print_Mul(self, expr: sympy.Expr) -> str:
|
|
21
|
+
return self.stringify(expr.args, "*", precedence(expr))
|
|
22
|
+
|
|
23
|
+
def _print_Not(self, expr: sympy.Expr) -> str:
|
|
24
|
+
return f"not ({self._print(expr.args[0])})"
|
|
25
|
+
|
|
26
|
+
def _print_Add(self, expr: sympy.Expr, order: str | None = None) -> str:
|
|
27
|
+
return self.stringify(expr.args, " + ", precedence(expr))
|
|
28
|
+
|
|
29
|
+
def _print_Relational(self, expr: sympy.Expr) -> str:
|
|
30
|
+
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
|
|
31
|
+
|
|
32
|
+
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
|
|
33
|
+
return self.stringify(expr.args, " & ", PRECEDENCE["BitwiseAnd"])
|
|
34
|
+
|
|
35
|
+
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
|
|
36
|
+
return self.stringify(expr.args, " | ", PRECEDENCE["BitwiseOr"])
|
|
37
|
+
|
|
38
|
+
# NB: this is OK to put here, because Mod is only defined for positive
|
|
39
|
+
# numbers, and so across C/Python its behavior is consistent
|
|
40
|
+
def _print_Mod(self, expr: sympy.Expr) -> str:
|
|
41
|
+
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
|
42
|
+
|
|
43
|
+
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
|
|
44
|
+
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
|
45
|
+
return f"({s})"
|
|
46
|
+
|
|
47
|
+
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
|
|
48
|
+
return self._print_FloorDiv(expr)
|
|
49
|
+
|
|
50
|
+
def _print_Identity(self, expr: sympy.Expr) -> str:
|
|
51
|
+
return self._print(expr.args[0])
|
|
52
|
+
|
|
53
|
+
def _print_Float(self, expr: sympy.Expr) -> str:
|
|
54
|
+
if expr._prec == 53:
|
|
55
|
+
# IEEE-754 double precision have 53 bits. SymPy prints them with
|
|
56
|
+
# 15 digits, but we need 17 for round-trip correctness
|
|
57
|
+
return str(sympy.Float(expr, dps=17))
|
|
58
|
+
else:
|
|
59
|
+
# We don't use other precisions in pytorch
|
|
60
|
+
return str(expr)
|
|
61
|
+
|
|
62
|
+
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
|
63
|
+
# any explicit intervention. We print it just like x * x, notably, we
|
|
64
|
+
# never generate sympy.Pow with floats.
|
|
65
|
+
#
|
|
66
|
+
# NB: this pow by natural, you should never have used builtin sympy.pow
|
|
67
|
+
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
|
68
|
+
# means exp is guaranteed to be integer.
|
|
69
|
+
def _print_Pow(self, expr: sympy.Expr) -> str:
|
|
70
|
+
base, exp = expr.args
|
|
71
|
+
assert exp == int(exp), exp
|
|
72
|
+
exp = int(exp)
|
|
73
|
+
assert exp >= 0
|
|
74
|
+
if exp > 0:
|
|
75
|
+
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
|
76
|
+
return "1"
|
|
77
|
+
|
|
78
|
+
# Explicit NotImplemented functions are to prevent default sympy printing
|
|
79
|
+
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
|
80
|
+
# message is better here because it tells you which printer class it needs
|
|
81
|
+
# to go in.
|
|
82
|
+
|
|
83
|
+
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
|
84
|
+
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
|
85
|
+
|
|
86
|
+
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
|
87
|
+
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
|
88
|
+
|
|
89
|
+
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
|
90
|
+
raise NotImplementedError(f"_print_NegativeInfinity not implemented for {type(self)}")
|
|
91
|
+
|
|
92
|
+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
|
93
|
+
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
|
94
|
+
|
|
95
|
+
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
|
96
|
+
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
|
97
|
+
|
|
98
|
+
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
|
99
|
+
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
|
100
|
+
|
|
101
|
+
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
|
102
|
+
raise NotImplementedError(f"_print_PowByNatural not implemented for {type(self)}")
|
|
103
|
+
|
|
104
|
+
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
|
105
|
+
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
|
106
|
+
|
|
107
|
+
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
|
108
|
+
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
|
109
|
+
|
|
110
|
+
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
|
111
|
+
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
|
112
|
+
|
|
113
|
+
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
|
114
|
+
raise NotImplementedError(f"_print_RoundDecimal not implemented for {type(self)}")
|
|
115
|
+
|
|
116
|
+
# NB: Some float operations are INTENTIONALLY not implemented for
|
|
117
|
+
# printers. You can implement them as a quick unblock, but it is better
|
|
118
|
+
# to ask yourself why we haven't done this computation in the Tensor
|
|
119
|
+
# universe instead
|
|
120
|
+
|
|
121
|
+
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
|
122
|
+
raise NotImplementedError(f"_print_TruncToFloat not implemented for {type(self)}")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class PythonPrinter(ExprPrinter):
|
|
126
|
+
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
|
127
|
+
assert len(expr.args) == 1
|
|
128
|
+
# NB: We use sym_float here because the printer is used for cache
|
|
129
|
+
# serialization, and cache guards get evaluated with SymInt to
|
|
130
|
+
# propagate guards to the parent ShapeEnv. However, this comes at a
|
|
131
|
+
# runtime cost for guards involving float. If this is unacceptable
|
|
132
|
+
# overhead, what you want to do is have two separate printers for
|
|
133
|
+
# SymInt, one for when the inputs are guaranteed to be int, and
|
|
134
|
+
# another for when they could be SymInt.
|
|
135
|
+
#
|
|
136
|
+
# NB: sym_min/sym_max also have this problem, but I chose not to fix
|
|
137
|
+
# those.
|
|
138
|
+
#
|
|
139
|
+
# See https://github.com/pytorch/pytorch/issues/142507 for more
|
|
140
|
+
# context.
|
|
141
|
+
return f"torch.sym_float({self._print(expr.args[0])})"
|
|
142
|
+
|
|
143
|
+
def _print_And(self, expr: sympy.Expr) -> str:
|
|
144
|
+
return self.stringify(expr.args, " and ", precedence(expr))
|
|
145
|
+
|
|
146
|
+
def _print_Or(self, expr: sympy.Expr) -> str:
|
|
147
|
+
return self.stringify(expr.args, " or ", precedence(expr))
|
|
148
|
+
|
|
149
|
+
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
|
150
|
+
x, div, mod = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
|
|
151
|
+
if div != "1":
|
|
152
|
+
x = f"({x} // {div})"
|
|
153
|
+
return f"({x} % {mod})"
|
|
154
|
+
|
|
155
|
+
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
|
156
|
+
return "math.inf"
|
|
157
|
+
|
|
158
|
+
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
|
159
|
+
return "-math.inf"
|
|
160
|
+
|
|
161
|
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
|
162
|
+
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
|
163
|
+
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
|
164
|
+
|
|
165
|
+
# WARNING: this is dangerous for Triton, which has C-style modulus
|
|
166
|
+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
|
167
|
+
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
|
|
168
|
+
return f"{x} // {div}"
|
|
169
|
+
|
|
170
|
+
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
|
171
|
+
# does a special algorithm
|
|
172
|
+
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
|
173
|
+
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
|
174
|
+
|
|
175
|
+
def _helper_sqrt(self, expr: sympy.Expr) -> str:
|
|
176
|
+
return f"math.sqrt({self._print(expr)})"
|
|
177
|
+
|
|
178
|
+
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
|
179
|
+
return self._helper_sqrt(expr.args[0])
|
|
180
|
+
|
|
181
|
+
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
|
182
|
+
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
|
183
|
+
|
|
184
|
+
# TODO: Not sure this works with Triton, even when base/exp are integral
|
|
185
|
+
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
|
186
|
+
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
|
187
|
+
|
|
188
|
+
def _print_floor(self, expr: sympy.Expr) -> str:
|
|
189
|
+
assert len(expr.args) == 1
|
|
190
|
+
return f"math.floor({self._print(expr.args[0])})"
|
|
191
|
+
|
|
192
|
+
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
|
193
|
+
assert len(expr.args) == 1
|
|
194
|
+
return f"math.floor({self._print(expr.args[0])})"
|
|
195
|
+
|
|
196
|
+
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
|
197
|
+
assert len(expr.args) == 1
|
|
198
|
+
# This also could have been int(), they'll do the same thing for float
|
|
199
|
+
return f"math.trunc({self._print(expr.args[0])})"
|
|
200
|
+
|
|
201
|
+
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
|
202
|
+
assert len(expr.args) == 1
|
|
203
|
+
return f"math.ceil({self._print(expr.args[0])})"
|
|
204
|
+
|
|
205
|
+
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
|
206
|
+
assert len(expr.args) == 1
|
|
207
|
+
return f"math.ceil({self._print(expr.args[0])})"
|
|
208
|
+
|
|
209
|
+
def _print_Abs(self, expr: sympy.Expr) -> str:
|
|
210
|
+
assert len(expr.args) == 1
|
|
211
|
+
return f"abs({self._print(expr.args[0])})"
|
|
212
|
+
|
|
213
|
+
# NB: It's expected that we've made explicit any promotion in the sympy
|
|
214
|
+
# expression, so it doesn't matter that Python max/min doesn't perform
|
|
215
|
+
# promotion
|
|
216
|
+
def _print_Max(self, expr: sympy.Expr) -> str:
|
|
217
|
+
assert len(expr.args) >= 2
|
|
218
|
+
return f"max({', '.join(map(self._print, expr.args))})"
|
|
219
|
+
|
|
220
|
+
def _print_Min(self, expr: sympy.Expr) -> str:
|
|
221
|
+
assert len(expr.args) >= 2
|
|
222
|
+
return f"min({', '.join(map(self._print, expr.args))})"
|
|
223
|
+
|
|
224
|
+
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
|
225
|
+
assert len(expr.args) == 1
|
|
226
|
+
return f"math.cos({self._print(expr.args[0])})"
|
|
227
|
+
|
|
228
|
+
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
|
229
|
+
assert len(expr.args) == 1
|
|
230
|
+
return f"math.cosh({self._print(expr.args[0])})"
|
|
231
|
+
|
|
232
|
+
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
|
233
|
+
assert len(expr.args) == 1
|
|
234
|
+
return f"math.acos({self._print(expr.args[0])})"
|
|
235
|
+
|
|
236
|
+
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
|
237
|
+
assert len(expr.args) == 1
|
|
238
|
+
return f"math.sin({self._print(expr.args[0])})"
|
|
239
|
+
|
|
240
|
+
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
|
241
|
+
assert len(expr.args) == 1
|
|
242
|
+
return f"math.sinh({self._print(expr.args[0])})"
|
|
243
|
+
|
|
244
|
+
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
|
245
|
+
assert len(expr.args) == 1
|
|
246
|
+
return f"math.asin({self._print(expr.args[0])})"
|
|
247
|
+
|
|
248
|
+
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
|
249
|
+
assert len(expr.args) == 1
|
|
250
|
+
return f"math.tan({self._print(expr.args[0])})"
|
|
251
|
+
|
|
252
|
+
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
|
253
|
+
assert len(expr.args) == 1
|
|
254
|
+
return f"math.tanh({self._print(expr.args[0])})"
|
|
255
|
+
|
|
256
|
+
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
|
257
|
+
assert len(expr.args) == 1
|
|
258
|
+
return f"math.atan({self._print(expr.args[0])})"
|
|
259
|
+
|
|
260
|
+
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
|
|
261
|
+
assert len(expr.args) == 1
|
|
262
|
+
return f"math.log2({self._print(expr.args[0])})"
|
|
263
|
+
|
|
264
|
+
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
|
265
|
+
assert len(expr.args) == 1
|
|
266
|
+
return f"round({self._print(expr.args[0])})"
|
|
267
|
+
|
|
268
|
+
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
|
269
|
+
assert len(expr.args) == 2
|
|
270
|
+
number, ndigits = expr.args
|
|
271
|
+
assert isinstance(ndigits, sympy.Integer)
|
|
272
|
+
return f"round({self._print(number)}, {ndigits})"
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class CppPrinter(ExprPrinter):
|
|
276
|
+
def _print_Integer(self, expr: sympy.Expr) -> str:
|
|
277
|
+
suffix = "LL" if sys.platform in ["darwin", "win32"] else "L"
|
|
278
|
+
i = int(expr)
|
|
279
|
+
if i > INDEX_TYPE_MAX or i < INDEX_TYPE_MIN:
|
|
280
|
+
raise OverflowError(f"{i} too big to convert to {INDEX_TYPE}")
|
|
281
|
+
elif i == INDEX_TYPE_MIN:
|
|
282
|
+
assert i == (-1) << 63
|
|
283
|
+
# Writing -9223372036854775808L makes the value overflow
|
|
284
|
+
# as it is parsed as -(9223372036854775808L) by the C/C++ compiler
|
|
285
|
+
return f"(-1{suffix} << 63)"
|
|
286
|
+
return f"{i}{suffix}"
|
|
287
|
+
|
|
288
|
+
def _print_Where(self, expr: sympy.Expr) -> str:
|
|
289
|
+
c, p, q = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
|
|
290
|
+
return f"{c} ? {p} : {q}"
|
|
291
|
+
|
|
292
|
+
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
|
293
|
+
x, div, mod = expr.args
|
|
294
|
+
x = self.doprint(x)
|
|
295
|
+
if div != 1:
|
|
296
|
+
div = self.doprint(div)
|
|
297
|
+
if expr.is_integer:
|
|
298
|
+
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
|
299
|
+
else:
|
|
300
|
+
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
|
301
|
+
mod = self.doprint(mod)
|
|
302
|
+
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
|
|
303
|
+
|
|
304
|
+
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
|
305
|
+
x, div = expr.args
|
|
306
|
+
x = self.doprint(x)
|
|
307
|
+
div = self.doprint(div)
|
|
308
|
+
if expr.is_integer:
|
|
309
|
+
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
|
310
|
+
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
|
311
|
+
|
|
312
|
+
def _print_floor(self, expr: sympy.Expr) -> str:
|
|
313
|
+
assert len(expr.args) == 1
|
|
314
|
+
r = f"std::floor({self._print(expr.args[0])})"
|
|
315
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
316
|
+
|
|
317
|
+
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
|
318
|
+
assert len(expr.args) == 1
|
|
319
|
+
r = f"std::floor({self._print(expr.args[0])})"
|
|
320
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
321
|
+
|
|
322
|
+
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
|
323
|
+
assert len(expr.args) == 1
|
|
324
|
+
r = f"std::trunc({self._print(expr.args[0])})"
|
|
325
|
+
return f"static_cast<{INDEX_TYPE}>({r})"
|
|
326
|
+
|
|
327
|
+
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
|
328
|
+
assert len(expr.args) == 1
|
|
329
|
+
return f"std::trunc({self._print(expr.args[0])})"
|
|
330
|
+
|
|
331
|
+
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
|
332
|
+
assert len(expr.args) == 1
|
|
333
|
+
return f"static_cast<double>({self._print(expr.args[0])})"
|
|
334
|
+
|
|
335
|
+
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
|
336
|
+
x, div = expr.args
|
|
337
|
+
x = self.doprint(x)
|
|
338
|
+
div = self.doprint(div)
|
|
339
|
+
return f"c10::div_mod({x}, {div})"
|
|
340
|
+
|
|
341
|
+
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
|
342
|
+
lhs, rhs = expr.args
|
|
343
|
+
# TODO: This is only accurate up to 2**53
|
|
344
|
+
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
|
345
|
+
|
|
346
|
+
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
|
347
|
+
# use std::pow, that operates on floats
|
|
348
|
+
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
|
349
|
+
# Implement the special-case of 2**x for now
|
|
350
|
+
base, exp = expr.args
|
|
351
|
+
if base == 2:
|
|
352
|
+
return f"(1 << ({self._print(exp)}))"
|
|
353
|
+
raise NotImplementedError(f"_print_PowByNatural not implemented for {type(self)}")
|
|
354
|
+
|
|
355
|
+
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
|
356
|
+
base, exp = expr.args
|
|
357
|
+
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
|
358
|
+
|
|
359
|
+
def _print_Pow(self, expr: sympy.Expr) -> str:
|
|
360
|
+
# Uses float constants to perform FP div
|
|
361
|
+
base, exp = expr.args
|
|
362
|
+
|
|
363
|
+
if exp == 0.5 or exp == -0.5:
|
|
364
|
+
base = self._print(base)
|
|
365
|
+
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
|
366
|
+
if exp.is_integer:
|
|
367
|
+
exp = int(exp)
|
|
368
|
+
if exp > 0:
|
|
369
|
+
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
|
370
|
+
elif exp < -1:
|
|
371
|
+
r = "1.0/(" + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + ")"
|
|
372
|
+
elif exp == -1:
|
|
373
|
+
r = "1.0/" + self._print(base)
|
|
374
|
+
else: # exp == 0
|
|
375
|
+
r = "1.0"
|
|
376
|
+
|
|
377
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
378
|
+
else:
|
|
379
|
+
# TODO: float vs double
|
|
380
|
+
return f"std::pow({base}, {float(exp)})"
|
|
381
|
+
|
|
382
|
+
def _print_Rational(self, expr: sympy.Expr) -> str:
|
|
383
|
+
# Uses float constants to perform FP div
|
|
384
|
+
if expr.q == 1:
|
|
385
|
+
r = f"{expr.p}"
|
|
386
|
+
else:
|
|
387
|
+
r = f"{expr.p}.0/{expr.q}.0"
|
|
388
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
389
|
+
|
|
390
|
+
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
|
391
|
+
assert len(expr.args) == 1
|
|
392
|
+
r = f"std::ceil({self._print(expr.args[0])})"
|
|
393
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
394
|
+
|
|
395
|
+
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
|
396
|
+
assert len(expr.args) == 1
|
|
397
|
+
r = f"std::ceil({self._print(expr.args[0])})"
|
|
398
|
+
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
399
|
+
|
|
400
|
+
def _print_Min(self, expr: sympy.Expr) -> str:
|
|
401
|
+
args = [self._print(a) for a in expr.args]
|
|
402
|
+
if len(args) == 2:
|
|
403
|
+
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
|
404
|
+
else:
|
|
405
|
+
# Initializer list overload
|
|
406
|
+
il = "{" + ", ".join(args) + "}"
|
|
407
|
+
return f"std::min<{INDEX_TYPE}>({il})"
|
|
408
|
+
|
|
409
|
+
def _print_Max(self, expr: sympy.Expr) -> str:
|
|
410
|
+
args = [self._print(a) for a in expr.args]
|
|
411
|
+
if len(args) == 2:
|
|
412
|
+
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
|
413
|
+
else:
|
|
414
|
+
# Initializer list overload
|
|
415
|
+
il = "{" + ", ".join(args) + "}"
|
|
416
|
+
return f"std::max<{INDEX_TYPE}>({il})"
|
|
417
|
+
|
|
418
|
+
def _print_Abs(self, expr: sympy.Expr) -> str:
|
|
419
|
+
assert len(expr.args) == 1
|
|
420
|
+
return f"std::abs({self._print(expr.args[0])})"
|
|
421
|
+
|
|
422
|
+
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
|
423
|
+
assert len(expr.args) == 1
|
|
424
|
+
return f"std::cos({self._print(expr.args[0])})"
|
|
425
|
+
|
|
426
|
+
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
|
427
|
+
assert len(expr.args) == 1
|
|
428
|
+
return f"std::cosh({self._print(expr.args[0])})"
|
|
429
|
+
|
|
430
|
+
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
|
431
|
+
assert len(expr.args) == 1
|
|
432
|
+
return f"std::acos({self._print(expr.args[0])})"
|
|
433
|
+
|
|
434
|
+
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
|
435
|
+
assert len(expr.args) == 1
|
|
436
|
+
return f"std::sin({self._print(expr.args[0])})"
|
|
437
|
+
|
|
438
|
+
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
|
439
|
+
assert len(expr.args) == 1
|
|
440
|
+
return f"std::sinh({self._print(expr.args[0])})"
|
|
441
|
+
|
|
442
|
+
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
|
443
|
+
assert len(expr.args) == 1
|
|
444
|
+
return f"std::asin({self._print(expr.args[0])})"
|
|
445
|
+
|
|
446
|
+
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
|
447
|
+
assert len(expr.args) == 1
|
|
448
|
+
return f"std::tan({self._print(expr.args[0])})"
|
|
449
|
+
|
|
450
|
+
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
|
451
|
+
assert len(expr.args) == 1
|
|
452
|
+
return f"std::tanh({self._print(expr.args[0])})"
|
|
453
|
+
|
|
454
|
+
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
|
455
|
+
assert len(expr.args) == 1
|
|
456
|
+
return f"std::atan({self._print(expr.args[0])})"
|
|
457
|
+
|
|
458
|
+
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
|
459
|
+
return f"std::sqrt({self._print(expr.args[0])})"
|
|
460
|
+
|
|
461
|
+
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
|
|
462
|
+
return f"std::log2({self._print(expr.args[0])})"
|
|
463
|
+
|
|
464
|
+
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
|
465
|
+
assert len(expr.args) == 1
|
|
466
|
+
# TODO: dispatch to llrint depending on index type
|
|
467
|
+
return f"std::lrint({self._print(expr.args[0])})"
|
|
468
|
+
|
|
469
|
+
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
|
470
|
+
assert len(expr.args) == 2
|
|
471
|
+
number, ndigits = expr.args
|
|
472
|
+
if number.is_integer:
|
|
473
|
+
# ndigits < 0 should have been filtered by the sympy function
|
|
474
|
+
assert ndigits < 0
|
|
475
|
+
raise ValueError(
|
|
476
|
+
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
|
477
|
+
)
|
|
478
|
+
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
|
|
479
|
+
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
|
|
480
|
+
|
|
481
|
+
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
|
|
482
|
+
return "true"
|
|
483
|
+
|
|
484
|
+
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
|
|
485
|
+
return "false"
|
|
486
|
+
|
|
487
|
+
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
|
488
|
+
return "std::numeric_limits<double>::infinity()"
|
|
489
|
+
|
|
490
|
+
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
|
491
|
+
return f"-{self._print_Infinity(expr)}"
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import sympy
|
|
6
|
+
|
|
7
|
+
from onnxslim.third_party._sympy.functions import FloorDiv
|
|
8
|
+
|
|
9
|
+
log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
_MIRROR_REL_OP: dict[type[sympy.Basic], type[sympy.Rel]] = {
|
|
12
|
+
sympy.Eq: sympy.Eq,
|
|
13
|
+
sympy.Ne: sympy.Ne,
|
|
14
|
+
sympy.Ge: sympy.Le,
|
|
15
|
+
sympy.Gt: sympy.Lt,
|
|
16
|
+
sympy.Le: sympy.Ge,
|
|
17
|
+
sympy.Lt: sympy.Gt,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
INEQUALITY_TYPES = (sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def mirror_rel_op(type: type) -> type[sympy.Rel] | None:
|
|
24
|
+
return _MIRROR_REL_OP.get(type, None)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# Tries to simplify 'expr', so as to leave only 'thing' in the left-hand side.
|
|
28
|
+
#
|
|
29
|
+
# Returns a tuple of:
|
|
30
|
+
# 1. The simplified expression
|
|
31
|
+
# 2. The expression on the right-hand side
|
|
32
|
+
#
|
|
33
|
+
# Returns 'None' if it can't reach a state where the only thing in the left
|
|
34
|
+
# hand side is 'thing'.
|
|
35
|
+
#
|
|
36
|
+
# 'trials': number of times 'try_solve' will try to isolate 'thing' to the
|
|
37
|
+
# left-hand side.
|
|
38
|
+
#
|
|
39
|
+
# 'floordiv_inequality': flag to enable conversion of 'FloorDiv' into
|
|
40
|
+
# inequalities.
|
|
41
|
+
def try_solve(
|
|
42
|
+
expr: sympy.Basic,
|
|
43
|
+
thing: sympy.Basic,
|
|
44
|
+
trials: int = 5,
|
|
45
|
+
floordiv_inequality: bool = True,
|
|
46
|
+
) -> tuple[sympy.Rel, sympy.Expr] | None:
|
|
47
|
+
mirror = mirror_rel_op(type(expr))
|
|
48
|
+
|
|
49
|
+
# Ignore unsupported expressions:
|
|
50
|
+
# - Those that are not relational operations
|
|
51
|
+
# - Those that don't have a mirror (just avoiding unexpected classes)
|
|
52
|
+
if not isinstance(expr, sympy.Rel) or mirror is None:
|
|
53
|
+
log.debug("expression with unsupported type: %s", type(expr))
|
|
54
|
+
return None
|
|
55
|
+
|
|
56
|
+
lhs_has_thing = expr.lhs.has(thing)
|
|
57
|
+
rhs_has_thing = expr.rhs.has(thing)
|
|
58
|
+
|
|
59
|
+
# Give up when 'thing' appears on both sides of the relational expression.
|
|
60
|
+
# That is because, as is, we assume the thing we are trying to isolate is
|
|
61
|
+
# only on the right-hand side.
|
|
62
|
+
if lhs_has_thing and rhs_has_thing:
|
|
63
|
+
log.debug("thing (%s) found in both sides of expression: %s", thing, expr)
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
# Try considering both LHS and RHS by mirroring the original expression:
|
|
67
|
+
# a < b ==> b > a
|
|
68
|
+
expressions = []
|
|
69
|
+
|
|
70
|
+
# Add each version of 'expr' if 'thing' is in its left-hand side.
|
|
71
|
+
if lhs_has_thing:
|
|
72
|
+
expressions.append(expr)
|
|
73
|
+
if rhs_has_thing:
|
|
74
|
+
expressions.append(mirror(expr.rhs, expr.lhs))
|
|
75
|
+
|
|
76
|
+
for e in expressions:
|
|
77
|
+
if e is None:
|
|
78
|
+
continue
|
|
79
|
+
|
|
80
|
+
assert isinstance(e, sympy.Rel)
|
|
81
|
+
|
|
82
|
+
for _ in range(trials):
|
|
83
|
+
trial = _try_isolate_lhs(e, thing, floordiv_inequality=floordiv_inequality)
|
|
84
|
+
# Stop if there was no change in this trial.
|
|
85
|
+
if trial == e:
|
|
86
|
+
break
|
|
87
|
+
e = trial # type: ignore[assignment]
|
|
88
|
+
|
|
89
|
+
# Return if we were able to isolate 'thing' on the left-hand side.
|
|
90
|
+
if isinstance(e, sympy.Rel) and e.lhs == thing:
|
|
91
|
+
log.debug("solved: %s ---> %s", expr, e)
|
|
92
|
+
return e, e.rhs
|
|
93
|
+
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def _try_isolate_lhs(e: sympy.Basic, thing: sympy.Basic, floordiv_inequality: bool) -> sympy.Basic:
|
|
98
|
+
op = type(e)
|
|
99
|
+
|
|
100
|
+
if isinstance(e, sympy.Rel):
|
|
101
|
+
# Move any constants in the left-hand side to the right-hand side.
|
|
102
|
+
lhs_not_thing = sum(a for a in e.lhs.args if not a.has(thing)) if isinstance(e.lhs, sympy.Add) else 0
|
|
103
|
+
e = op(e.lhs - lhs_not_thing, e.rhs - lhs_not_thing) # type: ignore[attr-defined]
|
|
104
|
+
|
|
105
|
+
# Divide both sides by the factors that don't contain thing.
|
|
106
|
+
if isinstance(e, sympy.Rel) and isinstance(e.lhs, sympy.Mul):
|
|
107
|
+
lhs, rhs = e.args
|
|
108
|
+
other = sympy.Mul(*[a for a in lhs.args if not a.has(thing)])
|
|
109
|
+
|
|
110
|
+
# If we can't tell whether 'other' is negative or positive, we do nothing.
|
|
111
|
+
# That is because we don't know whether we have mirror the operation or not.
|
|
112
|
+
# We also divide only when we know 'rhs' is not zero.
|
|
113
|
+
if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not (
|
|
114
|
+
not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero
|
|
115
|
+
):
|
|
116
|
+
# Divide both sides by 'other'.
|
|
117
|
+
lhs = lhs / other
|
|
118
|
+
rhs = rhs / other
|
|
119
|
+
|
|
120
|
+
# If 'e' is an inequality and 'other' is negative, we have to
|
|
121
|
+
# mirror the expression.
|
|
122
|
+
if isinstance(e, INEQUALITY_TYPES) and other.is_negative:
|
|
123
|
+
op = mirror_rel_op(op) # type: ignore[assignment]
|
|
124
|
+
|
|
125
|
+
assert op is not None
|
|
126
|
+
e = op(lhs, rhs)
|
|
127
|
+
|
|
128
|
+
################################################################################
|
|
129
|
+
# left-hand side is FloorDiv
|
|
130
|
+
################################################################################
|
|
131
|
+
#
|
|
132
|
+
# Given the expression: a // b op c
|
|
133
|
+
# where 'op' is a relational operation, these rules only work if:
|
|
134
|
+
# - b > 0
|
|
135
|
+
# - c is an integer
|
|
136
|
+
if (
|
|
137
|
+
floordiv_inequality
|
|
138
|
+
and isinstance(e, sympy.Rel)
|
|
139
|
+
and isinstance(e.lhs, FloorDiv)
|
|
140
|
+
and e.lhs.divisor.is_positive
|
|
141
|
+
and e.rhs.is_integer
|
|
142
|
+
):
|
|
143
|
+
# a // b == expr
|
|
144
|
+
# => a >= (b * expr) and a < (b * (expr + 1))
|
|
145
|
+
if isinstance(e, sympy.Eq):
|
|
146
|
+
numerator, denominator = e.lhs.args
|
|
147
|
+
return sympy.And(
|
|
148
|
+
sympy.Ge(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
149
|
+
sympy.Lt(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
150
|
+
)
|
|
151
|
+
# a // b != expr
|
|
152
|
+
# => a < (b * expr) or a >= (b * (expr + 1))
|
|
153
|
+
if isinstance(e, sympy.Ne):
|
|
154
|
+
numerator, denominator = e.lhs.args
|
|
155
|
+
return sympy.Or(
|
|
156
|
+
sympy.Lt(numerator, (e.rhs * denominator)), # type: ignore[arg-type]
|
|
157
|
+
sympy.Ge(numerator, ((e.rhs + 1) * denominator)), # type: ignore[arg-type]
|
|
158
|
+
)
|
|
159
|
+
# The transformations below only work if b is positive.
|
|
160
|
+
# Note: we only have this information for constants.
|
|
161
|
+
# a // b > expr => a >= b * (expr + 1)
|
|
162
|
+
# a // b >= expr => a >= b * expr
|
|
163
|
+
if isinstance(e, (sympy.Gt, sympy.Ge)):
|
|
164
|
+
quotient = e.rhs if isinstance(e, sympy.Ge) else (e.rhs + 1) # type: ignore[arg-type]
|
|
165
|
+
return sympy.Ge(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
166
|
+
# a // b < expr => a < b * expr
|
|
167
|
+
# a // b <= expr => a < b * (expr + 1)
|
|
168
|
+
if isinstance(e, (sympy.Lt, sympy.Le)):
|
|
169
|
+
quotient = e.rhs if isinstance(e, sympy.Lt) else (e.rhs + 1) # type: ignore[arg-type]
|
|
170
|
+
return sympy.Lt(e.lhs.args[0], (quotient * e.lhs.args[1])) # type: ignore[arg-type]
|
|
171
|
+
|
|
172
|
+
return e
|