onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen multinomial operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+ from ..utils import get_shape_from_sympy_shape
12
+
13
+
14
+ class AtenMultinomialHandler(ShapeHandler):
15
+ """Handler for ATen multinomial operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "multinomial"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ sympy_shape = ctx.get_sympy_shape(node, 0)
23
+ rank = len(sympy_shape)
24
+ assert rank in {1, 2}
25
+ num_samples = ctx.try_get_value(node, 1)
26
+ di = rank - 1
27
+ last_dim = num_samples or str(ctx.new_symbolic_dim_from_output(node, 0, di))
28
+ output_shape = [*sympy_shape[:-1], last_dim]
29
+ vi = ctx.known_vi_[node.output[0]]
30
+ vi.CopyFrom(
31
+ helper.make_tensor_value_info(
32
+ node.output[0],
33
+ onnx.TensorProto.INT64,
34
+ get_shape_from_sympy_shape(output_shape),
35
+ )
36
+ )
37
+
38
+
39
+ register_aten_handler(AtenMultinomialHandler())
@@ -0,0 +1,22 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen numpy_T operator."""
5
+
6
+ from ..base import ShapeHandler
7
+ from ..registry import register_aten_handler
8
+ from ..standard_ops.tensor.transpose import TransposeHandler
9
+
10
+
11
+ class AtenNumpyTHandler(ShapeHandler):
12
+ """Handler for ATen numpy_T operator (reuses Transpose logic)."""
13
+
14
+ @property
15
+ def op_type(self) -> str:
16
+ return "numpy_T"
17
+
18
+ def infer_shape(self, node, ctx) -> None:
19
+ TransposeHandler().infer_shape(node, ctx)
20
+
21
+
22
+ register_aten_handler(AtenNumpyTHandler())
@@ -0,0 +1,40 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen pooling operators."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+ from ..utils import get_shape_from_sympy_shape
12
+
13
+
14
+ class AtenPool2dHandler(ShapeHandler):
15
+ """Handler for ATen pooling operators (max_pool2d_with_indices, avg_pool2d, _adaptive_avg_pool2d)."""
16
+
17
+ def __init__(self, op_name):
18
+ super().__init__()
19
+ self._op_type = op_name
20
+
21
+ @property
22
+ def op_type(self) -> str:
23
+ return self._op_type
24
+
25
+ def infer_shape(self, node, ctx) -> None:
26
+ sympy_shape = ctx.get_sympy_shape(node, 0)
27
+ assert len(sympy_shape) == 4
28
+ sympy_shape[-2:] = [ctx.new_symbolic_dim_from_output(node, 0, i) for i in {2, 3}]
29
+ ctx.update_computed_dims(sympy_shape)
30
+ for i, o in enumerate(node.output):
31
+ if not o:
32
+ continue
33
+ vi = ctx.known_vi_[o]
34
+ elem_type = onnx.TensorProto.INT64 if i == 1 else ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
35
+ vi.CopyFrom(helper.make_tensor_value_info(o, elem_type, get_shape_from_sympy_shape(sympy_shape)))
36
+
37
+
38
+ register_aten_handler(AtenPool2dHandler("max_pool2d_with_indices"))
39
+ register_aten_handler(AtenPool2dHandler("avg_pool2d"))
40
+ register_aten_handler(AtenPool2dHandler("_adaptive_avg_pool2d"))
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen unfold operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ..base import ShapeHandler
9
+ from ..registry import register_aten_handler
10
+ from ..utils import get_shape_from_sympy_shape
11
+
12
+
13
+ class AtenUnfoldHandler(ShapeHandler):
14
+ """Handler for ATen unfold operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "unfold"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ sympy_shape = ctx.get_sympy_shape(node, 0)
22
+ dimension = ctx.try_get_value(node, 1)
23
+ size = ctx.try_get_value(node, 2)
24
+ step = ctx.try_get_value(node, 3)
25
+ if dimension is not None and size is not None and step is not None:
26
+ assert dimension < len(sympy_shape)
27
+ sympy_shape[dimension] = (sympy_shape[dimension] - size) // step + 1
28
+ sympy_shape.append(size)
29
+ else:
30
+ rank = len(sympy_shape)
31
+ sympy_shape = ctx.new_symbolic_shape(rank + 1, node)
32
+ ctx.update_computed_dims(sympy_shape)
33
+ if node.output[0]:
34
+ vi = ctx.known_vi_[node.output[0]]
35
+ vi.CopyFrom(
36
+ helper.make_tensor_value_info(
37
+ node.output[0],
38
+ ctx.known_vi_[node.input[0]].type.tensor_type.elem_type,
39
+ get_shape_from_sympy_shape(sympy_shape),
40
+ )
41
+ )
42
+
43
+
44
+ register_aten_handler(AtenUnfoldHandler())
@@ -0,0 +1,44 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for ATen upsample operators."""
5
+
6
+ import numpy as np
7
+ from onnx import helper
8
+
9
+ from ..base import ShapeHandler
10
+ from ..registry import register_aten_handler
11
+
12
+
13
+ class AtenUpsampleHandler(ShapeHandler):
14
+ """Handler for ATen upsample operators."""
15
+
16
+ def __init__(self, op_name):
17
+ super().__init__()
18
+ self._op_type = op_name
19
+
20
+ @property
21
+ def op_type(self) -> str:
22
+ return self._op_type
23
+
24
+ def infer_shape(self, node, ctx) -> None:
25
+ new_shape = None
26
+ input_shape = ctx.get_shape(node, 0)
27
+ if input_shape is not None:
28
+ new_shape = input_shape[:2]
29
+ output_size = ctx.try_get_value(node, 1)
30
+ if output_size is not None:
31
+ new_shape += [dim_size.item() if type(dim_size) == np.int64 else dim_size for dim_size in output_size]
32
+ else:
33
+ rank = len(input_shape)
34
+ new_shape += [str(ctx.new_symbolic_dim_from_output(node, 0, i)) for i in range(2, rank)]
35
+ if node.output[0] and new_shape is not None:
36
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
37
+ vi = ctx.known_vi_[node.output[0]]
38
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape))
39
+
40
+
41
+ register_aten_handler(AtenUpsampleHandler("upsample_nearest1d"))
42
+ register_aten_handler(AtenUpsampleHandler("upsample_nearest2d"))
43
+ register_aten_handler(AtenUpsampleHandler("upsample_nearest3d"))
44
+ register_aten_handler(AtenUpsampleHandler("upsample_bicubic2d"))
@@ -0,0 +1,111 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Base classes for shape inference handlers."""
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+
9
+ class ShapeHandler(ABC):
10
+ """Abstract base class for shape inference handlers.
11
+
12
+ Each handler is responsible for inferring the output shapes for a specific
13
+ ONNX operator type.
14
+ """
15
+
16
+ def __init__(self, min_opset=1, max_opset=999):
17
+ """Initialize the shape handler.
18
+
19
+ Args:
20
+ min_opset: Minimum ONNX opset version this handler supports.
21
+ max_opset: Maximum ONNX opset version this handler supports.
22
+ """
23
+ self.min_opset = min_opset
24
+ self.max_opset = max_opset
25
+
26
+ @property
27
+ @abstractmethod
28
+ def op_type(self) -> str:
29
+ """Return the ONNX operator type this handler supports.
30
+
31
+ Returns:
32
+ The operator type string (e.g., "Reshape", "MatMul").
33
+ """
34
+ raise NotImplementedError
35
+
36
+ def supports_opset(self, opset: int) -> bool:
37
+ """Check if this handler supports the given opset version.
38
+
39
+ Args:
40
+ opset: The ONNX opset version to check.
41
+
42
+ Returns:
43
+ True if the handler supports this opset version.
44
+ """
45
+ return self.min_opset <= opset <= self.max_opset
46
+
47
+ @abstractmethod
48
+ def infer_shape(self, node, ctx) -> None:
49
+ """Infer the output shapes for the given node.
50
+
51
+ Args:
52
+ node: The ONNX node to infer shapes for.
53
+ ctx: The InferenceContext providing access to shape information.
54
+ """
55
+ raise NotImplementedError
56
+
57
+
58
+ class PassthroughHandler(ShapeHandler):
59
+ """Handler for operators that pass through input shape to output unchanged.
60
+
61
+ This is used for operators like Identity, Reciprocal, Round, etc.
62
+ """
63
+
64
+ def __init__(self, op_type_name, min_opset=1, max_opset=999):
65
+ """Initialize the passthrough handler.
66
+
67
+ Args:
68
+ op_type_name: The operator type name.
69
+ min_opset: Minimum ONNX opset version this handler supports.
70
+ max_opset: Maximum ONNX opset version this handler supports.
71
+ """
72
+ super().__init__(min_opset, max_opset)
73
+ self._op_type = op_type_name
74
+
75
+ @property
76
+ def op_type(self) -> str:
77
+ """Return the ONNX operator type this handler supports."""
78
+ return self._op_type
79
+
80
+ def infer_shape(self, node, ctx) -> None:
81
+ """Pass through shape and type from input to output."""
82
+ ctx.pass_on_shape_and_type(node)
83
+
84
+
85
+ class MultiOpHandler(ShapeHandler):
86
+ """Handler that supports multiple operator types with the same logic.
87
+
88
+ This is useful when multiple operators share the same shape inference logic.
89
+ """
90
+
91
+ def __init__(self, op_type_name, handler_func, min_opset=1, max_opset=999):
92
+ """Initialize the multi-op handler.
93
+
94
+ Args:
95
+ op_type_name: The operator type name.
96
+ handler_func: The function to call for shape inference.
97
+ min_opset: Minimum ONNX opset version this handler supports.
98
+ max_opset: Maximum ONNX opset version this handler supports.
99
+ """
100
+ super().__init__(min_opset, max_opset)
101
+ self._op_type = op_type_name
102
+ self._handler_func = handler_func
103
+
104
+ @property
105
+ def op_type(self) -> str:
106
+ """Return the ONNX operator type this handler supports."""
107
+ return self._op_type
108
+
109
+ def infer_shape(self, node, ctx) -> None:
110
+ """Call the handler function for shape inference."""
111
+ self._handler_func(node, ctx)