onnxslim 0.1.81__py3-none-any.whl → 0.1.83__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +84 -3
- onnxslim/core/pattern/fusion/convadd.py +21 -1
- onnxslim/core/pattern/fusion/convbn.py +21 -4
- onnxslim/core/pattern/fusion/convmul.py +23 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/METADATA +21 -11
- onnxslim-0.1.83.dist-info/RECORD +187 -0
- onnxslim-0.1.81.dist-info/RECORD +0 -63
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.81.dist-info → onnxslim-0.1.83.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Attention operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AttentionHandler(ShapeHandler):
|
|
14
|
+
"""Handler for Attention operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "Attention"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
shape = ctx.get_shape(node, 0)
|
|
22
|
+
shape_weights = ctx.get_shape(node, 1)
|
|
23
|
+
shape_bias = ctx.try_get_shape(node, 2)
|
|
24
|
+
if shape_bias is not None:
|
|
25
|
+
assert len(shape_bias) == 1
|
|
26
|
+
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
|
|
27
|
+
if shape and len(shape) == 3:
|
|
28
|
+
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
|
|
29
|
+
if qkv_hidden_sizes_attr is not None:
|
|
30
|
+
assert len(qkv_hidden_sizes_attr) == 3
|
|
31
|
+
shape[2] = int(qkv_hidden_sizes_attr[2])
|
|
32
|
+
elif isinstance(tripled_hidden_size, int):
|
|
33
|
+
shape[2] = int(tripled_hidden_size / 3)
|
|
34
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
35
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
36
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
|
37
|
+
|
|
38
|
+
if len(node.output) > 1:
|
|
39
|
+
input_shape = ctx.get_shape(node, 0)
|
|
40
|
+
past_shape = ctx.get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
|
|
41
|
+
mask_shape = ctx.get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []
|
|
42
|
+
|
|
43
|
+
if past_shape and len(past_shape) == 5:
|
|
44
|
+
if mask_shape and len(mask_shape) in {2, 3}:
|
|
45
|
+
past_shape[3] = mask_shape[-1]
|
|
46
|
+
elif input_shape and len(input_shape) == 3:
|
|
47
|
+
if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
|
|
48
|
+
past_shape[3] = input_shape[1] + past_shape[3]
|
|
49
|
+
else:
|
|
50
|
+
past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
|
|
51
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
52
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
53
|
+
else:
|
|
54
|
+
num_heads = get_attribute(node, "num_heads")
|
|
55
|
+
head_size = input_shape[2] // num_heads
|
|
56
|
+
present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size]
|
|
57
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
58
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
register_shape_handler(AttentionHandler())
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for DecoderMaskedMultiHeadAttention operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DecoderMaskedMultiHeadAttentionHandler(ShapeHandler):
|
|
13
|
+
"""Handler for DecoderMaskedMultiHeadAttention operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "DecoderMaskedMultiHeadAttention"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
query_shape = ctx.get_shape(node, 0)
|
|
21
|
+
if query_shape is not None:
|
|
22
|
+
output_shape = query_shape
|
|
23
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
24
|
+
assert output_dtype is not None
|
|
25
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
26
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
27
|
+
|
|
28
|
+
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
|
29
|
+
past_shape = ctx.try_get_shape(node, 5)
|
|
30
|
+
if past_shape is not None:
|
|
31
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
32
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
33
|
+
vi = ctx.known_vi_[node.output[2]]
|
|
34
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
register_shape_handler(DecoderMaskedMultiHeadAttentionHandler())
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GatedRelativePositionBias operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class GatedRelativePositionBiasHandler(ShapeHandler):
|
|
14
|
+
"""Handler for GatedRelativePositionBias operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "GatedRelativePositionBias"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
num_heads = get_attribute(node, "num_heads")
|
|
22
|
+
token_offset_shape = ctx.try_get_shape(node, 6)
|
|
23
|
+
if token_offset_shape is not None:
|
|
24
|
+
output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]]
|
|
25
|
+
else:
|
|
26
|
+
query_layer_shape = ctx.get_shape(node, 0)
|
|
27
|
+
assert query_layer_shape is not None and len(query_layer_shape) == 3
|
|
28
|
+
output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]]
|
|
29
|
+
|
|
30
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
31
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
32
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
register_shape_handler(GatedRelativePositionBiasHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for LongformerAttention operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class LongformerAttentionHandler(ShapeHandler):
|
|
11
|
+
"""Handler for LongformerAttention operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "LongformerAttention"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(LongformerAttentionHandler())
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MultiHeadAttention operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class MultiHeadAttentionHandler(ShapeHandler):
|
|
14
|
+
"""Handler for MultiHeadAttention operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "MultiHeadAttention"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
query_shape = ctx.get_shape(node, 0)
|
|
22
|
+
total_sequence_length = None
|
|
23
|
+
output_dtype = None
|
|
24
|
+
if query_shape is not None:
|
|
25
|
+
if len(query_shape) == 3:
|
|
26
|
+
key_shape = ctx.try_get_shape(node, 1)
|
|
27
|
+
output_shape = query_shape
|
|
28
|
+
if key_shape is not None and len(key_shape) == 3:
|
|
29
|
+
value_shape = ctx.try_get_shape(node, 2)
|
|
30
|
+
if value_shape is not None and len(value_shape) == 3:
|
|
31
|
+
output_shape[2] = value_shape[2]
|
|
32
|
+
total_sequence_length = key_shape[1]
|
|
33
|
+
|
|
34
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
35
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
36
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
37
|
+
|
|
38
|
+
elif len(query_shape) == 5:
|
|
39
|
+
if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
|
|
40
|
+
output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
|
|
41
|
+
else:
|
|
42
|
+
output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
|
|
43
|
+
|
|
44
|
+
total_sequence_length = query_shape[1]
|
|
45
|
+
|
|
46
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
47
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
48
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
49
|
+
|
|
50
|
+
if len(node.output) > 1:
|
|
51
|
+
batch_size = query_shape[0]
|
|
52
|
+
num_heads = get_attribute(node, "num_heads")
|
|
53
|
+
|
|
54
|
+
head_size = None
|
|
55
|
+
if len(query_shape) == 3:
|
|
56
|
+
head_size = (
|
|
57
|
+
int(query_shape[2] / num_heads)
|
|
58
|
+
if isinstance(query_shape[2], int)
|
|
59
|
+
else f"{query_shape[2]}/{num_heads}"
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
head_size = query_shape[4]
|
|
63
|
+
|
|
64
|
+
past_shape = ctx.try_get_shape(node, 6)
|
|
65
|
+
|
|
66
|
+
if past_shape is not None:
|
|
67
|
+
if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
|
|
68
|
+
total_sequence_length = past_shape[2] + total_sequence_length
|
|
69
|
+
else:
|
|
70
|
+
total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
|
|
71
|
+
|
|
72
|
+
present_shape = [batch_size, num_heads, total_sequence_length, head_size]
|
|
73
|
+
|
|
74
|
+
assert output_dtype is not None
|
|
75
|
+
if len(node.output) > 2 and node.output[1] and node.output[2]:
|
|
76
|
+
vi = ctx.known_vi_[node.output[1]]
|
|
77
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
78
|
+
vi = ctx.known_vi_[node.output[2]]
|
|
79
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
register_shape_handler(MultiHeadAttentionHandler())
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for MultiScaleDeformableAttnTRT operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MultiScaleDeformableAttnTRTHandler(ShapeHandler):
|
|
13
|
+
"""Handler for MultiScaleDeformableAttnTRT operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "MultiScaleDeformableAttnTRT"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
shape_value = ctx.try_get_shape(node, 0)
|
|
21
|
+
sampling_locations = ctx.try_get_shape(node, 3)
|
|
22
|
+
output_shape = shape_value
|
|
23
|
+
output_shape[1] = sampling_locations[1]
|
|
24
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
25
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
26
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
register_shape_handler(MultiScaleDeformableAttnTRTHandler())
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for PackedAttention operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
from ...utils import get_attribute
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class PackedAttentionHandler(ShapeHandler):
|
|
14
|
+
"""Handler for PackedAttention operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "PackedAttention"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
shape = ctx.get_shape(node, 0)
|
|
22
|
+
shape_weights = ctx.get_shape(node, 1)
|
|
23
|
+
shape_bias = ctx.try_get_shape(node, 2)
|
|
24
|
+
if shape_bias is not None:
|
|
25
|
+
assert len(shape_bias) == 1
|
|
26
|
+
tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
|
|
27
|
+
if shape and len(shape) == 2:
|
|
28
|
+
qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
|
|
29
|
+
if qkv_hidden_sizes_attr is not None:
|
|
30
|
+
assert len(qkv_hidden_sizes_attr) == 3
|
|
31
|
+
shape[1] = int(qkv_hidden_sizes_attr[2])
|
|
32
|
+
elif isinstance(tripled_hidden_size, int):
|
|
33
|
+
shape[1] = int(tripled_hidden_size / 3)
|
|
34
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
35
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
36
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
register_shape_handler(PackedAttentionHandler())
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for PackedMultiHeadAttention operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class PackedMultiHeadAttentionHandler(ShapeHandler):
|
|
13
|
+
"""Handler for PackedMultiHeadAttention operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "PackedMultiHeadAttention"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
shape_value = ctx.try_get_shape(node, 2)
|
|
21
|
+
if shape_value is not None and len(shape_value) == 2:
|
|
22
|
+
output_shape = shape_value
|
|
23
|
+
else:
|
|
24
|
+
shape_query = ctx.get_shape(node, 0)
|
|
25
|
+
assert shape_query is not None and len(shape_query) == 4
|
|
26
|
+
output_shape = [shape_query[0], shape_query[1] * shape_query[3]]
|
|
27
|
+
|
|
28
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
29
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
30
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
register_shape_handler(PackedMultiHeadAttentionHandler())
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for RemovePadding operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class RemovePaddingHandler(ShapeHandler):
|
|
14
|
+
"""Handler for RemovePadding operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "RemovePadding"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
shape = ctx.get_shape(node, 0)
|
|
22
|
+
if shape and len(shape) == 3:
|
|
23
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
24
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
25
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]]))
|
|
26
|
+
|
|
27
|
+
vi_token_offset = ctx.known_vi_[node.output[1]]
|
|
28
|
+
vi_token_offset.CopyFrom(
|
|
29
|
+
helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]])
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
vi_cumulated_seq_len = ctx.known_vi_[node.output[2]]
|
|
33
|
+
vi_cumulated_seq_len.CopyFrom(
|
|
34
|
+
helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"])
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
vi_max_seq_len = ctx.known_vi_[node.output[3]]
|
|
38
|
+
vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_shape_handler(RemovePaddingHandler())
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for RestorePadding operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class RestorePaddingHandler(ShapeHandler):
|
|
13
|
+
"""Handler for RestorePadding operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "RestorePadding"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
shape_input = ctx.get_shape(node, 0)
|
|
21
|
+
shape_token_offset = ctx.get_shape(node, 1)
|
|
22
|
+
if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2:
|
|
23
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
24
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
25
|
+
output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]]
|
|
26
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
register_shape_handler(RestorePaddingHandler())
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Miscellaneous contrib operator shape handlers."""
|
|
5
|
+
|
|
6
|
+
from . import bias_gelu
|
|
7
|
+
from . import fast_gelu
|
|
8
|
+
from . import gelu
|
|
9
|
+
from . import quick_gelu
|
|
10
|
+
from . import gemm_fast_gelu
|
|
11
|
+
from . import gemm_float8
|
|
12
|
+
from . import bias_split_gelu
|
|
13
|
+
from . import bias_add
|
|
14
|
+
from . import rotary_embedding
|
|
15
|
+
from . import python_op
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for BiasAdd operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BiasAddHandler(ShapeHandler):
|
|
11
|
+
"""Handler for BiasAdd operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "BiasAdd"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(BiasAddHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for BiasGelu operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BiasGeluHandler(ShapeHandler):
|
|
11
|
+
"""Handler for BiasGelu operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "BiasGelu"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(BiasGeluHandler())
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for BiasSplitGelu operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ...base import ShapeHandler
|
|
9
|
+
from ...registry import register_shape_handler
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BiasSplitGeluHandler(ShapeHandler):
|
|
13
|
+
"""Handler for BiasSplitGelu operator."""
|
|
14
|
+
|
|
15
|
+
@property
|
|
16
|
+
def op_type(self) -> str:
|
|
17
|
+
return "BiasSplitGelu"
|
|
18
|
+
|
|
19
|
+
def infer_shape(self, node, ctx) -> None:
|
|
20
|
+
input_shape = ctx.get_shape(node, 0)
|
|
21
|
+
bias_shape = ctx.get_shape(node, 1)
|
|
22
|
+
if input_shape and bias_shape and isinstance(bias_shape[0], int):
|
|
23
|
+
output_shape = input_shape
|
|
24
|
+
output_shape[2] = int(bias_shape[0] / 2)
|
|
25
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
26
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
27
|
+
vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
register_shape_handler(BiasSplitGeluHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for FastGelu operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FastGeluHandler(ShapeHandler):
|
|
11
|
+
"""Handler for FastGelu operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "FastGelu"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(FastGeluHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for Gelu operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GeluHandler(ShapeHandler):
|
|
11
|
+
"""Handler for Gelu operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "Gelu"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.propagate_shape_and_type(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(GeluHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GemmFastGelu operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GemmFastGeluHandler(ShapeHandler):
|
|
11
|
+
"""Handler for GemmFastGelu operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "GemmFastGelu"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.compute_matmul_shape(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(GemmFastGeluHandler())
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for GemmFloat8 operator."""
|
|
5
|
+
|
|
6
|
+
from ...base import ShapeHandler
|
|
7
|
+
from ...registry import register_shape_handler
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class GemmFloat8Handler(ShapeHandler):
|
|
11
|
+
"""Handler for GemmFloat8 operator."""
|
|
12
|
+
|
|
13
|
+
@property
|
|
14
|
+
def op_type(self) -> str:
|
|
15
|
+
return "GemmFloat8"
|
|
16
|
+
|
|
17
|
+
def infer_shape(self, node, ctx) -> None:
|
|
18
|
+
ctx.compute_matmul_shape(node)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
register_shape_handler(GemmFloat8Handler())
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for PythonOp operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ...base import ShapeHandler
|
|
10
|
+
from ...registry import register_shape_handler
|
|
11
|
+
from ...utils import get_attribute, get_shape_from_sympy_shape
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class PythonOpHandler(ShapeHandler):
|
|
15
|
+
"""Handler for PythonOp operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "PythonOp"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
output_tensor_types = get_attribute(node, "output_tensor_types")
|
|
23
|
+
assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
|
|
24
|
+
output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
|
|
25
|
+
assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from onnxruntime.capi._pybind_state import get_shape_inference_function
|
|
29
|
+
|
|
30
|
+
func_name = get_attribute(node, "func_name").decode()
|
|
31
|
+
shape_inferer = get_shape_inference_function(func_name)
|
|
32
|
+
except ImportError:
|
|
33
|
+
shape_inferer = None
|
|
34
|
+
|
|
35
|
+
# Set the context output separately.
|
|
36
|
+
# The first output is torch.autograd.Function's context.
|
|
37
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
38
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
|
|
39
|
+
|
|
40
|
+
if shape_inferer is not None:
|
|
41
|
+
input_shapes = []
|
|
42
|
+
input_dtypes = []
|
|
43
|
+
for input_index in range(len(node.input)):
|
|
44
|
+
shape = ctx.get_shape(node, input_index)
|
|
45
|
+
input_shapes.append(shape)
|
|
46
|
+
input_dtype = ctx.known_vi_[node.input[input_index]].type.tensor_type.elem_type
|
|
47
|
+
input_dtypes.append(input_dtype)
|
|
48
|
+
output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
|
|
49
|
+
assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
|
|
50
|
+
f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
|
|
51
|
+
f"but expected {len(node.output) - 1} outputs."
|
|
52
|
+
)
|
|
53
|
+
for i in range(len(node.output) - 1):
|
|
54
|
+
output_index = i + 1
|
|
55
|
+
vi = ctx.known_vi_[node.output[output_index]]
|
|
56
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]))
|
|
57
|
+
else:
|
|
58
|
+
# General shape inference for PythonOp.
|
|
59
|
+
for i in range(len(node.output) - 1):
|
|
60
|
+
vi = ctx.known_vi_[node.output[i + 1]]
|
|
61
|
+
sympy_shape = ctx.new_symbolic_shape(output_tensor_ranks[i], node)
|
|
62
|
+
shape = get_shape_from_sympy_shape(sympy_shape)
|
|
63
|
+
value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
|
|
64
|
+
vi.CopyFrom(value_info)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
register_shape_handler(PythonOpHandler())
|