onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
2
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class ConcatPatternMatcher(PatternMatcher):
|
|
6
|
+
def __init__(self, priority):
|
|
7
|
+
"""Initializes the ConcatPatternMatcher with a specified priority using a predefined graph pattern."""
|
|
8
|
+
pattern = Pattern(
|
|
9
|
+
"""
|
|
10
|
+
input input 0 1 concat_0
|
|
11
|
+
Concat concat_0 1+ 1 input concat_1
|
|
12
|
+
Concat concat_1 1* 1 concat_0 output
|
|
13
|
+
output output 1 0 concat_1
|
|
14
|
+
"""
|
|
15
|
+
)
|
|
16
|
+
super().__init__(pattern, priority)
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def name(self):
|
|
20
|
+
"""Returns the name of the elimination pattern, 'EliminationConcat'."""
|
|
21
|
+
return "EliminationConcat"
|
|
22
|
+
|
|
23
|
+
def rewrite(self, opset=11):
|
|
24
|
+
"""Rewrites an elimination pattern for concat nodes by optimizing nested slice operations."""
|
|
25
|
+
match_case = {}
|
|
26
|
+
|
|
27
|
+
node_concat_0 = self.concat_0
|
|
28
|
+
users_node_concat_0 = node_concat_0.users
|
|
29
|
+
node_concat_1 = self.concat_1
|
|
30
|
+
node_concat_0_axis = node_concat_0.attrs.get("axis", 0)
|
|
31
|
+
node_concat_1.attrs.get("axis", 0)
|
|
32
|
+
|
|
33
|
+
if all(user.op == "Concat" and user.attrs.get("axis", 0) == node_concat_0_axis for user in users_node_concat_0):
|
|
34
|
+
index = node_concat_1.inputs.index(node_concat_0.outputs[0])
|
|
35
|
+
node_concat_1.inputs.pop(index)
|
|
36
|
+
for i, item in enumerate(node_concat_0.inputs):
|
|
37
|
+
node_concat_1.inputs.insert(index + i, item)
|
|
38
|
+
inputs = list(node_concat_1.inputs)
|
|
39
|
+
outputs = list(node_concat_1.outputs)
|
|
40
|
+
node_concat_1.inputs.clear()
|
|
41
|
+
node_concat_1.outputs.clear()
|
|
42
|
+
|
|
43
|
+
if len(users_node_concat_0) == 0:
|
|
44
|
+
node_concat_0.inputs.clear()
|
|
45
|
+
node_concat_0.outputs.clear()
|
|
46
|
+
|
|
47
|
+
attrs = {"axis": node_concat_0_axis}
|
|
48
|
+
|
|
49
|
+
match_case[node_concat_1.name] = {
|
|
50
|
+
"op": "Concat",
|
|
51
|
+
"inputs": inputs,
|
|
52
|
+
"outputs": outputs,
|
|
53
|
+
"name": node_concat_1.name,
|
|
54
|
+
"attrs": attrs,
|
|
55
|
+
"domain": None,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
return match_case
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
register_fusion_pattern(ConcatPatternMatcher(1))
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
5
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ReshapePatternMatcher(PatternMatcher):
|
|
9
|
+
def __init__(self, priority):
|
|
10
|
+
"""Initializes the ReshapePatternMatcher with a priority and a specific pattern for detecting nested reshape
|
|
11
|
+
operations.
|
|
12
|
+
"""
|
|
13
|
+
pattern = Pattern(
|
|
14
|
+
"""
|
|
15
|
+
input input 0 1 reshape_0
|
|
16
|
+
Reshape reshape_0 2 1 input ? reshape_1
|
|
17
|
+
Reshape reshape_1 2 1 reshape_0 ? output
|
|
18
|
+
output output 1 0 reshape_1
|
|
19
|
+
"""
|
|
20
|
+
)
|
|
21
|
+
super().__init__(pattern, priority)
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def name(self):
|
|
25
|
+
"""Returns the name 'EliminationReshape'."""
|
|
26
|
+
return "EliminationReshape"
|
|
27
|
+
|
|
28
|
+
def rewrite(self, opset=11):
|
|
29
|
+
"""Rewrite the computational graph by eliminating redundant reshape operations when certain conditions are
|
|
30
|
+
met.
|
|
31
|
+
"""
|
|
32
|
+
match_case = {}
|
|
33
|
+
node = self.reshape_1
|
|
34
|
+
first_reshape_node = node.i(0)
|
|
35
|
+
first_reshape_node_inputs = list(first_reshape_node.inputs)
|
|
36
|
+
first_reshape_node_users = first_reshape_node.users
|
|
37
|
+
if len(first_reshape_node_users) == 1:
|
|
38
|
+
second_reshape_node = node
|
|
39
|
+
|
|
40
|
+
def check_constant_mergeable(reshape_node):
|
|
41
|
+
"""Check if a reshape node's shape input, containing zero dimensions, can be merged with its input
|
|
42
|
+
node's shape.
|
|
43
|
+
"""
|
|
44
|
+
if isinstance(reshape_node.inputs[1], gs.Constant):
|
|
45
|
+
input_shape = reshape_node.inputs[0].shape
|
|
46
|
+
reshape_shape = reshape_node.inputs[1].values.tolist()
|
|
47
|
+
if input_shape is not None and np.any(np.array(reshape_shape) == 0):
|
|
48
|
+
shape = [
|
|
49
|
+
input_shape[i] if dim_size == 0 else reshape_shape[i]
|
|
50
|
+
for i, dim_size in enumerate(reshape_shape)
|
|
51
|
+
]
|
|
52
|
+
if not all(isinstance(item, int) for item in shape):
|
|
53
|
+
return False
|
|
54
|
+
return True
|
|
55
|
+
|
|
56
|
+
if check_constant_mergeable(first_reshape_node) and check_constant_mergeable(second_reshape_node):
|
|
57
|
+
inputs = []
|
|
58
|
+
inputs.append(first_reshape_node_inputs[0])
|
|
59
|
+
inputs.append(second_reshape_node.inputs[1])
|
|
60
|
+
outputs = list(second_reshape_node.outputs)
|
|
61
|
+
first_reshape_node.outputs.clear()
|
|
62
|
+
second_reshape_node.inputs.clear()
|
|
63
|
+
second_reshape_node.outputs.clear()
|
|
64
|
+
|
|
65
|
+
match_case[first_reshape_node.name] = {
|
|
66
|
+
"op": "Reshape",
|
|
67
|
+
"inputs": inputs,
|
|
68
|
+
"outputs": outputs,
|
|
69
|
+
"name": first_reshape_node.name,
|
|
70
|
+
"attrs": first_reshape_node.attrs,
|
|
71
|
+
"domain": None,
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
return match_case
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
register_fusion_pattern(ReshapePatternMatcher(1))
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReshapeAsPatternMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ReshapeAsPatternMatcher with a priority and a specific pattern for reshape as operations."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 shape
|
|
12
|
+
Shape shape 1+ 1 input gather
|
|
13
|
+
Gather gather 1+ 1 shape unsqueeze
|
|
14
|
+
Unsqueeze unsqueeze 1+ 1 gather output
|
|
15
|
+
Concat concat 1+ 1 unsqueeze output
|
|
16
|
+
output output 1 0 concat
|
|
17
|
+
"""
|
|
18
|
+
)
|
|
19
|
+
super().__init__(pattern, priority)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self):
|
|
23
|
+
"""Returns the name 'EliminationReshapeAs'."""
|
|
24
|
+
return "EliminationReshapeAs"
|
|
25
|
+
|
|
26
|
+
def parameter_check(self) -> bool:
|
|
27
|
+
shape_node = self.shape
|
|
28
|
+
if shape_node.outputs[0].shape is None:
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
if len(shape_node.users) != shape_node.outputs[0].shape[0]:
|
|
32
|
+
return False
|
|
33
|
+
|
|
34
|
+
if not all([user.op == "Gather" for user in shape_node.users]):
|
|
35
|
+
return False
|
|
36
|
+
|
|
37
|
+
for idx, user in enumerate(shape_node.users):
|
|
38
|
+
if not isinstance(user.inputs[1], gs.Constant):
|
|
39
|
+
return False
|
|
40
|
+
|
|
41
|
+
if user.inputs[1].values.shape != ():
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
if user.inputs[1].values != idx:
|
|
45
|
+
return False
|
|
46
|
+
|
|
47
|
+
concat_node = self.concat
|
|
48
|
+
if len(concat_node.inputs) != shape_node.users:
|
|
49
|
+
return False
|
|
50
|
+
|
|
51
|
+
return True
|
|
52
|
+
|
|
53
|
+
def rewrite(self, opset=11):
|
|
54
|
+
"""Rewrites the pattern by replacing the Concat node with the Shape node."""
|
|
55
|
+
match_case = {}
|
|
56
|
+
shape_node = self.shape
|
|
57
|
+
concat_node = self.concat
|
|
58
|
+
|
|
59
|
+
concat_node.replace_all_uses_with(shape_node)
|
|
60
|
+
|
|
61
|
+
return match_case
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
register_fusion_pattern(ReshapeAsPatternMatcher(1))
|
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
5
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class SlicePatternMatcher(PatternMatcher):
|
|
9
|
+
def __init__(self, priority):
|
|
10
|
+
"""Initializes the SlicePatternMatcher with a specified priority using a predefined graph pattern."""
|
|
11
|
+
pattern = Pattern(
|
|
12
|
+
"""
|
|
13
|
+
input input 0 1 slice_0
|
|
14
|
+
Slice slice_0 5 1 input ? ? ? ? slice_1
|
|
15
|
+
Slice slice_1 5 1 slice_0 ? ? ? ? output
|
|
16
|
+
output output 1 0 slice_1
|
|
17
|
+
"""
|
|
18
|
+
) # to check here slice_0
|
|
19
|
+
super().__init__(pattern, priority)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self):
|
|
23
|
+
"""Returns the name of the elimination pattern, 'EliminationSlice'."""
|
|
24
|
+
return "EliminationSlice"
|
|
25
|
+
|
|
26
|
+
def rewrite(self, opset=11):
|
|
27
|
+
"""Rewrites an elimination pattern for slice nodes by optimizing nested slice operations."""
|
|
28
|
+
match_case = {}
|
|
29
|
+
first_slice_node = self.slice_0
|
|
30
|
+
first_slice_node_inputs = list(first_slice_node.inputs)
|
|
31
|
+
if all(isinstance(input, gs.Constant) for input in first_slice_node_inputs[1:]):
|
|
32
|
+
first_slice_node_users = first_slice_node.users
|
|
33
|
+
if all(
|
|
34
|
+
user.op == "Slice" and all(isinstance(input, gs.Constant) for input in list(user.inputs)[1:])
|
|
35
|
+
for user in first_slice_node_users
|
|
36
|
+
):
|
|
37
|
+
first_slice_node_starts = first_slice_node_inputs[1].values.tolist()
|
|
38
|
+
first_slice_node_ends = first_slice_node_inputs[2].values.tolist()
|
|
39
|
+
first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
|
|
40
|
+
first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
|
|
41
|
+
|
|
42
|
+
for user_node in first_slice_node_users:
|
|
43
|
+
second_slice_node = user_node
|
|
44
|
+
second_slice_node_inputs = list(second_slice_node.inputs)
|
|
45
|
+
second_slice_node_starts = second_slice_node_inputs[1].values.tolist()
|
|
46
|
+
second_slice_node_ends = second_slice_node_inputs[2].values.tolist()
|
|
47
|
+
second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
|
|
48
|
+
second_slice_node_steps = second_slice_node_inputs[4].values.tolist()
|
|
49
|
+
|
|
50
|
+
new_starts = first_slice_node_starts + second_slice_node_starts
|
|
51
|
+
new_ends = first_slice_node_ends + second_slice_node_ends
|
|
52
|
+
new_axes = first_slice_node_axes + second_slice_node_axes
|
|
53
|
+
new_steps = first_slice_node_steps + second_slice_node_steps
|
|
54
|
+
|
|
55
|
+
if len(new_axes) != len(set(new_axes)):
|
|
56
|
+
continue
|
|
57
|
+
|
|
58
|
+
inputs = []
|
|
59
|
+
inputs.extend(
|
|
60
|
+
(
|
|
61
|
+
next(iter(first_slice_node.inputs)),
|
|
62
|
+
gs.Constant(
|
|
63
|
+
second_slice_node_inputs[1].name + "_starts",
|
|
64
|
+
values=np.array(new_starts, dtype=np.int64),
|
|
65
|
+
),
|
|
66
|
+
gs.Constant(
|
|
67
|
+
second_slice_node_inputs[2].name + "_ends",
|
|
68
|
+
values=np.array(new_ends, dtype=np.int64),
|
|
69
|
+
),
|
|
70
|
+
gs.Constant(
|
|
71
|
+
second_slice_node_inputs[3].name + "_axes",
|
|
72
|
+
values=np.array(new_axes, dtype=np.int64),
|
|
73
|
+
),
|
|
74
|
+
gs.Constant(
|
|
75
|
+
second_slice_node_inputs[4].name + "_steps",
|
|
76
|
+
values=np.array(new_steps, dtype=np.int64),
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
)
|
|
80
|
+
outputs = list(second_slice_node.outputs)
|
|
81
|
+
|
|
82
|
+
first_slice_node.outputs.clear()
|
|
83
|
+
second_slice_node.inputs.clear()
|
|
84
|
+
second_slice_node.outputs.clear()
|
|
85
|
+
|
|
86
|
+
if len(first_slice_node_users) == 1:
|
|
87
|
+
match_case[first_slice_node.name] = {
|
|
88
|
+
"op": "Slice",
|
|
89
|
+
"inputs": inputs,
|
|
90
|
+
"outputs": outputs,
|
|
91
|
+
"name": first_slice_node.name,
|
|
92
|
+
"attrs": first_slice_node.attrs,
|
|
93
|
+
"domain": None,
|
|
94
|
+
}
|
|
95
|
+
else:
|
|
96
|
+
match_case[second_slice_node.name] = {
|
|
97
|
+
"op": "Slice",
|
|
98
|
+
"inputs": inputs,
|
|
99
|
+
"outputs": outputs,
|
|
100
|
+
"name": second_slice_node.name,
|
|
101
|
+
"attrs": second_slice_node.attrs,
|
|
102
|
+
"domain": None,
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
return match_case
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
register_fusion_pattern(SlicePatternMatcher(1))
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
5
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class UnsqueezePatternMatcher(PatternMatcher):
|
|
9
|
+
def __init__(self, priority):
|
|
10
|
+
"""Initializes the UnsqueezePatternMatcher with a specified priority using a predefined graph pattern."""
|
|
11
|
+
pattern = Pattern(
|
|
12
|
+
"""
|
|
13
|
+
input input 0 1 unsqueeze_0
|
|
14
|
+
Unsqueeze unsqueeze_0 1+ 1 input unsqueeze_1
|
|
15
|
+
Unsqueeze unsqueeze_1 1+ 1 unsqueeze_0 output
|
|
16
|
+
output output 1 0 unsqueeze_1
|
|
17
|
+
"""
|
|
18
|
+
)
|
|
19
|
+
super().__init__(pattern, priority)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self):
|
|
23
|
+
"""Returns the name of the elimination pattern, 'EliminationUnsqueeze'."""
|
|
24
|
+
return "EliminationUnsqueeze"
|
|
25
|
+
|
|
26
|
+
def rewrite(self, opset=11):
|
|
27
|
+
"""Rewrites an elimination pattern for unsqueeze nodes by optimizing nested slice operations."""
|
|
28
|
+
match_case = {}
|
|
29
|
+
node_unsqueeze_0 = self.unsqueeze_0
|
|
30
|
+
users_node_unsqueeze_0 = node_unsqueeze_0.users
|
|
31
|
+
node_unsqueeze_1 = self.unsqueeze_1
|
|
32
|
+
if len(users_node_unsqueeze_0) == 1 and node_unsqueeze_0.inputs[0].shape and node_unsqueeze_1.inputs[0].shape:
|
|
33
|
+
if opset < 13 or (
|
|
34
|
+
isinstance(node_unsqueeze_0.inputs[1], gs.Constant)
|
|
35
|
+
and isinstance(node_unsqueeze_1.inputs[1], gs.Constant)
|
|
36
|
+
):
|
|
37
|
+
|
|
38
|
+
def get_unsqueeze_axes(unsqueeze_node, opset):
|
|
39
|
+
dim = len(unsqueeze_node.inputs[0].shape)
|
|
40
|
+
if opset < 13:
|
|
41
|
+
axes = unsqueeze_node.attrs["axes"]
|
|
42
|
+
else:
|
|
43
|
+
axes = unsqueeze_node.inputs[1].values
|
|
44
|
+
return [axis + dim + len(axes) if axis < 0 else axis for axis in axes]
|
|
45
|
+
|
|
46
|
+
axes_node_unsqueeze_0 = get_unsqueeze_axes(node_unsqueeze_0, opset)
|
|
47
|
+
axes_node_unsqueeze_1 = get_unsqueeze_axes(node_unsqueeze_1, opset)
|
|
48
|
+
|
|
49
|
+
axes_node_unsqueeze_0 = [
|
|
50
|
+
axis + sum(1 for axis_ in axes_node_unsqueeze_1 if axis_ <= axis) for axis in axes_node_unsqueeze_0
|
|
51
|
+
]
|
|
52
|
+
|
|
53
|
+
inputs = [next(iter(node_unsqueeze_0.inputs))]
|
|
54
|
+
outputs = list(node_unsqueeze_1.outputs)
|
|
55
|
+
|
|
56
|
+
index = node_unsqueeze_1.inputs.index(node_unsqueeze_0.outputs[0])
|
|
57
|
+
node_unsqueeze_1.inputs.pop(index)
|
|
58
|
+
for i, item in enumerate(node_unsqueeze_0.inputs):
|
|
59
|
+
node_unsqueeze_1.inputs.insert(index + i, item)
|
|
60
|
+
inputs = [next(iter(node_unsqueeze_1.inputs))]
|
|
61
|
+
outputs = list(node_unsqueeze_1.outputs)
|
|
62
|
+
node_unsqueeze_1.inputs.clear()
|
|
63
|
+
node_unsqueeze_1.outputs.clear()
|
|
64
|
+
|
|
65
|
+
if len(users_node_unsqueeze_0) == 0:
|
|
66
|
+
outputs.inputs.clear()
|
|
67
|
+
outputs.outputs.clear()
|
|
68
|
+
|
|
69
|
+
if opset < 13:
|
|
70
|
+
attrs = {"axes": axes_node_unsqueeze_0 + axes_node_unsqueeze_1}
|
|
71
|
+
else:
|
|
72
|
+
attrs = None
|
|
73
|
+
inputs.append(
|
|
74
|
+
gs.Constant(
|
|
75
|
+
name=f"{node_unsqueeze_0.name}_axes",
|
|
76
|
+
values=np.array(axes_node_unsqueeze_0 + axes_node_unsqueeze_1, dtype=np.int64),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
match_case[node_unsqueeze_0.name] = {
|
|
81
|
+
"op": "Unsqueeze",
|
|
82
|
+
"inputs": inputs,
|
|
83
|
+
"outputs": outputs,
|
|
84
|
+
"name": node_unsqueeze_0.name,
|
|
85
|
+
"attrs": attrs,
|
|
86
|
+
"domain": None,
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
return match_case
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
register_fusion_pattern(UnsqueezePatternMatcher(1))
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
5
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConcatReshapeMatcher(PatternMatcher):
|
|
9
|
+
def __init__(self, priority):
|
|
10
|
+
pattern = Pattern(
|
|
11
|
+
"""
|
|
12
|
+
input input 0 1 concat_0
|
|
13
|
+
Concat concat_0 1+ 1 input reshape_0
|
|
14
|
+
Reshape reshape_0 2 1 ? concat_0 output
|
|
15
|
+
output output 1 0 reshape_0
|
|
16
|
+
"""
|
|
17
|
+
)
|
|
18
|
+
super().__init__(pattern, priority)
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def name(self):
|
|
22
|
+
return "FusionConcatReshape"
|
|
23
|
+
|
|
24
|
+
def parameter_check(self):
|
|
25
|
+
concat_node = self.concat_0
|
|
26
|
+
|
|
27
|
+
def check_inputs(inputs):
|
|
28
|
+
vars = [i for i in inputs if isinstance(i, gs.Variable)]
|
|
29
|
+
consts = [i for i in inputs if isinstance(i, gs.Constant)]
|
|
30
|
+
return (
|
|
31
|
+
len(vars) == 1 and all(c.values.size == 1 and c.values != -1 for c in consts) and vars[0].shape == [1]
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
return check_inputs(concat_node.inputs)
|
|
35
|
+
|
|
36
|
+
def rewrite(self, opset=11):
|
|
37
|
+
match_case = {}
|
|
38
|
+
concat_node = self.concat_0
|
|
39
|
+
index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
|
|
40
|
+
constant = gs.Constant(
|
|
41
|
+
concat_node.inputs[index].name + "_fixed",
|
|
42
|
+
values=np.array([-1], dtype=np.int64),
|
|
43
|
+
)
|
|
44
|
+
concat_node.inputs.pop(index)
|
|
45
|
+
concat_node.inputs.insert(index, constant)
|
|
46
|
+
|
|
47
|
+
return match_case
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
register_fusion_pattern(ConcatReshapeMatcher(1))
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConvAddMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ConvAddMatcher for fusing Conv and Add layers in an ONNX graph."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 conv_0
|
|
12
|
+
Conv conv_0 1+ 1 input bn_0
|
|
13
|
+
Add add_0 2 1 conv_0 ? output
|
|
14
|
+
output output 1 0 add_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the FusionConvAdd pattern."""
|
|
22
|
+
return "FusionConvAdd"
|
|
23
|
+
|
|
24
|
+
def rewrite(self, opset=11):
|
|
25
|
+
match_case = {}
|
|
26
|
+
conv_node = self.conv_0
|
|
27
|
+
conv_weight = list(conv_node.inputs)[1]
|
|
28
|
+
conv_node_users = conv_node.users
|
|
29
|
+
node = self.add_0
|
|
30
|
+
if (
|
|
31
|
+
len(conv_node_users) == 1
|
|
32
|
+
and isinstance(node.inputs[1], gs.Constant)
|
|
33
|
+
and isinstance(conv_weight, gs.Constant)
|
|
34
|
+
and node.inputs[1].values.squeeze().ndim == 1
|
|
35
|
+
and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]
|
|
36
|
+
):
|
|
37
|
+
add_node = node
|
|
38
|
+
if len(conv_node.inputs) == 2:
|
|
39
|
+
conv_bias = node.inputs[1].values.squeeze()
|
|
40
|
+
else:
|
|
41
|
+
conv_bias = conv_node.inputs[2].values + node.inputs[1].values.squeeze()
|
|
42
|
+
|
|
43
|
+
inputs = []
|
|
44
|
+
inputs.append(next(iter(conv_node.inputs)))
|
|
45
|
+
inputs.append(conv_weight)
|
|
46
|
+
weight_name = list(conv_node.inputs)[1].name
|
|
47
|
+
if weight_name.endswith("weight"):
|
|
48
|
+
bias_name = f"{weight_name[:-6]}bias"
|
|
49
|
+
else:
|
|
50
|
+
bias_name = f"{weight_name}_bias"
|
|
51
|
+
inputs.append(gs.Constant(bias_name, values=conv_bias))
|
|
52
|
+
outputs = list(add_node.outputs)
|
|
53
|
+
|
|
54
|
+
conv_node.outputs.clear()
|
|
55
|
+
add_node.inputs.clear()
|
|
56
|
+
add_node.outputs.clear()
|
|
57
|
+
|
|
58
|
+
match_case[conv_node.name] = {
|
|
59
|
+
"op": conv_node.op,
|
|
60
|
+
"inputs": inputs,
|
|
61
|
+
"outputs": outputs,
|
|
62
|
+
"name": conv_node.name,
|
|
63
|
+
"attrs": conv_node.attrs,
|
|
64
|
+
"domain": None,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return match_case
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
register_fusion_pattern(ConvAddMatcher(1))
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
5
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class ConvBatchNormMatcher(PatternMatcher):
|
|
9
|
+
def __init__(self, priority):
|
|
10
|
+
"""Initializes the ConvBatchNormMatcher for fusing Conv and BatchNormalization layers in an ONNX graph."""
|
|
11
|
+
pattern = Pattern(
|
|
12
|
+
"""
|
|
13
|
+
input input 0 1 conv_0
|
|
14
|
+
Conv conv_0 1+ 1 input bn_0
|
|
15
|
+
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
|
|
16
|
+
output output 1 0 bn_0
|
|
17
|
+
"""
|
|
18
|
+
)
|
|
19
|
+
super().__init__(pattern, priority)
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def name(self):
|
|
23
|
+
"""Returns the name of the FusionConvBN pattern."""
|
|
24
|
+
return "FusionConvBN"
|
|
25
|
+
|
|
26
|
+
def rewrite(self, opset=11):
|
|
27
|
+
"""Rewrites the weights and biases of a BatchNormalization layer fused with a convolution layer."""
|
|
28
|
+
match_case = {}
|
|
29
|
+
conv_transpose_node = self.conv_0
|
|
30
|
+
conv_transpose_node_users = conv_transpose_node.users
|
|
31
|
+
node = self.bn_0
|
|
32
|
+
if len(conv_transpose_node_users) == 1 and isinstance(conv_transpose_node.inputs[1], gs.Constant):
|
|
33
|
+
conv_transpose_weight = conv_transpose_node.inputs[1].values
|
|
34
|
+
bn_node = node
|
|
35
|
+
bn_scale = bn_node.inputs[1].values
|
|
36
|
+
bn_bias = bn_node.inputs[2].values
|
|
37
|
+
bn_running_mean = bn_node.inputs[3].values
|
|
38
|
+
bn_running_var = bn_node.inputs[4].values
|
|
39
|
+
bn_eps = bn_node.attrs.get("epsilon", 1.0e-5)
|
|
40
|
+
|
|
41
|
+
if len(conv_transpose_node.inputs) == 2:
|
|
42
|
+
conv_transpose_bias = np.zeros_like(bn_running_mean)
|
|
43
|
+
else:
|
|
44
|
+
conv_transpose_bias = conv_transpose_node.inputs[2].values
|
|
45
|
+
|
|
46
|
+
bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
|
|
47
|
+
shape = [1] * len(conv_transpose_weight.shape)
|
|
48
|
+
if bn_node.i(0).op == "Conv":
|
|
49
|
+
shape[0] = -1
|
|
50
|
+
else:
|
|
51
|
+
shape[1] = -1
|
|
52
|
+
conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
|
|
53
|
+
conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
|
|
54
|
+
|
|
55
|
+
inputs = []
|
|
56
|
+
inputs.append(next(iter(conv_transpose_node.inputs)))
|
|
57
|
+
weight_name = list(conv_transpose_node.inputs)[1].name
|
|
58
|
+
if weight_name.endswith("weight"):
|
|
59
|
+
bias_name = f"{weight_name[:-6]}bias"
|
|
60
|
+
else:
|
|
61
|
+
bias_name = f"{weight_name}_bias"
|
|
62
|
+
inputs.extend(
|
|
63
|
+
(
|
|
64
|
+
gs.Constant(weight_name + "_weight", values=conv_w),
|
|
65
|
+
gs.Constant(bias_name, values=conv_b),
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
outputs = list(bn_node.outputs)
|
|
69
|
+
|
|
70
|
+
conv_transpose_node.outputs.clear()
|
|
71
|
+
bn_node.inputs.clear()
|
|
72
|
+
bn_node.outputs.clear()
|
|
73
|
+
|
|
74
|
+
match_case[conv_transpose_node.name] = {
|
|
75
|
+
"op": conv_transpose_node.op,
|
|
76
|
+
"inputs": inputs,
|
|
77
|
+
"outputs": outputs,
|
|
78
|
+
"name": conv_transpose_node.name,
|
|
79
|
+
"attrs": conv_transpose_node.attrs,
|
|
80
|
+
"domain": None,
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
return match_case
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
register_fusion_pattern(ConvBatchNormMatcher(1))
|