onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -53,10 +53,17 @@ def dead_node_elimination(graph, is_subgraph=False):
53
53
  node.inputs.pop(1)
54
54
  node.inputs.insert(1, reshape_const)
55
55
  logger.debug(f"replacing {node.op} op: {node.name}")
56
- # elif node.op == "Slice":
57
- # if node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape and all(isinstance(item, int) for item in node.inputs[0].shape):
58
- # node.erase()
59
- # logger.debug(f"removing {node.op} op: {node.name}")
56
+ elif node.op == "Slice":
57
+ if (node.inputs[0].shape and node.outputs[0].shape
58
+ and node.inputs[0].shape == node.outputs[0].shape
59
+ and all(isinstance(item, int) for item in node.inputs[0].shape)):
60
+
61
+ # Check if slice is a no-op by analyzing parameters directly
62
+ # Slice inputs: data, starts, ends, [axes], [steps]
63
+ if is_noop_slice(node):
64
+ node.erase()
65
+ logger.debug(f"removing {node.op} op: {node.name}")
66
+
60
67
  elif node.op == "Mul":
61
68
  if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
62
69
  isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
@@ -153,3 +160,77 @@ def get_constant_variable(node, return_idx=False):
153
160
  for idx, input in enumerate(list(node.inputs)):
154
161
  if isinstance(input, Constant):
155
162
  return (idx, input) if return_idx else input
163
+
164
+
165
+ def is_noop_slice(node):
166
+ """Check if a Slice node is a no-op by analyzing its parameters directly.
167
+
168
+ A Slice is a no-op when it extracts the entire tensor, i.e., for each sliced axis:
169
+ - start == 0 (or equivalent negative index)
170
+ - end >= dim_size (or is INT_MAX-like value)
171
+ - step == 1
172
+ """
173
+ # Slice inputs: data, starts, ends, [axes], [steps]
174
+ if len(node.inputs) < 3:
175
+ return False
176
+
177
+ data_shape = node.inputs[0].shape
178
+ if not data_shape or not all(isinstance(d, int) for d in data_shape):
179
+ return False
180
+
181
+ # Get starts and ends (required)
182
+ starts_input = node.inputs[1]
183
+ ends_input = node.inputs[2]
184
+
185
+ if not isinstance(starts_input, Constant) or not isinstance(ends_input, Constant):
186
+ return False
187
+
188
+ starts = starts_input.values.flatten().tolist()
189
+ ends = ends_input.values.flatten().tolist()
190
+
191
+ # Get axes (optional, defaults to [0, 1, 2, ...])
192
+ if len(node.inputs) > 3 and isinstance(node.inputs[3], Constant):
193
+ axes = node.inputs[3].values.flatten().tolist()
194
+ else:
195
+ axes = list(range(len(starts)))
196
+
197
+ # Get steps (optional, defaults to [1, 1, 1, ...])
198
+ if len(node.inputs) > 4 and isinstance(node.inputs[4], Constant):
199
+ steps = node.inputs[4].values.flatten().tolist()
200
+ else:
201
+ steps = [1] * len(starts)
202
+
203
+ # Check each axis
204
+ ndim = len(data_shape)
205
+ for start, end, axis, step in zip(starts, ends, axes, steps):
206
+ # Normalize negative axis
207
+ if axis < 0:
208
+ axis = ndim + axis
209
+
210
+ if axis < 0 or axis >= ndim:
211
+ return False
212
+
213
+ dim_size = data_shape[axis]
214
+
215
+ # Step must be 1 for no-op
216
+ if step != 1:
217
+ return False
218
+
219
+ # Normalize negative start index
220
+ if start < 0:
221
+ start = max(0, dim_size + start)
222
+
223
+ # Start must be 0
224
+ if start != 0:
225
+ return False
226
+
227
+ # Normalize negative end index
228
+ if end < 0:
229
+ end = dim_size + end
230
+
231
+ # End must cover the entire dimension
232
+ # Common patterns: end == dim_size, or end is a large value like INT_MAX
233
+ if end < dim_size:
234
+ return False
235
+
236
+ return True
@@ -39,6 +39,16 @@ class SlicePatternMatcher(PatternMatcher):
39
39
  first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
40
40
  first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
41
41
 
42
+ # Check all users upfront before modifying the graph.
43
+ # If any user has overlapping axes, skip the optimization entirely
44
+ # to avoid corrupting the graph (fixes GitHub issue #277).
45
+ for user_node in first_slice_node_users:
46
+ second_slice_node_inputs = list(user_node.inputs)
47
+ second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
48
+ new_axes = first_slice_node_axes + second_slice_node_axes
49
+ if len(new_axes) != len(set(new_axes)):
50
+ return match_case
51
+
42
52
  for user_node in first_slice_node_users:
43
53
  second_slice_node = user_node
44
54
  second_slice_node_inputs = list(second_slice_node.inputs)
@@ -52,33 +62,30 @@ class SlicePatternMatcher(PatternMatcher):
52
62
  new_axes = first_slice_node_axes + second_slice_node_axes
53
63
  new_steps = first_slice_node_steps + second_slice_node_steps
54
64
 
55
- if len(new_axes) != len(set(new_axes)):
56
- continue
57
-
58
65
  inputs = []
66
+ output_name = second_slice_node.outputs[0].name
59
67
  inputs.extend(
60
68
  (
61
69
  next(iter(first_slice_node.inputs)),
62
70
  gs.Constant(
63
- second_slice_node_inputs[1].name + "_starts",
71
+ output_name + "_starts",
64
72
  values=np.array(new_starts, dtype=np.int64),
65
73
  ),
66
74
  gs.Constant(
67
- second_slice_node_inputs[2].name + "_ends",
75
+ output_name + "_ends",
68
76
  values=np.array(new_ends, dtype=np.int64),
69
77
  ),
70
78
  gs.Constant(
71
- second_slice_node_inputs[3].name + "_axes",
79
+ output_name + "_axes",
72
80
  values=np.array(new_axes, dtype=np.int64),
73
81
  ),
74
82
  gs.Constant(
75
- second_slice_node_inputs[4].name + "_steps",
83
+ output_name + "_steps",
76
84
  values=np.array(new_steps, dtype=np.int64),
77
85
  ),
78
86
  )
79
87
  )
80
88
  outputs = list(second_slice_node.outputs)
81
-
82
89
  first_slice_node.outputs.clear()
83
90
  second_slice_node.inputs.clear()
84
91
  second_slice_node.outputs.clear()
@@ -36,9 +36,11 @@ class ConcatReshapeMatcher(PatternMatcher):
36
36
  def rewrite(self, opset=11):
37
37
  match_case = {}
38
38
  concat_node = self.concat_0
39
+ reshape_node = self.reshape_0
39
40
  index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
41
+ output_name = reshape_node.outputs[0].name
40
42
  constant = gs.Constant(
41
- concat_node.inputs[index].name + "_fixed",
43
+ output_name + "_fixed",
42
44
  values=np.array([-1], dtype=np.int64),
43
45
  )
44
46
  concat_node.inputs.pop(index)
@@ -27,12 +27,13 @@ class ConvAddMatcher(PatternMatcher):
27
27
  conv_weight = list(conv_node.inputs)[1]
28
28
  conv_node_users = conv_node.users
29
29
  node = self.add_0
30
+ oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
30
31
  if (
31
32
  len(conv_node_users) == 1
32
33
  and isinstance(node.inputs[1], gs.Constant)
33
34
  and isinstance(conv_weight, gs.Constant)
34
35
  and node.inputs[1].values.squeeze().ndim == 1
35
- and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]
36
+ and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[oc_axis]
36
37
  ):
37
38
  add_node = node
38
39
  if len(conv_node.inputs) == 2:
@@ -43,12 +44,8 @@ class ConvAddMatcher(PatternMatcher):
43
44
  inputs = []
44
45
  inputs.append(next(iter(conv_node.inputs)))
45
46
  inputs.append(conv_weight)
46
- weight_name = list(conv_node.inputs)[1].name
47
- if weight_name.endswith("weight"):
48
- bias_name = f"{weight_name[:-6]}bias"
49
- else:
50
- bias_name = f"{weight_name}_bias"
51
- inputs.append(gs.Constant(bias_name, values=conv_bias))
47
+ output_name = add_node.outputs[0].name
48
+ inputs.append(gs.Constant(output_name + "_bias", values=conv_bias))
52
49
  outputs = list(add_node.outputs)
53
50
 
54
51
  conv_node.outputs.clear()
@@ -66,5 +63,24 @@ class ConvAddMatcher(PatternMatcher):
66
63
 
67
64
  return match_case
68
65
 
66
+ class ConvTransposeAddMatcher(ConvAddMatcher):
67
+ def __init__(self, priority):
68
+ """Initializes the ConvTransposeAddMatcher for fusing ConvTranspose and Add layers in an ONNX graph."""
69
+ pattern = Pattern(
70
+ """
71
+ input input 0 1 conv_0
72
+ ConvTranspose conv_0 1+ 1 input bn_0
73
+ Add add_0 2 1 conv_0 ? output
74
+ output output 1 0 add_0
75
+ """
76
+ )
77
+ super(ConvAddMatcher, self).__init__(pattern, priority)
78
+
79
+ @property
80
+ def name(self):
81
+ """Returns the name of the FusionConvTransposeAdd pattern."""
82
+ return "FusionConvTransposeAdd"
83
+
69
84
 
70
85
  register_fusion_pattern(ConvAddMatcher(1))
86
+ register_fusion_pattern(ConvTransposeAddMatcher(1))
@@ -44,25 +44,19 @@ class ConvBatchNormMatcher(PatternMatcher):
44
44
  conv_transpose_bias = conv_transpose_node.inputs[2].values
45
45
 
46
46
  bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
47
+ oc_axis = 0 if conv_transpose_node.op == "Conv" else 1 # output_channel_axis
47
48
  shape = [1] * len(conv_transpose_weight.shape)
48
- if bn_node.i(0).op == "Conv":
49
- shape[0] = -1
50
- else:
51
- shape[1] = -1
49
+ shape[oc_axis] = -1
52
50
  conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
53
51
  conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
54
52
 
55
53
  inputs = []
56
54
  inputs.append(next(iter(conv_transpose_node.inputs)))
57
- weight_name = list(conv_transpose_node.inputs)[1].name
58
- if weight_name.endswith("weight"):
59
- bias_name = f"{weight_name[:-6]}bias"
60
- else:
61
- bias_name = f"{weight_name}_bias"
55
+ output_name = bn_node.outputs[0].name
62
56
  inputs.extend(
63
57
  (
64
- gs.Constant(weight_name + "_weight", values=conv_w),
65
- gs.Constant(bias_name, values=conv_b),
58
+ gs.Constant(output_name + "_weight", values=conv_w),
59
+ gs.Constant(output_name + "_bias", values=conv_b),
66
60
  )
67
61
  )
68
62
  outputs = list(bn_node.outputs)
@@ -82,5 +76,24 @@ class ConvBatchNormMatcher(PatternMatcher):
82
76
 
83
77
  return match_case
84
78
 
79
+ class ConvTransposeBatchNormMatcher(ConvBatchNormMatcher):
80
+ def __init__(self, priority):
81
+ """Initializes the ConvTransposeBatchNormMatcher for fusing ConvTranspose and BatchNormalization layers in an ONNX graph."""
82
+ pattern = Pattern(
83
+ """
84
+ input input 0 1 conv_0
85
+ ConvTranspose conv_0 1+ 1 input bn_0
86
+ BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
87
+ output output 1 0 bn_0
88
+ """
89
+ )
90
+ super(ConvBatchNormMatcher, self).__init__(pattern, priority)
91
+
92
+ @property
93
+ def name(self):
94
+ """Returns the name of the FusionConvTransposeBN pattern."""
95
+ return "FusionConvTransposeBN"
96
+
85
97
 
86
98
  register_fusion_pattern(ConvBatchNormMatcher(1))
99
+ register_fusion_pattern(ConvTransposeBatchNormMatcher(1))
@@ -28,25 +28,23 @@ class ConvMulMatcher(PatternMatcher):
28
28
  conv_weight = list(conv_node.inputs)[1]
29
29
  if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
30
30
  mul_constant = mul_node.inputs[1].values
31
-
32
- if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
33
- weight_shape = conv_weight.values.shape
34
- reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
35
-
31
+ oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
32
+ if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[oc_axis]:
33
+ reshape_shape = [1] * len(conv_weight.values.shape)
34
+ reshape_shape[oc_axis] = -1
36
35
  mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
37
36
  new_weight = conv_weight.values * mul_scale_reshaped
38
37
 
39
38
  inputs = []
40
39
  inputs.append(next(iter(conv_node.inputs)))
41
40
 
42
- weight_name = list(conv_node.inputs)[1].name
43
- inputs.append(gs.Constant(weight_name, values=new_weight))
41
+ output_name = mul_node.outputs[0].name
42
+ inputs.append(gs.Constant(output_name + "_weight", values=new_weight))
44
43
 
45
44
  if len(conv_node.inputs) == 3:
46
45
  conv_bias = conv_node.inputs[2].values
47
46
  new_bias = conv_bias * mul_constant.squeeze()
48
- bias_name = list(conv_node.inputs)[2].name
49
- inputs.append(gs.Constant(bias_name, values=new_bias))
47
+ inputs.append(gs.Constant(output_name + "_bias", values=new_bias))
50
48
 
51
49
  outputs = list(mul_node.outputs)
52
50
 
@@ -65,5 +63,24 @@ class ConvMulMatcher(PatternMatcher):
65
63
 
66
64
  return match_case
67
65
 
66
+ class ConvTransposeMulMatcher(ConvMulMatcher):
67
+ def __init__(self, priority):
68
+ """Initializes the ConvTransposeMulMatcher for fusing ConvTranspose and Mul layers in an ONNX graph."""
69
+ pattern = Pattern(
70
+ """
71
+ input input 0 1 conv_0
72
+ ConvTranspose conv_0 1+ 1 input mul_0
73
+ Mul mul_0 2 1 conv_0 ? output
74
+ output output 1 0 mul_0
75
+ """
76
+ )
77
+ super(ConvMulMatcher, self).__init__(pattern, priority)
78
+
79
+ @property
80
+ def name(self):
81
+ """Returns the name of the FusionConvTransposeMul pattern."""
82
+ return "FusionConvTransposeMul"
83
+
68
84
 
69
85
  register_fusion_pattern(ConvMulMatcher(1))
86
+ register_fusion_pattern(ConvTransposeMulMatcher(1))
@@ -76,7 +76,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
76
76
  output_variable.outputs.remove(add_node)
77
77
 
78
78
  matmul_bias_transpose_constant = gs.Constant(
79
- matmul_bias_variable.name, values=matmul_bias_variable.values.T
79
+ f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
80
80
  )
81
81
 
82
82
  inputs = []
@@ -143,7 +143,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
143
143
  output_variable.outputs.remove(add_node)
144
144
 
145
145
  matmul_bias_transpose_constant = gs.Constant(
146
- matmul_bias_variable.name, values=matmul_bias_variable.values.T
146
+ f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
147
147
  )
148
148
 
149
149
  inputs = []
@@ -235,14 +235,15 @@ class GemmMulPatternMatcher(PatternMatcher):
235
235
  gemm_weight_fused = gemm_weight * mul_weight[:, None]
236
236
  else:
237
237
  gemm_weight_fused = gemm_weight * mul_weight
238
- gemm_weight_fused_constant = gs.Constant(gemm_weight_constant.name + "_fused", values=gemm_weight_fused)
238
+ output_name = reshape_node.outputs[0].name
239
+ gemm_weight_fused_constant = gs.Constant(output_name + "_weight_fused", values=gemm_weight_fused)
239
240
  gemm_node.inputs[1] = gemm_weight_fused_constant
240
241
 
241
242
  if gemm_bias_constant:
242
243
  gemm_bias = gemm_bias_constant.values
243
244
  mul_bias = mul_bias_variable.values
244
245
  gemm_bias_fused = gemm_bias * mul_bias
245
- gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
246
+ gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
246
247
  gemm_node.inputs[2] = gemm_bias_fused_constant
247
248
 
248
249
  mul_node.replace_all_uses_with(reshape_node)
@@ -312,7 +313,8 @@ class GemmAddPatternMatcher(PatternMatcher):
312
313
  and add_bias.ndim <= 2
313
314
  ):
314
315
  gemm_bias_fused = gemm_bias + add_bias
315
- gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
316
+ output_name = reshape_node.outputs[0].name
317
+ gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
316
318
  gemm_node.inputs[2] = gemm_bias_fused_constant
317
319
  else:
318
320
  return match_case
@@ -37,6 +37,8 @@ class PadConvMatcher(PatternMatcher):
37
37
  pad_node_users = pad_node.users
38
38
 
39
39
  pad_inputs = len(pad_node.inputs)
40
+ auto_pad = pad_node.attrs.get("auto_pad", "NOTSET")
41
+
40
42
  if pad_inputs < 3 or (
41
43
  (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
42
44
  or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
@@ -45,6 +47,7 @@ class PadConvMatcher(PatternMatcher):
45
47
  isinstance(pad_node.inputs[1], gs.Constant)
46
48
  and pad_node.attrs.get("mode", "constant") == "constant"
47
49
  and conv_node.inputs[1].shape
50
+ and (auto_pad == "NOTSET" or auto_pad == "VALID")
48
51
  ):
49
52
  conv_weight_dim = len(conv_node.inputs[1].shape)
50
53
  pad_value = pad_node.inputs[1].values.tolist()
@@ -74,6 +77,8 @@ class PadConvMatcher(PatternMatcher):
74
77
  pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
75
78
 
76
79
  attrs["pads"] = pads
80
+ conv_node.attrs.pop("auto_pad", None)
81
+
77
82
  match_case[conv_node.name] = {
78
83
  "op": "Conv",
79
84
  "inputs": inputs,