onnxslim 0.1.72__py3-none-any.whl → 0.1.74__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/pattern/fusion/__init__.py +1 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gemm.py +26 -4
- onnxslim/core/pattern/fusion/padconv.py +5 -3
- onnxslim/misc/tabulate.py +1 -1
- onnxslim/third_party/_sympy/functions.py +1 -1
- onnxslim/third_party/symbolic_shape_infer.py +1 -0
- onnxslim/version.py +1 -1
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/METADATA +1 -1
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/RECORD +15 -14
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/licenses/LICENSE +0 -0
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/top_level.txt +0 -0
- {onnxslim-0.1.72.dist-info → onnxslim-0.1.74.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConvMulMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ConvMulMatcher for fusing Conv and Mul layers in an ONNX graph."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 conv_0
|
|
12
|
+
Conv conv_0 1+ 1 input mul_0
|
|
13
|
+
Mul mul_0 2 1 conv_0 ? output
|
|
14
|
+
output output 1 0 mul_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the FusionConvMul pattern."""
|
|
22
|
+
return "FusionConvMul"
|
|
23
|
+
|
|
24
|
+
def rewrite(self, opset=11):
|
|
25
|
+
match_case = {}
|
|
26
|
+
conv_node = self.conv_0
|
|
27
|
+
mul_node = self.mul_0
|
|
28
|
+
conv_weight = list(conv_node.inputs)[1]
|
|
29
|
+
if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
|
|
30
|
+
mul_constant = mul_node.inputs[1].values
|
|
31
|
+
|
|
32
|
+
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
|
|
33
|
+
weight_shape = conv_weight.values.shape
|
|
34
|
+
reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
|
|
35
|
+
|
|
36
|
+
mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
|
|
37
|
+
new_weight = conv_weight.values * mul_scale_reshaped
|
|
38
|
+
|
|
39
|
+
inputs = []
|
|
40
|
+
inputs.append(next(iter(conv_node.inputs)))
|
|
41
|
+
|
|
42
|
+
weight_name = list(conv_node.inputs)[1].name
|
|
43
|
+
inputs.append(gs.Constant(weight_name, values=new_weight))
|
|
44
|
+
|
|
45
|
+
if len(conv_node.inputs) == 3:
|
|
46
|
+
conv_bias = conv_node.inputs[2].values
|
|
47
|
+
new_bias = conv_bias * mul_constant.squeeze()
|
|
48
|
+
bias_name = list(conv_node.inputs)[2].name
|
|
49
|
+
inputs.append(gs.Constant(bias_name, values=new_bias))
|
|
50
|
+
|
|
51
|
+
outputs = list(mul_node.outputs)
|
|
52
|
+
|
|
53
|
+
conv_node.outputs.clear()
|
|
54
|
+
mul_node.inputs.clear()
|
|
55
|
+
mul_node.outputs.clear()
|
|
56
|
+
|
|
57
|
+
match_case[conv_node.name] = {
|
|
58
|
+
"op": conv_node.op,
|
|
59
|
+
"inputs": inputs,
|
|
60
|
+
"outputs": outputs,
|
|
61
|
+
"name": conv_node.name,
|
|
62
|
+
"attrs": conv_node.attrs,
|
|
63
|
+
"domain": None,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return match_case
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
register_fusion_pattern(ConvMulMatcher(1))
|
|
@@ -289,16 +289,38 @@ class GemmAddPatternMatcher(PatternMatcher):
|
|
|
289
289
|
)
|
|
290
290
|
and add_bias_variable
|
|
291
291
|
and len(reshape_node.users) == 1
|
|
292
|
+
and gemm_node.outputs[0].shape
|
|
292
293
|
):
|
|
294
|
+
|
|
295
|
+
def can_broadcast_to(shape_from, shape_to):
|
|
296
|
+
"""Return True if shape_from can broadcast to shape_to per NumPy rules."""
|
|
297
|
+
if shape_from is None or shape_to is None:
|
|
298
|
+
return False
|
|
299
|
+
try:
|
|
300
|
+
np.empty(shape_to, dtype=np.float32) + np.empty(shape_from, dtype=np.float32)
|
|
301
|
+
return True
|
|
302
|
+
except ValueError:
|
|
303
|
+
return False
|
|
304
|
+
|
|
293
305
|
gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
|
|
294
306
|
if gemm_bias_constant:
|
|
295
307
|
gemm_bias = gemm_bias_constant.values
|
|
296
308
|
add_bias = add_bias_variable.values
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
309
|
+
if (
|
|
310
|
+
can_broadcast_to(gemm_bias.shape, gemm_node.outputs[0].shape)
|
|
311
|
+
and can_broadcast_to(add_bias.shape, gemm_node.outputs[0].shape)
|
|
312
|
+
and add_bias.ndim <= 2
|
|
313
|
+
):
|
|
314
|
+
gemm_bias_fused = gemm_bias + add_bias
|
|
315
|
+
gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
|
|
316
|
+
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
317
|
+
else:
|
|
318
|
+
return match_case
|
|
300
319
|
else:
|
|
301
|
-
gemm_node.
|
|
320
|
+
if can_broadcast_to(add_bias_variable.values.shape, gemm_node.outputs[0].shape):
|
|
321
|
+
gemm_node.inputs[2] = add_bias_variable
|
|
322
|
+
else:
|
|
323
|
+
return match_case
|
|
302
324
|
|
|
303
325
|
add_node.replace_all_uses_with(reshape_node)
|
|
304
326
|
|
|
@@ -24,6 +24,7 @@ class PadConvMatcher(PatternMatcher):
|
|
|
24
24
|
def parameter_check(self) -> bool:
|
|
25
25
|
"""Validates if the padding parameter for a convolutional node is a constant."""
|
|
26
26
|
pad_node = self.pad_0
|
|
27
|
+
|
|
27
28
|
return isinstance(pad_node.inputs[1], gs.Constant)
|
|
28
29
|
|
|
29
30
|
def rewrite(self, opset=11):
|
|
@@ -42,7 +43,7 @@ class PadConvMatcher(PatternMatcher):
|
|
|
42
43
|
):
|
|
43
44
|
if (
|
|
44
45
|
isinstance(pad_node.inputs[1], gs.Constant)
|
|
45
|
-
and pad_node.attrs
|
|
46
|
+
and pad_node.attrs.get("mode", "constant") == "constant"
|
|
46
47
|
and conv_node.inputs[1].shape
|
|
47
48
|
):
|
|
48
49
|
conv_weight_dim = len(conv_node.inputs[1].shape)
|
|
@@ -67,9 +68,10 @@ class PadConvMatcher(PatternMatcher):
|
|
|
67
68
|
pad_node.inputs.clear()
|
|
68
69
|
pad_node.outputs.clear()
|
|
69
70
|
|
|
70
|
-
conv_pads = attrs["pads"]
|
|
71
71
|
pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :]
|
|
72
|
-
|
|
72
|
+
if hasattr(attrs, "pads"):
|
|
73
|
+
conv_pads = attrs["pads"]
|
|
74
|
+
pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
|
|
73
75
|
|
|
74
76
|
attrs["pads"] = pads
|
|
75
77
|
match_case[conv_node.name] = {
|
onnxslim/misc/tabulate.py
CHANGED
|
@@ -1606,7 +1606,7 @@ def tabulate(
|
|
|
1606
1606
|
given header. Possible values are: "global" (no override), "same"
|
|
1607
1607
|
(follow column alignment), "right", "center", "left".
|
|
1608
1608
|
|
|
1609
|
-
Note on intended
|
|
1609
|
+
Note on intended behavior: If there is no `tabular_data`, any column
|
|
1610
1610
|
alignment argument is ignored. Hence, in this case, header
|
|
1611
1611
|
alignment cannot be inferred from column alignment.
|
|
1612
1612
|
|
|
@@ -52,7 +52,7 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
|
|
|
52
52
|
|
|
53
53
|
We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, where n is the greatest common integer factor and
|
|
54
54
|
e is the largest syntactic common factor (i.e., common sub-expression) in p and q. Then the gcd returned is n*e,
|
|
55
|
-
|
|
55
|
+
canceling which we would be left with p1 + p2 and q0.
|
|
56
56
|
|
|
57
57
|
Note that further factoring of p1 + p2 and q0 might be possible with sympy.factor (which uses domain-specific
|
|
58
58
|
theories). E.g., we are unable to find that x*y + x + y + 1 is divisible by x + 1. More generally, when q is of the
|
|
@@ -532,6 +532,7 @@ class SymbolicShapeInference:
|
|
|
532
532
|
initializers = []
|
|
533
533
|
if (get_opset(self.out_mp_) >= 9) and (
|
|
534
534
|
node.op_type == "Unsqueeze" or node.op_type == "ReduceMax" or node.op_type == "ReduceMean"
|
|
535
|
+
or node.op_type == "DFT" or node.op_type == "ReduceL2" or node.op_type == "ReduceMin"
|
|
535
536
|
):
|
|
536
537
|
initializers = [
|
|
537
538
|
self.initializers_[name]
|
onnxslim/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.74"
|
|
@@ -2,7 +2,7 @@ onnxslim/__init__.py,sha256=ECHGdxzg4b-4SZaPhxM_KulBi-xDbVcVUbpJc8i6a60,571
|
|
|
2
2
|
onnxslim/__main__.py,sha256=FgDcl6xX8kV_52rB-jPVsmGqidlVhkpe_YhXK75-nFU,75
|
|
3
3
|
onnxslim/argparser.py,sha256=pFv3nEZH2BiHO9ejS4Iq5ZuZ3GrpdyRQJypAyR0xF7w,8942
|
|
4
4
|
onnxslim/utils.py,sha256=Z39vRKAtD7o18NFbf3Qrws9xDg81uBJ9F-RYEFAfqM8,28095
|
|
5
|
-
onnxslim/version.py,sha256=
|
|
5
|
+
onnxslim/version.py,sha256=KewwZnnRMQkstdiLHG6R7ciil5WtDn_QVk-Z392irYc,23
|
|
6
6
|
onnxslim/cli/__init__.py,sha256=kxK27cDgWotBOWRs86rbRQf_dtmniKr1GZJeasxfESE,42
|
|
7
7
|
onnxslim/cli/_main.py,sha256=jEKv9q7y_y9g0GsrfXcnk_wyMVej6jhe9QNPChE7yTs,5550
|
|
8
8
|
onnxslim/core/__init__.py,sha256=uDg-Eu29Ezb3txwZf5mN0zQRVuqF-K9BvktE8WBYS4E,8825
|
|
@@ -18,20 +18,21 @@ onnxslim/core/pattern/elimination/reshape.py,sha256=XwvuPAZnXCCEwJb2n1guigstnsl3
|
|
|
18
18
|
onnxslim/core/pattern/elimination/reshape_as.py,sha256=FI3LYR0pzbp2pDmaX13duHrQ4uqwaKNu4bG78en-7wY,2034
|
|
19
19
|
onnxslim/core/pattern/elimination/slice.py,sha256=moZibU-TbtdwtmGIUwyjnjf3oRCeCBcQq0M1gY5ZWDk,5033
|
|
20
20
|
onnxslim/core/pattern/elimination/unsqueeze.py,sha256=v7Rin3qB6F49ETrxXWEQQxUgtlF18nvHb6JFarf0kwQ,3855
|
|
21
|
-
onnxslim/core/pattern/fusion/__init__.py,sha256=
|
|
21
|
+
onnxslim/core/pattern/fusion/__init__.py,sha256=3ajHvRurL7WHL4tfNsBoLQh6Sq2fyiqH-VsPuftYMGg,183
|
|
22
22
|
onnxslim/core/pattern/fusion/concat_reshape.py,sha256=LvknixTAsSUqUkGSuoEA1QpC-TmBrsx6AHZoeT0gTbI,1615
|
|
23
23
|
onnxslim/core/pattern/fusion/convadd.py,sha256=P1GI7hJAHgDBO17aDDghNxMEhWkFIcqGLIfnpTMGhWk,2432
|
|
24
24
|
onnxslim/core/pattern/fusion/convbn.py,sha256=1wI0nPCRj_3y2Ozortrm6gGDvy6qwH6CwHlyYLl_lRI,3340
|
|
25
|
+
onnxslim/core/pattern/fusion/convmul.py,sha256=W2C6H3kWSDUg0he0jfR4tXI5GMi7gsyylQR4aSh-rik,2581
|
|
25
26
|
onnxslim/core/pattern/fusion/gelu.py,sha256=uR67AJ_tL1gboY6VsTdqajHxW3Pbu656UMhCe1mQZDY,1469
|
|
26
|
-
onnxslim/core/pattern/fusion/gemm.py,sha256=
|
|
27
|
-
onnxslim/core/pattern/fusion/padconv.py,sha256=
|
|
27
|
+
onnxslim/core/pattern/fusion/gemm.py,sha256=Ti9yZAfEprFRvW1FiAD0zvewELOJbRjposIk3yjjXfQ,12928
|
|
28
|
+
onnxslim/core/pattern/fusion/padconv.py,sha256=eOutev5rOrHuyyw-BRIFzMjcvu9MxXj73kY215GaeG8,3652
|
|
28
29
|
onnxslim/core/pattern/fusion/reduce.py,sha256=dMC7CPlFglrJxugsJWjcc-jQCIa_GIbW1y9K2FRvvcE,2755
|
|
29
30
|
onnxslim/misc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
onnxslim/misc/tabulate.py,sha256=
|
|
31
|
+
onnxslim/misc/tabulate.py,sha256=Pg5uU0UP18HbwG-c8LlA82LbIb_5JWQeuIB1AnturbM,99695
|
|
31
32
|
onnxslim/third_party/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
-
onnxslim/third_party/symbolic_shape_infer.py,sha256=
|
|
33
|
+
onnxslim/third_party/symbolic_shape_infer.py,sha256=1rY4J73ZQgdMscAeTceGPrET_bRQVq5O4IP424yTtlQ,152492
|
|
33
34
|
onnxslim/third_party/_sympy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
-
onnxslim/third_party/_sympy/functions.py,sha256=
|
|
35
|
+
onnxslim/third_party/_sympy/functions.py,sha256=s3pKzyjYCKnvlddLFR_H8UmbbcdMB51PRxqhe9zGI9E,8876
|
|
35
36
|
onnxslim/third_party/_sympy/numbers.py,sha256=w1dJJcQkKRzLDCJMn70YTwPtrEWPFrhCJpAsZhukJOk,11383
|
|
36
37
|
onnxslim/third_party/_sympy/printers.py,sha256=Kv2vpR-YgjCsNxIBKkcHyvEnj_H-gmJqG03Hwh4rWdk,20429
|
|
37
38
|
onnxslim/third_party/_sympy/solve.py,sha256=gcqmluQbAKzIQTTtzsgALguhK8ViRGlVP-CRDcbwP6A,6465
|
|
@@ -55,10 +56,10 @@ onnxslim/third_party/onnx_graphsurgeon/logger/logger.py,sha256=L12rrwn33RHH-2WLv
|
|
|
55
56
|
onnxslim/third_party/onnx_graphsurgeon/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
57
|
onnxslim/third_party/onnx_graphsurgeon/util/exception.py,sha256=KrsHbKEQ4237UbjlODsUzvkXoAY72LZi23ApBeFANWg,786
|
|
57
58
|
onnxslim/third_party/onnx_graphsurgeon/util/misc.py,sha256=kyxInD2SCRLU4wHMeiDEYEHB3871fGks6kQTuF9uATY,8960
|
|
58
|
-
onnxslim-0.1.
|
|
59
|
-
onnxslim-0.1.
|
|
60
|
-
onnxslim-0.1.
|
|
61
|
-
onnxslim-0.1.
|
|
62
|
-
onnxslim-0.1.
|
|
63
|
-
onnxslim-0.1.
|
|
64
|
-
onnxslim-0.1.
|
|
59
|
+
onnxslim-0.1.74.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
|
|
60
|
+
onnxslim-0.1.74.dist-info/METADATA,sha256=a8Ckc7D7p-LlSnAlKbpXm6dda_ZBD9yVvW2V97Ynfok,7621
|
|
61
|
+
onnxslim-0.1.74.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
62
|
+
onnxslim-0.1.74.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
|
|
63
|
+
onnxslim-0.1.74.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
|
|
64
|
+
onnxslim-0.1.74.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
65
|
+
onnxslim-0.1.74.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|