onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for QuickGelu operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class QuickGeluHandler(ShapeHandler):
11
+ """Handler for QuickGelu operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "QuickGelu"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(QuickGeluHandler())
@@ -0,0 +1,31 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for RotaryEmbedding operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class RotaryEmbeddingHandler(ShapeHandler):
11
+ """Handler for RotaryEmbedding operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "RotaryEmbedding"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ if len(node.output) == 1:
19
+ ctx.propagate_shape_and_type(node)
20
+ elif len(node.output) == 2:
21
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
22
+ ctx.propagate_shape_and_type(node, input_index=1, output_index=0)
23
+ ctx.propagate_shape_and_type(node, input_index=0, output_index=1)
24
+ elif len(node.output) == 3:
25
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
26
+ ctx.propagate_shape_and_type(node, input_index=1, output_index=0)
27
+ ctx.propagate_shape_and_type(node, input_index=1, output_index=1)
28
+ ctx.propagate_shape_and_type(node, input_index=0, output_index=2)
29
+
30
+
31
+ register_shape_handler(RotaryEmbeddingHandler())
@@ -0,0 +1,12 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Normalization-related contrib operator shape handlers."""
5
+
6
+ from . import layer_normalization
7
+ from . import simplified_layer_normalization
8
+ from . import skip_layer_normalization
9
+ from . import skip_simplified_layer_normalization
10
+ from . import group_norm
11
+ from . import skip_group_norm
12
+ from . import embed_layer_normalization
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for EmbedLayerNormalization operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+
12
+
13
+ class EmbedLayerNormalizationHandler(ShapeHandler):
14
+ """Handler for EmbedLayerNormalization operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "EmbedLayerNormalization"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ input_ids_shape = ctx.get_shape(node, 0)
22
+ word_embedding_shape = ctx.get_shape(node, 2)
23
+ assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
24
+ output_shape = [*input_ids_shape, word_embedding_shape[1]]
25
+
26
+ word_embedding_dtype = ctx.known_vi_[node.input[2]].type.tensor_type.elem_type
27
+ vi = ctx.known_vi_[node.output[0]]
28
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
29
+
30
+ if len(node.output) > 1 and node.output[1]:
31
+ mask_index_shape = [input_ids_shape[0]]
32
+ vi = ctx.known_vi_[node.output[1]]
33
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
34
+
35
+ if len(node.output) > 2:
36
+ # Optional output of add before layer normalization is done
37
+ vi = ctx.known_vi_[node.output[2]]
38
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
39
+
40
+
41
+ register_shape_handler(EmbedLayerNormalizationHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GroupNorm operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class GroupNormHandler(ShapeHandler):
11
+ """Handler for GroupNorm operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "GroupNorm"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(GroupNormHandler())
@@ -0,0 +1,42 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for LayerNormalization operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_attribute, handle_negative_axis
12
+
13
+
14
+ class LayerNormalizationHandler(ShapeHandler):
15
+ """Handler for LayerNormalization operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "LayerNormalization"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ ctx.propagate_shape_and_type(node)
23
+ if len(node.output) > 1:
24
+ axis = get_attribute(node, "axis")
25
+ if axis is None:
26
+ axis = -1
27
+ x_shape = ctx.get_shape(node, 0)
28
+ if x_shape is not None:
29
+ rank = len(x_shape)
30
+ axis = handle_negative_axis(axis, rank)
31
+ mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
32
+ mean_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
33
+ if mean_dtype in {onnx.TensorProto.FLOAT16, onnx.TensorProto.BFLOAT16}:
34
+ mean_dtype = onnx.TensorProto.FLOAT
35
+ vi = ctx.known_vi_[node.output[1]]
36
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
37
+ if len(node.output) > 2:
38
+ vi = ctx.known_vi_[node.output[2]]
39
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape))
40
+
41
+
42
+ register_shape_handler(LayerNormalizationHandler())
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SimplifiedLayerNormalization operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+ from .layer_normalization import LayerNormalizationHandler
9
+
10
+
11
+ class SimplifiedLayerNormalizationHandler(ShapeHandler):
12
+ """Handler for SimplifiedLayerNormalization operator."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "SimplifiedLayerNormalization"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ # Reuse LayerNormalization handler
20
+ LayerNormalizationHandler().infer_shape(node, ctx)
21
+
22
+
23
+ register_shape_handler(SimplifiedLayerNormalizationHandler())
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SkipGroupNorm operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class SkipGroupNormHandler(ShapeHandler):
11
+ """Handler for SkipGroupNorm operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "SkipGroupNorm"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node, 0, 0)
19
+ if len(node.output) > 1:
20
+ ctx.propagate_shape_and_type(node, 0, 1)
21
+
22
+
23
+ register_shape_handler(SkipGroupNormHandler())
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SkipLayerNormalization operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class SkipLayerNormalizationHandler(ShapeHandler):
11
+ """Handler for SkipLayerNormalization operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "SkipLayerNormalization"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+ # If the SkipLayerNormalization node contains the optional
21
+ # output for inference, infer the shape and type for it too
22
+ if len(node.output) > 3:
23
+ ctx.propagate_shape_and_type(node, 0, 3)
24
+
25
+
26
+ register_shape_handler(SkipLayerNormalizationHandler())
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for SkipSimplifiedLayerNormalization operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+ from .skip_layer_normalization import SkipLayerNormalizationHandler
9
+
10
+
11
+ class SkipSimplifiedLayerNormalizationHandler(ShapeHandler):
12
+ """Handler for SkipSimplifiedLayerNormalization operator."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "SkipSimplifiedLayerNormalization"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ # Reuse SkipLayerNormalization handler
20
+ SkipLayerNormalizationHandler().infer_shape(node, ctx)
21
+
22
+
23
+ register_shape_handler(SkipSimplifiedLayerNormalizationHandler())
@@ -0,0 +1,90 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Registry for shape inference handlers."""
5
+
6
+ from collections import OrderedDict
7
+
8
+ # Global registries for shape handlers
9
+ SHAPE_HANDLERS = OrderedDict()
10
+ ATEN_SHAPE_HANDLERS = OrderedDict()
11
+
12
+
13
+ def register_shape_handler(handler):
14
+ """Register a shape handler for a specific ONNX operator type.
15
+
16
+ Args:
17
+ handler: A ShapeHandler instance to register.
18
+
19
+ Returns:
20
+ The registered handler.
21
+
22
+ Raises:
23
+ ValueError: If a handler for the same op_type is already registered.
24
+ """
25
+ op_type = handler.op_type
26
+ if op_type in SHAPE_HANDLERS:
27
+ raise ValueError(f"Handler for op_type '{op_type}' is already registered")
28
+ SHAPE_HANDLERS[op_type] = handler
29
+ return handler
30
+
31
+
32
+ def register_aten_handler(handler):
33
+ """Register a shape handler for a PyTorch ATen operator.
34
+
35
+ Args:
36
+ handler: A ShapeHandler instance to register.
37
+
38
+ Returns:
39
+ The registered handler.
40
+
41
+ Raises:
42
+ ValueError: If a handler for the same op_name is already registered.
43
+ """
44
+ op_name = handler.op_type
45
+ if op_name in ATEN_SHAPE_HANDLERS:
46
+ raise ValueError(f"Handler for ATen op '{op_name}' is already registered")
47
+ ATEN_SHAPE_HANDLERS[op_name] = handler
48
+ return handler
49
+
50
+
51
+ def get_shape_handler(op_type):
52
+ """Get the shape handler for a given ONNX operator type.
53
+
54
+ Args:
55
+ op_type: The ONNX operator type string.
56
+
57
+ Returns:
58
+ The registered ShapeHandler or None if not found.
59
+ """
60
+ return SHAPE_HANDLERS.get(op_type)
61
+
62
+
63
+ def get_aten_handler(op_name):
64
+ """Get the shape handler for a given ATen operator name.
65
+
66
+ Args:
67
+ op_name: The ATen operator name string.
68
+
69
+ Returns:
70
+ The registered ShapeHandler or None if not found.
71
+ """
72
+ return ATEN_SHAPE_HANDLERS.get(op_name)
73
+
74
+
75
+ def get_all_shape_handlers():
76
+ """Get all registered shape handlers.
77
+
78
+ Returns:
79
+ OrderedDict of all registered shape handlers.
80
+ """
81
+ return SHAPE_HANDLERS.copy()
82
+
83
+
84
+ def get_all_aten_handlers():
85
+ """Get all registered ATen shape handlers.
86
+
87
+ Returns:
88
+ OrderedDict of all registered ATen shape handlers.
89
+ """
90
+ return ATEN_SHAPE_HANDLERS.copy()
@@ -0,0 +1,11 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Standard ONNX operator shape handlers."""
5
+
6
+ from . import tensor
7
+ from . import math
8
+ from . import nn
9
+ from . import control_flow
10
+ from . import sequence
11
+ from . import misc
@@ -0,0 +1,8 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Control flow operator shape handlers."""
5
+
6
+ from . import if_op
7
+ from . import loop
8
+ from . import scan
@@ -0,0 +1,43 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for If operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+ from ...utils import as_scalar, get_attribute
9
+
10
+
11
+ class IfHandler(ShapeHandler):
12
+ """Handler for If operator."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "If"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ subgraphs = [
20
+ get_attribute(node, "then_branch"),
21
+ get_attribute(node, "else_branch"),
22
+ ]
23
+
24
+ cond = ctx.try_get_value(node, 0)
25
+
26
+ for i_sub, subgraph in enumerate(subgraphs):
27
+ subgraph_infer = ctx.onnx_infer_subgraph(node, subgraph, use_node_input=False)
28
+ for i_out in range(len(node.output)):
29
+ vi = ctx.known_vi_[node.output[i_out]]
30
+ if i_sub == 0:
31
+ vi.CopyFrom(subgraph.output[i_out])
32
+ vi.name = node.output[i_out]
33
+ else:
34
+ ctx.fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
35
+ if (
36
+ cond is not None
37
+ and i_sub == (0 if as_scalar(cond) > 0 else 1)
38
+ and subgraph.output[i_out].name in subgraph_infer.sympy_data_
39
+ ):
40
+ ctx.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
41
+
42
+
43
+ register_shape_handler(IfHandler())
@@ -0,0 +1,74 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Loop operator."""
5
+
6
+ import logging
7
+
8
+ import onnx
9
+
10
+ from ...base import ShapeHandler
11
+ from ...registry import register_shape_handler
12
+ from ...utils import get_attribute, get_shape_from_value_info, is_sequence
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class LoopHandler(ShapeHandler):
18
+ """Handler for Loop operator."""
19
+
20
+ @property
21
+ def op_type(self) -> str:
22
+ return "Loop"
23
+
24
+ def infer_shape(self, node, ctx) -> None:
25
+ subgraph = get_attribute(node, "body")
26
+ assert len(subgraph.input) == len(node.input)
27
+ num_loop_carried = len(node.input) - 2
28
+
29
+ for i, si in enumerate(subgraph.input):
30
+ si_name = si.name
31
+ si.CopyFrom(ctx.known_vi_[node.input[i]])
32
+ si.name = si_name
33
+
34
+ ctx.onnx_infer_subgraph(node, subgraph)
35
+
36
+ need_second_infer = False
37
+ for i_out in range(1, num_loop_carried + 1):
38
+ so = subgraph.output[i_out]
39
+ so_shape = get_shape_from_value_info(so)
40
+ if is_sequence(so.type):
41
+ if so_shape and None in so_shape:
42
+ subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
43
+ need_second_infer = True
44
+ else:
45
+ si = subgraph.input[i_out + 1]
46
+ si_shape = get_shape_from_value_info(si)
47
+ for di, dims in enumerate(zip(si_shape, so_shape)):
48
+ if dims[0] != dims[1]:
49
+ new_dim = onnx.TensorShapeProto.Dimension()
50
+ new_dim.dim_param = str(ctx.new_symbolic_dim_from_output(node, i_out, di))
51
+ si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
52
+ so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
53
+ need_second_infer = True
54
+
55
+ if need_second_infer:
56
+ if ctx.verbose_ > 2:
57
+ logger.debug(f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables")
58
+ ctx.onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
59
+
60
+ loop_iter_dim = str(ctx.new_symbolic_dim_from_output(node))
61
+ for i in range(len(node.output)):
62
+ vi = ctx.known_vi_[node.output[i]]
63
+ vi.CopyFrom(subgraph.output[i + 1])
64
+ if i >= num_loop_carried:
65
+ assert not is_sequence(vi.type)
66
+ subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
67
+ vi.type.tensor_type.shape.ClearField("dim")
68
+ vi_dim = vi.type.tensor_type.shape.dim
69
+ vi_dim.add().dim_param = loop_iter_dim
70
+ vi_dim.extend(list(subgraph_vi_dim))
71
+ vi.name = node.output[i]
72
+
73
+
74
+ register_shape_handler(LoopHandler())
@@ -0,0 +1,54 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Scan operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_shape_from_type_proto, handle_negative_axis
11
+
12
+
13
+ class ScanHandler(ShapeHandler):
14
+ """Handler for Scan operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Scan"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ subgraph = get_attribute(node, "body")
22
+ num_scan_inputs = get_attribute(node, "num_scan_inputs")
23
+ scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
24
+ num_scan_states = len(node.input) - num_scan_inputs
25
+ scan_input_axes = [
26
+ handle_negative_axis(ax, ctx.get_shape_rank(node, i + num_scan_states)) for i, ax in enumerate(scan_input_axes)
27
+ ]
28
+
29
+ assert len(subgraph.input) >= len(node.input)
30
+ subgraph_inputs = subgraph.input[: len(node.input)]
31
+ for i, si in enumerate(subgraph_inputs):
32
+ subgraph_name = si.name
33
+ si.CopyFrom(ctx.known_vi_[node.input[i]])
34
+ if i >= num_scan_states:
35
+ scan_input_dim = si.type.tensor_type.shape.dim
36
+ scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
37
+ si.name = subgraph_name
38
+ ctx.onnx_infer_subgraph(node, subgraph)
39
+ num_scan_outputs = len(node.output) - num_scan_states
40
+ scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
41
+ scan_input_dim = get_shape_from_type_proto(ctx.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
42
+ for i, o in enumerate(node.output):
43
+ vi = ctx.known_vi_[o]
44
+ if i >= num_scan_states:
45
+ shape = get_shape_from_type_proto(subgraph.output[i].type)
46
+ new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
47
+ shape = [*shape[:new_dim], scan_input_dim, *shape[new_dim:]]
48
+ vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
49
+ else:
50
+ vi.CopyFrom(subgraph.output[i])
51
+ vi.name = o
52
+
53
+
54
+ register_shape_handler(ScanHandler())
@@ -0,0 +1,20 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Mathematical operator shape handlers."""
5
+
6
+ from . import matmul
7
+ from . import matmul_integer
8
+ from . import einsum
9
+ from . import reduce_sum
10
+ from . import reduce_prod
11
+ from . import add
12
+ from . import sub
13
+ from . import mul
14
+ from . import div
15
+ from . import neg
16
+ from . import floor
17
+ from . import min
18
+ from . import max
19
+ from . import equal
20
+ from . import where
@@ -0,0 +1,34 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shared symbolic computation helper for math operators."""
5
+
6
+ import sympy
7
+
8
+ from ...utils import is_literal
9
+
10
+
11
+ def infer_symbolic_compute_ops(node, ctx):
12
+ """Handles symbolic computation operations for given node based on predefined functions."""
13
+ funcs = {
14
+ "Add": lambda l: l[0] + l[1],
15
+ "Div": lambda l: (int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1]),
16
+ "Equal": lambda l: l[0] == l[1],
17
+ "Floor": lambda l: sympy.floor(l[0]),
18
+ "Max": lambda l: (
19
+ l[1]
20
+ if is_literal(l[0]) and int(l[0]) < -ctx.int_max_
21
+ else (l[0] if is_literal(l[1]) and int(l[1]) < -ctx.int_max_ else sympy.Max(l[0], l[1]))
22
+ ),
23
+ "Min": lambda l: (
24
+ l[1]
25
+ if is_literal(l[0]) and int(l[0]) > ctx.int_max_
26
+ else (l[0] if is_literal(l[1]) and int(l[1]) > ctx.int_max_ else sympy.Min(l[0], l[1]))
27
+ ),
28
+ "Mul": lambda l: (int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1]),
29
+ "Sub": lambda l: l[0] - l[1],
30
+ "Where": lambda l: l[1] if l[0] else l[2],
31
+ "Neg": lambda l: -l[0],
32
+ }
33
+ assert node.op_type in funcs
34
+ ctx.compute_on_sympy_data(node, funcs[node.op_type])
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Add operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Add", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Div operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Div", infer_symbolic_compute_ops))