onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (137) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +84 -3
  2. onnxslim/core/pattern/fusion/convadd.py +21 -1
  3. onnxslim/core/pattern/fusion/convbn.py +21 -4
  4. onnxslim/core/pattern/fusion/convmul.py +23 -5
  5. onnxslim/core/pattern/fusion/padconv.py +5 -0
  6. onnxslim/core/shape_inference/__init__.py +378 -0
  7. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  8. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  9. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  10. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  11. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  12. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  13. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  14. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  15. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  16. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  17. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  18. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  19. onnxslim/core/shape_inference/base.py +111 -0
  20. onnxslim/core/shape_inference/context.py +645 -0
  21. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  22. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  23. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  24. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  33. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  34. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  35. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  44. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  45. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  46. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/registry.py +90 -0
  53. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  54. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  55. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  56. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  58. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  59. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  60. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  61. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  62. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  63. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  66. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  67. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  69. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  70. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  72. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  73. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  75. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  76. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  77. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  93. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  94. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  95. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  108. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  109. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  113. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  114. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  115. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  129. onnxslim/core/shape_inference/utils.py +244 -0
  130. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  131. onnxslim/utils.py +4 -2
  132. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
  133. onnxslim-0.1.83.dist-info/RECORD +187 -0
  134. onnxslim-0.1.81.dist-info/RECORD +0 -63
  135. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
  136. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
  137. {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
@@ -54,9 +54,16 @@ def dead_node_elimination(graph, is_subgraph=False):
54
54
  node.inputs.insert(1, reshape_const)
55
55
  logger.debug(f"replacing {node.op} op: {node.name}")
56
56
  elif node.op == "Slice":
57
- if node.inputs[0].shape and node.outputs[0].shape and node.inputs[0].shape == node.outputs[0].shape:
58
- node.erase()
59
- logger.debug(f"removing {node.op} op: {node.name}")
57
+ if (node.inputs[0].shape and node.outputs[0].shape
58
+ and node.inputs[0].shape == node.outputs[0].shape
59
+ and all(isinstance(item, int) for item in node.inputs[0].shape)):
60
+
61
+ # Check if slice is a no-op by analyzing parameters directly
62
+ # Slice inputs: data, starts, ends, [axes], [steps]
63
+ if is_noop_slice(node):
64
+ node.erase()
65
+ logger.debug(f"removing {node.op} op: {node.name}")
66
+
60
67
  elif node.op == "Mul":
61
68
  if (isinstance(node.inputs[1], Constant) and isinstance(node.inputs[0], Variable)) or (
62
69
  isinstance(node.inputs[0], Constant) and isinstance(node.inputs[1], Variable)
@@ -153,3 +160,77 @@ def get_constant_variable(node, return_idx=False):
153
160
  for idx, input in enumerate(list(node.inputs)):
154
161
  if isinstance(input, Constant):
155
162
  return (idx, input) if return_idx else input
163
+
164
+
165
+ def is_noop_slice(node):
166
+ """Check if a Slice node is a no-op by analyzing its parameters directly.
167
+
168
+ A Slice is a no-op when it extracts the entire tensor, i.e., for each sliced axis:
169
+ - start == 0 (or equivalent negative index)
170
+ - end >= dim_size (or is INT_MAX-like value)
171
+ - step == 1
172
+ """
173
+ # Slice inputs: data, starts, ends, [axes], [steps]
174
+ if len(node.inputs) < 3:
175
+ return False
176
+
177
+ data_shape = node.inputs[0].shape
178
+ if not data_shape or not all(isinstance(d, int) for d in data_shape):
179
+ return False
180
+
181
+ # Get starts and ends (required)
182
+ starts_input = node.inputs[1]
183
+ ends_input = node.inputs[2]
184
+
185
+ if not isinstance(starts_input, Constant) or not isinstance(ends_input, Constant):
186
+ return False
187
+
188
+ starts = starts_input.values.flatten().tolist()
189
+ ends = ends_input.values.flatten().tolist()
190
+
191
+ # Get axes (optional, defaults to [0, 1, 2, ...])
192
+ if len(node.inputs) > 3 and isinstance(node.inputs[3], Constant):
193
+ axes = node.inputs[3].values.flatten().tolist()
194
+ else:
195
+ axes = list(range(len(starts)))
196
+
197
+ # Get steps (optional, defaults to [1, 1, 1, ...])
198
+ if len(node.inputs) > 4 and isinstance(node.inputs[4], Constant):
199
+ steps = node.inputs[4].values.flatten().tolist()
200
+ else:
201
+ steps = [1] * len(starts)
202
+
203
+ # Check each axis
204
+ ndim = len(data_shape)
205
+ for start, end, axis, step in zip(starts, ends, axes, steps):
206
+ # Normalize negative axis
207
+ if axis < 0:
208
+ axis = ndim + axis
209
+
210
+ if axis < 0 or axis >= ndim:
211
+ return False
212
+
213
+ dim_size = data_shape[axis]
214
+
215
+ # Step must be 1 for no-op
216
+ if step != 1:
217
+ return False
218
+
219
+ # Normalize negative start index
220
+ if start < 0:
221
+ start = max(0, dim_size + start)
222
+
223
+ # Start must be 0
224
+ if start != 0:
225
+ return False
226
+
227
+ # Normalize negative end index
228
+ if end < 0:
229
+ end = dim_size + end
230
+
231
+ # End must cover the entire dimension
232
+ # Common patterns: end == dim_size, or end is a large value like INT_MAX
233
+ if end < dim_size:
234
+ return False
235
+
236
+ return True
@@ -27,12 +27,13 @@ class ConvAddMatcher(PatternMatcher):
27
27
  conv_weight = list(conv_node.inputs)[1]
28
28
  conv_node_users = conv_node.users
29
29
  node = self.add_0
30
+ oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
30
31
  if (
31
32
  len(conv_node_users) == 1
32
33
  and isinstance(node.inputs[1], gs.Constant)
33
34
  and isinstance(conv_weight, gs.Constant)
34
35
  and node.inputs[1].values.squeeze().ndim == 1
35
- and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[0]
36
+ and node.inputs[1].values.squeeze().shape[0] == conv_weight.shape[oc_axis]
36
37
  ):
37
38
  add_node = node
38
39
  if len(conv_node.inputs) == 2:
@@ -66,5 +67,24 @@ class ConvAddMatcher(PatternMatcher):
66
67
 
67
68
  return match_case
68
69
 
70
+ class ConvTransposeAddMatcher(ConvAddMatcher):
71
+ def __init__(self, priority):
72
+ """Initializes the ConvTransposeAddMatcher for fusing ConvTranspose and Add layers in an ONNX graph."""
73
+ pattern = Pattern(
74
+ """
75
+ input input 0 1 conv_0
76
+ ConvTranspose conv_0 1+ 1 input bn_0
77
+ Add add_0 2 1 conv_0 ? output
78
+ output output 1 0 add_0
79
+ """
80
+ )
81
+ super(ConvAddMatcher, self).__init__(pattern, priority)
82
+
83
+ @property
84
+ def name(self):
85
+ """Returns the name of the FusionConvTransposeAdd pattern."""
86
+ return "FusionConvTransposeAdd"
87
+
69
88
 
70
89
  register_fusion_pattern(ConvAddMatcher(1))
90
+ register_fusion_pattern(ConvTransposeAddMatcher(1))
@@ -44,11 +44,9 @@ class ConvBatchNormMatcher(PatternMatcher):
44
44
  conv_transpose_bias = conv_transpose_node.inputs[2].values
45
45
 
46
46
  bn_var_rsqrt = bn_scale / np.sqrt(bn_running_var + bn_eps)
47
+ oc_axis = 0 if conv_transpose_node.op == "Conv" else 1 # output_channel_axis
47
48
  shape = [1] * len(conv_transpose_weight.shape)
48
- if bn_node.i(0).op == "Conv":
49
- shape[0] = -1
50
- else:
51
- shape[1] = -1
49
+ shape[oc_axis] = -1
52
50
  conv_w = conv_transpose_weight * bn_var_rsqrt.reshape(shape)
53
51
  conv_b = (conv_transpose_bias - bn_running_mean) * bn_var_rsqrt + bn_bias
54
52
 
@@ -82,5 +80,24 @@ class ConvBatchNormMatcher(PatternMatcher):
82
80
 
83
81
  return match_case
84
82
 
83
+ class ConvTransposeBatchNormMatcher(ConvBatchNormMatcher):
84
+ def __init__(self, priority):
85
+ """Initializes the ConvTransposeBatchNormMatcher for fusing ConvTranspose and BatchNormalization layers in an ONNX graph."""
86
+ pattern = Pattern(
87
+ """
88
+ input input 0 1 conv_0
89
+ ConvTranspose conv_0 1+ 1 input bn_0
90
+ BatchNormalization bn_0 5 1 conv_0 ? ? ? ? output
91
+ output output 1 0 bn_0
92
+ """
93
+ )
94
+ super(ConvBatchNormMatcher, self).__init__(pattern, priority)
95
+
96
+ @property
97
+ def name(self):
98
+ """Returns the name of the FusionConvTransposeBN pattern."""
99
+ return "FusionConvTransposeBN"
100
+
85
101
 
86
102
  register_fusion_pattern(ConvBatchNormMatcher(1))
103
+ register_fusion_pattern(ConvTransposeBatchNormMatcher(1))
@@ -28,11 +28,10 @@ class ConvMulMatcher(PatternMatcher):
28
28
  conv_weight = list(conv_node.inputs)[1]
29
29
  if len(conv_node.users) == 1 and conv_node.users[0] == mul_node and isinstance(mul_node.inputs[1], gs.Constant):
30
30
  mul_constant = mul_node.inputs[1].values
31
-
32
- if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[0]:
33
- weight_shape = conv_weight.values.shape
34
- reshape_shape = [-1] + [1] * (len(weight_shape) - 1)
35
-
31
+ oc_axis = 0 if conv_node.op == "Conv" else 1 # output_channel_axis
32
+ if mul_constant.squeeze().ndim == 1 and mul_constant.squeeze().shape[0] == conv_weight.shape[oc_axis]:
33
+ reshape_shape = [1] * len(conv_weight.values.shape)
34
+ reshape_shape[oc_axis] = -1
36
35
  mul_scale_reshaped = mul_constant.squeeze().reshape(reshape_shape)
37
36
  new_weight = conv_weight.values * mul_scale_reshaped
38
37
 
@@ -65,5 +64,24 @@ class ConvMulMatcher(PatternMatcher):
65
64
 
66
65
  return match_case
67
66
 
67
+ class ConvTransposeMulMatcher(ConvMulMatcher):
68
+ def __init__(self, priority):
69
+ """Initializes the ConvTransposeMulMatcher for fusing ConvTranspose and Mul layers in an ONNX graph."""
70
+ pattern = Pattern(
71
+ """
72
+ input input 0 1 conv_0
73
+ ConvTranspose conv_0 1+ 1 input mul_0
74
+ Mul mul_0 2 1 conv_0 ? output
75
+ output output 1 0 mul_0
76
+ """
77
+ )
78
+ super(ConvMulMatcher, self).__init__(pattern, priority)
79
+
80
+ @property
81
+ def name(self):
82
+ """Returns the name of the FusionConvTransposeMul pattern."""
83
+ return "FusionConvTransposeMul"
84
+
68
85
 
69
86
  register_fusion_pattern(ConvMulMatcher(1))
87
+ register_fusion_pattern(ConvTransposeMulMatcher(1))
@@ -37,6 +37,8 @@ class PadConvMatcher(PatternMatcher):
37
37
  pad_node_users = pad_node.users
38
38
 
39
39
  pad_inputs = len(pad_node.inputs)
40
+ auto_pad = pad_node.attrs.get("auto_pad", "NOTSET")
41
+
40
42
  if pad_inputs < 3 or (
41
43
  (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Constant) and pad_node.inputs[2].values == 0))
42
44
  or (pad_inputs >= 3 and (isinstance(pad_node.inputs[2], gs.Variable) and pad_node.inputs[2].name == ""))
@@ -45,6 +47,7 @@ class PadConvMatcher(PatternMatcher):
45
47
  isinstance(pad_node.inputs[1], gs.Constant)
46
48
  and pad_node.attrs.get("mode", "constant") == "constant"
47
49
  and conv_node.inputs[1].shape
50
+ and (auto_pad == "NOTSET" or auto_pad == "VALID")
48
51
  ):
49
52
  conv_weight_dim = len(conv_node.inputs[1].shape)
50
53
  pad_value = pad_node.inputs[1].values.tolist()
@@ -74,6 +77,8 @@ class PadConvMatcher(PatternMatcher):
74
77
  pads = [pad + conv_pad for pad, conv_pad in zip(pads, conv_pads)]
75
78
 
76
79
  attrs["pads"] = pads
80
+ conv_node.attrs.pop("auto_pad", None)
81
+
77
82
  match_case[conv_node.name] = {
78
83
  "op": "Conv",
79
84
  "inputs": inputs,
@@ -0,0 +1,378 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """
5
+ Symbolic Shape Inference Module
6
+
7
+ This module provides symbolic shape inference for ONNX models. It replaces the
8
+ monolithic SymbolicShapeInference class with a modular, handler-based architecture.
9
+
10
+ Usage:
11
+ from onnxslim.core.shape_inference import ShapeInferencer
12
+
13
+ model = onnx.load("model.onnx")
14
+ model_with_shapes = ShapeInferencer.infer_shapes(model)
15
+ """
16
+
17
+ import logging
18
+
19
+ import onnx
20
+ import sympy
21
+ from onnx import helper
22
+
23
+ from .context import InferenceContext
24
+ from .registry import get_all_aten_handlers, get_all_shape_handlers, get_aten_handler, get_shape_handler
25
+ from .utils import (
26
+ get_attribute,
27
+ get_opset,
28
+ get_shape_from_type_proto,
29
+ get_shape_from_value_info,
30
+ is_literal,
31
+ is_sequence,
32
+ )
33
+
34
+ # Import all handlers to trigger registration
35
+ from . import aten_ops # noqa: F401
36
+ from . import contrib_ops # noqa: F401
37
+ from . import standard_ops # noqa: F401
38
+
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ class ShapeInferencer:
43
+ """Main class for performing symbolic shape inference on ONNX models."""
44
+
45
+ def __init__(self, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, prefix=""):
46
+ """Initialize the ShapeInferencer.
47
+
48
+ Args:
49
+ int_max: Maximum value for unbounded integers.
50
+ auto_merge: Whether to automatically merge conflicting dimensions.
51
+ guess_output_rank: Whether to guess output rank from input.
52
+ verbose: Logging verbosity level.
53
+ prefix: Prefix for generated symbolic dimension names.
54
+ """
55
+ self.int_max_ = int_max
56
+ self.auto_merge_ = auto_merge
57
+ self.guess_output_rank_ = guess_output_rank
58
+ self.verbose_ = verbose
59
+ self.prefix_ = prefix
60
+
61
+ def _infer_impl(self, ctx, start_sympy_data=None):
62
+ """Main inference implementation loop."""
63
+ ctx.sympy_data_ = start_sympy_data or {}
64
+ ctx.apply_suggested_merge(graph_input_only=True)
65
+ ctx.input_symbols_ = set()
66
+
67
+ # Process graph inputs
68
+ for i in ctx.out_mp_.graph.input:
69
+ input_shape = get_shape_from_value_info(i)
70
+ if input_shape is None:
71
+ continue
72
+
73
+ if is_sequence(i.type):
74
+ input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
75
+ else:
76
+ input_dims = i.type.tensor_type.shape.dim
77
+
78
+ for i_dim, dim in enumerate(input_shape):
79
+ if dim is None:
80
+ input_dims[i_dim].dim_param = str(ctx.new_symbolic_dim(i.name, i_dim))
81
+
82
+ ctx.input_symbols_.update([d for d in input_shape if type(d) == str])
83
+
84
+ for s in ctx.input_symbols_:
85
+ if s in ctx.suggested_merge_:
86
+ s_merge = ctx.suggested_merge_[s]
87
+ assert s_merge in ctx.symbolic_dims_
88
+ ctx.symbolic_dims_[s] = ctx.symbolic_dims_[s_merge]
89
+ else:
90
+ ctx.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
91
+
92
+ # Compute prerequisite for node for topological sort
93
+ prereq_for_node = {}
94
+
95
+ def get_prereq(node):
96
+ names = {i for i in node.input if i}
97
+ subgraphs = []
98
+ if node.op_type == "If":
99
+ subgraphs = [get_attribute(node, "then_branch"), get_attribute(node, "else_branch")]
100
+ elif node.op_type in {"Loop", "Scan"}:
101
+ subgraphs = [get_attribute(node, "body")]
102
+ for g in subgraphs:
103
+ g_outputs_and_initializers = {i.name for i in g.initializer}
104
+ g_prereq = set()
105
+ for n in g.node:
106
+ g_outputs_and_initializers.update(n.output)
107
+ for n in g.node:
108
+ g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
109
+ names.update(g_prereq)
110
+ for i in g.input:
111
+ if i.name in names:
112
+ names.remove(i.name)
113
+ return names
114
+
115
+ for n in ctx.out_mp_.graph.node:
116
+ prereq_for_node[n.output[0]] = get_prereq(n)
117
+
118
+ # Topological sort nodes
119
+ sorted_nodes = []
120
+ sorted_known_vi = {i.name for i in list(ctx.out_mp_.graph.input) + list(ctx.out_mp_.graph.initializer)}
121
+ if any(o.name in sorted_known_vi for o in ctx.out_mp_.graph.output):
122
+ sorted_nodes = ctx.out_mp_.graph.node
123
+ else:
124
+ while any(o.name not in sorted_known_vi for o in ctx.out_mp_.graph.output):
125
+ old_sorted_nodes_len = len(sorted_nodes)
126
+ for node in ctx.out_mp_.graph.node:
127
+ if node.output[0] not in sorted_known_vi and all(
128
+ i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
129
+ ):
130
+ sorted_known_vi.update(node.output)
131
+ sorted_nodes.append(node)
132
+ if old_sorted_nodes_len == len(sorted_nodes) and not all(
133
+ o.name in sorted_known_vi for o in ctx.out_mp_.graph.output
134
+ ):
135
+ raise Exception("Invalid model with cyclic graph")
136
+
137
+ # Get handlers
138
+ shape_handlers = get_all_shape_handlers()
139
+ aten_handlers = get_all_aten_handlers()
140
+
141
+ # Process each node
142
+ for node in sorted_nodes:
143
+ assert all([i in ctx.known_vi_ for i in node.input if i])
144
+ ctx.onnx_infer_single_node(node)
145
+ known_aten_op = False
146
+
147
+ # Try standard handlers first
148
+ handler = get_shape_handler(node.op_type)
149
+ if handler is not None:
150
+ handler.infer_shape(node, ctx)
151
+ elif node.op_type == "ConvTranspose":
152
+ vi = ctx.known_vi_[node.output[0]]
153
+ if len(vi.type.tensor_type.shape.dim) == 0:
154
+ vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
155
+ elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
156
+ for attr in node.attribute:
157
+ if attr.name == "operator":
158
+ aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
159
+ aten_handler = get_aten_handler(aten_op_name)
160
+ if aten_handler is not None:
161
+ known_aten_op = True
162
+ aten_handler.infer_shape(node, ctx)
163
+ break
164
+
165
+ if ctx.verbose_ > 2:
166
+ logger.debug(node.op_type + ": " + node.name)
167
+ for i, name in enumerate(node.input):
168
+ logger.debug(f" Input {i}: {name} {'initializer' if name in ctx.initializers_ else ''}")
169
+
170
+ # Handle dimension merging for broadcast ops
171
+ if node.op_type in {
172
+ "Add",
173
+ "Sub",
174
+ "Mul",
175
+ "Div",
176
+ "MatMul",
177
+ "MatMulInteger",
178
+ "MatMulInteger16",
179
+ "Where",
180
+ "Sum",
181
+ }:
182
+ vi = ctx.known_vi_[node.output[0]]
183
+ out_rank = len(get_shape_from_type_proto(vi.type))
184
+ in_shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
185
+ for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
186
+ in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
187
+ if len(in_dims) > 1:
188
+ ctx.check_merged_dims(in_dims, allow_broadcast=True)
189
+
190
+ # Process outputs
191
+ for i_o in range(len(node.output)):
192
+ if node.op_type in {"SkipLayerNormalization", "SkipSimplifiedLayerNormalization"} and i_o in {1, 2}:
193
+ continue
194
+ if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
195
+ continue
196
+
197
+ vi = ctx.known_vi_[node.output[i_o]]
198
+ out_type = vi.type
199
+ out_type_kind = out_type.WhichOneof("value")
200
+
201
+ if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
202
+ if ctx.verbose_ > 2:
203
+ if out_type_kind == "sequence_type":
204
+ seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
205
+ if seq_cls_type == "tensor_type":
206
+ logger.debug(
207
+ f" {node.output[i_o]}: sequence of {str(get_shape_from_value_info(vi))} "
208
+ f"{onnx.TensorProto.DataType.Name(vi.type.sequence_type.elem_type.tensor_type.elem_type)}"
209
+ )
210
+ else:
211
+ logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
212
+ else:
213
+ logger.debug(f" {node.output[i_o]}: {out_type_kind}")
214
+ continue
215
+
216
+ out_shape = get_shape_from_value_info(vi)
217
+ out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
218
+ if ctx.verbose_ > 2:
219
+ logger.debug(
220
+ f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
221
+ )
222
+ if node.output[i_o] in ctx.sympy_data_:
223
+ logger.debug(" Sympy Data: " + str(ctx.sympy_data_[node.output[i_o]]))
224
+
225
+ if (out_shape is not None and (None in out_shape or ctx.is_shape_contains_none_dim(out_shape))) or out_type_undefined:
226
+ if ctx.auto_merge_:
227
+ if node.op_type in {
228
+ "Add",
229
+ "Sub",
230
+ "Mul",
231
+ "Div",
232
+ "MatMul",
233
+ "MatMulInteger",
234
+ "MatMulInteger16",
235
+ "Concat",
236
+ "Where",
237
+ "Sum",
238
+ "Equal",
239
+ "Less",
240
+ "Greater",
241
+ "LessOrEqual",
242
+ "GreaterOrEqual",
243
+ "Min",
244
+ "Max",
245
+ }:
246
+ shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
247
+ if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} and (
248
+ None in out_shape or ctx.is_shape_contains_none_dim(out_shape)
249
+ ):
250
+ if None in out_shape:
251
+ idx = out_shape.index(None)
252
+ else:
253
+ idx = out_shape.index(ctx.is_shape_contains_none_dim(out_shape))
254
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
255
+ assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
256
+ assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
257
+ elif node.op_type == "Expand":
258
+ shapes = [ctx.get_shape(node, 0), ctx.get_value(node, 1)]
259
+ else:
260
+ shapes = []
261
+
262
+ if shapes:
263
+ for idx in range(len(out_shape)):
264
+ if out_shape[idx] is not None and not ctx.is_none_dim(out_shape[idx]):
265
+ continue
266
+ dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
267
+ if dim_idx:
268
+ ctx.add_suggested_merge(
269
+ [s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx) if i >= 0]
270
+ )
271
+ ctx.run_ = True
272
+ else:
273
+ ctx.run_ = False
274
+ else:
275
+ ctx.run_ = False
276
+
277
+ if not ctx.run_ and handler is None and not known_aten_op:
278
+ is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
279
+ if is_unknown_op:
280
+ out_rank = ctx.get_shape_rank(node, 0) if ctx.guess_output_rank_ else -1
281
+ else:
282
+ out_rank = len(out_shape)
283
+
284
+ if out_rank >= 0:
285
+ new_shape = ctx.new_symbolic_shape(out_rank, node, i_o)
286
+ if out_type_undefined:
287
+ out_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
288
+ else:
289
+ out_dtype = vi.type.tensor_type.elem_type
290
+ from .utils import get_shape_from_sympy_shape
291
+
292
+ vi.CopyFrom(
293
+ helper.make_tensor_value_info(vi.name, out_dtype, get_shape_from_sympy_shape(new_shape))
294
+ )
295
+
296
+ if ctx.verbose_ > 0:
297
+ if is_unknown_op:
298
+ logger.debug(f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape")
299
+ if ctx.verbose_ > 2:
300
+ logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
301
+ ctx.run_ = True
302
+ continue
303
+
304
+ if ctx.verbose_ > 0 or not ctx.auto_merge_ or out_type_undefined:
305
+ logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
306
+ logger.debug("node inputs:")
307
+ for i in node.input:
308
+ if i in ctx.known_vi_:
309
+ logger.debug(ctx.known_vi_[i])
310
+ else:
311
+ logger.debug(f"not in known_vi_ for {i}")
312
+ logger.debug("node outputs:")
313
+ for o in node.output:
314
+ if o in ctx.known_vi_:
315
+ logger.debug(ctx.known_vi_[o])
316
+ else:
317
+ logger.debug(f"not in known_vi_ for {o}")
318
+ if ctx.auto_merge_ and not out_type_undefined:
319
+ logger.debug("Merging: " + str(ctx.suggested_merge_))
320
+ return False
321
+
322
+ ctx.run_ = False
323
+ return True
324
+
325
+ def _update_output_from_vi(self, ctx):
326
+ """Update output attributes using known value information dictionary."""
327
+ for output in ctx.out_mp_.graph.output:
328
+ if output.name in ctx.known_vi_:
329
+ output.CopyFrom(ctx.known_vi_[output.name])
330
+
331
+ @staticmethod
332
+ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
333
+ """Perform symbolic shape inference on an ONNX model.
334
+
335
+ Args:
336
+ in_mp: The input ONNX ModelProto.
337
+ int_max: Maximum value for unbounded integers.
338
+ auto_merge: Whether to automatically merge conflicting dimensions.
339
+ guess_output_rank: Whether to guess output rank from input.
340
+ verbose: Logging verbosity level.
341
+
342
+ Returns:
343
+ The model with inferred shapes.
344
+
345
+ Raises:
346
+ Exception: If shape inference is incomplete.
347
+ """
348
+ onnx_opset = get_opset(in_mp)
349
+ if (not onnx_opset) or onnx_opset < 7:
350
+ logger.warning("Only support models of onnx opset 7 and above.")
351
+ return None
352
+
353
+ inferencer = ShapeInferencer(int_max, auto_merge, guess_output_rank, verbose)
354
+
355
+ # Create inference context
356
+ ctx = InferenceContext(
357
+ in_mp,
358
+ int_max=int_max,
359
+ auto_merge=auto_merge,
360
+ guess_output_rank=guess_output_rank,
361
+ verbose=verbose,
362
+ )
363
+ ctx.preprocess()
364
+
365
+ all_shapes_inferred = False
366
+ while ctx.run_:
367
+ all_shapes_inferred = inferencer._infer_impl(ctx)
368
+
369
+ inferencer._update_output_from_vi(ctx)
370
+
371
+ if not all_shapes_inferred:
372
+ raise Exception("Incomplete symbolic shape inference")
373
+
374
+ return ctx.out_mp_
375
+
376
+
377
+ # For backward compatibility
378
+ SymbolicShapeInference = ShapeInferencer
@@ -0,0 +1,16 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """PyTorch ATen operator shape handlers."""
5
+
6
+ from . import bitwise_or
7
+ from . import diagonal
8
+ from . import pool2d
9
+ from . import min_max
10
+ from . import multinomial
11
+ from . import unfold
12
+ from . import argmax
13
+ from . import group_norm
14
+ from . import upsample
15
+ from . import embedding
16
+ from . import numpy_t