onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -53,10 +53,17 @@ def dead_node_elimination(graph, is_subgraph=False):
|
|
|
53
53
|
node.inputs.pop(1)
|
|
54
54
|
node.inputs.insert(1, reshape_const)
|
|
55
55
|
logger.debug(f"replacing {node.op} op: {node.name}")
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
56
|
+
elif node.op == "Slice":
|
|
57
|
+
if (node.inputs[0].shape and node.outputs[0].shape
|
|
58
|
+
and node.inputs[0].shape == node.outputs[0].shape
|
|
59
|
+
and all(isinstance(item, int) for item in node.inputs[0].shape)):
|
|
60
|
+
|
|
61
|
+
# Check if slice is a no-op by analyzing parameters directly
|
|
62
|
+
# Slice inputs: data, starts, ends, [axes], [steps]
|
|
63
|
+
if is_noop_slice(node):
|
|
64
|
+
node.erase()
|
|
65
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
66
|
+
|
|
60
67
|
elif node.op == "Mul":
|
|
61
68
|
if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
|
|
62
69
|
isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
|
|
@@ -153,3 +160,77 @@ def get_constant_variable(node, return_idx=False):
|
|
|
153
160
|
for idx, input in enumerate(list(node.inputs)):
|
|
154
161
|
if isinstance(input, Constant):
|
|
155
162
|
return (idx, input) if return_idx else input
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def is_noop_slice(node):
|
|
166
|
+
"""Check if a Slice node is a no-op by analyzing its parameters directly.
|
|
167
|
+
|
|
168
|
+
A Slice is a no-op when it extracts the entire tensor, i.e., for each sliced axis:
|
|
169
|
+
- start == 0 (or equivalent negative index)
|
|
170
|
+
- end >= dim_size (or is INT_MAX-like value)
|
|
171
|
+
- step == 1
|
|
172
|
+
"""
|
|
173
|
+
# Slice inputs: data, starts, ends, [axes], [steps]
|
|
174
|
+
if len(node.inputs) < 3:
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
data_shape = node.inputs[0].shape
|
|
178
|
+
if not data_shape or not all(isinstance(d, int) for d in data_shape):
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
# Get starts and ends (required)
|
|
182
|
+
starts_input = node.inputs[1]
|
|
183
|
+
ends_input = node.inputs[2]
|
|
184
|
+
|
|
185
|
+
if not isinstance(starts_input, Constant) or not isinstance(ends_input, Constant):
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
starts = starts_input.values.flatten().tolist()
|
|
189
|
+
ends = ends_input.values.flatten().tolist()
|
|
190
|
+
|
|
191
|
+
# Get axes (optional, defaults to [0, 1, 2, ...])
|
|
192
|
+
if len(node.inputs) > 3 and isinstance(node.inputs[3], Constant):
|
|
193
|
+
axes = node.inputs[3].values.flatten().tolist()
|
|
194
|
+
else:
|
|
195
|
+
axes = list(range(len(starts)))
|
|
196
|
+
|
|
197
|
+
# Get steps (optional, defaults to [1, 1, 1, ...])
|
|
198
|
+
if len(node.inputs) > 4 and isinstance(node.inputs[4], Constant):
|
|
199
|
+
steps = node.inputs[4].values.flatten().tolist()
|
|
200
|
+
else:
|
|
201
|
+
steps = [1] * len(starts)
|
|
202
|
+
|
|
203
|
+
# Check each axis
|
|
204
|
+
ndim = len(data_shape)
|
|
205
|
+
for start, end, axis, step in zip(starts, ends, axes, steps):
|
|
206
|
+
# Normalize negative axis
|
|
207
|
+
if axis < 0:
|
|
208
|
+
axis = ndim + axis
|
|
209
|
+
|
|
210
|
+
if axis < 0 or axis >= ndim:
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
dim_size = data_shape[axis]
|
|
214
|
+
|
|
215
|
+
# Step must be 1 for no-op
|
|
216
|
+
if step != 1:
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
# Normalize negative start index
|
|
220
|
+
if start < 0:
|
|
221
|
+
start = max(0, dim_size + start)
|
|
222
|
+
|
|
223
|
+
# Start must be 0
|
|
224
|
+
if start != 0:
|
|
225
|
+
return False
|
|
226
|
+
|
|
227
|
+
# Normalize negative end index
|
|
228
|
+
if end < 0:
|
|
229
|
+
end = dim_size + end
|
|
230
|
+
|
|
231
|
+
# End must cover the entire dimension
|
|
232
|
+
# Common patterns: end == dim_size, or end is a large value like INT_MAX
|
|
233
|
+
if end < dim_size:
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
return True
|
|
@@ -39,6 +39,16 @@ class SlicePatternMatcher(PatternMatcher):
|
|
|
39
39
|
first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
|
|
40
40
|
first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
|
|
41
41
|
|
|
42
|
+
# Check all users upfront before modifying the graph.
|
|
43
|
+
# If any user has overlapping axes, skip the optimization entirely
|
|
44
|
+
# to avoid corrupting the graph (fixes GitHub issue #277).
|
|
45
|
+
for user_node in first_slice_node_users:
|
|
46
|
+
second_slice_node_inputs = list(user_node.inputs)
|
|
47
|
+
second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
|
|
48
|
+
new_axes = first_slice_node_axes + second_slice_node_axes
|
|
49
|
+
if len(new_axes) != len(set(new_axes)):
|
|
50
|
+
return match_case
|
|
51
|
+
|
|
42
52
|
for user_node in first_slice_node_users:
|
|
43
53
|
second_slice_node = user_node
|
|
44
54
|
second_slice_node_inputs = list(second_slice_node.inputs)
|
|
@@ -52,33 +62,30 @@ class SlicePatternMatcher(PatternMatcher):
|
|
|
52
62
|
new_axes = first_slice_node_axes + second_slice_node_axes
|
|
53
63
|
new_steps = first_slice_node_steps + second_slice_node_steps
|
|
54
64
|
|
|
55
|
-
if len(new_axes) != len(set(new_axes)):
|
|
56
|
-
continue
|
|
57
|
-
|
|
58
65
|
inputs = []
|
|
66
|
+
output_name = second_slice_node.outputs[0].name
|
|
59
67
|
inputs.extend(
|
|
60
68
|
(
|
|
61
69
|
next(iter(first_slice_node.inputs)),
|
|
62
70
|
gs.Constant(
|
|
63
|
-
|
|
71
|
+
output_name + "_starts",
|
|
64
72
|
values=np.array(new_starts, dtype=np.int64),
|
|
65
73
|
),
|
|
66
74
|
gs.Constant(
|
|
67
|
-
|
|
75
|
+
output_name + "_ends",
|
|
68
76
|
values=np.array(new_ends, dtype=np.int64),
|
|
69
77
|
),
|
|
70
78
|
gs.Constant(
|
|
71
|
-
|
|
79
|
+
output_name + "_axes",
|
|
72
80
|
values=np.array(new_axes, dtype=np.int64),
|
|
73
81
|
),
|
|
74
82
|
gs.Constant(
|
|
75
|
-
|
|
83
|
+
output_name + "_steps",
|
|
76
84
|
values=np.array(new_steps, dtype=np.int64),
|
|
77
85
|
),
|
|
78
86
|
)
|
|
79
87
|
)
|
|
80
88
|
outputs = list(second_slice_node.outputs)
|
|
81
|
-
|
|
82
89
|
first_slice_node.outputs.clear()
|
|
83
90
|
second_slice_node.inputs.clear()
|
|
84
91
|
second_slice_node.outputs.clear()
|
|
@@ -36,9 +36,11 @@ class ConcatReshapeMatcher(PatternMatcher):
|
|
|
36
36
|
def rewrite(self, opset=11):
|
|
37
37
|
match_case = {}
|
|
38
38
|
concat_node = self.concat_0
|
|
39
|
+
reshape_node = self.reshape_0
|
|
39
40
|
index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
|
|
41
|
+
output_name = reshape_node.outputs[0].name
|
|
40
42
|
constant = gs.Constant(
|
|
41
|
-
|
|
43
|
+
output_name + "_fixed",
|
|
42
44
|
values=np.array([-1], dtype=np.int64),
|
|
43
45
|
)
|
|
44
46
|
concat_node.inputs.pop(index)
|
|
@@ -27,12 +27,13 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
27
27
|
conv_weight = list(conv_node.inputs)[1]
|
|
28
28
|
conv_node_users = conv_node.users
|
|
29
29
|
node = self.add_0
|
|
30
|
+
oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
|
|
30
31
|
if (
|
|
31
32
|
len(conv_node_users) == 1
|
|
32
33
|
and isinstance(node.inputs[1], gs.Constant)
|
|
33
34
|
and isinstance(conv_weight, gs.Constant)
|
|
34
35
|
and node.inputs[1].values.squeeze().ndim == 1
|
|
35
|
-
and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[
|
|
36
|
+
and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[oc_axis]
|
|
36
37
|
):
|
|
37
38
|
add_node = node
|
|
38
39
|
if len(conv_node.inputs) == 2:
|
|
@@ -43,12 +44,8 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
43
44
|
inputs = []
|
|
44
45
|
inputs.append(next(iter(conv_node.inputs)))
|
|
45
46
|
inputs.append(conv_weight)
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
bias_name = f"{weight_name[:-6]}bias"
|
|
49
|
-
else:
|
|
50
|
-
bias_name = f"{weight_name}_bias"
|
|
51
|
-
inputs.append(gs.Constant(bias_name, values=conv_bias))
|
|
47
|
+
output_name = add_node.outputs[0].name
|
|
48
|
+
inputs.append(gs.Constant(output_name + "_bias", values=conv_bias))
|
|
52
49
|
outputs = list(add_node.outputs)
|
|
53
50
|
|
|
54
51
|
conv_node.outputs.clear()
|
|
@@ -66,5 +63,24 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
66
63
|
|
|
67
64
|
return match_case
|
|
68
65
|
|
|
66
|
+
class ConvTransposeAddMatcher(ConvAddMatcher):
|
|
67
|
+
def __init__(self, priority):
|
|
68
|
+
"""Initializes the ConvTransposeAddMatcher for fusing ConvTranspose and Add layers in an ONNX graph."""
|
|
69
|
+
pattern = Pattern(
|
|
70
|
+
"""
|
|
71
|
+
input input 0 1 conv_0
|
|
72
|
+
ConvTranspose conv_0 1+ 1 input bn_0
|
|
73
|
+
Add add_0 2 1 conv_0 ? output
|
|
74
|
+
output output 1 0 add_0
|
|
75
|
+
"""
|
|
76
|
+
)
|
|
77
|
+
super(ConvAddMatcher, self).__init__(pattern, priority)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def name(self):
|
|
81
|
+
"""Returns the name of the FusionConvTransposeAdd pattern."""
|
|
82
|
+
return "FusionConvTransposeAdd"
|
|
83
|
+
|
|
69
84
|
|
|
70
85
|
register_fusion_pattern(ConvAddMatcher(1))
|
|
86
|
+
register_fusion_pattern(ConvTransposeAddMatcher(1))
|
|
@@ -44,25 +44,19 @@ class ConvBatchNormMatcher(PatternMatcher):
|
|
|
44
44
|
conv_transpose_bias = conv_transpose_node.inputs[2].values
|
|
45
45
|
|
|
46
46
|
bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
|
|
47
|
+
oc_axis = 0 if conv_transpose_node.op == "Conv" else 1 # output_channel_axis
|
|
47
48
|
shape = [1] * len(conv_transpose_weight.shape)
|
|
48
|
-
|
|
49
|
-
shape[0] = -1
|
|
50
|
-
else:
|
|
51
|
-
shape[1] = -1
|
|
49
|
+
shape[oc_axis] = -1
|
|
52
50
|
conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
|
|
53
51
|
conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
|
|
54
52
|
|
|
55
53
|
inputs = []
|
|
56
54
|
inputs.append(next(iter(conv_transpose_node.inputs)))
|
|
57
|
-
|
|
58
|
-
if weight_name.endswith("weight"):
|
|
59
|
-
bias_name = f"{weight_name[:-6]}bias"
|
|
60
|
-
else:
|
|
61
|
-
bias_name = f"{weight_name}_bias"
|
|
55
|
+
output_name = bn_node.outputs[0].name
|
|
62
56
|
inputs.extend(
|
|
63
57
|
(
|
|
64
|
-
gs.Constant(
|
|
65
|
-
gs.Constant(
|
|
58
|
+
gs.Constant(output_name + "_weight", values=conv_w),
|
|
59
|
+
gs.Constant(output_name + "_bias", values=conv_b),
|
|
66
60
|
)
|
|
67
61
|
)
|
|
68
62
|
outputs = list(bn_node.outputs)
|
|
@@ -82,5 +76,24 @@ class ConvBatchNormMatcher(PatternMatcher):
|
|
|
82
76
|
|
|
83
77
|
return match_case
|
|
84
78
|
|
|
79
|
+
class ConvTransposeBatchNormMatcher(ConvBatchNormMatcher):
|
|
80
|
+
def __init__(self, priority):
|
|
81
|
+
"""Initializes the ConvTransposeBatchNormMatcher for fusing ConvTranspose and BatchNormalization layers in an ONNX graph."""
|
|
82
|
+
pattern = Pattern(
|
|
83
|
+
"""
|
|
84
|
+
input input 0 1 conv_0
|
|
85
|
+
ConvTranspose conv_0 1+ 1 input bn_0
|
|
86
|
+
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
|
|
87
|
+
output output 1 0 bn_0
|
|
88
|
+
"""
|
|
89
|
+
)
|
|
90
|
+
super(ConvBatchNormMatcher, self).__init__(pattern, priority)
|
|
91
|
+
|
|
92
|
+
@property
|
|
93
|
+
def name(self):
|
|
94
|
+
"""Returns the name of the FusionConvTransposeBN pattern."""
|
|
95
|
+
return "FusionConvTransposeBN"
|
|
96
|
+
|
|
85
97
|
|
|
86
98
|
register_fusion_pattern(ConvBatchNormMatcher(1))
|
|
99
|
+
register_fusion_pattern(ConvTransposeBatchNormMatcher(1))
|
|
@@ -28,25 +28,23 @@ class ConvMulMatcher(PatternMatcher):
|
|
|
28
28
|
conv_weight = list(conv_node.inputs)[1]
|
|
29
29
|
if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
|
|
30
30
|
mul_constant = mul_node.inputs[1].values
|
|
31
|
-
|
|
32
|
-
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[
|
|
33
|
-
|
|
34
|
-
reshape_shape
|
|
35
|
-
|
|
31
|
+
oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
|
|
32
|
+
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[oc_axis]:
|
|
33
|
+
reshape_shape = [1] * len(conv_weight.values.shape)
|
|
34
|
+
reshape_shape[oc_axis] = -1
|
|
36
35
|
mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
|
|
37
36
|
new_weight = conv_weight.values * mul_scale_reshaped
|
|
38
37
|
|
|
39
38
|
inputs = []
|
|
40
39
|
inputs.append(next(iter(conv_node.inputs)))
|
|
41
40
|
|
|
42
|
-
|
|
43
|
-
inputs.append(gs.Constant(
|
|
41
|
+
output_name = mul_node.outputs[0].name
|
|
42
|
+
inputs.append(gs.Constant(output_name + "_weight", values=new_weight))
|
|
44
43
|
|
|
45
44
|
if len(conv_node.inputs) == 3:
|
|
46
45
|
conv_bias = conv_node.inputs[2].values
|
|
47
46
|
new_bias = conv_bias * mul_constant.squeeze()
|
|
48
|
-
|
|
49
|
-
inputs.append(gs.Constant(bias_name, values=new_bias))
|
|
47
|
+
inputs.append(gs.Constant(output_name + "_bias", values=new_bias))
|
|
50
48
|
|
|
51
49
|
outputs = list(mul_node.outputs)
|
|
52
50
|
|
|
@@ -65,5 +63,24 @@ class ConvMulMatcher(PatternMatcher):
|
|
|
65
63
|
|
|
66
64
|
return match_case
|
|
67
65
|
|
|
66
|
+
class ConvTransposeMulMatcher(ConvMulMatcher):
|
|
67
|
+
def __init__(self, priority):
|
|
68
|
+
"""Initializes the ConvTransposeMulMatcher for fusing ConvTranspose and Mul layers in an ONNX graph."""
|
|
69
|
+
pattern = Pattern(
|
|
70
|
+
"""
|
|
71
|
+
input input 0 1 conv_0
|
|
72
|
+
ConvTranspose conv_0 1+ 1 input mul_0
|
|
73
|
+
Mul mul_0 2 1 conv_0 ? output
|
|
74
|
+
output output 1 0 mul_0
|
|
75
|
+
"""
|
|
76
|
+
)
|
|
77
|
+
super(ConvMulMatcher, self).__init__(pattern, priority)
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def name(self):
|
|
81
|
+
"""Returns the name of the FusionConvTransposeMul pattern."""
|
|
82
|
+
return "FusionConvTransposeMul"
|
|
83
|
+
|
|
68
84
|
|
|
69
85
|
register_fusion_pattern(ConvMulMatcher(1))
|
|
86
|
+
register_fusion_pattern(ConvTransposeMulMatcher(1))
|
|
@@ -76,7 +76,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
|
|
|
76
76
|
output_variable.outputs.remove(add_node)
|
|
77
77
|
|
|
78
78
|
matmul_bias_transpose_constant = gs.Constant(
|
|
79
|
-
|
|
79
|
+
f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
inputs = []
|
|
@@ -143,7 +143,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
|
|
|
143
143
|
output_variable.outputs.remove(add_node)
|
|
144
144
|
|
|
145
145
|
matmul_bias_transpose_constant = gs.Constant(
|
|
146
|
-
|
|
146
|
+
f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
|
|
147
147
|
)
|
|
148
148
|
|
|
149
149
|
inputs = []
|
|
@@ -235,14 +235,15 @@ class GemmMulPatternMatcher(PatternMatcher):
|
|
|
235
235
|
gemm_weight_fused = gemm_weight * mul_weight[:, None]
|
|
236
236
|
else:
|
|
237
237
|
gemm_weight_fused = gemm_weight * mul_weight
|
|
238
|
-
|
|
238
|
+
output_name = reshape_node.outputs[0].name
|
|
239
|
+
gemm_weight_fused_constant = gs.Constant(output_name + "_weight_fused", values=gemm_weight_fused)
|
|
239
240
|
gemm_node.inputs[1] = gemm_weight_fused_constant
|
|
240
241
|
|
|
241
242
|
if gemm_bias_constant:
|
|
242
243
|
gemm_bias = gemm_bias_constant.values
|
|
243
244
|
mul_bias = mul_bias_variable.values
|
|
244
245
|
gemm_bias_fused = gemm_bias * mul_bias
|
|
245
|
-
gemm_bias_fused_constant = gs.Constant(
|
|
246
|
+
gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
|
|
246
247
|
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
247
248
|
|
|
248
249
|
mul_node.replace_all_uses_with(reshape_node)
|
|
@@ -312,7 +313,8 @@ class GemmAddPatternMatcher(PatternMatcher):
|
|
|
312
313
|
and add_bias.ndim <= 2
|
|
313
314
|
):
|
|
314
315
|
gemm_bias_fused = gemm_bias + add_bias
|
|
315
|
-
|
|
316
|
+
output_name = reshape_node.outputs[0].name
|
|
317
|
+
gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
|
|
316
318
|
gemm_node.inputs[2] = gemm_bias_fused_constant
|
|
317
319
|
else:
|
|
318
320
|
return match_case
|
|
@@ -37,6 +37,8 @@ class PadConvMatcher(PatternMatcher):
|
|
|
37
37
|
pad_node_users = pad_node.users
|
|
38
38
|
|
|
39
39
|
pad_inputs = len(pad_node.inputs)
|
|
40
|
+
auto_pad = pad_node.attrs.get("auto_pad", "NOTSET")
|
|
41
|
+
|
|
40
42
|
if pad_inputs < 3 or (
|
|
41
43
|
(pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
|
|
42
44
|
or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
|
|
@@ -45,6 +47,7 @@ class PadConvMatcher(PatternMatcher):
|
|
|
45
47
|
isinstance(pad_node.inputs[1], gs.Constant)
|
|
46
48
|
and pad_node.attrs.get("mode", "constant") == "constant"
|
|
47
49
|
and conv_node.inputs[1].shape
|
|
50
|
+
and (auto_pad == "NOTSET" or auto_pad == "VALID")
|
|
48
51
|
):
|
|
49
52
|
conv_weight_dim = len(conv_node.inputs[1].shape)
|
|
50
53
|
pad_value = pad_node.inputs[1].values.tolist()
|
|
@@ -74,6 +77,8 @@ class PadConvMatcher(PatternMatcher):
|
|
|
74
77
|
pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
|
|
75
78
|
|
|
76
79
|
attrs["pads"] = pads
|
|
80
|
+
conv_node.attrs.pop("auto_pad", None)
|
|
81
|
+
|
|
77
82
|
match_case[conv_node.name] = {
|
|
78
83
|
"op": "Conv",
|
|
79
84
|
"inputs": inputs,
|