onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,3273 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ # -*- coding: UTF-8 -*-
5
+ import argparse
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ import sympy
11
+ from onnx import helper, numpy_helper, shape_inference
12
+ from packaging import version
13
+
14
+ from onnxslim.third_party._sympy.functions import FloorDiv
15
+ from onnxslim.third_party._sympy.printers import PythonPrinter as _PythonPrinter
16
+ from onnxslim.third_party._sympy.solve import try_solve
17
+
18
+ assert version.parse(onnx.__version__) >= version.parse("1.8.0")
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class PythonPrinter(_PythonPrinter):
24
+ def doprint(self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True) -> str:
25
+ # TODO: why are people passing strings to the printer here :think:
26
+ # if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
27
+ # expr = V.graph.sizevars.simplify(expr)
28
+ return super().doprint(expr)
29
+
30
+
31
+ pexpr = PythonPrinter().doprint
32
+
33
+
34
+ def get_attribute(node, attr_name, default_value=None):
35
+ """Retrieve the value of an attribute from an ONNX node, returning a default if the attribute is not found."""
36
+ found = [attr for attr in node.attribute if attr.name == attr_name]
37
+ return helper.get_attribute_value(found[0]) if found else default_value
38
+
39
+
40
+ def get_dim_from_proto(dim):
41
+ """Retrieve the dimension value from the ONNX protobuf object if it is a string."""
42
+ return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None
43
+
44
+
45
+ def is_sequence(type_proto):
46
+ """Check if the given ONNX proto type is a sequence."""
47
+ cls_type = type_proto.WhichOneof("value")
48
+ assert cls_type in {"tensor_type", "sequence_type"}
49
+ return cls_type == "sequence_type"
50
+
51
+
52
+ def get_shape_from_type_proto(type_proto):
53
+ """Extract the shape of a tensor from an ONNX type proto if available, otherwise return None."""
54
+ assert not is_sequence(type_proto)
55
+ if type_proto.tensor_type.HasField("shape"):
56
+ return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
57
+ else:
58
+ return None # note no shape is different from shape without dim (scalar)
59
+
60
+
61
+ def get_elem_type_from_type_proto(type_proto):
62
+ """Return the element type from a given TypeProto object, either from sequence type or tensor type."""
63
+ if is_sequence(type_proto):
64
+ return type_proto.sequence_type.elem_type.tensor_type.elem_type
65
+ else:
66
+ return type_proto.tensor_type.elem_type
67
+
68
+
69
+ def get_shape_from_value_info(vi):
70
+ """Return the shape from the given ValueInfoProto object, either from sequence type or tensor type."""
71
+ cls_type = vi.type.WhichOneof("value")
72
+ if cls_type is None:
73
+ return None
74
+ if not is_sequence(vi.type):
75
+ return get_shape_from_type_proto(vi.type)
76
+ if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type":
77
+ return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
78
+ else:
79
+ return None
80
+
81
+
82
+ def make_named_value_info(name):
83
+ """Create and return an ONNX ValueInfoProto object with the specified name."""
84
+ vi = onnx.ValueInfoProto()
85
+ vi.name = name
86
+ return vi
87
+
88
+
89
+ def get_shape_from_sympy_shape(sympy_shape):
90
+ """Convert a sympy shape to a list with int, str, or None elements."""
91
+ return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
92
+
93
+
94
+ def is_literal(dim):
95
+ """Check if a dimension is a literal number (int, np.int64, np.int32, sympy.Integer) or has an 'is_number'
96
+ attribute.
97
+ """
98
+ return type(dim) in {int, np.int64, np.int32, sympy.Integer} or (hasattr(dim, "is_number") and dim.is_number)
99
+
100
+
101
+ def handle_negative_axis(axis, rank):
102
+ """Convert a potentially negative axis to a positive axis based on the given rank."""
103
+ assert axis < rank and axis >= -rank
104
+ return axis if axis >= 0 else rank + axis
105
+
106
+
107
+ def get_opset(mp, domain=None):
108
+ """Retrieve the opset version for a given model namespace, defaulting to common ONNX domains if no specific domain
109
+ is provided.
110
+ """
111
+ domain = domain or ["", "onnx", "ai.onnx"]
112
+ if type(domain) != list: # noqa: E721
113
+ domain = [domain]
114
+ for opset in mp.opset_import:
115
+ if opset.domain in domain:
116
+ return opset.version
117
+
118
+ return None
119
+
120
+
121
+ def as_scalar(x):
122
+ """Convert input to scalar if input is a list with a single item or a NumPy ndarray."""
123
+ if type(x) == list: # noqa: E721
124
+ assert len(x) == 1
125
+ return x[0]
126
+ elif type(x) == np.ndarray:
127
+ return x.item()
128
+ else:
129
+ return x
130
+
131
+
132
+ def as_list(x, keep_none):
133
+ """Convert input to list, optionally preserving None values."""
134
+ if type(x) == list: # noqa: E721
135
+ return x
136
+ elif type(x) == np.ndarray:
137
+ return list(x)
138
+ elif keep_none and x is None:
139
+ return None
140
+ else:
141
+ return [x]
142
+
143
+
144
+ def sympy_reduce_product(x):
145
+ """Reduce a list or element to a product using Sympy's Integer."""
146
+ if type(x) == list: # noqa: E721
147
+ value = sympy.Integer(1)
148
+ for v in x:
149
+ value = value * v
150
+ else:
151
+ value = x
152
+ return value
153
+
154
+
155
+ class SymbolicShapeInference:
156
+ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
157
+ """Initializes the SymbolicShapeInference class with configuration parameters for symbolic shape inference."""
158
+ self.dispatcher_ = {
159
+ "Add": self._infer_symbolic_compute_ops,
160
+ "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor,
161
+ "AveragePool": self._infer_Pool,
162
+ "BatchNormalization": self._infer_BatchNormalization,
163
+ "Cast": self._infer_Cast,
164
+ "CategoryMapper": self._infer_CategoryMapper,
165
+ "Compress": self._infer_Compress,
166
+ "Concat": self._infer_Concat,
167
+ "ConcatFromSequence": self._infer_ConcatFromSequence,
168
+ "Constant": self._infer_Constant,
169
+ "ConstantOfShape": self._infer_ConstantOfShape,
170
+ "Conv": self._infer_Conv,
171
+ "CumSum": self._pass_on_shape_and_type,
172
+ "Div": self._infer_symbolic_compute_ops,
173
+ "Einsum": self._infer_Einsum,
174
+ "Expand": self._infer_Expand,
175
+ "Equal": self._infer_symbolic_compute_ops,
176
+ "Floor": self._infer_symbolic_compute_ops,
177
+ "Gather": self._infer_Gather,
178
+ "GatherElements": self._infer_GatherElements,
179
+ "GatherND": self._infer_GatherND,
180
+ "Identity": self._pass_on_shape_and_type,
181
+ "AllReduce": self._pass_on_shape_and_type,
182
+ "If": self._infer_If,
183
+ "Loop": self._infer_Loop,
184
+ "MatMul": self._infer_MatMul,
185
+ "MatMulInteger16": self._infer_MatMulInteger,
186
+ "MaxPool": self._infer_Pool,
187
+ "Max": self._infer_symbolic_compute_ops,
188
+ "MemcpyFromHost": self._pass_on_shape_and_type,
189
+ "MemcpyToHost": self._pass_on_shape_and_type,
190
+ "Min": self._infer_symbolic_compute_ops,
191
+ "MoE": self._pass_on_shape_and_type,
192
+ "Mul": self._infer_symbolic_compute_ops,
193
+ "NonMaxSuppression": self._infer_NonMaxSuppression,
194
+ "NonZero": self._infer_NonZero,
195
+ "OneHot": self._infer_OneHot,
196
+ "Pad": self._infer_Pad,
197
+ "Range": self._infer_Range,
198
+ "Reciprocal": self._pass_on_shape_and_type,
199
+ "ReduceSum": self._infer_ReduceSum,
200
+ "ReduceProd": self._infer_ReduceProd,
201
+ "Reshape": self._infer_Reshape,
202
+ "Resize": self._infer_Resize,
203
+ "Round": self._pass_on_shape_and_type,
204
+ "Scan": self._infer_Scan,
205
+ "ScatterElements": self._infer_ScatterElements,
206
+ "SequenceAt": self._infer_SequenceAt,
207
+ "SequenceInsert": self._infer_SequenceInsert,
208
+ "Shape": self._infer_Shape,
209
+ "Size": self._infer_Size,
210
+ "Slice": self._infer_Slice,
211
+ "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss,
212
+ "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss,
213
+ "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss,
214
+ "Split": self._infer_Split,
215
+ "SplitToSequence": self._infer_SplitToSequence,
216
+ "Squeeze": self._infer_Squeeze,
217
+ "Sub": self._infer_symbolic_compute_ops,
218
+ "Tile": self._infer_Tile,
219
+ "TopK": self._infer_TopK,
220
+ "Transpose": self._infer_Transpose,
221
+ "Unsqueeze": self._infer_Unsqueeze,
222
+ "Where": self._infer_symbolic_compute_ops,
223
+ "ZipMap": self._infer_ZipMap,
224
+ "Neg": self._infer_symbolic_compute_ops,
225
+ # contrib ops:
226
+ "Attention": self._infer_Attention,
227
+ "BiasAdd": self._infer_BiasAdd,
228
+ "BiasGelu": self._infer_BiasGelu,
229
+ "BiasSplitGelu": self._infer_BiasSplitGelu,
230
+ "DecoderMaskedMultiHeadAttention": self._infer_DecoderMaskedMultiHeadAttention,
231
+ "DequantizeLinear": self._infer_DequantizeLinear,
232
+ "EmbedLayerNormalization": self._infer_EmbedLayerNormalization,
233
+ "FastGelu": self._infer_FastGelu,
234
+ "GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
235
+ "Gelu": self._infer_Gelu,
236
+ "GemmFastGelu": self._infer_GemmFastGelu,
237
+ "GemmFloat8": self._infer_GemmFloat8,
238
+ "GroupNorm": self._infer_GroupNorm,
239
+ "SkipGroupNorm": self._infer_SkipGroupNorm,
240
+ "LayerNormalization": self._infer_LayerNormalization,
241
+ "LongformerAttention": self._infer_LongformerAttention,
242
+ "MultiHeadAttention": self._infer_MultiHeadAttention,
243
+ "NhwcConv": self._infer_NhwcConv,
244
+ "PackedAttention": self._infer_PackedAttention,
245
+ "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention,
246
+ "MultiScaleDeformableAttnTRT": self._infer_MultiScaleDeformableAttnTRT,
247
+ "PythonOp": self._infer_PythonOp,
248
+ "QuantizeLinear": self._infer_QuantizeLinear,
249
+ "QuickGelu": self._infer_FastGelu,
250
+ "RelativePositionBias": self._infer_RelativePositionBias,
251
+ "RemovePadding": self._infer_RemovePadding,
252
+ "RestorePadding": self._infer_RestorePadding,
253
+ "RotaryEmbedding": self._infer_RotaryEmbedding,
254
+ "SimplifiedLayerNormalization": self._infer_LayerNormalization,
255
+ "SkipLayerNormalization": self._infer_SkipLayerNormalization,
256
+ "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization,
257
+ }
258
+ self.aten_op_dispatcher_ = {
259
+ "embedding": self._infer_Gather,
260
+ "bitwise_or": self._infer_aten_bitwise_or,
261
+ "diagonal": self._infer_aten_diagonal,
262
+ "max_pool2d_with_indices": self._infer_aten_pool2d,
263
+ "max": self._infer_aten_minmax,
264
+ "min": self._infer_aten_minmax,
265
+ "multinomial": self._infer_aten_multinomial,
266
+ "unfold": self._infer_aten_unfold,
267
+ "argmax": self._infer_aten_argmax,
268
+ "avg_pool2d": self._infer_aten_pool2d,
269
+ "_adaptive_avg_pool2d": self._infer_aten_pool2d,
270
+ "numpy_T": self._infer_Transpose,
271
+ "native_group_norm": self._infer_aten_group_norm,
272
+ "upsample_nearest1d": self._infer_aten_upsample,
273
+ "upsample_nearest2d": self._infer_aten_upsample,
274
+ "upsample_nearest3d": self._infer_aten_upsample,
275
+ "upsample_bicubic2d": self._infer_aten_upsample,
276
+ }
277
+ self.run_ = True
278
+ self.suggested_merge_ = {}
279
+ self.symbolic_dims_ = {}
280
+ self.input_symbols_ = {}
281
+ self.auto_merge_ = auto_merge
282
+ self.guess_output_rank_ = guess_output_rank
283
+ self.verbose_ = verbose
284
+ self.int_max_ = int_max
285
+ self.subgraph_id_ = 0
286
+ self.prefix_ = prefix
287
+
288
+ def _add_suggested_merge(self, symbols, apply=False):
289
+ """Add suggested merges for input symbols, prioritizing literals, input symbolic dims, or existing symbolic
290
+ dims.
291
+ """
292
+ assert all((type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols)
293
+ symbols = set(symbols)
294
+ for k, v in self.suggested_merge_.items():
295
+ if k in symbols:
296
+ symbols.remove(k)
297
+ symbols.add(v)
298
+ map_to = None
299
+ # if there is literal, map to it first
300
+ for s in symbols:
301
+ if is_literal(s):
302
+ map_to = s
303
+ break
304
+ # when no literals, map to input symbolic dims, then existing symbolic dims
305
+ if map_to is None:
306
+ for s in symbols:
307
+ if s in self.input_symbols_:
308
+ map_to = s
309
+ break
310
+ if map_to is None:
311
+ for s in symbols:
312
+ if type(self.symbolic_dims_[s]) == sympy.Symbol:
313
+ map_to = s
314
+ break
315
+ # when nothing to map to, use the shorter one
316
+ if map_to is None:
317
+ if self.verbose_ > 0:
318
+ logger.warning(f"Potential unsafe merge between symbolic expressions: ({','.join(symbols)})")
319
+ symbols_list = list(symbols)
320
+ lens = [len(s) for s in symbols_list]
321
+ map_to = symbols_list[lens.index(min(lens))]
322
+ symbols.remove(map_to)
323
+
324
+ for s in symbols:
325
+ if s == map_to:
326
+ continue
327
+ if is_literal(map_to) and is_literal(s):
328
+ assert int(map_to) == int(s)
329
+ self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
330
+ for k, v in self.suggested_merge_.items():
331
+ if v == s:
332
+ self.suggested_merge_[k] = map_to
333
+ if apply and self.auto_merge_:
334
+ self._apply_suggested_merge()
335
+
336
+ def _apply_suggested_merge(self, graph_input_only=False):
337
+ """Applies suggested merges to graph dimensions based on predefined rules in the `suggested_merge_`
338
+ dictionary.
339
+ """
340
+ if not self.suggested_merge_:
341
+ return
342
+ for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
343
+ for d in i.type.tensor_type.shape.dim:
344
+ if d.dim_param in self.suggested_merge_:
345
+ v = self.suggested_merge_[d.dim_param]
346
+ if is_literal(v):
347
+ d.dim_value = int(v)
348
+ else:
349
+ d.dim_param = v
350
+
351
+ def _preprocess(self, in_mp):
352
+ self.out_mp_ = in_mp
353
+ self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)}
354
+ self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer}
355
+ self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)}
356
+ self.known_vi_.update(
357
+ {
358
+ i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))
359
+ for i in self.out_mp_.graph.initializer
360
+ }
361
+ )
362
+ self.known_vi_.update({i.name: i for i in list(self.out_mp_.graph.output)})
363
+
364
+ def _merge_symbols(self, dims):
365
+ """Merge dimension symbols, handling automatic merging and validation of symbolic dimensions."""
366
+ if any(type(d) != str for d in dims): # noqa: E721
367
+ if not self.auto_merge_:
368
+ return None
369
+ unique_dims = list(set(dims))
370
+ is_int = [is_literal(d) for d in unique_dims]
371
+ assert sum(is_int) <= 1 # if there are more than 1 unique ints, something is wrong
372
+ if sum(is_int) == 1:
373
+ int_dim = is_int.index(1)
374
+ if self.verbose_ > 0:
375
+ logger.debug(
376
+ f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
377
+ )
378
+ self._check_merged_dims(unique_dims, allow_broadcast=False)
379
+ return unique_dims[int_dim]
380
+ else:
381
+ if self.verbose_ > 0:
382
+ logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}")
383
+ return dims[0]
384
+ if all(d == dims[0] for d in dims):
385
+ return dims[0]
386
+ merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims]
387
+ if all(d == merged[0] for d in merged):
388
+ assert merged[0] in self.symbolic_dims_
389
+ return merged[0]
390
+ else:
391
+ return None
392
+
393
+ # broadcast from right to left, and merge symbolic dims if needed
394
+ def _broadcast_shapes(self, shape1, shape2):
395
+ """Broadcast two shapes from right to left, merging symbolic dimensions if necessary."""
396
+ new_shape = []
397
+ rank1 = len(shape1)
398
+ rank2 = len(shape2)
399
+ new_rank = max(rank1, rank2)
400
+ for i in range(new_rank):
401
+ dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
402
+ dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
403
+ if dim1 in [1, dim2]:
404
+ new_dim = dim2
405
+ elif dim2 == 1:
406
+ new_dim = dim1
407
+ else:
408
+ new_dim = self._merge_symbols([dim1, dim2])
409
+ if not new_dim:
410
+ # warning about unsupported broadcast when not auto merge
411
+ # note that auto merge has the risk of incorrectly merge symbols while one of them being 1
412
+ # for example, 'a' = 1, 'b' = 5 at runtime is valid broadcasting, but with auto merge 'a' == 'b'
413
+ if self.auto_merge_:
414
+ self._add_suggested_merge([dim1, dim2], apply=True)
415
+ else:
416
+ logger.warning(f"unsupported broadcast between {dim1!s} {dim2!s}")
417
+ new_shape = [new_dim, *new_shape]
418
+ return new_shape
419
+
420
+ def _get_shape(self, node, idx):
421
+ """Retrieve the shape of a tensor from a node's inputs based on known value info or initializers."""
422
+ name = node.input[idx]
423
+ if name in self.known_vi_:
424
+ vi = self.known_vi_[name]
425
+ return get_shape_from_value_info(vi)
426
+ else:
427
+ assert name in self.initializers_
428
+ return list(self.initializers_[name].dims)
429
+
430
+ def _try_get_shape(self, node, idx):
431
+ """Attempts to retrieve the shape of the input node at the specified index if available."""
432
+ if idx > len(node.input) - 1:
433
+ return None
434
+ name = node.input[idx]
435
+ if name in self.known_vi_:
436
+ vi = self.known_vi_[name]
437
+ return get_shape_from_value_info(vi)
438
+ if name in self.initializers_:
439
+ return list(self.initializers_[name].dims)
440
+ return None
441
+
442
+ def _get_shape_rank(self, node, idx):
443
+ """Return the rank (number of dimensions) of the shape of the input tensor at the specified index for a given
444
+ node.
445
+ """
446
+ return len(self._get_shape(node, idx))
447
+
448
+ def _get_sympy_shape(self, node, idx):
449
+ """Return the symbolic shape dimensions using SymPy for the given input tensor at the specified index for a
450
+ node.
451
+ """
452
+ sympy_shape = []
453
+ for d in self._get_shape(node, idx):
454
+ if type(d) == str: # noqa: E721
455
+ sympy_shape.append(
456
+ self.symbolic_dims_[d]
457
+ if d in self.symbolic_dims_
458
+ else sympy.Symbol(d, integer=True, nonnegative=True)
459
+ )
460
+ else:
461
+ assert None is not d
462
+ sympy_shape.append(d)
463
+ return sympy_shape
464
+
465
+ def _get_value(self, node, idx):
466
+ """Retrieve the value associated with a node's input index from sympy_data_ or initializers_."""
467
+ name = node.input[idx]
468
+ assert name in self.sympy_data_ or name in self.initializers_
469
+ return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
470
+
471
+ def _try_get_value(self, node, idx):
472
+ """Try to retrieve the value associated with a node's input index from sympy_data_ or initializers_."""
473
+ if idx >= len(node.input):
474
+ return None
475
+ name = node.input[idx]
476
+ if name in self.sympy_data_ or name in self.initializers_:
477
+ return self._get_value(node, idx)
478
+ return None
479
+
480
+ def _update_computed_dims(self, new_sympy_shape):
481
+ """Update dimensions in new_sympy_shape based on suggested merges and computational expressions."""
482
+ for i, new_dim in enumerate(new_sympy_shape):
483
+ if not is_literal(new_dim) and type(new_dim) != str: # noqa: E721
484
+ str_dim = pexpr(new_dim)
485
+ if str_dim in self.suggested_merge_:
486
+ if not is_literal(self.suggested_merge_[str_dim]):
487
+ new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
488
+ elif str_dim not in self.symbolic_dims_:
489
+ self.symbolic_dims_[str_dim] = new_dim
490
+
491
+ def _onnx_infer_single_node(self, node):
492
+ """Performs ONNX shape inference for a single node, skipping inference for specified operation types."""
493
+ skip_infer = node.op_type in {
494
+ "If",
495
+ "Loop",
496
+ "Scan",
497
+ "SplitToSequence",
498
+ "ZipMap", # contrib ops
499
+ "Attention",
500
+ "BiasGelu",
501
+ "EmbedLayerNormalization",
502
+ "FastGelu",
503
+ "Gelu",
504
+ "GemmFastGelu",
505
+ "LayerNormalization",
506
+ "LongformerAttention",
507
+ "DequantizeLinear",
508
+ "QuantizeLinear",
509
+ "RelativePositionBias",
510
+ "RemovePadding",
511
+ "RestorePadding",
512
+ "SimplifiedLayerNormalization",
513
+ "SkipLayerNormalization",
514
+ "SkipSimplifiedLayerNormalization",
515
+ "PackedAttention",
516
+ "PythonOp",
517
+ "MultiHeadAttention",
518
+ "GroupNorm",
519
+ "SkipGroupNorm",
520
+ "BiasSplitGelu",
521
+ "BiasAdd",
522
+ "NhwcConv",
523
+ "QuickGelu",
524
+ "RotaryEmbedding",
525
+ }
526
+
527
+ if not skip_infer:
528
+ # Only pass initializers that satisfy the following condition:
529
+ # (1) Operator need value of some input for shape inference.
530
+ # For example, Unsqueeze in opset 13 uses the axes input to calculate shape of output.
531
+ # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
532
+ # (3) The initializer is not in graph input. The means the node input is "constant" in inference.
533
+ initializers = []
534
+ if (get_opset(self.out_mp_) >= 9) and (
535
+ node.op_type == "Unsqueeze" or node.op_type == "ReduceMax" or node.op_type == "ReduceMean"
536
+ or node.op_type == "DFT" or node.op_type == "ReduceL2" or node.op_type == "ReduceMin"
537
+ ):
538
+ initializers = [
539
+ self.initializers_[name]
540
+ for name in node.input
541
+ if (name in self.initializers_ and name not in self.graph_inputs_)
542
+ ]
543
+
544
+ if (
545
+ node.op_type
546
+ in {
547
+ "Add",
548
+ "Sub",
549
+ "Mul",
550
+ "Div",
551
+ "MatMul",
552
+ "MatMulInteger",
553
+ "MatMulInteger16",
554
+ "Where",
555
+ "Sum",
556
+ }
557
+ and node.output[0] in self.known_vi_
558
+ ):
559
+ vi = self.known_vi_[node.output[0]]
560
+ out_rank = len(get_shape_from_type_proto(vi.type))
561
+ in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
562
+ for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
563
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
564
+ if len(in_dims) > 1:
565
+ self._check_merged_dims(in_dims, allow_broadcast=True)
566
+
567
+ # run single node inference with self.known_vi_ shapes
568
+ tmp_graph = helper.make_graph(
569
+ [node],
570
+ "tmp",
571
+ [self.known_vi_[i] for i in node.input if i],
572
+ [make_named_value_info(i) for i in node.output],
573
+ initializers,
574
+ )
575
+
576
+ kwargs = {}
577
+ kwargs["opset_imports"] = self.out_mp_.opset_import
578
+ kwargs["ir_version"] = self.out_mp_.ir_version
579
+
580
+ model = helper.make_model(tmp_graph, **kwargs)
581
+ model = shape_inference.infer_shapes(model)
582
+
583
+ for i_o in range(len(node.output)):
584
+ o = node.output[i_o]
585
+ if o: # skip optional output
586
+ out = model.graph.output[i_o]
587
+ if not out.type.WhichOneof("value") and o in self.known_vi_: # if empty and already had
588
+ continue
589
+
590
+ vi = self.out_mp_.graph.value_info.add()
591
+ if not skip_infer:
592
+ vi.CopyFrom(out)
593
+ else:
594
+ vi.name = o
595
+ self.known_vi_[o] = vi
596
+
597
+ def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True):
598
+ """Infer shapes and types within a subgraph for a given ONNX node using temporary graphs and known value
599
+ information.
600
+ """
601
+ if self.verbose_ > 2:
602
+ logger.debug(f"Inferencing subgraph of node {node.name} with output({node.output[0]}...): {node.op_type}")
603
+ # node inputs are not passed directly to the subgraph
604
+ # it's up to the node dispatcher to prepare subgraph input
605
+ # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape
606
+ # besides, inputs in subgraph could shadow implicit inputs
607
+ subgraph_inputs = {i.name for i in list(subgraph.initializer) + list(subgraph.input)}
608
+ subgraph_implicit_input = {name for name in self.known_vi_ if name not in subgraph_inputs}
609
+ tmp_graph = helper.make_graph(
610
+ list(subgraph.node),
611
+ "tmp",
612
+ list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input],
613
+ [make_named_value_info(i.name) for i in subgraph.output],
614
+ )
615
+ tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input])
616
+ tmp_graph.initializer.extend(subgraph.initializer)
617
+ kwargs = {}
618
+ kwargs["opset_imports"] = self.out_mp_.opset_import
619
+ kwargs["ir_version"] = self.out_mp_.ir_version
620
+
621
+ model = helper.make_model(tmp_graph, **kwargs)
622
+
623
+ symbolic_shape_inference = SymbolicShapeInference(
624
+ self.int_max_,
625
+ self.auto_merge_,
626
+ self.guess_output_rank_,
627
+ self.verbose_,
628
+ prefix=f"{self.prefix_}_{self.subgraph_id_!s}",
629
+ )
630
+ if inc_subgraph_id:
631
+ self.subgraph_id_ += 1
632
+
633
+ symbolic_shape_inference._preprocess(model)
634
+ symbolic_shape_inference.suggested_merge_ = self.suggested_merge_.copy()
635
+ while symbolic_shape_inference.run_:
636
+ symbolic_shape_inference._infer_impl(self.sympy_data_.copy())
637
+ symbolic_shape_inference._update_output_from_vi()
638
+ if use_node_input:
639
+ # if subgraph uses node input, it needs to update to merged dims
640
+ subgraph.ClearField("input")
641
+ subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)])
642
+ subgraph.ClearField("output")
643
+ subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output)
644
+ subgraph.ClearField("value_info")
645
+ subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info)
646
+ subgraph.ClearField("node")
647
+ subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node)
648
+ # for new symbolic dims from subgraph output, add to main graph symbolic dims
649
+ subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output]
650
+ subgraph_new_symbolic_dims = {
651
+ d
652
+ for s in subgraph_shapes
653
+ if s
654
+ for d in s
655
+ if type(d) == str and d not in self.symbolic_dims_ # noqa: E721
656
+ }
657
+ new_dims = {}
658
+ for d in subgraph_new_symbolic_dims:
659
+ assert d in symbolic_shape_inference.symbolic_dims_
660
+ new_dims[d] = symbolic_shape_inference.symbolic_dims_[d]
661
+ self.symbolic_dims_.update(new_dims)
662
+ return symbolic_shape_inference
663
+
664
+ def _get_int_or_float_values(self, node, broadcast=False, allow_float_values=False):
665
+ """Extracts integer or float values from a node, with options for broadcasting and allowing float values."""
666
+
667
+ def int_or_float(value, allow_float_values):
668
+ """Converts a value to an integer unless precision loss occurs and allow_float_values is True."""
669
+ return value if allow_float_values and value % 1 != 0 else int(value)
670
+
671
+ values = [self._try_get_value(node, i) for i in range(len(node.input))]
672
+ if all(v is not None for v in values):
673
+ # some shape compute is in floating point, cast to int for sympy
674
+ for i, v in enumerate(values):
675
+ if type(v) != np.ndarray:
676
+ continue
677
+ if len(v.shape) > 1:
678
+ new_v = None # ignore value for rank > 1
679
+ elif len(v.shape) == 0:
680
+ new_v = int_or_float(v.item(), allow_float_values)
681
+ else:
682
+ assert len(v.shape) == 1
683
+ new_v = [int_or_float(vv, allow_float_values) for vv in v]
684
+ values[i] = new_v
685
+ values_len = [len(v) if isinstance(v, list) else 0 for v in values]
686
+ max_len = max(values_len)
687
+ if max_len >= 1 and broadcast:
688
+ # broadcast
689
+ for i, v in enumerate(values):
690
+ if v is None:
691
+ continue # don't broadcast if value is unknown
692
+ if isinstance(v, list):
693
+ if len(v) < max_len:
694
+ values[i] = v * max_len
695
+ else:
696
+ assert len(v) == max_len
697
+ else:
698
+ values[i] = [v] * max_len
699
+ return values
700
+
701
+ def _compute_on_sympy_data(self, node, op_func):
702
+ """Calculate the result using Sympy data and a specified operation function."""
703
+ assert len(node.output) == 1
704
+
705
+ # Before mul & div operations
706
+ # cast inputs into integer might lose decimal part and reduce precision
707
+ # keep them as float, finish the operation, then cast the result into integer
708
+ if node.op_type in {"Mul", "Div"}:
709
+ values = self._get_int_or_float_values(node, broadcast=True, allow_float_values=True)
710
+ else:
711
+ values = self._get_int_or_float_values(node, broadcast=True)
712
+ if all(v is not None for v in values):
713
+ is_list = [isinstance(v, list) for v in values]
714
+ as_list = any(is_list)
715
+ if as_list:
716
+ self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)]
717
+ else:
718
+ self.sympy_data_[node.output[0]] = op_func(values)
719
+
720
+ def _pass_on_sympy_data(self, node):
721
+ """Pass Sympy data through a node, validating input length or node operation type 'Reshape', 'Unsqueeze',
722
+ 'Squeeze'.
723
+ """
724
+ assert len(node.input) == 1 or node.op_type in {
725
+ "Reshape",
726
+ "Unsqueeze",
727
+ "Squeeze",
728
+ }
729
+ self._compute_on_sympy_data(node, lambda x: x[0])
730
+
731
+ def _pass_on_shape_and_type(self, node):
732
+ """Propagates the shape and type information from input to output for a given node."""
733
+ vi = self.known_vi_[node.output[0]]
734
+ vi.CopyFrom(
735
+ helper.make_tensor_value_info(
736
+ node.output[0],
737
+ get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
738
+ self._get_shape(node, 0),
739
+ )
740
+ )
741
+
742
+ def _new_symbolic_dim(self, prefix, dim):
743
+ """Create and return a new symbolic dimension, handling literal values and caching for repeated uses."""
744
+ new_dim = f"{prefix}_d{dim}"
745
+ if new_dim in self.suggested_merge_:
746
+ v = self.suggested_merge_[new_dim]
747
+ new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
748
+ else:
749
+ new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
750
+ self.symbolic_dims_[new_dim] = new_symbolic_dim
751
+ return new_symbolic_dim
752
+
753
+ def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
754
+ """Generates a new symbolic dimension for a given node's output using the node's operation type, prefix, and
755
+ output index.
756
+ """
757
+ return self._new_symbolic_dim(
758
+ f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
759
+ dim,
760
+ )
761
+
762
+ def _new_symbolic_shape(self, rank, node, out_idx=0):
763
+ """Generate a new symbolic shape for a node output based on its rank and index."""
764
+ return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
765
+
766
+ def _compute_conv_pool_shape(self, node, channels_last=False):
767
+ """Calculate the output shape of a convolutional or pooling layer node, optionally considering channels_last
768
+ format.
769
+ """
770
+ sympy_shape = self._get_sympy_shape(node, 0)
771
+ if len(node.input) > 1:
772
+ W_shape = self._get_sympy_shape(node, 1)
773
+ rank = len(W_shape) - 2 # number of spatial axes
774
+ kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
775
+ sympy_shape[3 if channels_last else 1] = W_shape[0]
776
+ else:
777
+ W_shape = None
778
+ kernel_shape = get_attribute(node, "kernel_shape")
779
+ rank = len(kernel_shape)
780
+
781
+ assert len(sympy_shape) == rank + 2
782
+
783
+ # only need to symbolic shape inference if input has symbolic dims in spatial axes
784
+ spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
785
+ is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
786
+
787
+ if not any(is_symbolic_dims):
788
+ shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
789
+ if len(shape) > 0:
790
+ assert len(sympy_shape) == len(shape)
791
+ if channels_last:
792
+ sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
793
+ else:
794
+ sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
795
+ return sympy_shape
796
+
797
+ dilations = get_attribute(node, "dilations", [1] * rank)
798
+ strides = get_attribute(node, "strides", [1] * rank)
799
+ effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
800
+ pads = get_attribute(node, "pads")
801
+ if pads is None:
802
+ pads = [0] * (2 * rank)
803
+ auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
804
+ if auto_pad not in {"VALID", "NOTSET"}:
805
+ try:
806
+ residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
807
+ total_pads = [
808
+ max(0, (k - s) if r == 0 else (k - r))
809
+ for k, s, r in zip(effective_kernel_shape, strides, residual)
810
+ ]
811
+ except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational
812
+ total_pads = [
813
+ max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)
814
+ ] # assuming no residual if sympy throws error
815
+ elif auto_pad == "VALID":
816
+ total_pads = []
817
+ else:
818
+ total_pads = [0] * rank
819
+ else:
820
+ assert len(pads) == 2 * rank
821
+ total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
822
+
823
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
824
+ for i in range(rank):
825
+ effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
826
+ if len(total_pads) > 0:
827
+ effective_input_size = effective_input_size + total_pads[i]
828
+ if ceil_mode:
829
+ strided_kernel_positions = sympy.ceiling(
830
+ (effective_input_size - effective_kernel_shape[i]) / strides[i]
831
+ )
832
+ else:
833
+ strided_kernel_positions = FloorDiv((effective_input_size - effective_kernel_shape[i]), strides[i])
834
+ sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
835
+ return sympy_shape
836
+
837
+ def _check_merged_dims(self, dims, allow_broadcast=True):
838
+ """Checks merged dimensions for consistency, optionally allowing broadcasting."""
839
+ if allow_broadcast:
840
+ dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
841
+ if any(d != dims[0] for d in dims):
842
+ self._add_suggested_merge(dims, apply=True)
843
+
844
+ def _compute_matmul_shape(self, node, output_dtype=None):
845
+ """Compute the output shape for a matrix multiplication operation based on input shapes and optionally infer the
846
+ output data type.
847
+ """
848
+ lhs_shape = self._get_shape(node, 0)
849
+ rhs_shape = self._get_shape(node, 1)
850
+ lhs_rank = len(lhs_shape)
851
+ rhs_rank = len(rhs_shape)
852
+ lhs_reduce_dim = 0
853
+ rhs_reduce_dim = 0
854
+ assert lhs_rank > 0 and rhs_rank > 0
855
+ if lhs_rank == 1 and rhs_rank == 1:
856
+ new_shape = []
857
+ elif lhs_rank == 1:
858
+ rhs_reduce_dim = -2
859
+ new_shape = [*rhs_shape[:rhs_reduce_dim], rhs_shape[-1]]
860
+ elif rhs_rank == 1:
861
+ lhs_reduce_dim = -1
862
+ new_shape = lhs_shape[:lhs_reduce_dim]
863
+ else:
864
+ lhs_reduce_dim = -1
865
+ rhs_reduce_dim = -2
866
+ new_shape = [
867
+ *self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]),
868
+ lhs_shape[-2],
869
+ rhs_shape[-1],
870
+ ]
871
+ # merge reduce dim
872
+ self._check_merged_dims(
873
+ [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
874
+ allow_broadcast=False,
875
+ )
876
+ if output_dtype is None:
877
+ # infer output_dtype from input type when not specified
878
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
879
+ vi = self.known_vi_[node.output[0]]
880
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
881
+
882
+ def _fuse_tensor_type(self, node, out_idx, dst_type, src_type):
883
+ """Update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches."""
884
+ dst_tensor_type = (
885
+ dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
886
+ )
887
+ src_tensor_type = (
888
+ src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
889
+ )
890
+ if dst_tensor_type.elem_type != src_tensor_type.elem_type:
891
+ node_id = node.name or node.op_type
892
+ raise ValueError(
893
+ f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
894
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
895
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
896
+ )
897
+ if dst_tensor_type.HasField("shape"):
898
+ for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
899
+ if ds[0] != ds[1]:
900
+ # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type
901
+ # for sequence_type, clear the dimension
902
+ new_dim = onnx.TensorShapeProto.Dimension()
903
+ if not is_sequence(dst_type):
904
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, out_idx, di))
905
+ dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
906
+ else:
907
+ dst_tensor_type.CopyFrom(src_tensor_type)
908
+
909
+ def _infer_ArrayFeatureExtractor(self, node):
910
+ """Infer and update the shape and type information for the ArrayFeatureExtractor node using input data and
911
+ indices shapes.
912
+ """
913
+ data_shape = self._get_shape(node, 0)
914
+ indices_shape = self._get_shape(node, 1)
915
+ vi = self.known_vi_[node.output[0]]
916
+ vi.CopyFrom(
917
+ helper.make_tensor_value_info(
918
+ node.output[0],
919
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
920
+ data_shape[:-1] + indices_shape,
921
+ )
922
+ )
923
+
924
+ def _infer_symbolic_compute_ops(self, node):
925
+ """Handles symbolic computation operations for given node based on predefined functions."""
926
+ funcs = {
927
+ "Add": lambda l: l[0] + l[1], # noqa: E741
928
+ "Div": lambda l: (
929
+ int(l[0] // l[1]) if isinstance(l[0] // l[1], float) else l[0] // l[1]
930
+ ), # integer div in sympy
931
+ "Equal": lambda l: l[0] == l[1], # noqa: E741
932
+ "Floor": lambda l: sympy.floor(l[0]), # noqa: E741
933
+ "Max": lambda l: (
934
+ l[1]
935
+ if is_literal(l[0]) and int(l[0]) < -self.int_max_
936
+ else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1]))
937
+ ),
938
+ "Min": lambda l: (
939
+ l[1]
940
+ if is_literal(l[0]) and int(l[0]) > self.int_max_
941
+ else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1]))
942
+ ),
943
+ "Mul": lambda l: (int(l[0] * l[1]) if isinstance(l[0] * l[1], float) else l[0] * l[1]), # noqa: E741
944
+ "Sub": lambda l: l[0] - l[1], # noqa: E741
945
+ "Where": lambda l: l[1] if l[0] else l[2], # noqa: E741
946
+ "Neg": lambda l: -l[0], # noqa: E741
947
+ }
948
+ assert node.op_type in funcs
949
+ self._compute_on_sympy_data(node, funcs[node.op_type])
950
+
951
+ def _infer_Cast(self, node):
952
+ """Pass node's data to SymPy representation without alteration."""
953
+ self._pass_on_sympy_data(node)
954
+
955
+ def _infer_CategoryMapper(self, node):
956
+ """Infer and set output tensor type for ONNX CategoryMapper nodes based on input tensor type."""
957
+ input_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
958
+ if input_type == onnx.TensorProto.STRING:
959
+ output_type = onnx.TensorProto.INT64
960
+ else:
961
+ output_type = onnx.TensorProto.STRING
962
+ vi = self.known_vi_[node.output[0]]
963
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_type, self._get_shape(node, 0)))
964
+
965
+ def _infer_Compress(self, node):
966
+ """Infer the output shape and type for the Compress operation based on input shape and axis attribute."""
967
+ input_shape = self._get_shape(node, 0)
968
+ # create a new symbolic dimension for Compress output
969
+ compress_len = str(self._new_symbolic_dim_from_output(node))
970
+ axis = get_attribute(node, "axis")
971
+ if axis is None:
972
+ # when axis is not specified, input is flattened before compress so output is 1D
973
+ output_shape = [compress_len]
974
+ else:
975
+ output_shape = input_shape
976
+ output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len
977
+ vi = self.known_vi_[node.output[0]]
978
+ vi.CopyFrom(
979
+ helper.make_tensor_value_info(
980
+ node.output[0],
981
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
982
+ output_shape,
983
+ )
984
+ )
985
+
986
+ def _infer_Concat(self, node):
987
+ """Infer the output shape and type for the Concat operation based on input node values."""
988
+ if any(i in self.sympy_data_ or i in self.initializers_ for i in node.input):
989
+ values = self._get_int_or_float_values(node)
990
+ if all(v is not None for v in values):
991
+ assert get_attribute(node, "axis") == 0
992
+ self.sympy_data_[node.output[0]] = []
993
+ for i in range(len(node.input)):
994
+ value = values[i]
995
+ if isinstance(value, list):
996
+ self.sympy_data_[node.output[0]].extend(value)
997
+ else:
998
+ self.sympy_data_[node.output[0]].append(value)
999
+
1000
+ sympy_shape = self._get_sympy_shape(node, 0)
1001
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape))
1002
+ for i_idx in range(1, len(node.input)):
1003
+ input_shape = self._get_sympy_shape(node, i_idx)
1004
+ if input_shape:
1005
+ sympy_shape[axis] = sympy_shape[axis] + input_shape[axis]
1006
+ self._update_computed_dims(sympy_shape)
1007
+ # merge symbolic dims for non-concat axes
1008
+ for d in range(len(sympy_shape)):
1009
+ if d == axis:
1010
+ continue
1011
+ dims = [self._get_shape(node, i_idx)[d] for i_idx in range(len(node.input)) if self._get_shape(node, i_idx)]
1012
+ if all(d == dims[0] for d in dims):
1013
+ continue
1014
+ merged = self._merge_symbols(dims)
1015
+ if type(merged) == str: # noqa: E721
1016
+ sympy_shape[d] = self.symbolic_dims_[merged] if merged else None
1017
+ else:
1018
+ sympy_shape[d] = merged
1019
+ vi = self.known_vi_[node.output[0]]
1020
+ vi.CopyFrom(
1021
+ helper.make_tensor_value_info(
1022
+ node.output[0],
1023
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1024
+ get_shape_from_sympy_shape(sympy_shape),
1025
+ )
1026
+ )
1027
+
1028
+ def _infer_ConcatFromSequence(self, node):
1029
+ """Infers the output shape and type info for ConcatFromSequence operation in a computational graph node."""
1030
+ seq_shape = self._get_shape(node, 0)
1031
+ new_axis = 1 if get_attribute(node, "new_axis") else 0
1032
+ axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis)
1033
+ concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis))
1034
+ new_shape = seq_shape
1035
+ if new_axis:
1036
+ new_shape = [*seq_shape[:axis], concat_dim, *seq_shape[axis:]]
1037
+ else:
1038
+ new_shape[axis] = concat_dim
1039
+ vi = self.known_vi_[node.output[0]]
1040
+ vi.CopyFrom(
1041
+ helper.make_tensor_value_info(
1042
+ node.output[0],
1043
+ self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type,
1044
+ new_shape,
1045
+ )
1046
+ )
1047
+
1048
+ def _infer_Constant(self, node):
1049
+ """Infer the constant value for a given node and store it in sympy_data_."""
1050
+ t = get_attribute(node, "value")
1051
+ # Lower constant nodes to initializers
1052
+ t.name = node.output[0]
1053
+ self.initializers_[node.output[0]] = t
1054
+ self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
1055
+
1056
+ def _infer_ConstantOfShape(self, node):
1057
+ """Infer the constant tensor of a given shape from a node and update sympy_data_."""
1058
+ sympy_shape = self._get_int_or_float_values(node)[0]
1059
+ vi = self.known_vi_[node.output[0]]
1060
+ if sympy_shape is not None:
1061
+ if type(sympy_shape) != list: # noqa: E721
1062
+ sympy_shape = [sympy_shape]
1063
+ self._update_computed_dims(sympy_shape)
1064
+ # update sympy data if output type is int, and shape is known
1065
+ if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all(is_literal(x) for x in sympy_shape):
1066
+ self.sympy_data_[node.output[0]] = np.ones(
1067
+ [int(x) for x in sympy_shape], dtype=np.int64
1068
+ ) * numpy_helper.to_array(get_attribute(node, "value", 0))
1069
+ else:
1070
+ # create new dynamic shape
1071
+ # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length
1072
+ sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node)
1073
+
1074
+ vi.CopyFrom(
1075
+ helper.make_tensor_value_info(
1076
+ node.output[0],
1077
+ vi.type.tensor_type.elem_type,
1078
+ get_shape_from_sympy_shape(sympy_shape),
1079
+ )
1080
+ )
1081
+
1082
+ def _infer_Conv(self, node):
1083
+ """Infers the shape of the output tensor for a convolution operation node and updates the known value info."""
1084
+ sympy_shape = self._compute_conv_pool_shape(node)
1085
+ self._update_computed_dims(sympy_shape)
1086
+ vi = self.known_vi_[node.output[0]]
1087
+ vi.CopyFrom(
1088
+ helper.make_tensor_value_info(
1089
+ node.output[0],
1090
+ vi.type.tensor_type.elem_type,
1091
+ get_shape_from_sympy_shape(sympy_shape),
1092
+ )
1093
+ )
1094
+
1095
+ def _infer_NhwcConv(self, node):
1096
+ """Infer the shape of the output tensor for a convolution operation with NHWC format."""
1097
+ sympy_shape = self._compute_conv_pool_shape(node, channels_last=True)
1098
+ self._update_computed_dims(sympy_shape)
1099
+ vi = self.known_vi_[node.output[0]]
1100
+ vi.CopyFrom(
1101
+ helper.make_tensor_value_info(
1102
+ node.output[0],
1103
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1104
+ get_shape_from_sympy_shape(sympy_shape),
1105
+ )
1106
+ )
1107
+
1108
+ def _infer_DequantizeLinear(self, node):
1109
+ """Infer output type and shape for the DequantizeLinear node based on input 1's scale data type."""
1110
+ output_dtype = self.known_vi_[node.input[1]].type.tensor_type.elem_type
1111
+
1112
+ # Get the output shape from the first input.
1113
+ output_shape = self._get_shape(node, 0)
1114
+
1115
+ vi = self.known_vi_[node.output[0]]
1116
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
1117
+
1118
+ def _infer_QuantizeLinear(self, node):
1119
+ """Infer the output data type and shape for the QuantizeLinear ONNX node, defaulting to uint8 if not
1120
+ specified.
1121
+ """
1122
+ # Otherwise, default to uint8
1123
+ output_dtype = onnx.TensorProto.UINT8
1124
+ if len(node.input) > 2 and node.input[2]:
1125
+ output_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
1126
+
1127
+ # Get the output shape from the first input.
1128
+ output_shape = self._get_shape(node, 0)
1129
+
1130
+ vi = self.known_vi_[node.output[0]]
1131
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
1132
+
1133
+ def _infer_Einsum(self, node):
1134
+ """Infer the output shape and type for the Einsum operation as per ONNX standards: https://github.com/onnx/onnx/blob/v1.18.0/onnx/defs/math/defs.cc#L2472."""
1135
+ equation = get_attribute(node, "equation")
1136
+ equation = equation.replace(b" ", b"")
1137
+ mid_index = equation.find(b"->")
1138
+ left_equation = equation[:mid_index] if mid_index != -1 else equation
1139
+
1140
+ num_operands = 0
1141
+ num_ellipsis = 0
1142
+ num_ellipsis_indices = 0
1143
+ num_labels = 0
1144
+ ellipsis_flag = True
1145
+ dims_value = []
1146
+ ellipsis_dims_value = []
1147
+
1148
+ label_maps = {}
1149
+ repeated_labels = set()
1150
+
1151
+ terms = left_equation.split(b",")
1152
+ for term in terms:
1153
+ ellipsis_index = term.find(b"...")
1154
+ shape = self._get_shape(node, num_operands)
1155
+ rank = len(shape)
1156
+ ellipsis_dims = 0
1157
+ term_size = 0
1158
+ num_illegal_char = 0
1159
+
1160
+ for i in range(len(term)):
1161
+ if term[i] != 46:
1162
+ term_size = term_size + 1
1163
+
1164
+ index = 0
1165
+ while index < len(term):
1166
+ if index == ellipsis_index:
1167
+ ellipsis_dims = rank - term_size
1168
+ if ellipsis_flag:
1169
+ ellipsis_flag = False
1170
+ for i in range(ellipsis_dims):
1171
+ ellipsis_dims_value.append(shape[index + i - num_illegal_char])
1172
+ else:
1173
+ for i in range(ellipsis_dims):
1174
+ shape_dim = shape[index + i - num_illegal_char]
1175
+ current_dim = ellipsis_dims_value[i]
1176
+ ellipsis_dims_value[i] = max(current_dim, shape_dim)
1177
+
1178
+ num_illegal_char += 3
1179
+ index += 3 # Skip all three characters in '...'
1180
+ continue
1181
+
1182
+ elif term[index] == 46: # ASCII for '.'
1183
+ num_illegal_char += 1
1184
+ index += 1
1185
+ continue
1186
+
1187
+ char = term[index]
1188
+ if char not in label_maps:
1189
+ label_maps[char] = num_labels
1190
+ dims_value.append(shape[index + ellipsis_dims - num_illegal_char])
1191
+ num_labels += 1
1192
+ else:
1193
+ repeated_labels.add(char)
1194
+
1195
+ index += 1
1196
+
1197
+ if ellipsis_index != -1:
1198
+ # If there is an ellipsis, the number of dimensions it represents
1199
+ # must be total dim - letter dimensions
1200
+ if num_ellipsis == 0:
1201
+ if rank < term_size:
1202
+ raise ValueError("Ellipsis represents incompatible dimensions.")
1203
+ num_ellipsis_indices = rank - term_size
1204
+ else:
1205
+ if num_ellipsis_indices != rank - term_size:
1206
+ raise ValueError("Ellipsis represents incompatible dimensions.")
1207
+ num_ellipsis += 1
1208
+ else:
1209
+ if rank != term_size:
1210
+ raise ValueError("Rank of input ", num_operands, " does not match the equation indices.")
1211
+ num_operands += 1
1212
+
1213
+ new_sympy_shape = []
1214
+ from collections import OrderedDict
1215
+
1216
+ OrderedDict()
1217
+ if mid_index != -1:
1218
+ right_equation = equation[mid_index + 2 :]
1219
+ right_ellipsis_index = right_equation.find(b"...")
1220
+ if right_ellipsis_index != -1:
1221
+ for i in range(num_ellipsis_indices):
1222
+ new_sympy_shape.append(ellipsis_dims_value[i])
1223
+ for c in right_equation:
1224
+ if c != 46: # c != b'.'
1225
+ new_sympy_shape.append(dims_value[label_maps[c]])
1226
+ else:
1227
+ for i in range(num_ellipsis_indices):
1228
+ new_sympy_shape.append(ellipsis_dims_value[i])
1229
+ for label, idx in label_maps.items():
1230
+ if label not in repeated_labels:
1231
+ new_sympy_shape.append(dims_value[idx])
1232
+
1233
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1234
+ vi = self.known_vi_[node.output[0]]
1235
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_sympy_shape))
1236
+
1237
+ def _infer_Expand(self, node):
1238
+ """Infers and updates the output shape for the Expand operation based on broadcasted input shapes."""
1239
+ expand_to_shape = as_list(self._try_get_value(node, 1), keep_none=True)
1240
+ if expand_to_shape is not None:
1241
+ # new_shape's dim can come from shape value
1242
+ self._update_computed_dims(expand_to_shape)
1243
+ shape = self._get_shape(node, 0)
1244
+ new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape))
1245
+ vi = self.known_vi_[node.output[0]]
1246
+ vi.CopyFrom(
1247
+ helper.make_tensor_value_info(
1248
+ node.output[0],
1249
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1250
+ new_shape,
1251
+ )
1252
+ )
1253
+
1254
+ def _infer_Gather(self, node):
1255
+ """Infer the output shape of the Gather operation based on the input data and indices shapes."""
1256
+ data_shape = self._get_shape(node, 0)
1257
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape))
1258
+ indices_shape = self._get_shape(node, 1)
1259
+ vi = self.known_vi_[node.output[0]]
1260
+ vi.CopyFrom(
1261
+ helper.make_tensor_value_info(
1262
+ node.output[0],
1263
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1264
+ data_shape[:axis] + indices_shape + data_shape[axis + 1 :],
1265
+ )
1266
+ )
1267
+ # for 1D input, do some sympy compute
1268
+ if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and get_attribute(node, "axis", 0) == 0:
1269
+ idx = self._try_get_value(node, 1)
1270
+ if idx is not None:
1271
+ data = self.sympy_data_[node.input[0]]
1272
+ if type(data) == list: # noqa: E721
1273
+ if type(idx) == np.ndarray and len(idx.shape) == 1:
1274
+ self.sympy_data_[node.output[0]] = [data[int(i)] for i in idx]
1275
+ else:
1276
+ self.sympy_data_[node.output[0]] = data[int(idx)]
1277
+ else:
1278
+ assert idx in {0, -1}
1279
+ self.sympy_data_[node.output[0]] = data
1280
+
1281
+ def _infer_GatherElements(self, node):
1282
+ """Infers the output shape and type for the GatherElements node based on input tensors and updates the node's
1283
+ value information.
1284
+ """
1285
+ indices_shape = self._get_shape(node, 1)
1286
+ vi = self.known_vi_[node.output[0]]
1287
+ vi.CopyFrom(
1288
+ helper.make_tensor_value_info(
1289
+ node.output[0],
1290
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1291
+ indices_shape,
1292
+ )
1293
+ )
1294
+
1295
+ def _infer_GatherND(self, node):
1296
+ """Infers the output shape and type for the GatherND operation based on input data and indices shapes."""
1297
+ data_shape = self._get_shape(node, 0)
1298
+ data_rank = len(data_shape)
1299
+ indices_shape = self._get_shape(node, 1)
1300
+ len(indices_shape)
1301
+ last_index_dimension = indices_shape[-1]
1302
+ batch_dims = get_attribute(node, "batch_dims", 0)
1303
+ assert (
1304
+ is_literal(last_index_dimension)
1305
+ and is_literal(batch_dims)
1306
+ and (batch_dims + last_index_dimension) <= data_rank
1307
+ )
1308
+ new_shape = indices_shape[:-1] + data_shape[batch_dims + last_index_dimension :]
1309
+ vi = self.known_vi_[node.output[0]]
1310
+ vi.CopyFrom(
1311
+ helper.make_tensor_value_info(
1312
+ node.output[0],
1313
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1314
+ new_shape,
1315
+ )
1316
+ )
1317
+
1318
+ def _infer_If(self, node):
1319
+ """Infer the output shape for an If node, handling constant conditions to ensure shape consistency between
1320
+ branches.
1321
+ """
1322
+ subgraphs = [
1323
+ get_attribute(node, "then_branch"),
1324
+ get_attribute(node, "else_branch"),
1325
+ ]
1326
+
1327
+ for i_sub, subgraph in enumerate(subgraphs):
1328
+ subgraph_infer = self._onnx_infer_subgraph(node, subgraph, use_node_input=False)
1329
+ for i_out in range(len(node.output)):
1330
+ vi = self.known_vi_[node.output[i_out]]
1331
+ if i_sub == 0:
1332
+ vi.CopyFrom(subgraph.output[i_out])
1333
+ vi.name = node.output[i_out]
1334
+ else:
1335
+ self._fuse_tensor_type(node, i_out, vi.type, subgraph.output[i_out].type)
1336
+ # fixme
1337
+ if (
1338
+ cond is not None
1339
+ and i_sub == (0 if as_scalar(cond) > 0 else 1)
1340
+ and subgraph.output[i_out].name in subgraph_infer.sympy_data_
1341
+ ):
1342
+ self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name]
1343
+
1344
+ def _infer_Loop(self, node):
1345
+ """Infer the shape and type of variables produced by the 'Loop' operation in an ONNX graph."""
1346
+ subgraph = get_attribute(node, "body")
1347
+ assert len(subgraph.input) == len(node.input)
1348
+ num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition
1349
+ # when sequence_type is used as loop carried input
1350
+ # needs to run subgraph infer twice if the tensor shape in sequence contains None
1351
+ for i, si in enumerate(subgraph.input):
1352
+ si_name = si.name
1353
+ si.CopyFrom(self.known_vi_[node.input[i]])
1354
+ si.name = si_name
1355
+
1356
+ self._onnx_infer_subgraph(node, subgraph)
1357
+
1358
+ # check subgraph input/output for shape changes in loop carried variables
1359
+ # for tensor_type, create new symbolic dim when changing, i.e., output = Concat(input, a)
1360
+ # for sequence_type, propagate from output to input
1361
+ need_second_infer = False
1362
+ for i_out in range(1, num_loop_carried + 1):
1363
+ so = subgraph.output[i_out]
1364
+ so_shape = get_shape_from_value_info(so)
1365
+ if is_sequence(so.type):
1366
+ if so_shape and None in so_shape:
1367
+ # copy shape from output to input
1368
+ # note that loop input is [loop_len, cond, input_0, input_1, ...]
1369
+ # while loop output is [cond, output_0, output_1, ...]
1370
+ subgraph.input[i_out + 1].type.sequence_type.elem_type.CopyFrom(so.type.sequence_type.elem_type)
1371
+ need_second_infer = True
1372
+ else:
1373
+ si = subgraph.input[i_out + 1]
1374
+ si_shape = get_shape_from_value_info(si)
1375
+ for di, dims in enumerate(zip(si_shape, so_shape)):
1376
+ if dims[0] != dims[1]:
1377
+ new_dim = onnx.TensorShapeProto.Dimension()
1378
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, i_out, di))
1379
+ si.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1380
+ so.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1381
+ need_second_infer = True
1382
+
1383
+ if need_second_infer:
1384
+ if self.verbose_ > 2:
1385
+ logger.debug(
1386
+ f"Rerun Loop: {node.name}({node.output[0]}...), because of sequence in loop carried variables"
1387
+ )
1388
+ self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False)
1389
+
1390
+ # create a new symbolic dimension for iteration dependent dimension
1391
+ loop_iter_dim = str(self._new_symbolic_dim_from_output(node))
1392
+ for i in range(len(node.output)):
1393
+ vi = self.known_vi_[node.output[i]]
1394
+ vi.CopyFrom(subgraph.output[i + 1]) # first subgraph output is condition, not in node output
1395
+ if i >= num_loop_carried:
1396
+ assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type
1397
+ subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim
1398
+ vi.type.tensor_type.shape.ClearField("dim")
1399
+ vi_dim = vi.type.tensor_type.shape.dim
1400
+ vi_dim.add().dim_param = loop_iter_dim
1401
+ vi_dim.extend(list(subgraph_vi_dim))
1402
+ vi.name = node.output[i]
1403
+
1404
+ def _infer_MatMul(self, node):
1405
+ """Infer the output shape of a matrix multiplication node."""
1406
+ self._compute_matmul_shape(node)
1407
+
1408
+ def _infer_MatMulInteger(self, node):
1409
+ """Infer the output shape of an integer matrix multiplication node."""
1410
+ self._compute_matmul_shape(node, onnx.TensorProto.INT32)
1411
+
1412
+ def _infer_NonMaxSuppression(self, node):
1413
+ """Infer the output shape of a NonMaxSuppression node and update the value info."""
1414
+ selected = str(self._new_symbolic_dim_from_output(node))
1415
+ vi = self.known_vi_[node.output[0]]
1416
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [selected, 3]))
1417
+
1418
+ def _infer_NonZero(self, node):
1419
+ """Infer the output shape of a NonZero node and update the value info."""
1420
+ input_rank = self._get_shape_rank(node, 0)
1421
+ # create a new symbolic dimension for NonZero output
1422
+ nz_len = str(self._new_symbolic_dim_from_output(node, 0, 1))
1423
+ vi = self.known_vi_[node.output[0]]
1424
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, [input_rank, nz_len]))
1425
+
1426
+ def _infer_OneHot(self, node):
1427
+ """Infer the shape and type of the output tensor for the OneHot node operation."""
1428
+ sympy_shape = self._get_sympy_shape(node, 0)
1429
+ depth = self._try_get_value(node, 1)
1430
+ axis = get_attribute(node, "axis", -1)
1431
+ axis = handle_negative_axis(axis, len(sympy_shape) + 1)
1432
+ new_shape = get_shape_from_sympy_shape(
1433
+ [
1434
+ *sympy_shape[:axis],
1435
+ depth if is_literal(depth) else self._new_symbolic_dim_from_output(node),
1436
+ *sympy_shape[axis:],
1437
+ ]
1438
+ )
1439
+ vi = self.known_vi_[node.output[0]]
1440
+ vi.CopyFrom(
1441
+ helper.make_tensor_value_info(
1442
+ node.output[0],
1443
+ self.known_vi_[node.input[2]].type.tensor_type.elem_type,
1444
+ new_shape,
1445
+ )
1446
+ )
1447
+
1448
+ def _infer_Pad(self, node):
1449
+ """Infers the output shape and type for the Pad operation based on ONNX node attributes and opset version."""
1450
+ if get_opset(self.out_mp_) <= 10:
1451
+ pads = get_attribute(node, "pads")
1452
+ else:
1453
+ pads = self._try_get_value(node, 1)
1454
+
1455
+ sympy_shape = self._get_sympy_shape(node, 0)
1456
+ rank = len(sympy_shape)
1457
+
1458
+ if pads is not None:
1459
+ assert len(pads) == 2 * rank
1460
+ new_sympy_shape = [
1461
+ d + pad_up + pad_down for d, pad_up, pad_down in zip(sympy_shape, pads[:rank], pads[rank:])
1462
+ ]
1463
+ self._update_computed_dims(new_sympy_shape)
1464
+ else:
1465
+ # dynamic pads, create new symbolic dimensions
1466
+ new_sympy_shape = self._new_symbolic_shape(rank, node)
1467
+ output_tp = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1468
+
1469
+ vi = self.known_vi_[node.output[0]]
1470
+ vi.CopyFrom(
1471
+ helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))
1472
+ )
1473
+
1474
+ def _infer_Pool(self, node):
1475
+ """Infer and update dimensions for pooling layers based on the input node."""
1476
+ sympy_shape = self._compute_conv_pool_shape(node)
1477
+ self._update_computed_dims(sympy_shape)
1478
+ for o in node.output:
1479
+ if not o:
1480
+ continue
1481
+ vi = self.known_vi_[o]
1482
+ vi.CopyFrom(
1483
+ helper.make_tensor_value_info(
1484
+ o,
1485
+ vi.type.tensor_type.elem_type,
1486
+ get_shape_from_sympy_shape(sympy_shape),
1487
+ )
1488
+ )
1489
+
1490
+ def _infer_aten_bitwise_or(self, node):
1491
+ """Infers the output shape for Aten bitwise OR operation based on input node shapes."""
1492
+ shape0 = self._get_shape(node, 0)
1493
+ shape1 = self._get_shape(node, 1)
1494
+ new_shape = self._broadcast_shapes(shape0, shape1)
1495
+ t0 = self.known_vi_[node.input[0]]
1496
+ vi = self.known_vi_[node.output[0]]
1497
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
1498
+
1499
+ def _infer_aten_diagonal(self, node):
1500
+ """Infers the shape of the diagonal of a tensor given a node, offset, and dimensions."""
1501
+ sympy_shape = self._get_sympy_shape(node, 0)
1502
+ rank = len(sympy_shape)
1503
+ offset = self._try_get_value(node, 1)
1504
+ dim1 = self._try_get_value(node, 2)
1505
+ dim2 = self._try_get_value(node, 3)
1506
+
1507
+ assert offset is not None and dim1 is not None and dim2 is not None
1508
+ dim1 = handle_negative_axis(dim1, rank)
1509
+ dim2 = handle_negative_axis(dim2, rank)
1510
+
1511
+ new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}]
1512
+ shape1 = sympy_shape[dim1]
1513
+ shape2 = sympy_shape[dim2]
1514
+ if offset >= 0:
1515
+ diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
1516
+ else:
1517
+ diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
1518
+ new_shape.append(diag_shape)
1519
+
1520
+ if node.output[0]:
1521
+ vi = self.known_vi_[node.output[0]]
1522
+ vi.CopyFrom(
1523
+ helper.make_tensor_value_info(
1524
+ node.output[0],
1525
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1526
+ get_shape_from_sympy_shape(new_shape),
1527
+ )
1528
+ )
1529
+
1530
+ def _infer_aten_multinomial(self, node):
1531
+ """Infers the output shape and type for the PyTorch multinomial operation in an ONNX graph node."""
1532
+ sympy_shape = self._get_sympy_shape(node, 0)
1533
+ rank = len(sympy_shape)
1534
+ assert rank in {1, 2}
1535
+ num_samples = self._try_get_value(node, 1)
1536
+ di = rank - 1
1537
+ last_dim = num_samples or str(self._new_symbolic_dim_from_output(node, 0, di))
1538
+ output_shape = [*sympy_shape[:-1], last_dim]
1539
+ vi = self.known_vi_[node.output[0]]
1540
+ vi.CopyFrom(
1541
+ helper.make_tensor_value_info(
1542
+ node.output[0],
1543
+ onnx.TensorProto.INT64,
1544
+ get_shape_from_sympy_shape(output_shape),
1545
+ )
1546
+ )
1547
+
1548
+ def _infer_aten_pool2d(self, node):
1549
+ """Infer the output shape of a 2D pooling operation in an ONNX graph node."""
1550
+ sympy_shape = self._get_sympy_shape(node, 0)
1551
+ assert len(sympy_shape) == 4
1552
+ sympy_shape[-2:] = [self._new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}]
1553
+ self._update_computed_dims(sympy_shape)
1554
+ for i, o in enumerate(node.output):
1555
+ if not o:
1556
+ continue
1557
+ vi = self.known_vi_[o]
1558
+ elem_type = onnx.TensorProto.INT64 if i == 1 else self.known_vi_[node.input[0]].type.tensor_type.elem_type
1559
+ vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
1560
+
1561
+ def _infer_aten_minmax(self, node):
1562
+ """Infer the output shape and type for the ATen MinMax operation in an ONNX node."""
1563
+ vi = self.known_vi_[node.output[0]]
1564
+ if len(node.input) == 1:
1565
+ vi.CopyFrom(
1566
+ helper.make_tensor_value_info(
1567
+ node.output[0],
1568
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1569
+ [],
1570
+ )
1571
+ )
1572
+ else:
1573
+ assert len(node.input) == 3
1574
+ keepdim = self._try_get_value(node, 2)
1575
+ assert keepdim is not None # can only handle known keepdim case.
1576
+ dim = self._try_get_value(node, 1)
1577
+ if dim is None:
1578
+ rank = self._get_shape_rank(node, 0)
1579
+ output_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
1580
+ else:
1581
+ shape = self._get_sympy_shape(node, 0)
1582
+ dim = handle_negative_axis(dim, len(shape))
1583
+ output_shape = shape[:dim]
1584
+ if keepdim:
1585
+ output_shape += [1]
1586
+ output_shape += shape[dim + 1 :]
1587
+
1588
+ output_shape = get_shape_from_sympy_shape(output_shape)
1589
+ vi.CopyFrom(
1590
+ helper.make_tensor_value_info(
1591
+ node.output[0],
1592
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1593
+ output_shape,
1594
+ )
1595
+ )
1596
+ vi1 = self.known_vi_[node.output[1]]
1597
+ vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
1598
+
1599
+ def _infer_aten_unfold(self, node):
1600
+ """Infer the tensor shape for the 'aten::unfold' operation based on input shape and parameters dimension, size, and step."""
1601
+ sympy_shape = self._get_sympy_shape(node, 0)
1602
+ dimension = self._try_get_value(node, 1)
1603
+ size = self._try_get_value(node, 2)
1604
+ step = self._try_get_value(node, 3)
1605
+ if dimension is not None and size is not None and step is not None:
1606
+ assert dimension < len(sympy_shape)
1607
+ sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
1608
+ sympy_shape.append(size)
1609
+ else:
1610
+ rank = len(sympy_shape)
1611
+ sympy_shape = self._new_symbolic_shape(rank + 1, node)
1612
+ self._update_computed_dims(sympy_shape)
1613
+ if node.output[0]:
1614
+ vi = self.known_vi_[node.output[0]]
1615
+ vi.CopyFrom(
1616
+ helper.make_tensor_value_info(
1617
+ node.output[0],
1618
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1619
+ get_shape_from_sympy_shape(sympy_shape),
1620
+ )
1621
+ )
1622
+
1623
+ def _infer_aten_argmax(self, node):
1624
+ """Infers the output shape for the ONNX ATen argmax operation."""
1625
+ new_shape = None
1626
+ if not node.input[1]:
1627
+ # The argmax of the flattened input is returned.
1628
+ new_shape = []
1629
+ else:
1630
+ dim = self._try_get_value(node, 1)
1631
+ keepdim = self._try_get_value(node, 2)
1632
+ if keepdim is not None:
1633
+ sympy_shape = self._get_sympy_shape(node, 0)
1634
+ if dim is not None:
1635
+ dim = handle_negative_axis(dim, len(sympy_shape))
1636
+ if keepdim:
1637
+ sympy_shape[dim] = 1
1638
+ else:
1639
+ del sympy_shape[dim]
1640
+ else:
1641
+ rank = len(sympy_shape)
1642
+ sympy_shape = self._new_symbolic_shape(rank if keepdim else rank - 1, node)
1643
+ self._update_computed_dims(sympy_shape)
1644
+ new_shape = get_shape_from_sympy_shape(sympy_shape)
1645
+ if node.output[0] and new_shape is not None:
1646
+ vi = self.known_vi_[node.output[0]]
1647
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
1648
+
1649
+ def _infer_aten_group_norm(self, node):
1650
+ """Infers the output shapes and types for the ATen GroupNorm operation based on the provided node
1651
+ information.
1652
+ """
1653
+ self._propagate_shape_and_type(node)
1654
+ input_shape = self._get_shape(node, 0)
1655
+ N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None
1656
+ group = self._try_get_value(node, 6)
1657
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1658
+ for i in {1, 2}:
1659
+ if node.output[i]:
1660
+ vi = self.known_vi_[node.output[i]]
1661
+ vi.CopyFrom(
1662
+ helper.make_tensor_value_info(
1663
+ node.output[i],
1664
+ output_dtype,
1665
+ [
1666
+ (N if N is not None else str(self._new_symbolic_dim_from_output(node, i, 0))),
1667
+ (
1668
+ as_scalar(group)
1669
+ if group is not None
1670
+ else str(self._new_symbolic_dim_from_output(node, i, 1))
1671
+ ),
1672
+ ],
1673
+ )
1674
+ )
1675
+
1676
+ def _infer_aten_upsample(self, node):
1677
+ """Infers the output shape for an aten::upsample operation based on the input shape and specified upsampling parameters."""
1678
+ new_shape = None
1679
+ input_shape = self._get_shape(node, 0)
1680
+ if input_shape is not None:
1681
+ new_shape = input_shape[:2]
1682
+ output_size = self._try_get_value(node, 1)
1683
+ if output_size is not None:
1684
+ new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
1685
+ else:
1686
+ rank = len(input_shape)
1687
+ new_shape += [str(self._new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
1688
+ if node.output[0] and new_shape is not None:
1689
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1690
+ vi = self.known_vi_[node.output[0]]
1691
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1692
+
1693
+ def _infer_BatchNormalization(self, node):
1694
+ """Propagate the shape and type information for the BatchNormalization node."""
1695
+ self._propagate_shape_and_type(node)
1696
+
1697
+ # this works for opsets < 14 and 14 since we check i < len(node.output) in the loop
1698
+ for i in {1, 2, 3, 4}:
1699
+ if i < len(node.output) and node.output[i]:
1700
+ # all of these parameters have the same shape as the 1st input
1701
+ self._propagate_shape_and_type(node, input_index=1, output_index=i)
1702
+
1703
+ def _infer_Range(self, node):
1704
+ """Infers the shape and type for Range nodes based on the provided start, limit, and delta values."""
1705
+ vi = self.known_vi_[node.output[0]]
1706
+ input_data = self._get_int_or_float_values(node)
1707
+ if all(i is not None for i in input_data):
1708
+ start = as_scalar(input_data[0])
1709
+ limit = as_scalar(input_data[1])
1710
+ delta = as_scalar(input_data[2])
1711
+ new_sympy_shape = [sympy.Max(sympy.ceiling((limit - start) / delta), 0)]
1712
+ else:
1713
+ new_sympy_shape = [self._new_symbolic_dim_from_output(node)]
1714
+ self._update_computed_dims(new_sympy_shape)
1715
+ vi.CopyFrom(
1716
+ helper.make_tensor_value_info(
1717
+ node.output[0],
1718
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1719
+ get_shape_from_sympy_shape(new_sympy_shape),
1720
+ )
1721
+ )
1722
+
1723
+ def _infer_ReduceSum(self, node):
1724
+ """Infer output shape for ReduceSum operation based on input shape, axes, and keep_dims attribute."""
1725
+ keep_dims = get_attribute(node, "keepdims", 1)
1726
+ if get_opset(self.out_mp_) >= 13 and len(node.input) > 1:
1727
+ # ReduceSum changes axes to input[1] in opset 13
1728
+ axes = self._try_get_value(node, 1)
1729
+ vi = self.known_vi_[node.output[0]]
1730
+ if axes is None:
1731
+ assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks
1732
+ vi.CopyFrom(
1733
+ helper.make_tensor_value_info(
1734
+ node.output[0],
1735
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1736
+ get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)),
1737
+ )
1738
+ )
1739
+ else:
1740
+ shape = self._get_shape(node, 0)
1741
+ output_shape = []
1742
+ axes = [handle_negative_axis(a, len(shape)) for a in axes]
1743
+ for i, d in enumerate(shape):
1744
+ if i in axes:
1745
+ if keep_dims:
1746
+ output_shape.append(1)
1747
+ else:
1748
+ output_shape.append(d)
1749
+ vi.CopyFrom(
1750
+ helper.make_tensor_value_info(
1751
+ node.output[0],
1752
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1753
+ output_shape,
1754
+ )
1755
+ )
1756
+
1757
+ def _infer_ReduceProd(self, node):
1758
+ """Infer the ReduceProd operation on a node, considering axes and keep dimensions attributes."""
1759
+ axes = get_attribute(node, "axes")
1760
+ keep_dims = get_attribute(node, "keepdims", 1)
1761
+ if keep_dims == 0 and axes == [0]:
1762
+ data = self._get_int_or_float_values(node)[0]
1763
+ if data is not None:
1764
+ self.sympy_data_[node.output[0]] = sympy_reduce_product(data)
1765
+
1766
+ def _infer_RelativePositionBias(self, node):
1767
+ """Infers the relative position bias for a given ONNX node."""
1768
+ seq_len = self._try_get_value(node, 1)
1769
+ real_seq_len = self._try_get_value(node, 2)
1770
+ if seq_len is None or real_seq_len is None:
1771
+ return
1772
+ num_heads = self._get_sympy_shape(node, 0)[1]
1773
+
1774
+ new_shape = [1, num_heads, str(seq_len), str(real_seq_len)]
1775
+
1776
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
1777
+ vi = self.known_vi_[node.output[0]]
1778
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
1779
+
1780
+ def _infer_Reshape(self, node):
1781
+ """Infer the output shape for the Reshape operation based on the provided input shape and reshape parameters."""
1782
+ shape_value = self._try_get_value(node, 1)
1783
+ vi = self.known_vi_[node.output[0]]
1784
+ if shape_value is None:
1785
+ shape_shape = self._get_shape(node, 1)
1786
+ assert len(shape_shape) == 1
1787
+ shape_rank = shape_shape[0]
1788
+ assert is_literal(shape_rank)
1789
+ vi.CopyFrom(
1790
+ helper.make_tensor_value_info(
1791
+ node.output[0],
1792
+ vi.type.tensor_type.elem_type,
1793
+ get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)),
1794
+ )
1795
+ )
1796
+ else:
1797
+ input_sympy_shape = self._get_sympy_shape(node, 0)
1798
+ total = 1
1799
+ for d in input_sympy_shape:
1800
+ total = total * d
1801
+ new_sympy_shape = []
1802
+ deferred_dim_idx = -1
1803
+ non_deferred_size = 1
1804
+ for i, d in enumerate(shape_value):
1805
+ if type(d) == sympy.Symbol or d != 0:
1806
+ new_sympy_shape.append(d)
1807
+ else:
1808
+ new_sympy_shape.append(input_sympy_shape[i])
1809
+ non_deferred_size = non_deferred_size * input_sympy_shape[i]
1810
+ if d == -1:
1811
+ deferred_dim_idx = i
1812
+ elif d != 0:
1813
+ non_deferred_size = non_deferred_size * d
1814
+
1815
+ assert new_sympy_shape.count(-1) < 2
1816
+ if -1 in new_sympy_shape:
1817
+ new_dim = total // non_deferred_size
1818
+ new_sympy_shape[deferred_dim_idx] = new_dim
1819
+
1820
+ self._update_computed_dims(new_sympy_shape)
1821
+ vi.CopyFrom(
1822
+ helper.make_tensor_value_info(
1823
+ node.output[0],
1824
+ vi.type.tensor_type.elem_type,
1825
+ get_shape_from_sympy_shape(new_sympy_shape),
1826
+ )
1827
+ )
1828
+
1829
+ self._pass_on_sympy_data(node)
1830
+
1831
+ def _infer_Resize(self, node):
1832
+ """Infers and updates the shape of the output tensor for a Resize node based on scales or sizes."""
1833
+ vi = self.known_vi_[node.output[0]]
1834
+ input_sympy_shape = self._get_sympy_shape(node, 0)
1835
+ if get_opset(self.out_mp_) <= 10:
1836
+ scales = self._try_get_value(node, 1)
1837
+ if scales is not None:
1838
+ new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)]
1839
+ self._update_computed_dims(new_sympy_shape)
1840
+ vi.CopyFrom(
1841
+ helper.make_tensor_value_info(
1842
+ node.output[0],
1843
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1844
+ get_shape_from_sympy_shape(new_sympy_shape),
1845
+ )
1846
+ )
1847
+ else:
1848
+ roi = self._try_get_value(node, 1)
1849
+ scales = self._try_get_value(node, 2)
1850
+ sizes = self._try_get_value(node, 3)
1851
+ if sizes is not None:
1852
+ new_sympy_shape = [sympy.simplify(round(s)) for s in sizes]
1853
+ self._update_computed_dims(new_sympy_shape)
1854
+ elif scales is not None:
1855
+ rank = len(scales)
1856
+ if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize":
1857
+ assert len(roi) == 2 * rank
1858
+ roi_start = list(roi)[:rank]
1859
+ roi_end = list(roi)[rank:]
1860
+ else:
1861
+ roi_start = [0] * rank
1862
+ roi_end = [1] * rank
1863
+ if isinstance(scales, np.ndarray):
1864
+ scales = scales.tolist()
1865
+ else:
1866
+ scales = list(scales)
1867
+ new_sympy_shape = [
1868
+ sympy.floor(d * (end - start) * scale + sympy.Rational(1, 2))
1869
+ for d, start, end, scale in zip(
1870
+ input_sympy_shape, roi_start, roi_end, scales
1871
+ )
1872
+ ]
1873
+ self._update_computed_dims(new_sympy_shape)
1874
+ else:
1875
+ new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
1876
+
1877
+ vi.CopyFrom(
1878
+ helper.make_tensor_value_info(
1879
+ node.output[0],
1880
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1881
+ get_shape_from_sympy_shape(new_sympy_shape),
1882
+ )
1883
+ )
1884
+
1885
+ def _infer_Scan(self, node):
1886
+ """Infer shape and type information for the ONNX 'Scan' operator node."""
1887
+ subgraph = get_attribute(node, "body")
1888
+ num_scan_inputs = get_attribute(node, "num_scan_inputs")
1889
+ scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs)
1890
+ num_scan_states = len(node.input) - num_scan_inputs
1891
+ scan_input_axes = [
1892
+ handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states))
1893
+ for i, ax in enumerate(scan_input_axes)
1894
+ ]
1895
+ # We may have cases where the subgraph has optional inputs that appear in both subgraph's input and initializer,
1896
+ # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs.
1897
+ assert len(subgraph.input) >= len(node.input)
1898
+ subgraph_inputs = subgraph.input[: len(node.input)]
1899
+ for i, si in enumerate(subgraph_inputs):
1900
+ subgraph_name = si.name
1901
+ si.CopyFrom(self.known_vi_[node.input[i]])
1902
+ if i >= num_scan_states:
1903
+ scan_input_dim = si.type.tensor_type.shape.dim
1904
+ scan_input_dim.remove(scan_input_dim[scan_input_axes[i - num_scan_states]])
1905
+ si.name = subgraph_name
1906
+ self._onnx_infer_subgraph(node, subgraph)
1907
+ num_scan_outputs = len(node.output) - num_scan_states
1908
+ scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs)
1909
+ scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]]
1910
+ for i, o in enumerate(node.output):
1911
+ vi = self.known_vi_[o]
1912
+ if i >= num_scan_states:
1913
+ shape = get_shape_from_type_proto(subgraph.output[i].type)
1914
+ new_dim = handle_negative_axis(scan_output_axes[i - num_scan_states], len(shape) + 1)
1915
+ shape = [*shape[:new_dim], scan_input_dim, *shape[new_dim:]]
1916
+ vi.CopyFrom(helper.make_tensor_value_info(o, subgraph.output[i].type.tensor_type.elem_type, shape))
1917
+ else:
1918
+ vi.CopyFrom(subgraph.output[i])
1919
+ vi.name = o
1920
+
1921
+ def _infer_ScatterElements(self, node):
1922
+ """Infer the output shape and type for ScatterElements node and update known value infos."""
1923
+ data_shape = self._get_shape(node, 0)
1924
+ vi = self.known_vi_[node.output[0]]
1925
+ vi.CopyFrom(
1926
+ helper.make_tensor_value_info(
1927
+ node.output[0],
1928
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
1929
+ data_shape,
1930
+ )
1931
+ )
1932
+
1933
+ def _infer_SequenceAt(self, node):
1934
+ """Infers the shape and type for the output of the 'SequenceAt' ONNX operation, handling symbolic dimensions if
1935
+ necessary.
1936
+ """
1937
+ seq_shape = self._get_shape(node, 0)
1938
+ if seq_shape is not None:
1939
+ vi = self.known_vi_[node.output[0]]
1940
+ for di, d in enumerate(seq_shape):
1941
+ if d is not None:
1942
+ continue
1943
+ new_dim = onnx.TensorShapeProto.Dimension()
1944
+ new_dim.dim_param = str(self._new_symbolic_dim_from_output(node, 0, di))
1945
+ vi.type.tensor_type.shape.dim[di].CopyFrom(new_dim)
1946
+
1947
+ def _infer_SequenceInsert(self, node):
1948
+ """Workaround ONNX's shape inference bug by inferring sequence insertion shapes and types for the provided
1949
+ node.
1950
+ """
1951
+ vi_seq = self.known_vi_[node.input[0]]
1952
+ vi_tensor = self.known_vi_[node.input[1]]
1953
+ vi_out_seq = self.known_vi_[node.output[0]]
1954
+ vi_out_seq.CopyFrom(vi_seq)
1955
+ vi_out_seq.name = node.output[0]
1956
+ self._fuse_tensor_type(node, 0, vi_out_seq.type, vi_tensor.type)
1957
+
1958
+ def _infer_Shape(self, node):
1959
+ """Infers and sets the symbolic shape for the output node in the computation graph."""
1960
+ start = get_attribute(node, "start", 0)
1961
+ end = get_attribute(node, "end", None)
1962
+
1963
+ full_sympy_shape = self._get_sympy_shape(node, 0)
1964
+ num_dims = len(full_sympy_shape)
1965
+
1966
+ if start < 0:
1967
+ start = num_dims + start
1968
+ if end is None:
1969
+ end = num_dims
1970
+ elif end < 0:
1971
+ end = num_dims + end
1972
+
1973
+ assert 0 <= start <= end <= num_dims, (
1974
+ f"reshape start/end invalid: start={start}, end={end}, total_dims={num_dims}"
1975
+ )
1976
+
1977
+ target_sympy_shape = full_sympy_shape[start:end]
1978
+ self.sympy_data_[node.output[0]] = target_sympy_shape
1979
+
1980
+ def _infer_Size(self, node):
1981
+ """Infers and sets the size of the output node by computing the product of its shape in the computation
1982
+ graph.
1983
+ """
1984
+ sympy_shape = self._get_sympy_shape(node, 0)
1985
+ self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape)
1986
+ self.known_vi_[node.output[0]].CopyFrom(
1987
+ helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])
1988
+ )
1989
+
1990
+ def _infer_Slice(self, node):
1991
+ """Infer the shape and value information for the Slice node using SymPy and ONNX helper methods."""
1992
+
1993
+ # even when the relation holds for both `a` and `b`.
1994
+ #
1995
+ # When given `expr` of form `min(a, b) + ...`, this function returns `[a + ..., b + ...]`,
1996
+ # so that we can prove inequalities for both expressions separately.
1997
+ #
1998
+ # If the number of `min(...)` subexpressions is not exactly one, this function just returns `[expr]`.
1999
+ def flatten_min(expr):
2000
+ """Returns a list with expressions split by min() for inequality proof or original expr if no single min()
2001
+ found.
2002
+ """
2003
+ assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}"
2004
+ min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)]
2005
+ if len(min_positions) == 1:
2006
+ min_pos = min_positions[0]
2007
+
2008
+ def replace_min_with_arg(arg_idx):
2009
+ """Replace the sympy.Min() function at a specified position in a sympy.Add() expression with one of
2010
+ its arguments.
2011
+ """
2012
+ replaced = list(expr.args)
2013
+ assert isinstance(replaced[min_pos], sympy.Min), (
2014
+ f"Expected a sympy.Min() at position {min_pos}, got {replaced[min_pos]}"
2015
+ )
2016
+ assert len(replaced[min_pos].args) == 2, (
2017
+ f"Expected a sympy.Min() with exactly 2 arguments, got {replaced[min_pos]}"
2018
+ )
2019
+ replaced[min_pos] = replaced[min_pos].args[arg_idx]
2020
+ return sympy.Add(*replaced)
2021
+
2022
+ return [
2023
+ replace_min_with_arg(0),
2024
+ replace_min_with_arg(1),
2025
+ ]
2026
+ return [expr]
2027
+
2028
+ def less_equal(x, y):
2029
+ """Returns True if x is less than or equal to y, otherwise False."""
2030
+ try:
2031
+ return x <= y
2032
+ except TypeError:
2033
+ pass
2034
+ try:
2035
+ return y >= x
2036
+ except TypeError:
2037
+ pass
2038
+ try:
2039
+ return -x >= -y
2040
+ except TypeError:
2041
+ pass
2042
+ try:
2043
+ return -y <= -x
2044
+ except TypeError:
2045
+ pass
2046
+ try:
2047
+ return y - x >= 0
2048
+ except TypeError:
2049
+ # the last attempt; this may raise TypeError
2050
+ return all(d >= 0 for d in flatten_min(y - x))
2051
+
2052
+ def handle_negative_index(index, bound):
2053
+ """Normalizes a negative index to be in [0, bound)."""
2054
+ try:
2055
+ if not less_equal(0, index):
2056
+ if is_literal(index) and index <= -self.int_max_:
2057
+ # this case is handled separately
2058
+ return index
2059
+ return bound + index
2060
+ except TypeError:
2061
+ logger.warning(f"Cannot determine if {index} < 0")
2062
+ return index
2063
+
2064
+ if get_opset(self.out_mp_) <= 9:
2065
+ axes = get_attribute(node, "axes")
2066
+ starts = get_attribute(node, "starts")
2067
+ ends = get_attribute(node, "ends")
2068
+ if not axes:
2069
+ axes = list(range(len(starts)))
2070
+ steps = [1] * len(axes)
2071
+ else:
2072
+ starts = as_list(self._try_get_value(node, 1), keep_none=True)
2073
+ ends = as_list(self._try_get_value(node, 2), keep_none=True)
2074
+ axes = self._try_get_value(node, 3)
2075
+ steps = self._try_get_value(node, 4)
2076
+ if axes is None and (starts is not None or ends is not None):
2077
+ axes = list(range(len(starts if starts is not None else ends)))
2078
+ if steps is None and (starts is not None or ends is not None):
2079
+ steps = [1] * len(starts if starts is not None else ends)
2080
+ axes = as_list(axes, keep_none=True)
2081
+ steps = as_list(steps, keep_none=True)
2082
+
2083
+ new_sympy_shape = self._get_sympy_shape(node, 0)
2084
+ if starts is None or ends is None:
2085
+ if axes is None:
2086
+ for i in range(len(new_sympy_shape)):
2087
+ new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
2088
+ else:
2089
+ new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
2090
+ for i in axes:
2091
+ new_sympy_shape[i] = self._new_symbolic_dim_from_output(node, 0, i)
2092
+ else:
2093
+ for i, s, e, t in zip(axes, starts, ends, steps):
2094
+ if is_literal(e):
2095
+ e = handle_negative_index(e, new_sympy_shape[i])
2096
+ if is_literal(e):
2097
+ if e >= self.int_max_:
2098
+ e = new_sympy_shape[i]
2099
+ elif e <= -self.int_max_:
2100
+ e = 0 if s > 0 else -1
2101
+ elif is_literal(new_sympy_shape[i]):
2102
+ if e < 0:
2103
+ e = max(0, e + new_sympy_shape[i])
2104
+ e = min(e, new_sympy_shape[i])
2105
+ else:
2106
+ if e > 0:
2107
+ e = (
2108
+ sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
2109
+ ) # special case for slicing first to make computation easier
2110
+ else:
2111
+ if is_literal(new_sympy_shape[i]):
2112
+ if new_sympy_shape[i] < 0:
2113
+ e = sympy.Min(e, new_sympy_shape[i])
2114
+ else:
2115
+ try:
2116
+ if not less_equal(e, new_sympy_shape[i]):
2117
+ e = new_sympy_shape[i]
2118
+ except Exception:
2119
+ if len(e.free_symbols) == 1:
2120
+ if try_solve((e - new_sympy_shape[i]) >= 0, next(iter(e.free_symbols))) is None:
2121
+ logger.warning(
2122
+ f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal"
2123
+ )
2124
+ e = new_sympy_shape[i]
2125
+ else:
2126
+ logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal")
2127
+ e = new_sympy_shape[i]
2128
+
2129
+ s = handle_negative_index(s, new_sympy_shape[i])
2130
+ if is_literal(new_sympy_shape[i]) and is_literal(s):
2131
+ s = max(0, min(s, new_sympy_shape[i]))
2132
+
2133
+ new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
2134
+
2135
+ self._update_computed_dims(new_sympy_shape)
2136
+
2137
+ vi = self.known_vi_[node.output[0]]
2138
+ vi.CopyFrom(
2139
+ helper.make_tensor_value_info(
2140
+ node.output[0],
2141
+ vi.type.tensor_type.elem_type,
2142
+ get_shape_from_sympy_shape(new_sympy_shape),
2143
+ )
2144
+ )
2145
+
2146
+ # handle sympy_data if needed, for slice in shape computation
2147
+ if (
2148
+ node.input[0] in self.sympy_data_
2149
+ and [0] == axes
2150
+ and starts is not None
2151
+ and len(starts) == 1
2152
+ and ends is not None
2153
+ and len(ends) == 1
2154
+ and steps is not None
2155
+ and len(steps) == 1
2156
+ ):
2157
+ input_sympy_data = self.sympy_data_[node.input[0]]
2158
+ if type(input_sympy_data) == list or ( # noqa: E721
2159
+ type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1
2160
+ ):
2161
+ self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
2162
+
2163
+ def _infer_SoftmaxCrossEntropyLoss(self, node):
2164
+ """Infer the softmax cross-entropy loss for a given node in the computation graph."""
2165
+ vi = self.known_vi_[node.output[0]]
2166
+ elem_type = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2167
+
2168
+ # If output type is explicit specified in attribute, we use it as output tensor type.
2169
+ specified_output_type = get_attribute(node, "output_type", None)
2170
+ if specified_output_type is not None:
2171
+ elem_type = specified_output_type
2172
+
2173
+ vi.type.tensor_type.elem_type = elem_type
2174
+ vi.type.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
2175
+
2176
+ if len(node.output) > 1:
2177
+ data_shape = self._get_shape(node, 0)
2178
+ vi = self.known_vi_[node.output[1]]
2179
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, elem_type, data_shape))
2180
+
2181
+ def _infer_Split_Common(self, node, make_value_info_func):
2182
+ """Infers the output shape for the Split operator given an ONNX node and a function to create tensor value
2183
+ info.
2184
+ """
2185
+ input_sympy_shape = self._get_sympy_shape(node, 0)
2186
+ axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
2187
+ op_set = get_opset(self.out_mp_)
2188
+
2189
+ # Depending on op-version 'split' are provided as attribute or via 2nd input
2190
+ if op_set < 13:
2191
+ split = get_attribute(node, "split")
2192
+ assert self._try_get_value(node, 1) is None
2193
+ else:
2194
+ split = self._try_get_value(node, 1)
2195
+ assert get_attribute(node, "split") is None
2196
+
2197
+ if split is None:
2198
+ num_outputs = len(node.output)
2199
+ split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
2200
+ self._update_computed_dims(split)
2201
+ else:
2202
+ split = [sympy.Integer(s) for s in split]
2203
+
2204
+ for i_o in range(len(split)):
2205
+ vi = self.known_vi_[node.output[i_o]]
2206
+ vi.CopyFrom(
2207
+ make_value_info_func(
2208
+ node.output[i_o],
2209
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2210
+ get_shape_from_sympy_shape([*input_sympy_shape[:axis], split[i_o], *input_sympy_shape[axis + 1 :]]),
2211
+ )
2212
+ )
2213
+ self.known_vi_[vi.name] = vi
2214
+
2215
+ def _infer_Split(self, node):
2216
+ """Infers the output shapes and types for the Split operation node."""
2217
+ self._infer_Split_Common(node, helper.make_tensor_value_info)
2218
+
2219
+ def _infer_SplitToSequence(self, node):
2220
+ """Infers the output shapes and types for the SplitToSequence operation node."""
2221
+ self._infer_Split_Common(node, helper.make_sequence_value_info)
2222
+
2223
+ def _infer_Squeeze(self, node):
2224
+ """Infers the output shapes and types for the Squeeze operation node."""
2225
+ input_shape = self._get_shape(node, 0)
2226
+ op_set = get_opset(self.out_mp_)
2227
+
2228
+ # Depending on op-version 'axes' are provided as attribute or via 2nd input
2229
+ if op_set < 13:
2230
+ axes = get_attribute(node, "axes")
2231
+ assert self._try_get_value(node, 1) is None
2232
+ else:
2233
+ axes = self._try_get_value(node, 1)
2234
+ assert get_attribute(node, "axes") is None
2235
+
2236
+ if axes is None:
2237
+ # No axes have been provided (neither via attribute nor via input).
2238
+ # In this case the 'Shape' op should remove all axis with dimension 1.
2239
+ # For symbolic dimensions we guess they are !=1.
2240
+ output_shape = [s for s in input_shape if s != 1]
2241
+ if self.verbose_ > 0:
2242
+ symbolic_dimensions = [s for s in input_shape if type(s) != int] # noqa: E721
2243
+ if symbolic_dimensions:
2244
+ logger.debug(
2245
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
2246
+ f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
2247
+ )
2248
+ else:
2249
+ axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
2250
+ output_shape = []
2251
+ for i in range(len(input_shape)):
2252
+ if i not in axes:
2253
+ output_shape.append(input_shape[i])
2254
+ else:
2255
+ assert input_shape[i] == 1 or type(input_shape[i]) != int # noqa: E721
2256
+ if self.verbose_ > 0 and type(input_shape[i]) != int: # noqa: E721
2257
+ logger.debug(
2258
+ f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
2259
+ f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
2260
+ )
2261
+
2262
+ vi = self.known_vi_[node.output[0]]
2263
+ vi.CopyFrom(
2264
+ helper.make_tensor_value_info(
2265
+ node.output[0],
2266
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2267
+ output_shape,
2268
+ )
2269
+ )
2270
+ self._pass_on_sympy_data(node)
2271
+
2272
+ def _infer_Tile(self, node):
2273
+ """Infers the output shape for the Tile operation in a computation graph based on input shape and repeat
2274
+ values.
2275
+ """
2276
+ repeats_value = self._try_get_value(node, 1)
2277
+ new_sympy_shape = []
2278
+ if repeats_value is not None:
2279
+ input_sympy_shape = self._get_sympy_shape(node, 0)
2280
+ for i, d in enumerate(input_sympy_shape):
2281
+ new_dim = d * repeats_value[i]
2282
+ new_sympy_shape.append(new_dim)
2283
+ self._update_computed_dims(new_sympy_shape)
2284
+ else:
2285
+ new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node)
2286
+ vi = self.known_vi_[node.output[0]]
2287
+ vi.CopyFrom(
2288
+ helper.make_tensor_value_info(
2289
+ node.output[0],
2290
+ vi.type.tensor_type.elem_type,
2291
+ get_shape_from_sympy_shape(new_sympy_shape),
2292
+ )
2293
+ )
2294
+
2295
+ def _infer_TopK(self, node):
2296
+ """Infers the output shape for the TopK operation in an ONNX graph node based on input shape and specified
2297
+ axis.
2298
+ """
2299
+ rank = self._get_shape_rank(node, 0)
2300
+ axis = handle_negative_axis(get_attribute(node, "axis", -1), rank)
2301
+ new_shape = self._get_shape(node, 0)
2302
+
2303
+ if get_opset(self.out_mp_) <= 9:
2304
+ k = get_attribute(node, "k")
2305
+ else:
2306
+ k = self._get_int_or_float_values(node)[1]
2307
+
2308
+ k = self._new_symbolic_dim_from_output(node) if k is None else as_scalar(k)
2309
+ if type(k) in {int, str}:
2310
+ new_shape[axis] = k
2311
+ else:
2312
+ new_sympy_shape = self._get_sympy_shape(node, 0)
2313
+ new_sympy_shape[axis] = k
2314
+ self._update_computed_dims(
2315
+ new_sympy_shape
2316
+ ) # note that TopK dim could be computed in sympy_data, so need to update computed_dims when it enters shape
2317
+ new_shape = get_shape_from_sympy_shape(new_sympy_shape)
2318
+
2319
+ for i_o in range(len(node.output)):
2320
+ vi = self.known_vi_[node.output[i_o]]
2321
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[i_o], vi.type.tensor_type.elem_type, new_shape))
2322
+
2323
+ def _infer_Transpose(self, node):
2324
+ """Infer and update the shape information for a Transpose node based on its input shape and permutation
2325
+ attributes.
2326
+ """
2327
+ if node.input[0] in self.sympy_data_:
2328
+ data_shape = self._get_shape(node, 0)
2329
+ perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
2330
+ input_data = self.sympy_data_[node.input[0]]
2331
+ self.sympy_data_[node.output[0]] = (
2332
+ np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
2333
+ )
2334
+
2335
+ def _infer_Unsqueeze(self, node):
2336
+ """Infers the output shape for the Unsqueeze operation based on the input shape and operator set."""
2337
+ input_shape = self._get_shape(node, 0)
2338
+ op_set = get_opset(self.out_mp_)
2339
+
2340
+ # Depending on op-version 'axes' are provided as attribute or via 2nd input
2341
+ if op_set < 13:
2342
+ axes = get_attribute(node, "axes")
2343
+ assert self._try_get_value(node, 1) is None
2344
+ else:
2345
+ axes = self._try_get_value(node, 1)
2346
+ assert get_attribute(node, "axes") is None
2347
+
2348
+ output_rank = len(input_shape) + len(axes)
2349
+ axes = [handle_negative_axis(a, output_rank) for a in axes]
2350
+
2351
+ input_axis = 0
2352
+ output_shape = []
2353
+ for i in range(output_rank):
2354
+ if i in axes:
2355
+ output_shape.append(1)
2356
+ else:
2357
+ output_shape.append(input_shape[input_axis])
2358
+ input_axis += 1
2359
+
2360
+ vi = self.known_vi_[node.output[0]]
2361
+ vi.CopyFrom(
2362
+ helper.make_tensor_value_info(
2363
+ node.output[0],
2364
+ self.known_vi_[node.input[0]].type.tensor_type.elem_type,
2365
+ output_shape,
2366
+ )
2367
+ )
2368
+
2369
+ self._pass_on_sympy_data(node)
2370
+
2371
+ def _infer_ZipMap(self, node):
2372
+ """Infer the type of keys for a ZipMap node based on its class labels attribute."""
2373
+ map_key_type = None
2374
+ if get_attribute(node, "classlabels_int64s") is not None:
2375
+ map_key_type = onnx.TensorProto.INT64
2376
+ elif get_attribute(node, "classlabels_strings") is not None:
2377
+ map_key_type = onnx.TensorProto.STRING
2378
+
2379
+ assert map_key_type is not None
2380
+ new_vi = onnx.ValueInfoProto()
2381
+ new_vi.name = node.output[0]
2382
+ new_vi.type.sequence_type.elem_type.map_type.value_type.tensor_type.elem_type = onnx.TensorProto.FLOAT
2383
+ new_vi.type.sequence_type.elem_type.map_type.key_type = map_key_type
2384
+ vi = self.known_vi_[node.output[0]]
2385
+ vi.CopyFrom(new_vi)
2386
+
2387
+ def _infer_Attention(self, node):
2388
+ """Infer shape and data type for ONNX Attention node outputs given input shapes and attributes."""
2389
+ shape = self._get_shape(node, 0)
2390
+ shape_weights = self._get_shape(node, 1)
2391
+ shape_bias = self._try_get_shape(node, 2)
2392
+ if shape_bias is not None:
2393
+ assert len(shape_bias) == 1
2394
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
2395
+ if shape and len(shape) == 3:
2396
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
2397
+ if qkv_hidden_sizes_attr is not None:
2398
+ assert len(qkv_hidden_sizes_attr) == 3
2399
+ shape[2] = int(qkv_hidden_sizes_attr[2])
2400
+ elif isinstance(tripled_hidden_size, int):
2401
+ shape[2] = int(tripled_hidden_size / 3)
2402
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2403
+ vi = self.known_vi_[node.output[0]]
2404
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
2405
+
2406
+ if len(node.output) > 1:
2407
+ # input shape: (batch_size, sequence_length, hidden_size)
2408
+ # past shape: (2, batch_size, num_heads, past_sequence_length, head_size)
2409
+ # mask shape: (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length) or (batch_size, 1, max_seq_len, max_seq_len)
2410
+ # present shape: (2, batch_size, num_heads, total_sequence_length, head_size), where total_sequence_length=sequence_length+past_sequence_length
2411
+ input_shape = self._get_shape(node, 0)
2412
+ past_shape = self._get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
2413
+ mask_shape = self._get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []
2414
+
2415
+ if past_shape and len(past_shape) == 5:
2416
+ if mask_shape and len(mask_shape) in {2, 3}:
2417
+ past_shape[3] = mask_shape[-1]
2418
+ elif input_shape and len(input_shape) == 3:
2419
+ if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
2420
+ past_shape[3] = input_shape[1] + past_shape[3]
2421
+ else:
2422
+ past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
2423
+ vi = self.known_vi_[node.output[1]]
2424
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2425
+ else:
2426
+ num_heads = get_attribute(node, "num_heads")
2427
+ head_size = input_shape[2] // num_heads
2428
+ present_shape = [
2429
+ 2,
2430
+ input_shape[0],
2431
+ num_heads,
2432
+ input_shape[1],
2433
+ head_size,
2434
+ ]
2435
+ vi = self.known_vi_[node.output[1]]
2436
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2437
+
2438
+ def _infer_GatedRelativePositionBias(self, node):
2439
+ """Infer the shape for gated relative position bias given the node attributes."""
2440
+ # query_layer: (token_count, num_heads x head_size)
2441
+ # token_offset: (batch_size, seq_len)
2442
+ # Otherwise:
2443
+ # query_layer: (batch_size, seq_len, num_heads x head_size)
2444
+ # token_offset: None
2445
+ # Output shape: (batch_size, num_heads, seq_len, seq_len)
2446
+ num_heads = get_attribute(node, "num_heads")
2447
+
2448
+ token_offset_shape = self._try_get_shape(node, 6)
2449
+ if token_offset_shape is not None:
2450
+ output_shape = [
2451
+ token_offset_shape[0],
2452
+ num_heads,
2453
+ token_offset_shape[1],
2454
+ token_offset_shape[1],
2455
+ ]
2456
+ else:
2457
+ query_layer_shape = self._get_shape(node, 0)
2458
+ assert query_layer_shape is not None and len(query_layer_shape) == 3
2459
+ output_shape = [
2460
+ query_layer_shape[0],
2461
+ num_heads,
2462
+ query_layer_shape[1],
2463
+ query_layer_shape[1],
2464
+ ]
2465
+
2466
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2467
+ vi = self.known_vi_[node.output[0]]
2468
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2469
+
2470
+ def _infer_PackedAttention(self, node):
2471
+ """Infer shape and data type for PackedAttention nodes in a given computational graph."""
2472
+ shape = self._get_shape(node, 0)
2473
+ shape_weights = self._get_shape(node, 1)
2474
+ shape_bias = self._try_get_shape(node, 2)
2475
+ if shape_bias is not None:
2476
+ assert len(shape_bias) == 1
2477
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
2478
+ if shape and len(shape) == 2:
2479
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
2480
+ if qkv_hidden_sizes_attr is not None:
2481
+ assert len(qkv_hidden_sizes_attr) == 3
2482
+ shape[1] = int(qkv_hidden_sizes_attr[2])
2483
+ elif isinstance(tripled_hidden_size, int):
2484
+ shape[1] = int(tripled_hidden_size / 3)
2485
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2486
+ vi = self.known_vi_[node.output[0]]
2487
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
2488
+
2489
+ def _infer_PackedMultiHeadAttention(self, node):
2490
+ """Infer the output shape for PackedMultiHeadAttention node in the computational graph."""
2491
+ shape_value = self._try_get_shape(node, 2)
2492
+ if shape_value is not None and len(shape_value) == 2:
2493
+ output_shape = shape_value
2494
+ else:
2495
+ shape_query = self._get_shape(node, 0)
2496
+ assert shape_query is not None and len(shape_query) == 4
2497
+ output_shape = [shape_query[0], shape_query[1] * shape_query[3]]
2498
+
2499
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2500
+ vi = self.known_vi_[node.output[0]]
2501
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2502
+
2503
+ def _infer_MultiScaleDeformableAttnTRT(self, node):
2504
+ shape_value = self._try_get_shape(node, 0)
2505
+ sampling_locations = self._try_get_shape(node, 3)
2506
+ output_shape = shape_value
2507
+ output_shape[1] = sampling_locations[1]
2508
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2509
+ vi = self.known_vi_[node.output[0]]
2510
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2511
+
2512
+ def _infer_RemovePadding(self, node):
2513
+ """Infers the shape and data type for the output tensor after removing padding."""
2514
+ shape = self._get_shape(node, 0)
2515
+ if shape and len(shape) == 3:
2516
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2517
+ vi = self.known_vi_[node.output[0]]
2518
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]]))
2519
+
2520
+ vi_token_offset = self.known_vi_[node.output[1]]
2521
+ vi_token_offset.CopyFrom(
2522
+ helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]])
2523
+ )
2524
+
2525
+ vi_cumulated_seq_len = self.known_vi_[node.output[2]]
2526
+ vi_cumulated_seq_len.CopyFrom(
2527
+ helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"])
2528
+ )
2529
+
2530
+ vi_max_seq_len = self.known_vi_[node.output[3]]
2531
+ vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]))
2532
+
2533
+ def _infer_RestorePadding(self, node):
2534
+ """Infers the output shape and type for the RestorePadding operation."""
2535
+ shape_input = self._get_shape(node, 0)
2536
+ shape_token_offset = self._get_shape(node, 1)
2537
+ if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2:
2538
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2539
+ vi = self.known_vi_[node.output[0]]
2540
+
2541
+ output_shape = [
2542
+ shape_token_offset[0],
2543
+ shape_token_offset[1],
2544
+ shape_input[1],
2545
+ ]
2546
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2547
+
2548
+ def _infer_BiasGelu(self, node):
2549
+ """Propagate shape and type information for BiasGelu node during inference."""
2550
+ self._propagate_shape_and_type(node)
2551
+
2552
+ def _infer_MultiHeadAttention(self, node):
2553
+ """Propagate shape and type information for MultiHeadAttention node during inference."""
2554
+ # Q, K and V without packing:
2555
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
2556
+ # Input 1 (key) has shape (batch_size, kv_sequence_length, hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
2557
+ # Input 2 (value) has shape (batch_size, kv_sequence_length, v_hidden_size) or (batch_size, num_heads, kv_sequence_length, head_size)
2558
+ # Packed KV:
2559
+ # Input 0 (query) has shape (batch_size, sequence_length, hidden_size)
2560
+ # Input 1 (batch_size, kv_sequence_length, num_heads, 2, head_size)
2561
+ # Input 2 nullptr
2562
+ # Packed QKV:
2563
+ # Input 0 (batch_size, sequence_length, num_heads, 3, head_size)
2564
+ # Input 1 nullptr
2565
+ # Input 2 nullptr
2566
+
2567
+ query_shape = self._get_shape(node, 0)
2568
+ total_sequence_length = None
2569
+ output_dtype = None
2570
+ if query_shape is not None:
2571
+ if len(query_shape) == 3:
2572
+ key_shape = self._try_get_shape(node, 1)
2573
+ # By default, hidden size is same for Q/K/V. Only need check v_hidden_size when value is provided.
2574
+ output_shape = query_shape
2575
+ if key_shape is not None and len(key_shape) == 3:
2576
+ value_shape = self._try_get_shape(node, 2)
2577
+ if value_shape is not None and len(value_shape) == 3:
2578
+ output_shape[2] = value_shape[2]
2579
+ total_sequence_length = key_shape[1]
2580
+
2581
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2582
+ vi = self.known_vi_[node.output[0]]
2583
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2584
+
2585
+ elif len(query_shape) == 5:
2586
+ if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
2587
+ output_shape = [
2588
+ query_shape[0],
2589
+ query_shape[1],
2590
+ query_shape[2] * query_shape[4],
2591
+ ]
2592
+ else:
2593
+ output_shape = [
2594
+ query_shape[0],
2595
+ query_shape[1],
2596
+ f"{query_shape[2]}*{query_shape[4]}",
2597
+ ]
2598
+
2599
+ total_sequence_length = query_shape[1]
2600
+
2601
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2602
+ vi = self.known_vi_[node.output[0]]
2603
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2604
+
2605
+ if len(node.output) > 1:
2606
+ batch_size = query_shape[0]
2607
+ num_heads = get_attribute(node, "num_heads")
2608
+
2609
+ head_size = None
2610
+ if len(query_shape) == 3:
2611
+ head_size = (
2612
+ int(query_shape[2] / num_heads)
2613
+ if isinstance(query_shape[2], int)
2614
+ else f"{query_shape[2]}/{num_heads}"
2615
+ )
2616
+ else:
2617
+ head_size = query_shape[4]
2618
+
2619
+ past_shape = self._try_get_shape(node, 6)
2620
+
2621
+ if past_shape is not None:
2622
+ if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
2623
+ total_sequence_length = past_shape[2] + total_sequence_length
2624
+ else:
2625
+ total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
2626
+
2627
+ present_shape = [
2628
+ batch_size,
2629
+ num_heads,
2630
+ total_sequence_length,
2631
+ head_size,
2632
+ ]
2633
+
2634
+ assert output_dtype is not None
2635
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
2636
+ vi = self.known_vi_[node.output[1]]
2637
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2638
+ vi = self.known_vi_[node.output[2]]
2639
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
2640
+
2641
+ def _infer_DecoderMaskedMultiHeadAttention(self, node):
2642
+ """Infers the output shape of the DecoderMaskedMultiHeadAttention node based on input shapes and attributes in
2643
+ the computational graph.
2644
+ """
2645
+ # Q, K and V without packing:
2646
+ # Input 0 (query) has shape (batch_size, 1, hidden_size)
2647
+ # Input 5 (past_key) if exists has shape (batch_size, num_heads, max_sequence_length, head_size)
2648
+
2649
+ query_shape = self._get_shape(node, 0)
2650
+ if query_shape is not None:
2651
+ output_shape = query_shape
2652
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2653
+ assert output_dtype is not None
2654
+ vi = self.known_vi_[node.output[0]]
2655
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
2656
+
2657
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
2658
+ past_shape = self._try_get_shape(node, 5)
2659
+ if past_shape is not None:
2660
+ vi = self.known_vi_[node.output[1]]
2661
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2662
+ vi = self.known_vi_[node.output[2]]
2663
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
2664
+
2665
+ def _infer_FastGelu(self, node):
2666
+ """Infers the output shapes and types for the FastGelu node using shape propagation."""
2667
+ self._propagate_shape_and_type(node)
2668
+
2669
+ def _infer_Gelu(self, node):
2670
+ """Infers the output shapes and types for the Gelu node using shape propagation."""
2671
+ self._propagate_shape_and_type(node)
2672
+
2673
+ def _infer_QuickGelu(self, node):
2674
+ """Infers the output shapes and types for the QuickGelu node using shape propagation."""
2675
+ self._propagate_shape_and_type(node)
2676
+
2677
+ def _infer_GemmFastGelu(self, node):
2678
+ """Infers the output shapes and types for the GemmFastGelu node using matrix multiplication shape
2679
+ computation.
2680
+ """
2681
+ self._compute_matmul_shape(node)
2682
+
2683
+ def _infer_GemmFloat8(self, node):
2684
+ """Infers the output shapes and types for the GemmFloat8 node using matrix multiplication shape computation."""
2685
+ self._compute_matmul_shape(node)
2686
+
2687
+ def _infer_LayerNormalization(self, node):
2688
+ """Infers the output shapes and types for the LayerNormalization node, including handling mean and variance
2689
+ outputs.
2690
+ """
2691
+ self._propagate_shape_and_type(node)
2692
+ if len(node.output) > 1:
2693
+ axis = get_attribute(node, "axis")
2694
+ if axis is None:
2695
+ axis = -1
2696
+ x_shape = self._get_shape(node, 0)
2697
+ if x_shape is not None:
2698
+ rank = len(x_shape)
2699
+ axis = handle_negative_axis(axis, rank)
2700
+ mean_shape = x_shape[:axis] + [1 for _ in range(rank - axis)]
2701
+ mean_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2702
+ if mean_dtype in {
2703
+ onnx.TensorProto.FLOAT16,
2704
+ onnx.TensorProto.BFLOAT16,
2705
+ }:
2706
+ mean_dtype = onnx.TensorProto.FLOAT
2707
+ vi = self.known_vi_[node.output[1]]
2708
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], mean_dtype, mean_shape))
2709
+ if len(node.output) > 2:
2710
+ vi = self.known_vi_[node.output[2]]
2711
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], mean_dtype, mean_shape))
2712
+
2713
+ def _infer_LongformerAttention(self, node):
2714
+ """Infer and propagate shape and type information for a LongformerAttention node."""
2715
+ self._propagate_shape_and_type(node)
2716
+
2717
+ def _infer_EmbedLayerNormalization(self, node):
2718
+ """Infer and propagate shape and type information for an EmbedLayerNormalization node."""
2719
+ input_ids_shape = self._get_shape(node, 0)
2720
+ word_embedding_shape = self._get_shape(node, 2)
2721
+ assert len(input_ids_shape) == 2 and len(word_embedding_shape) == 2
2722
+ output_shape = [*input_ids_shape, word_embedding_shape[1]]
2723
+
2724
+ word_embedding_dtype = self.known_vi_[node.input[2]].type.tensor_type.elem_type
2725
+ vi = self.known_vi_[node.output[0]]
2726
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], word_embedding_dtype, output_shape))
2727
+
2728
+ if len(node.output) > 1 and node.output[1]:
2729
+ mask_index_shape = [input_ids_shape[0]]
2730
+ vi = self.known_vi_[node.output[1]]
2731
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, mask_index_shape))
2732
+
2733
+ if len(node.output) > 2:
2734
+ # Optional output of add before layer normalization is done
2735
+ # shape is same as the output
2736
+ vi = self.known_vi_[node.output[2]]
2737
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[2], word_embedding_dtype, output_shape))
2738
+
2739
+ def _infer_SkipLayerNormalization(self, node):
2740
+ """Infer the output shape and type for a node with SkipLayerNormalization in an ONNX model."""
2741
+ self._propagate_shape_and_type(node)
2742
+
2743
+ # If the SkipLayerNormalization node contains the optional
2744
+ # output for inference, infer the shape and type for it too
2745
+ if len(node.output) > 3:
2746
+ self._propagate_shape_and_type(node, 0, 3)
2747
+
2748
+ def _infer_GroupNorm(self, node):
2749
+ """Infer the shape and type for Group Normalization in an ONNX model."""
2750
+ self._propagate_shape_and_type(node)
2751
+
2752
+ def _infer_SkipGroupNorm(self, node):
2753
+ """Infer the shape and type for Skip Group Normalization in an ONNX model."""
2754
+ self._propagate_shape_and_type(node, 0, 0)
2755
+ if len(node.output) > 1:
2756
+ self._propagate_shape_and_type(node, 0, 1)
2757
+
2758
+ def _infer_BiasSplitGelu(self, node):
2759
+ """Infer the shape and type for Bias Split Gelu in an ONNX model."""
2760
+ input_shape = self._get_shape(node, 0)
2761
+ bias_shape = self._get_shape(node, 1)
2762
+ if input_shape and bias_shape and isinstance(bias_shape[0], int):
2763
+ output_shape = input_shape
2764
+ output_shape[2] = int(bias_shape[0] / 2)
2765
+ vi = self.known_vi_[node.output[0]]
2766
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
2767
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
2768
+
2769
+ def _infer_BiasAdd(self, node):
2770
+ """Infer the output shape and type for a BiasAdd node by propagating input shape and type information."""
2771
+ self._propagate_shape_and_type(node)
2772
+
2773
+ def _infer_RotaryEmbedding(self, node):
2774
+ """Infer the output shape and type for a RotaryEmbedding node by appropriately propagating input shape and type
2775
+ information.
2776
+ """
2777
+ if len(node.output) == 1:
2778
+ self._propagate_shape_and_type(node)
2779
+ elif len(node.output) == 2:
2780
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
2781
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
2782
+ self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output
2783
+ elif len(node.output) == 3:
2784
+ # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions`
2785
+ self._propagate_shape_and_type(node, input_index=1, output_index=0)
2786
+ self._propagate_shape_and_type(node, input_index=1, output_index=1)
2787
+ self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output
2788
+
2789
+ def _infer_PythonOp(self, node):
2790
+ """Infer and propagate the shape and type information for a PythonOp node in the computation graph."""
2791
+ output_tensor_types = get_attribute(node, "output_tensor_types")
2792
+ assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
2793
+ output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
2794
+ assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."
2795
+
2796
+ from onnxruntime.capi._pybind_state import get_shape_inference_function
2797
+
2798
+ func_name = get_attribute(node, "func_name").decode()
2799
+ shape_inferer = get_shape_inference_function(func_name)
2800
+
2801
+ # Set the context output separately.
2802
+ # The first output is torch.autograd.Function''s context.
2803
+ vi = self.known_vi_[node.output[0]]
2804
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
2805
+
2806
+ if shape_inferer is not None:
2807
+ input_shapes = []
2808
+ input_dtypes = []
2809
+ for input_index in range(len(node.input)):
2810
+ shape = self._get_shape(node, input_index)
2811
+ input_shapes.append(shape)
2812
+ input_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
2813
+ input_dtypes.append(input_dtype)
2814
+ output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
2815
+ assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
2816
+ f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
2817
+ f"but expected {len(node.output) - 1} outputs."
2818
+ )
2819
+ for i in range(len(node.output) - 1):
2820
+ output_index = i + 1
2821
+ vi = self.known_vi_[node.output[output_index]]
2822
+ vi.CopyFrom(
2823
+ helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i])
2824
+ )
2825
+ else:
2826
+ # General shape inference for PythonOp.
2827
+ # Outputs after torch.autograd.Function's context are tensors.
2828
+ # We assume their ranks are fixed for different model inputs.
2829
+ for i in range(len(node.output) - 1):
2830
+ # Process the i-th tensor outputs.
2831
+ vi = self.known_vi_[node.output[i + 1]]
2832
+ sympy_shape = self._new_symbolic_shape(output_tensor_ranks[i], node)
2833
+ shape = get_shape_from_sympy_shape(sympy_shape)
2834
+ value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
2835
+ vi.CopyFrom(value_info)
2836
+
2837
+ def _propagate_shape_and_type(self, node, input_index=0, output_index=0):
2838
+ """Propagates the shape and type information from input to output tensors in a given node."""
2839
+ shape = self._get_shape(node, input_index)
2840
+ output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
2841
+ vi = self.known_vi_[node.output[output_index]]
2842
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
2843
+
2844
+ def _is_none_dim(self, dim_value):
2845
+ """Check if dimension value is a string representing an unknown dimension that is not in symbolic_dims_."""
2846
+ if type(dim_value) != str: # noqa: E721
2847
+ return False
2848
+ return dim_value not in self.symbolic_dims_ if "unk__" in dim_value else False
2849
+
2850
+ def _is_shape_contains_none_dim(self, out_shape):
2851
+ """Check if any dimension in the given shape contains the 'None' dimension and return it if found."""
2852
+ for out in out_shape:
2853
+ if self._is_none_dim(out):
2854
+ return out
2855
+ return None
2856
+
2857
+ def _infer_impl(self, start_sympy_data=None):
2858
+ """Infer implementation details and update symbolic data and input symbols."""
2859
+ self.sympy_data_ = start_sympy_data or {}
2860
+ self._apply_suggested_merge(graph_input_only=True)
2861
+ self.input_symbols_ = set()
2862
+ for i in self.out_mp_.graph.input:
2863
+ input_shape = get_shape_from_value_info(i)
2864
+ if input_shape is None:
2865
+ continue
2866
+
2867
+ if is_sequence(i.type):
2868
+ input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
2869
+ else:
2870
+ input_dims = i.type.tensor_type.shape.dim
2871
+
2872
+ for i_dim, dim in enumerate(input_shape):
2873
+ if dim is None:
2874
+ # some models use None for symbolic dim in input, replace it with a string
2875
+ input_dims[i_dim].dim_param = str(self._new_symbolic_dim(i.name, i_dim))
2876
+
2877
+ self.input_symbols_.update([d for d in input_shape if type(d) == str]) # noqa: E721
2878
+
2879
+ for s in self.input_symbols_:
2880
+ if s in self.suggested_merge_:
2881
+ s_merge = self.suggested_merge_[s]
2882
+ assert s_merge in self.symbolic_dims_
2883
+ self.symbolic_dims_[s] = self.symbolic_dims_[s_merge]
2884
+ else:
2885
+ # Since inputs are not produced by other ops, we can assume positivity
2886
+ self.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
2887
+ # compute prerequisite for node for topological sort
2888
+ # node with subgraphs may have dependency on implicit inputs, which will affect topological sort
2889
+ prereq_for_node = {} # map from node to all its inputs, including implicit ones in subgraph
2890
+
2891
+ def get_prereq(node):
2892
+ """Compute and return the prerequisite inputs for a given node, including implicit inputs from subgraphs."""
2893
+ names = {i for i in node.input if i}
2894
+ subgraphs = []
2895
+ if node.op_type == "If":
2896
+ subgraphs = [
2897
+ get_attribute(node, "then_branch"),
2898
+ get_attribute(node, "else_branch"),
2899
+ ]
2900
+ elif node.op_type in {"Loop", "Scan"}:
2901
+ subgraphs = [get_attribute(node, "body")]
2902
+ for g in subgraphs:
2903
+ g_outputs_and_initializers = {i.name for i in g.initializer}
2904
+ g_prereq = set()
2905
+ for n in g.node:
2906
+ g_outputs_and_initializers.update(n.output)
2907
+ for n in g.node:
2908
+ g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
2909
+ names.update(g_prereq)
2910
+ # remove subgraph inputs from g_prereq since those are local-only
2911
+ for i in g.input:
2912
+ if i.name in names:
2913
+ names.remove(i.name)
2914
+ return names
2915
+
2916
+ for n in self.out_mp_.graph.node:
2917
+ prereq_for_node[n.output[0]] = get_prereq(n)
2918
+
2919
+ # topological sort nodes, note there might be dead nodes so we check if all graph outputs are reached to terminate
2920
+ sorted_nodes = []
2921
+ sorted_known_vi = {i.name for i in list(self.out_mp_.graph.input) + list(self.out_mp_.graph.initializer)}
2922
+ if any(o.name in sorted_known_vi for o in self.out_mp_.graph.output):
2923
+ # Loop/Scan will have some graph output in graph inputs, so don't do topological sort
2924
+ sorted_nodes = self.out_mp_.graph.node
2925
+ else:
2926
+ while any(o.name not in sorted_known_vi for o in self.out_mp_.graph.output):
2927
+ old_sorted_nodes_len = len(sorted_nodes)
2928
+ for node in self.out_mp_.graph.node:
2929
+ if node.output[0] not in sorted_known_vi and all(
2930
+ i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
2931
+ ):
2932
+ sorted_known_vi.update(node.output)
2933
+ sorted_nodes.append(node)
2934
+ if old_sorted_nodes_len == len(sorted_nodes) and not all(
2935
+ o.name in sorted_known_vi for o in self.out_mp_.graph.output
2936
+ ):
2937
+ raise Exception("Invalid model with cyclic graph")
2938
+
2939
+ for node in sorted_nodes:
2940
+ assert all([i in self.known_vi_ for i in node.input if i])
2941
+ self._onnx_infer_single_node(node)
2942
+ known_aten_op = False
2943
+ if node.op_type in self.dispatcher_:
2944
+ self.dispatcher_[node.op_type](node)
2945
+ elif node.op_type == "ConvTranspose":
2946
+ # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input
2947
+ # before adding symbolic compute for them
2948
+ # mark the output type as UNDEFINED to allow guessing of rank
2949
+ vi = self.known_vi_[node.output[0]]
2950
+ if len(vi.type.tensor_type.shape.dim) == 0:
2951
+ vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
2952
+ elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
2953
+ for attr in node.attribute:
2954
+ # TODO: Is overload_name needed?
2955
+ if attr.name == "operator":
2956
+ aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
2957
+ if aten_op_name in self.aten_op_dispatcher_:
2958
+ known_aten_op = True
2959
+ self.aten_op_dispatcher_[aten_op_name](node)
2960
+ break
2961
+
2962
+ if self.verbose_ > 2:
2963
+ logger.debug(node.op_type + ": " + node.name)
2964
+ for i, name in enumerate(node.input):
2965
+ logger.debug(
2966
+ " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "")
2967
+ )
2968
+
2969
+ # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb']
2970
+ # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case
2971
+ if node.op_type in {
2972
+ "Add",
2973
+ "Sub",
2974
+ "Mul",
2975
+ "Div",
2976
+ "MatMul",
2977
+ "MatMulInteger",
2978
+ "MatMulInteger16",
2979
+ "Where",
2980
+ "Sum",
2981
+ }:
2982
+ vi = self.known_vi_[node.output[0]]
2983
+ out_rank = len(get_shape_from_type_proto(vi.type))
2984
+ in_shapes = [self._get_shape(node, i) for i in range(len(node.input))]
2985
+ for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
2986
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
2987
+ if len(in_dims) > 1:
2988
+ self._check_merged_dims(in_dims, allow_broadcast=True)
2989
+
2990
+ for i_o in range(len(node.output)):
2991
+ # Special cases:
2992
+ # 1) We do not care about the training related outputs of SkipLayerNormalization
2993
+ # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because
2994
+ # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding
2995
+ # contrib op
2996
+ if node.op_type in {
2997
+ "SkipLayerNormalization",
2998
+ "SkipSimplifiedLayerNormalization",
2999
+ } and i_o in {1, 2}:
3000
+ continue
3001
+ if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
3002
+ # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs
3003
+ # generated by `export_modules_as_functions`
3004
+ continue
3005
+
3006
+ vi = self.known_vi_[node.output[i_o]]
3007
+ out_type = vi.type
3008
+ out_type_kind = out_type.WhichOneof("value")
3009
+
3010
+ # do not process shape for non-tensors
3011
+ if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
3012
+ if self.verbose_ > 2:
3013
+ if out_type_kind == "sequence_type":
3014
+ seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
3015
+ if seq_cls_type == "tensor_type":
3016
+ logger.debug(
3017
+ " {}: sequence of {} {}".format(
3018
+ node.output[i_o],
3019
+ str(get_shape_from_value_info(vi)),
3020
+ onnx.TensorProto.DataType.Name(
3021
+ vi.type.sequence_type.elem_type.tensor_type.elem_type
3022
+ ),
3023
+ )
3024
+ )
3025
+ else:
3026
+ logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
3027
+ else:
3028
+ logger.debug(f" {node.output[i_o]}: {out_type_kind}")
3029
+ continue
3030
+
3031
+ out_shape = get_shape_from_value_info(vi)
3032
+ out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
3033
+ if self.verbose_ > 2:
3034
+ logger.debug(
3035
+ f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
3036
+ )
3037
+ if node.output[i_o] in self.sympy_data_:
3038
+ logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]]))
3039
+
3040
+ # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain
3041
+ if (
3042
+ out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))
3043
+ ) or out_type_undefined:
3044
+ if self.auto_merge_:
3045
+ if node.op_type in {
3046
+ "Add",
3047
+ "Sub",
3048
+ "Mul",
3049
+ "Div",
3050
+ "MatMul",
3051
+ "MatMulInteger",
3052
+ "MatMulInteger16",
3053
+ "Concat",
3054
+ "Where",
3055
+ "Sum",
3056
+ "Equal",
3057
+ "Less",
3058
+ "Greater",
3059
+ "LessOrEqual",
3060
+ "GreaterOrEqual",
3061
+ "Min",
3062
+ "Max",
3063
+ }:
3064
+ shapes = [self._get_shape(node, i) for i in range(len(node.input))]
3065
+ if node.op_type in {
3066
+ "MatMul",
3067
+ "MatMulInteger",
3068
+ "MatMulInteger16",
3069
+ } and (None in out_shape or self._is_shape_contains_none_dim(out_shape)):
3070
+ if None in out_shape:
3071
+ idx = out_shape.index(None)
3072
+ else:
3073
+ idx = out_shape.index(self._is_shape_contains_none_dim(out_shape))
3074
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
3075
+ # only support auto merge for MatMul for dim < rank-2 when rank > 2
3076
+ assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
3077
+ assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
3078
+ elif node.op_type == "Expand":
3079
+ # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq])
3080
+ shapes = [
3081
+ self._get_shape(node, 0),
3082
+ self._get_value(node, 1),
3083
+ ]
3084
+ else:
3085
+ shapes = []
3086
+
3087
+ if shapes:
3088
+ for idx in range(len(out_shape)):
3089
+ if out_shape[idx] is not None and not self._is_none_dim(out_shape[idx]):
3090
+ continue
3091
+ # note that the broadcasting rule aligns from right to left
3092
+ # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge
3093
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
3094
+ if dim_idx:
3095
+ self._add_suggested_merge(
3096
+ [
3097
+ s[i] if is_literal(s[i]) else str(s[i])
3098
+ for s, i in zip(shapes, dim_idx)
3099
+ if i >= 0
3100
+ ]
3101
+ )
3102
+ self.run_ = True
3103
+ else:
3104
+ self.run_ = False
3105
+ else:
3106
+ self.run_ = False
3107
+
3108
+ # create new dynamic dims for ops not handled by symbolic shape inference
3109
+ if not self.run_ and node.op_type not in self.dispatcher_ and not known_aten_op:
3110
+ is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
3111
+ if is_unknown_op:
3112
+ # unknown op to ONNX, maybe from higher opset or other domain
3113
+ # only guess the output rank from input 0 when using guess_output_rank option
3114
+ out_rank = self._get_shape_rank(node, 0) if self.guess_output_rank_ else -1
3115
+ else:
3116
+ # valid ONNX op, but not handled by symbolic shape inference, just assign dynamic shape
3117
+ out_rank = len(out_shape)
3118
+
3119
+ if out_rank >= 0:
3120
+ new_shape = self._new_symbolic_shape(out_rank, node, i_o)
3121
+ if out_type_undefined:
3122
+ # guess output data type from input vi if not defined
3123
+ out_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
3124
+ else:
3125
+ # otherwise, use original data type
3126
+ out_dtype = vi.type.tensor_type.elem_type
3127
+ vi.CopyFrom(
3128
+ helper.make_tensor_value_info(
3129
+ vi.name,
3130
+ out_dtype,
3131
+ get_shape_from_sympy_shape(new_shape),
3132
+ )
3133
+ )
3134
+
3135
+ if self.verbose_ > 0:
3136
+ if is_unknown_op:
3137
+ logger.debug(
3138
+ f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape"
3139
+ )
3140
+ if self.verbose_ > 2:
3141
+ logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
3142
+ self.run_ = True
3143
+ continue # continue the inference after guess, no need to stop as no merge is needed
3144
+
3145
+ if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined:
3146
+ logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
3147
+ logger.debug("node inputs:")
3148
+ for i in node.input:
3149
+ if i in self.known_vi_:
3150
+ logger.debug(self.known_vi_[i])
3151
+ else:
3152
+ logger.debug(f"not in known_vi_ for {i}")
3153
+ logger.debug("node outputs:")
3154
+ for o in node.output:
3155
+ if o in self.known_vi_:
3156
+ logger.debug(self.known_vi_[o])
3157
+ else:
3158
+ logger.debug(f"not in known_vi_ for {o}")
3159
+ if self.auto_merge_ and not out_type_undefined:
3160
+ logger.debug("Merging: " + str(self.suggested_merge_))
3161
+ return False
3162
+
3163
+ self.run_ = False
3164
+ return True
3165
+
3166
+ def _update_output_from_vi(self):
3167
+ """Update output attributes using known value information dictionary."""
3168
+ for output in self.out_mp_.graph.output:
3169
+ if output.name in self.known_vi_:
3170
+ output.CopyFrom(self.known_vi_[output.name])
3171
+
3172
+ @staticmethod
3173
+ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
3174
+ """Perform symbolic shape inference on an ONNX model using the specified options to handle model shapes
3175
+ efficiently.
3176
+ """
3177
+ onnx_opset = get_opset(in_mp)
3178
+ if (not onnx_opset) or onnx_opset < 7:
3179
+ logger.warning("Only support models of onnx opset 7 and above.")
3180
+ return None
3181
+ symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose)
3182
+ all_shapes_inferred = False
3183
+ symbolic_shape_inference._preprocess(in_mp)
3184
+ while symbolic_shape_inference.run_:
3185
+ all_shapes_inferred = symbolic_shape_inference._infer_impl()
3186
+ symbolic_shape_inference._update_output_from_vi()
3187
+ if not all_shapes_inferred:
3188
+ raise Exception("Incomplete symbolic shape inference")
3189
+ return symbolic_shape_inference.out_mp_
3190
+
3191
+
3192
+ def parse_arguments():
3193
+ """Parses command-line arguments for ONNX model transformation options."""
3194
+ parser = argparse.ArgumentParser()
3195
+ parser.add_argument("--input", required=True, help="The input model file")
3196
+ parser.add_argument("--output", help="The output model file")
3197
+ parser.add_argument(
3198
+ "--auto_merge",
3199
+ help="Automatically merge symbolic dims when confliction happens",
3200
+ action="store_true",
3201
+ default=False,
3202
+ )
3203
+ parser.add_argument(
3204
+ "--int_max",
3205
+ help="maximum value for integer to be treated as boundless for ops like slice",
3206
+ type=int,
3207
+ default=2**31 - 1,
3208
+ )
3209
+ parser.add_argument(
3210
+ "--guess_output_rank",
3211
+ help="guess output rank to be the same as input 0 for unknown ops",
3212
+ action="store_true",
3213
+ default=False,
3214
+ )
3215
+ parser.add_argument(
3216
+ "--verbose",
3217
+ help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed",
3218
+ type=int,
3219
+ default=0,
3220
+ )
3221
+ parser.add_argument(
3222
+ "--save_as_external_data",
3223
+ help="Saving an ONNX model to external data",
3224
+ action="store_true",
3225
+ default=False,
3226
+ )
3227
+ parser.add_argument(
3228
+ "--all_tensors_to_one_file",
3229
+ help="Saving all the external data to one file",
3230
+ action="store_true",
3231
+ default=False,
3232
+ )
3233
+ parser.add_argument(
3234
+ "--external_data_location",
3235
+ help="The file location to save the external file",
3236
+ default="./",
3237
+ )
3238
+ parser.add_argument(
3239
+ "--external_data_size_threshold",
3240
+ help="The size threshold for external data",
3241
+ type=int,
3242
+ default=1024,
3243
+ )
3244
+ return parser.parse_args()
3245
+
3246
+
3247
+ if __name__ == "__main__":
3248
+ args = parse_arguments()
3249
+ logger.info(f"input model: {args.input}")
3250
+ if args.output:
3251
+ logger.info(f"output model {args.output}")
3252
+ logger.info("Doing symbolic shape inference...")
3253
+ out_mp = SymbolicShapeInference.infer_shapes(
3254
+ onnx.load(args.input),
3255
+ args.int_max,
3256
+ args.auto_merge,
3257
+ args.guess_output_rank,
3258
+ args.verbose,
3259
+ )
3260
+ if args.output and out_mp:
3261
+ if args.save_as_external_data:
3262
+ onnx.save_model(
3263
+ out_mp,
3264
+ args.output,
3265
+ save_as_external_data=True,
3266
+ all_tensors_to_one_file=args.all_tensors_to_one_file,
3267
+ location=args.external_data_location,
3268
+ size_threshold=args.external_data_size_threshold,
3269
+ convert_attribute=False,
3270
+ )
3271
+ else:
3272
+ onnx.save(out_mp, args.output)
3273
+ logger.info("Done!")