onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,183 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Slice operator."""
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import sympy
10
+ from onnx import helper
11
+
12
+ from onnxslim.third_party._sympy.solve import try_solve
13
+
14
+ from ...base import ShapeHandler
15
+ from ...registry import register_shape_handler
16
+ from ...utils import (
17
+ as_list,
18
+ get_attribute,
19
+ get_opset,
20
+ get_shape_from_sympy_shape,
21
+ is_literal,
22
+ )
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class SliceHandler(ShapeHandler):
28
+ """Handler for Slice operator."""
29
+
30
+ @property
31
+ def op_type(self) -> str:
32
+ return "Slice"
33
+
34
+ def infer_shape(self, node, ctx) -> None:
35
+ def flatten_min(expr):
36
+ """Returns a list with expressions split by min() for inequality proof."""
37
+ assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}"
38
+ min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)]
39
+ if len(min_positions) == 1:
40
+ min_pos = min_positions[0]
41
+
42
+ def replace_min_with_arg(arg_idx):
43
+ replaced = list(expr.args)
44
+ assert isinstance(replaced[min_pos], sympy.Min)
45
+ assert len(replaced[min_pos].args) == 2
46
+ replaced[min_pos] = replaced[min_pos].args[arg_idx]
47
+ return sympy.Add(*replaced)
48
+
49
+ return [replace_min_with_arg(0), replace_min_with_arg(1)]
50
+ return [expr]
51
+
52
+ def less_equal(x, y):
53
+ """Returns True if x is less than or equal to y."""
54
+ try:
55
+ return x <= y
56
+ except TypeError:
57
+ pass
58
+ try:
59
+ return y >= x
60
+ except TypeError:
61
+ pass
62
+ try:
63
+ return -x >= -y
64
+ except TypeError:
65
+ pass
66
+ try:
67
+ return -y <= -x
68
+ except TypeError:
69
+ pass
70
+ try:
71
+ return y - x >= 0
72
+ except TypeError:
73
+ return all(d >= 0 for d in flatten_min(y - x))
74
+
75
+ def handle_negative_index(index, bound):
76
+ """Normalizes a negative index to be in [0, bound)."""
77
+ try:
78
+ if not less_equal(0, index):
79
+ if is_literal(index) and index <= -ctx.int_max_:
80
+ return index
81
+ return bound + index
82
+ except TypeError:
83
+ logger.warning(f"Cannot determine if {index} < 0")
84
+ return index
85
+
86
+ if get_opset(ctx.out_mp_) <= 9:
87
+ axes = get_attribute(node, "axes")
88
+ starts = get_attribute(node, "starts")
89
+ ends = get_attribute(node, "ends")
90
+ if not axes:
91
+ axes = list(range(len(starts)))
92
+ steps = [1] * len(axes)
93
+ else:
94
+ starts = as_list(ctx.try_get_value(node, 1), keep_none=True)
95
+ ends = as_list(ctx.try_get_value(node, 2), keep_none=True)
96
+ axes = ctx.try_get_value(node, 3)
97
+ steps = ctx.try_get_value(node, 4)
98
+ if axes is None and (starts is not None or ends is not None):
99
+ axes = list(range(len(starts if starts is not None else ends)))
100
+ if steps is None and (starts is not None or ends is not None):
101
+ steps = [1] * len(starts if starts is not None else ends)
102
+ axes = as_list(axes, keep_none=True)
103
+ steps = as_list(steps, keep_none=True)
104
+
105
+ new_sympy_shape = ctx.get_sympy_shape(node, 0)
106
+ if starts is None or ends is None:
107
+ if axes is None:
108
+ for i in range(len(new_sympy_shape)):
109
+ new_sympy_shape[i] = ctx.new_symbolic_dim_from_output(node, 0, i)
110
+ else:
111
+ new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
112
+ for i in axes:
113
+ new_sympy_shape[i] = ctx.new_symbolic_dim_from_output(node, 0, i)
114
+ else:
115
+ for i, s, e, t in zip(axes, starts, ends, steps):
116
+ if is_literal(e):
117
+ e = handle_negative_index(e, new_sympy_shape[i])
118
+ if is_literal(e):
119
+ if e >= ctx.int_max_:
120
+ e = new_sympy_shape[i]
121
+ elif e <= -ctx.int_max_:
122
+ e = 0 if s > 0 else -1
123
+ elif is_literal(new_sympy_shape[i]):
124
+ if e < 0:
125
+ e = max(0, e + new_sympy_shape[i])
126
+ e = min(e, new_sympy_shape[i])
127
+ else:
128
+ if e > 0:
129
+ e = sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
130
+ else:
131
+ if is_literal(new_sympy_shape[i]):
132
+ if new_sympy_shape[i] < 0:
133
+ e = sympy.Min(e, new_sympy_shape[i])
134
+ else:
135
+ try:
136
+ if not less_equal(e, new_sympy_shape[i]):
137
+ e = new_sympy_shape[i]
138
+ except Exception:
139
+ if len(e.free_symbols) == 1:
140
+ if try_solve((e - new_sympy_shape[i]) >= 0, next(iter(e.free_symbols))) is None:
141
+ logger.warning(
142
+ f"Unable to solve if {e} <= {new_sympy_shape[i]}, treat as not equal"
143
+ )
144
+ else:
145
+ logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal")
146
+ e = new_sympy_shape[i]
147
+
148
+ s = handle_negative_index(s, new_sympy_shape[i])
149
+ if is_literal(new_sympy_shape[i]) and is_literal(s):
150
+ s = max(0, min(s, new_sympy_shape[i]))
151
+
152
+ new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
153
+
154
+ ctx.update_computed_dims(new_sympy_shape)
155
+
156
+ vi = ctx.known_vi_[node.output[0]]
157
+ vi.CopyFrom(
158
+ helper.make_tensor_value_info(
159
+ node.output[0],
160
+ vi.type.tensor_type.elem_type,
161
+ get_shape_from_sympy_shape(new_sympy_shape),
162
+ )
163
+ )
164
+
165
+ # handle sympy_data if needed, for slice in shape computation
166
+ if (
167
+ node.input[0] in ctx.sympy_data_
168
+ and [0] == axes
169
+ and starts is not None
170
+ and len(starts) == 1
171
+ and ends is not None
172
+ and len(ends) == 1
173
+ and steps is not None
174
+ and len(steps) == 1
175
+ ):
176
+ input_sympy_data = ctx.sympy_data_[node.input[0]]
177
+ if type(input_sympy_data) == list or (
178
+ type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1
179
+ ):
180
+ ctx.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
181
+
182
+
183
+ register_shape_handler(SliceHandler())
@@ -0,0 +1,57 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Split operator."""
5
+
6
+ import sympy
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
12
+
13
+
14
+ class SplitHandler(ShapeHandler):
15
+ """Handler for Split operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "Split"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ infer_split_common(node, ctx, helper.make_tensor_value_info)
23
+
24
+
25
+ def infer_split_common(node, ctx, make_value_info_func):
26
+ """Infers the output shape for the Split operator."""
27
+ input_sympy_shape = ctx.get_sympy_shape(node, 0)
28
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
29
+ op_set = get_opset(ctx.out_mp_)
30
+
31
+ if op_set < 13:
32
+ split = get_attribute(node, "split")
33
+ assert ctx.try_get_value(node, 1) is None
34
+ else:
35
+ split = ctx.try_get_value(node, 1)
36
+ assert get_attribute(node, "split") is None
37
+
38
+ if split is None:
39
+ num_outputs = len(node.output)
40
+ split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
41
+ ctx.update_computed_dims(split)
42
+ else:
43
+ split = [sympy.Integer(s) for s in split]
44
+
45
+ for i_o in range(len(split)):
46
+ vi = ctx.known_vi_[node.output[i_o]]
47
+ vi.CopyFrom(
48
+ make_value_info_func(
49
+ node.output[i_o],
50
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
51
+ get_shape_from_sympy_shape([*input_sympy_shape[:axis], split[i_o], *input_sympy_shape[axis + 1 :]]),
52
+ )
53
+ )
54
+ ctx.known_vi_[vi.name] = vi
55
+
56
+
57
+ register_shape_handler(SplitHandler())
@@ -0,0 +1,69 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Squeeze operator."""
5
+
6
+ import logging
7
+
8
+ from onnx import helper
9
+
10
+ from ...base import ShapeHandler
11
+ from ...registry import register_shape_handler
12
+ from ...utils import get_attribute, get_opset, handle_negative_axis
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SqueezeHandler(ShapeHandler):
18
+ """Handler for Squeeze operator."""
19
+
20
+ @property
21
+ def op_type(self) -> str:
22
+ return "Squeeze"
23
+
24
+ def infer_shape(self, node, ctx) -> None:
25
+ input_shape = ctx.get_shape(node, 0)
26
+ op_set = get_opset(ctx.out_mp_)
27
+
28
+ if op_set < 13:
29
+ axes = get_attribute(node, "axes")
30
+ assert ctx.try_get_value(node, 1) is None
31
+ else:
32
+ axes = ctx.try_get_value(node, 1)
33
+ assert get_attribute(node, "axes") is None
34
+
35
+ if axes is None:
36
+ output_shape = [s for s in input_shape if s != 1]
37
+ if ctx.verbose_ > 0:
38
+ symbolic_dimensions = [s for s in input_shape if type(s) != int]
39
+ if symbolic_dimensions:
40
+ logger.debug(
41
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
42
+ f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
43
+ )
44
+ else:
45
+ axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
46
+ output_shape = []
47
+ for i in range(len(input_shape)):
48
+ if i not in axes:
49
+ output_shape.append(input_shape[i])
50
+ else:
51
+ assert input_shape[i] == 1 or type(input_shape[i]) != int
52
+ if ctx.verbose_ > 0 and type(input_shape[i]) != int:
53
+ logger.debug(
54
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
55
+ f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
56
+ )
57
+
58
+ vi = ctx.known_vi_[node.output[0]]
59
+ vi.CopyFrom(
60
+ helper.make_tensor_value_info(
61
+ node.output[0],
62
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
63
+ output_shape,
64
+ )
65
+ )
66
+ ctx.pass_on_sympy_data(node)
67
+
68
+
69
+ register_shape_handler(SqueezeHandler())
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Tile operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_shape_from_sympy_shape
11
+
12
+
13
+ class TileHandler(ShapeHandler):
14
+ """Handler for Tile operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Tile"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ repeats_value = ctx.try_get_value(node, 1)
22
+ new_sympy_shape = []
23
+ if repeats_value is not None:
24
+ input_sympy_shape = ctx.get_sympy_shape(node, 0)
25
+ for i, d in enumerate(input_sympy_shape):
26
+ new_dim = d * repeats_value[i]
27
+ new_sympy_shape.append(new_dim)
28
+ ctx.update_computed_dims(new_sympy_shape)
29
+ else:
30
+ new_sympy_shape = ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)
31
+ vi = ctx.known_vi_[node.output[0]]
32
+ vi.CopyFrom(
33
+ helper.make_tensor_value_info(
34
+ node.output[0],
35
+ vi.type.tensor_type.elem_type,
36
+ get_shape_from_sympy_shape(new_sympy_shape),
37
+ )
38
+ )
39
+
40
+
41
+ register_shape_handler(TileHandler())
@@ -0,0 +1,30 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Transpose operator."""
5
+
6
+ import numpy as np
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class TransposeHandler(ShapeHandler):
14
+ """Handler for Transpose operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Transpose"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ if node.input[0] in ctx.sympy_data_:
22
+ data_shape = ctx.get_shape(node, 0)
23
+ perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
24
+ input_data = ctx.sympy_data_[node.input[0]]
25
+ ctx.sympy_data_[node.output[0]] = (
26
+ np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
27
+ )
28
+
29
+
30
+ register_shape_handler(TransposeHandler())
@@ -0,0 +1,54 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Unsqueeze operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute, get_opset, handle_negative_axis
11
+
12
+
13
+ class UnsqueezeHandler(ShapeHandler):
14
+ """Handler for Unsqueeze operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Unsqueeze"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ input_shape = ctx.get_shape(node, 0)
22
+ op_set = get_opset(ctx.out_mp_)
23
+
24
+ if op_set < 13:
25
+ axes = get_attribute(node, "axes")
26
+ assert ctx.try_get_value(node, 1) is None
27
+ else:
28
+ axes = ctx.try_get_value(node, 1)
29
+ assert get_attribute(node, "axes") is None
30
+
31
+ output_rank = len(input_shape) + len(axes)
32
+ axes = [handle_negative_axis(a, output_rank) for a in axes]
33
+
34
+ input_axis = 0
35
+ output_shape = []
36
+ for i in range(output_rank):
37
+ if i in axes:
38
+ output_shape.append(1)
39
+ else:
40
+ output_shape.append(input_shape[input_axis])
41
+ input_axis += 1
42
+
43
+ vi = ctx.known_vi_[node.output[0]]
44
+ vi.CopyFrom(
45
+ helper.make_tensor_value_info(
46
+ node.output[0],
47
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
48
+ output_shape,
49
+ )
50
+ )
51
+ ctx.pass_on_sympy_data(node)
52
+
53
+
54
+ register_shape_handler(UnsqueezeHandler())
@@ -0,0 +1,244 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Utility functions for symbolic shape inference."""
5
+
6
+ import numpy as np
7
+ import onnx
8
+ import sympy
9
+ from onnx import helper, numpy_helper
10
+
11
+
12
+ def get_attribute(node, attr_name, default_value=None):
13
+ """Retrieve the value of an attribute from an ONNX node.
14
+
15
+ Args:
16
+ node: The ONNX node.
17
+ attr_name: The name of the attribute to retrieve.
18
+ default_value: The default value if the attribute is not found.
19
+
20
+ Returns:
21
+ The attribute value or the default value.
22
+ """
23
+ found = [attr for attr in node.attribute if attr.name == attr_name]
24
+ return helper.get_attribute_value(found[0]) if found else default_value
25
+
26
+
27
+ def get_dim_from_proto(dim):
28
+ """Retrieve the dimension value from the ONNX protobuf object.
29
+
30
+ Args:
31
+ dim: The ONNX TensorShapeProto.Dimension.
32
+
33
+ Returns:
34
+ The dimension value (int or str) or None.
35
+ """
36
+ return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None
37
+
38
+
39
+ def is_sequence(type_proto):
40
+ """Check if the given ONNX proto type is a sequence.
41
+
42
+ Args:
43
+ type_proto: The ONNX TypeProto.
44
+
45
+ Returns:
46
+ True if the type is a sequence type.
47
+ """
48
+ cls_type = type_proto.WhichOneof("value")
49
+ assert cls_type in {"tensor_type", "sequence_type"}
50
+ return cls_type == "sequence_type"
51
+
52
+
53
+ def get_shape_from_type_proto(type_proto):
54
+ """Extract the shape of a tensor from an ONNX type proto.
55
+
56
+ Args:
57
+ type_proto: The ONNX TypeProto.
58
+
59
+ Returns:
60
+ A list of dimension values or None if no shape is available.
61
+ """
62
+ assert not is_sequence(type_proto)
63
+ if type_proto.tensor_type.HasField("shape"):
64
+ return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
65
+ else:
66
+ return None
67
+
68
+
69
+ def get_elem_type_from_type_proto(type_proto):
70
+ """Return the element type from a given TypeProto object.
71
+
72
+ Args:
73
+ type_proto: The ONNX TypeProto.
74
+
75
+ Returns:
76
+ The element type (e.g., TensorProto.FLOAT).
77
+ """
78
+ if is_sequence(type_proto):
79
+ return type_proto.sequence_type.elem_type.tensor_type.elem_type
80
+ else:
81
+ return type_proto.tensor_type.elem_type
82
+
83
+
84
+ def get_shape_from_value_info(vi):
85
+ """Return the shape from the given ValueInfoProto object.
86
+
87
+ Args:
88
+ vi: The ONNX ValueInfoProto.
89
+
90
+ Returns:
91
+ A list of dimension values or None.
92
+ """
93
+ cls_type = vi.type.WhichOneof("value")
94
+ if cls_type is None:
95
+ return None
96
+ if not is_sequence(vi.type):
97
+ return get_shape_from_type_proto(vi.type)
98
+ if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type":
99
+ return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
100
+ else:
101
+ return None
102
+
103
+
104
+ def make_named_value_info(name):
105
+ """Create and return an ONNX ValueInfoProto object with the specified name.
106
+
107
+ Args:
108
+ name: The name for the ValueInfoProto.
109
+
110
+ Returns:
111
+ A new ValueInfoProto with the given name.
112
+ """
113
+ vi = onnx.ValueInfoProto()
114
+ vi.name = name
115
+ return vi
116
+
117
+
118
+ def get_shape_from_sympy_shape(sympy_shape):
119
+ """Convert a sympy shape to a list with int, str, or None elements.
120
+
121
+ Args:
122
+ sympy_shape: A list of sympy expressions.
123
+
124
+ Returns:
125
+ A list of int, str, or None values.
126
+ """
127
+ return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
128
+
129
+
130
+ def is_literal(dim):
131
+ """Check if a dimension is a literal number.
132
+
133
+ Args:
134
+ dim: The dimension value to check.
135
+
136
+ Returns:
137
+ True if the dimension is a literal number.
138
+ """
139
+ return type(dim) in {int, np.int64, np.int32, sympy.Integer} or (hasattr(dim, "is_number") and dim.is_number)
140
+
141
+
142
+ def handle_negative_axis(axis, rank):
143
+ """Convert a potentially negative axis to a positive axis.
144
+
145
+ Args:
146
+ axis: The axis value (can be negative).
147
+ rank: The total rank of the tensor.
148
+
149
+ Returns:
150
+ A non-negative axis value.
151
+ """
152
+ assert axis < rank and axis >= -rank
153
+ return axis if axis >= 0 else rank + axis
154
+
155
+
156
+ def get_opset(mp, domain=None):
157
+ """Retrieve the opset version for a given model namespace.
158
+
159
+ Args:
160
+ mp: The ONNX ModelProto.
161
+ domain: The domain(s) to check. Defaults to common ONNX domains.
162
+
163
+ Returns:
164
+ The opset version or None if not found.
165
+ """
166
+ domain = domain or ["", "onnx", "ai.onnx"]
167
+ if type(domain) != list:
168
+ domain = [domain]
169
+ for opset in mp.opset_import:
170
+ if opset.domain in domain:
171
+ return opset.version
172
+ return None
173
+
174
+
175
+ def as_scalar(x):
176
+ """Convert input to scalar if input is a list with a single item or a NumPy ndarray.
177
+
178
+ Args:
179
+ x: The input value.
180
+
181
+ Returns:
182
+ A scalar value.
183
+ """
184
+ if type(x) == list:
185
+ assert len(x) == 1
186
+ return x[0]
187
+ elif type(x) == np.ndarray:
188
+ return x.item()
189
+ else:
190
+ return x
191
+
192
+
193
+ def as_list(x, keep_none):
194
+ """Convert input to list, optionally preserving None values.
195
+
196
+ Args:
197
+ x: The input value.
198
+ keep_none: If True, return None as-is instead of wrapping in list.
199
+
200
+ Returns:
201
+ A list or None.
202
+ """
203
+ if type(x) == list:
204
+ return x
205
+ elif type(x) == np.ndarray:
206
+ return list(x)
207
+ elif keep_none and x is None:
208
+ return None
209
+ else:
210
+ return [x]
211
+
212
+
213
+ def sympy_reduce_product(x):
214
+ """Reduce a list or element to a product using Sympy's Integer.
215
+
216
+ Args:
217
+ x: A list or single value.
218
+
219
+ Returns:
220
+ The product as a sympy expression.
221
+ """
222
+ if type(x) == list:
223
+ value = sympy.Integer(1)
224
+ for v in x:
225
+ value = value * v
226
+ else:
227
+ value = x
228
+ return value
229
+
230
+
231
+ def numpy_to_sympy(array):
232
+ """Convert a numpy array to a list of sympy values.
233
+
234
+ Args:
235
+ array: A numpy array.
236
+
237
+ Returns:
238
+ The converted list or value.
239
+ """
240
+ if isinstance(array, np.ndarray):
241
+ if array.ndim == 0:
242
+ return int(array.item())
243
+ return [int(x) for x in array.flatten()]
244
+ return array