onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,645 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""InferenceContext class for managing shape inference state."""
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnx
|
|
10
|
+
import sympy
|
|
11
|
+
from onnx import helper, numpy_helper, shape_inference
|
|
12
|
+
|
|
13
|
+
from onnxslim.third_party._sympy.functions import FloorDiv
|
|
14
|
+
from onnxslim.third_party._sympy.printers import PythonPrinter as _PythonPrinter
|
|
15
|
+
|
|
16
|
+
from .utils import (
|
|
17
|
+
as_list,
|
|
18
|
+
as_scalar,
|
|
19
|
+
get_attribute,
|
|
20
|
+
get_elem_type_from_type_proto,
|
|
21
|
+
get_opset,
|
|
22
|
+
get_shape_from_sympy_shape,
|
|
23
|
+
get_shape_from_type_proto,
|
|
24
|
+
get_shape_from_value_info,
|
|
25
|
+
handle_negative_axis,
|
|
26
|
+
is_literal,
|
|
27
|
+
is_sequence,
|
|
28
|
+
make_named_value_info,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger = logging.getLogger(__name__)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class PythonPrinter(_PythonPrinter):
|
|
35
|
+
"""Custom Python printer for sympy expressions."""
|
|
36
|
+
|
|
37
|
+
def doprint(self, expr: sympy.Expr, *, simplify: bool = True, p: bool = True) -> str:
|
|
38
|
+
return super().doprint(expr)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
pexpr = PythonPrinter().doprint
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class InferenceContext:
|
|
45
|
+
"""Context object that encapsulates all state for shape inference.
|
|
46
|
+
|
|
47
|
+
This class provides access to:
|
|
48
|
+
- Known value info (known_vi_)
|
|
49
|
+
- Symbolic dimensions (symbolic_dims_)
|
|
50
|
+
- Sympy computed data (sympy_data_)
|
|
51
|
+
- Initializers (initializers_)
|
|
52
|
+
- Graph inputs (graph_inputs_)
|
|
53
|
+
- Model opset and other configuration
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
out_mp,
|
|
59
|
+
int_max=2**31 - 1,
|
|
60
|
+
auto_merge=False,
|
|
61
|
+
guess_output_rank=False,
|
|
62
|
+
verbose=0,
|
|
63
|
+
prefix="",
|
|
64
|
+
):
|
|
65
|
+
"""Initialize the inference context.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
out_mp: The ONNX ModelProto being processed.
|
|
69
|
+
int_max: Maximum value for unbounded integers.
|
|
70
|
+
auto_merge: Whether to automatically merge conflicting dimensions.
|
|
71
|
+
guess_output_rank: Whether to guess output rank from input.
|
|
72
|
+
verbose: Logging verbosity level.
|
|
73
|
+
prefix: Prefix for generated symbolic dimension names.
|
|
74
|
+
"""
|
|
75
|
+
self.out_mp_ = out_mp
|
|
76
|
+
self.int_max_ = int_max
|
|
77
|
+
self.auto_merge_ = auto_merge
|
|
78
|
+
self.guess_output_rank_ = guess_output_rank
|
|
79
|
+
self.verbose_ = verbose
|
|
80
|
+
self.prefix_ = prefix
|
|
81
|
+
self.subgraph_id_ = 0
|
|
82
|
+
|
|
83
|
+
# State that needs to be initialized
|
|
84
|
+
self.known_vi_ = {}
|
|
85
|
+
self.symbolic_dims_ = {}
|
|
86
|
+
self.sympy_data_ = {}
|
|
87
|
+
self.initializers_ = {}
|
|
88
|
+
self.graph_inputs_ = {}
|
|
89
|
+
self.input_symbols_ = set()
|
|
90
|
+
self.suggested_merge_ = {}
|
|
91
|
+
self.run_ = True
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def opset(self):
|
|
95
|
+
"""Get the ONNX opset version of the model."""
|
|
96
|
+
return get_opset(self.out_mp_)
|
|
97
|
+
|
|
98
|
+
def preprocess(self):
|
|
99
|
+
"""Initialize data structures from the model."""
|
|
100
|
+
self.graph_inputs_ = {i.name: i for i in list(self.out_mp_.graph.input)}
|
|
101
|
+
self.initializers_ = {i.name: i for i in self.out_mp_.graph.initializer}
|
|
102
|
+
self.known_vi_ = {i.name: i for i in list(self.out_mp_.graph.input)}
|
|
103
|
+
self.known_vi_.update(
|
|
104
|
+
{
|
|
105
|
+
i.name: helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))
|
|
106
|
+
for i in self.out_mp_.graph.initializer
|
|
107
|
+
}
|
|
108
|
+
)
|
|
109
|
+
self.known_vi_.update({i.name: i for i in list(self.out_mp_.graph.output)})
|
|
110
|
+
|
|
111
|
+
# Shape retrieval methods
|
|
112
|
+
def get_shape(self, node, idx):
|
|
113
|
+
"""Retrieve the shape of a tensor from a node's inputs."""
|
|
114
|
+
name = node.input[idx]
|
|
115
|
+
if name in self.known_vi_:
|
|
116
|
+
vi = self.known_vi_[name]
|
|
117
|
+
return get_shape_from_value_info(vi)
|
|
118
|
+
else:
|
|
119
|
+
assert name in self.initializers_
|
|
120
|
+
return list(self.initializers_[name].dims)
|
|
121
|
+
|
|
122
|
+
def try_get_shape(self, node, idx):
|
|
123
|
+
"""Attempts to retrieve the shape of the input node at the specified index."""
|
|
124
|
+
if idx > len(node.input) - 1:
|
|
125
|
+
return None
|
|
126
|
+
name = node.input[idx]
|
|
127
|
+
if name in self.known_vi_:
|
|
128
|
+
vi = self.known_vi_[name]
|
|
129
|
+
return get_shape_from_value_info(vi)
|
|
130
|
+
if name in self.initializers_:
|
|
131
|
+
return list(self.initializers_[name].dims)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
def get_shape_rank(self, node, idx):
|
|
135
|
+
"""Return the rank (number of dimensions) of the input tensor."""
|
|
136
|
+
return len(self.get_shape(node, idx))
|
|
137
|
+
|
|
138
|
+
def get_sympy_shape(self, node, idx):
|
|
139
|
+
"""Return the symbolic shape dimensions using SymPy."""
|
|
140
|
+
sympy_shape = []
|
|
141
|
+
for d in self.get_shape(node, idx):
|
|
142
|
+
if type(d) == str:
|
|
143
|
+
sympy_shape.append(
|
|
144
|
+
self.symbolic_dims_[d]
|
|
145
|
+
if d in self.symbolic_dims_
|
|
146
|
+
else sympy.Symbol(d, integer=True, nonnegative=True)
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
assert None is not d
|
|
150
|
+
sympy_shape.append(d)
|
|
151
|
+
return sympy_shape
|
|
152
|
+
|
|
153
|
+
# Value retrieval methods
|
|
154
|
+
def get_value(self, node, idx):
|
|
155
|
+
"""Retrieve the value associated with a node's input index."""
|
|
156
|
+
name = node.input[idx]
|
|
157
|
+
assert name in self.sympy_data_ or name in self.initializers_
|
|
158
|
+
return self.sympy_data_[name] if name in self.sympy_data_ else numpy_helper.to_array(self.initializers_[name])
|
|
159
|
+
|
|
160
|
+
def try_get_value(self, node, idx):
|
|
161
|
+
"""Try to retrieve the value associated with a node's input index."""
|
|
162
|
+
if idx >= len(node.input):
|
|
163
|
+
return None
|
|
164
|
+
name = node.input[idx]
|
|
165
|
+
if name in self.sympy_data_ or name in self.initializers_:
|
|
166
|
+
return self.get_value(node, idx)
|
|
167
|
+
return None
|
|
168
|
+
|
|
169
|
+
# Symbolic dimension management
|
|
170
|
+
def new_symbolic_dim(self, prefix, dim):
|
|
171
|
+
"""Create and return a new symbolic dimension."""
|
|
172
|
+
new_dim = f"{prefix}_d{dim}"
|
|
173
|
+
if new_dim in self.suggested_merge_:
|
|
174
|
+
v = self.suggested_merge_[new_dim]
|
|
175
|
+
new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v
|
|
176
|
+
else:
|
|
177
|
+
new_symbolic_dim = sympy.Symbol(new_dim, integer=True, nonnegative=True)
|
|
178
|
+
self.symbolic_dims_[new_dim] = new_symbolic_dim
|
|
179
|
+
return new_symbolic_dim
|
|
180
|
+
|
|
181
|
+
def new_symbolic_dim_from_output(self, node, out_idx=0, dim=0):
|
|
182
|
+
"""Generates a new symbolic dimension for a given node's output."""
|
|
183
|
+
return self.new_symbolic_dim(
|
|
184
|
+
f"{node.op_type}{self.prefix_}_{list(self.out_mp_.graph.node).index(node)}_o{out_idx}_",
|
|
185
|
+
dim,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
def new_symbolic_shape(self, rank, node, out_idx=0):
|
|
189
|
+
"""Generate a new symbolic shape for a node output based on its rank."""
|
|
190
|
+
return [self.new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)]
|
|
191
|
+
|
|
192
|
+
def update_computed_dims(self, new_sympy_shape):
|
|
193
|
+
"""Update dimensions in new_sympy_shape based on suggested merges."""
|
|
194
|
+
for i, new_dim in enumerate(new_sympy_shape):
|
|
195
|
+
if not is_literal(new_dim) and type(new_dim) != str:
|
|
196
|
+
str_dim = pexpr(new_dim)
|
|
197
|
+
if str_dim in self.suggested_merge_:
|
|
198
|
+
if not is_literal(self.suggested_merge_[str_dim]):
|
|
199
|
+
new_sympy_shape[i] = self.symbolic_dims_[self.suggested_merge_[str_dim]]
|
|
200
|
+
elif str_dim not in self.symbolic_dims_:
|
|
201
|
+
self.symbolic_dims_[str_dim] = new_dim
|
|
202
|
+
|
|
203
|
+
# Dimension merging
|
|
204
|
+
def add_suggested_merge(self, symbols, apply=False):
|
|
205
|
+
"""Add suggested merges for input symbols."""
|
|
206
|
+
assert all((type(s) == str and s in self.symbolic_dims_) or is_literal(s) for s in symbols)
|
|
207
|
+
symbols = set(symbols)
|
|
208
|
+
for k, v in self.suggested_merge_.items():
|
|
209
|
+
if k in symbols:
|
|
210
|
+
symbols.remove(k)
|
|
211
|
+
symbols.add(v)
|
|
212
|
+
map_to = None
|
|
213
|
+
# if there is literal, map to it first
|
|
214
|
+
for s in symbols:
|
|
215
|
+
if is_literal(s):
|
|
216
|
+
map_to = s
|
|
217
|
+
break
|
|
218
|
+
# when no literals, map to input symbolic dims, then existing symbolic dims
|
|
219
|
+
if map_to is None:
|
|
220
|
+
for s in symbols:
|
|
221
|
+
if s in self.input_symbols_:
|
|
222
|
+
map_to = s
|
|
223
|
+
break
|
|
224
|
+
if map_to is None:
|
|
225
|
+
for s in symbols:
|
|
226
|
+
if type(self.symbolic_dims_[s]) == sympy.Symbol:
|
|
227
|
+
map_to = s
|
|
228
|
+
break
|
|
229
|
+
# when nothing to map to, use the shorter one
|
|
230
|
+
if map_to is None:
|
|
231
|
+
if self.verbose_ > 0:
|
|
232
|
+
logger.warning(f"Potential unsafe merge between symbolic expressions: ({','.join(symbols)})")
|
|
233
|
+
symbols_list = list(symbols)
|
|
234
|
+
lens = [len(s) for s in symbols_list]
|
|
235
|
+
map_to = symbols_list[lens.index(min(lens))]
|
|
236
|
+
symbols.remove(map_to)
|
|
237
|
+
|
|
238
|
+
for s in symbols:
|
|
239
|
+
if s == map_to:
|
|
240
|
+
continue
|
|
241
|
+
if is_literal(map_to) and is_literal(s):
|
|
242
|
+
assert int(map_to) == int(s)
|
|
243
|
+
self.suggested_merge_[s] = int(map_to) if is_literal(map_to) else map_to
|
|
244
|
+
for k, v in self.suggested_merge_.items():
|
|
245
|
+
if v == s:
|
|
246
|
+
self.suggested_merge_[k] = map_to
|
|
247
|
+
if apply and self.auto_merge_:
|
|
248
|
+
self.apply_suggested_merge()
|
|
249
|
+
|
|
250
|
+
def apply_suggested_merge(self, graph_input_only=False):
|
|
251
|
+
"""Applies suggested merges to graph dimensions."""
|
|
252
|
+
if not self.suggested_merge_:
|
|
253
|
+
return
|
|
254
|
+
for i in list(self.out_mp_.graph.input) + ([] if graph_input_only else list(self.out_mp_.graph.value_info)):
|
|
255
|
+
for d in i.type.tensor_type.shape.dim:
|
|
256
|
+
if d.dim_param in self.suggested_merge_:
|
|
257
|
+
v = self.suggested_merge_[d.dim_param]
|
|
258
|
+
if is_literal(v):
|
|
259
|
+
d.dim_value = int(v)
|
|
260
|
+
else:
|
|
261
|
+
d.dim_param = v
|
|
262
|
+
|
|
263
|
+
def merge_symbols(self, dims):
|
|
264
|
+
"""Merge dimension symbols, handling automatic merging and validation."""
|
|
265
|
+
if any(type(d) != str for d in dims):
|
|
266
|
+
if not self.auto_merge_:
|
|
267
|
+
return None
|
|
268
|
+
unique_dims = list(set(dims))
|
|
269
|
+
is_int = [is_literal(d) for d in unique_dims]
|
|
270
|
+
assert sum(is_int) <= 1
|
|
271
|
+
if sum(is_int) == 1:
|
|
272
|
+
int_dim = is_int.index(1)
|
|
273
|
+
if self.verbose_ > 0:
|
|
274
|
+
logger.debug(
|
|
275
|
+
f"dim {unique_dims[:int_dim] + unique_dims[int_dim + 1 :]} has been merged with value {unique_dims[int_dim]}"
|
|
276
|
+
)
|
|
277
|
+
self.check_merged_dims(unique_dims, allow_broadcast=False)
|
|
278
|
+
return unique_dims[int_dim]
|
|
279
|
+
else:
|
|
280
|
+
if self.verbose_ > 0:
|
|
281
|
+
logger.debug(f"dim {unique_dims[1:]} has been merged with dim {unique_dims[0]}")
|
|
282
|
+
return dims[0]
|
|
283
|
+
if all(d == dims[0] for d in dims):
|
|
284
|
+
return dims[0]
|
|
285
|
+
merged = [self.suggested_merge_[d] if d in self.suggested_merge_ else d for d in dims]
|
|
286
|
+
if all(d == merged[0] for d in merged):
|
|
287
|
+
assert merged[0] in self.symbolic_dims_
|
|
288
|
+
return merged[0]
|
|
289
|
+
else:
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
def check_merged_dims(self, dims, allow_broadcast=True):
|
|
293
|
+
"""Checks merged dimensions for consistency."""
|
|
294
|
+
if allow_broadcast:
|
|
295
|
+
dims = [d for d in dims if not (is_literal(d) and int(d) <= 1)]
|
|
296
|
+
if any(d != dims[0] for d in dims):
|
|
297
|
+
self.add_suggested_merge(dims, apply=True)
|
|
298
|
+
|
|
299
|
+
# Broadcasting
|
|
300
|
+
def broadcast_shapes(self, shape1, shape2):
|
|
301
|
+
"""Broadcast two shapes from right to left."""
|
|
302
|
+
new_shape = []
|
|
303
|
+
rank1 = len(shape1)
|
|
304
|
+
rank2 = len(shape2)
|
|
305
|
+
new_rank = max(rank1, rank2)
|
|
306
|
+
for i in range(new_rank):
|
|
307
|
+
dim1 = shape1[rank1 - 1 - i] if i < rank1 else 1
|
|
308
|
+
dim2 = shape2[rank2 - 1 - i] if i < rank2 else 1
|
|
309
|
+
if dim1 in [1, dim2]:
|
|
310
|
+
new_dim = dim2
|
|
311
|
+
elif dim2 == 1:
|
|
312
|
+
new_dim = dim1
|
|
313
|
+
else:
|
|
314
|
+
new_dim = self.merge_symbols([dim1, dim2])
|
|
315
|
+
if not new_dim:
|
|
316
|
+
if self.auto_merge_:
|
|
317
|
+
self.add_suggested_merge([dim1, dim2], apply=True)
|
|
318
|
+
else:
|
|
319
|
+
logger.warning(f"unsupported broadcast between {dim1!s} {dim2!s}")
|
|
320
|
+
new_shape = [new_dim, *new_shape]
|
|
321
|
+
return new_shape
|
|
322
|
+
|
|
323
|
+
# Shape computations
|
|
324
|
+
def compute_conv_pool_shape(self, node, channels_last=False):
|
|
325
|
+
"""Calculate the output shape of a convolutional or pooling layer."""
|
|
326
|
+
sympy_shape = self.get_sympy_shape(node, 0)
|
|
327
|
+
if len(node.input) > 1:
|
|
328
|
+
W_shape = self.get_sympy_shape(node, 1)
|
|
329
|
+
rank = len(W_shape) - 2
|
|
330
|
+
kernel_shape = W_shape[-rank - 1 : -1] if channels_last else W_shape[-rank:]
|
|
331
|
+
sympy_shape[3 if channels_last else 1] = W_shape[0]
|
|
332
|
+
else:
|
|
333
|
+
W_shape = None
|
|
334
|
+
kernel_shape = get_attribute(node, "kernel_shape")
|
|
335
|
+
rank = len(kernel_shape)
|
|
336
|
+
|
|
337
|
+
assert len(sympy_shape) == rank + 2
|
|
338
|
+
|
|
339
|
+
spatial_shape = sympy_shape[-rank - 1 : -1] if channels_last else sympy_shape[-rank:]
|
|
340
|
+
is_symbolic_dims = [not is_literal(i) for i in spatial_shape]
|
|
341
|
+
|
|
342
|
+
if not any(is_symbolic_dims):
|
|
343
|
+
shape = get_shape_from_value_info(self.known_vi_[node.output[0]])
|
|
344
|
+
if len(shape) > 0:
|
|
345
|
+
assert len(sympy_shape) == len(shape)
|
|
346
|
+
if channels_last:
|
|
347
|
+
sympy_shape[-rank - 1 : -1] = [sympy.Integer(d) for d in shape[-rank - 1 : -1]]
|
|
348
|
+
else:
|
|
349
|
+
sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]]
|
|
350
|
+
return sympy_shape
|
|
351
|
+
|
|
352
|
+
dilations = get_attribute(node, "dilations", [1] * rank)
|
|
353
|
+
strides = get_attribute(node, "strides", [1] * rank)
|
|
354
|
+
effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)]
|
|
355
|
+
pads = get_attribute(node, "pads")
|
|
356
|
+
if pads is None:
|
|
357
|
+
pads = [0] * (2 * rank)
|
|
358
|
+
auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8")
|
|
359
|
+
if auto_pad not in {"VALID", "NOTSET"}:
|
|
360
|
+
try:
|
|
361
|
+
residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)]
|
|
362
|
+
total_pads = [
|
|
363
|
+
max(0, (k - s) if r == 0 else (k - r))
|
|
364
|
+
for k, s, r in zip(effective_kernel_shape, strides, residual)
|
|
365
|
+
]
|
|
366
|
+
except TypeError:
|
|
367
|
+
total_pads = [max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides)]
|
|
368
|
+
elif auto_pad == "VALID":
|
|
369
|
+
total_pads = []
|
|
370
|
+
else:
|
|
371
|
+
total_pads = [0] * rank
|
|
372
|
+
else:
|
|
373
|
+
assert len(pads) == 2 * rank
|
|
374
|
+
total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])]
|
|
375
|
+
|
|
376
|
+
ceil_mode = get_attribute(node, "ceil_mode", 0)
|
|
377
|
+
for i in range(rank):
|
|
378
|
+
effective_input_size = sympy_shape[-rank + i + (-1 if channels_last else 0)]
|
|
379
|
+
if len(total_pads) > 0:
|
|
380
|
+
effective_input_size = effective_input_size + total_pads[i]
|
|
381
|
+
if ceil_mode:
|
|
382
|
+
strided_kernel_positions = sympy.ceiling(
|
|
383
|
+
(effective_input_size - effective_kernel_shape[i]) / strides[i]
|
|
384
|
+
)
|
|
385
|
+
else:
|
|
386
|
+
strided_kernel_positions = FloorDiv((effective_input_size - effective_kernel_shape[i]), strides[i])
|
|
387
|
+
sympy_shape[-rank + i + (-1 if channels_last else 0)] = strided_kernel_positions + 1
|
|
388
|
+
return sympy_shape
|
|
389
|
+
|
|
390
|
+
def compute_matmul_shape(self, node, output_dtype=None):
|
|
391
|
+
"""Compute the output shape for a matrix multiplication operation."""
|
|
392
|
+
lhs_shape = self.get_shape(node, 0)
|
|
393
|
+
rhs_shape = self.get_shape(node, 1)
|
|
394
|
+
lhs_rank = len(lhs_shape)
|
|
395
|
+
rhs_rank = len(rhs_shape)
|
|
396
|
+
lhs_reduce_dim = 0
|
|
397
|
+
rhs_reduce_dim = 0
|
|
398
|
+
assert lhs_rank > 0 and rhs_rank > 0
|
|
399
|
+
if lhs_rank == 1 and rhs_rank == 1:
|
|
400
|
+
new_shape = []
|
|
401
|
+
elif lhs_rank == 1:
|
|
402
|
+
rhs_reduce_dim = -2
|
|
403
|
+
new_shape = [*rhs_shape[:rhs_reduce_dim], rhs_shape[-1]]
|
|
404
|
+
elif rhs_rank == 1:
|
|
405
|
+
lhs_reduce_dim = -1
|
|
406
|
+
new_shape = lhs_shape[:lhs_reduce_dim]
|
|
407
|
+
else:
|
|
408
|
+
lhs_reduce_dim = -1
|
|
409
|
+
rhs_reduce_dim = -2
|
|
410
|
+
new_shape = [
|
|
411
|
+
*self.broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]),
|
|
412
|
+
lhs_shape[-2],
|
|
413
|
+
rhs_shape[-1],
|
|
414
|
+
]
|
|
415
|
+
self.check_merged_dims(
|
|
416
|
+
[lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]],
|
|
417
|
+
allow_broadcast=False,
|
|
418
|
+
)
|
|
419
|
+
if output_dtype is None:
|
|
420
|
+
output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
421
|
+
vi = self.known_vi_[node.output[0]]
|
|
422
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
423
|
+
|
|
424
|
+
# Value operations
|
|
425
|
+
def get_int_or_float_values(self, node, broadcast=False, allow_float_values=False):
|
|
426
|
+
"""Extracts integer or float values from a node."""
|
|
427
|
+
|
|
428
|
+
def int_or_float(value, allow_float_values):
|
|
429
|
+
return value if allow_float_values and value % 1 != 0 else int(value)
|
|
430
|
+
|
|
431
|
+
values = [self.try_get_value(node, i) for i in range(len(node.input))]
|
|
432
|
+
if all(v is not None for v in values):
|
|
433
|
+
for i, v in enumerate(values):
|
|
434
|
+
if type(v) != np.ndarray:
|
|
435
|
+
continue
|
|
436
|
+
if len(v.shape) > 1:
|
|
437
|
+
new_v = None
|
|
438
|
+
elif len(v.shape) == 0:
|
|
439
|
+
new_v = int_or_float(v.item(), allow_float_values)
|
|
440
|
+
else:
|
|
441
|
+
assert len(v.shape) == 1
|
|
442
|
+
new_v = [int_or_float(vv, allow_float_values) for vv in v]
|
|
443
|
+
values[i] = new_v
|
|
444
|
+
values_len = [len(v) if isinstance(v, list) else 0 for v in values]
|
|
445
|
+
max_len = max(values_len)
|
|
446
|
+
if max_len >= 1 and broadcast:
|
|
447
|
+
for i, v in enumerate(values):
|
|
448
|
+
if v is None:
|
|
449
|
+
continue
|
|
450
|
+
if isinstance(v, list):
|
|
451
|
+
if len(v) < max_len:
|
|
452
|
+
values[i] = v * max_len
|
|
453
|
+
else:
|
|
454
|
+
assert len(v) == max_len
|
|
455
|
+
else:
|
|
456
|
+
values[i] = [v] * max_len
|
|
457
|
+
return values
|
|
458
|
+
|
|
459
|
+
def compute_on_sympy_data(self, node, op_func):
|
|
460
|
+
"""Calculate the result using Sympy data and a specified operation function."""
|
|
461
|
+
assert len(node.output) == 1
|
|
462
|
+
|
|
463
|
+
if node.op_type in {"Mul", "Div"}:
|
|
464
|
+
values = self.get_int_or_float_values(node, broadcast=True, allow_float_values=True)
|
|
465
|
+
else:
|
|
466
|
+
values = self.get_int_or_float_values(node, broadcast=True)
|
|
467
|
+
if all(v is not None for v in values):
|
|
468
|
+
is_list = [isinstance(v, list) for v in values]
|
|
469
|
+
as_list = any(is_list)
|
|
470
|
+
if as_list:
|
|
471
|
+
self.sympy_data_[node.output[0]] = [op_func(vs) for vs in zip(*values)]
|
|
472
|
+
else:
|
|
473
|
+
self.sympy_data_[node.output[0]] = op_func(values)
|
|
474
|
+
|
|
475
|
+
def pass_on_sympy_data(self, node):
|
|
476
|
+
"""Pass Sympy data through a node."""
|
|
477
|
+
assert len(node.input) == 1 or node.op_type in {
|
|
478
|
+
"Reshape",
|
|
479
|
+
"Unsqueeze",
|
|
480
|
+
"Squeeze",
|
|
481
|
+
}
|
|
482
|
+
self.compute_on_sympy_data(node, lambda x: x[0])
|
|
483
|
+
|
|
484
|
+
# Shape propagation
|
|
485
|
+
def pass_on_shape_and_type(self, node):
|
|
486
|
+
"""Propagates the shape and type information from input to output."""
|
|
487
|
+
vi = self.known_vi_[node.output[0]]
|
|
488
|
+
vi.CopyFrom(
|
|
489
|
+
helper.make_tensor_value_info(
|
|
490
|
+
node.output[0],
|
|
491
|
+
get_elem_type_from_type_proto(self.known_vi_[node.input[0]].type),
|
|
492
|
+
self.get_shape(node, 0),
|
|
493
|
+
)
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
def propagate_shape_and_type(self, node, input_index=0, output_index=0):
|
|
497
|
+
"""Propagates the shape and type information from input to output tensors."""
|
|
498
|
+
shape = self.get_shape(node, input_index)
|
|
499
|
+
output_dtype = self.known_vi_[node.input[input_index]].type.tensor_type.elem_type
|
|
500
|
+
vi = self.known_vi_[node.output[output_index]]
|
|
501
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtype, shape))
|
|
502
|
+
|
|
503
|
+
def fuse_tensor_type(self, node, out_idx, dst_type, src_type):
|
|
504
|
+
"""Update dst_tensor_type to be compatible with src_tensor_type."""
|
|
505
|
+
dst_tensor_type = (
|
|
506
|
+
dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type
|
|
507
|
+
)
|
|
508
|
+
src_tensor_type = (
|
|
509
|
+
src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type
|
|
510
|
+
)
|
|
511
|
+
if dst_tensor_type.elem_type != src_tensor_type.elem_type:
|
|
512
|
+
node_id = node.name or node.op_type
|
|
513
|
+
raise ValueError(
|
|
514
|
+
f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: "
|
|
515
|
+
f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs "
|
|
516
|
+
f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}"
|
|
517
|
+
)
|
|
518
|
+
if dst_tensor_type.HasField("shape"):
|
|
519
|
+
for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)):
|
|
520
|
+
if ds[0] != ds[1]:
|
|
521
|
+
new_dim = onnx.TensorShapeProto.Dimension()
|
|
522
|
+
if not is_sequence(dst_type):
|
|
523
|
+
new_dim.dim_param = str(self.new_symbolic_dim_from_output(node, out_idx, di))
|
|
524
|
+
dst_tensor_type.shape.dim[di].CopyFrom(new_dim)
|
|
525
|
+
else:
|
|
526
|
+
dst_tensor_type.CopyFrom(src_tensor_type)
|
|
527
|
+
|
|
528
|
+
# ONNX inference helpers
|
|
529
|
+
def onnx_infer_single_node(self, node):
|
|
530
|
+
"""Performs ONNX shape inference for a single node."""
|
|
531
|
+
skip_infer = node.op_type in {
|
|
532
|
+
"If",
|
|
533
|
+
"Loop",
|
|
534
|
+
"Scan",
|
|
535
|
+
"SplitToSequence",
|
|
536
|
+
"ZipMap",
|
|
537
|
+
"Attention",
|
|
538
|
+
"BiasGelu",
|
|
539
|
+
"EmbedLayerNormalization",
|
|
540
|
+
"FastGelu",
|
|
541
|
+
"Gelu",
|
|
542
|
+
"GemmFastGelu",
|
|
543
|
+
"LayerNormalization",
|
|
544
|
+
"LongformerAttention",
|
|
545
|
+
"DequantizeLinear",
|
|
546
|
+
"QuantizeLinear",
|
|
547
|
+
"RelativePositionBias",
|
|
548
|
+
"RemovePadding",
|
|
549
|
+
"RestorePadding",
|
|
550
|
+
"SimplifiedLayerNormalization",
|
|
551
|
+
"SkipLayerNormalization",
|
|
552
|
+
"SkipSimplifiedLayerNormalization",
|
|
553
|
+
"PackedAttention",
|
|
554
|
+
"PythonOp",
|
|
555
|
+
"MultiHeadAttention",
|
|
556
|
+
"GroupNorm",
|
|
557
|
+
"SkipGroupNorm",
|
|
558
|
+
"BiasSplitGelu",
|
|
559
|
+
"BiasAdd",
|
|
560
|
+
"NhwcConv",
|
|
561
|
+
"QuickGelu",
|
|
562
|
+
"RotaryEmbedding",
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
if not skip_infer:
|
|
566
|
+
initializers = []
|
|
567
|
+
if (get_opset(self.out_mp_) >= 9) and (
|
|
568
|
+
node.op_type == "Unsqueeze"
|
|
569
|
+
or node.op_type == "ReduceMax"
|
|
570
|
+
or node.op_type == "ReduceMean"
|
|
571
|
+
or node.op_type == "DFT"
|
|
572
|
+
or node.op_type == "ReduceL2"
|
|
573
|
+
or node.op_type == "ReduceMin"
|
|
574
|
+
):
|
|
575
|
+
initializers = [
|
|
576
|
+
self.initializers_[name]
|
|
577
|
+
for name in node.input
|
|
578
|
+
if (name in self.initializers_ and name not in self.graph_inputs_)
|
|
579
|
+
]
|
|
580
|
+
|
|
581
|
+
if (
|
|
582
|
+
node.op_type
|
|
583
|
+
in {
|
|
584
|
+
"Add",
|
|
585
|
+
"Sub",
|
|
586
|
+
"Mul",
|
|
587
|
+
"Div",
|
|
588
|
+
"MatMul",
|
|
589
|
+
"MatMulInteger",
|
|
590
|
+
"MatMulInteger16",
|
|
591
|
+
"Where",
|
|
592
|
+
"Sum",
|
|
593
|
+
}
|
|
594
|
+
and node.output[0] in self.known_vi_
|
|
595
|
+
):
|
|
596
|
+
vi = self.known_vi_[node.output[0]]
|
|
597
|
+
out_rank = len(get_shape_from_type_proto(vi.type))
|
|
598
|
+
in_shapes = [self.get_shape(node, i) for i in range(len(node.input))]
|
|
599
|
+
for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
|
|
600
|
+
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
|
|
601
|
+
if len(in_dims) > 1:
|
|
602
|
+
self.check_merged_dims(in_dims, allow_broadcast=True)
|
|
603
|
+
|
|
604
|
+
tmp_graph = helper.make_graph(
|
|
605
|
+
[node],
|
|
606
|
+
"tmp",
|
|
607
|
+
[self.known_vi_[i] for i in node.input if i],
|
|
608
|
+
[make_named_value_info(i) for i in node.output],
|
|
609
|
+
initializers,
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
kwargs = {}
|
|
613
|
+
kwargs["opset_imports"] = self.out_mp_.opset_import
|
|
614
|
+
kwargs["ir_version"] = self.out_mp_.ir_version
|
|
615
|
+
|
|
616
|
+
model = helper.make_model(tmp_graph, **kwargs)
|
|
617
|
+
model = shape_inference.infer_shapes(model)
|
|
618
|
+
|
|
619
|
+
for i_o in range(len(node.output)):
|
|
620
|
+
o = node.output[i_o]
|
|
621
|
+
if o:
|
|
622
|
+
out = model.graph.output[i_o]
|
|
623
|
+
if not out.type.WhichOneof("value") and o in self.known_vi_:
|
|
624
|
+
continue
|
|
625
|
+
|
|
626
|
+
vi = self.out_mp_.graph.value_info.add()
|
|
627
|
+
if not skip_infer:
|
|
628
|
+
vi.CopyFrom(out)
|
|
629
|
+
else:
|
|
630
|
+
vi.name = o
|
|
631
|
+
self.known_vi_[o] = vi
|
|
632
|
+
|
|
633
|
+
# Helper methods for checking none dims
|
|
634
|
+
def is_none_dim(self, dim_value):
|
|
635
|
+
"""Check if dimension value is unknown."""
|
|
636
|
+
if type(dim_value) != str:
|
|
637
|
+
return False
|
|
638
|
+
return dim_value not in self.symbolic_dims_ if "unk__" in dim_value else False
|
|
639
|
+
|
|
640
|
+
def is_shape_contains_none_dim(self, out_shape):
|
|
641
|
+
"""Check if any dimension in the given shape is unknown."""
|
|
642
|
+
for out in out_shape:
|
|
643
|
+
if self.is_none_dim(out):
|
|
644
|
+
return out
|
|
645
|
+
return None
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Attention-related contrib operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import attention
|
|
7
|
+
from . import multi_head_attention
|
|
8
|
+
from . import packed_attention
|
|
9
|
+
from . import packed_multi_head_attention
|
|
10
|
+
from . import gated_relative_position_bias
|
|
11
|
+
from . import multi_scale_deformable_attn
|
|
12
|
+
from . import longformer_attention
|
|
13
|
+
from . import decoder_masked_mha
|
|
14
|
+
from . import remove_padding
|
|
15
|
+
from . import restore_padding
|