onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for OneHot operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_shape_from_sympy_shape, handle_negative_axis, is_literal
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class OneHotHandler(ShapeHandler):
|
|
14
|
+
"""Handler for OneHot operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "OneHot"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
22
|
+
depth = ctx.try_get_value(node, 1)
|
|
23
|
+
axis = get_attribute(node, "axis", -1)
|
|
24
|
+
axis = handle_negative_axis(axis, len(sympy_shape) + 1)
|
|
25
|
+
new_shape = get_shape_from_sympy_shape(
|
|
26
|
+
[
|
|
27
|
+
*sympy_shape[:axis],
|
|
28
|
+
depth if is_literal(depth) else ctx.new_symbolic_dim_from_output(node),
|
|
29
|
+
*sympy_shape[axis:],
|
|
30
|
+
]
|
|
31
|
+
)
|
|
32
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
33
|
+
vi.CopyFrom(
|
|
34
|
+
helper.make_tensor_value_info(
|
|
35
|
+
node.output[0],
|
|
36
|
+
ctx.known_vi_[node.input[2]].type.tensor_type.elem_type,
|
|
37
|
+
new_shape,
|
|
38
|
+
)
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
register_shape_handler(OneHotHandler())
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for QuantizeLinear operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class QuantizeLinearHandler(ShapeHandler):
|
|
14
|
+
"""Handler for QuantizeLinear operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "QuantizeLinear"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
output_dtype = onnx.TensorProto.UINT8
|
|
22
|
+
if len(node.input) > 2 and node.input[2]:
|
|
23
|
+
output_dtype = ctx.known_vi_[node.input[2]].type.tensor_type.elem_type
|
|
24
|
+
output_shape = ctx.get_shape(node, 0)
|
|
25
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
26
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
register_shape_handler(QuantizeLinearHandler())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Range operator."""
|
|
5
|
+
|
|
6
|
+
import sympy
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import as_scalar, get_shape_from_sympy_shape
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class RangeHandler(ShapeHandler):
|
|
15
|
+
"""Handler for Range operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "Range"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
input_data = ctx.get_int_or_float_values(node)
|
|
24
|
+
if all(i is not None for i in input_data):
|
|
25
|
+
start = as_scalar(input_data[0])
|
|
26
|
+
limit = as_scalar(input_data[1])
|
|
27
|
+
delta = as_scalar(input_data[2])
|
|
28
|
+
new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
|
|
29
|
+
else:
|
|
30
|
+
new_sympy_shape = [ctx.new_symbolic_dim_from_output(node)]
|
|
31
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
32
|
+
vi.CopyFrom(
|
|
33
|
+
helper.make_tensor_value_info(
|
|
34
|
+
node.output[0],
|
|
35
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
36
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_shape_handler(RangeHandler())
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for RelativePositionBias operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RelativePositionBiasHandler(ShapeHandler):
|
|
13
|
+
"""Handler for RelativePositionBias operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "RelativePositionBias"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
seq_len = ctx.try_get_value(node, 1)
|
|
21
|
+
real_seq_len = ctx.try_get_value(node, 2)
|
|
22
|
+
if seq_len is None or real_seq_len is None:
|
|
23
|
+
return
|
|
24
|
+
num_heads = ctx.get_sympy_shape(node, 0)[1]
|
|
25
|
+
new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
|
|
26
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
27
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
28
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
register_shape_handler(RelativePositionBiasHandler())
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Resize operator."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import sympy
|
|
8
|
+
from onnx import helper
|
|
9
|
+
|
|
10
|
+
from ...base import ShapeHandler
|
|
11
|
+
from ...registry import register_shape_handler
|
|
12
|
+
from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ResizeHandler(ShapeHandler):
|
|
16
|
+
"""Handler for Resize operator."""
|
|
17
|
+
|
|
18
|
+
@property
|
|
19
|
+
def op_type(self) -> str:
|
|
20
|
+
return "Resize"
|
|
21
|
+
|
|
22
|
+
def infer_shape(self, node, ctx) -> None:
|
|
23
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
24
|
+
input_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
25
|
+
if get_opset(ctx.out_mp_) <= 10:
|
|
26
|
+
scales = ctx.try_get_value(node, 1)
|
|
27
|
+
if scales is not None:
|
|
28
|
+
new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)]
|
|
29
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
30
|
+
vi.CopyFrom(
|
|
31
|
+
helper.make_tensor_value_info(
|
|
32
|
+
node.output[0],
|
|
33
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
34
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
else:
|
|
38
|
+
roi = ctx.try_get_value(node, 1)
|
|
39
|
+
scales = ctx.try_get_value(node, 2)
|
|
40
|
+
sizes = ctx.try_get_value(node, 3)
|
|
41
|
+
if sizes is not None:
|
|
42
|
+
new_sympy_shape = [sympy.simplify(round(s)) for s in sizes]
|
|
43
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
44
|
+
elif scales is not None:
|
|
45
|
+
rank = len(scales)
|
|
46
|
+
if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
|
|
47
|
+
assert len(roi) == 2 * rank
|
|
48
|
+
roi_start = list(roi)[:rank]
|
|
49
|
+
roi_end = list(roi)[rank:]
|
|
50
|
+
else:
|
|
51
|
+
roi_start = [0] * rank
|
|
52
|
+
roi_end = [1] * rank
|
|
53
|
+
if isinstance(scales, np.ndarray):
|
|
54
|
+
scales = scales.tolist()
|
|
55
|
+
else:
|
|
56
|
+
scales = list(scales)
|
|
57
|
+
new_sympy_shape = [
|
|
58
|
+
sympy.floor(d * (end - start) * scale + sympy.Rational(1, 2))
|
|
59
|
+
for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales)
|
|
60
|
+
]
|
|
61
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
62
|
+
else:
|
|
63
|
+
new_sympy_shape = ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)
|
|
64
|
+
|
|
65
|
+
vi.CopyFrom(
|
|
66
|
+
helper.make_tensor_value_info(
|
|
67
|
+
node.output[0],
|
|
68
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
69
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
register_shape_handler(ResizeHandler())
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ScatterElements operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ScatterElementsHandler(ShapeHandler):
|
|
13
|
+
"""Handler for ScatterElements operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "ScatterElements"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
data_shape = ctx.get_shape(node, 0)
|
|
21
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
22
|
+
vi.CopyFrom(
|
|
23
|
+
helper.make_tensor_value_info(
|
|
24
|
+
node.output[0],
|
|
25
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
26
|
+
data_shape,
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
register_shape_handler(ScatterElementsHandler())
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for SoftmaxCrossEntropyLoss operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import MultiOpHandler, ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_attribute
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SoftmaxCrossEntropyLossHandler(ShapeHandler):
|
|
15
|
+
"""Handler for SoftmaxCrossEntropyLoss operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "SoftmaxCrossEntropyLoss"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
23
|
+
elem_type = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
24
|
+
|
|
25
|
+
specified_output_type = get_attribute(node, "output_type", None)
|
|
26
|
+
if specified_output_type is not None:
|
|
27
|
+
elem_type = specified_output_type
|
|
28
|
+
|
|
29
|
+
vi.type.tensor_type.elem_type = elem_type
|
|
30
|
+
vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
|
|
31
|
+
|
|
32
|
+
if len(node.output) > 1:
|
|
33
|
+
data_shape = ctx.get_shape(node, 0)
|
|
34
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
35
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _infer_softmax_cross_entropy(node, ctx):
|
|
39
|
+
SoftmaxCrossEntropyLossHandler().infer_shape(node, ctx)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
register_shape_handler(SoftmaxCrossEntropyLossHandler())
|
|
43
|
+
register_shape_handler(MultiOpHandler("SoftmaxCrossEntropyLossInternal", _infer_softmax_cross_entropy))
|
|
44
|
+
register_shape_handler(MultiOpHandler("NegativeLogLikelihoodLossInternal", _infer_softmax_cross_entropy))
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for TopK operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import as_scalar, get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TopKHandler(ShapeHandler):
|
|
14
|
+
"""Handler for TopK operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "TopK"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
rank = ctx.get_shape_rank(node, 0)
|
|
22
|
+
axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
|
|
23
|
+
new_shape = ctx.get_shape(node, 0)
|
|
24
|
+
|
|
25
|
+
if get_opset(ctx.out_mp_) <= 9:
|
|
26
|
+
k = get_attribute(node, "k")
|
|
27
|
+
else:
|
|
28
|
+
k = ctx.get_int_or_float_values(node)[1]
|
|
29
|
+
|
|
30
|
+
k = ctx.new_symbolic_dim_from_output(node) if k is None else as_scalar(k)
|
|
31
|
+
if type(k) in {int, str}:
|
|
32
|
+
new_shape[axis] = k
|
|
33
|
+
else:
|
|
34
|
+
new_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
35
|
+
new_sympy_shape[axis] = k
|
|
36
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
37
|
+
new_shape = get_shape_from_sympy_shape(new_sympy_shape)
|
|
38
|
+
|
|
39
|
+
for i_o in range(len(node.output)):
|
|
40
|
+
vi = ctx.known_vi_[node.output[i_o]]
|
|
41
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
register_shape_handler(TopKHandler())
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Neural network operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import conv
|
|
7
|
+
from . import nhwc_conv
|
|
8
|
+
from . import average_pool
|
|
9
|
+
from . import max_pool
|
|
10
|
+
from . import batch_normalization
|
|
11
|
+
from . import identity
|
|
12
|
+
from . import cum_sum
|
|
13
|
+
from . import round
|
|
14
|
+
from . import reciprocal
|
|
15
|
+
from . import memcpy_from_host
|
|
16
|
+
from . import memcpy_to_host
|
|
17
|
+
from . import moe
|
|
18
|
+
from . import all_reduce
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for AllReduce operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("AllReduce"))
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for AveragePool operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PoolHandler(ShapeHandler):
|
|
14
|
+
"""Handler for pooling operators (AveragePool, MaxPool)."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, op_type_name):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self._op_type = op_type_name
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def op_type(self) -> str:
|
|
22
|
+
return self._op_type
|
|
23
|
+
|
|
24
|
+
def infer_shape(self, node, ctx) -> None:
|
|
25
|
+
sympy_shape = ctx.compute_conv_pool_shape(node)
|
|
26
|
+
ctx.update_computed_dims(sympy_shape)
|
|
27
|
+
for o in node.output:
|
|
28
|
+
if not o:
|
|
29
|
+
continue
|
|
30
|
+
vi = ctx.known_vi_[o]
|
|
31
|
+
vi.CopyFrom(
|
|
32
|
+
helper.make_tensor_value_info(
|
|
33
|
+
o,
|
|
34
|
+
vi.type.tensor_type.elem_type,
|
|
35
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
36
|
+
)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
register_shape_handler(PoolHandler("AveragePool"))
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for BatchNormalization operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BatchNormalizationHandler(ShapeHandler):
|
|
11
|
+
"""Handler for BatchNormalization operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "BatchNormalization"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
# this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
|
|
21
|
+
for i in {1, 2, 3, 4}:
|
|
22
|
+
if i < len(node.output) and node.output[i]:
|
|
23
|
+
ctx.propagate_shape_and_type(node, input_index=1, output_index=i)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
register_shape_handler(BatchNormalizationHandler())
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Conv operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConvHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Conv operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Conv"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
sympy_shape = ctx.compute_conv_pool_shape(node)
|
|
22
|
+
ctx.update_computed_dims(sympy_shape)
|
|
23
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
24
|
+
vi.CopyFrom(
|
|
25
|
+
helper.make_tensor_value_info(
|
|
26
|
+
node.output[0],
|
|
27
|
+
vi.type.tensor_type.elem_type,
|
|
28
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
29
|
+
)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
register_shape_handler(ConvHandler())
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for CumSum operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("CumSum"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Identity operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("Identity"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MaxPool operator."""
|
|
5
|
+
|
|
6
|
+
from ...registry import register_shape_handler
|
|
7
|
+
from .average_pool import PoolHandler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PoolHandler("MaxPool"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MemcpyFromHost operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("MemcpyFromHost"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MemcpyToHost operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("MemcpyToHost"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MoE operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("MoE"))
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for NhwcConv operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NhwcConvHandler(ShapeHandler):
|
|
14
|
+
"""Handler for NhwcConv operator (channels last format)."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "NhwcConv"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
sympy_shape = ctx.compute_conv_pool_shape(node, channels_last=True)
|
|
22
|
+
ctx.update_computed_dims(sympy_shape)
|
|
23
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
24
|
+
vi.CopyFrom(
|
|
25
|
+
helper.make_tensor_value_info(
|
|
26
|
+
node.output[0],
|
|
27
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
28
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
29
|
+
)
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
register_shape_handler(NhwcConvHandler())
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Reciprocal operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("Reciprocal"))
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Round operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import PassthroughHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
register_shape_handler(PassthroughHandler("Round"))
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Sequence operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import concat_from_sequence
|
|
7
|
+
from . import split_to_sequence
|
|
8
|
+
from . import sequence_at
|
|
9
|
+
from . import sequence_insert
|
|
10
|
+
from . import zip_map
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ConcatFromSequence operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ConcatFromSequenceHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ConcatFromSequence operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "ConcatFromSequence"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
seq_shape = ctx.get_shape(node, 0)
|
|
22
|
+
new_axis = 1 if get_attribute(node, "new_axis") else 0
|
|
23
|
+
axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
|
|
24
|
+
concat_dim = str(ctx.new_symbolic_dim_from_output(node, 0, axis))
|
|
25
|
+
new_shape = seq_shape
|
|
26
|
+
if new_axis:
|
|
27
|
+
new_shape = [*seq_shape[:axis], concat_dim, *seq_shape[axis:]]
|
|
28
|
+
else:
|
|
29
|
+
new_shape[axis] = concat_dim
|
|
30
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
31
|
+
vi.CopyFrom(
|
|
32
|
+
helper.make_tensor_value_info(
|
|
33
|
+
node.output[0],
|
|
34
|
+
ctx.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
|
|
35
|
+
new_shape,
|
|
36
|
+
)
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
register_shape_handler(ConcatFromSequenceHandler())
|