onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ConvMulMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ConvMulMatcher for fusing Conv and Mul layers in an ONNX graph."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 conv_0
|
|
12
|
+
Conv conv_0 1+ 1 input mul_0
|
|
13
|
+
Mul mul_0 2 1 conv_0 ? output
|
|
14
|
+
output output 1 0 mul_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the FusionConvMul pattern."""
|
|
22
|
+
return "FusionConvMul"
|
|
23
|
+
|
|
24
|
+
def rewrite(self, opset=11):
|
|
25
|
+
match_case = {}
|
|
26
|
+
conv_node = self.conv_0
|
|
27
|
+
mul_node = self.mul_0
|
|
28
|
+
conv_weight = list(conv_node.inputs)[1]
|
|
29
|
+
if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
|
|
30
|
+
mul_constant = mul_node.inputs[1].values
|
|
31
|
+
|
|
32
|
+
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
|
|
33
|
+
weight_shape = conv_weight.values.shape
|
|
34
|
+
reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
|
|
35
|
+
|
|
36
|
+
mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
|
|
37
|
+
new_weight = conv_weight.values * mul_scale_reshaped
|
|
38
|
+
|
|
39
|
+
inputs = []
|
|
40
|
+
inputs.append(next(iter(conv_node.inputs)))
|
|
41
|
+
|
|
42
|
+
weight_name = list(conv_node.inputs)[1].name
|
|
43
|
+
inputs.append(gs.Constant(weight_name, values=new_weight))
|
|
44
|
+
|
|
45
|
+
if len(conv_node.inputs) == 3:
|
|
46
|
+
conv_bias = conv_node.inputs[2].values
|
|
47
|
+
new_bias = conv_bias * mul_constant.squeeze()
|
|
48
|
+
bias_name = list(conv_node.inputs)[2].name
|
|
49
|
+
inputs.append(gs.Constant(bias_name, values=new_bias))
|
|
50
|
+
|
|
51
|
+
outputs = list(mul_node.outputs)
|
|
52
|
+
|
|
53
|
+
conv_node.outputs.clear()
|
|
54
|
+
mul_node.inputs.clear()
|
|
55
|
+
mul_node.outputs.clear()
|
|
56
|
+
|
|
57
|
+
match_case[conv_node.name] = {
|
|
58
|
+
"op": conv_node.op,
|
|
59
|
+
"inputs": inputs,
|
|
60
|
+
"outputs": outputs,
|
|
61
|
+
"name": conv_node.name,
|
|
62
|
+
"attrs": conv_node.attrs,
|
|
63
|
+
"domain": None,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
return match_case
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
register_fusion_pattern(ConvMulMatcher(1))
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class GeluPatternMatcher(PatternMatcher):
|
|
5
|
+
def __init__(self, priority):
|
|
6
|
+
"""Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph."""
|
|
7
|
+
pattern = Pattern(
|
|
8
|
+
"""
|
|
9
|
+
input input 0 2 mul_0 div_0
|
|
10
|
+
Div div_0 2 1 input ? erf_0
|
|
11
|
+
Erf erf_0 1 1 div_0 add_0
|
|
12
|
+
Add add_0 2 1 erf_0 ? mul_0
|
|
13
|
+
Mul mul_0 2 1 input add_0 mul_1
|
|
14
|
+
Mul mul_1 2 1 mul_0 ? output
|
|
15
|
+
output output 1 0 mul_1
|
|
16
|
+
"""
|
|
17
|
+
)
|
|
18
|
+
super().__init__(pattern, priority)
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def name(self):
|
|
22
|
+
"""Returns the name of the fusion pattern, 'FusionGelu'."""
|
|
23
|
+
return "FusionGelu"
|
|
24
|
+
|
|
25
|
+
def rewrite(self, opset=11):
|
|
26
|
+
"""Rewrite the computation graph pattern to fuse GELU operations."""
|
|
27
|
+
input_variable = self.div_0.inputs[0]
|
|
28
|
+
mul_node = self.mul_0
|
|
29
|
+
div_node = self.div_0
|
|
30
|
+
|
|
31
|
+
input_variable.outputs.remove(mul_node)
|
|
32
|
+
input_variable.outputs.remove(div_node)
|
|
33
|
+
|
|
34
|
+
output_variable = self.mul_1.outputs[0]
|
|
35
|
+
output_variable.inputs.clear()
|
|
36
|
+
|
|
37
|
+
return {
|
|
38
|
+
self.mul_1.name: {
|
|
39
|
+
"op": "Gelu",
|
|
40
|
+
"inputs": [input_variable],
|
|
41
|
+
"outputs": [output_variable],
|
|
42
|
+
"domain": None,
|
|
43
|
+
}
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# register_fusion_pattern(GeluPatternMatcher(1))
|
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
4
|
+
from onnxslim.core.optimization.dead_node_elimination import get_constant_variable
|
|
5
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
6
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class MatMulAddPatternMatcher(PatternMatcher):
|
|
10
|
+
def __init__(self, priority):
|
|
11
|
+
"""Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
|
|
12
|
+
pattern = Pattern(
|
|
13
|
+
"""
|
|
14
|
+
input input 0 1 matmul_0
|
|
15
|
+
MatMul matmul_0 2 1 input ? add_0
|
|
16
|
+
Add add_0 1* 1 matmul_0 output
|
|
17
|
+
output output 1 0 add_0
|
|
18
|
+
"""
|
|
19
|
+
)
|
|
20
|
+
super().__init__(pattern, priority)
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def name(self):
|
|
24
|
+
"""Returns the name of the fusion pattern as a string 'FusionGemm'."""
|
|
25
|
+
return "FusionGemm"
|
|
26
|
+
|
|
27
|
+
def rewrite(self, opset=11):
|
|
28
|
+
"""Rewrites the graph for the fusion pattern 'FusionGemm' based on matching criteria and constant variables in
|
|
29
|
+
matmul nodes.
|
|
30
|
+
"""
|
|
31
|
+
match_case = {}
|
|
32
|
+
node = self.add_0
|
|
33
|
+
matmul_node = self.matmul_0
|
|
34
|
+
matmul_bias_variable = get_constant_variable(matmul_node)
|
|
35
|
+
add_bias_variable = get_constant_variable(node)
|
|
36
|
+
input_variable = (
|
|
37
|
+
matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1]
|
|
38
|
+
)
|
|
39
|
+
users = matmul_node.users
|
|
40
|
+
if len(users) == 1 and matmul_bias_variable and add_bias_variable and len(matmul_bias_variable.shape) == 2:
|
|
41
|
+
if (
|
|
42
|
+
input_variable.shape
|
|
43
|
+
and len(input_variable.shape) > 2
|
|
44
|
+
and all([isinstance(value, int) for value in input_variable.shape])
|
|
45
|
+
):
|
|
46
|
+
pre_reshape_const = gs.Constant(
|
|
47
|
+
f"{matmul_node.name}_pre_reshape_in",
|
|
48
|
+
values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64),
|
|
49
|
+
)
|
|
50
|
+
inputs = []
|
|
51
|
+
inputs.append(input_variable)
|
|
52
|
+
inputs.append(pre_reshape_const)
|
|
53
|
+
|
|
54
|
+
reshape_out_variable = gs.Variable(
|
|
55
|
+
f"{matmul_node.name}_pre_reshape_out",
|
|
56
|
+
dtype=input_variable.dtype,
|
|
57
|
+
)
|
|
58
|
+
outputs = [reshape_out_variable]
|
|
59
|
+
|
|
60
|
+
match_case.update(
|
|
61
|
+
{
|
|
62
|
+
f"{matmul_node.name}_pre_reshape": {
|
|
63
|
+
"op": "Reshape",
|
|
64
|
+
"inputs": inputs,
|
|
65
|
+
"outputs": outputs,
|
|
66
|
+
"name": f"{matmul_node.name}_pre_reshape",
|
|
67
|
+
"domain": None,
|
|
68
|
+
}
|
|
69
|
+
}
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
add_node = node
|
|
73
|
+
add_bias_variable = get_constant_variable(add_node)
|
|
74
|
+
|
|
75
|
+
output_variable = add_node.inputs[0]
|
|
76
|
+
output_variable.outputs.remove(add_node)
|
|
77
|
+
|
|
78
|
+
matmul_bias_transpose_constant = gs.Constant(
|
|
79
|
+
matmul_bias_variable.name, values=matmul_bias_variable.values.T
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
inputs = []
|
|
83
|
+
inputs.append(reshape_out_variable)
|
|
84
|
+
inputs.append(matmul_bias_transpose_constant)
|
|
85
|
+
inputs.append(add_bias_variable)
|
|
86
|
+
|
|
87
|
+
gemm_out_variable = gs.Variable(f"{matmul_node.name}_gemm_out", dtype=output_variable.dtype)
|
|
88
|
+
outputs = [gemm_out_variable]
|
|
89
|
+
|
|
90
|
+
match_case.update(
|
|
91
|
+
{
|
|
92
|
+
matmul_node.name: {
|
|
93
|
+
"op": "Gemm",
|
|
94
|
+
"inputs": inputs,
|
|
95
|
+
"outputs": outputs,
|
|
96
|
+
"name": matmul_node.name,
|
|
97
|
+
"attrs": {
|
|
98
|
+
"alpha": 1.0,
|
|
99
|
+
"beta": 1.0,
|
|
100
|
+
"transA": 0,
|
|
101
|
+
"transB": 1,
|
|
102
|
+
},
|
|
103
|
+
"domain": None,
|
|
104
|
+
}
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
values = [*list(input_variable.shape[:-1]), matmul_bias_variable.values.shape[-1]]
|
|
109
|
+
post_reshape_const = gs.Constant(
|
|
110
|
+
f"{matmul_node.name}_post_reshape_in",
|
|
111
|
+
values=np.array(values, dtype=np.int64),
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
inputs = []
|
|
115
|
+
inputs.append(gemm_out_variable)
|
|
116
|
+
inputs.append(post_reshape_const)
|
|
117
|
+
outputs = list(add_node.outputs)
|
|
118
|
+
|
|
119
|
+
matmul_node.outputs.clear()
|
|
120
|
+
add_node.inputs.clear()
|
|
121
|
+
add_node.outputs.clear()
|
|
122
|
+
|
|
123
|
+
match_case.update(
|
|
124
|
+
{
|
|
125
|
+
f"{matmul_node.name}_post_reshape": {
|
|
126
|
+
"op": "Reshape",
|
|
127
|
+
"inputs": inputs,
|
|
128
|
+
"outputs": outputs,
|
|
129
|
+
"name": f"{matmul_node.name}_post_reshape",
|
|
130
|
+
"domain": None,
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
)
|
|
134
|
+
elif (
|
|
135
|
+
input_variable.shape
|
|
136
|
+
and len(input_variable.shape) == 2
|
|
137
|
+
and all([isinstance(value, int) for value in input_variable.shape])
|
|
138
|
+
):
|
|
139
|
+
add_node = node
|
|
140
|
+
add_bias_variable = get_constant_variable(add_node)
|
|
141
|
+
|
|
142
|
+
output_variable = add_node.inputs[0]
|
|
143
|
+
output_variable.outputs.remove(add_node)
|
|
144
|
+
|
|
145
|
+
matmul_bias_transpose_constant = gs.Constant(
|
|
146
|
+
matmul_bias_variable.name, values=matmul_bias_variable.values.T
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
inputs = []
|
|
150
|
+
inputs.append(input_variable)
|
|
151
|
+
inputs.append(matmul_bias_transpose_constant)
|
|
152
|
+
inputs.append(add_bias_variable)
|
|
153
|
+
|
|
154
|
+
outputs = list(add_node.outputs)
|
|
155
|
+
add_node.inputs.clear()
|
|
156
|
+
add_node.outputs.clear()
|
|
157
|
+
match_case.update(
|
|
158
|
+
{
|
|
159
|
+
matmul_node.name: {
|
|
160
|
+
"op": "Gemm",
|
|
161
|
+
"inputs": inputs,
|
|
162
|
+
"outputs": outputs,
|
|
163
|
+
"name": matmul_node.name,
|
|
164
|
+
"attrs": {
|
|
165
|
+
"alpha": 1.0,
|
|
166
|
+
"beta": 1.0,
|
|
167
|
+
"transA": 0,
|
|
168
|
+
"transB": 1,
|
|
169
|
+
},
|
|
170
|
+
"domain": None,
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
)
|
|
174
|
+
return match_case
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
register_fusion_pattern(MatMulAddPatternMatcher(1))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class GemmMulPatternMatcher(PatternMatcher):
|
|
181
|
+
def __init__(self, priority):
|
|
182
|
+
"""Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
|
|
183
|
+
pattern = Pattern(
|
|
184
|
+
"""
|
|
185
|
+
input input 0 1 gemm_0
|
|
186
|
+
Gemm gemm_0 1+ 1 input reshape_0
|
|
187
|
+
Reshape reshape_0 2 1 gemm_0 ? mul_0
|
|
188
|
+
Mul mul_0 1* 1 reshape_0 output
|
|
189
|
+
output output 1 0 mul_0
|
|
190
|
+
"""
|
|
191
|
+
)
|
|
192
|
+
super().__init__(pattern, priority)
|
|
193
|
+
|
|
194
|
+
@property
|
|
195
|
+
def name(self):
|
|
196
|
+
"""Returns the name of the fusion pattern as a string 'FusionGemmMul'."""
|
|
197
|
+
return "FusionGemmMul"
|
|
198
|
+
|
|
199
|
+
def rewrite(self, opset=11):
|
|
200
|
+
"""Rewrites the graph for the fusion pattern 'FusionGemmMul' based on matching criteria and constant variables
|
|
201
|
+
in gemm nodes.
|
|
202
|
+
"""
|
|
203
|
+
match_case = {}
|
|
204
|
+
gemm_node = self.gemm_0
|
|
205
|
+
reshape_node = self.reshape_0
|
|
206
|
+
mul_node = self.mul_0
|
|
207
|
+
mul_bias_variable = get_constant_variable(mul_node)
|
|
208
|
+
|
|
209
|
+
if (
|
|
210
|
+
(
|
|
211
|
+
(len(gemm_node.inputs) == 2 and isinstance(gemm_node.inputs[1], gs.Constant))
|
|
212
|
+
or (
|
|
213
|
+
len(gemm_node.inputs) == 3
|
|
214
|
+
and isinstance(gemm_node.inputs[1], gs.Constant)
|
|
215
|
+
and isinstance(gemm_node.inputs[2], gs.Constant)
|
|
216
|
+
)
|
|
217
|
+
)
|
|
218
|
+
and mul_bias_variable
|
|
219
|
+
and len(reshape_node.users) == 1
|
|
220
|
+
):
|
|
221
|
+
gemm_attr = gemm_node.attrs
|
|
222
|
+
gemm_weight_constant = gemm_node.inputs[1]
|
|
223
|
+
gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
|
|
224
|
+
if (
|
|
225
|
+
gemm_attr["transA"] == 0
|
|
226
|
+
and gemm_attr["transB"] == 1
|
|
227
|
+
and (
|
|
228
|
+
(mul_bias_variable.values.ndim == 1 and gemm_weight_constant.shape[0] == mul_bias_variable.shape[0])
|
|
229
|
+
or mul_bias_variable.values.ndim == 0
|
|
230
|
+
)
|
|
231
|
+
):
|
|
232
|
+
gemm_weight = gemm_weight_constant.values
|
|
233
|
+
mul_weight = mul_bias_variable.values
|
|
234
|
+
if mul_bias_variable.values.ndim == 1:
|
|
235
|
+
gemm_weight_fused = gemm_weight * mul_weight[:, None]
|
|
236
|
+
else:
|
|
237
|
+
gemm_weight_fused = gemm_weight * mul_weight
|
|
238
|
+
gemm_weight_fused_constant = gs.Constant(gemm_weight_constant.name + "_fused", values=gemm_weight_fused)
|
|
239
|
+
gemm_node.inputs[1] = gemm_weight_fused_constant
|
|
240
|
+
|
|
241
|
+
if gemm_bias_constant:
|
|
242
|
+
gemm_bias = gemm_bias_constant.values
|
|
243
|
+
mul_bias = mul_bias_variable.values
|
|
244
|
+
gemm_bias_fused = gemm_bias * mul_bias
|
|
245
|
+
gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
|
|
246
|
+
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
247
|
+
|
|
248
|
+
mul_node.replace_all_uses_with(reshape_node)
|
|
249
|
+
|
|
250
|
+
return match_case
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
register_fusion_pattern(GemmMulPatternMatcher(1))
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
class GemmAddPatternMatcher(PatternMatcher):
|
|
257
|
+
def __init__(self, priority):
|
|
258
|
+
"""Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
|
|
259
|
+
pattern = Pattern(
|
|
260
|
+
"""
|
|
261
|
+
input input 0 1 gemm_0
|
|
262
|
+
Gemm gemm_0 1+ 1 input reshape_0
|
|
263
|
+
Reshape reshape_0 2 1 gemm_0 ? add_0
|
|
264
|
+
Add add_0 1* 1 reshape_0 output
|
|
265
|
+
output output 1 0 add_0
|
|
266
|
+
"""
|
|
267
|
+
)
|
|
268
|
+
super().__init__(pattern, priority)
|
|
269
|
+
|
|
270
|
+
@property
|
|
271
|
+
def name(self):
|
|
272
|
+
"""Returns the name of the fusion pattern as a string 'FusionGemmAdd'."""
|
|
273
|
+
return "FusionGemmAdd"
|
|
274
|
+
|
|
275
|
+
def rewrite(self, opset=11):
|
|
276
|
+
"""Rewrites the graph for the fusion pattern 'FusionGemmAdd' based on matching criteria and constant variables
|
|
277
|
+
in gemm nodes.
|
|
278
|
+
"""
|
|
279
|
+
match_case = {}
|
|
280
|
+
gemm_node = self.gemm_0
|
|
281
|
+
reshape_node = self.reshape_0
|
|
282
|
+
add_node = self.add_0
|
|
283
|
+
add_bias_variable = get_constant_variable(add_node)
|
|
284
|
+
|
|
285
|
+
if (
|
|
286
|
+
(
|
|
287
|
+
(len(gemm_node.inputs) == 2)
|
|
288
|
+
or (len(gemm_node.inputs) == 3 and isinstance(gemm_node.inputs[2], gs.Constant))
|
|
289
|
+
)
|
|
290
|
+
and add_bias_variable
|
|
291
|
+
and len(reshape_node.users) == 1
|
|
292
|
+
and gemm_node.outputs[0].shape
|
|
293
|
+
):
|
|
294
|
+
|
|
295
|
+
def can_broadcast_to(shape_from, shape_to):
|
|
296
|
+
"""Return True if shape_from can broadcast to shape_to per NumPy rules."""
|
|
297
|
+
if shape_from is None or shape_to is None:
|
|
298
|
+
return False
|
|
299
|
+
try:
|
|
300
|
+
np.empty(shape_to, dtype=np.float32) + np.empty(shape_from, dtype=np.float32)
|
|
301
|
+
return True
|
|
302
|
+
except ValueError:
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
|
|
306
|
+
if gemm_bias_constant:
|
|
307
|
+
gemm_bias = gemm_bias_constant.values
|
|
308
|
+
add_bias = add_bias_variable.values
|
|
309
|
+
if (
|
|
310
|
+
can_broadcast_to(gemm_bias.shape, gemm_node.outputs[0].shape)
|
|
311
|
+
and can_broadcast_to(add_bias.shape, gemm_node.outputs[0].shape)
|
|
312
|
+
and add_bias.ndim <= 2
|
|
313
|
+
):
|
|
314
|
+
gemm_bias_fused = gemm_bias + add_bias
|
|
315
|
+
gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
|
|
316
|
+
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
317
|
+
else:
|
|
318
|
+
return match_case
|
|
319
|
+
else:
|
|
320
|
+
if can_broadcast_to(add_bias_variable.values.shape, gemm_node.outputs[0].shape):
|
|
321
|
+
gemm_node.inputs[2] = add_bias_variable
|
|
322
|
+
else:
|
|
323
|
+
return match_case
|
|
324
|
+
|
|
325
|
+
add_node.replace_all_uses_with(reshape_node)
|
|
326
|
+
|
|
327
|
+
return match_case
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
register_fusion_pattern(GemmAddPatternMatcher(1))
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class PadConvMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the PadConvMatcher with a specified priority and defines its matching pattern."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 pad_0
|
|
12
|
+
Pad pad_0 1+ 1 input conv_0
|
|
13
|
+
Conv conv_0 1+ 1 pad_0 output
|
|
14
|
+
output output 1 0 conv_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the fusion pattern used."""
|
|
22
|
+
return "FusionPadConv"
|
|
23
|
+
|
|
24
|
+
def parameter_check(self) -> bool:
|
|
25
|
+
"""Validates if the padding parameter for a convolutional node is a constant."""
|
|
26
|
+
pad_node = self.pad_0
|
|
27
|
+
|
|
28
|
+
return isinstance(pad_node.inputs[1], gs.Constant)
|
|
29
|
+
|
|
30
|
+
def rewrite(self, opset=11):
|
|
31
|
+
"""Rewrites the padding parameter for a convolutional node to use a constant if the current parameter is not a
|
|
32
|
+
constant.
|
|
33
|
+
"""
|
|
34
|
+
match_case = {}
|
|
35
|
+
conv_node = self.conv_0
|
|
36
|
+
pad_node = self.pad_0
|
|
37
|
+
pad_node_users = pad_node.users
|
|
38
|
+
|
|
39
|
+
pad_inputs = len(pad_node.inputs)
|
|
40
|
+
if pad_inputs < 3 or (
|
|
41
|
+
(pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
|
|
42
|
+
or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
|
|
43
|
+
):
|
|
44
|
+
if (
|
|
45
|
+
isinstance(pad_node.inputs[1], gs.Constant)
|
|
46
|
+
and pad_node.attrs.get("mode", "constant") == "constant"
|
|
47
|
+
and conv_node.inputs[1].shape
|
|
48
|
+
):
|
|
49
|
+
conv_weight_dim = len(conv_node.inputs[1].shape)
|
|
50
|
+
pad_value = pad_node.inputs[1].values.tolist()
|
|
51
|
+
if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])):
|
|
52
|
+
conv_weight_dim - 2
|
|
53
|
+
input_variable = self.pad_0.inputs[0]
|
|
54
|
+
pad_variable = pad_node.outputs[0] # pad output variable
|
|
55
|
+
index = conv_node.inputs.index(pad_variable)
|
|
56
|
+
conv_node.inputs.pop(index)
|
|
57
|
+
conv_node.inputs.insert(index, input_variable)
|
|
58
|
+
|
|
59
|
+
inputs = list(conv_node.inputs)
|
|
60
|
+
outputs = list(conv_node.outputs)
|
|
61
|
+
attrs = conv_node.attrs
|
|
62
|
+
|
|
63
|
+
conv_node.inputs.clear()
|
|
64
|
+
conv_node.outputs.clear()
|
|
65
|
+
# remove pad node if it has only one user
|
|
66
|
+
if len(pad_node_users) == 0:
|
|
67
|
+
input_variable.outputs.remove(pad_node)
|
|
68
|
+
pad_node.inputs.clear()
|
|
69
|
+
pad_node.outputs.clear()
|
|
70
|
+
|
|
71
|
+
pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :]
|
|
72
|
+
if hasattr(attrs, "pads"):
|
|
73
|
+
conv_pads = attrs["pads"]
|
|
74
|
+
pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
|
|
75
|
+
|
|
76
|
+
attrs["pads"] = pads
|
|
77
|
+
match_case[conv_node.name] = {
|
|
78
|
+
"op": "Conv",
|
|
79
|
+
"inputs": inputs,
|
|
80
|
+
"outputs": outputs,
|
|
81
|
+
"name": conv_node.name,
|
|
82
|
+
"attrs": conv_node.attrs,
|
|
83
|
+
"domain": None,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
return match_case
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
register_fusion_pattern(PadConvMatcher(1))
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
2
|
+
from onnxslim.core.pattern import Pattern, PatternMatcher
|
|
3
|
+
from onnxslim.core.pattern.registry import register_fusion_pattern
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ReducePatternMatcher(PatternMatcher):
|
|
7
|
+
def __init__(self, priority):
|
|
8
|
+
"""Initializes the ReducePatternMatcher with a specified pattern matching priority level."""
|
|
9
|
+
pattern = Pattern(
|
|
10
|
+
"""
|
|
11
|
+
input input 0 1 reduce_0
|
|
12
|
+
ReduceSum reduce_0 1+ 1 input unsqueeze_0
|
|
13
|
+
Unsqueeze unsqueeze_0 1+ 1 reduce_0 output
|
|
14
|
+
output output 1 0 unsqueeze_0
|
|
15
|
+
"""
|
|
16
|
+
)
|
|
17
|
+
super().__init__(pattern, priority)
|
|
18
|
+
|
|
19
|
+
@property
|
|
20
|
+
def name(self):
|
|
21
|
+
"""Returns the name of the fusion pattern 'FusionReduce'."""
|
|
22
|
+
return "FusionReduce"
|
|
23
|
+
|
|
24
|
+
def rewrite(self, opset=11):
|
|
25
|
+
"""Rewrites the graph pattern based on opset version; reuses Reduce and Unsqueeze nodes if possible."""
|
|
26
|
+
match_case = {}
|
|
27
|
+
node = self.unsqueeze_0
|
|
28
|
+
reduce_node = self.reduce_0
|
|
29
|
+
reduce_node_node_users = reduce_node.users
|
|
30
|
+
if len(reduce_node_node_users) == 1:
|
|
31
|
+
unsqueeze_node = node
|
|
32
|
+
|
|
33
|
+
if opset < 13:
|
|
34
|
+
reduce_node_axes = reduce_node.attrs.get("axes", None)
|
|
35
|
+
reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1)
|
|
36
|
+
unsqueeze_node_axes = unsqueeze_node.attrs.get("axes", None)
|
|
37
|
+
else:
|
|
38
|
+
reduce_node_axes_ = reduce_node.inputs[1]
|
|
39
|
+
reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1)
|
|
40
|
+
unsqueeze_node_axes_ = unsqueeze_node.inputs[1]
|
|
41
|
+
if isinstance(reduce_node_axes_, gs.Constant) and isinstance(unsqueeze_node_axes_, gs.Constant):
|
|
42
|
+
reduce_node_axes = reduce_node_axes_.values
|
|
43
|
+
unsqueeze_node_axes = unsqueeze_node_axes_.values
|
|
44
|
+
else:
|
|
45
|
+
return match_case
|
|
46
|
+
|
|
47
|
+
if reduce_node_axes == unsqueeze_node_axes and reduce_node_keepdims == 0:
|
|
48
|
+
inputs = list(reduce_node.inputs)
|
|
49
|
+
outputs = list(unsqueeze_node.outputs)
|
|
50
|
+
attrs = reduce_node.attrs
|
|
51
|
+
reduce_node.outputs.clear()
|
|
52
|
+
unsqueeze_node.inputs.clear()
|
|
53
|
+
unsqueeze_node.outputs.clear()
|
|
54
|
+
attrs["keepdims"] = 1
|
|
55
|
+
match_case[reduce_node.name] = {
|
|
56
|
+
"op": reduce_node.op,
|
|
57
|
+
"inputs": inputs,
|
|
58
|
+
"outputs": outputs,
|
|
59
|
+
"name": reduce_node.name,
|
|
60
|
+
"attrs": attrs,
|
|
61
|
+
"domain": None,
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
return match_case
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
register_fusion_pattern(ReducePatternMatcher(1))
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
|
|
5
|
+
DEFAULT_FUSION_PATTERNS = OrderedDict()
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def register_fusion_pattern(fusion_pattern):
|
|
9
|
+
"""Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary."""
|
|
10
|
+
layer_type = fusion_pattern.name
|
|
11
|
+
|
|
12
|
+
if layer_type in DEFAULT_FUSION_PATTERNS.keys():
|
|
13
|
+
raise
|
|
14
|
+
DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def get_fusion_patterns(skip_fusion_patterns: str | None = None):
|
|
18
|
+
"""Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
|
|
19
|
+
default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
|
|
20
|
+
if skip_fusion_patterns:
|
|
21
|
+
for pattern in skip_fusion_patterns:
|
|
22
|
+
default_fusion_patterns.pop(pattern)
|
|
23
|
+
|
|
24
|
+
return default_fusion_patterns
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
from .elimination import *
|
|
28
|
+
from .fusion import *
|
|
File without changes
|