onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,61 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Attention operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class AttentionHandler(ShapeHandler):
14
+ """Handler for Attention operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "Attention"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ shape = ctx.get_shape(node, 0)
22
+ shape_weights = ctx.get_shape(node, 1)
23
+ shape_bias = ctx.try_get_shape(node, 2)
24
+ if shape_bias is not None:
25
+ assert len(shape_bias) == 1
26
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
27
+ if shape and len(shape) == 3:
28
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
29
+ if qkv_hidden_sizes_attr is not None:
30
+ assert len(qkv_hidden_sizes_attr) == 3
31
+ shape[2] = int(qkv_hidden_sizes_attr[2])
32
+ elif isinstance(tripled_hidden_size, int):
33
+ shape[2] = int(tripled_hidden_size / 3)
34
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
35
+ vi = ctx.known_vi_[node.output[0]]
36
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
37
+
38
+ if len(node.output) > 1:
39
+ input_shape = ctx.get_shape(node, 0)
40
+ past_shape = ctx.get_shape(node, 4) if len(node.input) > 4 and node.input[4] else []
41
+ mask_shape = ctx.get_shape(node, 3) if len(node.input) > 3 and node.input[3] else []
42
+
43
+ if past_shape and len(past_shape) == 5:
44
+ if mask_shape and len(mask_shape) in {2, 3}:
45
+ past_shape[3] = mask_shape[-1]
46
+ elif input_shape and len(input_shape) == 3:
47
+ if isinstance(input_shape[1], int) and isinstance(past_shape[3], int):
48
+ past_shape[3] = input_shape[1] + past_shape[3]
49
+ else:
50
+ past_shape[3] = f"{past_shape[3]}+{input_shape[1]}"
51
+ vi = ctx.known_vi_[node.output[1]]
52
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
53
+ else:
54
+ num_heads = get_attribute(node, "num_heads")
55
+ head_size = input_shape[2] // num_heads
56
+ present_shape = [2, input_shape[0], num_heads, input_shape[1], head_size]
57
+ vi = ctx.known_vi_[node.output[1]]
58
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
59
+
60
+
61
+ register_shape_handler(AttentionHandler())
@@ -0,0 +1,37 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for DecoderMaskedMultiHeadAttention operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class DecoderMaskedMultiHeadAttentionHandler(ShapeHandler):
13
+ """Handler for DecoderMaskedMultiHeadAttention operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "DecoderMaskedMultiHeadAttention"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ query_shape = ctx.get_shape(node, 0)
21
+ if query_shape is not None:
22
+ output_shape = query_shape
23
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
24
+ assert output_dtype is not None
25
+ vi = ctx.known_vi_[node.output[0]]
26
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
27
+
28
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
29
+ past_shape = ctx.try_get_shape(node, 5)
30
+ if past_shape is not None:
31
+ vi = ctx.known_vi_[node.output[1]]
32
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
33
+ vi = ctx.known_vi_[node.output[2]]
34
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, past_shape))
35
+
36
+
37
+ register_shape_handler(DecoderMaskedMultiHeadAttentionHandler())
@@ -0,0 +1,35 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GatedRelativePositionBias operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class GatedRelativePositionBiasHandler(ShapeHandler):
14
+ """Handler for GatedRelativePositionBias operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "GatedRelativePositionBias"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ num_heads = get_attribute(node, "num_heads")
22
+ token_offset_shape = ctx.try_get_shape(node, 6)
23
+ if token_offset_shape is not None:
24
+ output_shape = [token_offset_shape[0], num_heads, token_offset_shape[1], token_offset_shape[1]]
25
+ else:
26
+ query_layer_shape = ctx.get_shape(node, 0)
27
+ assert query_layer_shape is not None and len(query_layer_shape) == 3
28
+ output_shape = [query_layer_shape[0], num_heads, query_layer_shape[1], query_layer_shape[1]]
29
+
30
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
31
+ vi = ctx.known_vi_[node.output[0]]
32
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
33
+
34
+
35
+ register_shape_handler(GatedRelativePositionBiasHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for LongformerAttention operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class LongformerAttentionHandler(ShapeHandler):
11
+ """Handler for LongformerAttention operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "LongformerAttention"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(LongformerAttentionHandler())
@@ -0,0 +1,82 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MultiHeadAttention operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class MultiHeadAttentionHandler(ShapeHandler):
14
+ """Handler for MultiHeadAttention operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "MultiHeadAttention"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ query_shape = ctx.get_shape(node, 0)
22
+ total_sequence_length = None
23
+ output_dtype = None
24
+ if query_shape is not None:
25
+ if len(query_shape) == 3:
26
+ key_shape = ctx.try_get_shape(node, 1)
27
+ output_shape = query_shape
28
+ if key_shape is not None and len(key_shape) == 3:
29
+ value_shape = ctx.try_get_shape(node, 2)
30
+ if value_shape is not None and len(value_shape) == 3:
31
+ output_shape[2] = value_shape[2]
32
+ total_sequence_length = key_shape[1]
33
+
34
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
35
+ vi = ctx.known_vi_[node.output[0]]
36
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
37
+
38
+ elif len(query_shape) == 5:
39
+ if isinstance(query_shape[2], int) and isinstance(query_shape[4], int):
40
+ output_shape = [query_shape[0], query_shape[1], query_shape[2] * query_shape[4]]
41
+ else:
42
+ output_shape = [query_shape[0], query_shape[1], f"{query_shape[2]}*{query_shape[4]}"]
43
+
44
+ total_sequence_length = query_shape[1]
45
+
46
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
47
+ vi = ctx.known_vi_[node.output[0]]
48
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
49
+
50
+ if len(node.output) > 1:
51
+ batch_size = query_shape[0]
52
+ num_heads = get_attribute(node, "num_heads")
53
+
54
+ head_size = None
55
+ if len(query_shape) == 3:
56
+ head_size = (
57
+ int(query_shape[2] / num_heads)
58
+ if isinstance(query_shape[2], int)
59
+ else f"{query_shape[2]}/{num_heads}"
60
+ )
61
+ else:
62
+ head_size = query_shape[4]
63
+
64
+ past_shape = ctx.try_get_shape(node, 6)
65
+
66
+ if past_shape is not None:
67
+ if isinstance(past_shape[2], int) and isinstance(total_sequence_length, int):
68
+ total_sequence_length = past_shape[2] + total_sequence_length
69
+ else:
70
+ total_sequence_length = f"{past_shape[2]}+{total_sequence_length}"
71
+
72
+ present_shape = [batch_size, num_heads, total_sequence_length, head_size]
73
+
74
+ assert output_dtype is not None
75
+ if len(node.output) > 2 and node.output[1] and node.output[2]:
76
+ vi = ctx.known_vi_[node.output[1]]
77
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
78
+ vi = ctx.known_vi_[node.output[2]]
79
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, present_shape))
80
+
81
+
82
+ register_shape_handler(MultiHeadAttentionHandler())
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for MultiScaleDeformableAttnTRT operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class MultiScaleDeformableAttnTRTHandler(ShapeHandler):
13
+ """Handler for MultiScaleDeformableAttnTRT operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "MultiScaleDeformableAttnTRT"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ shape_value = ctx.try_get_shape(node, 0)
21
+ sampling_locations = ctx.try_get_shape(node, 3)
22
+ output_shape = shape_value
23
+ output_shape[1] = sampling_locations[1]
24
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
25
+ vi = ctx.known_vi_[node.output[0]]
26
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
27
+
28
+
29
+ register_shape_handler(MultiScaleDeformableAttnTRTHandler())
@@ -0,0 +1,39 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for PackedAttention operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+ from ...utils import get_attribute
11
+
12
+
13
+ class PackedAttentionHandler(ShapeHandler):
14
+ """Handler for PackedAttention operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "PackedAttention"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ shape = ctx.get_shape(node, 0)
22
+ shape_weights = ctx.get_shape(node, 1)
23
+ shape_bias = ctx.try_get_shape(node, 2)
24
+ if shape_bias is not None:
25
+ assert len(shape_bias) == 1
26
+ tripled_hidden_size = shape_bias[0] if shape_bias is not None else shape_weights[1]
27
+ if shape and len(shape) == 2:
28
+ qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes")
29
+ if qkv_hidden_sizes_attr is not None:
30
+ assert len(qkv_hidden_sizes_attr) == 3
31
+ shape[1] = int(qkv_hidden_sizes_attr[2])
32
+ elif isinstance(tripled_hidden_size, int):
33
+ shape[1] = int(tripled_hidden_size / 3)
34
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
35
+ vi = ctx.known_vi_[node.output[0]]
36
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, shape))
37
+
38
+
39
+ register_shape_handler(PackedAttentionHandler())
@@ -0,0 +1,33 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for PackedMultiHeadAttention operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class PackedMultiHeadAttentionHandler(ShapeHandler):
13
+ """Handler for PackedMultiHeadAttention operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "PackedMultiHeadAttention"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ shape_value = ctx.try_get_shape(node, 2)
21
+ if shape_value is not None and len(shape_value) == 2:
22
+ output_shape = shape_value
23
+ else:
24
+ shape_query = ctx.get_shape(node, 0)
25
+ assert shape_query is not None and len(shape_query) == 4
26
+ output_shape = [shape_query[0], shape_query[1] * shape_query[3]]
27
+
28
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
29
+ vi = ctx.known_vi_[node.output[0]]
30
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
31
+
32
+
33
+ register_shape_handler(PackedMultiHeadAttentionHandler())
@@ -0,0 +1,41 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for RemovePadding operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+
12
+
13
+ class RemovePaddingHandler(ShapeHandler):
14
+ """Handler for RemovePadding operator."""
15
+
16
+ @property
17
+ def op_type(self) -> str:
18
+ return "RemovePadding"
19
+
20
+ def infer_shape(self, node, ctx) -> None:
21
+ shape = ctx.get_shape(node, 0)
22
+ if shape and len(shape) == 3:
23
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
24
+ vi = ctx.known_vi_[node.output[0]]
25
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, ["token_count", shape[2]]))
26
+
27
+ vi_token_offset = ctx.known_vi_[node.output[1]]
28
+ vi_token_offset.CopyFrom(
29
+ helper.make_tensor_value_info(node.output[1], onnx.TensorProto.INT32, [shape[0], shape[1]])
30
+ )
31
+
32
+ vi_cumulated_seq_len = ctx.known_vi_[node.output[2]]
33
+ vi_cumulated_seq_len.CopyFrom(
34
+ helper.make_tensor_value_info(node.output[2], onnx.TensorProto.INT32, ["batch_size + 1"])
35
+ )
36
+
37
+ vi_max_seq_len = ctx.known_vi_[node.output[3]]
38
+ vi_max_seq_len.CopyFrom(helper.make_tensor_value_info(node.output[3], onnx.TensorProto.INT32, [1]))
39
+
40
+
41
+ register_shape_handler(RemovePaddingHandler())
@@ -0,0 +1,29 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for RestorePadding operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class RestorePaddingHandler(ShapeHandler):
13
+ """Handler for RestorePadding operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "RestorePadding"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ shape_input = ctx.get_shape(node, 0)
21
+ shape_token_offset = ctx.get_shape(node, 1)
22
+ if shape_input and len(shape_input) == 2 and shape_token_offset and len(shape_token_offset) == 2:
23
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
24
+ vi = ctx.known_vi_[node.output[0]]
25
+ output_shape = [shape_token_offset[0], shape_token_offset[1], shape_input[1]]
26
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, output_shape))
27
+
28
+
29
+ register_shape_handler(RestorePaddingHandler())
@@ -0,0 +1,15 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Miscellaneous contrib operator shape handlers."""
5
+
6
+ from . import bias_gelu
7
+ from . import fast_gelu
8
+ from . import gelu
9
+ from . import quick_gelu
10
+ from . import gemm_fast_gelu
11
+ from . import gemm_float8
12
+ from . import bias_split_gelu
13
+ from . import bias_add
14
+ from . import rotary_embedding
15
+ from . import python_op
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for BiasAdd operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class BiasAddHandler(ShapeHandler):
11
+ """Handler for BiasAdd operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "BiasAdd"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(BiasAddHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for BiasGelu operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class BiasGeluHandler(ShapeHandler):
11
+ """Handler for BiasGelu operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "BiasGelu"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(BiasGeluHandler())
@@ -0,0 +1,30 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for BiasSplitGelu operator."""
5
+
6
+ from onnx import helper
7
+
8
+ from ...base import ShapeHandler
9
+ from ...registry import register_shape_handler
10
+
11
+
12
+ class BiasSplitGeluHandler(ShapeHandler):
13
+ """Handler for BiasSplitGelu operator."""
14
+
15
+ @property
16
+ def op_type(self) -> str:
17
+ return "BiasSplitGelu"
18
+
19
+ def infer_shape(self, node, ctx) -> None:
20
+ input_shape = ctx.get_shape(node, 0)
21
+ bias_shape = ctx.get_shape(node, 1)
22
+ if input_shape and bias_shape and isinstance(bias_shape[0], int):
23
+ output_shape = input_shape
24
+ output_shape[2] = int(bias_shape[0] / 2)
25
+ vi = ctx.known_vi_[node.output[0]]
26
+ output_dtype = ctx.known_vi_[node.input[0]].type.tensor_type.elem_type
27
+ vi.CopyFrom(helper.make_tensor_value_info(vi.name, output_dtype, output_shape))
28
+
29
+
30
+ register_shape_handler(BiasSplitGeluHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for FastGelu operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class FastGeluHandler(ShapeHandler):
11
+ """Handler for FastGelu operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "FastGelu"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(FastGeluHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for Gelu operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class GeluHandler(ShapeHandler):
11
+ """Handler for Gelu operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "Gelu"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.propagate_shape_and_type(node)
19
+
20
+
21
+ register_shape_handler(GeluHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GemmFastGelu operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class GemmFastGeluHandler(ShapeHandler):
11
+ """Handler for GemmFastGelu operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "GemmFastGelu"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.compute_matmul_shape(node)
19
+
20
+
21
+ register_shape_handler(GemmFastGeluHandler())
@@ -0,0 +1,21 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for GemmFloat8 operator."""
5
+
6
+ from ...base import ShapeHandler
7
+ from ...registry import register_shape_handler
8
+
9
+
10
+ class GemmFloat8Handler(ShapeHandler):
11
+ """Handler for GemmFloat8 operator."""
12
+
13
+ @property
14
+ def op_type(self) -> str:
15
+ return "GemmFloat8"
16
+
17
+ def infer_shape(self, node, ctx) -> None:
18
+ ctx.compute_matmul_shape(node)
19
+
20
+
21
+ register_shape_handler(GemmFloat8Handler())
@@ -0,0 +1,67 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ """Shape handler for PythonOp operator."""
5
+
6
+ import onnx
7
+ from onnx import helper
8
+
9
+ from ...base import ShapeHandler
10
+ from ...registry import register_shape_handler
11
+ from ...utils import get_attribute, get_shape_from_sympy_shape
12
+
13
+
14
+ class PythonOpHandler(ShapeHandler):
15
+ """Handler for PythonOp operator."""
16
+
17
+ @property
18
+ def op_type(self) -> str:
19
+ return "PythonOp"
20
+
21
+ def infer_shape(self, node, ctx) -> None:
22
+ output_tensor_types = get_attribute(node, "output_tensor_types")
23
+ assert output_tensor_types, f"PythonOp '{node.name}' has no output_tensor_types attribute."
24
+ output_tensor_ranks = get_attribute(node, "output_tensor_ranks")
25
+ assert output_tensor_ranks, f"PythonOp '{node.name}' has no output_tensor_ranks attribute."
26
+
27
+ try:
28
+ from onnxruntime.capi._pybind_state import get_shape_inference_function
29
+
30
+ func_name = get_attribute(node, "func_name").decode()
31
+ shape_inferer = get_shape_inference_function(func_name)
32
+ except ImportError:
33
+ shape_inferer = None
34
+
35
+ # Set the context output separately.
36
+ # The first output is torch.autograd.Function's context.
37
+ vi = ctx.known_vi_[node.output[0]]
38
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []))
39
+
40
+ if shape_inferer is not None:
41
+ input_shapes = []
42
+ input_dtypes = []
43
+ for input_index in range(len(node.input)):
44
+ shape = ctx.get_shape(node, input_index)
45
+ input_shapes.append(shape)
46
+ input_dtype = ctx.known_vi_[node.input[input_index]].type.tensor_type.elem_type
47
+ input_dtypes.append(input_dtype)
48
+ output_shapes, output_dtypes = shape_inferer(node, input_shapes, input_dtypes)
49
+ assert len(output_shapes) == len(output_dtypes) == (len(node.output) - 1), (
50
+ f"PythonOp '{func_name}' returned {len(output_shapes)} shapes and {len(output_dtypes)} dtypes, "
51
+ f"but expected {len(node.output) - 1} outputs."
52
+ )
53
+ for i in range(len(node.output) - 1):
54
+ output_index = i + 1
55
+ vi = ctx.known_vi_[node.output[output_index]]
56
+ vi.CopyFrom(helper.make_tensor_value_info(node.output[output_index], output_dtypes[i], output_shapes[i]))
57
+ else:
58
+ # General shape inference for PythonOp.
59
+ for i in range(len(node.output) - 1):
60
+ vi = ctx.known_vi_[node.output[i + 1]]
61
+ sympy_shape = ctx.new_symbolic_shape(output_tensor_ranks[i], node)
62
+ shape = get_shape_from_sympy_shape(sympy_shape)
63
+ value_info = helper.make_tensor_value_info(node.output[i + 1], output_tensor_types[i], shape)
64
+ vi.CopyFrom(value_info)
65
+
66
+
67
+ register_shape_handler(PythonOpHandler())