onnxslim 0.1.82__py3-none-any.whl → 0.1.84__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/core/optimization/dead_node_elimination.py +85 -4
- onnxslim/core/pattern/elimination/slice.py +15 -8
- onnxslim/core/pattern/fusion/concat_reshape.py +3 -1
- onnxslim/core/pattern/fusion/convadd.py +23 -7
- onnxslim/core/pattern/fusion/convbn.py +24 -11
- onnxslim/core/pattern/fusion/convmul.py +26 -9
- onnxslim/core/pattern/fusion/gemm.py +7 -5
- onnxslim/core/pattern/fusion/padconv.py +5 -0
- onnxslim/core/shape_inference/__init__.py +378 -0
- onnxslim/core/shape_inference/aten_ops/__init__.py +16 -0
- onnxslim/core/shape_inference/aten_ops/argmax.py +47 -0
- onnxslim/core/shape_inference/aten_ops/bitwise_or.py +28 -0
- onnxslim/core/shape_inference/aten_ops/diagonal.py +52 -0
- onnxslim/core/shape_inference/aten_ops/embedding.py +23 -0
- onnxslim/core/shape_inference/aten_ops/group_norm.py +41 -0
- onnxslim/core/shape_inference/aten_ops/min_max.py +64 -0
- onnxslim/core/shape_inference/aten_ops/multinomial.py +39 -0
- onnxslim/core/shape_inference/aten_ops/numpy_t.py +22 -0
- onnxslim/core/shape_inference/aten_ops/pool2d.py +40 -0
- onnxslim/core/shape_inference/aten_ops/unfold.py +44 -0
- onnxslim/core/shape_inference/aten_ops/upsample.py +44 -0
- onnxslim/core/shape_inference/base.py +111 -0
- onnxslim/core/shape_inference/context.py +645 -0
- onnxslim/core/shape_inference/contrib_ops/__init__.py +8 -0
- onnxslim/core/shape_inference/contrib_ops/attention/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/attention/attention.py +61 -0
- onnxslim/core/shape_inference/contrib_ops/attention/decoder_masked_mha.py +37 -0
- onnxslim/core/shape_inference/contrib_ops/attention/gated_relative_position_bias.py +35 -0
- onnxslim/core/shape_inference/contrib_ops/attention/longformer_attention.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_head_attention.py +82 -0
- onnxslim/core/shape_inference/contrib_ops/attention/multi_scale_deformable_attn.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_attention.py +39 -0
- onnxslim/core/shape_inference/contrib_ops/attention/packed_multi_head_attention.py +33 -0
- onnxslim/core/shape_inference/contrib_ops/attention/remove_padding.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/attention/restore_padding.py +29 -0
- onnxslim/core/shape_inference/contrib_ops/misc/__init__.py +15 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_add.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/bias_split_gelu.py +30 -0
- onnxslim/core/shape_inference/contrib_ops/misc/fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_fast_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/gemm_float8.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/python_op.py +67 -0
- onnxslim/core/shape_inference/contrib_ops/misc/quick_gelu.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/misc/rotary_embedding.py +31 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/__init__.py +12 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/embed_layer_normalization.py +41 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/group_norm.py +21 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/layer_normalization.py +42 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_group_norm.py +23 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_layer_normalization.py +26 -0
- onnxslim/core/shape_inference/contrib_ops/normalization/skip_simplified_layer_normalization.py +23 -0
- onnxslim/core/shape_inference/registry.py +90 -0
- onnxslim/core/shape_inference/standard_ops/__init__.py +11 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/__init__.py +8 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/if_op.py +43 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/loop.py +74 -0
- onnxslim/core/shape_inference/standard_ops/control_flow/scan.py +54 -0
- onnxslim/core/shape_inference/standard_ops/math/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/math/_symbolic_compute.py +34 -0
- onnxslim/core/shape_inference/standard_ops/math/add.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/div.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/einsum.py +119 -0
- onnxslim/core/shape_inference/standard_ops/math/equal.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/floor.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul.py +21 -0
- onnxslim/core/shape_inference/standard_ops/math/matmul_integer.py +23 -0
- onnxslim/core/shape_inference/standard_ops/math/max.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/min.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/mul.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/neg.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_prod.py +27 -0
- onnxslim/core/shape_inference/standard_ops/math/reduce_sum.py +53 -0
- onnxslim/core/shape_inference/standard_ops/math/sub.py +10 -0
- onnxslim/core/shape_inference/standard_ops/math/where.py +10 -0
- onnxslim/core/shape_inference/standard_ops/misc/__init__.py +22 -0
- onnxslim/core/shape_inference/standard_ops/misc/array_feature_extractor.py +32 -0
- onnxslim/core/shape_inference/standard_ops/misc/cast.py +21 -0
- onnxslim/core/shape_inference/standard_ops/misc/category_mapper.py +30 -0
- onnxslim/core/shape_inference/standard_ops/misc/compress.py +39 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant.py +27 -0
- onnxslim/core/shape_inference/standard_ops/misc/constant_of_shape.py +45 -0
- onnxslim/core/shape_inference/standard_ops/misc/dequantize_linear.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_max_suppression.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/non_zero.py +26 -0
- onnxslim/core/shape_inference/standard_ops/misc/one_hot.py +42 -0
- onnxslim/core/shape_inference/standard_ops/misc/quantize_linear.py +29 -0
- onnxslim/core/shape_inference/standard_ops/misc/range.py +41 -0
- onnxslim/core/shape_inference/standard_ops/misc/relative_position_bias.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/resize.py +74 -0
- onnxslim/core/shape_inference/standard_ops/misc/scatter_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/misc/softmax_cross_entropy_loss.py +44 -0
- onnxslim/core/shape_inference/standard_ops/misc/top_k.py +44 -0
- onnxslim/core/shape_inference/standard_ops/nn/__init__.py +18 -0
- onnxslim/core/shape_inference/standard_ops/nn/all_reduce.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/average_pool.py +40 -0
- onnxslim/core/shape_inference/standard_ops/nn/batch_normalization.py +26 -0
- onnxslim/core/shape_inference/standard_ops/nn/conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/cum_sum.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/identity.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/max_pool.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_from_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/memcpy_to_host.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/moe.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/nhwc_conv.py +33 -0
- onnxslim/core/shape_inference/standard_ops/nn/reciprocal.py +9 -0
- onnxslim/core/shape_inference/standard_ops/nn/round.py +9 -0
- onnxslim/core/shape_inference/standard_ops/sequence/__init__.py +10 -0
- onnxslim/core/shape_inference/standard_ops/sequence/concat_from_sequence.py +40 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_at.py +31 -0
- onnxslim/core/shape_inference/standard_ops/sequence/sequence_insert.py +26 -0
- onnxslim/core/shape_inference/standard_ops/sequence/split_to_sequence.py +24 -0
- onnxslim/core/shape_inference/standard_ops/sequence/zip_map.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/__init__.py +20 -0
- onnxslim/core/shape_inference/standard_ops/tensor/concat.py +62 -0
- onnxslim/core/shape_inference/standard_ops/tensor/expand.py +36 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather.py +48 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_elements.py +31 -0
- onnxslim/core/shape_inference/standard_ops/tensor/gather_nd.py +42 -0
- onnxslim/core/shape_inference/standard_ops/tensor/pad.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/reshape.py +72 -0
- onnxslim/core/shape_inference/standard_ops/tensor/shape.py +38 -0
- onnxslim/core/shape_inference/standard_ops/tensor/size.py +29 -0
- onnxslim/core/shape_inference/standard_ops/tensor/slice.py +183 -0
- onnxslim/core/shape_inference/standard_ops/tensor/split.py +57 -0
- onnxslim/core/shape_inference/standard_ops/tensor/squeeze.py +69 -0
- onnxslim/core/shape_inference/standard_ops/tensor/tile.py +41 -0
- onnxslim/core/shape_inference/standard_ops/tensor/transpose.py +30 -0
- onnxslim/core/shape_inference/standard_ops/tensor/unsqueeze.py +54 -0
- onnxslim/core/shape_inference/utils.py +244 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +0 -103
- onnxslim/third_party/symbolic_shape_infer.py +73 -3156
- onnxslim/utils.py +4 -2
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/METADATA +21 -11
- onnxslim-0.1.84.dist-info/RECORD +187 -0
- onnxslim-0.1.82.dist-info/RECORD +0 -63
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.82.dist-info → onnxslim-0.1.84.dist-info}/licenses/LICENSE +0 -0
|
@@ -794,109 +794,6 @@ class Graph:
|
|
|
794
794
|
tensor.to_constant(arr)
|
|
795
795
|
tensor.inputs.clear()
|
|
796
796
|
|
|
797
|
-
# Pass 2: Run shape-tensor cast elision
|
|
798
|
-
def run_cast_elision(node):
|
|
799
|
-
"""Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations."""
|
|
800
|
-
import onnx
|
|
801
|
-
|
|
802
|
-
# Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
|
|
803
|
-
# This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
|
|
804
|
-
# are not allowed to be floating point type. Attempt to fold the pattern here
|
|
805
|
-
VALID_CAST_ELISION_OPS = {
|
|
806
|
-
"Add",
|
|
807
|
-
"Sub",
|
|
808
|
-
"Mul",
|
|
809
|
-
"Div",
|
|
810
|
-
"Max",
|
|
811
|
-
"Min",
|
|
812
|
-
"Equal",
|
|
813
|
-
"Greater",
|
|
814
|
-
"Less",
|
|
815
|
-
"Concat",
|
|
816
|
-
}
|
|
817
|
-
|
|
818
|
-
if node.op not in VALID_CAST_ELISION_OPS:
|
|
819
|
-
return
|
|
820
|
-
|
|
821
|
-
# If the uncasted outputs of this node have any consumers other than "Cast" nodes,
|
|
822
|
-
# then we cannot elide the cast.
|
|
823
|
-
for out_tensor in node.outputs:
|
|
824
|
-
if out_tensor in self.outputs:
|
|
825
|
-
return
|
|
826
|
-
|
|
827
|
-
if any(out_node.op != "Cast" for out_node in out_tensor.outputs):
|
|
828
|
-
return
|
|
829
|
-
|
|
830
|
-
# Get list of input nodes that cast to float32
|
|
831
|
-
inp_casts = [
|
|
832
|
-
inp_node
|
|
833
|
-
for inp_tensor in node.inputs
|
|
834
|
-
for inp_node in inp_tensor.inputs
|
|
835
|
-
if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
|
|
836
|
-
]
|
|
837
|
-
|
|
838
|
-
# No cast nodes found, return early
|
|
839
|
-
if not inp_casts:
|
|
840
|
-
return
|
|
841
|
-
|
|
842
|
-
# Ensure that all input cast nodes are casting from the same type
|
|
843
|
-
inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts]
|
|
844
|
-
if len(set(inp_dtypes)) != 1:
|
|
845
|
-
return
|
|
846
|
-
|
|
847
|
-
final_type = inp_dtypes[0]
|
|
848
|
-
|
|
849
|
-
# Get list of output nodes that cast to int32 or int64
|
|
850
|
-
out_casts = [
|
|
851
|
-
out_node
|
|
852
|
-
for out_tensor in node.outputs
|
|
853
|
-
for out_node in out_tensor.outputs
|
|
854
|
-
if out_node.op == "Cast"
|
|
855
|
-
and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64}
|
|
856
|
-
]
|
|
857
|
-
|
|
858
|
-
# No cast node found on outputs, return early
|
|
859
|
-
if not out_casts:
|
|
860
|
-
return
|
|
861
|
-
|
|
862
|
-
# Ensure that all output cast nodes are casting to the same type and that this
|
|
863
|
-
# matches the original type before the inputs were casted.
|
|
864
|
-
out_dtypes = [out_cast.attrs["to"] for out_cast in out_casts]
|
|
865
|
-
if len(set(out_dtypes)) != 1 or out_dtypes[0] != final_type:
|
|
866
|
-
return
|
|
867
|
-
|
|
868
|
-
# If all checks passed, reconnect inputs/outputs to the consumers/producers
|
|
869
|
-
# of the Cast nodes.
|
|
870
|
-
# Note that we need to be careful in how we rebind tensors since they may
|
|
871
|
-
# be used by multiple nodes. Thus, it is not necessarily safe to assume that
|
|
872
|
-
# `cast_node.inputs[0].outputs[0] == cast_node`.
|
|
873
|
-
for index, inp in enumerate(node.inputs):
|
|
874
|
-
if isinstance(inp, Constant):
|
|
875
|
-
inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type))
|
|
876
|
-
|
|
877
|
-
for cast in inp_casts:
|
|
878
|
-
if cast.outputs[0] == inp:
|
|
879
|
-
node.inputs[index] = cast.inputs[0]
|
|
880
|
-
|
|
881
|
-
for index, out in enumerate(node.outputs):
|
|
882
|
-
for cast in out_casts:
|
|
883
|
-
if cast.inputs[0] == out:
|
|
884
|
-
out_tensor = cast.outputs[0]
|
|
885
|
-
out_tensor.inputs.clear() # Disconnect from Cast
|
|
886
|
-
node.outputs[index] = out_tensor
|
|
887
|
-
|
|
888
|
-
if fold_shapes:
|
|
889
|
-
# Perform shape tensor cast elision prior to most other folding
|
|
890
|
-
G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}")
|
|
891
|
-
try:
|
|
892
|
-
with self.node_ids():
|
|
893
|
-
for node in self.nodes:
|
|
894
|
-
run_cast_elision(node)
|
|
895
|
-
except Exception as err:
|
|
896
|
-
if not error_ok:
|
|
897
|
-
raise err
|
|
898
|
-
G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
|
|
899
|
-
|
|
900
797
|
# Note that most of the remaining passes operate on a clone of the original graph.
|
|
901
798
|
# Pass 3: Find all descendants of constant tensors
|
|
902
799
|
|