onnxslim 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for QuickGelu operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class QuickGeluHandler(ShapeHandler):
|
|
11
|
+
"""Handler for QuickGelu operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "QuickGelu"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(QuickGeluHandler())
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for RotaryEmbedding operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RotaryEmbeddingHandler(ShapeHandler):
|
|
11
|
+
"""Handler for RotaryEmbedding operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "RotaryEmbedding"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
if len(node.output) == 1:
|
|
19
|
+
ctx.propagate_shape_and_type(node)
|
|
20
|
+
elif len(node.output) == 2:
|
|
21
|
+
# Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
|
|
22
|
+
ctx.propagate_shape_and_type(node, input_index=1, output_index=0)
|
|
23
|
+
ctx.propagate_shape_and_type(node, input_index=0, output_index=1)
|
|
24
|
+
elif len(node.output) == 3:
|
|
25
|
+
# Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
|
|
26
|
+
ctx.propagate_shape_and_type(node, input_index=1, output_index=0)
|
|
27
|
+
ctx.propagate_shape_and_type(node, input_index=1, output_index=1)
|
|
28
|
+
ctx.propagate_shape_and_type(node, input_index=0, output_index=2)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
register_shape_handler(RotaryEmbeddingHandler())
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Normalization-related contrib operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import layer_normalization
|
|
7
|
+
from . import simplified_layer_normalization
|
|
8
|
+
from . import skip_layer_normalization
|
|
9
|
+
from . import skip_simplified_layer_normalization
|
|
10
|
+
from . import group_norm
|
|
11
|
+
from . import skip_group_norm
|
|
12
|
+
from . import embed_layer_normalization
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for EmbedLayerNormalization operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbedLayerNormalizationHandler(ShapeHandler):
|
|
14
|
+
"""Handler for EmbedLayerNormalization operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "EmbedLayerNormalization"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
input_ids_shape = ctx.get_shape(node, 0)
|
|
22
|
+
word_embedding_shape = ctx.get_shape(node, 2)
|
|
23
|
+
assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
|
|
24
|
+
output_shape = [*input_ids_shape, word_embedding_shape[1]]
|
|
25
|
+
|
|
26
|
+
word_embedding_dtype = ctx.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
27
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
28
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
|
|
29
|
+
|
|
30
|
+
if len(node.output) > 1 and node.output[1]:
|
|
31
|
+
mask_index_shape = [input_ids_shape[0]]
|
|
32
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
33
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
|
|
34
|
+
|
|
35
|
+
if len(node.output) > 2:
|
|
36
|
+
# Optional output of add before layer normalization is done
|
|
37
|
+
vi = ctx.known_vi_[node.output[2]]
|
|
38
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_shape_handler(EmbedLayerNormalizationHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GroupNorm operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GroupNormHandler(ShapeHandler):
|
|
11
|
+
"""Handler for GroupNorm operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "GroupNorm"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(GroupNormHandler())
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for LayerNormalization operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_attribute, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class LayerNormalizationHandler(ShapeHandler):
|
|
15
|
+
"""Handler for LayerNormalization operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "LayerNormalization"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
ctx.propagate_shape_and_type(node)
|
|
23
|
+
if len(node.output) > 1:
|
|
24
|
+
axis = get_attribute(node, "axis")
|
|
25
|
+
if axis is None:
|
|
26
|
+
axis = -1
|
|
27
|
+
x_shape = ctx.get_shape(node, 0)
|
|
28
|
+
if x_shape is not None:
|
|
29
|
+
rank = len(x_shape)
|
|
30
|
+
axis = handle_negative_axis(axis, rank)
|
|
31
|
+
mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
|
|
32
|
+
mean_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
33
|
+
if mean_dtype in {onnx.TensorProto.FLOAT16, onnx.TensorProto.BFLOAT16}:
|
|
34
|
+
mean_dtype = onnx.TensorProto.FLOAT
|
|
35
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
36
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
|
|
37
|
+
if len(node.output) > 2:
|
|
38
|
+
vi = ctx.known_vi_[node.output[2]]
|
|
39
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape))
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
register_shape_handler(LayerNormalizationHandler())
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SimplifiedLayerNormalization operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from .layer_normalization import LayerNormalizationHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SimplifiedLayerNormalizationHandler(ShapeHandler):
|
|
12
|
+
"""Handler for SimplifiedLayerNormalization operator."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "SimplifiedLayerNormalization"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
# Reuse LayerNormalization handler
|
|
20
|
+
LayerNormalizationHandler().infer_shape(node, ctx)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
register_shape_handler(SimplifiedLayerNormalizationHandler())
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SkipGroupNorm operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SkipGroupNormHandler(ShapeHandler):
|
|
11
|
+
"""Handler for SkipGroupNorm operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "SkipGroupNorm"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node, 0, 0)
|
|
19
|
+
if len(node.output) > 1:
|
|
20
|
+
ctx.propagate_shape_and_type(node, 0, 1)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
register_shape_handler(SkipGroupNormHandler())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SkipLayerNormalization operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SkipLayerNormalizationHandler(ShapeHandler):
|
|
11
|
+
"""Handler for SkipLayerNormalization operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "SkipLayerNormalization"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
# If the SkipLayerNormalization node contains the optional
|
|
21
|
+
# output for inference, infer the shape and type for it too
|
|
22
|
+
if len(node.output) > 3:
|
|
23
|
+
ctx.propagate_shape_and_type(node, 0, 3)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(SkipLayerNormalizationHandler())
|
onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SkipSimplifiedLayerNormalization operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from .skip_layer_normalization import SkipLayerNormalizationHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SkipSimplifiedLayerNormalizationHandler(ShapeHandler):
|
|
12
|
+
"""Handler for SkipSimplifiedLayerNormalization operator."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "SkipSimplifiedLayerNormalization"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
# Reuse SkipLayerNormalization handler
|
|
20
|
+
SkipLayerNormalizationHandler().infer_shape(node, ctx)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
register_shape_handler(SkipSimplifiedLayerNormalizationHandler())
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Registry for shape inference handlers."""
|
|
5
|
+
|
|
6
|
+
from collections import OrderedDict
|
|
7
|
+
|
|
8
|
+
# Global registries for shape handlers
|
|
9
|
+
SHAPE_HANDLERS = OrderedDict()
|
|
10
|
+
ATEN_SHAPE_HANDLERS = OrderedDict()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_shape_handler(handler):
|
|
14
|
+
"""Register a shape handler for a specific ONNX operator type.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
handler: A ShapeHandler instance to register.
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
The registered handler.
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: If a handler for the same op_type is already registered.
|
|
24
|
+
"""
|
|
25
|
+
op_type = handler.op_type
|
|
26
|
+
if op_type in SHAPE_HANDLERS:
|
|
27
|
+
raise ValueError(f"Handler for op_type '{op_type}' is already registered")
|
|
28
|
+
SHAPE_HANDLERS[op_type] = handler
|
|
29
|
+
return handler
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def register_aten_handler(handler):
|
|
33
|
+
"""Register a shape handler for a PyTorch ATen operator.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
handler: A ShapeHandler instance to register.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
The registered handler.
|
|
40
|
+
|
|
41
|
+
Raises:
|
|
42
|
+
ValueError: If a handler for the same op_name is already registered.
|
|
43
|
+
"""
|
|
44
|
+
op_name = handler.op_type
|
|
45
|
+
if op_name in ATEN_SHAPE_HANDLERS:
|
|
46
|
+
raise ValueError(f"Handler for ATen op '{op_name}' is already registered")
|
|
47
|
+
ATEN_SHAPE_HANDLERS[op_name] = handler
|
|
48
|
+
return handler
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_shape_handler(op_type):
|
|
52
|
+
"""Get the shape handler for a given ONNX operator type.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
op_type: The ONNX operator type string.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
The registered ShapeHandler or None if not found.
|
|
59
|
+
"""
|
|
60
|
+
return SHAPE_HANDLERS.get(op_type)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def get_aten_handler(op_name):
|
|
64
|
+
"""Get the shape handler for a given ATen operator name.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
op_name: The ATen operator name string.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
The registered ShapeHandler or None if not found.
|
|
71
|
+
"""
|
|
72
|
+
return ATEN_SHAPE_HANDLERS.get(op_name)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_all_shape_handlers():
|
|
76
|
+
"""Get all registered shape handlers.
|
|
77
|
+
|
|
78
|
+
Returns:
|
|
79
|
+
OrderedDict of all registered shape handlers.
|
|
80
|
+
"""
|
|
81
|
+
return SHAPE_HANDLERS.copy()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_all_aten_handlers():
|
|
85
|
+
"""Get all registered ATen shape handlers.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
OrderedDict of all registered ATen shape handlers.
|
|
89
|
+
"""
|
|
90
|
+
return ATEN_SHAPE_HANDLERS.copy()
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Standard ONNX operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import tensor
|
|
7
|
+
from . import math
|
|
8
|
+
from . import nn
|
|
9
|
+
from . import control_flow
|
|
10
|
+
from . import sequence
|
|
11
|
+
from . import misc
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for If operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ...utils import as_scalar, get_attribute
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class IfHandler(ShapeHandler):
|
|
12
|
+
"""Handler for If operator."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "If"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
subgraphs = [
|
|
20
|
+
get_attribute(node, "then_branch"),
|
|
21
|
+
get_attribute(node, "else_branch"),
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
cond = ctx.try_get_value(node, 0)
|
|
25
|
+
|
|
26
|
+
for i_sub, subgraph in enumerate(subgraphs):
|
|
27
|
+
subgraph_infer = ctx.onnx_infer_subgraph(node, subgraph, use_node_input=False)
|
|
28
|
+
for i_out in range(len(node.output)):
|
|
29
|
+
vi = ctx.known_vi_[node.output[i_out]]
|
|
30
|
+
if i_sub == 0:
|
|
31
|
+
vi.CopyFrom(subgraph.output[i_out])
|
|
32
|
+
vi.name = node.output[i_out]
|
|
33
|
+
else:
|
|
34
|
+
ctx.fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
|
|
35
|
+
if (
|
|
36
|
+
cond is not None
|
|
37
|
+
and i_sub == (0 if as_scalar(cond) > 0 else 1)
|
|
38
|
+
and subgraph.output[i_out].name in subgraph_infer.sympy_data_
|
|
39
|
+
):
|
|
40
|
+
ctx.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
register_shape_handler(IfHandler())
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Loop operator."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import onnx
|
|
9
|
+
|
|
10
|
+
from ...base import ShapeHandler
|
|
11
|
+
from ...registry import register_shape_handler
|
|
12
|
+
from ...utils import get_attribute, get_shape_from_value_info, is_sequence
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class LoopHandler(ShapeHandler):
|
|
18
|
+
"""Handler for Loop operator."""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def op_type(self) -> str:
|
|
22
|
+
return "Loop"
|
|
23
|
+
|
|
24
|
+
def infer_shape(self, node, ctx) -> None:
|
|
25
|
+
subgraph = get_attribute(node, "body")
|
|
26
|
+
assert len(subgraph.input) == len(node.input)
|
|
27
|
+
num_loop_carried = len(node.input) - 2
|
|
28
|
+
|
|
29
|
+
for i, si in enumerate(subgraph.input):
|
|
30
|
+
si_name = si.name
|
|
31
|
+
si.CopyFrom(ctx.known_vi_[node.input[i]])
|
|
32
|
+
si.name = si_name
|
|
33
|
+
|
|
34
|
+
ctx.onnx_infer_subgraph(node, subgraph)
|
|
35
|
+
|
|
36
|
+
need_second_infer = False
|
|
37
|
+
for i_out in range(1, num_loop_carried + 1):
|
|
38
|
+
so = subgraph.output[i_out]
|
|
39
|
+
so_shape = get_shape_from_value_info(so)
|
|
40
|
+
if is_sequence(so.type):
|
|
41
|
+
if so_shape and None in so_shape:
|
|
42
|
+
subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
|
|
43
|
+
need_second_infer = True
|
|
44
|
+
else:
|
|
45
|
+
si = subgraph.input[i_out + 1]
|
|
46
|
+
si_shape = get_shape_from_value_info(si)
|
|
47
|
+
for di, dims in enumerate(zip(si_shape, so_shape)):
|
|
48
|
+
if dims[0] != dims[1]:
|
|
49
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
50
|
+
new_dim.dim_param = str(ctx.new_symbolic_dim_from_output(node, i_out, di))
|
|
51
|
+
si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
52
|
+
so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
53
|
+
need_second_infer = True
|
|
54
|
+
|
|
55
|
+
if need_second_infer:
|
|
56
|
+
if ctx.verbose_ > 2:
|
|
57
|
+
logger.debug(f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables")
|
|
58
|
+
ctx.onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
|
|
59
|
+
|
|
60
|
+
loop_iter_dim = str(ctx.new_symbolic_dim_from_output(node))
|
|
61
|
+
for i in range(len(node.output)):
|
|
62
|
+
vi = ctx.known_vi_[node.output[i]]
|
|
63
|
+
vi.CopyFrom(subgraph.output[i + 1])
|
|
64
|
+
if i >= num_loop_carried:
|
|
65
|
+
assert not is_sequence(vi.type)
|
|
66
|
+
subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
|
|
67
|
+
vi.type.tensor_type.shape.ClearField("dim")
|
|
68
|
+
vi_dim = vi.type.tensor_type.shape.dim
|
|
69
|
+
vi_dim.add().dim_param = loop_iter_dim
|
|
70
|
+
vi_dim.extend(list(subgraph_vi_dim))
|
|
71
|
+
vi.name = node.output[i]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
register_shape_handler(LoopHandler())
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Scan operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_shape_from_type_proto, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ScanHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Scan operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Scan"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
subgraph = get_attribute(node, "body")
|
|
22
|
+
num_scan_inputs = get_attribute(node, "num_scan_inputs")
|
|
23
|
+
scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
|
|
24
|
+
num_scan_states = len(node.input) - num_scan_inputs
|
|
25
|
+
scan_input_axes = [
|
|
26
|
+
handle_negative_axis(ax, ctx.get_shape_rank(node, i + num_scan_states)) for i, ax in enumerate(scan_input_axes)
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
assert len(subgraph.input) >= len(node.input)
|
|
30
|
+
subgraph_inputs = subgraph.input[: len(node.input)]
|
|
31
|
+
for i, si in enumerate(subgraph_inputs):
|
|
32
|
+
subgraph_name = si.name
|
|
33
|
+
si.CopyFrom(ctx.known_vi_[node.input[i]])
|
|
34
|
+
if i >= num_scan_states:
|
|
35
|
+
scan_input_dim = si.type.tensor_type.shape.dim
|
|
36
|
+
scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
|
|
37
|
+
si.name = subgraph_name
|
|
38
|
+
ctx.onnx_infer_subgraph(node, subgraph)
|
|
39
|
+
num_scan_outputs = len(node.output) - num_scan_states
|
|
40
|
+
scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
|
|
41
|
+
scan_input_dim = get_shape_from_type_proto(ctx.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
|
|
42
|
+
for i, o in enumerate(node.output):
|
|
43
|
+
vi = ctx.known_vi_[o]
|
|
44
|
+
if i >= num_scan_states:
|
|
45
|
+
shape = get_shape_from_type_proto(subgraph.output[i].type)
|
|
46
|
+
new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
|
|
47
|
+
shape = [*shape[:new_dim], scan_input_dim, *shape[new_dim:]]
|
|
48
|
+
vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
|
|
49
|
+
else:
|
|
50
|
+
vi.CopyFrom(subgraph.output[i])
|
|
51
|
+
vi.name = o
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
register_shape_handler(ScanHandler())
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Mathematical operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import matmul
|
|
7
|
+
from . import matmul_integer
|
|
8
|
+
from . import einsum
|
|
9
|
+
from . import reduce_sum
|
|
10
|
+
from . import reduce_prod
|
|
11
|
+
from . import add
|
|
12
|
+
from . import sub
|
|
13
|
+
from . import mul
|
|
14
|
+
from . import div
|
|
15
|
+
from . import neg
|
|
16
|
+
from . import floor
|
|
17
|
+
from . import min
|
|
18
|
+
from . import max
|
|
19
|
+
from . import equal
|
|
20
|
+
from . import where
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shared symbolic computation helper for math operators."""
|
|
5
|
+
|
|
6
|
+
import sympy
|
|
7
|
+
|
|
8
|
+
from ...utils import is_literal
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def infer_symbolic_compute_ops(node, ctx):
|
|
12
|
+
"""Handles symbolic computation operations for given node based on predefined functions."""
|
|
13
|
+
funcs = {
|
|
14
|
+
"Add": lambda l: l[0] + l[1],
|
|
15
|
+
"Div": lambda l: (int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1]),
|
|
16
|
+
"Equal": lambda l: l[0] == l[1],
|
|
17
|
+
"Floor": lambda l: sympy.floor(l[0]),
|
|
18
|
+
"Max": lambda l: (
|
|
19
|
+
l[1]
|
|
20
|
+
if is_literal(l[0]) and int(l[0]) < -ctx.int_max_
|
|
21
|
+
else (l[0] if is_literal(l[1]) and int(l[1]) < -ctx.int_max_ else sympy.Max(l[0], l[1]))
|
|
22
|
+
),
|
|
23
|
+
"Min": lambda l: (
|
|
24
|
+
l[1]
|
|
25
|
+
if is_literal(l[0]) and int(l[0]) > ctx.int_max_
|
|
26
|
+
else (l[0] if is_literal(l[1]) and int(l[1]) > ctx.int_max_ else sympy.Min(l[0], l[1]))
|
|
27
|
+
),
|
|
28
|
+
"Mul": lambda l: (int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1]),
|
|
29
|
+
"Sub": lambda l: l[0] - l[1],
|
|
30
|
+
"Where": lambda l: l[1] if l[0] else l[2],
|
|
31
|
+
"Neg": lambda l: -l[0],
|
|
32
|
+
}
|
|
33
|
+
assert node.op_type in funcs
|
|
34
|
+
ctx.compute_on_sympy_data(node, funcs[node.op_type])
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Add operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Add", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Div operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Div", infer_symbolic_compute_ops))
|