onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Symbolic Shape Inference Module
|
|
6
|
+
|
|
7
|
+
This module provides symbolic shape inference for ONNX models. It replaces the
|
|
8
|
+
monolithic SymbolicShapeInference class with a modular, handler-based architecture.
|
|
9
|
+
|
|
10
|
+
Usage:
|
|
11
|
+
from onnxslim.core.shape_inference import ShapeInferencer
|
|
12
|
+
|
|
13
|
+
model = onnx.load("model.onnx")
|
|
14
|
+
model_with_shapes = ShapeInferencer.infer_shapes(model)
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
|
|
19
|
+
import onnx
|
|
20
|
+
import sympy
|
|
21
|
+
from onnx import helper
|
|
22
|
+
|
|
23
|
+
from .context import InferenceContext
|
|
24
|
+
from .registry import get_all_aten_handlers, get_all_shape_handlers, get_aten_handler, get_shape_handler
|
|
25
|
+
from .utils import (
|
|
26
|
+
get_attribute,
|
|
27
|
+
get_opset,
|
|
28
|
+
get_shape_from_type_proto,
|
|
29
|
+
get_shape_from_value_info,
|
|
30
|
+
is_literal,
|
|
31
|
+
is_sequence,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
# Import all handlers to trigger registration
|
|
35
|
+
from . import aten_ops # noqa: F401
|
|
36
|
+
from . import contrib_ops # noqa: F401
|
|
37
|
+
from . import standard_ops # noqa: F401
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class ShapeInferencer:
|
|
43
|
+
"""Main class for performing symbolic shape inference on ONNX models."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0, prefix=""):
|
|
46
|
+
"""Initialize the ShapeInferencer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
int_max: Maximum value for unbounded integers.
|
|
50
|
+
auto_merge: Whether to automatically merge conflicting dimensions.
|
|
51
|
+
guess_output_rank: Whether to guess output rank from input.
|
|
52
|
+
verbose: Logging verbosity level.
|
|
53
|
+
prefix: Prefix for generated symbolic dimension names.
|
|
54
|
+
"""
|
|
55
|
+
self.int_max_ = int_max
|
|
56
|
+
self.auto_merge_ = auto_merge
|
|
57
|
+
self.guess_output_rank_ = guess_output_rank
|
|
58
|
+
self.verbose_ = verbose
|
|
59
|
+
self.prefix_ = prefix
|
|
60
|
+
|
|
61
|
+
def _infer_impl(self, ctx, start_sympy_data=None):
|
|
62
|
+
"""Main inference implementation loop."""
|
|
63
|
+
ctx.sympy_data_ = start_sympy_data or {}
|
|
64
|
+
ctx.apply_suggested_merge(graph_input_only=True)
|
|
65
|
+
ctx.input_symbols_ = set()
|
|
66
|
+
|
|
67
|
+
# Process graph inputs
|
|
68
|
+
for i in ctx.out_mp_.graph.input:
|
|
69
|
+
input_shape = get_shape_from_value_info(i)
|
|
70
|
+
if input_shape is None:
|
|
71
|
+
continue
|
|
72
|
+
|
|
73
|
+
if is_sequence(i.type):
|
|
74
|
+
input_dims = i.type.sequence_type.elem_type.tensor_type.shape.dim
|
|
75
|
+
else:
|
|
76
|
+
input_dims = i.type.tensor_type.shape.dim
|
|
77
|
+
|
|
78
|
+
for i_dim, dim in enumerate(input_shape):
|
|
79
|
+
if dim is None:
|
|
80
|
+
input_dims[i_dim].dim_param = str(ctx.new_symbolic_dim(i.name, i_dim))
|
|
81
|
+
|
|
82
|
+
ctx.input_symbols_.update([d for d in input_shape if type(d) == str])
|
|
83
|
+
|
|
84
|
+
for s in ctx.input_symbols_:
|
|
85
|
+
if s in ctx.suggested_merge_:
|
|
86
|
+
s_merge = ctx.suggested_merge_[s]
|
|
87
|
+
assert s_merge in ctx.symbolic_dims_
|
|
88
|
+
ctx.symbolic_dims_[s] = ctx.symbolic_dims_[s_merge]
|
|
89
|
+
else:
|
|
90
|
+
ctx.symbolic_dims_[s] = sympy.Symbol(s, integer=True, positive=True)
|
|
91
|
+
|
|
92
|
+
# Compute prerequisite for node for topological sort
|
|
93
|
+
prereq_for_node = {}
|
|
94
|
+
|
|
95
|
+
def get_prereq(node):
|
|
96
|
+
names = {i for i in node.input if i}
|
|
97
|
+
subgraphs = []
|
|
98
|
+
if node.op_type == "If":
|
|
99
|
+
subgraphs = [get_attribute(node, "then_branch"), get_attribute(node, "else_branch")]
|
|
100
|
+
elif node.op_type in {"Loop", "Scan"}:
|
|
101
|
+
subgraphs = [get_attribute(node, "body")]
|
|
102
|
+
for g in subgraphs:
|
|
103
|
+
g_outputs_and_initializers = {i.name for i in g.initializer}
|
|
104
|
+
g_prereq = set()
|
|
105
|
+
for n in g.node:
|
|
106
|
+
g_outputs_and_initializers.update(n.output)
|
|
107
|
+
for n in g.node:
|
|
108
|
+
g_prereq.update([i for i in get_prereq(n) if i not in g_outputs_and_initializers])
|
|
109
|
+
names.update(g_prereq)
|
|
110
|
+
for i in g.input:
|
|
111
|
+
if i.name in names:
|
|
112
|
+
names.remove(i.name)
|
|
113
|
+
return names
|
|
114
|
+
|
|
115
|
+
for n in ctx.out_mp_.graph.node:
|
|
116
|
+
prereq_for_node[n.output[0]] = get_prereq(n)
|
|
117
|
+
|
|
118
|
+
# Topological sort nodes
|
|
119
|
+
sorted_nodes = []
|
|
120
|
+
sorted_known_vi = {i.name for i in list(ctx.out_mp_.graph.input) + list(ctx.out_mp_.graph.initializer)}
|
|
121
|
+
if any(o.name in sorted_known_vi for o in ctx.out_mp_.graph.output):
|
|
122
|
+
sorted_nodes = ctx.out_mp_.graph.node
|
|
123
|
+
else:
|
|
124
|
+
while any(o.name not in sorted_known_vi for o in ctx.out_mp_.graph.output):
|
|
125
|
+
old_sorted_nodes_len = len(sorted_nodes)
|
|
126
|
+
for node in ctx.out_mp_.graph.node:
|
|
127
|
+
if node.output[0] not in sorted_known_vi and all(
|
|
128
|
+
i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i
|
|
129
|
+
):
|
|
130
|
+
sorted_known_vi.update(node.output)
|
|
131
|
+
sorted_nodes.append(node)
|
|
132
|
+
if old_sorted_nodes_len == len(sorted_nodes) and not all(
|
|
133
|
+
o.name in sorted_known_vi for o in ctx.out_mp_.graph.output
|
|
134
|
+
):
|
|
135
|
+
raise Exception("Invalid model with cyclic graph")
|
|
136
|
+
|
|
137
|
+
# Get handlers
|
|
138
|
+
shape_handlers = get_all_shape_handlers()
|
|
139
|
+
aten_handlers = get_all_aten_handlers()
|
|
140
|
+
|
|
141
|
+
# Process each node
|
|
142
|
+
for node in sorted_nodes:
|
|
143
|
+
assert all([i in ctx.known_vi_ for i in node.input if i])
|
|
144
|
+
ctx.onnx_infer_single_node(node)
|
|
145
|
+
known_aten_op = False
|
|
146
|
+
|
|
147
|
+
# Try standard handlers first
|
|
148
|
+
handler = get_shape_handler(node.op_type)
|
|
149
|
+
if handler is not None:
|
|
150
|
+
handler.infer_shape(node, ctx)
|
|
151
|
+
elif node.op_type == "ConvTranspose":
|
|
152
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
153
|
+
if len(vi.type.tensor_type.shape.dim) == 0:
|
|
154
|
+
vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED
|
|
155
|
+
elif node.op_type == "ATen" and node.domain == "org.pytorch.aten":
|
|
156
|
+
for attr in node.attribute:
|
|
157
|
+
if attr.name == "operator":
|
|
158
|
+
aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s
|
|
159
|
+
aten_handler = get_aten_handler(aten_op_name)
|
|
160
|
+
if aten_handler is not None:
|
|
161
|
+
known_aten_op = True
|
|
162
|
+
aten_handler.infer_shape(node, ctx)
|
|
163
|
+
break
|
|
164
|
+
|
|
165
|
+
if ctx.verbose_ > 2:
|
|
166
|
+
logger.debug(node.op_type + ": " + node.name)
|
|
167
|
+
for i, name in enumerate(node.input):
|
|
168
|
+
logger.debug(f" Input {i}: {name} {'initializer' if name in ctx.initializers_ else ''}")
|
|
169
|
+
|
|
170
|
+
# Handle dimension merging for broadcast ops
|
|
171
|
+
if node.op_type in {
|
|
172
|
+
"Add",
|
|
173
|
+
"Sub",
|
|
174
|
+
"Mul",
|
|
175
|
+
"Div",
|
|
176
|
+
"MatMul",
|
|
177
|
+
"MatMulInteger",
|
|
178
|
+
"MatMulInteger16",
|
|
179
|
+
"Where",
|
|
180
|
+
"Sum",
|
|
181
|
+
}:
|
|
182
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
183
|
+
out_rank = len(get_shape_from_type_proto(vi.type))
|
|
184
|
+
in_shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
|
|
185
|
+
for d in range(out_rank - (2 if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} else 0)):
|
|
186
|
+
in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank]
|
|
187
|
+
if len(in_dims) > 1:
|
|
188
|
+
ctx.check_merged_dims(in_dims, allow_broadcast=True)
|
|
189
|
+
|
|
190
|
+
# Process outputs
|
|
191
|
+
for i_o in range(len(node.output)):
|
|
192
|
+
if node.op_type in {"SkipLayerNormalization", "SkipSimplifiedLayerNormalization"} and i_o in {1, 2}:
|
|
193
|
+
continue
|
|
194
|
+
if node.op_type == "RotaryEmbedding" and len(node.output) > 1:
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
vi = ctx.known_vi_[node.output[i_o]]
|
|
198
|
+
out_type = vi.type
|
|
199
|
+
out_type_kind = out_type.WhichOneof("value")
|
|
200
|
+
|
|
201
|
+
if out_type_kind not in {"tensor_type", "sparse_tensor_type", None}:
|
|
202
|
+
if ctx.verbose_ > 2:
|
|
203
|
+
if out_type_kind == "sequence_type":
|
|
204
|
+
seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value")
|
|
205
|
+
if seq_cls_type == "tensor_type":
|
|
206
|
+
logger.debug(
|
|
207
|
+
f" {node.output[i_o]}: sequence of {str(get_shape_from_value_info(vi))} "
|
|
208
|
+
f"{onnx.TensorProto.DataType.Name(vi.type.sequence_type.elem_type.tensor_type.elem_type)}"
|
|
209
|
+
)
|
|
210
|
+
else:
|
|
211
|
+
logger.debug(f" {node.output[i_o]}: sequence of {seq_cls_type}")
|
|
212
|
+
else:
|
|
213
|
+
logger.debug(f" {node.output[i_o]}: {out_type_kind}")
|
|
214
|
+
continue
|
|
215
|
+
|
|
216
|
+
out_shape = get_shape_from_value_info(vi)
|
|
217
|
+
out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED
|
|
218
|
+
if ctx.verbose_ > 2:
|
|
219
|
+
logger.debug(
|
|
220
|
+
f" {node.output[i_o]}: {out_shape!s} {onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type)}"
|
|
221
|
+
)
|
|
222
|
+
if node.output[i_o] in ctx.sympy_data_:
|
|
223
|
+
logger.debug(" Sympy Data: " + str(ctx.sympy_data_[node.output[i_o]]))
|
|
224
|
+
|
|
225
|
+
if (out_shape is not None and (None in out_shape or ctx.is_shape_contains_none_dim(out_shape))) or out_type_undefined:
|
|
226
|
+
if ctx.auto_merge_:
|
|
227
|
+
if node.op_type in {
|
|
228
|
+
"Add",
|
|
229
|
+
"Sub",
|
|
230
|
+
"Mul",
|
|
231
|
+
"Div",
|
|
232
|
+
"MatMul",
|
|
233
|
+
"MatMulInteger",
|
|
234
|
+
"MatMulInteger16",
|
|
235
|
+
"Concat",
|
|
236
|
+
"Where",
|
|
237
|
+
"Sum",
|
|
238
|
+
"Equal",
|
|
239
|
+
"Less",
|
|
240
|
+
"Greater",
|
|
241
|
+
"LessOrEqual",
|
|
242
|
+
"GreaterOrEqual",
|
|
243
|
+
"Min",
|
|
244
|
+
"Max",
|
|
245
|
+
}:
|
|
246
|
+
shapes = [ctx.get_shape(node, i) for i in range(len(node.input))]
|
|
247
|
+
if node.op_type in {"MatMul", "MatMulInteger", "MatMulInteger16"} and (
|
|
248
|
+
None in out_shape or ctx.is_shape_contains_none_dim(out_shape)
|
|
249
|
+
):
|
|
250
|
+
if None in out_shape:
|
|
251
|
+
idx = out_shape.index(None)
|
|
252
|
+
else:
|
|
253
|
+
idx = out_shape.index(ctx.is_shape_contains_none_dim(out_shape))
|
|
254
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
255
|
+
assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2
|
|
256
|
+
assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2
|
|
257
|
+
elif node.op_type == "Expand":
|
|
258
|
+
shapes = [ctx.get_shape(node, 0), ctx.get_value(node, 1)]
|
|
259
|
+
else:
|
|
260
|
+
shapes = []
|
|
261
|
+
|
|
262
|
+
if shapes:
|
|
263
|
+
for idx in range(len(out_shape)):
|
|
264
|
+
if out_shape[idx] is not None and not ctx.is_none_dim(out_shape[idx]):
|
|
265
|
+
continue
|
|
266
|
+
dim_idx = [len(s) - len(out_shape) + idx for s in shapes]
|
|
267
|
+
if dim_idx:
|
|
268
|
+
ctx.add_suggested_merge(
|
|
269
|
+
[s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx) if i >= 0]
|
|
270
|
+
)
|
|
271
|
+
ctx.run_ = True
|
|
272
|
+
else:
|
|
273
|
+
ctx.run_ = False
|
|
274
|
+
else:
|
|
275
|
+
ctx.run_ = False
|
|
276
|
+
|
|
277
|
+
if not ctx.run_ and handler is None and not known_aten_op:
|
|
278
|
+
is_unknown_op = out_type_undefined and (out_shape is None or len(out_shape) == 0)
|
|
279
|
+
if is_unknown_op:
|
|
280
|
+
out_rank = ctx.get_shape_rank(node, 0) if ctx.guess_output_rank_ else -1
|
|
281
|
+
else:
|
|
282
|
+
out_rank = len(out_shape)
|
|
283
|
+
|
|
284
|
+
if out_rank >= 0:
|
|
285
|
+
new_shape = ctx.new_symbolic_shape(out_rank, node, i_o)
|
|
286
|
+
if out_type_undefined:
|
|
287
|
+
out_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
288
|
+
else:
|
|
289
|
+
out_dtype = vi.type.tensor_type.elem_type
|
|
290
|
+
from .utils import get_shape_from_sympy_shape
|
|
291
|
+
|
|
292
|
+
vi.CopyFrom(
|
|
293
|
+
helper.make_tensor_value_info(vi.name, out_dtype, get_shape_from_sympy_shape(new_shape))
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
if ctx.verbose_ > 0:
|
|
297
|
+
if is_unknown_op:
|
|
298
|
+
logger.debug(f"Possible unknown op: {node.op_type} node: {node.name}, guessing {vi.name} shape")
|
|
299
|
+
if ctx.verbose_ > 2:
|
|
300
|
+
logger.debug(f" {node.output[i_o]}: {new_shape!s} {vi.type.tensor_type.elem_type}")
|
|
301
|
+
ctx.run_ = True
|
|
302
|
+
continue
|
|
303
|
+
|
|
304
|
+
if ctx.verbose_ > 0 or not ctx.auto_merge_ or out_type_undefined:
|
|
305
|
+
logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name)
|
|
306
|
+
logger.debug("node inputs:")
|
|
307
|
+
for i in node.input:
|
|
308
|
+
if i in ctx.known_vi_:
|
|
309
|
+
logger.debug(ctx.known_vi_[i])
|
|
310
|
+
else:
|
|
311
|
+
logger.debug(f"not in known_vi_ for {i}")
|
|
312
|
+
logger.debug("node outputs:")
|
|
313
|
+
for o in node.output:
|
|
314
|
+
if o in ctx.known_vi_:
|
|
315
|
+
logger.debug(ctx.known_vi_[o])
|
|
316
|
+
else:
|
|
317
|
+
logger.debug(f"not in known_vi_ for {o}")
|
|
318
|
+
if ctx.auto_merge_ and not out_type_undefined:
|
|
319
|
+
logger.debug("Merging: " + str(ctx.suggested_merge_))
|
|
320
|
+
return False
|
|
321
|
+
|
|
322
|
+
ctx.run_ = False
|
|
323
|
+
return True
|
|
324
|
+
|
|
325
|
+
def _update_output_from_vi(self, ctx):
|
|
326
|
+
"""Update output attributes using known value information dictionary."""
|
|
327
|
+
for output in ctx.out_mp_.graph.output:
|
|
328
|
+
if output.name in ctx.known_vi_:
|
|
329
|
+
output.CopyFrom(ctx.known_vi_[output.name])
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0):
|
|
333
|
+
"""Perform symbolic shape inference on an ONNX model.
|
|
334
|
+
|
|
335
|
+
Args:
|
|
336
|
+
in_mp: The input ONNX ModelProto.
|
|
337
|
+
int_max: Maximum value for unbounded integers.
|
|
338
|
+
auto_merge: Whether to automatically merge conflicting dimensions.
|
|
339
|
+
guess_output_rank: Whether to guess output rank from input.
|
|
340
|
+
verbose: Logging verbosity level.
|
|
341
|
+
|
|
342
|
+
Returns:
|
|
343
|
+
The model with inferred shapes.
|
|
344
|
+
|
|
345
|
+
Raises:
|
|
346
|
+
Exception: If shape inference is incomplete.
|
|
347
|
+
"""
|
|
348
|
+
onnx_opset = get_opset(in_mp)
|
|
349
|
+
if (not onnx_opset) or onnx_opset < 7:
|
|
350
|
+
logger.warning("Only support models of onnx opset 7 and above.")
|
|
351
|
+
return None
|
|
352
|
+
|
|
353
|
+
inferencer = ShapeInferencer(int_max, auto_merge, guess_output_rank, verbose)
|
|
354
|
+
|
|
355
|
+
# Create inference context
|
|
356
|
+
ctx = InferenceContext(
|
|
357
|
+
in_mp,
|
|
358
|
+
int_max=int_max,
|
|
359
|
+
auto_merge=auto_merge,
|
|
360
|
+
guess_output_rank=guess_output_rank,
|
|
361
|
+
verbose=verbose,
|
|
362
|
+
)
|
|
363
|
+
ctx.preprocess()
|
|
364
|
+
|
|
365
|
+
all_shapes_inferred = False
|
|
366
|
+
while ctx.run_:
|
|
367
|
+
all_shapes_inferred = inferencer._infer_impl(ctx)
|
|
368
|
+
|
|
369
|
+
inferencer._update_output_from_vi(ctx)
|
|
370
|
+
|
|
371
|
+
if not all_shapes_inferred:
|
|
372
|
+
raise Exception("Incomplete symbolic shape inference")
|
|
373
|
+
|
|
374
|
+
return ctx.out_mp_
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
# For backward compatibility
|
|
378
|
+
SymbolicShapeInference = ShapeInferencer
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""PyTorch ATen operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import bitwise_or
|
|
7
|
+
from . import diagonal
|
|
8
|
+
from . import pool2d
|
|
9
|
+
from . import min_max
|
|
10
|
+
from . import multinomial
|
|
11
|
+
from . import unfold
|
|
12
|
+
from . import argmax
|
|
13
|
+
from . import group_norm
|
|
14
|
+
from . import upsample
|
|
15
|
+
from . import embedding
|
|
16
|
+
from . import numpy_t
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen argmax operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
from ..utils import get_shape_from_sympy_shape, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AtenArgmaxHandler(ShapeHandler):
|
|
15
|
+
"""Handler for ATen argmax operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "argmax"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
new_shape = None
|
|
23
|
+
if not node.input[1]:
|
|
24
|
+
# The argmax of the flattened input is returned.
|
|
25
|
+
new_shape = []
|
|
26
|
+
else:
|
|
27
|
+
dim = ctx.try_get_value(node, 1)
|
|
28
|
+
keepdim = ctx.try_get_value(node, 2)
|
|
29
|
+
if keepdim is not None:
|
|
30
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
31
|
+
if dim is not None:
|
|
32
|
+
dim = handle_negative_axis(dim, len(sympy_shape))
|
|
33
|
+
if keepdim:
|
|
34
|
+
sympy_shape[dim] = 1
|
|
35
|
+
else:
|
|
36
|
+
del sympy_shape[dim]
|
|
37
|
+
else:
|
|
38
|
+
rank = len(sympy_shape)
|
|
39
|
+
sympy_shape = ctx.new_symbolic_shape(rank if keepdim else rank - 1, node)
|
|
40
|
+
ctx.update_computed_dims(sympy_shape)
|
|
41
|
+
new_shape = get_shape_from_sympy_shape(sympy_shape)
|
|
42
|
+
if node.output[0] and new_shape is not None:
|
|
43
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
44
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, new_shape))
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
register_aten_handler(AtenArgmaxHandler())
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen bitwise_or operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ..base import ShapeHandler
|
|
9
|
+
from ..registry import register_aten_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class AtenBitwiseOrHandler(ShapeHandler):
|
|
13
|
+
"""Handler for ATen bitwise_or operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "bitwise_or"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
shape0 = ctx.get_shape(node, 0)
|
|
21
|
+
shape1 = ctx.get_shape(node, 1)
|
|
22
|
+
new_shape = ctx.broadcast_shapes(shape0, shape1)
|
|
23
|
+
t0 = ctx.known_vi_[node.input[0]]
|
|
24
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
25
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape))
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
register_aten_handler(AtenBitwiseOrHandler())
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen diagonal operator."""
|
|
5
|
+
|
|
6
|
+
import sympy
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
from ..utils import get_shape_from_sympy_shape, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AtenDiagonalHandler(ShapeHandler):
|
|
15
|
+
"""Handler for ATen diagonal operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "diagonal"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
23
|
+
rank = len(sympy_shape)
|
|
24
|
+
offset = ctx.try_get_value(node, 1)
|
|
25
|
+
dim1 = ctx.try_get_value(node, 2)
|
|
26
|
+
dim2 = ctx.try_get_value(node, 3)
|
|
27
|
+
|
|
28
|
+
assert offset is not None and dim1 is not None and dim2 is not None
|
|
29
|
+
dim1 = handle_negative_axis(dim1, rank)
|
|
30
|
+
dim2 = handle_negative_axis(dim2, rank)
|
|
31
|
+
|
|
32
|
+
new_shape = [val for dim, val in enumerate(sympy_shape) if dim not in {dim1, dim2}]
|
|
33
|
+
shape1 = sympy_shape[dim1]
|
|
34
|
+
shape2 = sympy_shape[dim2]
|
|
35
|
+
if offset >= 0:
|
|
36
|
+
diag_shape = sympy.Max(0, sympy.Min(shape1, shape2 - offset))
|
|
37
|
+
else:
|
|
38
|
+
diag_shape = sympy.Max(0, sympy.Min(shape1 + offset, shape2))
|
|
39
|
+
new_shape.append(diag_shape)
|
|
40
|
+
|
|
41
|
+
if node.output[0]:
|
|
42
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
43
|
+
vi.CopyFrom(
|
|
44
|
+
helper.make_tensor_value_info(
|
|
45
|
+
node.output[0],
|
|
46
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
47
|
+
get_shape_from_sympy_shape(new_shape),
|
|
48
|
+
)
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
register_aten_handler(AtenDiagonalHandler())
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen embedding operator."""
|
|
5
|
+
|
|
6
|
+
from ..base import ShapeHandler
|
|
7
|
+
from ..registry import register_aten_handler
|
|
8
|
+
from ..standard_ops.tensor.gather import GatherHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AtenEmbeddingHandler(ShapeHandler):
|
|
12
|
+
"""Handler for ATen embedding operator (reuses Gather logic)."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "embedding"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
# Embedding uses the same logic as Gather
|
|
20
|
+
GatherHandler().infer_shape(node, ctx)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
register_aten_handler(AtenEmbeddingHandler())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen native_group_norm operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ..base import ShapeHandler
|
|
9
|
+
from ..registry import register_aten_handler
|
|
10
|
+
from ..utils import as_scalar
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AtenGroupNormHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ATen native_group_norm operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "native_group_norm"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
ctx.propagate_shape_and_type(node)
|
|
22
|
+
input_shape = ctx.get_shape(node, 0)
|
|
23
|
+
N = input_shape[0] if input_shape is not None and len(input_shape) != 0 else None
|
|
24
|
+
group = ctx.try_get_value(node, 6)
|
|
25
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
26
|
+
for i in {1, 2}:
|
|
27
|
+
if node.output[i]:
|
|
28
|
+
vi = ctx.known_vi_[node.output[i]]
|
|
29
|
+
vi.CopyFrom(
|
|
30
|
+
helper.make_tensor_value_info(
|
|
31
|
+
node.output[i],
|
|
32
|
+
output_dtype,
|
|
33
|
+
[
|
|
34
|
+
(N if N is not None else str(ctx.new_symbolic_dim_from_output(node, i, 0))),
|
|
35
|
+
(as_scalar(group) if group is not None else str(ctx.new_symbolic_dim_from_output(node, i, 1))),
|
|
36
|
+
],
|
|
37
|
+
)
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_aten_handler(AtenGroupNormHandler())
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen min/max operators."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
from ..utils import get_shape_from_sympy_shape, handle_negative_axis
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AtenMinMaxHandler(ShapeHandler):
|
|
15
|
+
"""Handler for ATen min/max operators."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, op_name):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._op_type = op_name
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def op_type(self) -> str:
|
|
23
|
+
return self._op_type
|
|
24
|
+
|
|
25
|
+
def infer_shape(self, node, ctx) -> None:
|
|
26
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
27
|
+
if len(node.input) == 1:
|
|
28
|
+
vi.CopyFrom(
|
|
29
|
+
helper.make_tensor_value_info(
|
|
30
|
+
node.output[0],
|
|
31
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
32
|
+
[],
|
|
33
|
+
)
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
assert len(node.input) == 3
|
|
37
|
+
keepdim = ctx.try_get_value(node, 2)
|
|
38
|
+
assert keepdim is not None
|
|
39
|
+
dim = ctx.try_get_value(node, 1)
|
|
40
|
+
if dim is None:
|
|
41
|
+
rank = ctx.get_shape_rank(node, 0)
|
|
42
|
+
output_shape = ctx.new_symbolic_shape(rank if keepdim else rank - 1, node)
|
|
43
|
+
else:
|
|
44
|
+
shape = ctx.get_sympy_shape(node, 0)
|
|
45
|
+
dim = handle_negative_axis(dim, len(shape))
|
|
46
|
+
output_shape = shape[:dim]
|
|
47
|
+
if keepdim:
|
|
48
|
+
output_shape += [1]
|
|
49
|
+
output_shape += shape[dim + 1 :]
|
|
50
|
+
|
|
51
|
+
output_shape = get_shape_from_sympy_shape(output_shape)
|
|
52
|
+
vi.CopyFrom(
|
|
53
|
+
helper.make_tensor_value_info(
|
|
54
|
+
node.output[0],
|
|
55
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
56
|
+
output_shape,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
vi1 = ctx.known_vi_[node.output[1]]
|
|
60
|
+
vi1.CopyFrom(helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT64, output_shape))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
register_aten_handler(AtenMinMaxHandler("max"))
|
|
64
|
+
register_aten_handler(AtenMinMaxHandler("min"))
|