onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +84 -3
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.81.dist-info/RECORD +0 -63
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Slice operator."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import sympy
|
|
10
|
+
from onnx import helper
|
|
11
|
+
|
|
12
|
+
from onnxslim.third_party._sympy.solve import try_solve
|
|
13
|
+
|
|
14
|
+
from ...base import ShapeHandler
|
|
15
|
+
from ...registry import register_shape_handler
|
|
16
|
+
from ...utils import (
|
|
17
|
+
as_list,
|
|
18
|
+
get_attribute,
|
|
19
|
+
get_opset,
|
|
20
|
+
get_shape_from_sympy_shape,
|
|
21
|
+
is_literal,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class SliceHandler(ShapeHandler):
|
|
28
|
+
"""Handler for Slice operator."""
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def op_type(self) -> str:
|
|
32
|
+
return "Slice"
|
|
33
|
+
|
|
34
|
+
def infer_shape(self, node, ctx) -> None:
|
|
35
|
+
def flatten_min(expr):
|
|
36
|
+
"""Returns a list with expressions split by min() for inequality proof."""
|
|
37
|
+
assert isinstance(expr, sympy.Add), f"Expected a sum of two arguments, got {expr}"
|
|
38
|
+
min_positions = [idx for idx in range(len(expr.args)) if isinstance(expr.args[idx], sympy.Min)]
|
|
39
|
+
if len(min_positions) == 1:
|
|
40
|
+
min_pos = min_positions[0]
|
|
41
|
+
|
|
42
|
+
def replace_min_with_arg(arg_idx):
|
|
43
|
+
replaced = list(expr.args)
|
|
44
|
+
assert isinstance(replaced[min_pos], sympy.Min)
|
|
45
|
+
assert len(replaced[min_pos].args) == 2
|
|
46
|
+
replaced[min_pos] = replaced[min_pos].args[arg_idx]
|
|
47
|
+
return sympy.Add(*replaced)
|
|
48
|
+
|
|
49
|
+
return [replace_min_with_arg(0), replace_min_with_arg(1)]
|
|
50
|
+
return [expr]
|
|
51
|
+
|
|
52
|
+
def less_equal(x, y):
|
|
53
|
+
"""Returns True if x is less than or equal to y."""
|
|
54
|
+
try:
|
|
55
|
+
return x <= y
|
|
56
|
+
except TypeError:
|
|
57
|
+
pass
|
|
58
|
+
try:
|
|
59
|
+
return y >= x
|
|
60
|
+
except TypeError:
|
|
61
|
+
pass
|
|
62
|
+
try:
|
|
63
|
+
return -x >= -y
|
|
64
|
+
except TypeError:
|
|
65
|
+
pass
|
|
66
|
+
try:
|
|
67
|
+
return -y <= -x
|
|
68
|
+
except TypeError:
|
|
69
|
+
pass
|
|
70
|
+
try:
|
|
71
|
+
return y - x >= 0
|
|
72
|
+
except TypeError:
|
|
73
|
+
return all(d >= 0 for d in flatten_min(y - x))
|
|
74
|
+
|
|
75
|
+
def handle_negative_index(index, bound):
|
|
76
|
+
"""Normalizes a negative index to be in [0, bound)."""
|
|
77
|
+
try:
|
|
78
|
+
if not less_equal(0, index):
|
|
79
|
+
if is_literal(index) and index <= -ctx.int_max_:
|
|
80
|
+
return index
|
|
81
|
+
return bound + index
|
|
82
|
+
except TypeError:
|
|
83
|
+
logger.warning(f"Cannot determine if {index} < 0")
|
|
84
|
+
return index
|
|
85
|
+
|
|
86
|
+
if get_opset(ctx.out_mp_) <= 9:
|
|
87
|
+
axes = get_attribute(node, "axes")
|
|
88
|
+
starts = get_attribute(node, "starts")
|
|
89
|
+
ends = get_attribute(node, "ends")
|
|
90
|
+
if not axes:
|
|
91
|
+
axes = list(range(len(starts)))
|
|
92
|
+
steps = [1] * len(axes)
|
|
93
|
+
else:
|
|
94
|
+
starts = as_list(ctx.try_get_value(node, 1), keep_none=True)
|
|
95
|
+
ends = as_list(ctx.try_get_value(node, 2), keep_none=True)
|
|
96
|
+
axes = ctx.try_get_value(node, 3)
|
|
97
|
+
steps = ctx.try_get_value(node, 4)
|
|
98
|
+
if axes is None and (starts is not None or ends is not None):
|
|
99
|
+
axes = list(range(len(starts if starts is not None else ends)))
|
|
100
|
+
if steps is None and (starts is not None or ends is not None):
|
|
101
|
+
steps = [1] * len(starts if starts is not None else ends)
|
|
102
|
+
axes = as_list(axes, keep_none=True)
|
|
103
|
+
steps = as_list(steps, keep_none=True)
|
|
104
|
+
|
|
105
|
+
new_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
106
|
+
if starts is None or ends is None:
|
|
107
|
+
if axes is None:
|
|
108
|
+
for i in range(len(new_sympy_shape)):
|
|
109
|
+
new_sympy_shape[i] = ctx.new_symbolic_dim_from_output(node, 0, i)
|
|
110
|
+
else:
|
|
111
|
+
new_sympy_shape = get_shape_from_sympy_shape(new_sympy_shape)
|
|
112
|
+
for i in axes:
|
|
113
|
+
new_sympy_shape[i] = ctx.new_symbolic_dim_from_output(node, 0, i)
|
|
114
|
+
else:
|
|
115
|
+
for i, s, e, t in zip(axes, starts, ends, steps):
|
|
116
|
+
if is_literal(e):
|
|
117
|
+
e = handle_negative_index(e, new_sympy_shape[i])
|
|
118
|
+
if is_literal(e):
|
|
119
|
+
if e >= ctx.int_max_:
|
|
120
|
+
e = new_sympy_shape[i]
|
|
121
|
+
elif e <= -ctx.int_max_:
|
|
122
|
+
e = 0 if s > 0 else -1
|
|
123
|
+
elif is_literal(new_sympy_shape[i]):
|
|
124
|
+
if e < 0:
|
|
125
|
+
e = max(0, e + new_sympy_shape[i])
|
|
126
|
+
e = min(e, new_sympy_shape[i])
|
|
127
|
+
else:
|
|
128
|
+
if e > 0:
|
|
129
|
+
e = sympy.Min(e, new_sympy_shape[i]) if e > 1 else e
|
|
130
|
+
else:
|
|
131
|
+
if is_literal(new_sympy_shape[i]):
|
|
132
|
+
if new_sympy_shape[i] < 0:
|
|
133
|
+
e = sympy.Min(e, new_sympy_shape[i])
|
|
134
|
+
else:
|
|
135
|
+
try:
|
|
136
|
+
if not less_equal(e, new_sympy_shape[i]):
|
|
137
|
+
e = new_sympy_shape[i]
|
|
138
|
+
except Exception:
|
|
139
|
+
if len(e.free_symbols) == 1:
|
|
140
|
+
if try_solve((e - new_sympy_shape[i]) >= 0, next(iter(e.free_symbols))) is None:
|
|
141
|
+
logger.warning(
|
|
142
|
+
f"Unable to solve if {e} <= {new_sympy_shape[i]}, treat as not equal"
|
|
143
|
+
)
|
|
144
|
+
else:
|
|
145
|
+
logger.warning(f"Unable to determine if {e} <= {new_sympy_shape[i]}, treat as equal")
|
|
146
|
+
e = new_sympy_shape[i]
|
|
147
|
+
|
|
148
|
+
s = handle_negative_index(s, new_sympy_shape[i])
|
|
149
|
+
if is_literal(new_sympy_shape[i]) and is_literal(s):
|
|
150
|
+
s = max(0, min(s, new_sympy_shape[i]))
|
|
151
|
+
|
|
152
|
+
new_sympy_shape[i] = sympy.simplify((e - s + t + (-1 if t > 0 else 1)) // t)
|
|
153
|
+
|
|
154
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
155
|
+
|
|
156
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
157
|
+
vi.CopyFrom(
|
|
158
|
+
helper.make_tensor_value_info(
|
|
159
|
+
node.output[0],
|
|
160
|
+
vi.type.tensor_type.elem_type,
|
|
161
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
162
|
+
)
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
# handle sympy_data if needed, for slice in shape computation
|
|
166
|
+
if (
|
|
167
|
+
node.input[0] in ctx.sympy_data_
|
|
168
|
+
and [0] == axes
|
|
169
|
+
and starts is not None
|
|
170
|
+
and len(starts) == 1
|
|
171
|
+
and ends is not None
|
|
172
|
+
and len(ends) == 1
|
|
173
|
+
and steps is not None
|
|
174
|
+
and len(steps) == 1
|
|
175
|
+
):
|
|
176
|
+
input_sympy_data = ctx.sympy_data_[node.input[0]]
|
|
177
|
+
if type(input_sympy_data) == list or (
|
|
178
|
+
type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1
|
|
179
|
+
):
|
|
180
|
+
ctx.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]]
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
register_shape_handler(SliceHandler())
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Split operator."""
|
|
5
|
+
|
|
6
|
+
import sympy
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_attribute, get_opset, get_shape_from_sympy_shape, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SplitHandler(ShapeHandler):
|
|
15
|
+
"""Handler for Split operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "Split"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
infer_split_common(node, ctx, helper.make_tensor_value_info)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def infer_split_common(node, ctx, make_value_info_func):
|
|
26
|
+
"""Infers the output shape for the Split operator."""
|
|
27
|
+
input_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
28
|
+
axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape))
|
|
29
|
+
op_set = get_opset(ctx.out_mp_)
|
|
30
|
+
|
|
31
|
+
if op_set < 13:
|
|
32
|
+
split = get_attribute(node, "split")
|
|
33
|
+
assert ctx.try_get_value(node, 1) is None
|
|
34
|
+
else:
|
|
35
|
+
split = ctx.try_get_value(node, 1)
|
|
36
|
+
assert get_attribute(node, "split") is None
|
|
37
|
+
|
|
38
|
+
if split is None:
|
|
39
|
+
num_outputs = len(node.output)
|
|
40
|
+
split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs
|
|
41
|
+
ctx.update_computed_dims(split)
|
|
42
|
+
else:
|
|
43
|
+
split = [sympy.Integer(s) for s in split]
|
|
44
|
+
|
|
45
|
+
for i_o in range(len(split)):
|
|
46
|
+
vi = ctx.known_vi_[node.output[i_o]]
|
|
47
|
+
vi.CopyFrom(
|
|
48
|
+
make_value_info_func(
|
|
49
|
+
node.output[i_o],
|
|
50
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
51
|
+
get_shape_from_sympy_shape([*input_sympy_shape[:axis], split[i_o], *input_sympy_shape[axis + 1 :]]),
|
|
52
|
+
)
|
|
53
|
+
)
|
|
54
|
+
ctx.known_vi_[vi.name] = vi
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
register_shape_handler(SplitHandler())
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Squeeze operator."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
from onnx import helper
|
|
9
|
+
|
|
10
|
+
from ...base import ShapeHandler
|
|
11
|
+
from ...registry import register_shape_handler
|
|
12
|
+
from ...utils import get_attribute, get_opset, handle_negative_axis
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SqueezeHandler(ShapeHandler):
|
|
18
|
+
"""Handler for Squeeze operator."""
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def op_type(self) -> str:
|
|
22
|
+
return "Squeeze"
|
|
23
|
+
|
|
24
|
+
def infer_shape(self, node, ctx) -> None:
|
|
25
|
+
input_shape = ctx.get_shape(node, 0)
|
|
26
|
+
op_set = get_opset(ctx.out_mp_)
|
|
27
|
+
|
|
28
|
+
if op_set < 13:
|
|
29
|
+
axes = get_attribute(node, "axes")
|
|
30
|
+
assert ctx.try_get_value(node, 1) is None
|
|
31
|
+
else:
|
|
32
|
+
axes = ctx.try_get_value(node, 1)
|
|
33
|
+
assert get_attribute(node, "axes") is None
|
|
34
|
+
|
|
35
|
+
if axes is None:
|
|
36
|
+
output_shape = [s for s in input_shape if s != 1]
|
|
37
|
+
if ctx.verbose_ > 0:
|
|
38
|
+
symbolic_dimensions = [s for s in input_shape if type(s) != int]
|
|
39
|
+
if symbolic_dimensions:
|
|
40
|
+
logger.debug(
|
|
41
|
+
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
|
|
42
|
+
f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}"
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
axes = [handle_negative_axis(a, len(input_shape)) for a in axes]
|
|
46
|
+
output_shape = []
|
|
47
|
+
for i in range(len(input_shape)):
|
|
48
|
+
if i not in axes:
|
|
49
|
+
output_shape.append(input_shape[i])
|
|
50
|
+
else:
|
|
51
|
+
assert input_shape[i] == 1 or type(input_shape[i]) != int
|
|
52
|
+
if ctx.verbose_ > 0 and type(input_shape[i]) != int:
|
|
53
|
+
logger.debug(
|
|
54
|
+
f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. "
|
|
55
|
+
f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1."
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
59
|
+
vi.CopyFrom(
|
|
60
|
+
helper.make_tensor_value_info(
|
|
61
|
+
node.output[0],
|
|
62
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
63
|
+
output_shape,
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
ctx.pass_on_sympy_data(node)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
register_shape_handler(SqueezeHandler())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Tile operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TileHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Tile operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Tile"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
repeats_value = ctx.try_get_value(node, 1)
|
|
22
|
+
new_sympy_shape = []
|
|
23
|
+
if repeats_value is not None:
|
|
24
|
+
input_sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
25
|
+
for i, d in enumerate(input_sympy_shape):
|
|
26
|
+
new_dim = d * repeats_value[i]
|
|
27
|
+
new_sympy_shape.append(new_dim)
|
|
28
|
+
ctx.update_computed_dims(new_sympy_shape)
|
|
29
|
+
else:
|
|
30
|
+
new_sympy_shape = ctx.new_symbolic_shape(ctx.get_shape_rank(node, 0), node)
|
|
31
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
32
|
+
vi.CopyFrom(
|
|
33
|
+
helper.make_tensor_value_info(
|
|
34
|
+
node.output[0],
|
|
35
|
+
vi.type.tensor_type.elem_type,
|
|
36
|
+
get_shape_from_sympy_shape(new_sympy_shape),
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_shape_handler(TileHandler())
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Transpose operator."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TransposeHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Transpose operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Transpose"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
if node.input[0] in ctx.sympy_data_:
|
|
22
|
+
data_shape = ctx.get_shape(node, 0)
|
|
23
|
+
perm = get_attribute(node, "perm", reversed(list(range(len(data_shape)))))
|
|
24
|
+
input_data = ctx.sympy_data_[node.input[0]]
|
|
25
|
+
ctx.sympy_data_[node.output[0]] = (
|
|
26
|
+
np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist()
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
register_shape_handler(TransposeHandler())
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Unsqueeze operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute, get_opset, handle_negative_axis
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class UnsqueezeHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Unsqueeze operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Unsqueeze"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
input_shape = ctx.get_shape(node, 0)
|
|
22
|
+
op_set = get_opset(ctx.out_mp_)
|
|
23
|
+
|
|
24
|
+
if op_set < 13:
|
|
25
|
+
axes = get_attribute(node, "axes")
|
|
26
|
+
assert ctx.try_get_value(node, 1) is None
|
|
27
|
+
else:
|
|
28
|
+
axes = ctx.try_get_value(node, 1)
|
|
29
|
+
assert get_attribute(node, "axes") is None
|
|
30
|
+
|
|
31
|
+
output_rank = len(input_shape) + len(axes)
|
|
32
|
+
axes = [handle_negative_axis(a, output_rank) for a in axes]
|
|
33
|
+
|
|
34
|
+
input_axis = 0
|
|
35
|
+
output_shape = []
|
|
36
|
+
for i in range(output_rank):
|
|
37
|
+
if i in axes:
|
|
38
|
+
output_shape.append(1)
|
|
39
|
+
else:
|
|
40
|
+
output_shape.append(input_shape[input_axis])
|
|
41
|
+
input_axis += 1
|
|
42
|
+
|
|
43
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
44
|
+
vi.CopyFrom(
|
|
45
|
+
helper.make_tensor_value_info(
|
|
46
|
+
node.output[0],
|
|
47
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
48
|
+
output_shape,
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
ctx.pass_on_sympy_data(node)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
register_shape_handler(UnsqueezeHandler())
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Utility functions for symbolic shape inference."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import onnx
|
|
8
|
+
import sympy
|
|
9
|
+
from onnx import helper, numpy_helper
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def get_attribute(node, attr_name, default_value=None):
|
|
13
|
+
"""Retrieve the value of an attribute from an ONNX node.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
node: The ONNX node.
|
|
17
|
+
attr_name: The name of the attribute to retrieve.
|
|
18
|
+
default_value: The default value if the attribute is not found.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The attribute value or the default value.
|
|
22
|
+
"""
|
|
23
|
+
found = [attr for attr in node.attribute if attr.name == attr_name]
|
|
24
|
+
return helper.get_attribute_value(found[0]) if found else default_value
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def get_dim_from_proto(dim):
|
|
28
|
+
"""Retrieve the dimension value from the ONNX protobuf object.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
dim: The ONNX TensorShapeProto.Dimension.
|
|
32
|
+
|
|
33
|
+
Returns:
|
|
34
|
+
The dimension value (int or str) or None.
|
|
35
|
+
"""
|
|
36
|
+
return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) is str else None
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def is_sequence(type_proto):
|
|
40
|
+
"""Check if the given ONNX proto type is a sequence.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
type_proto: The ONNX TypeProto.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
True if the type is a sequence type.
|
|
47
|
+
"""
|
|
48
|
+
cls_type = type_proto.WhichOneof("value")
|
|
49
|
+
assert cls_type in {"tensor_type", "sequence_type"}
|
|
50
|
+
return cls_type == "sequence_type"
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_shape_from_type_proto(type_proto):
|
|
54
|
+
"""Extract the shape of a tensor from an ONNX type proto.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
type_proto: The ONNX TypeProto.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
A list of dimension values or None if no shape is available.
|
|
61
|
+
"""
|
|
62
|
+
assert not is_sequence(type_proto)
|
|
63
|
+
if type_proto.tensor_type.HasField("shape"):
|
|
64
|
+
return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim]
|
|
65
|
+
else:
|
|
66
|
+
return None
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def get_elem_type_from_type_proto(type_proto):
|
|
70
|
+
"""Return the element type from a given TypeProto object.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
type_proto: The ONNX TypeProto.
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
The element type (e.g., TensorProto.FLOAT).
|
|
77
|
+
"""
|
|
78
|
+
if is_sequence(type_proto):
|
|
79
|
+
return type_proto.sequence_type.elem_type.tensor_type.elem_type
|
|
80
|
+
else:
|
|
81
|
+
return type_proto.tensor_type.elem_type
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def get_shape_from_value_info(vi):
|
|
85
|
+
"""Return the shape from the given ValueInfoProto object.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
vi: The ONNX ValueInfoProto.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
A list of dimension values or None.
|
|
92
|
+
"""
|
|
93
|
+
cls_type = vi.type.WhichOneof("value")
|
|
94
|
+
if cls_type is None:
|
|
95
|
+
return None
|
|
96
|
+
if not is_sequence(vi.type):
|
|
97
|
+
return get_shape_from_type_proto(vi.type)
|
|
98
|
+
if vi.type.sequence_type.elem_type.WhichOneof("value") == "tensor_type":
|
|
99
|
+
return get_shape_from_type_proto(vi.type.sequence_type.elem_type)
|
|
100
|
+
else:
|
|
101
|
+
return None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def make_named_value_info(name):
|
|
105
|
+
"""Create and return an ONNX ValueInfoProto object with the specified name.
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
name: The name for the ValueInfoProto.
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
A new ValueInfoProto with the given name.
|
|
112
|
+
"""
|
|
113
|
+
vi = onnx.ValueInfoProto()
|
|
114
|
+
vi.name = name
|
|
115
|
+
return vi
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_shape_from_sympy_shape(sympy_shape):
|
|
119
|
+
"""Convert a sympy shape to a list with int, str, or None elements.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
sympy_shape: A list of sympy expressions.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
A list of int, str, or None values.
|
|
126
|
+
"""
|
|
127
|
+
return [None if i is None else (int(i) if is_literal(i) else str(i)) for i in sympy_shape]
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def is_literal(dim):
|
|
131
|
+
"""Check if a dimension is a literal number.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
dim: The dimension value to check.
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
True if the dimension is a literal number.
|
|
138
|
+
"""
|
|
139
|
+
return type(dim) in {int, np.int64, np.int32, sympy.Integer} or (hasattr(dim, "is_number") and dim.is_number)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def handle_negative_axis(axis, rank):
|
|
143
|
+
"""Convert a potentially negative axis to a positive axis.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
axis: The axis value (can be negative).
|
|
147
|
+
rank: The total rank of the tensor.
|
|
148
|
+
|
|
149
|
+
Returns:
|
|
150
|
+
A non-negative axis value.
|
|
151
|
+
"""
|
|
152
|
+
assert axis < rank and axis >= -rank
|
|
153
|
+
return axis if axis >= 0 else rank + axis
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def get_opset(mp, domain=None):
|
|
157
|
+
"""Retrieve the opset version for a given model namespace.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
mp: The ONNX ModelProto.
|
|
161
|
+
domain: The domain(s) to check. Defaults to common ONNX domains.
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
The opset version or None if not found.
|
|
165
|
+
"""
|
|
166
|
+
domain = domain or ["", "onnx", "ai.onnx"]
|
|
167
|
+
if type(domain) != list:
|
|
168
|
+
domain = [domain]
|
|
169
|
+
for opset in mp.opset_import:
|
|
170
|
+
if opset.domain in domain:
|
|
171
|
+
return opset.version
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def as_scalar(x):
|
|
176
|
+
"""Convert input to scalar if input is a list with a single item or a NumPy ndarray.
|
|
177
|
+
|
|
178
|
+
Args:
|
|
179
|
+
x: The input value.
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
A scalar value.
|
|
183
|
+
"""
|
|
184
|
+
if type(x) == list:
|
|
185
|
+
assert len(x) == 1
|
|
186
|
+
return x[0]
|
|
187
|
+
elif type(x) == np.ndarray:
|
|
188
|
+
return x.item()
|
|
189
|
+
else:
|
|
190
|
+
return x
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def as_list(x, keep_none):
|
|
194
|
+
"""Convert input to list, optionally preserving None values.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
x: The input value.
|
|
198
|
+
keep_none: If True, return None as-is instead of wrapping in list.
|
|
199
|
+
|
|
200
|
+
Returns:
|
|
201
|
+
A list or None.
|
|
202
|
+
"""
|
|
203
|
+
if type(x) == list:
|
|
204
|
+
return x
|
|
205
|
+
elif type(x) == np.ndarray:
|
|
206
|
+
return list(x)
|
|
207
|
+
elif keep_none and x is None:
|
|
208
|
+
return None
|
|
209
|
+
else:
|
|
210
|
+
return [x]
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def sympy_reduce_product(x):
|
|
214
|
+
"""Reduce a list or element to a product using Sympy's Integer.
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
x: A list or single value.
|
|
218
|
+
|
|
219
|
+
Returns:
|
|
220
|
+
The product as a sympy expression.
|
|
221
|
+
"""
|
|
222
|
+
if type(x) == list:
|
|
223
|
+
value = sympy.Integer(1)
|
|
224
|
+
for v in x:
|
|
225
|
+
value = value * v
|
|
226
|
+
else:
|
|
227
|
+
value = x
|
|
228
|
+
return value
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def numpy_to_sympy(array):
|
|
232
|
+
"""Convert a numpy array to a list of sympy values.
|
|
233
|
+
|
|
234
|
+
Args:
|
|
235
|
+
array: A numpy array.
|
|
236
|
+
|
|
237
|
+
Returns:
|
|
238
|
+
The converted list or value.
|
|
239
|
+
"""
|
|
240
|
+
if isinstance(array, np.ndarray):
|
|
241
|
+
if array.ndim == 0:
|
|
242
|
+
return int(array.item())
|
|
243
|
+
return [int(x) for x in array.flatten()]
|
|
244
|
+
return array
|