onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,69 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class ConvMulMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the ConvMulMatcher for fusing Conv and Mul layers in an ONNX graph."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 conv_0
12
+ Conv conv_0 1+ 1 input mul_0
13
+ Mul mul_0 2 1 conv_0 ? output
14
+ output output 1 0 mul_0
15
+ """
16
+ )
17
+ super().__init__(pattern, priority)
18
+
19
+ @property
20
+ def name(self):
21
+ """Returns the name of the FusionConvMul pattern."""
22
+ return "FusionConvMul"
23
+
24
+ def rewrite(self, opset=11):
25
+ match_case = {}
26
+ conv_node = self.conv_0
27
+ mul_node = self.mul_0
28
+ conv_weight = list(conv_node.inputs)[1]
29
+ if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
30
+ mul_constant = mul_node.inputs[1].values
31
+
32
+ if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
33
+ weight_shape = conv_weight.values.shape
34
+ reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
35
+
36
+ mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
37
+ new_weight = conv_weight.values * mul_scale_reshaped
38
+
39
+ inputs = []
40
+ inputs.append(next(iter(conv_node.inputs)))
41
+
42
+ weight_name = list(conv_node.inputs)[1].name
43
+ inputs.append(gs.Constant(weight_name, values=new_weight))
44
+
45
+ if len(conv_node.inputs) == 3:
46
+ conv_bias = conv_node.inputs[2].values
47
+ new_bias = conv_bias * mul_constant.squeeze()
48
+ bias_name = list(conv_node.inputs)[2].name
49
+ inputs.append(gs.Constant(bias_name, values=new_bias))
50
+
51
+ outputs = list(mul_node.outputs)
52
+
53
+ conv_node.outputs.clear()
54
+ mul_node.inputs.clear()
55
+ mul_node.outputs.clear()
56
+
57
+ match_case[conv_node.name] = {
58
+ "op": conv_node.op,
59
+ "inputs": inputs,
60
+ "outputs": outputs,
61
+ "name": conv_node.name,
62
+ "attrs": conv_node.attrs,
63
+ "domain": None,
64
+ }
65
+
66
+ return match_case
67
+
68
+
69
+ register_fusion_pattern(ConvMulMatcher(1))
@@ -0,0 +1,47 @@
1
+ from onnxslim.core.pattern import Pattern, PatternMatcher
2
+
3
+
4
+ class GeluPatternMatcher(PatternMatcher):
5
+ def __init__(self, priority):
6
+ """Initializes a `GeluPatternMatcher` to identify and fuse GELU patterns in a computational graph."""
7
+ pattern = Pattern(
8
+ """
9
+ input input 0 2 mul_0 div_0
10
+ Div div_0 2 1 input ? erf_0
11
+ Erf erf_0 1 1 div_0 add_0
12
+ Add add_0 2 1 erf_0 ? mul_0
13
+ Mul mul_0 2 1 input add_0 mul_1
14
+ Mul mul_1 2 1 mul_0 ? output
15
+ output output 1 0 mul_1
16
+ """
17
+ )
18
+ super().__init__(pattern, priority)
19
+
20
+ @property
21
+ def name(self):
22
+ """Returns the name of the fusion pattern, 'FusionGelu'."""
23
+ return "FusionGelu"
24
+
25
+ def rewrite(self, opset=11):
26
+ """Rewrite the computation graph pattern to fuse GELU operations."""
27
+ input_variable = self.div_0.inputs[0]
28
+ mul_node = self.mul_0
29
+ div_node = self.div_0
30
+
31
+ input_variable.outputs.remove(mul_node)
32
+ input_variable.outputs.remove(div_node)
33
+
34
+ output_variable = self.mul_1.outputs[0]
35
+ output_variable.inputs.clear()
36
+
37
+ return {
38
+ self.mul_1.name: {
39
+ "op": "Gelu",
40
+ "inputs": [input_variable],
41
+ "outputs": [output_variable],
42
+ "domain": None,
43
+ }
44
+ }
45
+
46
+
47
+ # register_fusion_pattern(GeluPatternMatcher(1))
@@ -0,0 +1,330 @@
1
+ import numpy as np
2
+
3
+ import onnxslim.third_party.onnx_graphsurgeon as gs
4
+ from onnxslim.core.optimization.dead_node_elimination import get_constant_variable
5
+ from onnxslim.core.pattern import Pattern, PatternMatcher
6
+ from onnxslim.core.pattern.registry import register_fusion_pattern
7
+
8
+
9
+ class MatMulAddPatternMatcher(PatternMatcher):
10
+ def __init__(self, priority):
11
+ """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
12
+ pattern = Pattern(
13
+ """
14
+ input input 0 1 matmul_0
15
+ MatMul matmul_0 2 1 input ? add_0
16
+ Add add_0 1* 1 matmul_0 output
17
+ output output 1 0 add_0
18
+ """
19
+ )
20
+ super().__init__(pattern, priority)
21
+
22
+ @property
23
+ def name(self):
24
+ """Returns the name of the fusion pattern as a string 'FusionGemm'."""
25
+ return "FusionGemm"
26
+
27
+ def rewrite(self, opset=11):
28
+ """Rewrites the graph for the fusion pattern 'FusionGemm' based on matching criteria and constant variables in
29
+ matmul nodes.
30
+ """
31
+ match_case = {}
32
+ node = self.add_0
33
+ matmul_node = self.matmul_0
34
+ matmul_bias_variable = get_constant_variable(matmul_node)
35
+ add_bias_variable = get_constant_variable(node)
36
+ input_variable = (
37
+ matmul_node.inputs[0] if isinstance(matmul_node.inputs[1], gs.Constant) else matmul_node.inputs[1]
38
+ )
39
+ users = matmul_node.users
40
+ if len(users) == 1 and matmul_bias_variable and add_bias_variable and len(matmul_bias_variable.shape) == 2:
41
+ if (
42
+ input_variable.shape
43
+ and len(input_variable.shape) > 2
44
+ and all([isinstance(value, int) for value in input_variable.shape])
45
+ ):
46
+ pre_reshape_const = gs.Constant(
47
+ f"{matmul_node.name}_pre_reshape_in",
48
+ values=np.array([-1, matmul_bias_variable.values.shape[0]], dtype=np.int64),
49
+ )
50
+ inputs = []
51
+ inputs.append(input_variable)
52
+ inputs.append(pre_reshape_const)
53
+
54
+ reshape_out_variable = gs.Variable(
55
+ f"{matmul_node.name}_pre_reshape_out",
56
+ dtype=input_variable.dtype,
57
+ )
58
+ outputs = [reshape_out_variable]
59
+
60
+ match_case.update(
61
+ {
62
+ f"{matmul_node.name}_pre_reshape": {
63
+ "op": "Reshape",
64
+ "inputs": inputs,
65
+ "outputs": outputs,
66
+ "name": f"{matmul_node.name}_pre_reshape",
67
+ "domain": None,
68
+ }
69
+ }
70
+ )
71
+
72
+ add_node = node
73
+ add_bias_variable = get_constant_variable(add_node)
74
+
75
+ output_variable = add_node.inputs[0]
76
+ output_variable.outputs.remove(add_node)
77
+
78
+ matmul_bias_transpose_constant = gs.Constant(
79
+ matmul_bias_variable.name, values=matmul_bias_variable.values.T
80
+ )
81
+
82
+ inputs = []
83
+ inputs.append(reshape_out_variable)
84
+ inputs.append(matmul_bias_transpose_constant)
85
+ inputs.append(add_bias_variable)
86
+
87
+ gemm_out_variable = gs.Variable(f"{matmul_node.name}_gemm_out", dtype=output_variable.dtype)
88
+ outputs = [gemm_out_variable]
89
+
90
+ match_case.update(
91
+ {
92
+ matmul_node.name: {
93
+ "op": "Gemm",
94
+ "inputs": inputs,
95
+ "outputs": outputs,
96
+ "name": matmul_node.name,
97
+ "attrs": {
98
+ "alpha": 1.0,
99
+ "beta": 1.0,
100
+ "transA": 0,
101
+ "transB": 1,
102
+ },
103
+ "domain": None,
104
+ }
105
+ }
106
+ )
107
+
108
+ values = [*list(input_variable.shape[:-1]), matmul_bias_variable.values.shape[-1]]
109
+ post_reshape_const = gs.Constant(
110
+ f"{matmul_node.name}_post_reshape_in",
111
+ values=np.array(values, dtype=np.int64),
112
+ )
113
+
114
+ inputs = []
115
+ inputs.append(gemm_out_variable)
116
+ inputs.append(post_reshape_const)
117
+ outputs = list(add_node.outputs)
118
+
119
+ matmul_node.outputs.clear()
120
+ add_node.inputs.clear()
121
+ add_node.outputs.clear()
122
+
123
+ match_case.update(
124
+ {
125
+ f"{matmul_node.name}_post_reshape": {
126
+ "op": "Reshape",
127
+ "inputs": inputs,
128
+ "outputs": outputs,
129
+ "name": f"{matmul_node.name}_post_reshape",
130
+ "domain": None,
131
+ }
132
+ }
133
+ )
134
+ elif (
135
+ input_variable.shape
136
+ and len(input_variable.shape) == 2
137
+ and all([isinstance(value, int) for value in input_variable.shape])
138
+ ):
139
+ add_node = node
140
+ add_bias_variable = get_constant_variable(add_node)
141
+
142
+ output_variable = add_node.inputs[0]
143
+ output_variable.outputs.remove(add_node)
144
+
145
+ matmul_bias_transpose_constant = gs.Constant(
146
+ matmul_bias_variable.name, values=matmul_bias_variable.values.T
147
+ )
148
+
149
+ inputs = []
150
+ inputs.append(input_variable)
151
+ inputs.append(matmul_bias_transpose_constant)
152
+ inputs.append(add_bias_variable)
153
+
154
+ outputs = list(add_node.outputs)
155
+ add_node.inputs.clear()
156
+ add_node.outputs.clear()
157
+ match_case.update(
158
+ {
159
+ matmul_node.name: {
160
+ "op": "Gemm",
161
+ "inputs": inputs,
162
+ "outputs": outputs,
163
+ "name": matmul_node.name,
164
+ "attrs": {
165
+ "alpha": 1.0,
166
+ "beta": 1.0,
167
+ "transA": 0,
168
+ "transB": 1,
169
+ },
170
+ "domain": None,
171
+ }
172
+ }
173
+ )
174
+ return match_case
175
+
176
+
177
+ register_fusion_pattern(MatMulAddPatternMatcher(1))
178
+
179
+
180
+ class GemmMulPatternMatcher(PatternMatcher):
181
+ def __init__(self, priority):
182
+ """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
183
+ pattern = Pattern(
184
+ """
185
+ input input 0 1 gemm_0
186
+ Gemm gemm_0 1+ 1 input reshape_0
187
+ Reshape reshape_0 2 1 gemm_0 ? mul_0
188
+ Mul mul_0 1* 1 reshape_0 output
189
+ output output 1 0 mul_0
190
+ """
191
+ )
192
+ super().__init__(pattern, priority)
193
+
194
+ @property
195
+ def name(self):
196
+ """Returns the name of the fusion pattern as a string 'FusionGemmMul'."""
197
+ return "FusionGemmMul"
198
+
199
+ def rewrite(self, opset=11):
200
+ """Rewrites the graph for the fusion pattern 'FusionGemmMul' based on matching criteria and constant variables
201
+ in gemm nodes.
202
+ """
203
+ match_case = {}
204
+ gemm_node = self.gemm_0
205
+ reshape_node = self.reshape_0
206
+ mul_node = self.mul_0
207
+ mul_bias_variable = get_constant_variable(mul_node)
208
+
209
+ if (
210
+ (
211
+ (len(gemm_node.inputs) == 2 and isinstance(gemm_node.inputs[1], gs.Constant))
212
+ or (
213
+ len(gemm_node.inputs) == 3
214
+ and isinstance(gemm_node.inputs[1], gs.Constant)
215
+ and isinstance(gemm_node.inputs[2], gs.Constant)
216
+ )
217
+ )
218
+ and mul_bias_variable
219
+ and len(reshape_node.users) == 1
220
+ ):
221
+ gemm_attr = gemm_node.attrs
222
+ gemm_weight_constant = gemm_node.inputs[1]
223
+ gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
224
+ if (
225
+ gemm_attr["transA"] == 0
226
+ and gemm_attr["transB"] == 1
227
+ and (
228
+ (mul_bias_variable.values.ndim == 1 and gemm_weight_constant.shape[0] == mul_bias_variable.shape[0])
229
+ or mul_bias_variable.values.ndim == 0
230
+ )
231
+ ):
232
+ gemm_weight = gemm_weight_constant.values
233
+ mul_weight = mul_bias_variable.values
234
+ if mul_bias_variable.values.ndim == 1:
235
+ gemm_weight_fused = gemm_weight * mul_weight[:, None]
236
+ else:
237
+ gemm_weight_fused = gemm_weight * mul_weight
238
+ gemm_weight_fused_constant = gs.Constant(gemm_weight_constant.name + "_fused", values=gemm_weight_fused)
239
+ gemm_node.inputs[1] = gemm_weight_fused_constant
240
+
241
+ if gemm_bias_constant:
242
+ gemm_bias = gemm_bias_constant.values
243
+ mul_bias = mul_bias_variable.values
244
+ gemm_bias_fused = gemm_bias * mul_bias
245
+ gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
246
+ gemm_node.inputs[2] = gemm_bias_fused_constant
247
+
248
+ mul_node.replace_all_uses_with(reshape_node)
249
+
250
+ return match_case
251
+
252
+
253
+ register_fusion_pattern(GemmMulPatternMatcher(1))
254
+
255
+
256
+ class GemmAddPatternMatcher(PatternMatcher):
257
+ def __init__(self, priority):
258
+ """Initializes a matcher for fusing MatMul and Add operations in ONNX graph optimization."""
259
+ pattern = Pattern(
260
+ """
261
+ input input 0 1 gemm_0
262
+ Gemm gemm_0 1+ 1 input reshape_0
263
+ Reshape reshape_0 2 1 gemm_0 ? add_0
264
+ Add add_0 1* 1 reshape_0 output
265
+ output output 1 0 add_0
266
+ """
267
+ )
268
+ super().__init__(pattern, priority)
269
+
270
+ @property
271
+ def name(self):
272
+ """Returns the name of the fusion pattern as a string 'FusionGemmAdd'."""
273
+ return "FusionGemmAdd"
274
+
275
+ def rewrite(self, opset=11):
276
+ """Rewrites the graph for the fusion pattern 'FusionGemmAdd' based on matching criteria and constant variables
277
+ in gemm nodes.
278
+ """
279
+ match_case = {}
280
+ gemm_node = self.gemm_0
281
+ reshape_node = self.reshape_0
282
+ add_node = self.add_0
283
+ add_bias_variable = get_constant_variable(add_node)
284
+
285
+ if (
286
+ (
287
+ (len(gemm_node.inputs) == 2)
288
+ or (len(gemm_node.inputs) == 3 and isinstance(gemm_node.inputs[2], gs.Constant))
289
+ )
290
+ and add_bias_variable
291
+ and len(reshape_node.users) == 1
292
+ and gemm_node.outputs[0].shape
293
+ ):
294
+
295
+ def can_broadcast_to(shape_from, shape_to):
296
+ """Return True if shape_from can broadcast to shape_to per NumPy rules."""
297
+ if shape_from is None or shape_to is None:
298
+ return False
299
+ try:
300
+ np.empty(shape_to, dtype=np.float32) + np.empty(shape_from, dtype=np.float32)
301
+ return True
302
+ except ValueError:
303
+ return False
304
+
305
+ gemm_bias_constant = gemm_node.inputs[2] if len(gemm_node.inputs) == 3 else None
306
+ if gemm_bias_constant:
307
+ gemm_bias = gemm_bias_constant.values
308
+ add_bias = add_bias_variable.values
309
+ if (
310
+ can_broadcast_to(gemm_bias.shape, gemm_node.outputs[0].shape)
311
+ and can_broadcast_to(add_bias.shape, gemm_node.outputs[0].shape)
312
+ and add_bias.ndim <= 2
313
+ ):
314
+ gemm_bias_fused = gemm_bias + add_bias
315
+ gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
316
+ gemm_node.inputs[2] = gemm_bias_fused_constant
317
+ else:
318
+ return match_case
319
+ else:
320
+ if can_broadcast_to(add_bias_variable.values.shape, gemm_node.outputs[0].shape):
321
+ gemm_node.inputs[2] = add_bias_variable
322
+ else:
323
+ return match_case
324
+
325
+ add_node.replace_all_uses_with(reshape_node)
326
+
327
+ return match_case
328
+
329
+
330
+ register_fusion_pattern(GemmAddPatternMatcher(1))
@@ -0,0 +1,89 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class PadConvMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the PadConvMatcher with a specified priority and defines its matching pattern."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 pad_0
12
+ Pad pad_0 1+ 1 input conv_0
13
+ Conv conv_0 1+ 1 pad_0 output
14
+ output output 1 0 conv_0
15
+ """
16
+ )
17
+ super().__init__(pattern, priority)
18
+
19
+ @property
20
+ def name(self):
21
+ """Returns the name of the fusion pattern used."""
22
+ return "FusionPadConv"
23
+
24
+ def parameter_check(self) -> bool:
25
+ """Validates if the padding parameter for a convolutional node is a constant."""
26
+ pad_node = self.pad_0
27
+
28
+ return isinstance(pad_node.inputs[1], gs.Constant)
29
+
30
+ def rewrite(self, opset=11):
31
+ """Rewrites the padding parameter for a convolutional node to use a constant if the current parameter is not a
32
+ constant.
33
+ """
34
+ match_case = {}
35
+ conv_node = self.conv_0
36
+ pad_node = self.pad_0
37
+ pad_node_users = pad_node.users
38
+
39
+ pad_inputs = len(pad_node.inputs)
40
+ if pad_inputs < 3 or (
41
+ (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
42
+ or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
43
+ ):
44
+ if (
45
+ isinstance(pad_node.inputs[1], gs.Constant)
46
+ and pad_node.attrs.get("mode", "constant") == "constant"
47
+ and conv_node.inputs[1].shape
48
+ ):
49
+ conv_weight_dim = len(conv_node.inputs[1].shape)
50
+ pad_value = pad_node.inputs[1].values.tolist()
51
+ if all(pad == 0 for pad in (pad_value[:2] + pad_value[conv_weight_dim : conv_weight_dim + 2])):
52
+ conv_weight_dim - 2
53
+ input_variable = self.pad_0.inputs[0]
54
+ pad_variable = pad_node.outputs[0] # pad output variable
55
+ index = conv_node.inputs.index(pad_variable)
56
+ conv_node.inputs.pop(index)
57
+ conv_node.inputs.insert(index, input_variable)
58
+
59
+ inputs = list(conv_node.inputs)
60
+ outputs = list(conv_node.outputs)
61
+ attrs = conv_node.attrs
62
+
63
+ conv_node.inputs.clear()
64
+ conv_node.outputs.clear()
65
+ # remove pad node if it has only one user
66
+ if len(pad_node_users) == 0:
67
+ input_variable.outputs.remove(pad_node)
68
+ pad_node.inputs.clear()
69
+ pad_node.outputs.clear()
70
+
71
+ pads = pad_value[2:conv_weight_dim] + pad_value[conv_weight_dim + 2 :]
72
+ if hasattr(attrs, "pads"):
73
+ conv_pads = attrs["pads"]
74
+ pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
75
+
76
+ attrs["pads"] = pads
77
+ match_case[conv_node.name] = {
78
+ "op": "Conv",
79
+ "inputs": inputs,
80
+ "outputs": outputs,
81
+ "name": conv_node.name,
82
+ "attrs": conv_node.attrs,
83
+ "domain": None,
84
+ }
85
+
86
+ return match_case
87
+
88
+
89
+ register_fusion_pattern(PadConvMatcher(1))
@@ -0,0 +1,67 @@
1
+ import onnxslim.third_party.onnx_graphsurgeon as gs
2
+ from onnxslim.core.pattern import Pattern, PatternMatcher
3
+ from onnxslim.core.pattern.registry import register_fusion_pattern
4
+
5
+
6
+ class ReducePatternMatcher(PatternMatcher):
7
+ def __init__(self, priority):
8
+ """Initializes the ReducePatternMatcher with a specified pattern matching priority level."""
9
+ pattern = Pattern(
10
+ """
11
+ input input 0 1 reduce_0
12
+ ReduceSum reduce_0 1+ 1 input unsqueeze_0
13
+ Unsqueeze unsqueeze_0 1+ 1 reduce_0 output
14
+ output output 1 0 unsqueeze_0
15
+ """
16
+ )
17
+ super().__init__(pattern, priority)
18
+
19
+ @property
20
+ def name(self):
21
+ """Returns the name of the fusion pattern 'FusionReduce'."""
22
+ return "FusionReduce"
23
+
24
+ def rewrite(self, opset=11):
25
+ """Rewrites the graph pattern based on opset version; reuses Reduce and Unsqueeze nodes if possible."""
26
+ match_case = {}
27
+ node = self.unsqueeze_0
28
+ reduce_node = self.reduce_0
29
+ reduce_node_node_users = reduce_node.users
30
+ if len(reduce_node_node_users) == 1:
31
+ unsqueeze_node = node
32
+
33
+ if opset < 13:
34
+ reduce_node_axes = reduce_node.attrs.get("axes", None)
35
+ reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1)
36
+ unsqueeze_node_axes = unsqueeze_node.attrs.get("axes", None)
37
+ else:
38
+ reduce_node_axes_ = reduce_node.inputs[1]
39
+ reduce_node_keepdims = reduce_node.attrs.get("keepdims", 1)
40
+ unsqueeze_node_axes_ = unsqueeze_node.inputs[1]
41
+ if isinstance(reduce_node_axes_, gs.Constant) and isinstance(unsqueeze_node_axes_, gs.Constant):
42
+ reduce_node_axes = reduce_node_axes_.values
43
+ unsqueeze_node_axes = unsqueeze_node_axes_.values
44
+ else:
45
+ return match_case
46
+
47
+ if reduce_node_axes == unsqueeze_node_axes and reduce_node_keepdims == 0:
48
+ inputs = list(reduce_node.inputs)
49
+ outputs = list(unsqueeze_node.outputs)
50
+ attrs = reduce_node.attrs
51
+ reduce_node.outputs.clear()
52
+ unsqueeze_node.inputs.clear()
53
+ unsqueeze_node.outputs.clear()
54
+ attrs["keepdims"] = 1
55
+ match_case[reduce_node.name] = {
56
+ "op": reduce_node.op,
57
+ "inputs": inputs,
58
+ "outputs": outputs,
59
+ "name": reduce_node.name,
60
+ "attrs": attrs,
61
+ "domain": None,
62
+ }
63
+
64
+ return match_case
65
+
66
+
67
+ register_fusion_pattern(ReducePatternMatcher(1))
@@ -0,0 +1,28 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import OrderedDict
4
+
5
+ DEFAULT_FUSION_PATTERNS = OrderedDict()
6
+
7
+
8
+ def register_fusion_pattern(fusion_pattern):
9
+ """Registers a fusion pattern function for a specified layer type in the DEFAULT_FUSION_PATTERNS dictionary."""
10
+ layer_type = fusion_pattern.name
11
+
12
+ if layer_type in DEFAULT_FUSION_PATTERNS.keys():
13
+ raise
14
+ DEFAULT_FUSION_PATTERNS[layer_type] = fusion_pattern
15
+
16
+
17
+ def get_fusion_patterns(skip_fusion_patterns: str | None = None):
18
+ """Returns a copy of the default fusion patterns, optionally excluding specific patterns."""
19
+ default_fusion_patterns = DEFAULT_FUSION_PATTERNS.copy()
20
+ if skip_fusion_patterns:
21
+ for pattern in skip_fusion_patterns:
22
+ default_fusion_patterns.pop(pattern)
23
+
24
+ return default_fusion_patterns
25
+
26
+
27
+ from .elimination import *
28
+ from .fusion import *
File without changes