onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen multinomial operator."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
from ..utils import get_shape_from_sympy_shape
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AtenMultinomialHandler(ShapeHandler):
|
|
15
|
+
"""Handler for ATen multinomial operator."""
|
|
16
|
+
|
|
17
|
+
@property
|
|
18
|
+
def op_type(self) -> str:
|
|
19
|
+
return "multinomial"
|
|
20
|
+
|
|
21
|
+
def infer_shape(self, node, ctx) -> None:
|
|
22
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
23
|
+
rank = len(sympy_shape)
|
|
24
|
+
assert rank in {1, 2}
|
|
25
|
+
num_samples = ctx.try_get_value(node, 1)
|
|
26
|
+
di = rank - 1
|
|
27
|
+
last_dim = num_samples or str(ctx.new_symbolic_dim_from_output(node, 0, di))
|
|
28
|
+
output_shape = [*sympy_shape[:-1], last_dim]
|
|
29
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
30
|
+
vi.CopyFrom(
|
|
31
|
+
helper.make_tensor_value_info(
|
|
32
|
+
node.output[0],
|
|
33
|
+
onnx.TensorProto.INT64,
|
|
34
|
+
get_shape_from_sympy_shape(output_shape),
|
|
35
|
+
)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
register_aten_handler(AtenMultinomialHandler())
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen numpy_T operator."""
|
|
5
|
+
|
|
6
|
+
from ..base import ShapeHandler
|
|
7
|
+
from ..registry import register_aten_handler
|
|
8
|
+
from ..standard_ops.tensor.transpose import TransposeHandler
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AtenNumpyTHandler(ShapeHandler):
|
|
12
|
+
"""Handler for ATen numpy_T operator (reuses Transpose logic)."""
|
|
13
|
+
|
|
14
|
+
@property
|
|
15
|
+
def op_type(self) -> str:
|
|
16
|
+
return "numpy_T"
|
|
17
|
+
|
|
18
|
+
def infer_shape(self, node, ctx) -> None:
|
|
19
|
+
TransposeHandler().infer_shape(node, ctx)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
register_aten_handler(AtenNumpyTHandler())
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen pooling operators."""
|
|
5
|
+
|
|
6
|
+
import onnx
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
from ..utils import get_shape_from_sympy_shape
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AtenPool2dHandler(ShapeHandler):
|
|
15
|
+
"""Handler for ATen pooling operators (max_pool2d_with_indices, avg_pool2d, _adaptive_avg_pool2d)."""
|
|
16
|
+
|
|
17
|
+
def __init__(self, op_name):
|
|
18
|
+
super().__init__()
|
|
19
|
+
self._op_type = op_name
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def op_type(self) -> str:
|
|
23
|
+
return self._op_type
|
|
24
|
+
|
|
25
|
+
def infer_shape(self, node, ctx) -> None:
|
|
26
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
27
|
+
assert len(sympy_shape) == 4
|
|
28
|
+
sympy_shape[-2:] = [ctx.new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}]
|
|
29
|
+
ctx.update_computed_dims(sympy_shape)
|
|
30
|
+
for i, o in enumerate(node.output):
|
|
31
|
+
if not o:
|
|
32
|
+
continue
|
|
33
|
+
vi = ctx.known_vi_[o]
|
|
34
|
+
elem_type = onnx.TensorProto.INT64 if i == 1 else ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
35
|
+
vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
register_aten_handler(AtenPool2dHandler("max_pool2d_with_indices"))
|
|
39
|
+
register_aten_handler(AtenPool2dHandler("avg_pool2d"))
|
|
40
|
+
register_aten_handler(AtenPool2dHandler("_adaptive_avg_pool2d"))
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen unfold operator."""
|
|
5
|
+
|
|
6
|
+
from onnx import helper
|
|
7
|
+
|
|
8
|
+
from ..base import ShapeHandler
|
|
9
|
+
from ..registry import register_aten_handler
|
|
10
|
+
from ..utils import get_shape_from_sympy_shape
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AtenUnfoldHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ATen unfold operator."""
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def op_type(self) -> str:
|
|
18
|
+
return "unfold"
|
|
19
|
+
|
|
20
|
+
def infer_shape(self, node, ctx) -> None:
|
|
21
|
+
sympy_shape = ctx.get_sympy_shape(node, 0)
|
|
22
|
+
dimension = ctx.try_get_value(node, 1)
|
|
23
|
+
size = ctx.try_get_value(node, 2)
|
|
24
|
+
step = ctx.try_get_value(node, 3)
|
|
25
|
+
if dimension is not None and size is not None and step is not None:
|
|
26
|
+
assert dimension < len(sympy_shape)
|
|
27
|
+
sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
|
|
28
|
+
sympy_shape.append(size)
|
|
29
|
+
else:
|
|
30
|
+
rank = len(sympy_shape)
|
|
31
|
+
sympy_shape = ctx.new_symbolic_shape(rank + 1, node)
|
|
32
|
+
ctx.update_computed_dims(sympy_shape)
|
|
33
|
+
if node.output[0]:
|
|
34
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
35
|
+
vi.CopyFrom(
|
|
36
|
+
helper.make_tensor_value_info(
|
|
37
|
+
node.output[0],
|
|
38
|
+
ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
|
|
39
|
+
get_shape_from_sympy_shape(sympy_shape),
|
|
40
|
+
)
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
register_aten_handler(AtenUnfoldHandler())
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Shape handler for ATen upsample operators."""
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from onnx import helper
|
|
8
|
+
|
|
9
|
+
from ..base import ShapeHandler
|
|
10
|
+
from ..registry import register_aten_handler
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class AtenUpsampleHandler(ShapeHandler):
|
|
14
|
+
"""Handler for ATen upsample operators."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, op_name):
|
|
17
|
+
super().__init__()
|
|
18
|
+
self._op_type = op_name
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def op_type(self) -> str:
|
|
22
|
+
return self._op_type
|
|
23
|
+
|
|
24
|
+
def infer_shape(self, node, ctx) -> None:
|
|
25
|
+
new_shape = None
|
|
26
|
+
input_shape = ctx.get_shape(node, 0)
|
|
27
|
+
if input_shape is not None:
|
|
28
|
+
new_shape = input_shape[:2]
|
|
29
|
+
output_size = ctx.try_get_value(node, 1)
|
|
30
|
+
if output_size is not None:
|
|
31
|
+
new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
|
|
32
|
+
else:
|
|
33
|
+
rank = len(input_shape)
|
|
34
|
+
new_shape += [str(ctx.new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
|
|
35
|
+
if node.output[0] and new_shape is not None:
|
|
36
|
+
output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
|
|
37
|
+
vi = ctx.known_vi_[node.output[0]]
|
|
38
|
+
vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
register_aten_handler(AtenUpsampleHandler("upsample_nearest1d"))
|
|
42
|
+
register_aten_handler(AtenUpsampleHandler("upsample_nearest2d"))
|
|
43
|
+
register_aten_handler(AtenUpsampleHandler("upsample_nearest3d"))
|
|
44
|
+
register_aten_handler(AtenUpsampleHandler("upsample_bicubic2d"))
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
"""Base classes for shape inference handlers."""
|
|
5
|
+
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class ShapeHandler(ABC):
|
|
10
|
+
"""Abstract base class for shape inference handlers.
|
|
11
|
+
|
|
12
|
+
Each handler is responsible for inferring the output shapes for a specific
|
|
13
|
+
ONNX operator type.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
def __init__(self, min_opset=1, max_opset=999):
|
|
17
|
+
"""Initialize the shape handler.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
min_opset: Minimum ONNX opset version this handler supports.
|
|
21
|
+
max_opset: Maximum ONNX opset version this handler supports.
|
|
22
|
+
"""
|
|
23
|
+
self.min_opset = min_opset
|
|
24
|
+
self.max_opset = max_opset
|
|
25
|
+
|
|
26
|
+
@property
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def op_type(self) -> str:
|
|
29
|
+
"""Return the ONNX operator type this handler supports.
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
The operator type string (e.g., "Reshape", "MatMul").
|
|
33
|
+
"""
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
def supports_opset(self, opset: int) -> bool:
|
|
37
|
+
"""Check if this handler supports the given opset version.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
opset: The ONNX opset version to check.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
True if the handler supports this opset version.
|
|
44
|
+
"""
|
|
45
|
+
return self.min_opset <= opset <= self.max_opset
|
|
46
|
+
|
|
47
|
+
@abstractmethod
|
|
48
|
+
def infer_shape(self, node, ctx) -> None:
|
|
49
|
+
"""Infer the output shapes for the given node.
|
|
50
|
+
|
|
51
|
+
Args:
|
|
52
|
+
node: The ONNX node to infer shapes for.
|
|
53
|
+
ctx: The InferenceContext providing access to shape information.
|
|
54
|
+
"""
|
|
55
|
+
raise NotImplementedError
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class PassthroughHandler(ShapeHandler):
|
|
59
|
+
"""Handler for operators that pass through input shape to output unchanged.
|
|
60
|
+
|
|
61
|
+
This is used for operators like Identity, Reciprocal, Round, etc.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(self, op_type_name, min_opset=1, max_opset=999):
|
|
65
|
+
"""Initialize the passthrough handler.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
op_type_name: The operator type name.
|
|
69
|
+
min_opset: Minimum ONNX opset version this handler supports.
|
|
70
|
+
max_opset: Maximum ONNX opset version this handler supports.
|
|
71
|
+
"""
|
|
72
|
+
super().__init__(min_opset, max_opset)
|
|
73
|
+
self._op_type = op_type_name
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def op_type(self) -> str:
|
|
77
|
+
"""Return the ONNX operator type this handler supports."""
|
|
78
|
+
return self._op_type
|
|
79
|
+
|
|
80
|
+
def infer_shape(self, node, ctx) -> None:
|
|
81
|
+
"""Pass through shape and type from input to output."""
|
|
82
|
+
ctx.pass_on_shape_and_type(node)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class MultiOpHandler(ShapeHandler):
|
|
86
|
+
"""Handler that supports multiple operator types with the same logic.
|
|
87
|
+
|
|
88
|
+
This is useful when multiple operators share the same shape inference logic.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(self, op_type_name, handler_func, min_opset=1, max_opset=999):
|
|
92
|
+
"""Initialize the multi-op handler.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
op_type_name: The operator type name.
|
|
96
|
+
handler_func: The function to call for shape inference.
|
|
97
|
+
min_opset: Minimum ONNX opset version this handler supports.
|
|
98
|
+
max_opset: Maximum ONNX opset version this handler supports.
|
|
99
|
+
"""
|
|
100
|
+
super().__init__(min_opset, max_opset)
|
|
101
|
+
self._op_type = op_type_name
|
|
102
|
+
self._handler_func = handler_func
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def op_type(self) -> str:
|
|
106
|
+
"""Return the ONNX operator type this handler supports."""
|
|
107
|
+
return self._op_type
|
|
108
|
+
|
|
109
|
+
def infer_shape(self, node, ctx) -> None:
|
|
110
|
+
"""Call the handler function for shape inference."""
|
|
111
|
+
self._handler_func(node, ctx)
|