onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +84 -3
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.81.dist-info/RECORD +0 -63
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Einsum operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EinsumHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Einsum operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Einsum"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
equation = get_attribute(node, "equation")
|
|
22
|
+
equation = equation.replace(b" ", b"")
|
|
23
|
+
mid_index = equation.find(b"->")
|
|
24
|
+
left_equation = equation[:mid_index] if mid_index != -1 else equation
|
|
25
|
+
|
|
26
|
+
num_operands = 0
|
|
27
|
+
num_ellipsis = 0
|
|
28
|
+
num_ellipsis_indices = 0
|
|
29
|
+
num_labels = 0
|
|
30
|
+
ellipsis_flag = True
|
|
31
|
+
dims_value = []
|
|
32
|
+
ellipsis_dims_value = []
|
|
33
|
+
|
|
34
|
+
label_maps = {}
|
|
35
|
+
repeated_labels = set()
|
|
36
|
+
|
|
37
|
+
terms = left_equation.split(b",")
|
|
38
|
+
for term in terms:
|
|
39
|
+
ellipsis_index = term.find(b"...")
|
|
40
|
+
shape = ctx.get_shape(node, num_operands)
|
|
41
|
+
rank = len(shape)
|
|
42
|
+
ellipsis_dims = 0
|
|
43
|
+
term_size = 0
|
|
44
|
+
num_illegal_char = 0
|
|
45
|
+
|
|
46
|
+
for i in range(len(term)):
|
|
47
|
+
if term[i] != 46:
|
|
48
|
+
term_size = term_size + 1
|
|
49
|
+
|
|
50
|
+
index = 0
|
|
51
|
+
while index < len(term):
|
|
52
|
+
if index == ellipsis_index:
|
|
53
|
+
ellipsis_dims = rank - term_size
|
|
54
|
+
if ellipsis_flag:
|
|
55
|
+
ellipsis_flag = False
|
|
56
|
+
for i in range(ellipsis_dims):
|
|
57
|
+
ellipsis_dims_value.append(shape[index + i - num_illegal_char])
|
|
58
|
+
else:
|
|
59
|
+
for i in range(ellipsis_dims):
|
|
60
|
+
shape_dim = shape[index + i - num_illegal_char]
|
|
61
|
+
current_dim = ellipsis_dims_value[i]
|
|
62
|
+
ellipsis_dims_value[i] = max(current_dim, shape_dim)
|
|
63
|
+
|
|
64
|
+
num_illegal_char += 3
|
|
65
|
+
index += 3
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
elif term[index] == 46:
|
|
69
|
+
num_illegal_char += 1
|
|
70
|
+
index += 1
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
char = term[index]
|
|
74
|
+
if char not in label_maps:
|
|
75
|
+
label_maps[char] = num_labels
|
|
76
|
+
dims_value.append(shape[index + ellipsis_dims - num_illegal_char])
|
|
77
|
+
num_labels += 1
|
|
78
|
+
else:
|
|
79
|
+
repeated_labels.add(char)
|
|
80
|
+
|
|
81
|
+
index += 1
|
|
82
|
+
|
|
83
|
+
if ellipsis_index != -1:
|
|
84
|
+
if num_ellipsis == 0:
|
|
85
|
+
if rank < term_size:
|
|
86
|
+
raise ValueError("Ellipsis represents incompatible dimensions.")
|
|
87
|
+
num_ellipsis_indices = rank - term_size
|
|
88
|
+
else:
|
|
89
|
+
if num_ellipsis_indices != rank - term_size:
|
|
90
|
+
raise ValueError("Ellipsis represents incompatible dimensions.")
|
|
91
|
+
num_ellipsis += 1
|
|
92
|
+
else:
|
|
93
|
+
if rank != term_size:
|
|
94
|
+
raise ValueError("Rank of input ", num_operands, " does not match the equation indices.")
|
|
95
|
+
num_operands += 1
|
|
96
|
+
|
|
97
|
+
new_sympy_shape = []
|
|
98
|
+
if mid_index != -1:
|
|
99
|
+
right_equation = equation[mid_index + 2 :]
|
|
100
|
+
right_ellipsis_index = right_equation.find(b"...")
|
|
101
|
+
if right_ellipsis_index != -1:
|
|
102
|
+
for i in range(num_ellipsis_indices):
|
|
103
|
+
new_sympy_shape.append(ellipsis_dims_value[i])
|
|
104
|
+
for c in right_equation:
|
|
105
|
+
if c != 46:
|
|
106
|
+
new_sympy_shape.append(dims_value[label_maps[c]])
|
|
107
|
+
else:
|
|
108
|
+
for i in range(num_ellipsis_indices):
|
|
109
|
+
new_sympy_shape.append(ellipsis_dims_value[i])
|
|
110
|
+
for label, idx in label_maps.items():
|
|
111
|
+
if label not in repeated_labels:
|
|
112
|
+
new_sympy_shape.append(dims_value[idx])
|
|
113
|
+
|
|
114
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
115
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
116
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
register_shape_handler(EinsumHandler())
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Equal operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Equal", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Floor operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Floor", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MatMul operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class MatMulHandler(ShapeHandler):
|
|
11
|
+
"""Handler for MatMul operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "MatMul"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.compute_matmul_shape(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(MatMulHandler())
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MatMulInteger16 operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MatMulIntegerHandler(ShapeHandler):
|
|
13
|
+
"""Handler for MatMulInteger16 operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "MatMulInteger16"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
ctx.compute_matmul_shape(node, onnx.TensorProto.INT32)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
register_shape_handler(MatMulIntegerHandler())
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Max operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Max", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Min operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Min", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Mul operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Mul", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Neg operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Neg", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ReduceProd operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ...utils import get_attribute, sympy_reduce_product
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ReduceProdHandler(ShapeHandler):
|
|
12
|
+
"""Handler for ReduceProd operator."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "ReduceProd"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
axes = get_attribute(node, "axes")
|
|
20
|
+
keep_dims = get_attribute(node, "keepdims", 1)
|
|
21
|
+
if keep_dims == 0 and axes == [0]:
|
|
22
|
+
data = ctx.get_int_or_float_values(node)[0]
|
|
23
|
+
if data is not None:
|
|
24
|
+
ctx.sympy_data_[node.output[0]] = sympy_reduce_product(data)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
register_shape_handler(ReduceProdHandler())
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ReduceSum operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ReduceSumHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ReduceSum operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "ReduceSum"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
keep_dims = get_attribute(node, "keepdims", 1)
|
|
22
|
+
if get_opset(ctx.out_mp_) >= 13 and len(node.input) > 1:
|
|
23
|
+
axes = ctx.try_get_value(node, 1)
|
|
24
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
25
|
+
if axes is None:
|
|
26
|
+
assert keep_dims
|
|
27
|
+
vi.CopyFrom(
|
|
28
|
+
helper.make_tensor_value_info(
|
|
29
|
+
node.output[0],
|
|
30
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
31
|
+
get_shape_from_sympy_shape(ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)),
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
else:
|
|
35
|
+
shape = ctx.get_shape(node, 0)
|
|
36
|
+
output_shape = []
|
|
37
|
+
axes = [handle_negative_axis(a, len(shape)) for a in axes]
|
|
38
|
+
for i, d in enumerate(shape):
|
|
39
|
+
if i in axes:
|
|
40
|
+
if keep_dims:
|
|
41
|
+
output_shape.append(1)
|
|
42
|
+
else:
|
|
43
|
+
output_shape.append(d)
|
|
44
|
+
vi.CopyFrom(
|
|
45
|
+
helper.make_tensor_value_info(
|
|
46
|
+
node.output[0],
|
|
47
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
48
|
+
output_shape,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
register_shape_handler(ReduceSumHandler())
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Sub operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Sub", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Where operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import MultiOpHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
from ._symbolic_compute import infer_symbolic_compute_ops
|
|
9
|
+
|
|
10
|
+
register_shape_handler(MultiOpHandler("Where", infer_symbolic_compute_ops))
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Miscellaneous operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import constant
|
|
7
|
+
from . import constant_of_shape
|
|
8
|
+
from . import cast
|
|
9
|
+
from . import category_mapper
|
|
10
|
+
from . import compress
|
|
11
|
+
from . import one_hot
|
|
12
|
+
from . import non_max_suppression
|
|
13
|
+
from . import non_zero
|
|
14
|
+
from . import top_k
|
|
15
|
+
from . import range
|
|
16
|
+
from . import resize
|
|
17
|
+
from . import scatter_elements
|
|
18
|
+
from . import array_feature_extractor
|
|
19
|
+
from . import softmax_cross_entropy_loss
|
|
20
|
+
from . import dequantize_linear
|
|
21
|
+
from . import quantize_linear
|
|
22
|
+
from . import relative_position_bias
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ArrayFeatureExtractor operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ArrayFeatureExtractorHandler(ShapeHandler):
|
|
13
|
+
"""Handler for ArrayFeatureExtractor operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "ArrayFeatureExtractor"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
data_shape = ctx.get_shape(node, 0)
|
|
21
|
+
indices_shape = ctx.get_shape(node, 1)
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
vi.CopyFrom(
|
|
24
|
+
helper.make_tensor_value_info(
|
|
25
|
+
node.output[0],
|
|
26
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
27
|
+
data_shape[:-1] + indices_shape,
|
|
28
|
+
)
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
register_shape_handler(ArrayFeatureExtractorHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Cast operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CastHandler(ShapeHandler):
|
|
11
|
+
"""Handler for Cast operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "Cast"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.pass_on_sympy_data(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(CastHandler())
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for CategoryMapper operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CategoryMapperHandler(ShapeHandler):
|
|
14
|
+
"""Handler for CategoryMapper operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "CategoryMapper"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
input_type = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
22
|
+
if input_type == onnx.TensorProto.STRING:
|
|
23
|
+
output_type = onnx.TensorProto.INT64
|
|
24
|
+
else:
|
|
25
|
+
output_type = onnx.TensorProto.STRING
|
|
26
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
27
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, ctx.get_shape(node, 0)))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
register_shape_handler(CategoryMapperHandler())
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Compress operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CompressHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Compress operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Compress"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
input_shape = ctx.get_shape(node, 0)
|
|
22
|
+
compress_len = str(ctx.new_symbolic_dim_from_output(node))
|
|
23
|
+
axis = get_attribute(node, "axis")
|
|
24
|
+
if axis is None:
|
|
25
|
+
output_shape = [compress_len]
|
|
26
|
+
else:
|
|
27
|
+
output_shape = input_shape
|
|
28
|
+
output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
|
|
29
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
30
|
+
vi.CopyFrom(
|
|
31
|
+
helper.make_tensor_value_info(
|
|
32
|
+
node.output[0],
|
|
33
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
34
|
+
output_shape,
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
register_shape_handler(CompressHandler())
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Constant operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import numpy_helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConstantHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Constant operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Constant"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
t = get_attribute(node, "value")
|
|
22
|
+
t.name = node.output[0]
|
|
23
|
+
ctx.initializers_[node.output[0]] = t
|
|
24
|
+
ctx.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
register_shape_handler(ConstantHandler())
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ConstantOfShape operator."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnx
|
|
8
|
+
from onnx import helper, numpy_helper
|
|
9
|
+
|
|
10
|
+
from ...base import ShapeHandler
|
|
11
|
+
from ...registry import register_shape_handler
|
|
12
|
+
from ...utils import get_attribute, get_shape_from_sympy_shape, is_literal
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConstantOfShapeHandler(ShapeHandler):
|
|
16
|
+
"""Handler for ConstantOfShape operator."""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def op_type(self) -> str:
|
|
20
|
+
return "ConstantOfShape"
|
|
21
|
+
|
|
22
|
+
def infer_shape(self, node, ctx) -> None:
|
|
23
|
+
sympy_shape = ctx.get_int_or_float_values(node)[0]
|
|
24
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
25
|
+
if sympy_shape is not None:
|
|
26
|
+
if type(sympy_shape) != list:
|
|
27
|
+
sympy_shape = [sympy_shape]
|
|
28
|
+
ctx.update_computed_dims(sympy_shape)
|
|
29
|
+
if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(is_literal(x) for x in sympy_shape):
|
|
30
|
+
ctx.sympy_data_[node.output[0]] = np.ones(
|
|
31
|
+
[int(x) for x in sympy_shape], dtype=np.int64
|
|
32
|
+
) * numpy_helper.to_array(get_attribute(node, "value", 0))
|
|
33
|
+
else:
|
|
34
|
+
sympy_shape = ctx.new_symbolic_shape(ctx.get_shape(node, 0)[0], node)
|
|
35
|
+
|
|
36
|
+
vi.CopyFrom(
|
|
37
|
+
helper.make_tensor_value_info(
|
|
38
|
+
node.output[0],
|
|
39
|
+
vi.type.tensor_type.elem_type,
|
|
40
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
register_shape_handler(ConstantOfShapeHandler())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for DequantizeLinear operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DequantizeLinearHandler(ShapeHandler):
|
|
13
|
+
"""Handler for DequantizeLinear operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "DequantizeLinear"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
output_dtype = ctx.known_vi_[node.input[1]].type.tensor_type.elem_type
|
|
21
|
+
output_shape = ctx.get_shape(node, 0)
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(DequantizeLinearHandler())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for NonMaxSuppression operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NonMaxSuppressionHandler(ShapeHandler):
|
|
14
|
+
"""Handler for NonMaxSuppression operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "NonMaxSuppression"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
selected = str(ctx.new_symbolic_dim_from_output(node))
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(NonMaxSuppressionHandler())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for NonZero operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class NonZeroHandler(ShapeHandler):
|
|
13
|
+
"""Handler for NonZero operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "NonZero"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
input_rank = ctx.get_shape_rank(node, 0)
|
|
21
|
+
nz_len = str(ctx.new_symbolic_dim_from_output(node, 0, 1))
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(NonZeroHandler())
|