onnxslim 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/fusion/convadd.py +21 -1
  3. onnxslim/core/pattern/fusion/convbn.py +21 -4
  4. onnxslim/core/pattern/fusion/convmul.py +23 -5
  5. onnxslim/core/pattern/fusion/padconv.py +5 -0
  6. onnxslim/core/shape_inference/__init__.py +378 -0
  7. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  8. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  9. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  10. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  11. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  12. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  13. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  14. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  15. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  16. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  17. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  18. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  19. onnxslim/core/shape_inference/base.py +111 -0
  20. onnxslim/core/shape_inference/context.py +645 -0
  21. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  22. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  23. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  24. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  33. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  34. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  35. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  44. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  45. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  46. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/registry.py +90 -0
  53. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  54. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  55. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  56. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  58. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  59. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  60. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  61. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  62. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  63. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  66. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  67. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  69. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  70. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  72. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  73. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  75. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  76. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  77. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  93. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  94. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  95. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  108. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  109. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  113. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  114. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  115. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  129. onnxslim/core/shape_inference/utils.py +244 -0
  130. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  131. onnxslim/utils.py +4 -2
  132. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
  133. onnxslim-0.1.83.dist-info/RECORD +187 -0
  134. onnxslim-0.1.82.dist-info/RECORD +0 -63
  135. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
  137. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SequenceAt operator."""
5
+
6
+ import onnx
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class SequenceAtHandler(ShapeHandler):
13
+ """Handler for SequenceAt operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "SequenceAt"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ seq_shape = ctx.get_shape(node, 0)
21
+ if seq_shape is not None:
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ for di, d in enumerate(seq_shape):
24
+ if d is not None:
25
+ continue
26
+ new_dim = onnx.TensorShapeProto.Dimension()
27
+ new_dim.dim_param = str(ctx.new_symbolic_dim_from_output(node, 0, di))
28
+ vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
29
+
30
+
31
+ register_shape_handler(SequenceAtHandler())
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SequenceInsert operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class SequenceInsertHandler(ShapeHandler):
11
+ """Handler for SequenceInsert operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "SequenceInsert"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ vi_seq = ctx.known_vi_[node.input[0]]
19
+ vi_tensor = ctx.known_vi_[node.input[1]]
20
+ vi_out_seq = ctx.known_vi_[node.output[0]]
21
+ vi_out_seq.CopyFrom(vi_seq)
22
+ vi_out_seq.name = node.output[0]
23
+ ctx.fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
24
+
25
+
26
+ register_shape_handler(SequenceInsertHandler())
@@ -0,0 +1,24 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SplitToSequence operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ..tensor.split import infer_split_common
11
+
12
+
13
+ class SplitToSequenceHandler(ShapeHandler):
14
+ """Handler for SplitToSequence operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "SplitToSequence"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ infer_split_common(node, ctx, helper.make_sequence_value_info)
22
+
23
+
24
+ register_shape_handler(SplitToSequenceHandler())
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ZipMap operator."""
5
+
6
+ import onnx
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class ZipMapHandler(ShapeHandler):
14
+ """Handler for ZipMap operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "ZipMap"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ map_key_type = None
22
+ if get_attribute(node, "classlabels_int64s") is not None:
23
+ map_key_type = onnx.TensorProto.INT64
24
+ elif get_attribute(node, "classlabels_strings") is not None:
25
+ map_key_type = onnx.TensorProto.STRING
26
+
27
+ assert map_key_type is not None
28
+ new_vi = onnx.ValueInfoProto()
29
+ new_vi.name = node.output[0]
30
+ new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
31
+ new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
32
+ vi = ctx.known_vi_[node.output[0]]
33
+ vi.CopyFrom(new_vi)
34
+
35
+
36
+ register_shape_handler(ZipMapHandler())
@@ -0,0 +1,20 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Tensor manipulation operator shape handlers."""
5
+
6
+ from . import concat
7
+ from . import gather
8
+ from . import gather_elements
9
+ from . import gather_nd
10
+ from . import reshape
11
+ from . import slice
12
+ from . import split
13
+ from . import squeeze
14
+ from . import unsqueeze
15
+ from . import transpose
16
+ from . import tile
17
+ from . import expand
18
+ from . import pad
19
+ from . import shape
20
+ from . import size
@@ -0,0 +1,62 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Concat operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_shape_from_sympy_shape, handle_negative_axis
11
+
12
+
13
+ class ConcatHandler(ShapeHandler):
14
+ """Handler for Concat operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Concat"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ if any(i in ctx.sympy_data_ or i in ctx.initializers_ for i in node.input):
22
+ values = ctx.get_int_or_float_values(node)
23
+ if all(v is not None for v in values):
24
+ assert get_attribute(node, "axis") == 0
25
+ ctx.sympy_data_[node.output[0]] = []
26
+ for i in range(len(node.input)):
27
+ value = values[i]
28
+ if isinstance(value, list):
29
+ ctx.sympy_data_[node.output[0]].extend(value)
30
+ else:
31
+ ctx.sympy_data_[node.output[0]].append(value)
32
+
33
+ sympy_shape = ctx.get_sympy_shape(node, 0)
34
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
35
+ for i_idx in range(1, len(node.input)):
36
+ input_shape = ctx.get_sympy_shape(node, i_idx)
37
+ if input_shape:
38
+ sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
39
+ ctx.update_computed_dims(sympy_shape)
40
+ # merge symbolic dims for non-concat axes
41
+ for d in range(len(sympy_shape)):
42
+ if d == axis:
43
+ continue
44
+ dims = [ctx.get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if ctx.get_shape(node, i_idx)]
45
+ if all(d == dims[0] for d in dims):
46
+ continue
47
+ merged = ctx.merge_symbols(dims)
48
+ if type(merged) == str:
49
+ sympy_shape[d] = ctx.symbolic_dims_[merged] if merged else None
50
+ else:
51
+ sympy_shape[d] = merged
52
+ vi = ctx.known_vi_[node.output[0]]
53
+ vi.CopyFrom(
54
+ helper.make_tensor_value_info(
55
+ node.output[0],
56
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
57
+ get_shape_from_sympy_shape(sympy_shape),
58
+ )
59
+ )
60
+
61
+
62
+ register_shape_handler(ConcatHandler())
@@ -0,0 +1,36 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Expand operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import as_list, get_shape_from_sympy_shape
11
+
12
+
13
+ class ExpandHandler(ShapeHandler):
14
+ """Handler for Expand operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Expand"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ expand_to_shape = as_list(ctx.try_get_value(node, 1), keep_none=True)
22
+ if expand_to_shape is not None:
23
+ ctx.update_computed_dims(expand_to_shape)
24
+ shape = ctx.get_shape(node, 0)
25
+ new_shape = ctx.broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
26
+ vi = ctx.known_vi_[node.output[0]]
27
+ vi.CopyFrom(
28
+ helper.make_tensor_value_info(
29
+ node.output[0],
30
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
31
+ new_shape,
32
+ )
33
+ )
34
+
35
+
36
+ register_shape_handler(ExpandHandler())
@@ -0,0 +1,48 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Gather operator."""
5
+
6
+ import numpy as np
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_attribute, handle_negative_axis
12
+
13
+
14
+ class GatherHandler(ShapeHandler):
15
+ """Handler for Gather operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "Gather"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ data_shape = ctx.get_shape(node, 0)
23
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
24
+ indices_shape = ctx.get_shape(node, 1)
25
+ vi = ctx.known_vi_[node.output[0]]
26
+ vi.CopyFrom(
27
+ helper.make_tensor_value_info(
28
+ node.output[0],
29
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
30
+ data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
31
+ )
32
+ )
33
+ # for 1D input, do some sympy compute
34
+ if node.input[0] in ctx.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0:
35
+ idx = ctx.try_get_value(node, 1)
36
+ if idx is not None:
37
+ data = ctx.sympy_data_[node.input[0]]
38
+ if type(data) == list:
39
+ if type(idx) == np.ndarray and len(idx.shape) == 1:
40
+ ctx.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
41
+ else:
42
+ ctx.sympy_data_[node.output[0]] = data[int(idx)]
43
+ else:
44
+ assert idx in {0, -1}
45
+ ctx.sympy_data_[node.output[0]] = data
46
+
47
+
48
+ register_shape_handler(GatherHandler())
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GatherElements operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class GatherElementsHandler(ShapeHandler):
13
+ """Handler for GatherElements operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "GatherElements"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ indices_shape = ctx.get_shape(node, 1)
21
+ vi = ctx.known_vi_[node.output[0]]
22
+ vi.CopyFrom(
23
+ helper.make_tensor_value_info(
24
+ node.output[0],
25
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
26
+ indices_shape,
27
+ )
28
+ )
29
+
30
+
31
+ register_shape_handler(GatherElementsHandler())
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GatherND operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, is_literal
11
+
12
+
13
+ class GatherNDHandler(ShapeHandler):
14
+ """Handler for GatherND operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "GatherND"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ data_shape = ctx.get_shape(node, 0)
22
+ data_rank = len(data_shape)
23
+ indices_shape = ctx.get_shape(node, 1)
24
+ last_index_dimension = indices_shape[-1]
25
+ batch_dims = get_attribute(node, "batch_dims", 0)
26
+ assert (
27
+ is_literal(last_index_dimension)
28
+ and is_literal(batch_dims)
29
+ and (batch_dims + last_index_dimension) <= data_rank
30
+ )
31
+ new_shape = indices_shape[:-1] + data_shape[batch_dims + last_index_dimension :]
32
+ vi = ctx.known_vi_[node.output[0]]
33
+ vi.CopyFrom(
34
+ helper.make_tensor_value_info(
35
+ node.output[0],
36
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
37
+ new_shape,
38
+ )
39
+ )
40
+
41
+
42
+ register_shape_handler(GatherNDHandler())
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Pad operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape
11
+
12
+
13
+ class PadHandler(ShapeHandler):
14
+ """Handler for Pad operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Pad"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ if get_opset(ctx.out_mp_) <= 10:
22
+ pads = get_attribute(node, "pads")
23
+ else:
24
+ pads = ctx.try_get_value(node, 1)
25
+
26
+ sympy_shape = ctx.get_sympy_shape(node, 0)
27
+ rank = len(sympy_shape)
28
+
29
+ if pads is not None:
30
+ assert len(pads) == 2 * rank
31
+ new_sympy_shape = [d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:])]
32
+ ctx.update_computed_dims(new_sympy_shape)
33
+ else:
34
+ new_sympy_shape = ctx.new_symbolic_shape(rank, node)
35
+ output_tp = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
36
+
37
+ vi = ctx.known_vi_[node.output[0]]
38
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)))
39
+
40
+
41
+ register_shape_handler(PadHandler())
@@ -0,0 +1,72 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Reshape operator."""
5
+
6
+ import sympy
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_shape_from_sympy_shape, is_literal
12
+
13
+
14
+ class ReshapeHandler(ShapeHandler):
15
+ """Handler for Reshape operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "Reshape"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ shape_value = ctx.try_get_value(node, 1)
23
+ vi = ctx.known_vi_[node.output[0]]
24
+ if shape_value is None:
25
+ shape_shape = ctx.get_shape(node, 1)
26
+ assert len(shape_shape) == 1
27
+ shape_rank = shape_shape[0]
28
+ assert is_literal(shape_rank)
29
+ vi.CopyFrom(
30
+ helper.make_tensor_value_info(
31
+ node.output[0],
32
+ vi.type.tensor_type.elem_type,
33
+ get_shape_from_sympy_shape(ctx.new_symbolic_shape(shape_rank, node)),
34
+ )
35
+ )
36
+ else:
37
+ input_sympy_shape = ctx.get_sympy_shape(node, 0)
38
+ total = 1
39
+ for d in input_sympy_shape:
40
+ total = total * d
41
+ new_sympy_shape = []
42
+ deferred_dim_idx = -1
43
+ non_deferred_size = 1
44
+ for i, d in enumerate(shape_value):
45
+ if type(d) == sympy.Symbol or d != 0:
46
+ new_sympy_shape.append(d)
47
+ else:
48
+ new_sympy_shape.append(input_sympy_shape[i])
49
+ non_deferred_size = non_deferred_size * input_sympy_shape[i]
50
+ if d == -1:
51
+ deferred_dim_idx = i
52
+ elif d != 0:
53
+ non_deferred_size = non_deferred_size * d
54
+
55
+ assert new_sympy_shape.count(-1) < 2
56
+ if -1 in new_sympy_shape:
57
+ new_dim = total // non_deferred_size
58
+ new_sympy_shape[deferred_dim_idx] = new_dim
59
+
60
+ ctx.update_computed_dims(new_sympy_shape)
61
+ vi.CopyFrom(
62
+ helper.make_tensor_value_info(
63
+ node.output[0],
64
+ vi.type.tensor_type.elem_type,
65
+ get_shape_from_sympy_shape(new_sympy_shape),
66
+ )
67
+ )
68
+
69
+ ctx.pass_on_sympy_data(node)
70
+
71
+
72
+ register_shape_handler(ReshapeHandler())
@@ -0,0 +1,38 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Shape operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+ from ...utils import get_attribute
9
+
10
+
11
+ class ShapeOpHandler(ShapeHandler):
12
+ """Handler for Shape operator."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "Shape"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ start = get_attribute(node, "start", 0)
20
+ end = get_attribute(node, "end", None)
21
+
22
+ full_sympy_shape = ctx.get_sympy_shape(node, 0)
23
+ num_dims = len(full_sympy_shape)
24
+
25
+ if start < 0:
26
+ start = num_dims + start
27
+ if end is None:
28
+ end = num_dims
29
+ elif end < 0:
30
+ end = num_dims + end
31
+
32
+ assert 0 <= start <= end <= num_dims, f"reshape start/end invalid: start={start}, end={end}, total_dims={num_dims}"
33
+
34
+ target_sympy_shape = full_sympy_shape[start:end]
35
+ ctx.sympy_data_[node.output[0]] = target_sympy_shape
36
+
37
+
38
+ register_shape_handler(ShapeOpHandler())
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Size operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import sympy_reduce_product
12
+
13
+
14
+ class SizeHandler(ShapeHandler):
15
+ """Handler for Size operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "Size"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ sympy_shape = ctx.get_sympy_shape(node, 0)
23
+ ctx.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
24
+ ctx.known_vi_[node.output[0]].CopyFrom(
25
+ helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
26
+ )
27
+
28
+
29
+ register_shape_handler(SizeHandler())