onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,378 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """
5
+ Symbolic Shape Inference Module
6
+
7
+ This module provides symbolic shape inference for ONNX models. It replaces the
8
+ monolithic SymbolicShapeInference class with a modular, handler-based architecture.
9
+
10
+ Usage:
11
+ from onnxslim.core.shape_inference import ShapeInferencer
12
+
13
+ model = onnx.load("model.onnx")
14
+ model_with_shapes = ShapeInferencer.infer_shapes(model)
15
+ """
16
+
17
+ import logging
18
+
19
+ import onnx
20
+ import sympy
21
+ from onnx import helper
22
+
23
+ from .context import InferenceContext
24
+ from .registry import get_all_aten_handlers, get_all_shape_handlers, get_aten_handler, get_shape_handler
25
+ from .utils import (
26
+ get_attribute,
27
+ get_opset,
28
+ get_shape_from_type_proto,
29
+ get_shape_from_value_info,
30
+ is_literal,
31
+ is_sequence,
32
+ )
33
+
34
+ # Import all handlers to trigger registration
35
+ from . import aten_ops # noqa: F401
36
+ from . import contrib_ops # noqa: F401
37
+ from . import standard_ops # noqa: F401
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class ShapeInferencer:
43
+ """Main class for performing symbolic shape inference on ONNX models."""
44
+
45
+ def __init__(self, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, prefix=""):
46
+ """Initialize the ShapeInferencer.
47
+
48
+ Args:
49
+ int_max: Maximum value for unbounded integers.
50
+ auto_merge: Whether to automatically merge conflicting dimensions.
51
+ guess_output_rank: Whether to guess output rank from input.
52
+ verbose: Logging verbosity level.
53
+ prefix: Prefix for generated symbolic dimension names.
54
+ """
55
+ self.int_max_ = int_max
56
+ self.auto_merge_ = auto_merge
57
+ self.guess_output_rank_ = guess_output_rank
58
+ self.verbose_ = verbose
59
+ self.prefix_ = prefix
60
+
61
+ def _infer_impl(self, ctx, start_sympy_data=None):
62
+ """Main inference implementation loop."""
63
+ ctx.sympy_data_ = start_sympy_data or {}
64
+ ctx.apply_suggested_merge(graph_input_only=True)
65
+ ctx.input_symbols_ = set()
66
+
67
+ # Process graph inputs
68
+ for i in ctx.out_mp_.graph.input:
69
+ input_shape = get_shape_from_value_info(i)
70
+ if input_shape is None:
71
+ continue
72
+
73
+ if is_sequence(i.type):
74
+ input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
75
+ else:
76
+ input_dims = i.type.tensor_type.shape.dim
77
+
78
+ for i_dim, dim in enumerate(input_shape):
79
+ if dim is None:
80
+ input_dims[i_dim].dim_param = str(ctx.new_symbolic_dim(i.name, i_dim))
81
+
82
+ ctx.input_symbols_.update([d for d in input_shape if type(d) == str])
83
+
84
+ for s in ctx.input_symbols_:
85
+ if s in ctx.suggested_merge_:
86
+ s_merge = ctx.suggested_merge_[s]
87
+ assert s_merge in ctx.symbolic_dims_
88
+ ctx.symbolic_dims_[s] = ctx.symbolic_dims_[s_merge]
89
+ else:
90
+ ctx.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
91
+
92
+ # Compute prerequisite for node for topological sort
93
+ prereq_for_node = {}
94
+
95
+ def get_prereq(node):
96
+ names = {i for i in node.input if i}
97
+ subgraphs = []
98
+ if node.op_type == "If":
99
+ subgraphs = [get_attribute(node, "then_branch"), get_attribute(node, "else_branch")]
100
+ elif node.op_type in {"Loop", "Scan"}:
101
+ subgraphs = [get_attribute(node, "body")]
102
+ for g in subgraphs:
103
+ g_outputs_and_initializers = {i.name for i in g.initializer}
104
+ g_prereq = set()
105
+ for n in g.node:
106
+ g_outputs_and_initializers.update(n.output)
107
+ for n in g.node:
108
+ g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
109
+ names.update(g_prereq)
110
+ for i in g.input:
111
+ if i.name in names:
112
+ names.remove(i.name)
113
+ return names
114
+
115
+ for n in ctx.out_mp_.graph.node:
116
+ prereq_for_node[n.output[0]] = get_prereq(n)
117
+
118
+ # Topological sort nodes
119
+ sorted_nodes = []
120
+ sorted_known_vi = {i.name for i in list(ctx.out_mp_.graph.input) + list(ctx.out_mp_.graph.initializer)}
121
+ if any(o.name in sorted_known_vi for o in ctx.out_mp_.graph.output):
122
+ sorted_nodes = ctx.out_mp_.graph.node
123
+ else:
124
+ while any(o.name not in sorted_known_vi for o in ctx.out_mp_.graph.output):
125
+ old_sorted_nodes_len = len(sorted_nodes)
126
+ for node in ctx.out_mp_.graph.node:
127
+ if node.output[0] not in sorted_known_vi and all(
128
+ i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
129
+ ):
130
+ sorted_known_vi.update(node.output)
131
+ sorted_nodes.append(node)
132
+ if old_sorted_nodes_len == len(sorted_nodes) and not all(
133
+ o.name in sorted_known_vi for o in ctx.out_mp_.graph.output
134
+ ):
135
+ raise Exception("Invalid model with cyclic graph")
136
+
137
+ # Get handlers
138
+ shape_handlers = get_all_shape_handlers()
139
+ aten_handlers = get_all_aten_handlers()
140
+
141
+ # Process each node
142
+ for node in sorted_nodes:
143
+ assert all([i in ctx.known_vi_ for i in node.input if i])
144
+ ctx.onnx_infer_single_node(node)
145
+ known_aten_op = False
146
+
147
+ # Try standard handlers first
148
+ handler = get_shape_handler(node.op_type)
149
+ if handler is not None:
150
+ handler.infer_shape(node, ctx)
151
+ elif node.op_type == "ConvTranspose":
152
+ vi = ctx.known_vi_[node.output[0]]
153
+ if len(vi.type.tensor_type.shape.dim) == 0:
154
+ vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
155
+ elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
156
+ for attr in node.attribute:
157
+ if attr.name == "operator":
158
+ aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
159
+ aten_handler = get_aten_handler(aten_op_name)
160
+ if aten_handler is not None:
161
+ known_aten_op = True
162
+ aten_handler.infer_shape(node, ctx)
163
+ break
164
+
165
+ if ctx.verbose_ > 2:
166
+ logger.debug(node.op_type + ": " + node.name)
167
+ for i, name in enumerate(node.input):
168
+ logger.debug(f" Input {i}: {name} {'initializer' if name in ctx.initializers_ else ''}")
169
+
170
+ # Handle dimension merging for broadcast ops
171
+ if node.op_type in {
172
+ "Add",
173
+ "Sub",
174
+ "Mul",
175
+ "Div",
176
+ "MatMul",
177
+ "MatMulInteger",
178
+ "MatMulInteger16",
179
+ "Where",
180
+ "Sum",
181
+ }:
182
+ vi = ctx.known_vi_[node.output[0]]
183
+ out_rank = len(get_shape_from_type_proto(vi.type))
184
+ in_shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
185
+ for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
186
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
187
+ if len(in_dims) > 1:
188
+ ctx.check_merged_dims(in_dims, allow_broadcast=True)
189
+
190
+ # Process outputs
191
+ for i_o in range(len(node.output)):
192
+ if node.op_type in {"SkipLayerNormalization", "SkipSimplifiedLayerNormalization"} and i_o in {1, 2}:
193
+ continue
194
+ if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
195
+ continue
196
+
197
+ vi = ctx.known_vi_[node.output[i_o]]
198
+ out_type = vi.type
199
+ out_type_kind = out_type.WhichOneof("value")
200
+
201
+ if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
202
+ if ctx.verbose_ > 2:
203
+ if out_type_kind == "sequence_type":
204
+ seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
205
+ if seq_cls_type == "tensor_type":
206
+ logger.debug(
207
+ f" {node.output[i_o]}: sequence of {str(get_shape_from_value_info(vi))} "
208
+ f"{onnx.TensorProto.DataType.Name(vi.type.sequence_type.elem_type.tensor_type.elem_type)}"
209
+ )
210
+ else:
211
+ logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
212
+ else:
213
+ logger.debug(f" {node.output[i_o]}: {out_type_kind}")
214
+ continue
215
+
216
+ out_shape = get_shape_from_value_info(vi)
217
+ out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
218
+ if ctx.verbose_ > 2:
219
+ logger.debug(
220
+ f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
221
+ )
222
+ if node.output[i_o] in ctx.sympy_data_:
223
+ logger.debug(" Sympy Data: " + str(ctx.sympy_data_[node.output[i_o]]))
224
+
225
+ if (out_shape is not None and (None in out_shape or ctx.is_shape_contains_none_dim(out_shape))) or out_type_undefined:
226
+ if ctx.auto_merge_:
227
+ if node.op_type in {
228
+ "Add",
229
+ "Sub",
230
+ "Mul",
231
+ "Div",
232
+ "MatMul",
233
+ "MatMulInteger",
234
+ "MatMulInteger16",
235
+ "Concat",
236
+ "Where",
237
+ "Sum",
238
+ "Equal",
239
+ "Less",
240
+ "Greater",
241
+ "LessOrEqual",
242
+ "GreaterOrEqual",
243
+ "Min",
244
+ "Max",
245
+ }:
246
+ shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
247
+ if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} and (
248
+ None in out_shape or ctx.is_shape_contains_none_dim(out_shape)
249
+ ):
250
+ if None in out_shape:
251
+ idx = out_shape.index(None)
252
+ else:
253
+ idx = out_shape.index(ctx.is_shape_contains_none_dim(out_shape))
254
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
255
+ assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
256
+ assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
257
+ elif node.op_type == "Expand":
258
+ shapes = [ctx.get_shape(node, 0), ctx.get_value(node, 1)]
259
+ else:
260
+ shapes = []
261
+
262
+ if shapes:
263
+ for idx in range(len(out_shape)):
264
+ if out_shape[idx] is not None and not ctx.is_none_dim(out_shape[idx]):
265
+ continue
266
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
267
+ if dim_idx:
268
+ ctx.add_suggested_merge(
269
+ [s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx) if i >= 0]
270
+ )
271
+ ctx.run_ = True
272
+ else:
273
+ ctx.run_ = False
274
+ else:
275
+ ctx.run_ = False
276
+
277
+ if not ctx.run_ and handler is None and not known_aten_op:
278
+ is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
279
+ if is_unknown_op:
280
+ out_rank = ctx.get_shape_rank(node, 0) if ctx.guess_output_rank_ else -1
281
+ else:
282
+ out_rank = len(out_shape)
283
+
284
+ if out_rank >= 0:
285
+ new_shape = ctx.new_symbolic_shape(out_rank, node, i_o)
286
+ if out_type_undefined:
287
+ out_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
288
+ else:
289
+ out_dtype = vi.type.tensor_type.elem_type
290
+ from .utils import get_shape_from_sympy_shape
291
+
292
+ vi.CopyFrom(
293
+ helper.make_tensor_value_info(vi.name, out_dtype, get_shape_from_sympy_shape(new_shape))
294
+ )
295
+
296
+ if ctx.verbose_ > 0:
297
+ if is_unknown_op:
298
+ logger.debug(f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape")
299
+ if ctx.verbose_ > 2:
300
+ logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
301
+ ctx.run_ = True
302
+ continue
303
+
304
+ if ctx.verbose_ > 0 or not ctx.auto_merge_ or out_type_undefined:
305
+ logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
306
+ logger.debug("node inputs:")
307
+ for i in node.input:
308
+ if i in ctx.known_vi_:
309
+ logger.debug(ctx.known_vi_[i])
310
+ else:
311
+ logger.debug(f"not in known_vi_ for {i}")
312
+ logger.debug("node outputs:")
313
+ for o in node.output:
314
+ if o in ctx.known_vi_:
315
+ logger.debug(ctx.known_vi_[o])
316
+ else:
317
+ logger.debug(f"not in known_vi_ for {o}")
318
+ if ctx.auto_merge_ and not out_type_undefined:
319
+ logger.debug("Merging: " + str(ctx.suggested_merge_))
320
+ return False
321
+
322
+ ctx.run_ = False
323
+ return True
324
+
325
+ def _update_output_from_vi(self, ctx):
326
+ """Update output attributes using known value information dictionary."""
327
+ for output in ctx.out_mp_.graph.output:
328
+ if output.name in ctx.known_vi_:
329
+ output.CopyFrom(ctx.known_vi_[output.name])
330
+
331
+ @staticmethod
332
+ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
333
+ """Perform symbolic shape inference on an ONNX model.
334
+
335
+ Args:
336
+ in_mp: The input ONNX ModelProto.
337
+ int_max: Maximum value for unbounded integers.
338
+ auto_merge: Whether to automatically merge conflicting dimensions.
339
+ guess_output_rank: Whether to guess output rank from input.
340
+ verbose: Logging verbosity level.
341
+
342
+ Returns:
343
+ The model with inferred shapes.
344
+
345
+ Raises:
346
+ Exception: If shape inference is incomplete.
347
+ """
348
+ onnx_opset = get_opset(in_mp)
349
+ if (not onnx_opset) or onnx_opset < 7:
350
+ logger.warning("Only support models of onnx opset 7 and above.")
351
+ return None
352
+
353
+ inferencer = ShapeInferencer(int_max, auto_merge, guess_output_rank, verbose)
354
+
355
+ # Create inference context
356
+ ctx = InferenceContext(
357
+ in_mp,
358
+ int_max=int_max,
359
+ auto_merge=auto_merge,
360
+ guess_output_rank=guess_output_rank,
361
+ verbose=verbose,
362
+ )
363
+ ctx.preprocess()
364
+
365
+ all_shapes_inferred = False
366
+ while ctx.run_:
367
+ all_shapes_inferred = inferencer._infer_impl(ctx)
368
+
369
+ inferencer._update_output_from_vi(ctx)
370
+
371
+ if not all_shapes_inferred:
372
+ raise Exception("Incomplete symbolic shape inference")
373
+
374
+ return ctx.out_mp_
375
+
376
+
377
+ # For backward compatibility
378
+ SymbolicShapeInference = ShapeInferencer
@@ -0,0 +1,16 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """PyTorch ATen operator shape handlers."""
5
+
6
+ from . import bitwise_or
7
+ from . import diagonal
8
+ from . import pool2d
9
+ from . import min_max
10
+ from . import multinomial
11
+ from . import unfold
12
+ from . import argmax
13
+ from . import group_norm
14
+ from . import upsample
15
+ from . import embedding
16
+ from . import numpy_t
@@ -0,0 +1,47 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen argmax operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+ from ..utils import get_shape_from_sympy_shape, handle_negative_axis
12
+
13
+
14
+ class AtenArgmaxHandler(ShapeHandler):
15
+ """Handler for ATen argmax operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "argmax"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ new_shape = None
23
+ if not node.input[1]:
24
+ # The argmax of the flattened input is returned.
25
+ new_shape = []
26
+ else:
27
+ dim = ctx.try_get_value(node, 1)
28
+ keepdim = ctx.try_get_value(node, 2)
29
+ if keepdim is not None:
30
+ sympy_shape = ctx.get_sympy_shape(node, 0)
31
+ if dim is not None:
32
+ dim = handle_negative_axis(dim, len(sympy_shape))
33
+ if keepdim:
34
+ sympy_shape[dim] = 1
35
+ else:
36
+ del sympy_shape[dim]
37
+ else:
38
+ rank = len(sympy_shape)
39
+ sympy_shape = ctx.new_symbolic_shape(rank if keepdim else rank - 1, node)
40
+ ctx.update_computed_dims(sympy_shape)
41
+ new_shape = get_shape_from_sympy_shape(sympy_shape)
42
+ if node.output[0] and new_shape is not None:
43
+ vi = ctx.known_vi_[node.output[0]]
44
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
45
+
46
+
47
+ register_aten_handler(AtenArgmaxHandler())
@@ -0,0 +1,28 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen bitwise_or operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ..base import ShapeHandler
9
+ from ..registry import register_aten_handler
10
+
11
+
12
+ class AtenBitwiseOrHandler(ShapeHandler):
13
+ """Handler for ATen bitwise_or operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "bitwise_or"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ shape0 = ctx.get_shape(node, 0)
21
+ shape1 = ctx.get_shape(node, 1)
22
+ new_shape = ctx.broadcast_shapes(shape0, shape1)
23
+ t0 = ctx.known_vi_[node.input[0]]
24
+ vi = ctx.known_vi_[node.output[0]]
25
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
26
+
27
+
28
+ register_aten_handler(AtenBitwiseOrHandler())
@@ -0,0 +1,52 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen diagonal operator."""
5
+
6
+ import sympy
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+ from ..utils import get_shape_from_sympy_shape, handle_negative_axis
12
+
13
+
14
+ class AtenDiagonalHandler(ShapeHandler):
15
+ """Handler for ATen diagonal operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "diagonal"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ sympy_shape = ctx.get_sympy_shape(node, 0)
23
+ rank = len(sympy_shape)
24
+ offset = ctx.try_get_value(node, 1)
25
+ dim1 = ctx.try_get_value(node, 2)
26
+ dim2 = ctx.try_get_value(node, 3)
27
+
28
+ assert offset is not None and dim1 is not None and dim2 is not None
29
+ dim1 = handle_negative_axis(dim1, rank)
30
+ dim2 = handle_negative_axis(dim2, rank)
31
+
32
+ new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}]
33
+ shape1 = sympy_shape[dim1]
34
+ shape2 = sympy_shape[dim2]
35
+ if offset >= 0:
36
+ diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
37
+ else:
38
+ diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
39
+ new_shape.append(diag_shape)
40
+
41
+ if node.output[0]:
42
+ vi = ctx.known_vi_[node.output[0]]
43
+ vi.CopyFrom(
44
+ helper.make_tensor_value_info(
45
+ node.output[0],
46
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
47
+ get_shape_from_sympy_shape(new_shape),
48
+ )
49
+ )
50
+
51
+
52
+ register_aten_handler(AtenDiagonalHandler())
@@ -0,0 +1,23 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen embedding operator."""
5
+
6
+ from ..base import ShapeHandler
7
+ from ..registry import register_aten_handler
8
+ from ..standard_ops.tensor.gather import GatherHandler
9
+
10
+
11
+ class AtenEmbeddingHandler(ShapeHandler):
12
+ """Handler for ATen embedding operator (reuses Gather logic)."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "embedding"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ # Embedding uses the same logic as Gather
20
+ GatherHandler().infer_shape(node, ctx)
21
+
22
+
23
+ register_aten_handler(AtenEmbeddingHandler())
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen native_group_norm operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ..base import ShapeHandler
9
+ from ..registry import register_aten_handler
10
+ from ..utils import as_scalar
11
+
12
+
13
+ class AtenGroupNormHandler(ShapeHandler):
14
+ """Handler for ATen native_group_norm operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "native_group_norm"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ ctx.propagate_shape_and_type(node)
22
+ input_shape = ctx.get_shape(node, 0)
23
+ N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None
24
+ group = ctx.try_get_value(node, 6)
25
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
26
+ for i in {1, 2}:
27
+ if node.output[i]:
28
+ vi = ctx.known_vi_[node.output[i]]
29
+ vi.CopyFrom(
30
+ helper.make_tensor_value_info(
31
+ node.output[i],
32
+ output_dtype,
33
+ [
34
+ (N if N is not None else str(ctx.new_symbolic_dim_from_output(node, i, 0))),
35
+ (as_scalar(group) if group is not None else str(ctx.new_symbolic_dim_from_output(node, i, 1))),
36
+ ],
37
+ )
38
+ )
39
+
40
+
41
+ register_aten_handler(AtenGroupNormHandler())
@@ -0,0 +1,64 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen min/max operators."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+ from ..utils import get_shape_from_sympy_shape, handle_negative_axis
12
+
13
+
14
+ class AtenMinMaxHandler(ShapeHandler):
15
+ """Handler for ATen min/max operators."""
16
+
17
+ def __init__(self, op_name):
18
+ super().__init__()
19
+ self._op_type = op_name
20
+
21
+ @property
22
+ def op_type(self) -> str:
23
+ return self._op_type
24
+
25
+ def infer_shape(self, node, ctx) -> None:
26
+ vi = ctx.known_vi_[node.output[0]]
27
+ if len(node.input) == 1:
28
+ vi.CopyFrom(
29
+ helper.make_tensor_value_info(
30
+ node.output[0],
31
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
32
+ [],
33
+ )
34
+ )
35
+ else:
36
+ assert len(node.input) == 3
37
+ keepdim = ctx.try_get_value(node, 2)
38
+ assert keepdim is not None
39
+ dim = ctx.try_get_value(node, 1)
40
+ if dim is None:
41
+ rank = ctx.get_shape_rank(node, 0)
42
+ output_shape = ctx.new_symbolic_shape(rank if keepdim else rank - 1, node)
43
+ else:
44
+ shape = ctx.get_sympy_shape(node, 0)
45
+ dim = handle_negative_axis(dim, len(shape))
46
+ output_shape = shape[:dim]
47
+ if keepdim:
48
+ output_shape += [1]
49
+ output_shape += shape[dim + 1 :]
50
+
51
+ output_shape = get_shape_from_sympy_shape(output_shape)
52
+ vi.CopyFrom(
53
+ helper.make_tensor_value_info(
54
+ node.output[0],
55
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
56
+ output_shape,
57
+ )
58
+ )
59
+ vi1 = ctx.known_vi_[node.output[1]]
60
+ vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
61
+
62
+
63
+ register_aten_handler(AtenMinMaxHandler("max"))
64
+ register_aten_handler(AtenMinMaxHandler("min"))