onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +84 -3
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.81.dist-info/RECORD +0 -63
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -54,9 +54,16 @@ def dead_node_elimination(graph, is_subgraph=False):
|
|
|
54
54
|
node.inputs.insert(1, reshape_const)
|
|
55
55
|
logger.debug(f"replacing {node.op} op: {node.name}")
|
|
56
56
|
elif node.op == "Slice":
|
|
57
|
-
if node.inputs[0].shape and node.outputs[0].shape
|
|
58
|
-
node.
|
|
59
|
-
|
|
57
|
+
if (node.inputs[0].shape and node.outputs[0].shape
|
|
58
|
+
and node.inputs[0].shape == node.outputs[0].shape
|
|
59
|
+
and all(isinstance(item, int) for item in node.inputs[0].shape)):
|
|
60
|
+
|
|
61
|
+
# Check if slice is a no-op by analyzing parameters directly
|
|
62
|
+
# Slice inputs: data, starts, ends, [axes], [steps]
|
|
63
|
+
if is_noop_slice(node):
|
|
64
|
+
node.erase()
|
|
65
|
+
logger.debug(f"removing {node.op} op: {node.name}")
|
|
66
|
+
|
|
60
67
|
elif node.op == "Mul":
|
|
61
68
|
if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
|
|
62
69
|
isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
|
|
@@ -153,3 +160,77 @@ def get_constant_variable(node, return_idx=False):
|
|
|
153
160
|
for idx, input in enumerate(list(node.inputs)):
|
|
154
161
|
if isinstance(input, Constant):
|
|
155
162
|
return (idx, input) if return_idx else input
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def is_noop_slice(node):
|
|
166
|
+
"""Check if a Slice node is a no-op by analyzing its parameters directly.
|
|
167
|
+
|
|
168
|
+
A Slice is a no-op when it extracts the entire tensor, i.e., for each sliced axis:
|
|
169
|
+
- start == 0 (or equivalent negative index)
|
|
170
|
+
- end >= dim_size (or is INT_MAX-like value)
|
|
171
|
+
- step == 1
|
|
172
|
+
"""
|
|
173
|
+
# Slice inputs: data, starts, ends, [axes], [steps]
|
|
174
|
+
if len(node.inputs) < 3:
|
|
175
|
+
return False
|
|
176
|
+
|
|
177
|
+
data_shape = node.inputs[0].shape
|
|
178
|
+
if not data_shape or not all(isinstance(d, int) for d in data_shape):
|
|
179
|
+
return False
|
|
180
|
+
|
|
181
|
+
# Get starts and ends (required)
|
|
182
|
+
starts_input = node.inputs[1]
|
|
183
|
+
ends_input = node.inputs[2]
|
|
184
|
+
|
|
185
|
+
if not isinstance(starts_input, Constant) or not isinstance(ends_input, Constant):
|
|
186
|
+
return False
|
|
187
|
+
|
|
188
|
+
starts = starts_input.values.flatten().tolist()
|
|
189
|
+
ends = ends_input.values.flatten().tolist()
|
|
190
|
+
|
|
191
|
+
# Get axes (optional, defaults to [0, 1, 2, ...])
|
|
192
|
+
if len(node.inputs) > 3 and isinstance(node.inputs[3], Constant):
|
|
193
|
+
axes = node.inputs[3].values.flatten().tolist()
|
|
194
|
+
else:
|
|
195
|
+
axes = list(range(len(starts)))
|
|
196
|
+
|
|
197
|
+
# Get steps (optional, defaults to [1, 1, 1, ...])
|
|
198
|
+
if len(node.inputs) > 4 and isinstance(node.inputs[4], Constant):
|
|
199
|
+
steps = node.inputs[4].values.flatten().tolist()
|
|
200
|
+
else:
|
|
201
|
+
steps = [1] * len(starts)
|
|
202
|
+
|
|
203
|
+
# Check each axis
|
|
204
|
+
ndim = len(data_shape)
|
|
205
|
+
for start, end, axis, step in zip(starts, ends, axes, steps):
|
|
206
|
+
# Normalize negative axis
|
|
207
|
+
if axis < 0:
|
|
208
|
+
axis = ndim + axis
|
|
209
|
+
|
|
210
|
+
if axis < 0 or axis >= ndim:
|
|
211
|
+
return False
|
|
212
|
+
|
|
213
|
+
dim_size = data_shape[axis]
|
|
214
|
+
|
|
215
|
+
# Step must be 1 for no-op
|
|
216
|
+
if step != 1:
|
|
217
|
+
return False
|
|
218
|
+
|
|
219
|
+
# Normalize negative start index
|
|
220
|
+
if start < 0:
|
|
221
|
+
start = max(0, dim_size + start)
|
|
222
|
+
|
|
223
|
+
# Start must be 0
|
|
224
|
+
if start != 0:
|
|
225
|
+
return False
|
|
226
|
+
|
|
227
|
+
# Normalize negative end index
|
|
228
|
+
if end < 0:
|
|
229
|
+
end = dim_size + end
|
|
230
|
+
|
|
231
|
+
# End must cover the entire dimension
|
|
232
|
+
# Common patterns: end == dim_size, or end is a large value like INT_MAX
|
|
233
|
+
if end < dim_size:
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
return True
|
|
@@ -27,12 +27,13 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
27
27
|
conv_weight = list(conv_node.inputs)[1]
|
|
28
28
|
conv_node_users = conv_node.users
|
|
29
29
|
node = self.add_0
|
|
30
|
+
oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
|
|
30
31
|
if (
|
|
31
32
|
len(conv_node_users) == 1
|
|
32
33
|
and isinstance(node.inputs[1], gs.Constant)
|
|
33
34
|
and isinstance(conv_weight, gs.Constant)
|
|
34
35
|
and node.inputs[1].values.squeeze().ndim == 1
|
|
35
|
-
and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[
|
|
36
|
+
and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[oc_axis]
|
|
36
37
|
):
|
|
37
38
|
add_node = node
|
|
38
39
|
if len(conv_node.inputs) == 2:
|
|
@@ -66,5 +67,24 @@ class ConvAddMatcher(PatternMatcher):
|
|
|
66
67
|
|
|
67
68
|
return match_case
|
|
68
69
|
|
|
70
|
+
class ConvTransposeAddMatcher(ConvAddMatcher):
|
|
71
|
+
def __init__(self, priority):
|
|
72
|
+
"""Initializes the ConvTransposeAddMatcher for fusing ConvTranspose and Add layers in an ONNX graph."""
|
|
73
|
+
pattern = Pattern(
|
|
74
|
+
"""
|
|
75
|
+
input input 0 1 conv_0
|
|
76
|
+
ConvTranspose conv_0 1+ 1 input bn_0
|
|
77
|
+
Add add_0 2 1 conv_0 ? output
|
|
78
|
+
output output 1 0 add_0
|
|
79
|
+
"""
|
|
80
|
+
)
|
|
81
|
+
super(ConvAddMatcher, self).__init__(pattern, priority)
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
def name(self):
|
|
85
|
+
"""Returns the name of the FusionConvTransposeAdd pattern."""
|
|
86
|
+
return "FusionConvTransposeAdd"
|
|
87
|
+
|
|
69
88
|
|
|
70
89
|
register_fusion_pattern(ConvAddMatcher(1))
|
|
90
|
+
register_fusion_pattern(ConvTransposeAddMatcher(1))
|
|
@@ -44,11 +44,9 @@ class ConvBatchNormMatcher(PatternMatcher):
|
|
|
44
44
|
conv_transpose_bias = conv_transpose_node.inputs[2].values
|
|
45
45
|
|
|
46
46
|
bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
|
|
47
|
+
oc_axis = 0 if conv_transpose_node.op == "Conv" else 1 # output_channel_axis
|
|
47
48
|
shape = [1] * len(conv_transpose_weight.shape)
|
|
48
|
-
|
|
49
|
-
shape[0] = -1
|
|
50
|
-
else:
|
|
51
|
-
shape[1] = -1
|
|
49
|
+
shape[oc_axis] = -1
|
|
52
50
|
conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
|
|
53
51
|
conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
|
|
54
52
|
|
|
@@ -82,5 +80,24 @@ class ConvBatchNormMatcher(PatternMatcher):
|
|
|
82
80
|
|
|
83
81
|
return match_case
|
|
84
82
|
|
|
83
|
+
class ConvTransposeBatchNormMatcher(ConvBatchNormMatcher):
|
|
84
|
+
def __init__(self, priority):
|
|
85
|
+
"""Initializes the ConvTransposeBatchNormMatcher for fusing ConvTranspose and BatchNormalization layers in an ONNX graph."""
|
|
86
|
+
pattern = Pattern(
|
|
87
|
+
"""
|
|
88
|
+
input input 0 1 conv_0
|
|
89
|
+
ConvTranspose conv_0 1+ 1 input bn_0
|
|
90
|
+
BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
|
|
91
|
+
output output 1 0 bn_0
|
|
92
|
+
"""
|
|
93
|
+
)
|
|
94
|
+
super(ConvBatchNormMatcher, self).__init__(pattern, priority)
|
|
95
|
+
|
|
96
|
+
@property
|
|
97
|
+
def name(self):
|
|
98
|
+
"""Returns the name of the FusionConvTransposeBN pattern."""
|
|
99
|
+
return "FusionConvTransposeBN"
|
|
100
|
+
|
|
85
101
|
|
|
86
102
|
register_fusion_pattern(ConvBatchNormMatcher(1))
|
|
103
|
+
register_fusion_pattern(ConvTransposeBatchNormMatcher(1))
|
|
@@ -28,11 +28,10 @@ class ConvMulMatcher(PatternMatcher):
|
|
|
28
28
|
conv_weight = list(conv_node.inputs)[1]
|
|
29
29
|
if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
|
|
30
30
|
mul_constant = mul_node.inputs[1].values
|
|
31
|
-
|
|
32
|
-
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[
|
|
33
|
-
|
|
34
|
-
reshape_shape
|
|
35
|
-
|
|
31
|
+
oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
|
|
32
|
+
if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[oc_axis]:
|
|
33
|
+
reshape_shape = [1] * len(conv_weight.values.shape)
|
|
34
|
+
reshape_shape[oc_axis] = -1
|
|
36
35
|
mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
|
|
37
36
|
new_weight = conv_weight.values * mul_scale_reshaped
|
|
38
37
|
|
|
@@ -65,5 +64,24 @@ class ConvMulMatcher(PatternMatcher):
|
|
|
65
64
|
|
|
66
65
|
return match_case
|
|
67
66
|
|
|
67
|
+
class ConvTransposeMulMatcher(ConvMulMatcher):
|
|
68
|
+
def __init__(self, priority):
|
|
69
|
+
"""Initializes the ConvTransposeMulMatcher for fusing ConvTranspose and Mul layers in an ONNX graph."""
|
|
70
|
+
pattern = Pattern(
|
|
71
|
+
"""
|
|
72
|
+
input input 0 1 conv_0
|
|
73
|
+
ConvTranspose conv_0 1+ 1 input mul_0
|
|
74
|
+
Mul mul_0 2 1 conv_0 ? output
|
|
75
|
+
output output 1 0 mul_0
|
|
76
|
+
"""
|
|
77
|
+
)
|
|
78
|
+
super(ConvMulMatcher, self).__init__(pattern, priority)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def name(self):
|
|
82
|
+
"""Returns the name of the FusionConvTransposeMul pattern."""
|
|
83
|
+
return "FusionConvTransposeMul"
|
|
84
|
+
|
|
68
85
|
|
|
69
86
|
register_fusion_pattern(ConvMulMatcher(1))
|
|
87
|
+
register_fusion_pattern(ConvTransposeMulMatcher(1))
|
|
@@ -37,6 +37,8 @@ class PadConvMatcher(PatternMatcher):
|
|
|
37
37
|
pad_node_users = pad_node.users
|
|
38
38
|
|
|
39
39
|
pad_inputs = len(pad_node.inputs)
|
|
40
|
+
auto_pad = pad_node.attrs.get("auto_pad", "NOTSET")
|
|
41
|
+
|
|
40
42
|
if pad_inputs < 3 or (
|
|
41
43
|
(pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
|
|
42
44
|
or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
|
|
@@ -45,6 +47,7 @@ class PadConvMatcher(PatternMatcher):
|
|
|
45
47
|
isinstance(pad_node.inputs[1], gs.Constant)
|
|
46
48
|
and pad_node.attrs.get("mode", "constant") == "constant"
|
|
47
49
|
and conv_node.inputs[1].shape
|
|
50
|
+
and (auto_pad == "NOTSET" or auto_pad == "VALID")
|
|
48
51
|
):
|
|
49
52
|
conv_weight_dim = len(conv_node.inputs[1].shape)
|
|
50
53
|
pad_value = pad_node.inputs[1].values.tolist()
|
|
@@ -74,6 +77,8 @@ class PadConvMatcher(PatternMatcher):
|
|
|
74
77
|
pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
|
|
75
78
|
|
|
76
79
|
attrs["pads"] = pads
|
|
80
|
+
conv_node.attrs.pop("auto_pad", None)
|
|
81
|
+
|
|
77
82
|
match_case[conv_node.name] = {
|
|
78
83
|
"op": "Conv",
|
|
79
84
|
"inputs": inputs,
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Symbolic Shape Inference Module
|
|
6
|
+
|
|
7
|
+
This module provides symbolic shape inference for ONNX models. It replaces the
|
|
8
|
+
monolithic SymbolicShapeInference class with a modular, handler-based architecture.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from onnxslim.core.shape_inference import ShapeInferencer
|
|
12
|
+
|
|
13
|
+
model = onnx.load("model.onnx")
|
|
14
|
+
model_with_shapes = ShapeInferencer.infer_shapes(model)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
import onnx
|
|
20
|
+
import sympy
|
|
21
|
+
from onnx import helper
|
|
22
|
+
|
|
23
|
+
from .context import InferenceContext
|
|
24
|
+
from .registry import get_all_aten_handlers, get_all_shape_handlers, get_aten_handler, get_shape_handler
|
|
25
|
+
from .utils import (
|
|
26
|
+
get_attribute,
|
|
27
|
+
get_opset,
|
|
28
|
+
get_shape_from_type_proto,
|
|
29
|
+
get_shape_from_value_info,
|
|
30
|
+
is_literal,
|
|
31
|
+
is_sequence,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Import all handlers to trigger registration
|
|
35
|
+
from . import aten_ops # noqa: F401
|
|
36
|
+
from . import contrib_ops # noqa: F401
|
|
37
|
+
from . import standard_ops # noqa: F401
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ShapeInferencer:
|
|
43
|
+
"""Main class for performing symbolic shape inference on ONNX models."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, prefix=""):
|
|
46
|
+
"""Initialize the ShapeInferencer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
int_max: Maximum value for unbounded integers.
|
|
50
|
+
auto_merge: Whether to automatically merge conflicting dimensions.
|
|
51
|
+
guess_output_rank: Whether to guess output rank from input.
|
|
52
|
+
verbose: Logging verbosity level.
|
|
53
|
+
prefix: Prefix for generated symbolic dimension names.
|
|
54
|
+
"""
|
|
55
|
+
self.int_max_ = int_max
|
|
56
|
+
self.auto_merge_ = auto_merge
|
|
57
|
+
self.guess_output_rank_ = guess_output_rank
|
|
58
|
+
self.verbose_ = verbose
|
|
59
|
+
self.prefix_ = prefix
|
|
60
|
+
|
|
61
|
+
def _infer_impl(self, ctx, start_sympy_data=None):
|
|
62
|
+
"""Main inference implementation loop."""
|
|
63
|
+
ctx.sympy_data_ = start_sympy_data or {}
|
|
64
|
+
ctx.apply_suggested_merge(graph_input_only=True)
|
|
65
|
+
ctx.input_symbols_ = set()
|
|
66
|
+
|
|
67
|
+
# Process graph inputs
|
|
68
|
+
for i in ctx.out_mp_.graph.input:
|
|
69
|
+
input_shape = get_shape_from_value_info(i)
|
|
70
|
+
if input_shape is None:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
if is_sequence(i.type):
|
|
74
|
+
input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
|
|
75
|
+
else:
|
|
76
|
+
input_dims = i.type.tensor_type.shape.dim
|
|
77
|
+
|
|
78
|
+
for i_dim, dim in enumerate(input_shape):
|
|
79
|
+
if dim is None:
|
|
80
|
+
input_dims[i_dim].dim_param = str(ctx.new_symbolic_dim(i.name, i_dim))
|
|
81
|
+
|
|
82
|
+
ctx.input_symbols_.update([d for d in input_shape if type(d) == str])
|
|
83
|
+
|
|
84
|
+
for s in ctx.input_symbols_:
|
|
85
|
+
if s in ctx.suggested_merge_:
|
|
86
|
+
s_merge = ctx.suggested_merge_[s]
|
|
87
|
+
assert s_merge in ctx.symbolic_dims_
|
|
88
|
+
ctx.symbolic_dims_[s] = ctx.symbolic_dims_[s_merge]
|
|
89
|
+
else:
|
|
90
|
+
ctx.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
|
|
91
|
+
|
|
92
|
+
# Compute prerequisite for node for topological sort
|
|
93
|
+
prereq_for_node = {}
|
|
94
|
+
|
|
95
|
+
def get_prereq(node):
|
|
96
|
+
names = {i for i in node.input if i}
|
|
97
|
+
subgraphs = []
|
|
98
|
+
if node.op_type == "If":
|
|
99
|
+
subgraphs = [get_attribute(node, "then_branch"), get_attribute(node, "else_branch")]
|
|
100
|
+
elif node.op_type in {"Loop", "Scan"}:
|
|
101
|
+
subgraphs = [get_attribute(node, "body")]
|
|
102
|
+
for g in subgraphs:
|
|
103
|
+
g_outputs_and_initializers = {i.name for i in g.initializer}
|
|
104
|
+
g_prereq = set()
|
|
105
|
+
for n in g.node:
|
|
106
|
+
g_outputs_and_initializers.update(n.output)
|
|
107
|
+
for n in g.node:
|
|
108
|
+
g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
|
|
109
|
+
names.update(g_prereq)
|
|
110
|
+
for i in g.input:
|
|
111
|
+
if i.name in names:
|
|
112
|
+
names.remove(i.name)
|
|
113
|
+
return names
|
|
114
|
+
|
|
115
|
+
for n in ctx.out_mp_.graph.node:
|
|
116
|
+
prereq_for_node[n.output[0]] = get_prereq(n)
|
|
117
|
+
|
|
118
|
+
# Topological sort nodes
|
|
119
|
+
sorted_nodes = []
|
|
120
|
+
sorted_known_vi = {i.name for i in list(ctx.out_mp_.graph.input) + list(ctx.out_mp_.graph.initializer)}
|
|
121
|
+
if any(o.name in sorted_known_vi for o in ctx.out_mp_.graph.output):
|
|
122
|
+
sorted_nodes = ctx.out_mp_.graph.node
|
|
123
|
+
else:
|
|
124
|
+
while any(o.name not in sorted_known_vi for o in ctx.out_mp_.graph.output):
|
|
125
|
+
old_sorted_nodes_len = len(sorted_nodes)
|
|
126
|
+
for node in ctx.out_mp_.graph.node:
|
|
127
|
+
if node.output[0] not in sorted_known_vi and all(
|
|
128
|
+
i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
|
|
129
|
+
):
|
|
130
|
+
sorted_known_vi.update(node.output)
|
|
131
|
+
sorted_nodes.append(node)
|
|
132
|
+
if old_sorted_nodes_len == len(sorted_nodes) and not all(
|
|
133
|
+
o.name in sorted_known_vi for o in ctx.out_mp_.graph.output
|
|
134
|
+
):
|
|
135
|
+
raise Exception("Invalid model with cyclic graph")
|
|
136
|
+
|
|
137
|
+
# Get handlers
|
|
138
|
+
shape_handlers = get_all_shape_handlers()
|
|
139
|
+
aten_handlers = get_all_aten_handlers()
|
|
140
|
+
|
|
141
|
+
# Process each node
|
|
142
|
+
for node in sorted_nodes:
|
|
143
|
+
assert all([i in ctx.known_vi_ for i in node.input if i])
|
|
144
|
+
ctx.onnx_infer_single_node(node)
|
|
145
|
+
known_aten_op = False
|
|
146
|
+
|
|
147
|
+
# Try standard handlers first
|
|
148
|
+
handler = get_shape_handler(node.op_type)
|
|
149
|
+
if handler is not None:
|
|
150
|
+
handler.infer_shape(node, ctx)
|
|
151
|
+
elif node.op_type == "ConvTranspose":
|
|
152
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
153
|
+
if len(vi.type.tensor_type.shape.dim) == 0:
|
|
154
|
+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
|
|
155
|
+
elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
|
|
156
|
+
for attr in node.attribute:
|
|
157
|
+
if attr.name == "operator":
|
|
158
|
+
aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
|
159
|
+
aten_handler = get_aten_handler(aten_op_name)
|
|
160
|
+
if aten_handler is not None:
|
|
161
|
+
known_aten_op = True
|
|
162
|
+
aten_handler.infer_shape(node, ctx)
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
if ctx.verbose_ > 2:
|
|
166
|
+
logger.debug(node.op_type + ": " + node.name)
|
|
167
|
+
for i, name in enumerate(node.input):
|
|
168
|
+
logger.debug(f" Input {i}: {name} {'initializer' if name in ctx.initializers_ else ''}")
|
|
169
|
+
|
|
170
|
+
# Handle dimension merging for broadcast ops
|
|
171
|
+
if node.op_type in {
|
|
172
|
+
"Add",
|
|
173
|
+
"Sub",
|
|
174
|
+
"Mul",
|
|
175
|
+
"Div",
|
|
176
|
+
"MatMul",
|
|
177
|
+
"MatMulInteger",
|
|
178
|
+
"MatMulInteger16",
|
|
179
|
+
"Where",
|
|
180
|
+
"Sum",
|
|
181
|
+
}:
|
|
182
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
183
|
+
out_rank = len(get_shape_from_type_proto(vi.type))
|
|
184
|
+
in_shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
|
|
185
|
+
for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
|
|
186
|
+
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
|
|
187
|
+
if len(in_dims) > 1:
|
|
188
|
+
ctx.check_merged_dims(in_dims, allow_broadcast=True)
|
|
189
|
+
|
|
190
|
+
# Process outputs
|
|
191
|
+
for i_o in range(len(node.output)):
|
|
192
|
+
if node.op_type in {"SkipLayerNormalization", "SkipSimplifiedLayerNormalization"} and i_o in {1, 2}:
|
|
193
|
+
continue
|
|
194
|
+
if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
vi = ctx.known_vi_[node.output[i_o]]
|
|
198
|
+
out_type = vi.type
|
|
199
|
+
out_type_kind = out_type.WhichOneof("value")
|
|
200
|
+
|
|
201
|
+
if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
|
|
202
|
+
if ctx.verbose_ > 2:
|
|
203
|
+
if out_type_kind == "sequence_type":
|
|
204
|
+
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
|
|
205
|
+
if seq_cls_type == "tensor_type":
|
|
206
|
+
logger.debug(
|
|
207
|
+
f" {node.output[i_o]}: sequence of {str(get_shape_from_value_info(vi))} "
|
|
208
|
+
f"{onnx.TensorProto.DataType.Name(vi.type.sequence_type.elem_type.tensor_type.elem_type)}"
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
|
|
212
|
+
else:
|
|
213
|
+
logger.debug(f" {node.output[i_o]}: {out_type_kind}")
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
out_shape = get_shape_from_value_info(vi)
|
|
217
|
+
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
|
|
218
|
+
if ctx.verbose_ > 2:
|
|
219
|
+
logger.debug(
|
|
220
|
+
f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
|
|
221
|
+
)
|
|
222
|
+
if node.output[i_o] in ctx.sympy_data_:
|
|
223
|
+
logger.debug(" Sympy Data: " + str(ctx.sympy_data_[node.output[i_o]]))
|
|
224
|
+
|
|
225
|
+
if (out_shape is not None and (None in out_shape or ctx.is_shape_contains_none_dim(out_shape))) or out_type_undefined:
|
|
226
|
+
if ctx.auto_merge_:
|
|
227
|
+
if node.op_type in {
|
|
228
|
+
"Add",
|
|
229
|
+
"Sub",
|
|
230
|
+
"Mul",
|
|
231
|
+
"Div",
|
|
232
|
+
"MatMul",
|
|
233
|
+
"MatMulInteger",
|
|
234
|
+
"MatMulInteger16",
|
|
235
|
+
"Concat",
|
|
236
|
+
"Where",
|
|
237
|
+
"Sum",
|
|
238
|
+
"Equal",
|
|
239
|
+
"Less",
|
|
240
|
+
"Greater",
|
|
241
|
+
"LessOrEqual",
|
|
242
|
+
"GreaterOrEqual",
|
|
243
|
+
"Min",
|
|
244
|
+
"Max",
|
|
245
|
+
}:
|
|
246
|
+
shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
|
|
247
|
+
if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} and (
|
|
248
|
+
None in out_shape or ctx.is_shape_contains_none_dim(out_shape)
|
|
249
|
+
):
|
|
250
|
+
if None in out_shape:
|
|
251
|
+
idx = out_shape.index(None)
|
|
252
|
+
else:
|
|
253
|
+
idx = out_shape.index(ctx.is_shape_contains_none_dim(out_shape))
|
|
254
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
255
|
+
assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
|
|
256
|
+
assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
|
|
257
|
+
elif node.op_type == "Expand":
|
|
258
|
+
shapes = [ctx.get_shape(node, 0), ctx.get_value(node, 1)]
|
|
259
|
+
else:
|
|
260
|
+
shapes = []
|
|
261
|
+
|
|
262
|
+
if shapes:
|
|
263
|
+
for idx in range(len(out_shape)):
|
|
264
|
+
if out_shape[idx] is not None and not ctx.is_none_dim(out_shape[idx]):
|
|
265
|
+
continue
|
|
266
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
267
|
+
if dim_idx:
|
|
268
|
+
ctx.add_suggested_merge(
|
|
269
|
+
[s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx) if i >= 0]
|
|
270
|
+
)
|
|
271
|
+
ctx.run_ = True
|
|
272
|
+
else:
|
|
273
|
+
ctx.run_ = False
|
|
274
|
+
else:
|
|
275
|
+
ctx.run_ = False
|
|
276
|
+
|
|
277
|
+
if not ctx.run_ and handler is None and not known_aten_op:
|
|
278
|
+
is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
|
|
279
|
+
if is_unknown_op:
|
|
280
|
+
out_rank = ctx.get_shape_rank(node, 0) if ctx.guess_output_rank_ else -1
|
|
281
|
+
else:
|
|
282
|
+
out_rank = len(out_shape)
|
|
283
|
+
|
|
284
|
+
if out_rank >= 0:
|
|
285
|
+
new_shape = ctx.new_symbolic_shape(out_rank, node, i_o)
|
|
286
|
+
if out_type_undefined:
|
|
287
|
+
out_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
288
|
+
else:
|
|
289
|
+
out_dtype = vi.type.tensor_type.elem_type
|
|
290
|
+
from .utils import get_shape_from_sympy_shape
|
|
291
|
+
|
|
292
|
+
vi.CopyFrom(
|
|
293
|
+
helper.make_tensor_value_info(vi.name, out_dtype, get_shape_from_sympy_shape(new_shape))
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if ctx.verbose_ > 0:
|
|
297
|
+
if is_unknown_op:
|
|
298
|
+
logger.debug(f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape")
|
|
299
|
+
if ctx.verbose_ > 2:
|
|
300
|
+
logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
|
|
301
|
+
ctx.run_ = True
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
if ctx.verbose_ > 0 or not ctx.auto_merge_ or out_type_undefined:
|
|
305
|
+
logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
|
|
306
|
+
logger.debug("node inputs:")
|
|
307
|
+
for i in node.input:
|
|
308
|
+
if i in ctx.known_vi_:
|
|
309
|
+
logger.debug(ctx.known_vi_[i])
|
|
310
|
+
else:
|
|
311
|
+
logger.debug(f"not in known_vi_ for {i}")
|
|
312
|
+
logger.debug("node outputs:")
|
|
313
|
+
for o in node.output:
|
|
314
|
+
if o in ctx.known_vi_:
|
|
315
|
+
logger.debug(ctx.known_vi_[o])
|
|
316
|
+
else:
|
|
317
|
+
logger.debug(f"not in known_vi_ for {o}")
|
|
318
|
+
if ctx.auto_merge_ and not out_type_undefined:
|
|
319
|
+
logger.debug("Merging: " + str(ctx.suggested_merge_))
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
ctx.run_ = False
|
|
323
|
+
return True
|
|
324
|
+
|
|
325
|
+
def _update_output_from_vi(self, ctx):
|
|
326
|
+
"""Update output attributes using known value information dictionary."""
|
|
327
|
+
for output in ctx.out_mp_.graph.output:
|
|
328
|
+
if output.name in ctx.known_vi_:
|
|
329
|
+
output.CopyFrom(ctx.known_vi_[output.name])
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
|
333
|
+
"""Perform symbolic shape inference on an ONNX model.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
in_mp: The input ONNX ModelProto.
|
|
337
|
+
int_max: Maximum value for unbounded integers.
|
|
338
|
+
auto_merge: Whether to automatically merge conflicting dimensions.
|
|
339
|
+
guess_output_rank: Whether to guess output rank from input.
|
|
340
|
+
verbose: Logging verbosity level.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
The model with inferred shapes.
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
Exception: If shape inference is incomplete.
|
|
347
|
+
"""
|
|
348
|
+
onnx_opset = get_opset(in_mp)
|
|
349
|
+
if (not onnx_opset) or onnx_opset < 7:
|
|
350
|
+
logger.warning("Only support models of onnx opset 7 and above.")
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
inferencer = ShapeInferencer(int_max, auto_merge, guess_output_rank, verbose)
|
|
354
|
+
|
|
355
|
+
# Create inference context
|
|
356
|
+
ctx = InferenceContext(
|
|
357
|
+
in_mp,
|
|
358
|
+
int_max=int_max,
|
|
359
|
+
auto_merge=auto_merge,
|
|
360
|
+
guess_output_rank=guess_output_rank,
|
|
361
|
+
verbose=verbose,
|
|
362
|
+
)
|
|
363
|
+
ctx.preprocess()
|
|
364
|
+
|
|
365
|
+
all_shapes_inferred = False
|
|
366
|
+
while ctx.run_:
|
|
367
|
+
all_shapes_inferred = inferencer._infer_impl(ctx)
|
|
368
|
+
|
|
369
|
+
inferencer._update_output_from_vi(ctx)
|
|
370
|
+
|
|
371
|
+
if not all_shapes_inferred:
|
|
372
|
+
raise Exception("Incomplete symbolic shape inference")
|
|
373
|
+
|
|
374
|
+
return ctx.out_mp_
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# For backward compatibility
|
|
378
|
+
SymbolicShapeInference = ShapeInferencer
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""PyTorch ATen operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import bitwise_or
|
|
7
|
+
from . import diagonal
|
|
8
|
+
from . import pool2d
|
|
9
|
+
from . import min_max
|
|
10
|
+
from . import multinomial
|
|
11
|
+
from . import unfold
|
|
12
|
+
from . import argmax
|
|
13
|
+
from . import group_norm
|
|
14
|
+
from . import upsample
|
|
15
|
+
from . import embedding
|
|
16
|
+
from . import numpy_t
|