onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +84 -3
  2. onnxslim/core/pattern/fusion/convadd.py +21 -1
  3. onnxslim/core/pattern/fusion/convbn.py +21 -4
  4. onnxslim/core/pattern/fusion/convmul.py +23 -5
  5. onnxslim/core/pattern/fusion/padconv.py +5 -0
  6. onnxslim/core/shape_inference/__init__.py +378 -0
  7. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  8. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  9. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  10. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  11. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  12. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  13. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  14. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  15. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  16. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  17. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  18. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  19. onnxslim/core/shape_inference/base.py +111 -0
  20. onnxslim/core/shape_inference/context.py +645 -0
  21. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  22. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  23. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  24. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  33. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  34. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  35. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  44. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  45. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  46. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/registry.py +90 -0
  53. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  54. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  55. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  56. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  58. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  59. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  60. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  61. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  62. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  63. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  66. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  67. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  69. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  70. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  72. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  73. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  75. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  76. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  77. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  93. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  94. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  95. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  108. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  109. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  113. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  114. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  115. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  129. onnxslim/core/shape_inference/utils.py +244 -0
  130. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  131. onnxslim/utils.py +4 -2
  132. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
  133. onnxslim-0.1.83.dist-info/RECORD +187 -0
  134. onnxslim-0.1.81.dist-info/RECORD +0 -63
  135. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
  136. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
  137. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,645 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """InferenceContext class for managing shape inference state."""
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ import sympy
11
+ from onnx import helper, numpy_helper, shape_inference
12
+
13
+ from onnxslim.third_party._sympy.functions import FloorDiv
14
+ from onnxslim.third_party._sympy.printers import PythonPrinter as _PythonPrinter
15
+
16
+ from .utils import (
17
+ as_list,
18
+ as_scalar,
19
+ get_attribute,
20
+ get_elem_type_from_type_proto,
21
+ get_opset,
22
+ get_shape_from_sympy_shape,
23
+ get_shape_from_type_proto,
24
+ get_shape_from_value_info,
25
+ handle_negative_axis,
26
+ is_literal,
27
+ is_sequence,
28
+ make_named_value_info,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class PythonPrinter(_PythonPrinter):
35
+ """Custom Python printer for sympy expressions."""
36
+
37
+ def doprint(self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True) -> str:
38
+ return super().doprint(expr)
39
+
40
+
41
+ pexpr = PythonPrinter().doprint
42
+
43
+
44
+ class InferenceContext:
45
+ """Context object that encapsulates all state for shape inference.
46
+
47
+ This class provides access to:
48
+ - Known value info (known_vi_)
49
+ - Symbolic dimensions (symbolic_dims_)
50
+ - Sympy computed data (sympy_data_)
51
+ - Initializers (initializers_)
52
+ - Graph inputs (graph_inputs_)
53
+ - Model opset and other configuration
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ out_mp,
59
+ int_max=2**31 - 1,
60
+ auto_merge=False,
61
+ guess_output_rank=False,
62
+ verbose=0,
63
+ prefix="",
64
+ ):
65
+ """Initialize the inference context.
66
+
67
+ Args:
68
+ out_mp: The ONNX ModelProto being processed.
69
+ int_max: Maximum value for unbounded integers.
70
+ auto_merge: Whether to automatically merge conflicting dimensions.
71
+ guess_output_rank: Whether to guess output rank from input.
72
+ verbose: Logging verbosity level.
73
+ prefix: Prefix for generated symbolic dimension names.
74
+ """
75
+ self.out_mp_ = out_mp
76
+ self.int_max_ = int_max
77
+ self.auto_merge_ = auto_merge
78
+ self.guess_output_rank_ = guess_output_rank
79
+ self.verbose_ = verbose
80
+ self.prefix_ = prefix
81
+ self.subgraph_id_ = 0
82
+
83
+ # State that needs to be initialized
84
+ self.known_vi_ = {}
85
+ self.symbolic_dims_ = {}
86
+ self.sympy_data_ = {}
87
+ self.initializers_ = {}
88
+ self.graph_inputs_ = {}
89
+ self.input_symbols_ = set()
90
+ self.suggested_merge_ = {}
91
+ self.run_ = True
92
+
93
+ @property
94
+ def opset(self):
95
+ """Get the ONNX opset version of the model."""
96
+ return get_opset(self.out_mp_)
97
+
98
+ def preprocess(self):
99
+ """Initialize data structures from the model."""
100
+ self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)}
101
+ self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer}
102
+ self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)}
103
+ self.known_vi_.update(
104
+ {
105
+ i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))
106
+ for i in self.out_mp_.graph.initializer
107
+ }
108
+ )
109
+ self.known_vi_.update({i.name: i for i in list(self.out_mp_.graph.output)})
110
+
111
+ # Shape retrieval methods
112
+ def get_shape(self, node, idx):
113
+ """Retrieve the shape of a tensor from a node's inputs."""
114
+ name = node.input[idx]
115
+ if name in self.known_vi_:
116
+ vi = self.known_vi_[name]
117
+ return get_shape_from_value_info(vi)
118
+ else:
119
+ assert name in self.initializers_
120
+ return list(self.initializers_[name].dims)
121
+
122
+ def try_get_shape(self, node, idx):
123
+ """Attempts to retrieve the shape of the input node at the specified index."""
124
+ if idx > len(node.input) - 1:
125
+ return None
126
+ name = node.input[idx]
127
+ if name in self.known_vi_:
128
+ vi = self.known_vi_[name]
129
+ return get_shape_from_value_info(vi)
130
+ if name in self.initializers_:
131
+ return list(self.initializers_[name].dims)
132
+ return None
133
+
134
+ def get_shape_rank(self, node, idx):
135
+ """Return the rank (number of dimensions) of the input tensor."""
136
+ return len(self.get_shape(node, idx))
137
+
138
+ def get_sympy_shape(self, node, idx):
139
+ """Return the symbolic shape dimensions using SymPy."""
140
+ sympy_shape = []
141
+ for d in self.get_shape(node, idx):
142
+ if type(d) == str:
143
+ sympy_shape.append(
144
+ self.symbolic_dims_[d]
145
+ if d in self.symbolic_dims_
146
+ else sympy.Symbol(d, integer=True, nonnegative=True)
147
+ )
148
+ else:
149
+ assert None is not d
150
+ sympy_shape.append(d)
151
+ return sympy_shape
152
+
153
+ # Value retrieval methods
154
+ def get_value(self, node, idx):
155
+ """Retrieve the value associated with a node's input index."""
156
+ name = node.input[idx]
157
+ assert name in self.sympy_data_ or name in self.initializers_
158
+ return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
159
+
160
+ def try_get_value(self, node, idx):
161
+ """Try to retrieve the value associated with a node's input index."""
162
+ if idx >= len(node.input):
163
+ return None
164
+ name = node.input[idx]
165
+ if name in self.sympy_data_ or name in self.initializers_:
166
+ return self.get_value(node, idx)
167
+ return None
168
+
169
+ # Symbolic dimension management
170
+ def new_symbolic_dim(self, prefix, dim):
171
+ """Create and return a new symbolic dimension."""
172
+ new_dim = f"{prefix}_d{dim}"
173
+ if new_dim in self.suggested_merge_:
174
+ v = self.suggested_merge_[new_dim]
175
+ new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
176
+ else:
177
+ new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
178
+ self.symbolic_dims_[new_dim] = new_symbolic_dim
179
+ return new_symbolic_dim
180
+
181
+ def new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
182
+ """Generates a new symbolic dimension for a given node's output."""
183
+ return self.new_symbolic_dim(
184
+ f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
185
+ dim,
186
+ )
187
+
188
+ def new_symbolic_shape(self, rank, node, out_idx=0):
189
+ """Generate a new symbolic shape for a node output based on its rank."""
190
+ return [self.new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
191
+
192
+ def update_computed_dims(self, new_sympy_shape):
193
+ """Update dimensions in new_sympy_shape based on suggested merges."""
194
+ for i, new_dim in enumerate(new_sympy_shape):
195
+ if not is_literal(new_dim) and type(new_dim) != str:
196
+ str_dim = pexpr(new_dim)
197
+ if str_dim in self.suggested_merge_:
198
+ if not is_literal(self.suggested_merge_[str_dim]):
199
+ new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
200
+ elif str_dim not in self.symbolic_dims_:
201
+ self.symbolic_dims_[str_dim] = new_dim
202
+
203
+ # Dimension merging
204
+ def add_suggested_merge(self, symbols, apply=False):
205
+ """Add suggested merges for input symbols."""
206
+ assert all((type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols)
207
+ symbols = set(symbols)
208
+ for k, v in self.suggested_merge_.items():
209
+ if k in symbols:
210
+ symbols.remove(k)
211
+ symbols.add(v)
212
+ map_to = None
213
+ # if there is literal, map to it first
214
+ for s in symbols:
215
+ if is_literal(s):
216
+ map_to = s
217
+ break
218
+ # when no literals, map to input symbolic dims, then existing symbolic dims
219
+ if map_to is None:
220
+ for s in symbols:
221
+ if s in self.input_symbols_:
222
+ map_to = s
223
+ break
224
+ if map_to is None:
225
+ for s in symbols:
226
+ if type(self.symbolic_dims_[s]) == sympy.Symbol:
227
+ map_to = s
228
+ break
229
+ # when nothing to map to, use the shorter one
230
+ if map_to is None:
231
+ if self.verbose_ > 0:
232
+ logger.warning(f"Potential unsafe merge between symbolic expressions: ({','.join(symbols)})")
233
+ symbols_list = list(symbols)
234
+ lens = [len(s) for s in symbols_list]
235
+ map_to = symbols_list[lens.index(min(lens))]
236
+ symbols.remove(map_to)
237
+
238
+ for s in symbols:
239
+ if s == map_to:
240
+ continue
241
+ if is_literal(map_to) and is_literal(s):
242
+ assert int(map_to) == int(s)
243
+ self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
244
+ for k, v in self.suggested_merge_.items():
245
+ if v == s:
246
+ self.suggested_merge_[k] = map_to
247
+ if apply and self.auto_merge_:
248
+ self.apply_suggested_merge()
249
+
250
+ def apply_suggested_merge(self, graph_input_only=False):
251
+ """Applies suggested merges to graph dimensions."""
252
+ if not self.suggested_merge_:
253
+ return
254
+ for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
255
+ for d in i.type.tensor_type.shape.dim:
256
+ if d.dim_param in self.suggested_merge_:
257
+ v = self.suggested_merge_[d.dim_param]
258
+ if is_literal(v):
259
+ d.dim_value = int(v)
260
+ else:
261
+ d.dim_param = v
262
+
263
+ def merge_symbols(self, dims):
264
+ """Merge dimension symbols, handling automatic merging and validation."""
265
+ if any(type(d) != str for d in dims):
266
+ if not self.auto_merge_:
267
+ return None
268
+ unique_dims = list(set(dims))
269
+ is_int = [is_literal(d) for d in unique_dims]
270
+ assert sum(is_int) <= 1
271
+ if sum(is_int) == 1:
272
+ int_dim = is_int.index(1)
273
+ if self.verbose_ > 0:
274
+ logger.debug(
275
+ f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
276
+ )
277
+ self.check_merged_dims(unique_dims, allow_broadcast=False)
278
+ return unique_dims[int_dim]
279
+ else:
280
+ if self.verbose_ > 0:
281
+ logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}")
282
+ return dims[0]
283
+ if all(d == dims[0] for d in dims):
284
+ return dims[0]
285
+ merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims]
286
+ if all(d == merged[0] for d in merged):
287
+ assert merged[0] in self.symbolic_dims_
288
+ return merged[0]
289
+ else:
290
+ return None
291
+
292
+ def check_merged_dims(self, dims, allow_broadcast=True):
293
+ """Checks merged dimensions for consistency."""
294
+ if allow_broadcast:
295
+ dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
296
+ if any(d != dims[0] for d in dims):
297
+ self.add_suggested_merge(dims, apply=True)
298
+
299
+ # Broadcasting
300
+ def broadcast_shapes(self, shape1, shape2):
301
+ """Broadcast two shapes from right to left."""
302
+ new_shape = []
303
+ rank1 = len(shape1)
304
+ rank2 = len(shape2)
305
+ new_rank = max(rank1, rank2)
306
+ for i in range(new_rank):
307
+ dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
308
+ dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
309
+ if dim1 in [1, dim2]:
310
+ new_dim = dim2
311
+ elif dim2 == 1:
312
+ new_dim = dim1
313
+ else:
314
+ new_dim = self.merge_symbols([dim1, dim2])
315
+ if not new_dim:
316
+ if self.auto_merge_:
317
+ self.add_suggested_merge([dim1, dim2], apply=True)
318
+ else:
319
+ logger.warning(f"unsupported broadcast between {dim1!s} {dim2!s}")
320
+ new_shape = [new_dim, *new_shape]
321
+ return new_shape
322
+
323
+ # Shape computations
324
+ def compute_conv_pool_shape(self, node, channels_last=False):
325
+ """Calculate the output shape of a convolutional or pooling layer."""
326
+ sympy_shape = self.get_sympy_shape(node, 0)
327
+ if len(node.input) > 1:
328
+ W_shape = self.get_sympy_shape(node, 1)
329
+ rank = len(W_shape) - 2
330
+ kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
331
+ sympy_shape[3 if channels_last else 1] = W_shape[0]
332
+ else:
333
+ W_shape = None
334
+ kernel_shape = get_attribute(node, "kernel_shape")
335
+ rank = len(kernel_shape)
336
+
337
+ assert len(sympy_shape) == rank + 2
338
+
339
+ spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
340
+ is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
341
+
342
+ if not any(is_symbolic_dims):
343
+ shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
344
+ if len(shape) > 0:
345
+ assert len(sympy_shape) == len(shape)
346
+ if channels_last:
347
+ sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
348
+ else:
349
+ sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
350
+ return sympy_shape
351
+
352
+ dilations = get_attribute(node, "dilations", [1] * rank)
353
+ strides = get_attribute(node, "strides", [1] * rank)
354
+ effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
355
+ pads = get_attribute(node, "pads")
356
+ if pads is None:
357
+ pads = [0] * (2 * rank)
358
+ auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
359
+ if auto_pad not in {"VALID", "NOTSET"}:
360
+ try:
361
+ residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
362
+ total_pads = [
363
+ max(0, (k - s) if r == 0 else (k - r))
364
+ for k, s, r in zip(effective_kernel_shape, strides, residual)
365
+ ]
366
+ except TypeError:
367
+ total_pads = [max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)]
368
+ elif auto_pad == "VALID":
369
+ total_pads = []
370
+ else:
371
+ total_pads = [0] * rank
372
+ else:
373
+ assert len(pads) == 2 * rank
374
+ total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
375
+
376
+ ceil_mode = get_attribute(node, "ceil_mode", 0)
377
+ for i in range(rank):
378
+ effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
379
+ if len(total_pads) > 0:
380
+ effective_input_size = effective_input_size + total_pads[i]
381
+ if ceil_mode:
382
+ strided_kernel_positions = sympy.ceiling(
383
+ (effective_input_size - effective_kernel_shape[i]) / strides[i]
384
+ )
385
+ else:
386
+ strided_kernel_positions = FloorDiv((effective_input_size - effective_kernel_shape[i]), strides[i])
387
+ sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
388
+ return sympy_shape
389
+
390
+ def compute_matmul_shape(self, node, output_dtype=None):
391
+ """Compute the output shape for a matrix multiplication operation."""
392
+ lhs_shape = self.get_shape(node, 0)
393
+ rhs_shape = self.get_shape(node, 1)
394
+ lhs_rank = len(lhs_shape)
395
+ rhs_rank = len(rhs_shape)
396
+ lhs_reduce_dim = 0
397
+ rhs_reduce_dim = 0
398
+ assert lhs_rank > 0 and rhs_rank > 0
399
+ if lhs_rank == 1 and rhs_rank == 1:
400
+ new_shape = []
401
+ elif lhs_rank == 1:
402
+ rhs_reduce_dim = -2
403
+ new_shape = [*rhs_shape[:rhs_reduce_dim], rhs_shape[-1]]
404
+ elif rhs_rank == 1:
405
+ lhs_reduce_dim = -1
406
+ new_shape = lhs_shape[:lhs_reduce_dim]
407
+ else:
408
+ lhs_reduce_dim = -1
409
+ rhs_reduce_dim = -2
410
+ new_shape = [
411
+ *self.broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]),
412
+ lhs_shape[-2],
413
+ rhs_shape[-1],
414
+ ]
415
+ self.check_merged_dims(
416
+ [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
417
+ allow_broadcast=False,
418
+ )
419
+ if output_dtype is None:
420
+ output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
421
+ vi = self.known_vi_[node.output[0]]
422
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
423
+
424
+ # Value operations
425
+ def get_int_or_float_values(self, node, broadcast=False, allow_float_values=False):
426
+ """Extracts integer or float values from a node."""
427
+
428
+ def int_or_float(value, allow_float_values):
429
+ return value if allow_float_values and value % 1 != 0 else int(value)
430
+
431
+ values = [self.try_get_value(node, i) for i in range(len(node.input))]
432
+ if all(v is not None for v in values):
433
+ for i, v in enumerate(values):
434
+ if type(v) != np.ndarray:
435
+ continue
436
+ if len(v.shape) > 1:
437
+ new_v = None
438
+ elif len(v.shape) == 0:
439
+ new_v = int_or_float(v.item(), allow_float_values)
440
+ else:
441
+ assert len(v.shape) == 1
442
+ new_v = [int_or_float(vv, allow_float_values) for vv in v]
443
+ values[i] = new_v
444
+ values_len = [len(v) if isinstance(v, list) else 0 for v in values]
445
+ max_len = max(values_len)
446
+ if max_len >= 1 and broadcast:
447
+ for i, v in enumerate(values):
448
+ if v is None:
449
+ continue
450
+ if isinstance(v, list):
451
+ if len(v) < max_len:
452
+ values[i] = v * max_len
453
+ else:
454
+ assert len(v) == max_len
455
+ else:
456
+ values[i] = [v] * max_len
457
+ return values
458
+
459
+ def compute_on_sympy_data(self, node, op_func):
460
+ """Calculate the result using Sympy data and a specified operation function."""
461
+ assert len(node.output) == 1
462
+
463
+ if node.op_type in {"Mul", "Div"}:
464
+ values = self.get_int_or_float_values(node, broadcast=True, allow_float_values=True)
465
+ else:
466
+ values = self.get_int_or_float_values(node, broadcast=True)
467
+ if all(v is not None for v in values):
468
+ is_list = [isinstance(v, list) for v in values]
469
+ as_list = any(is_list)
470
+ if as_list:
471
+ self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)]
472
+ else:
473
+ self.sympy_data_[node.output[0]] = op_func(values)
474
+
475
+ def pass_on_sympy_data(self, node):
476
+ """Pass Sympy data through a node."""
477
+ assert len(node.input) == 1 or node.op_type in {
478
+ "Reshape",
479
+ "Unsqueeze",
480
+ "Squeeze",
481
+ }
482
+ self.compute_on_sympy_data(node, lambda x: x[0])
483
+
484
+ # Shape propagation
485
+ def pass_on_shape_and_type(self, node):
486
+ """Propagates the shape and type information from input to output."""
487
+ vi = self.known_vi_[node.output[0]]
488
+ vi.CopyFrom(
489
+ helper.make_tensor_value_info(
490
+ node.output[0],
491
+ get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
492
+ self.get_shape(node, 0),
493
+ )
494
+ )
495
+
496
+ def propagate_shape_and_type(self, node, input_index=0, output_index=0):
497
+ """Propagates the shape and type information from input to output tensors."""
498
+ shape = self.get_shape(node, input_index)
499
+ output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
500
+ vi = self.known_vi_[node.output[output_index]]
501
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
502
+
503
+ def fuse_tensor_type(self, node, out_idx, dst_type, src_type):
504
+ """Update dst_tensor_type to be compatible with src_tensor_type."""
505
+ dst_tensor_type = (
506
+ dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
507
+ )
508
+ src_tensor_type = (
509
+ src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
510
+ )
511
+ if dst_tensor_type.elem_type != src_tensor_type.elem_type:
512
+ node_id = node.name or node.op_type
513
+ raise ValueError(
514
+ f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
515
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
516
+ f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
517
+ )
518
+ if dst_tensor_type.HasField("shape"):
519
+ for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
520
+ if ds[0] != ds[1]:
521
+ new_dim = onnx.TensorShapeProto.Dimension()
522
+ if not is_sequence(dst_type):
523
+ new_dim.dim_param = str(self.new_symbolic_dim_from_output(node, out_idx, di))
524
+ dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
525
+ else:
526
+ dst_tensor_type.CopyFrom(src_tensor_type)
527
+
528
+ # ONNX inference helpers
529
+ def onnx_infer_single_node(self, node):
530
+ """Performs ONNX shape inference for a single node."""
531
+ skip_infer = node.op_type in {
532
+ "If",
533
+ "Loop",
534
+ "Scan",
535
+ "SplitToSequence",
536
+ "ZipMap",
537
+ "Attention",
538
+ "BiasGelu",
539
+ "EmbedLayerNormalization",
540
+ "FastGelu",
541
+ "Gelu",
542
+ "GemmFastGelu",
543
+ "LayerNormalization",
544
+ "LongformerAttention",
545
+ "DequantizeLinear",
546
+ "QuantizeLinear",
547
+ "RelativePositionBias",
548
+ "RemovePadding",
549
+ "RestorePadding",
550
+ "SimplifiedLayerNormalization",
551
+ "SkipLayerNormalization",
552
+ "SkipSimplifiedLayerNormalization",
553
+ "PackedAttention",
554
+ "PythonOp",
555
+ "MultiHeadAttention",
556
+ "GroupNorm",
557
+ "SkipGroupNorm",
558
+ "BiasSplitGelu",
559
+ "BiasAdd",
560
+ "NhwcConv",
561
+ "QuickGelu",
562
+ "RotaryEmbedding",
563
+ }
564
+
565
+ if not skip_infer:
566
+ initializers = []
567
+ if (get_opset(self.out_mp_) >= 9) and (
568
+ node.op_type == "Unsqueeze"
569
+ or node.op_type == "ReduceMax"
570
+ or node.op_type == "ReduceMean"
571
+ or node.op_type == "DFT"
572
+ or node.op_type == "ReduceL2"
573
+ or node.op_type == "ReduceMin"
574
+ ):
575
+ initializers = [
576
+ self.initializers_[name]
577
+ for name in node.input
578
+ if (name in self.initializers_ and name not in self.graph_inputs_)
579
+ ]
580
+
581
+ if (
582
+ node.op_type
583
+ in {
584
+ "Add",
585
+ "Sub",
586
+ "Mul",
587
+ "Div",
588
+ "MatMul",
589
+ "MatMulInteger",
590
+ "MatMulInteger16",
591
+ "Where",
592
+ "Sum",
593
+ }
594
+ and node.output[0] in self.known_vi_
595
+ ):
596
+ vi = self.known_vi_[node.output[0]]
597
+ out_rank = len(get_shape_from_type_proto(vi.type))
598
+ in_shapes = [self.get_shape(node, i) for i in range(len(node.input))]
599
+ for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
600
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
601
+ if len(in_dims) > 1:
602
+ self.check_merged_dims(in_dims, allow_broadcast=True)
603
+
604
+ tmp_graph = helper.make_graph(
605
+ [node],
606
+ "tmp",
607
+ [self.known_vi_[i] for i in node.input if i],
608
+ [make_named_value_info(i) for i in node.output],
609
+ initializers,
610
+ )
611
+
612
+ kwargs = {}
613
+ kwargs["opset_imports"] = self.out_mp_.opset_import
614
+ kwargs["ir_version"] = self.out_mp_.ir_version
615
+
616
+ model = helper.make_model(tmp_graph, **kwargs)
617
+ model = shape_inference.infer_shapes(model)
618
+
619
+ for i_o in range(len(node.output)):
620
+ o = node.output[i_o]
621
+ if o:
622
+ out = model.graph.output[i_o]
623
+ if not out.type.WhichOneof("value") and o in self.known_vi_:
624
+ continue
625
+
626
+ vi = self.out_mp_.graph.value_info.add()
627
+ if not skip_infer:
628
+ vi.CopyFrom(out)
629
+ else:
630
+ vi.name = o
631
+ self.known_vi_[o] = vi
632
+
633
+ # Helper methods for checking none dims
634
+ def is_none_dim(self, dim_value):
635
+ """Check if dimension value is unknown."""
636
+ if type(dim_value) != str:
637
+ return False
638
+ return dim_value not in self.symbolic_dims_ if "unk__" in dim_value else False
639
+
640
+ def is_shape_contains_none_dim(self, out_shape):
641
+ """Check if any dimension in the given shape is unknown."""
642
+ for out in out_shape:
643
+ if self.is_none_dim(out):
644
+ return out
645
+ return None
@@ -0,0 +1,8 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Contrib/Custom ONNX operator shape handlers."""
5
+
6
+ from . import attention
7
+ from . import normalization
8
+ from . import misc
@@ -0,0 +1,15 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Attention-related contrib operator shape handlers."""
5
+
6
+ from . import attention
7
+ from . import multi_head_attention
8
+ from . import packed_attention
9
+ from . import packed_multi_head_attention
10
+ from . import gated_relative_position_bias
11
+ from . import multi_scale_deformable_attn
12
+ from . import longformer_attention
13
+ from . import decoder_masked_mha
14
+ from . import remove_padding
15
+ from . import restore_padding