onnxslim 0.1.72__py3-none-any.whl → 0.1.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,7 @@
1
1
  from .concat_reshape import *
2
2
  from .convadd import *
3
3
  from .convbn import *
4
+ from .convmul import *
4
5
  from .gelu import *
5
6
  from .gemm import *
6
7
  from .padconv import *
@@ -0,0 +1,69 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class ConvMulMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the ConvMulMatcher for fusing Conv and Mul layers in an ONNX graph."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 conv_0
12
+ Conv conv_0 1+ 1 input mul_0
13
+ Mul mul_0 2 1 conv_0 ? output
14
+ output output 1 0 mul_0
15
+ """
16
+ )
17
+ super().__init__(pattern, priority)
18
+
19
+ @property
20
+ def name(self):
21
+ """Returns the name of the FusionConvMul pattern."""
22
+ return "FusionConvMul"
23
+
24
+ def rewrite(self, opset=11):
25
+ match_case = {}
26
+ conv_node = self.conv_0
27
+ mul_node = self.mul_0
28
+ conv_weight = list(conv_node.inputs)[1]
29
+ if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
30
+ mul_constant = mul_node.inputs[1].values
31
+
32
+ if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
33
+ weight_shape = conv_weight.values.shape
34
+ reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
35
+
36
+ mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
37
+ new_weight = conv_weight.values * mul_scale_reshaped
38
+
39
+ inputs = []
40
+ inputs.append(next(iter(conv_node.inputs)))
41
+
42
+ weight_name = list(conv_node.inputs)[1].name
43
+ inputs.append(gs.Constant(weight_name, values=new_weight))
44
+
45
+ if len(conv_node.inputs) == 3:
46
+ conv_bias = conv_node.inputs[2].values
47
+ new_bias = conv_bias * mul_constant.squeeze()
48
+ bias_name = list(conv_node.inputs)[2].name
49
+ inputs.append(gs.Constant(bias_name, values=new_bias))
50
+
51
+ outputs = list(mul_node.outputs)
52
+
53
+ conv_node.outputs.clear()
54
+ mul_node.inputs.clear()
55
+ mul_node.outputs.clear()
56
+
57
+ match_case[conv_node.name] = {
58
+ "op": conv_node.op,
59
+ "inputs": inputs,
60
+ "outputs": outputs,
61
+ "name": conv_node.name,
62
+ "attrs": conv_node.attrs,
63
+ "domain": None,
64
+ }
65
+
66
+ return match_case
67
+
68
+
69
+ register_fusion_pattern(ConvMulMatcher(1))
@@ -289,16 +289,38 @@ class GemmAddPatternMatcher(PatternMatcher):
289
289
  )
290
290
  and add_bias_variable
291
291
  and len(reshape_node.users) == 1
292
+ and gemm_node.outputs[0].shape
292
293
  ):
294
+
295
+ def can_broadcast_to(shape_from, shape_to):
296
+ """Return True if shape_from can broadcast to shape_to per NumPy rules."""
297
+ if shape_from is None or shape_to is None:
298
+ return False
299
+ try:
300
+ np.empty(shape_to, dtype=np.float32) + np.empty(shape_from, dtype=np.float32)
301
+ return True
302
+ except ValueError:
303
+ return False
304
+
293
305
  gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
294
306
  if gemm_bias_constant:
295
307
  gemm_bias = gemm_bias_constant.values
296
308
  add_bias = add_bias_variable.values
297
- gemm_bias_fused = gemm_bias + add_bias
298
- gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
299
- gemm_node.inputs[2] = gemm_bias_fused_constant
309
+ if (
310
+ can_broadcast_to(gemm_bias.shape, gemm_node.outputs[0].shape)
311
+ and can_broadcast_to(add_bias.shape, gemm_node.outputs[0].shape)
312
+ and add_bias.ndim <= 2
313
+ ):
314
+ gemm_bias_fused = gemm_bias + add_bias
315
+ gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
316
+ gemm_node.inputs[2] = gemm_bias_fused_constant
317
+ else:
318
+ return match_case
300
319
  else:
301
- gemm_node.inputs[2] = add_bias_variable
320
+ if can_broadcast_to(add_bias_variable.values.shape, gemm_node.outputs[0].shape):
321
+ gemm_node.inputs[2] = add_bias_variable
322
+ else:
323
+ return match_case
302
324
 
303
325
  add_node.replace_all_uses_with(reshape_node)
304
326
 
@@ -24,6 +24,7 @@ class PadConvMatcher(PatternMatcher):
24
24
  def parameter_check(self) -> bool:
25
25
  """Validates if the padding parameter for a convolutional node is a constant."""
26
26
  pad_node = self.pad_0
27
+
27
28
  return isinstance(pad_node.inputs[1], gs.Constant)
28
29
 
29
30
  def rewrite(self, opset=11):
@@ -42,7 +43,7 @@ class PadConvMatcher(PatternMatcher):
42
43
  ):
43
44
  if (
44
45
  isinstance(pad_node.inputs[1], gs.Constant)
45
- and pad_node.attrs["mode"] == "constant"
46
+ and pad_node.attrs.get("mode", "constant") == "constant"
46
47
  and conv_node.inputs[1].shape
47
48
  ):
48
49
  conv_weight_dim = len(conv_node.inputs[1].shape)
onnxslim/misc/tabulate.py CHANGED
@@ -1606,7 +1606,7 @@ def tabulate(
1606
1606
  given header. Possible values are: "global" (no override), "same"
1607
1607
  (follow column alignment), "right", "center", "left".
1608
1608
 
1609
- Note on intended behaviour: If there is no `tabular_data`, any column
1609
+ Note on intended behavior: If there is no `tabular_data`, any column
1610
1610
  alignment argument is ignored. Hence, in this case, header
1611
1611
  alignment cannot be inferred from column alignment.
1612
1612
 
@@ -52,7 +52,7 @@ def simple_floordiv_gcd(p: sympy.Basic, q: sympy.Basic) -> sympy.Basic:
52
52
 
53
53
  We try to rewrite p and q in the form n*e*p1 + n*e*p2 and n*e*q0, where n is the greatest common integer factor and
54
54
  e is the largest syntactic common factor (i.e., common sub-expression) in p and q. Then the gcd returned is n*e,
55
- cancelling which we would be left with p1 + p2 and q0.
55
+ canceling which we would be left with p1 + p2 and q0.
56
56
 
57
57
  Note that further factoring of p1 + p2 and q0 might be possible with sympy.factor (which uses domain-specific
58
58
  theories). E.g., we are unable to find that x*y + x + y + 1 is divisible by x + 1. More generally, when q is of the
onnxslim/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.72"
1
+ __version__ = "0.1.73"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnxslim
3
- Version: 0.1.72
3
+ Version: 0.1.73
4
4
  Summary: OnnxSlim: A Toolkit to Help Optimize Onnx Model
5
5
  Home-page: https://github.com/inisis/OnnxSlim
6
6
  Author: inisis
@@ -2,7 +2,7 @@ onnxslim/__init__.py,sha256=ECHGdxzg4b-4SZaPhxM_KulBi-xDbVcVUbpJc8i6a60,571
2
2
  onnxslim/__main__.py,sha256=FgDcl6xX8kV_52rB-jPVsmGqidlVhkpe_YhXK75-nFU,75
3
3
  onnxslim/argparser.py,sha256=pFv3nEZH2BiHO9ejS4Iq5ZuZ3GrpdyRQJypAyR0xF7w,8942
4
4
  onnxslim/utils.py,sha256=Z39vRKAtD7o18NFbf3Qrws9xDg81uBJ9F-RYEFAfqM8,28095
5
- onnxslim/version.py,sha256=gUdfeW8Q-eHCbw8IWb2lI3qRGa6BpSk5KcikOGNzSdM,23
5
+ onnxslim/version.py,sha256=B2xptQoAZtCL_O_fLMSWF71JDtlCZp8jLmdtHyMdosE,23
6
6
  onnxslim/cli/__init__.py,sha256=kxK27cDgWotBOWRs86rbRQf_dtmniKr1GZJeasxfESE,42
7
7
  onnxslim/cli/_main.py,sha256=jEKv9q7y_y9g0GsrfXcnk_wyMVej6jhe9QNPChE7yTs,5550
8
8
  onnxslim/core/__init__.py,sha256=uDg-Eu29Ezb3txwZf5mN0zQRVuqF-K9BvktE8WBYS4E,8825
@@ -18,20 +18,21 @@ onnxslim/core/pattern/elimination/reshape.py,sha256=XwvuPAZnXCCEwJb2n1guigstnsl3
18
18
  onnxslim/core/pattern/elimination/reshape_as.py,sha256=FI3LYR0pzbp2pDmaX13duHrQ4uqwaKNu4bG78en-7wY,2034
19
19
  onnxslim/core/pattern/elimination/slice.py,sha256=moZibU-TbtdwtmGIUwyjnjf3oRCeCBcQq0M1gY5ZWDk,5033
20
20
  onnxslim/core/pattern/elimination/unsqueeze.py,sha256=v7Rin3qB6F49ETrxXWEQQxUgtlF18nvHb6JFarf0kwQ,3855
21
- onnxslim/core/pattern/fusion/__init__.py,sha256=e1uydy8ab-5Niq1yQjoxcoz8q7bBzb0wFPL8NVUvhCs,160
21
+ onnxslim/core/pattern/fusion/__init__.py,sha256=3ajHvRurL7WHL4tfNsBoLQh6Sq2fyiqH-VsPuftYMGg,183
22
22
  onnxslim/core/pattern/fusion/concat_reshape.py,sha256=LvknixTAsSUqUkGSuoEA1QpC-TmBrsx6AHZoeT0gTbI,1615
23
23
  onnxslim/core/pattern/fusion/convadd.py,sha256=P1GI7hJAHgDBO17aDDghNxMEhWkFIcqGLIfnpTMGhWk,2432
24
24
  onnxslim/core/pattern/fusion/convbn.py,sha256=1wI0nPCRj_3y2Ozortrm6gGDvy6qwH6CwHlyYLl_lRI,3340
25
+ onnxslim/core/pattern/fusion/convmul.py,sha256=W2C6H3kWSDUg0he0jfR4tXI5GMi7gsyylQR4aSh-rik,2581
25
26
  onnxslim/core/pattern/fusion/gelu.py,sha256=uR67AJ_tL1gboY6VsTdqajHxW3Pbu656UMhCe1mQZDY,1469
26
- onnxslim/core/pattern/fusion/gemm.py,sha256=M0femDKmtak5fGIifN3MZDAU0WFW9GBqypUi38zeU5c,11942
27
- onnxslim/core/pattern/fusion/padconv.py,sha256=2Y1bgW6yRO3Yv-u2Pf5xbMIOeI3JMQc43ZgRy__rIuA,3580
27
+ onnxslim/core/pattern/fusion/gemm.py,sha256=Ti9yZAfEprFRvW1FiAD0zvewELOJbRjposIk3yjjXfQ,12928
28
+ onnxslim/core/pattern/fusion/padconv.py,sha256=O2DtY1XxP7-3k3vRBRPiNh4TuAyFW1n1-mPebwMdcXc,3597
28
29
  onnxslim/core/pattern/fusion/reduce.py,sha256=dMC7CPlFglrJxugsJWjcc-jQCIa_GIbW1y9K2FRvvcE,2755
29
30
  onnxslim/misc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- onnxslim/misc/tabulate.py,sha256=SqfeKm24drfruIyCwHQmVftsIyRW3lCYyq6Hs7RPwLw,99696
31
+ onnxslim/misc/tabulate.py,sha256=Pg5uU0UP18HbwG-c8LlA82LbIb_5JWQeuIB1AnturbM,99695
31
32
  onnxslim/third_party/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
33
  onnxslim/third_party/symbolic_shape_infer.py,sha256=KbdxXuRnjhj1vJGDR4SCIwq-sdUKPsSnc0ZrzAtaQ8s,152390
33
34
  onnxslim/third_party/_sympy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
- onnxslim/third_party/_sympy/functions.py,sha256=FrrTHPkDhwh_EccLGugbcodeOjfGeM17T5_oVPUjOlA,8877
35
+ onnxslim/third_party/_sympy/functions.py,sha256=s3pKzyjYCKnvlddLFR_H8UmbbcdMB51PRxqhe9zGI9E,8876
35
36
  onnxslim/third_party/_sympy/numbers.py,sha256=w1dJJcQkKRzLDCJMn70YTwPtrEWPFrhCJpAsZhukJOk,11383
36
37
  onnxslim/third_party/_sympy/printers.py,sha256=Kv2vpR-YgjCsNxIBKkcHyvEnj_H-gmJqG03Hwh4rWdk,20429
37
38
  onnxslim/third_party/_sympy/solve.py,sha256=gcqmluQbAKzIQTTtzsgALguhK8ViRGlVP-CRDcbwP6A,6465
@@ -55,10 +56,10 @@ onnxslim/third_party/onnx_graphsurgeon/logger/logger.py,sha256=L12rrwn33RHH-2WLv
55
56
  onnxslim/third_party/onnx_graphsurgeon/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
57
  onnxslim/third_party/onnx_graphsurgeon/util/exception.py,sha256=KrsHbKEQ4237UbjlODsUzvkXoAY72LZi23ApBeFANWg,786
57
58
  onnxslim/third_party/onnx_graphsurgeon/util/misc.py,sha256=kyxInD2SCRLU4wHMeiDEYEHB3871fGks6kQTuF9uATY,8960
58
- onnxslim-0.1.72.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
59
- onnxslim-0.1.72.dist-info/METADATA,sha256=HyWNHr0zT0rr-y-9du1FsX4q8uQu5JDXkm43qDPaRo0,7621
60
- onnxslim-0.1.72.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- onnxslim-0.1.72.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
62
- onnxslim-0.1.72.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
63
- onnxslim-0.1.72.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
64
- onnxslim-0.1.72.dist-info/RECORD,,
59
+ onnxslim-0.1.73.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
60
+ onnxslim-0.1.73.dist-info/METADATA,sha256=NZ6PTQNwmexVWSCO0kplgCbKwpG13zfFmCKeqAyxJSk,7621
61
+ onnxslim-0.1.73.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
62
+ onnxslim-0.1.73.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
63
+ onnxslim-0.1.73.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
64
+ onnxslim-0.1.73.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
65
+ onnxslim-0.1.73.dist-info/RECORD,,