onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,61 @@
1
+ from onnxslim.core.pattern import Pattern, PatternMatcher
2
+ from onnxslim.core.pattern.registry import register_fusion_pattern
3
+
4
+
5
+ class ConcatPatternMatcher(PatternMatcher):
6
+ def __init__(self, priority):
7
+ """Initializes the ConcatPatternMatcher with a specified priority using a predefined graph pattern."""
8
+ pattern = Pattern(
9
+ """
10
+ input input 0 1 concat_0
11
+ Concat concat_0 1+ 1 input concat_1
12
+ Concat concat_1 1* 1 concat_0 output
13
+ output output 1 0 concat_1
14
+ """
15
+ )
16
+ super().__init__(pattern, priority)
17
+
18
+ @property
19
+ def name(self):
20
+ """Returns the name of the elimination pattern, 'EliminationConcat'."""
21
+ return "EliminationConcat"
22
+
23
+ def rewrite(self, opset=11):
24
+ """Rewrites an elimination pattern for concat nodes by optimizing nested slice operations."""
25
+ match_case = {}
26
+
27
+ node_concat_0 = self.concat_0
28
+ users_node_concat_0 = node_concat_0.users
29
+ node_concat_1 = self.concat_1
30
+ node_concat_0_axis = node_concat_0.attrs.get("axis", 0)
31
+ node_concat_1.attrs.get("axis", 0)
32
+
33
+ if all(user.op == "Concat" and user.attrs.get("axis", 0) == node_concat_0_axis for user in users_node_concat_0):
34
+ index = node_concat_1.inputs.index(node_concat_0.outputs[0])
35
+ node_concat_1.inputs.pop(index)
36
+ for i, item in enumerate(node_concat_0.inputs):
37
+ node_concat_1.inputs.insert(index + i, item)
38
+ inputs = list(node_concat_1.inputs)
39
+ outputs = list(node_concat_1.outputs)
40
+ node_concat_1.inputs.clear()
41
+ node_concat_1.outputs.clear()
42
+
43
+ if len(users_node_concat_0) == 0:
44
+ node_concat_0.inputs.clear()
45
+ node_concat_0.outputs.clear()
46
+
47
+ attrs = {"axis": node_concat_0_axis}
48
+
49
+ match_case[node_concat_1.name] = {
50
+ "op": "Concat",
51
+ "inputs": inputs,
52
+ "outputs": outputs,
53
+ "name": node_concat_1.name,
54
+ "attrs": attrs,
55
+ "domain": None,
56
+ }
57
+
58
+ return match_case
59
+
60
+
61
+ register_fusion_pattern(ConcatPatternMatcher(1))
@@ -0,0 +1,77 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.pattern import Pattern, PatternMatcher
5
+ from onnxslim.core.pattern.registry import register_fusion_pattern
6
+
7
+
8
+ class ReshapePatternMatcher(PatternMatcher):
9
+ def __init__(self, priority):
10
+ """Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape
11
+ operations.
12
+ """
13
+ pattern = Pattern(
14
+ """
15
+ input input 0 1 reshape_0
16
+ Reshape reshape_0 2 1 input ? reshape_1
17
+ Reshape reshape_1 2 1 reshape_0 ? output
18
+ output output 1 0 reshape_1
19
+ """
20
+ )
21
+ super().__init__(pattern, priority)
22
+
23
+ @property
24
+ def name(self):
25
+ """Returns the name 'EliminationReshape'."""
26
+ return "EliminationReshape"
27
+
28
+ def rewrite(self, opset=11):
29
+ """Rewrite the computational graph by eliminating redundant reshape operations when certain conditions are
30
+ met.
31
+ """
32
+ match_case = {}
33
+ node = self.reshape_1
34
+ first_reshape_node = node.i(0)
35
+ first_reshape_node_inputs = list(first_reshape_node.inputs)
36
+ first_reshape_node_users = first_reshape_node.users
37
+ if len(first_reshape_node_users) == 1:
38
+ second_reshape_node = node
39
+
40
+ def check_constant_mergeable(reshape_node):
41
+ """Check if a reshape node's shape input, containing zero dimensions, can be merged with its input
42
+ node's shape.
43
+ """
44
+ if isinstance(reshape_node.inputs[1], gs.Constant):
45
+ input_shape = reshape_node.inputs[0].shape
46
+ reshape_shape = reshape_node.inputs[1].values.tolist()
47
+ if input_shape is not None and np.any(np.array(reshape_shape) == 0):
48
+ shape = [
49
+ input_shape[i] if dim_size == 0 else reshape_shape[i]
50
+ for i, dim_size in enumerate(reshape_shape)
51
+ ]
52
+ if not all(isinstance(item, int) for item in shape):
53
+ return False
54
+ return True
55
+
56
+ if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node):
57
+ inputs = []
58
+ inputs.append(first_reshape_node_inputs[0])
59
+ inputs.append(second_reshape_node.inputs[1])
60
+ outputs = list(second_reshape_node.outputs)
61
+ first_reshape_node.outputs.clear()
62
+ second_reshape_node.inputs.clear()
63
+ second_reshape_node.outputs.clear()
64
+
65
+ match_case[first_reshape_node.name] = {
66
+ "op": "Reshape",
67
+ "inputs": inputs,
68
+ "outputs": outputs,
69
+ "name": first_reshape_node.name,
70
+ "attrs": first_reshape_node.attrs,
71
+ "domain": None,
72
+ }
73
+
74
+ return match_case
75
+
76
+
77
+ register_fusion_pattern(ReshapePatternMatcher(1))
@@ -0,0 +1,64 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class ReshapeAsPatternMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the ReshapeAsPatternMatcher with a priority and a specific pattern for reshape as operations."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 shape
12
+ Shape shape 1+ 1 input gather
13
+ Gather gather 1+ 1 shape unsqueeze
14
+ Unsqueeze unsqueeze 1+ 1 gather output
15
+ Concat concat 1+ 1 unsqueeze output
16
+ output output 1 0 concat
17
+ """
18
+ )
19
+ super().__init__(pattern, priority)
20
+
21
+ @property
22
+ def name(self):
23
+ """Returns the name 'EliminationReshapeAs'."""
24
+ return "EliminationReshapeAs"
25
+
26
+ def parameter_check(self) -> bool:
27
+ shape_node = self.shape
28
+ if shape_node.outputs[0].shape is None:
29
+ return False
30
+
31
+ if len(shape_node.users) != shape_node.outputs[0].shape[0]:
32
+ return False
33
+
34
+ if not all([user.op == "Gather" for user in shape_node.users]):
35
+ return False
36
+
37
+ for idx, user in enumerate(shape_node.users):
38
+ if not isinstance(user.inputs[1], gs.Constant):
39
+ return False
40
+
41
+ if user.inputs[1].values.shape != ():
42
+ return False
43
+
44
+ if user.inputs[1].values != idx:
45
+ return False
46
+
47
+ concat_node = self.concat
48
+ if len(concat_node.inputs) != shape_node.users:
49
+ return False
50
+
51
+ return True
52
+
53
+ def rewrite(self, opset=11):
54
+ """Rewrites the pattern by replacing the Concat node with the Shape node."""
55
+ match_case = {}
56
+ shape_node = self.shape
57
+ concat_node = self.concat
58
+
59
+ concat_node.replace_all_uses_with(shape_node)
60
+
61
+ return match_case
62
+
63
+
64
+ register_fusion_pattern(ReshapeAsPatternMatcher(1))
@@ -0,0 +1,108 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.pattern import Pattern, PatternMatcher
5
+ from onnxslim.core.pattern.registry import register_fusion_pattern
6
+
7
+
8
+ class SlicePatternMatcher(PatternMatcher):
9
+ def __init__(self, priority):
10
+ """Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern."""
11
+ pattern = Pattern(
12
+ """
13
+ input input 0 1 slice_0
14
+ Slice slice_0 5 1 input ? ? ? ? slice_1
15
+ Slice slice_1 5 1 slice_0 ? ? ? ? output
16
+ output output 1 0 slice_1
17
+ """
18
+ ) # to check here slice_0
19
+ super().__init__(pattern, priority)
20
+
21
+ @property
22
+ def name(self):
23
+ """Returns the name of the elimination pattern, 'EliminationSlice'."""
24
+ return "EliminationSlice"
25
+
26
+ def rewrite(self, opset=11):
27
+ """Rewrites an elimination pattern for slice nodes by optimizing nested slice operations."""
28
+ match_case = {}
29
+ first_slice_node = self.slice_0
30
+ first_slice_node_inputs = list(first_slice_node.inputs)
31
+ if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]):
32
+ first_slice_node_users = first_slice_node.users
33
+ if all(
34
+ user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:])
35
+ for user in first_slice_node_users
36
+ ):
37
+ first_slice_node_starts = first_slice_node_inputs[1].values.tolist()
38
+ first_slice_node_ends = first_slice_node_inputs[2].values.tolist()
39
+ first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
40
+ first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
41
+
42
+ for user_node in first_slice_node_users:
43
+ second_slice_node = user_node
44
+ second_slice_node_inputs = list(second_slice_node.inputs)
45
+ second_slice_node_starts = second_slice_node_inputs[1].values.tolist()
46
+ second_slice_node_ends = second_slice_node_inputs[2].values.tolist()
47
+ second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
48
+ second_slice_node_steps = second_slice_node_inputs[4].values.tolist()
49
+
50
+ new_starts = first_slice_node_starts + second_slice_node_starts
51
+ new_ends = first_slice_node_ends + second_slice_node_ends
52
+ new_axes = first_slice_node_axes + second_slice_node_axes
53
+ new_steps = first_slice_node_steps + second_slice_node_steps
54
+
55
+ if len(new_axes) != len(set(new_axes)):
56
+ continue
57
+
58
+ inputs = []
59
+ inputs.extend(
60
+ (
61
+ next(iter(first_slice_node.inputs)),
62
+ gs.Constant(
63
+ second_slice_node_inputs[1].name + "_starts",
64
+ values=np.array(new_starts, dtype=np.int64),
65
+ ),
66
+ gs.Constant(
67
+ second_slice_node_inputs[2].name + "_ends",
68
+ values=np.array(new_ends, dtype=np.int64),
69
+ ),
70
+ gs.Constant(
71
+ second_slice_node_inputs[3].name + "_axes",
72
+ values=np.array(new_axes, dtype=np.int64),
73
+ ),
74
+ gs.Constant(
75
+ second_slice_node_inputs[4].name + "_steps",
76
+ values=np.array(new_steps, dtype=np.int64),
77
+ ),
78
+ )
79
+ )
80
+ outputs = list(second_slice_node.outputs)
81
+
82
+ first_slice_node.outputs.clear()
83
+ second_slice_node.inputs.clear()
84
+ second_slice_node.outputs.clear()
85
+
86
+ if len(first_slice_node_users) == 1:
87
+ match_case[first_slice_node.name] = {
88
+ "op": "Slice",
89
+ "inputs": inputs,
90
+ "outputs": outputs,
91
+ "name": first_slice_node.name,
92
+ "attrs": first_slice_node.attrs,
93
+ "domain": None,
94
+ }
95
+ else:
96
+ match_case[second_slice_node.name] = {
97
+ "op": "Slice",
98
+ "inputs": inputs,
99
+ "outputs": outputs,
100
+ "name": second_slice_node.name,
101
+ "attrs": second_slice_node.attrs,
102
+ "domain": None,
103
+ }
104
+
105
+ return match_case
106
+
107
+
108
+ register_fusion_pattern(SlicePatternMatcher(1))
@@ -0,0 +1,92 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.pattern import Pattern, PatternMatcher
5
+ from onnxslim.core.pattern.registry import register_fusion_pattern
6
+
7
+
8
+ class UnsqueezePatternMatcher(PatternMatcher):
9
+ def __init__(self, priority):
10
+ """Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern."""
11
+ pattern = Pattern(
12
+ """
13
+ input input 0 1 unsqueeze_0
14
+ Unsqueeze unsqueeze_0 1+ 1 input unsqueeze_1
15
+ Unsqueeze unsqueeze_1 1+ 1 unsqueeze_0 output
16
+ output output 1 0 unsqueeze_1
17
+ """
18
+ )
19
+ super().__init__(pattern, priority)
20
+
21
+ @property
22
+ def name(self):
23
+ """Returns the name of the elimination pattern, 'EliminationUnsqueeze'."""
24
+ return "EliminationUnsqueeze"
25
+
26
+ def rewrite(self, opset=11):
27
+ """Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations."""
28
+ match_case = {}
29
+ node_unsqueeze_0 = self.unsqueeze_0
30
+ users_node_unsqueeze_0 = node_unsqueeze_0.users
31
+ node_unsqueeze_1 = self.unsqueeze_1
32
+ if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape:
33
+ if opset < 13 or (
34
+ isinstance(node_unsqueeze_0.inputs[1], gs.Constant)
35
+ and isinstance(node_unsqueeze_1.inputs[1], gs.Constant)
36
+ ):
37
+
38
+ def get_unsqueeze_axes(unsqueeze_node, opset):
39
+ dim = len(unsqueeze_node.inputs[0].shape)
40
+ if opset < 13:
41
+ axes = unsqueeze_node.attrs["axes"]
42
+ else:
43
+ axes = unsqueeze_node.inputs[1].values
44
+ return [axis + dim + len(axes) if axis < 0 else axis for axis in axes]
45
+
46
+ axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset)
47
+ axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset)
48
+
49
+ axes_node_unsqueeze_0 = [
50
+ axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
51
+ ]
52
+
53
+ inputs = [next(iter(node_unsqueeze_0.inputs))]
54
+ outputs = list(node_unsqueeze_1.outputs)
55
+
56
+ index = node_unsqueeze_1.inputs.index(node_unsqueeze_0.outputs[0])
57
+ node_unsqueeze_1.inputs.pop(index)
58
+ for i, item in enumerate(node_unsqueeze_0.inputs):
59
+ node_unsqueeze_1.inputs.insert(index + i, item)
60
+ inputs = [next(iter(node_unsqueeze_1.inputs))]
61
+ outputs = list(node_unsqueeze_1.outputs)
62
+ node_unsqueeze_1.inputs.clear()
63
+ node_unsqueeze_1.outputs.clear()
64
+
65
+ if len(users_node_unsqueeze_0) == 0:
66
+ outputs.inputs.clear()
67
+ outputs.outputs.clear()
68
+
69
+ if opset < 13:
70
+ attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1}
71
+ else:
72
+ attrs = None
73
+ inputs.append(
74
+ gs.Constant(
75
+ name=f"{node_unsqueeze_0.name}_axes",
76
+ values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64),
77
+ )
78
+ )
79
+
80
+ match_case[node_unsqueeze_0.name] = {
81
+ "op": "Unsqueeze",
82
+ "inputs": inputs,
83
+ "outputs": outputs,
84
+ "name": node_unsqueeze_0.name,
85
+ "attrs": attrs,
86
+ "domain": None,
87
+ }
88
+
89
+ return match_case
90
+
91
+
92
+ register_fusion_pattern(UnsqueezePatternMatcher(1))
@@ -0,0 +1,8 @@
1
+ from .concat_reshape import *
2
+ from .convadd import *
3
+ from .convbn import *
4
+ from .convmul import *
5
+ from .gelu import *
6
+ from .gemm import *
7
+ from .padconv import *
8
+ from .reduce import *
@@ -0,0 +1,50 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.pattern import Pattern, PatternMatcher
5
+ from onnxslim.core.pattern.registry import register_fusion_pattern
6
+
7
+
8
+ class ConcatReshapeMatcher(PatternMatcher):
9
+ def __init__(self, priority):
10
+ pattern = Pattern(
11
+ """
12
+ input input 0 1 concat_0
13
+ Concat concat_0 1+ 1 input reshape_0
14
+ Reshape reshape_0 2 1 ? concat_0 output
15
+ output output 1 0 reshape_0
16
+ """
17
+ )
18
+ super().__init__(pattern, priority)
19
+
20
+ @property
21
+ def name(self):
22
+ return "FusionConcatReshape"
23
+
24
+ def parameter_check(self):
25
+ concat_node = self.concat_0
26
+
27
+ def check_inputs(inputs):
28
+ vars = [i for i in inputs if isinstance(i, gs.Variable)]
29
+ consts = [i for i in inputs if isinstance(i, gs.Constant)]
30
+ return (
31
+ len(vars) == 1 and all(c.values.size == 1 and c.values != -1 for c in consts) and vars[0].shape == [1]
32
+ )
33
+
34
+ return check_inputs(concat_node.inputs)
35
+
36
+ def rewrite(self, opset=11):
37
+ match_case = {}
38
+ concat_node = self.concat_0
39
+ index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
40
+ constant = gs.Constant(
41
+ concat_node.inputs[index].name + "_fixed",
42
+ values=np.array([-1], dtype=np.int64),
43
+ )
44
+ concat_node.inputs.pop(index)
45
+ concat_node.inputs.insert(index, constant)
46
+
47
+ return match_case
48
+
49
+
50
+ register_fusion_pattern(ConcatReshapeMatcher(1))
@@ -0,0 +1,70 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class ConvAddMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the ConvAddMatcher for fusing Conv and Add layers in an ONNX graph."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 conv_0
12
+ Conv conv_0 1+ 1 input bn_0
13
+ Add add_0 2 1 conv_0 ? output
14
+ output output 1 0 add_0
15
+ """
16
+ )
17
+ super().__init__(pattern, priority)
18
+
19
+ @property
20
+ def name(self):
21
+ """Returns the name of the FusionConvAdd pattern."""
22
+ return "FusionConvAdd"
23
+
24
+ def rewrite(self, opset=11):
25
+ match_case = {}
26
+ conv_node = self.conv_0
27
+ conv_weight = list(conv_node.inputs)[1]
28
+ conv_node_users = conv_node.users
29
+ node = self.add_0
30
+ if (
31
+ len(conv_node_users) == 1
32
+ and isinstance(node.inputs[1], gs.Constant)
33
+ and isinstance(conv_weight, gs.Constant)
34
+ and node.inputs[1].values.squeeze().ndim == 1
35
+ and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]
36
+ ):
37
+ add_node = node
38
+ if len(conv_node.inputs) == 2:
39
+ conv_bias = node.inputs[1].values.squeeze()
40
+ else:
41
+ conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze()
42
+
43
+ inputs = []
44
+ inputs.append(next(iter(conv_node.inputs)))
45
+ inputs.append(conv_weight)
46
+ weight_name = list(conv_node.inputs)[1].name
47
+ if weight_name.endswith("weight"):
48
+ bias_name = f"{weight_name[:-6]}bias"
49
+ else:
50
+ bias_name = f"{weight_name}_bias"
51
+ inputs.append(gs.Constant(bias_name, values=conv_bias))
52
+ outputs = list(add_node.outputs)
53
+
54
+ conv_node.outputs.clear()
55
+ add_node.inputs.clear()
56
+ add_node.outputs.clear()
57
+
58
+ match_case[conv_node.name] = {
59
+ "op": conv_node.op,
60
+ "inputs": inputs,
61
+ "outputs": outputs,
62
+ "name": conv_node.name,
63
+ "attrs": conv_node.attrs,
64
+ "domain": None,
65
+ }
66
+
67
+ return match_case
68
+
69
+
70
+ register_fusion_pattern(ConvAddMatcher(1))
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.pattern import Pattern, PatternMatcher
5
+ from onnxslim.core.pattern.registry import register_fusion_pattern
6
+
7
+
8
+ class ConvBatchNormMatcher(PatternMatcher):
9
+ def __init__(self, priority):
10
+ """Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph."""
11
+ pattern = Pattern(
12
+ """
13
+ input input 0 1 conv_0
14
+ Conv conv_0 1+ 1 input bn_0
15
+ BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
16
+ output output 1 0 bn_0
17
+ """
18
+ )
19
+ super().__init__(pattern, priority)
20
+
21
+ @property
22
+ def name(self):
23
+ """Returns the name of the FusionConvBN pattern."""
24
+ return "FusionConvBN"
25
+
26
+ def rewrite(self, opset=11):
27
+ """Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer."""
28
+ match_case = {}
29
+ conv_transpose_node = self.conv_0
30
+ conv_transpose_node_users = conv_transpose_node.users
31
+ node = self.bn_0
32
+ if len(conv_transpose_node_users) == 1 and isinstance(conv_transpose_node.inputs[1], gs.Constant):
33
+ conv_transpose_weight = conv_transpose_node.inputs[1].values
34
+ bn_node = node
35
+ bn_scale = bn_node.inputs[1].values
36
+ bn_bias = bn_node.inputs[2].values
37
+ bn_running_mean = bn_node.inputs[3].values
38
+ bn_running_var = bn_node.inputs[4].values
39
+ bn_eps = bn_node.attrs.get("epsilon", 1.0e-5)
40
+
41
+ if len(conv_transpose_node.inputs) == 2:
42
+ conv_transpose_bias = np.zeros_like(bn_running_mean)
43
+ else:
44
+ conv_transpose_bias = conv_transpose_node.inputs[2].values
45
+
46
+ bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
47
+ shape = [1] * len(conv_transpose_weight.shape)
48
+ if bn_node.i(0).op == "Conv":
49
+ shape[0] = -1
50
+ else:
51
+ shape[1] = -1
52
+ conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
53
+ conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
54
+
55
+ inputs = []
56
+ inputs.append(next(iter(conv_transpose_node.inputs)))
57
+ weight_name = list(conv_transpose_node.inputs)[1].name
58
+ if weight_name.endswith("weight"):
59
+ bias_name = f"{weight_name[:-6]}bias"
60
+ else:
61
+ bias_name = f"{weight_name}_bias"
62
+ inputs.extend(
63
+ (
64
+ gs.Constant(weight_name + "_weight", values=conv_w),
65
+ gs.Constant(bias_name, values=conv_b),
66
+ )
67
+ )
68
+ outputs = list(bn_node.outputs)
69
+
70
+ conv_transpose_node.outputs.clear()
71
+ bn_node.inputs.clear()
72
+ bn_node.outputs.clear()
73
+
74
+ match_case[conv_transpose_node.name] = {
75
+ "op": conv_transpose_node.op,
76
+ "inputs": inputs,
77
+ "outputs": outputs,
78
+ "name": conv_transpose_node.name,
79
+ "attrs": conv_transpose_node.attrs,
80
+ "domain": None,
81
+ }
82
+
83
+ return match_case
84
+
85
+
86
+ register_fusion_pattern(ConvBatchNormMatcher(1))