onnxslim 0.1.82__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/fusion/convadd.py +21 -1
  3. onnxslim/core/pattern/fusion/convbn.py +21 -4
  4. onnxslim/core/pattern/fusion/convmul.py +23 -5
  5. onnxslim/core/pattern/fusion/padconv.py +5 -0
  6. onnxslim/core/shape_inference/__init__.py +378 -0
  7. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  8. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  9. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  10. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  11. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  12. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  13. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  14. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  15. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  16. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  17. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  18. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  19. onnxslim/core/shape_inference/base.py +111 -0
  20. onnxslim/core/shape_inference/context.py +645 -0
  21. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  22. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  23. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  24. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  33. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  34. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  35. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  44. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  45. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  46. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/registry.py +90 -0
  53. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  54. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  55. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  56. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  58. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  59. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  60. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  61. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  62. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  63. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  66. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  67. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  69. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  70. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  72. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  73. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  75. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  76. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  77. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  93. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  94. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  95. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  108. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  109. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  113. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  114. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  115. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  129. onnxslim/core/shape_inference/utils.py +244 -0
  130. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  131. onnxslim/utils.py +4 -2
  132. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
  133. onnxslim-0.1.83.dist-info/RECORD +187 -0
  134. onnxslim-0.1.82.dist-info/RECORD +0 -63
  135. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
  137. {onnxslim-0.1.82.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,119 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Einsum operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class EinsumHandler(ShapeHandler):
14
+ """Handler for Einsum operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Einsum"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ equation = get_attribute(node, "equation")
22
+ equation = equation.replace(b" ", b"")
23
+ mid_index = equation.find(b"->")
24
+ left_equation = equation[:mid_index] if mid_index != -1 else equation
25
+
26
+ num_operands = 0
27
+ num_ellipsis = 0
28
+ num_ellipsis_indices = 0
29
+ num_labels = 0
30
+ ellipsis_flag = True
31
+ dims_value = []
32
+ ellipsis_dims_value = []
33
+
34
+ label_maps = {}
35
+ repeated_labels = set()
36
+
37
+ terms = left_equation.split(b",")
38
+ for term in terms:
39
+ ellipsis_index = term.find(b"...")
40
+ shape = ctx.get_shape(node, num_operands)
41
+ rank = len(shape)
42
+ ellipsis_dims = 0
43
+ term_size = 0
44
+ num_illegal_char = 0
45
+
46
+ for i in range(len(term)):
47
+ if term[i] != 46:
48
+ term_size = term_size + 1
49
+
50
+ index = 0
51
+ while index < len(term):
52
+ if index == ellipsis_index:
53
+ ellipsis_dims = rank - term_size
54
+ if ellipsis_flag:
55
+ ellipsis_flag = False
56
+ for i in range(ellipsis_dims):
57
+ ellipsis_dims_value.append(shape[index + i - num_illegal_char])
58
+ else:
59
+ for i in range(ellipsis_dims):
60
+ shape_dim = shape[index + i - num_illegal_char]
61
+ current_dim = ellipsis_dims_value[i]
62
+ ellipsis_dims_value[i] = max(current_dim, shape_dim)
63
+
64
+ num_illegal_char += 3
65
+ index += 3
66
+ continue
67
+
68
+ elif term[index] == 46:
69
+ num_illegal_char += 1
70
+ index += 1
71
+ continue
72
+
73
+ char = term[index]
74
+ if char not in label_maps:
75
+ label_maps[char] = num_labels
76
+ dims_value.append(shape[index + ellipsis_dims - num_illegal_char])
77
+ num_labels += 1
78
+ else:
79
+ repeated_labels.add(char)
80
+
81
+ index += 1
82
+
83
+ if ellipsis_index != -1:
84
+ if num_ellipsis == 0:
85
+ if rank < term_size:
86
+ raise ValueError("Ellipsis represents incompatible dimensions.")
87
+ num_ellipsis_indices = rank - term_size
88
+ else:
89
+ if num_ellipsis_indices != rank - term_size:
90
+ raise ValueError("Ellipsis represents incompatible dimensions.")
91
+ num_ellipsis += 1
92
+ else:
93
+ if rank != term_size:
94
+ raise ValueError("Rank of input ", num_operands, " does not match the equation indices.")
95
+ num_operands += 1
96
+
97
+ new_sympy_shape = []
98
+ if mid_index != -1:
99
+ right_equation = equation[mid_index + 2 :]
100
+ right_ellipsis_index = right_equation.find(b"...")
101
+ if right_ellipsis_index != -1:
102
+ for i in range(num_ellipsis_indices):
103
+ new_sympy_shape.append(ellipsis_dims_value[i])
104
+ for c in right_equation:
105
+ if c != 46:
106
+ new_sympy_shape.append(dims_value[label_maps[c]])
107
+ else:
108
+ for i in range(num_ellipsis_indices):
109
+ new_sympy_shape.append(ellipsis_dims_value[i])
110
+ for label, idx in label_maps.items():
111
+ if label not in repeated_labels:
112
+ new_sympy_shape.append(dims_value[idx])
113
+
114
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
115
+ vi = ctx.known_vi_[node.output[0]]
116
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
117
+
118
+
119
+ register_shape_handler(EinsumHandler())
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Equal operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Equal", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Floor operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Floor", infer_symbolic_compute_ops))
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MatMul operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class MatMulHandler(ShapeHandler):
11
+ """Handler for MatMul operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "MatMul"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.compute_matmul_shape(node)
19
+
20
+
21
+ register_shape_handler(MatMulHandler())
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MatMulInteger16 operator."""
5
+
6
+ import onnx
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class MatMulIntegerHandler(ShapeHandler):
13
+ """Handler for MatMulInteger16 operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "MatMulInteger16"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ ctx.compute_matmul_shape(node, onnx.TensorProto.INT32)
21
+
22
+
23
+ register_shape_handler(MatMulIntegerHandler())
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Max operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Max", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Min operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Min", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Mul operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Mul", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Neg operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Neg", infer_symbolic_compute_ops))
@@ -0,0 +1,27 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ReduceProd operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+ from ...utils import get_attribute, sympy_reduce_product
9
+
10
+
11
+ class ReduceProdHandler(ShapeHandler):
12
+ """Handler for ReduceProd operator."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "ReduceProd"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ axes = get_attribute(node, "axes")
20
+ keep_dims = get_attribute(node, "keepdims", 1)
21
+ if keep_dims == 0 and axes == [0]:
22
+ data = ctx.get_int_or_float_values(node)[0]
23
+ if data is not None:
24
+ ctx.sympy_data_[node.output[0]] = sympy_reduce_product(data)
25
+
26
+
27
+ register_shape_handler(ReduceProdHandler())
@@ -0,0 +1,53 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ReduceSum operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
11
+
12
+
13
+ class ReduceSumHandler(ShapeHandler):
14
+ """Handler for ReduceSum operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "ReduceSum"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ keep_dims = get_attribute(node, "keepdims", 1)
22
+ if get_opset(ctx.out_mp_) >= 13 and len(node.input) > 1:
23
+ axes = ctx.try_get_value(node, 1)
24
+ vi = ctx.known_vi_[node.output[0]]
25
+ if axes is None:
26
+ assert keep_dims
27
+ vi.CopyFrom(
28
+ helper.make_tensor_value_info(
29
+ node.output[0],
30
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
31
+ get_shape_from_sympy_shape(ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)),
32
+ )
33
+ )
34
+ else:
35
+ shape = ctx.get_shape(node, 0)
36
+ output_shape = []
37
+ axes = [handle_negative_axis(a, len(shape)) for a in axes]
38
+ for i, d in enumerate(shape):
39
+ if i in axes:
40
+ if keep_dims:
41
+ output_shape.append(1)
42
+ else:
43
+ output_shape.append(d)
44
+ vi.CopyFrom(
45
+ helper.make_tensor_value_info(
46
+ node.output[0],
47
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
48
+ output_shape,
49
+ )
50
+ )
51
+
52
+
53
+ register_shape_handler(ReduceSumHandler())
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Sub operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Sub", infer_symbolic_compute_ops))
@@ -0,0 +1,10 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Where operator."""
5
+
6
+ from ...base import MultiOpHandler
7
+ from ...registry import register_shape_handler
8
+ from ._symbolic_compute import infer_symbolic_compute_ops
9
+
10
+ register_shape_handler(MultiOpHandler("Where", infer_symbolic_compute_ops))
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Miscellaneous operator shape handlers."""
5
+
6
+ from . import constant
7
+ from . import constant_of_shape
8
+ from . import cast
9
+ from . import category_mapper
10
+ from . import compress
11
+ from . import one_hot
12
+ from . import non_max_suppression
13
+ from . import non_zero
14
+ from . import top_k
15
+ from . import range
16
+ from . import resize
17
+ from . import scatter_elements
18
+ from . import array_feature_extractor
19
+ from . import softmax_cross_entropy_loss
20
+ from . import dequantize_linear
21
+ from . import quantize_linear
22
+ from . import relative_position_bias
@@ -0,0 +1,32 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ArrayFeatureExtractor operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class ArrayFeatureExtractorHandler(ShapeHandler):
13
+ """Handler for ArrayFeatureExtractor operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "ArrayFeatureExtractor"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ data_shape = ctx.get_shape(node, 0)
21
+ indices_shape = ctx.get_shape(node, 1)
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ vi.CopyFrom(
24
+ helper.make_tensor_value_info(
25
+ node.output[0],
26
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
27
+ data_shape[:-1] + indices_shape,
28
+ )
29
+ )
30
+
31
+
32
+ register_shape_handler(ArrayFeatureExtractorHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Cast operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class CastHandler(ShapeHandler):
11
+ """Handler for Cast operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "Cast"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.pass_on_sympy_data(node)
19
+
20
+
21
+ register_shape_handler(CastHandler())
@@ -0,0 +1,30 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for CategoryMapper operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+
12
+
13
+ class CategoryMapperHandler(ShapeHandler):
14
+ """Handler for CategoryMapper operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "CategoryMapper"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ input_type = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
22
+ if input_type == onnx.TensorProto.STRING:
23
+ output_type = onnx.TensorProto.INT64
24
+ else:
25
+ output_type = onnx.TensorProto.STRING
26
+ vi = ctx.known_vi_[node.output[0]]
27
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, ctx.get_shape(node, 0)))
28
+
29
+
30
+ register_shape_handler(CategoryMapperHandler())
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Compress operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, handle_negative_axis
11
+
12
+
13
+ class CompressHandler(ShapeHandler):
14
+ """Handler for Compress operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Compress"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ input_shape = ctx.get_shape(node, 0)
22
+ compress_len = str(ctx.new_symbolic_dim_from_output(node))
23
+ axis = get_attribute(node, "axis")
24
+ if axis is None:
25
+ output_shape = [compress_len]
26
+ else:
27
+ output_shape = input_shape
28
+ output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
29
+ vi = ctx.known_vi_[node.output[0]]
30
+ vi.CopyFrom(
31
+ helper.make_tensor_value_info(
32
+ node.output[0],
33
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
34
+ output_shape,
35
+ )
36
+ )
37
+
38
+
39
+ register_shape_handler(CompressHandler())
@@ -0,0 +1,27 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Constant operator."""
5
+
6
+ from onnx import numpy_helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class ConstantHandler(ShapeHandler):
14
+ """Handler for Constant operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Constant"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ t = get_attribute(node, "value")
22
+ t.name = node.output[0]
23
+ ctx.initializers_[node.output[0]] = t
24
+ ctx.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
25
+
26
+
27
+ register_shape_handler(ConstantHandler())
@@ -0,0 +1,45 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ConstantOfShape operator."""
5
+
6
+ import numpy as np
7
+ import onnx
8
+ from onnx import helper, numpy_helper
9
+
10
+ from ...base import ShapeHandler
11
+ from ...registry import register_shape_handler
12
+ from ...utils import get_attribute, get_shape_from_sympy_shape, is_literal
13
+
14
+
15
+ class ConstantOfShapeHandler(ShapeHandler):
16
+ """Handler for ConstantOfShape operator."""
17
+
18
+ @property
19
+ def op_type(self) -> str:
20
+ return "ConstantOfShape"
21
+
22
+ def infer_shape(self, node, ctx) -> None:
23
+ sympy_shape = ctx.get_int_or_float_values(node)[0]
24
+ vi = ctx.known_vi_[node.output[0]]
25
+ if sympy_shape is not None:
26
+ if type(sympy_shape) != list:
27
+ sympy_shape = [sympy_shape]
28
+ ctx.update_computed_dims(sympy_shape)
29
+ if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(is_literal(x) for x in sympy_shape):
30
+ ctx.sympy_data_[node.output[0]] = np.ones(
31
+ [int(x) for x in sympy_shape], dtype=np.int64
32
+ ) * numpy_helper.to_array(get_attribute(node, "value", 0))
33
+ else:
34
+ sympy_shape = ctx.new_symbolic_shape(ctx.get_shape(node, 0)[0], node)
35
+
36
+ vi.CopyFrom(
37
+ helper.make_tensor_value_info(
38
+ node.output[0],
39
+ vi.type.tensor_type.elem_type,
40
+ get_shape_from_sympy_shape(sympy_shape),
41
+ )
42
+ )
43
+
44
+
45
+ register_shape_handler(ConstantOfShapeHandler())
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for DequantizeLinear operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class DequantizeLinearHandler(ShapeHandler):
13
+ """Handler for DequantizeLinear operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "DequantizeLinear"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ output_dtype = ctx.known_vi_[node.input[1]].type.tensor_type.elem_type
21
+ output_shape = ctx.get_shape(node, 0)
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
24
+
25
+
26
+ register_shape_handler(DequantizeLinearHandler())
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for NonMaxSuppression operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+
12
+
13
+ class NonMaxSuppressionHandler(ShapeHandler):
14
+ """Handler for NonMaxSuppression operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "NonMaxSuppression"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ selected = str(ctx.new_symbolic_dim_from_output(node))
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
24
+
25
+
26
+ register_shape_handler(NonMaxSuppressionHandler())
@@ -0,0 +1,26 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for NonZero operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class NonZeroHandler(ShapeHandler):
13
+ """Handler for NonZero operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "NonZero"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ input_rank = ctx.get_shape_rank(node, 0)
21
+ nz_len = str(ctx.new_symbolic_dim_from_output(node, 0, 1))
22
+ vi = ctx.known_vi_[node.output[0]]
23
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
24
+
25
+
26
+ register_shape_handler(NonZeroHandler())