onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. onnxslim/core/optimization/dead_node_elimination.py +85 -4
  2. onnxslim/core/pattern/elimination/slice.py +15 -8
  3. onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
  4. onnxslim/core/pattern/fusion/convadd.py +23 -7
  5. onnxslim/core/pattern/fusion/convbn.py +24 -11
  6. onnxslim/core/pattern/fusion/convmul.py +26 -9
  7. onnxslim/core/pattern/fusion/gemm.py +7 -5
  8. onnxslim/core/pattern/fusion/padconv.py +5 -0
  9. onnxslim/core/shape_inference/__init__.py +378 -0
  10. onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
  11. onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
  12. onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
  13. onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
  14. onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
  15. onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
  16. onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
  17. onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
  18. onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
  19. onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
  20. onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
  21. onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
  22. onnxslim/core/shape_inference/base.py +111 -0
  23. onnxslim/core/shape_inference/context.py +645 -0
  24. onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
  25. onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
  26. onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
  27. onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
  28. onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
  29. onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
  30. onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
  31. onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
  32. onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
  33. onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
  34. onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
  35. onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
  36. onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
  37. onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
  38. onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
  39. onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
  40. onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
  41. onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
  42. onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
  43. onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
  44. onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
  45. onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
  46. onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
  47. onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
  48. onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
  49. onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
  50. onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
  51. onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
  52. onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
  53. onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
  54. onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
  55. onnxslim/core/shape_inference/registry.py +90 -0
  56. onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
  57. onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
  58. onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
  59. onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
  60. onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
  61. onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
  62. onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
  63. onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
  64. onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
  65. onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
  66. onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
  67. onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
  68. onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
  69. onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
  70. onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
  71. onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
  72. onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
  73. onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
  74. onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
  75. onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
  76. onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
  77. onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
  78. onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
  79. onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
  80. onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
  81. onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
  82. onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
  83. onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
  84. onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
  85. onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
  86. onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
  87. onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
  88. onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
  89. onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
  90. onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
  91. onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
  92. onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
  93. onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
  94. onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
  95. onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
  96. onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
  97. onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
  98. onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
  99. onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
  100. onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
  101. onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
  102. onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
  103. onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
  104. onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
  105. onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
  106. onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
  107. onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
  108. onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
  109. onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
  110. onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
  111. onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
  112. onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
  113. onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
  114. onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
  115. onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
  116. onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
  117. onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
  118. onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
  119. onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
  120. onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
  121. onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
  122. onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
  123. onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
  124. onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
  125. onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
  126. onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
  127. onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
  128. onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
  129. onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
  130. onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
  131. onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
  132. onnxslim/core/shape_inference/utils.py +244 -0
  133. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
  134. onnxslim/third_party/symbolic_shape_infer.py +73 -3156
  135. onnxslim/utils.py +4 -2
  136. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
  137. onnxslim-0.1.84.dist-info/RECORD +187 -0
  138. onnxslim-0.1.82.dist-info/RECORD +0 -63
  139. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
  140. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
  141. {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
@@ -794,109 +794,6 @@ class Graph:
794
794
  tensor.to_constant(arr)
795
795
  tensor.inputs.clear()
796
796
 
797
- # Pass 2: Run shape-tensor cast elision
798
- def run_cast_elision(node):
799
- """Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations."""
800
- import onnx
801
-
802
- # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
803
- # This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
804
- # are not allowed to be floating point type. Attempt to fold the pattern here
805
- VALID_CAST_ELISION_OPS = {
806
- "Add",
807
- "Sub",
808
- "Mul",
809
- "Div",
810
- "Max",
811
- "Min",
812
- "Equal",
813
- "Greater",
814
- "Less",
815
- "Concat",
816
- }
817
-
818
- if node.op not in VALID_CAST_ELISION_OPS:
819
- return
820
-
821
- # If the uncasted outputs of this node have any consumers other than "Cast" nodes,
822
- # then we cannot elide the cast.
823
- for out_tensor in node.outputs:
824
- if out_tensor in self.outputs:
825
- return
826
-
827
- if any(out_node.op != "Cast" for out_node in out_tensor.outputs):
828
- return
829
-
830
- # Get list of input nodes that cast to float32
831
- inp_casts = [
832
- inp_node
833
- for inp_tensor in node.inputs
834
- for inp_node in inp_tensor.inputs
835
- if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
836
- ]
837
-
838
- # No cast nodes found, return early
839
- if not inp_casts:
840
- return
841
-
842
- # Ensure that all input cast nodes are casting from the same type
843
- inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts]
844
- if len(set(inp_dtypes)) != 1:
845
- return
846
-
847
- final_type = inp_dtypes[0]
848
-
849
- # Get list of output nodes that cast to int32 or int64
850
- out_casts = [
851
- out_node
852
- for out_tensor in node.outputs
853
- for out_node in out_tensor.outputs
854
- if out_node.op == "Cast"
855
- and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64}
856
- ]
857
-
858
- # No cast node found on outputs, return early
859
- if not out_casts:
860
- return
861
-
862
- # Ensure that all output cast nodes are casting to the same type and that this
863
- # matches the original type before the inputs were casted.
864
- out_dtypes = [out_cast.attrs["to"] for out_cast in out_casts]
865
- if len(set(out_dtypes)) != 1 or out_dtypes[0] != final_type:
866
- return
867
-
868
- # If all checks passed, reconnect inputs/outputs to the consumers/producers
869
- # of the Cast nodes.
870
- # Note that we need to be careful in how we rebind tensors since they may
871
- # be used by multiple nodes. Thus, it is not necessarily safe to assume that
872
- # `cast_node.inputs[0].outputs[0] == cast_node`.
873
- for index, inp in enumerate(node.inputs):
874
- if isinstance(inp, Constant):
875
- inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type))
876
-
877
- for cast in inp_casts:
878
- if cast.outputs[0] == inp:
879
- node.inputs[index] = cast.inputs[0]
880
-
881
- for index, out in enumerate(node.outputs):
882
- for cast in out_casts:
883
- if cast.inputs[0] == out:
884
- out_tensor = cast.outputs[0]
885
- out_tensor.inputs.clear() # Disconnect from Cast
886
- node.outputs[index] = out_tensor
887
-
888
- if fold_shapes:
889
- # Perform shape tensor cast elision prior to most other folding
890
- G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}")
891
- try:
892
- with self.node_ids():
893
- for node in self.nodes:
894
- run_cast_elision(node)
895
- except Exception as err:
896
- if not error_ok:
897
- raise err
898
- G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
899
-
900
797
  # Note that most of the remaining passes operate on a clone of the original graph.
901
798
  # Pass 3: Find all descendants of constant tensors
902
799