onnxslim 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/fusion/convadd.py +21 -1
  3. onnxslim/core/pattern/fusion/convbn.py +21 -4
  4. onnxslim/core/pattern/fusion/convmul.py +23 -5
  5. onnxslim/core/pattern/fusion/padconv.py +5 -0
  6. onnxslim/core/shape_inference/__init__.py +378 -0
  7. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  8. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  9. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  10. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  11. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  12. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  13. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  14. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  15. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  16. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  17. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  18. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  19. onnxslim/core/shape_inference/base.py +111 -0
  20. onnxslim/core/shape_inference/context.py +645 -0
  21. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  22. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  23. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  24. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  33. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  34. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  35. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  44. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  45. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  46. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/registry.py +90 -0
  53. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  54. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  55. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  56. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  58. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  59. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  60. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  61. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  62. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  63. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  66. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  67. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  69. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  70. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  72. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  73. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  75. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  76. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  77. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  93. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  94. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  95. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  108. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  109. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  113. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  114. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  115. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  129. onnxslim/core/shape_inference/utils.py +244 -0
  130. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  131. onnxslim/utils.py +4 -2
  132. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
  133. onnxslim-0.1.83.dist-info/RECORD +187 -0
  134. onnxslim-0.1.82.dist-info/RECORD +0 -63
  135. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
  137. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for OneHot operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_shape_from_sympy_shape, handle_negative_axis, is_literal
11
+
12
+
13
+ class OneHotHandler(ShapeHandler):
14
+ """Handler for OneHot operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "OneHot"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ sympy_shape = ctx.get_sympy_shape(node, 0)
22
+ depth = ctx.try_get_value(node, 1)
23
+ axis = get_attribute(node, "axis", -1)
24
+ axis = handle_negative_axis(axis, len(sympy_shape) + 1)
25
+ new_shape = get_shape_from_sympy_shape(
26
+ [
27
+ *sympy_shape[:axis],
28
+ depth if is_literal(depth) else ctx.new_symbolic_dim_from_output(node),
29
+ *sympy_shape[axis:],
30
+ ]
31
+ )
32
+ vi = ctx.known_vi_[node.output[0]]
33
+ vi.CopyFrom(
34
+ helper.make_tensor_value_info(
35
+ node.output[0],
36
+ ctx.known_vi_[node.input[2]].type.tensor_type.elem_type,
37
+ new_shape,
38
+ )
39
+ )
40
+
41
+
42
+ register_shape_handler(OneHotHandler())
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for QuantizeLinear operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+
12
+
13
+ class QuantizeLinearHandler(ShapeHandler):
14
+ """Handler for QuantizeLinear operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "QuantizeLinear"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ output_dtype = onnx.TensorProto.UINT8
22
+ if len(node.input) > 2 and node.input[2]:
23
+ output_dtype = ctx.known_vi_[node.input[2]].type.tensor_type.elem_type
24
+ output_shape = ctx.get_shape(node, 0)
25
+ vi = ctx.known_vi_[node.output[0]]
26
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
27
+
28
+
29
+ register_shape_handler(QuantizeLinearHandler())
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Range operator."""
5
+
6
+ import sympy
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import as_scalar, get_shape_from_sympy_shape
12
+
13
+
14
+ class RangeHandler(ShapeHandler):
15
+ """Handler for Range operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "Range"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ input_data = ctx.get_int_or_float_values(node)
24
+ if all(i is not None for i in input_data):
25
+ start = as_scalar(input_data[0])
26
+ limit = as_scalar(input_data[1])
27
+ delta = as_scalar(input_data[2])
28
+ new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
29
+ else:
30
+ new_sympy_shape = [ctx.new_symbolic_dim_from_output(node)]
31
+ ctx.update_computed_dims(new_sympy_shape)
32
+ vi.CopyFrom(
33
+ helper.make_tensor_value_info(
34
+ node.output[0],
35
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
36
+ get_shape_from_sympy_shape(new_sympy_shape),
37
+ )
38
+ )
39
+
40
+
41
+ register_shape_handler(RangeHandler())
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for RelativePositionBias operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class RelativePositionBiasHandler(ShapeHandler):
13
+ """Handler for RelativePositionBias operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "RelativePositionBias"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ seq_len = ctx.try_get_value(node, 1)
21
+ real_seq_len = ctx.try_get_value(node, 2)
22
+ if seq_len is None or real_seq_len is None:
23
+ return
24
+ num_heads = ctx.get_sympy_shape(node, 0)[1]
25
+ new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
26
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
27
+ vi = ctx.known_vi_[node.output[0]]
28
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
29
+
30
+
31
+ register_shape_handler(RelativePositionBiasHandler())
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Resize operator."""
5
+
6
+ import numpy as np
7
+ import sympy
8
+ from onnx import helper
9
+
10
+ from ...base import ShapeHandler
11
+ from ...registry import register_shape_handler
12
+ from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape
13
+
14
+
15
+ class ResizeHandler(ShapeHandler):
16
+ """Handler for Resize operator."""
17
+
18
+ @property
19
+ def op_type(self) -> str:
20
+ return "Resize"
21
+
22
+ def infer_shape(self, node, ctx) -> None:
23
+ vi = ctx.known_vi_[node.output[0]]
24
+ input_sympy_shape = ctx.get_sympy_shape(node, 0)
25
+ if get_opset(ctx.out_mp_) <= 10:
26
+ scales = ctx.try_get_value(node, 1)
27
+ if scales is not None:
28
+ new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)]
29
+ ctx.update_computed_dims(new_sympy_shape)
30
+ vi.CopyFrom(
31
+ helper.make_tensor_value_info(
32
+ node.output[0],
33
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
34
+ get_shape_from_sympy_shape(new_sympy_shape),
35
+ )
36
+ )
37
+ else:
38
+ roi = ctx.try_get_value(node, 1)
39
+ scales = ctx.try_get_value(node, 2)
40
+ sizes = ctx.try_get_value(node, 3)
41
+ if sizes is not None:
42
+ new_sympy_shape = [sympy.simplify(round(s)) for s in sizes]
43
+ ctx.update_computed_dims(new_sympy_shape)
44
+ elif scales is not None:
45
+ rank = len(scales)
46
+ if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
47
+ assert len(roi) == 2 * rank
48
+ roi_start = list(roi)[:rank]
49
+ roi_end = list(roi)[rank:]
50
+ else:
51
+ roi_start = [0] * rank
52
+ roi_end = [1] * rank
53
+ if isinstance(scales, np.ndarray):
54
+ scales = scales.tolist()
55
+ else:
56
+ scales = list(scales)
57
+ new_sympy_shape = [
58
+ sympy.floor(d * (end - start) * scale + sympy.Rational(1, 2))
59
+ for d, start, end, scale in zip(input_sympy_shape, roi_start, roi_end, scales)
60
+ ]
61
+ ctx.update_computed_dims(new_sympy_shape)
62
+ else:
63
+ new_sympy_shape = ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)
64
+
65
+ vi.CopyFrom(
66
+ helper.make_tensor_value_info(
67
+ node.output[0],
68
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
69
+ get_shape_from_sympy_shape(new_sympy_shape),
70
+ )
71
+ )
72
+
73
+
74
+ register_shape_handler(ResizeHandler())
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ScatterElements operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class ScatterElementsHandler(ShapeHandler):
13
+ """Handler for ScatterElements operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "ScatterElements"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ data_shape = ctx.get_shape(node, 0)
21
+ vi = ctx.known_vi_[node.output[0]]
22
+ vi.CopyFrom(
23
+ helper.make_tensor_value_info(
24
+ node.output[0],
25
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
26
+ data_shape,
27
+ )
28
+ )
29
+
30
+
31
+ register_shape_handler(ScatterElementsHandler())
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SoftmaxCrossEntropyLoss operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import MultiOpHandler, ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_attribute
12
+
13
+
14
+ class SoftmaxCrossEntropyLossHandler(ShapeHandler):
15
+ """Handler for SoftmaxCrossEntropyLoss operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "SoftmaxCrossEntropyLoss"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ elem_type = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
24
+
25
+ specified_output_type = get_attribute(node, "output_type", None)
26
+ if specified_output_type is not None:
27
+ elem_type = specified_output_type
28
+
29
+ vi.type.tensor_type.elem_type = elem_type
30
+ vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
31
+
32
+ if len(node.output) > 1:
33
+ data_shape = ctx.get_shape(node, 0)
34
+ vi = ctx.known_vi_[node.output[1]]
35
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
36
+
37
+
38
+ def _infer_softmax_cross_entropy(node, ctx):
39
+ SoftmaxCrossEntropyLossHandler().infer_shape(node, ctx)
40
+
41
+
42
+ register_shape_handler(SoftmaxCrossEntropyLossHandler())
43
+ register_shape_handler(MultiOpHandler("SoftmaxCrossEntropyLossInternal", _infer_softmax_cross_entropy))
44
+ register_shape_handler(MultiOpHandler("NegativeLogLikelihoodLossInternal", _infer_softmax_cross_entropy))
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for TopK operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import as_scalar, get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
11
+
12
+
13
+ class TopKHandler(ShapeHandler):
14
+ """Handler for TopK operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "TopK"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ rank = ctx.get_shape_rank(node, 0)
22
+ axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
23
+ new_shape = ctx.get_shape(node, 0)
24
+
25
+ if get_opset(ctx.out_mp_) <= 9:
26
+ k = get_attribute(node, "k")
27
+ else:
28
+ k = ctx.get_int_or_float_values(node)[1]
29
+
30
+ k = ctx.new_symbolic_dim_from_output(node) if k is None else as_scalar(k)
31
+ if type(k) in {int, str}:
32
+ new_shape[axis] = k
33
+ else:
34
+ new_sympy_shape = ctx.get_sympy_shape(node, 0)
35
+ new_sympy_shape[axis] = k
36
+ ctx.update_computed_dims(new_sympy_shape)
37
+ new_shape = get_shape_from_sympy_shape(new_sympy_shape)
38
+
39
+ for i_o in range(len(node.output)):
40
+ vi = ctx.known_vi_[node.output[i_o]]
41
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
42
+
43
+
44
+ register_shape_handler(TopKHandler())
@@ -0,0 +1,18 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Neural network operator shape handlers."""
5
+
6
+ from . import conv
7
+ from . import nhwc_conv
8
+ from . import average_pool
9
+ from . import max_pool
10
+ from . import batch_normalization
11
+ from . import identity
12
+ from . import cum_sum
13
+ from . import round
14
+ from . import reciprocal
15
+ from . import memcpy_from_host
16
+ from . import memcpy_to_host
17
+ from . import moe
18
+ from . import all_reduce
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for AllReduce operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("AllReduce"))
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for AveragePool operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_shape_from_sympy_shape
11
+
12
+
13
+ class PoolHandler(ShapeHandler):
14
+ """Handler for pooling operators (AveragePool, MaxPool)."""
15
+
16
+ def __init__(self, op_type_name):
17
+ super().__init__()
18
+ self._op_type = op_type_name
19
+
20
+ @property
21
+ def op_type(self) -> str:
22
+ return self._op_type
23
+
24
+ def infer_shape(self, node, ctx) -> None:
25
+ sympy_shape = ctx.compute_conv_pool_shape(node)
26
+ ctx.update_computed_dims(sympy_shape)
27
+ for o in node.output:
28
+ if not o:
29
+ continue
30
+ vi = ctx.known_vi_[o]
31
+ vi.CopyFrom(
32
+ helper.make_tensor_value_info(
33
+ o,
34
+ vi.type.tensor_type.elem_type,
35
+ get_shape_from_sympy_shape(sympy_shape),
36
+ )
37
+ )
38
+
39
+
40
+ register_shape_handler(PoolHandler("AveragePool"))
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for BatchNormalization operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class BatchNormalizationHandler(ShapeHandler):
11
+ """Handler for BatchNormalization operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "BatchNormalization"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+ # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
21
+ for i in {1, 2, 3, 4}:
22
+ if i < len(node.output) and node.output[i]:
23
+ ctx.propagate_shape_and_type(node, input_index=1, output_index=i)
24
+
25
+
26
+ register_shape_handler(BatchNormalizationHandler())
@@ -0,0 +1,33 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Conv operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_shape_from_sympy_shape
11
+
12
+
13
+ class ConvHandler(ShapeHandler):
14
+ """Handler for Conv operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Conv"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ sympy_shape = ctx.compute_conv_pool_shape(node)
22
+ ctx.update_computed_dims(sympy_shape)
23
+ vi = ctx.known_vi_[node.output[0]]
24
+ vi.CopyFrom(
25
+ helper.make_tensor_value_info(
26
+ node.output[0],
27
+ vi.type.tensor_type.elem_type,
28
+ get_shape_from_sympy_shape(sympy_shape),
29
+ )
30
+ )
31
+
32
+
33
+ register_shape_handler(ConvHandler())
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for CumSum operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("CumSum"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Identity operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("Identity"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MaxPool operator."""
5
+
6
+ from ...registry import register_shape_handler
7
+ from .average_pool import PoolHandler
8
+
9
+ register_shape_handler(PoolHandler("MaxPool"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MemcpyFromHost operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("MemcpyFromHost"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MemcpyToHost operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("MemcpyToHost"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MoE operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("MoE"))
@@ -0,0 +1,33 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for NhwcConv operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_shape_from_sympy_shape
11
+
12
+
13
+ class NhwcConvHandler(ShapeHandler):
14
+ """Handler for NhwcConv operator (channels last format)."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "NhwcConv"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ sympy_shape = ctx.compute_conv_pool_shape(node, channels_last=True)
22
+ ctx.update_computed_dims(sympy_shape)
23
+ vi = ctx.known_vi_[node.output[0]]
24
+ vi.CopyFrom(
25
+ helper.make_tensor_value_info(
26
+ node.output[0],
27
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
28
+ get_shape_from_sympy_shape(sympy_shape),
29
+ )
30
+ )
31
+
32
+
33
+ register_shape_handler(NhwcConvHandler())
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Reciprocal operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("Reciprocal"))
@@ -0,0 +1,9 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Round operator."""
5
+
6
+ from ...base import PassthroughHandler
7
+ from ...registry import register_shape_handler
8
+
9
+ register_shape_handler(PassthroughHandler("Round"))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Sequence operator shape handlers."""
5
+
6
+ from . import concat_from_sequence
7
+ from . import split_to_sequence
8
+ from . import sequence_at
9
+ from . import sequence_insert
10
+ from . import zip_map
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ConcatFromSequence operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, handle_negative_axis
11
+
12
+
13
+ class ConcatFromSequenceHandler(ShapeHandler):
14
+ """Handler for ConcatFromSequence operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "ConcatFromSequence"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ seq_shape = ctx.get_shape(node, 0)
22
+ new_axis = 1 if get_attribute(node, "new_axis") else 0
23
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
24
+ concat_dim = str(ctx.new_symbolic_dim_from_output(node, 0, axis))
25
+ new_shape = seq_shape
26
+ if new_axis:
27
+ new_shape = [*seq_shape[:axis], concat_dim, *seq_shape[axis:]]
28
+ else:
29
+ new_shape[axis] = concat_dim
30
+ vi = ctx.known_vi_[node.output[0]]
31
+ vi.CopyFrom(
32
+ helper.make_tensor_value_info(
33
+ node.output[0],
34
+ ctx.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
35
+ new_shape,
36
+ )
37
+ )
38
+
39
+
40
+ register_shape_handler(ConcatFromSequenceHandler())