onnxslim 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SequenceAt operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class SequenceAtHandler(ShapeHandler):
|
|
13
|
+
"""Handler for SequenceAt operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "SequenceAt"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
seq_shape = ctx.get_shape(node, 0)
|
|
21
|
+
if seq_shape is not None:
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
for di, d in enumerate(seq_shape):
|
|
24
|
+
if d is not None:
|
|
25
|
+
continue
|
|
26
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
27
|
+
new_dim.dim_param = str(ctx.new_symbolic_dim_from_output(node, 0, di))
|
|
28
|
+
vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
register_shape_handler(SequenceAtHandler())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SequenceInsert operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SequenceInsertHandler(ShapeHandler):
|
|
11
|
+
"""Handler for SequenceInsert operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "SequenceInsert"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
vi_seq = ctx.known_vi_[node.input[0]]
|
|
19
|
+
vi_tensor = ctx.known_vi_[node.input[1]]
|
|
20
|
+
vi_out_seq = ctx.known_vi_[node.output[0]]
|
|
21
|
+
vi_out_seq.CopyFrom(vi_seq)
|
|
22
|
+
vi_out_seq.name = node.output[0]
|
|
23
|
+
ctx.fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(SequenceInsertHandler())
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SplitToSequence operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ..tensor.split import infer_split_common
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class SplitToSequenceHandler(ShapeHandler):
|
|
14
|
+
"""Handler for SplitToSequence operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "SplitToSequence"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
infer_split_common(node, ctx, helper.make_sequence_value_info)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
register_shape_handler(SplitToSequenceHandler())
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ZipMap operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ZipMapHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ZipMap operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "ZipMap"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
map_key_type = None
|
|
22
|
+
if get_attribute(node, "classlabels_int64s") is not None:
|
|
23
|
+
map_key_type = onnx.TensorProto.INT64
|
|
24
|
+
elif get_attribute(node, "classlabels_strings") is not None:
|
|
25
|
+
map_key_type = onnx.TensorProto.STRING
|
|
26
|
+
|
|
27
|
+
assert map_key_type is not None
|
|
28
|
+
new_vi = onnx.ValueInfoProto()
|
|
29
|
+
new_vi.name = node.output[0]
|
|
30
|
+
new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
|
|
31
|
+
new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
|
|
32
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
33
|
+
vi.CopyFrom(new_vi)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
register_shape_handler(ZipMapHandler())
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Tensor manipulation operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import concat
|
|
7
|
+
from . import gather
|
|
8
|
+
from . import gather_elements
|
|
9
|
+
from . import gather_nd
|
|
10
|
+
from . import reshape
|
|
11
|
+
from . import slice
|
|
12
|
+
from . import split
|
|
13
|
+
from . import squeeze
|
|
14
|
+
from . import unsqueeze
|
|
15
|
+
from . import transpose
|
|
16
|
+
from . import tile
|
|
17
|
+
from . import expand
|
|
18
|
+
from . import pad
|
|
19
|
+
from . import shape
|
|
20
|
+
from . import size
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Concat operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_shape_from_sympy_shape, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConcatHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Concat operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Concat"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
if any(i in ctx.sympy_data_ or i in ctx.initializers_ for i in node.input):
|
|
22
|
+
values = ctx.get_int_or_float_values(node)
|
|
23
|
+
if all(v is not None for v in values):
|
|
24
|
+
assert get_attribute(node, "axis") == 0
|
|
25
|
+
ctx.sympy_data_[node.output[0]] = []
|
|
26
|
+
for i in range(len(node.input)):
|
|
27
|
+
value = values[i]
|
|
28
|
+
if isinstance(value, list):
|
|
29
|
+
ctx.sympy_data_[node.output[0]].extend(value)
|
|
30
|
+
else:
|
|
31
|
+
ctx.sympy_data_[node.output[0]].append(value)
|
|
32
|
+
|
|
33
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
34
|
+
axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
|
|
35
|
+
for i_idx in range(1, len(node.input)):
|
|
36
|
+
input_shape = ctx.get_sympy_shape(node, i_idx)
|
|
37
|
+
if input_shape:
|
|
38
|
+
sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
|
|
39
|
+
ctx.update_computed_dims(sympy_shape)
|
|
40
|
+
# merge symbolic dims for non-concat axes
|
|
41
|
+
for d in range(len(sympy_shape)):
|
|
42
|
+
if d == axis:
|
|
43
|
+
continue
|
|
44
|
+
dims = [ctx.get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if ctx.get_shape(node, i_idx)]
|
|
45
|
+
if all(d == dims[0] for d in dims):
|
|
46
|
+
continue
|
|
47
|
+
merged = ctx.merge_symbols(dims)
|
|
48
|
+
if type(merged) == str:
|
|
49
|
+
sympy_shape[d] = ctx.symbolic_dims_[merged] if merged else None
|
|
50
|
+
else:
|
|
51
|
+
sympy_shape[d] = merged
|
|
52
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
53
|
+
vi.CopyFrom(
|
|
54
|
+
helper.make_tensor_value_info(
|
|
55
|
+
node.output[0],
|
|
56
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
57
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
58
|
+
)
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
register_shape_handler(ConcatHandler())
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Expand operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import as_list, get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ExpandHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Expand operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Expand"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
expand_to_shape = as_list(ctx.try_get_value(node, 1), keep_none=True)
|
|
22
|
+
if expand_to_shape is not None:
|
|
23
|
+
ctx.update_computed_dims(expand_to_shape)
|
|
24
|
+
shape = ctx.get_shape(node, 0)
|
|
25
|
+
new_shape = ctx.broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
|
|
26
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
27
|
+
vi.CopyFrom(
|
|
28
|
+
helper.make_tensor_value_info(
|
|
29
|
+
node.output[0],
|
|
30
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
31
|
+
new_shape,
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
register_shape_handler(ExpandHandler())
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Gather operator."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_attribute, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GatherHandler(ShapeHandler):
|
|
15
|
+
"""Handler for Gather operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "Gather"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
data_shape = ctx.get_shape(node, 0)
|
|
23
|
+
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
|
|
24
|
+
indices_shape = ctx.get_shape(node, 1)
|
|
25
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
26
|
+
vi.CopyFrom(
|
|
27
|
+
helper.make_tensor_value_info(
|
|
28
|
+
node.output[0],
|
|
29
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
30
|
+
data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
|
|
31
|
+
)
|
|
32
|
+
)
|
|
33
|
+
# for 1D input, do some sympy compute
|
|
34
|
+
if node.input[0] in ctx.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0:
|
|
35
|
+
idx = ctx.try_get_value(node, 1)
|
|
36
|
+
if idx is not None:
|
|
37
|
+
data = ctx.sympy_data_[node.input[0]]
|
|
38
|
+
if type(data) == list:
|
|
39
|
+
if type(idx) == np.ndarray and len(idx.shape) == 1:
|
|
40
|
+
ctx.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
|
|
41
|
+
else:
|
|
42
|
+
ctx.sympy_data_[node.output[0]] = data[int(idx)]
|
|
43
|
+
else:
|
|
44
|
+
assert idx in {0, -1}
|
|
45
|
+
ctx.sympy_data_[node.output[0]] = data
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
register_shape_handler(GatherHandler())
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GatherElements operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GatherElementsHandler(ShapeHandler):
|
|
13
|
+
"""Handler for GatherElements operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "GatherElements"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
indices_shape = ctx.get_shape(node, 1)
|
|
21
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
22
|
+
vi.CopyFrom(
|
|
23
|
+
helper.make_tensor_value_info(
|
|
24
|
+
node.output[0],
|
|
25
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
26
|
+
indices_shape,
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
register_shape_handler(GatherElementsHandler())
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GatherND operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, is_literal
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GatherNDHandler(ShapeHandler):
|
|
14
|
+
"""Handler for GatherND operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "GatherND"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
data_shape = ctx.get_shape(node, 0)
|
|
22
|
+
data_rank = len(data_shape)
|
|
23
|
+
indices_shape = ctx.get_shape(node, 1)
|
|
24
|
+
last_index_dimension = indices_shape[-1]
|
|
25
|
+
batch_dims = get_attribute(node, "batch_dims", 0)
|
|
26
|
+
assert (
|
|
27
|
+
is_literal(last_index_dimension)
|
|
28
|
+
and is_literal(batch_dims)
|
|
29
|
+
and (batch_dims + last_index_dimension) <= data_rank
|
|
30
|
+
)
|
|
31
|
+
new_shape = indices_shape[:-1] + data_shape[batch_dims + last_index_dimension :]
|
|
32
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
33
|
+
vi.CopyFrom(
|
|
34
|
+
helper.make_tensor_value_info(
|
|
35
|
+
node.output[0],
|
|
36
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
37
|
+
new_shape,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
register_shape_handler(GatherNDHandler())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Pad operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PadHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Pad operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Pad"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
if get_opset(ctx.out_mp_) <= 10:
|
|
22
|
+
pads = get_attribute(node, "pads")
|
|
23
|
+
else:
|
|
24
|
+
pads = ctx.try_get_value(node, 1)
|
|
25
|
+
|
|
26
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
27
|
+
rank = len(sympy_shape)
|
|
28
|
+
|
|
29
|
+
if pads is not None:
|
|
30
|
+
assert len(pads) == 2 * rank
|
|
31
|
+
new_sympy_shape = [d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:])]
|
|
32
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
33
|
+
else:
|
|
34
|
+
new_sympy_shape = ctx.new_symbolic_shape(rank, node)
|
|
35
|
+
output_tp = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
36
|
+
|
|
37
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
38
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_shape_handler(PadHandler())
|
|
@@ -0,0 +1,72 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Reshape operator."""
|
|
5
|
+
|
|
6
|
+
import sympy
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_shape_from_sympy_shape, is_literal
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ReshapeHandler(ShapeHandler):
|
|
15
|
+
"""Handler for Reshape operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "Reshape"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
shape_value = ctx.try_get_value(node, 1)
|
|
23
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
24
|
+
if shape_value is None:
|
|
25
|
+
shape_shape = ctx.get_shape(node, 1)
|
|
26
|
+
assert len(shape_shape) == 1
|
|
27
|
+
shape_rank = shape_shape[0]
|
|
28
|
+
assert is_literal(shape_rank)
|
|
29
|
+
vi.CopyFrom(
|
|
30
|
+
helper.make_tensor_value_info(
|
|
31
|
+
node.output[0],
|
|
32
|
+
vi.type.tensor_type.elem_type,
|
|
33
|
+
get_shape_from_sympy_shape(ctx.new_symbolic_shape(shape_rank, node)),
|
|
34
|
+
)
|
|
35
|
+
)
|
|
36
|
+
else:
|
|
37
|
+
input_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
38
|
+
total = 1
|
|
39
|
+
for d in input_sympy_shape:
|
|
40
|
+
total = total * d
|
|
41
|
+
new_sympy_shape = []
|
|
42
|
+
deferred_dim_idx = -1
|
|
43
|
+
non_deferred_size = 1
|
|
44
|
+
for i, d in enumerate(shape_value):
|
|
45
|
+
if type(d) == sympy.Symbol or d != 0:
|
|
46
|
+
new_sympy_shape.append(d)
|
|
47
|
+
else:
|
|
48
|
+
new_sympy_shape.append(input_sympy_shape[i])
|
|
49
|
+
non_deferred_size = non_deferred_size * input_sympy_shape[i]
|
|
50
|
+
if d == -1:
|
|
51
|
+
deferred_dim_idx = i
|
|
52
|
+
elif d != 0:
|
|
53
|
+
non_deferred_size = non_deferred_size * d
|
|
54
|
+
|
|
55
|
+
assert new_sympy_shape.count(-1) < 2
|
|
56
|
+
if -1 in new_sympy_shape:
|
|
57
|
+
new_dim = total // non_deferred_size
|
|
58
|
+
new_sympy_shape[deferred_dim_idx] = new_dim
|
|
59
|
+
|
|
60
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
61
|
+
vi.CopyFrom(
|
|
62
|
+
helper.make_tensor_value_info(
|
|
63
|
+
node.output[0],
|
|
64
|
+
vi.type.tensor_type.elem_type,
|
|
65
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
66
|
+
)
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
ctx.pass_on_sympy_data(node)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
register_shape_handler(ReshapeHandler())
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Shape operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ...utils import get_attribute
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ShapeOpHandler(ShapeHandler):
|
|
12
|
+
"""Handler for Shape operator."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "Shape"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
start = get_attribute(node, "start", 0)
|
|
20
|
+
end = get_attribute(node, "end", None)
|
|
21
|
+
|
|
22
|
+
full_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
23
|
+
num_dims = len(full_sympy_shape)
|
|
24
|
+
|
|
25
|
+
if start < 0:
|
|
26
|
+
start = num_dims + start
|
|
27
|
+
if end is None:
|
|
28
|
+
end = num_dims
|
|
29
|
+
elif end < 0:
|
|
30
|
+
end = num_dims + end
|
|
31
|
+
|
|
32
|
+
assert 0 <= start <= end <= num_dims, f"reshape start/end invalid: start={start}, end={end}, total_dims={num_dims}"
|
|
33
|
+
|
|
34
|
+
target_sympy_shape = full_sympy_shape[start:end]
|
|
35
|
+
ctx.sympy_data_[node.output[0]] = target_sympy_shape
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
register_shape_handler(ShapeOpHandler())
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Size operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import sympy_reduce_product
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SizeHandler(ShapeHandler):
|
|
15
|
+
"""Handler for Size operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "Size"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
23
|
+
ctx.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
|
|
24
|
+
ctx.known_vi_[node.output[0]].CopyFrom(
|
|
25
|
+
helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
register_shape_handler(SizeHandler())
|