onnxslim 0.1.83__py3-none-any.whl → 0.1.84__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -39,6 +39,16 @@ class SlicePatternMatcher(PatternMatcher):
39
39
  first_slice_node_axes = first_slice_node_inputs[3].values.tolist()
40
40
  first_slice_node_steps = first_slice_node_inputs[4].values.tolist()
41
41
 
42
+ # Check all users upfront before modifying the graph.
43
+ # If any user has overlapping axes, skip the optimization entirely
44
+ # to avoid corrupting the graph (fixes GitHub issue #277).
45
+ for user_node in first_slice_node_users:
46
+ second_slice_node_inputs = list(user_node.inputs)
47
+ second_slice_node_axes = second_slice_node_inputs[3].values.tolist()
48
+ new_axes = first_slice_node_axes + second_slice_node_axes
49
+ if len(new_axes) != len(set(new_axes)):
50
+ return match_case
51
+
42
52
  for user_node in first_slice_node_users:
43
53
  second_slice_node = user_node
44
54
  second_slice_node_inputs = list(second_slice_node.inputs)
@@ -52,33 +62,30 @@ class SlicePatternMatcher(PatternMatcher):
52
62
  new_axes = first_slice_node_axes + second_slice_node_axes
53
63
  new_steps = first_slice_node_steps + second_slice_node_steps
54
64
 
55
- if len(new_axes) != len(set(new_axes)):
56
- continue
57
-
58
65
  inputs = []
66
+ output_name = second_slice_node.outputs[0].name
59
67
  inputs.extend(
60
68
  (
61
69
  next(iter(first_slice_node.inputs)),
62
70
  gs.Constant(
63
- second_slice_node_inputs[1].name + "_starts",
71
+ output_name + "_starts",
64
72
  values=np.array(new_starts, dtype=np.int64),
65
73
  ),
66
74
  gs.Constant(
67
- second_slice_node_inputs[2].name + "_ends",
75
+ output_name + "_ends",
68
76
  values=np.array(new_ends, dtype=np.int64),
69
77
  ),
70
78
  gs.Constant(
71
- second_slice_node_inputs[3].name + "_axes",
79
+ output_name + "_axes",
72
80
  values=np.array(new_axes, dtype=np.int64),
73
81
  ),
74
82
  gs.Constant(
75
- second_slice_node_inputs[4].name + "_steps",
83
+ output_name + "_steps",
76
84
  values=np.array(new_steps, dtype=np.int64),
77
85
  ),
78
86
  )
79
87
  )
80
88
  outputs = list(second_slice_node.outputs)
81
-
82
89
  first_slice_node.outputs.clear()
83
90
  second_slice_node.inputs.clear()
84
91
  second_slice_node.outputs.clear()
@@ -36,9 +36,11 @@ class ConcatReshapeMatcher(PatternMatcher):
36
36
  def rewrite(self, opset=11):
37
37
  match_case = {}
38
38
  concat_node = self.concat_0
39
+ reshape_node = self.reshape_0
39
40
  index = next(idx for idx, i in enumerate(concat_node.inputs) if isinstance(i, gs.Variable))
41
+ output_name = reshape_node.outputs[0].name
40
42
  constant = gs.Constant(
41
- concat_node.inputs[index].name + "_fixed",
43
+ output_name + "_fixed",
42
44
  values=np.array([-1], dtype=np.int64),
43
45
  )
44
46
  concat_node.inputs.pop(index)
@@ -44,12 +44,8 @@ class ConvAddMatcher(PatternMatcher):
44
44
  inputs = []
45
45
  inputs.append(next(iter(conv_node.inputs)))
46
46
  inputs.append(conv_weight)
47
- weight_name = list(conv_node.inputs)[1].name
48
- if weight_name.endswith("weight"):
49
- bias_name = f"{weight_name[:-6]}bias"
50
- else:
51
- bias_name = f"{weight_name}_bias"
52
- inputs.append(gs.Constant(bias_name, values=conv_bias))
47
+ output_name = add_node.outputs[0].name
48
+ inputs.append(gs.Constant(output_name + "_bias", values=conv_bias))
53
49
  outputs = list(add_node.outputs)
54
50
 
55
51
  conv_node.outputs.clear()
@@ -52,15 +52,11 @@ class ConvBatchNormMatcher(PatternMatcher):
52
52
 
53
53
  inputs = []
54
54
  inputs.append(next(iter(conv_transpose_node.inputs)))
55
- weight_name = list(conv_transpose_node.inputs)[1].name
56
- if weight_name.endswith("weight"):
57
- bias_name = f"{weight_name[:-6]}bias"
58
- else:
59
- bias_name = f"{weight_name}_bias"
55
+ output_name = bn_node.outputs[0].name
60
56
  inputs.extend(
61
57
  (
62
- gs.Constant(weight_name + "_weight", values=conv_w),
63
- gs.Constant(bias_name, values=conv_b),
58
+ gs.Constant(output_name + "_weight", values=conv_w),
59
+ gs.Constant(output_name + "_bias", values=conv_b),
64
60
  )
65
61
  )
66
62
  outputs = list(bn_node.outputs)
@@ -38,14 +38,13 @@ class ConvMulMatcher(PatternMatcher):
38
38
  inputs = []
39
39
  inputs.append(next(iter(conv_node.inputs)))
40
40
 
41
- weight_name = list(conv_node.inputs)[1].name
42
- inputs.append(gs.Constant(weight_name, values=new_weight))
41
+ output_name = mul_node.outputs[0].name
42
+ inputs.append(gs.Constant(output_name + "_weight", values=new_weight))
43
43
 
44
44
  if len(conv_node.inputs) == 3:
45
45
  conv_bias = conv_node.inputs[2].values
46
46
  new_bias = conv_bias * mul_constant.squeeze()
47
- bias_name = list(conv_node.inputs)[2].name
48
- inputs.append(gs.Constant(bias_name, values=new_bias))
47
+ inputs.append(gs.Constant(output_name + "_bias", values=new_bias))
49
48
 
50
49
  outputs = list(mul_node.outputs)
51
50
 
@@ -76,7 +76,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
76
76
  output_variable.outputs.remove(add_node)
77
77
 
78
78
  matmul_bias_transpose_constant = gs.Constant(
79
- matmul_bias_variable.name, values=matmul_bias_variable.values.T
79
+ f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
80
80
  )
81
81
 
82
82
  inputs = []
@@ -143,7 +143,7 @@ class MatMulAddPatternMatcher(PatternMatcher):
143
143
  output_variable.outputs.remove(add_node)
144
144
 
145
145
  matmul_bias_transpose_constant = gs.Constant(
146
- matmul_bias_variable.name, values=matmul_bias_variable.values.T
146
+ f"{matmul_node.name}_weight", values=matmul_bias_variable.values.T
147
147
  )
148
148
 
149
149
  inputs = []
@@ -235,14 +235,15 @@ class GemmMulPatternMatcher(PatternMatcher):
235
235
  gemm_weight_fused = gemm_weight * mul_weight[:, None]
236
236
  else:
237
237
  gemm_weight_fused = gemm_weight * mul_weight
238
- gemm_weight_fused_constant = gs.Constant(gemm_weight_constant.name + "_fused", values=gemm_weight_fused)
238
+ output_name = reshape_node.outputs[0].name
239
+ gemm_weight_fused_constant = gs.Constant(output_name + "_weight_fused", values=gemm_weight_fused)
239
240
  gemm_node.inputs[1] = gemm_weight_fused_constant
240
241
 
241
242
  if gemm_bias_constant:
242
243
  gemm_bias = gemm_bias_constant.values
243
244
  mul_bias = mul_bias_variable.values
244
245
  gemm_bias_fused = gemm_bias * mul_bias
245
- gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
246
+ gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
246
247
  gemm_node.inputs[2] = gemm_bias_fused_constant
247
248
 
248
249
  mul_node.replace_all_uses_with(reshape_node)
@@ -312,7 +313,8 @@ class GemmAddPatternMatcher(PatternMatcher):
312
313
  and add_bias.ndim <= 2
313
314
  ):
314
315
  gemm_bias_fused = gemm_bias + add_bias
315
- gemm_bias_fused_constant = gs.Constant(gemm_bias_constant.name + "_fused", values=gemm_bias_fused)
316
+ output_name = reshape_node.outputs[0].name
317
+ gemm_bias_fused_constant = gs.Constant(output_name + "_bias_fused", values=gemm_bias_fused)
316
318
  gemm_node.inputs[2] = gemm_bias_fused_constant
317
319
  else:
318
320
  return match_case
@@ -794,109 +794,6 @@ class Graph:
794
794
  tensor.to_constant(arr)
795
795
  tensor.inputs.clear()
796
796
 
797
- # Pass 2: Run shape-tensor cast elision
798
- def run_cast_elision(node):
799
- """Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations."""
800
- import onnx
801
-
802
- # Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
803
- # This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
804
- # are not allowed to be floating point type. Attempt to fold the pattern here
805
- VALID_CAST_ELISION_OPS = {
806
- "Add",
807
- "Sub",
808
- "Mul",
809
- "Div",
810
- "Max",
811
- "Min",
812
- "Equal",
813
- "Greater",
814
- "Less",
815
- "Concat",
816
- }
817
-
818
- if node.op not in VALID_CAST_ELISION_OPS:
819
- return
820
-
821
- # If the uncasted outputs of this node have any consumers other than "Cast" nodes,
822
- # then we cannot elide the cast.
823
- for out_tensor in node.outputs:
824
- if out_tensor in self.outputs:
825
- return
826
-
827
- if any(out_node.op != "Cast" for out_node in out_tensor.outputs):
828
- return
829
-
830
- # Get list of input nodes that cast to float32
831
- inp_casts = [
832
- inp_node
833
- for inp_tensor in node.inputs
834
- for inp_node in inp_tensor.inputs
835
- if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
836
- ]
837
-
838
- # No cast nodes found, return early
839
- if not inp_casts:
840
- return
841
-
842
- # Ensure that all input cast nodes are casting from the same type
843
- inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts]
844
- if len(set(inp_dtypes)) != 1:
845
- return
846
-
847
- final_type = inp_dtypes[0]
848
-
849
- # Get list of output nodes that cast to int32 or int64
850
- out_casts = [
851
- out_node
852
- for out_tensor in node.outputs
853
- for out_node in out_tensor.outputs
854
- if out_node.op == "Cast"
855
- and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64}
856
- ]
857
-
858
- # No cast node found on outputs, return early
859
- if not out_casts:
860
- return
861
-
862
- # Ensure that all output cast nodes are casting to the same type and that this
863
- # matches the original type before the inputs were casted.
864
- out_dtypes = [out_cast.attrs["to"] for out_cast in out_casts]
865
- if len(set(out_dtypes)) != 1 or out_dtypes[0] != final_type:
866
- return
867
-
868
- # If all checks passed, reconnect inputs/outputs to the consumers/producers
869
- # of the Cast nodes.
870
- # Note that we need to be careful in how we rebind tensors since they may
871
- # be used by multiple nodes. Thus, it is not necessarily safe to assume that
872
- # `cast_node.inputs[0].outputs[0] == cast_node`.
873
- for index, inp in enumerate(node.inputs):
874
- if isinstance(inp, Constant):
875
- inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type))
876
-
877
- for cast in inp_casts:
878
- if cast.outputs[0] == inp:
879
- node.inputs[index] = cast.inputs[0]
880
-
881
- for index, out in enumerate(node.outputs):
882
- for cast in out_casts:
883
- if cast.inputs[0] == out:
884
- out_tensor = cast.outputs[0]
885
- out_tensor.inputs.clear() # Disconnect from Cast
886
- node.outputs[index] = out_tensor
887
-
888
- if fold_shapes:
889
- # Perform shape tensor cast elision prior to most other folding
890
- G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}")
891
- try:
892
- with self.node_ids():
893
- for node in self.nodes:
894
- run_cast_elision(node)
895
- except Exception as err:
896
- if not error_ok:
897
- raise err
898
- G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
899
-
900
797
  # Note that most of the remaining passes operate on a clone of the original graph.
901
798
  # Pass 3: Find all descendants of constant tensors
902
799
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnxslim
3
- Version: 0.1.83
3
+ Version: 0.1.84
4
4
  Summary: OnnxSlim: A Toolkit to Help Optimize Onnx Model
5
5
  Project-URL: homepage, https://github.com/inisis/OnnxSlim
6
6
  Project-URL: issues, https://github.com/inisis/OnnxSlim/issues
@@ -16,15 +16,15 @@ onnxslim/core/pattern/elimination/__init__.py,sha256=C9EwJj7DQmaXVvGx6wxvqvCdQGE
16
16
  onnxslim/core/pattern/elimination/concat.py,sha256=RmN3B0qtVixE_7QfgxsJHj2MUPOEdp8oxrcFN2oSR5Q,2261
17
17
  onnxslim/core/pattern/elimination/reshape.py,sha256=XwvuPAZnXCCEwJb2n1guigstnsl3wlxGygytH3GZXN8,3109
18
18
  onnxslim/core/pattern/elimination/reshape_as.py,sha256=FI3LYR0pzbp2pDmaX13duHrQ4uqwaKNu4bG78en-7wY,2034
19
- onnxslim/core/pattern/elimination/slice.py,sha256=moZibU-TbtdwtmGIUwyjnjf3oRCeCBcQq0M1gY5ZWDk,5033
19
+ onnxslim/core/pattern/elimination/slice.py,sha256=aOfxc7h4mottkK78gq8qoKYtLWBwnxoa7lnY1Z15hSc,5547
20
20
  onnxslim/core/pattern/elimination/unsqueeze.py,sha256=v7Rin3qB6F49ETrxXWEQQxUgtlF18nvHb6JFarf0kwQ,3855
21
21
  onnxslim/core/pattern/fusion/__init__.py,sha256=3ajHvRurL7WHL4tfNsBoLQh6Sq2fyiqH-VsPuftYMGg,183
22
- onnxslim/core/pattern/fusion/concat_reshape.py,sha256=LvknixTAsSUqUkGSuoEA1QpC-TmBrsx6AHZoeT0gTbI,1615
23
- onnxslim/core/pattern/fusion/convadd.py,sha256=ONORwlZbQ1kYJVAnCyGY6KLIicOOELmKm7-l2vbe078,3245
24
- onnxslim/core/pattern/fusion/convbn.py,sha256=ZsVDuAxe41f_eN9rt2psJLKQyzGMjO2RCcX9FKRNM1Y,4118
25
- onnxslim/core/pattern/fusion/convmul.py,sha256=aqq2fMtnMt7cXgQxdwu2hIk2kl-SI7FwpyCxtt9lT1w,3380
22
+ onnxslim/core/pattern/fusion/concat_reshape.py,sha256=9q1cPpOpO7s87k0r9qUFuLLMuTGJXOkOX3l7Xl1KiAQ,1685
23
+ onnxslim/core/pattern/fusion/convadd.py,sha256=4nOB6OGbKIBaM2nlxSdnOP_Ayer-1O7hu_hdaXVzF8M,3082
24
+ onnxslim/core/pattern/fusion/convbn.py,sha256=e8EXGSWmlBFrM1tkTTZIXaLwSXh82V3XKie4D2cm1nY,3944
25
+ onnxslim/core/pattern/fusion/convmul.py,sha256=2QbbqxtzATXZMsCtcP4EcZQ1vj8Rb1yFFSiC72zG22Q,3335
26
26
  onnxslim/core/pattern/fusion/gelu.py,sha256=uR67AJ_tL1gboY6VsTdqajHxW3Pbu656UMhCe1mQZDY,1469
27
- onnxslim/core/pattern/fusion/gemm.py,sha256=Ti9yZAfEprFRvW1FiAD0zvewELOJbRjposIk3yjjXfQ,12928
27
+ onnxslim/core/pattern/fusion/gemm.py,sha256=-Fdp3FkD54Kw1yC-2FXQ1NzaSvr4IRxmR7ObL5_cJTI,13035
28
28
  onnxslim/core/pattern/fusion/padconv.py,sha256=oF-Z4tlyu-AAWJMQDoszNITNgd2mb0vAg2gi0RwQuMo,3838
29
29
  onnxslim/core/pattern/fusion/reduce.py,sha256=dMC7CPlFglrJxugsJWjcc-jQCIa_GIbW1y9K2FRvvcE,2755
30
30
  onnxslim/core/shape_inference/__init__.py,sha256=iMAX6y6LsR8S3DOpeshPaMQLS3Plj4zYBdSaLGRYIts,16833
@@ -172,7 +172,7 @@ onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py,sha256=ESIul1p
172
172
  onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py,sha256=qa86Ne8yWCmpoAPBWV2lV1hlCvnQ6UPe-M1JXSfnMqM,23097
173
173
  onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
174
174
  onnxslim/third_party/onnx_graphsurgeon/ir/function.py,sha256=X1Rd1ZQlHhK6crg788a-LCmQSzv446LGfw376_Cz8Co,11820
175
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py,sha256=RU1luTR5sGMbPbRXtKGsYtBIv5BlXZOo7gU6bv0L5FY,70494
175
+ onnxslim/third_party/onnx_graphsurgeon/ir/graph.py,sha256=BEHXQoQMYclhEld5_o2MeA9zgpPZECfe6J9VenhiPgk,66101
176
176
  onnxslim/third_party/onnx_graphsurgeon/ir/node.py,sha256=lHrJCNRhtPRZrE7vuvQkG_wfEsJzDW7Wf-T_kr4OJHI,9996
177
177
  onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py,sha256=bypjlsVp1qByPhJRbTSjSrPpoatmMykjnJ9_cnnmz9Y,19265
178
178
  onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py,sha256=b6lAvvrKZKNtCZOgcvz2Aj9lUO5mw5JM8UFP5BqBOnQ,83
@@ -180,8 +180,8 @@ onnxslim/third_party/onnx_graphsurgeon/logger/logger.py,sha256=L12rrwn33RHH-2WLv
180
180
  onnxslim/third_party/onnx_graphsurgeon/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
181
181
  onnxslim/third_party/onnx_graphsurgeon/util/exception.py,sha256=KrsHbKEQ4237UbjlODsUzvkXoAY72LZi23ApBeFANWg,786
182
182
  onnxslim/third_party/onnx_graphsurgeon/util/misc.py,sha256=kyxInD2SCRLU4wHMeiDEYEHB3871fGks6kQTuF9uATY,8960
183
- onnxslim-0.1.83.dist-info/METADATA,sha256=Npm1SQ2CnsjAh0NF6Z5twoqjiu9IJLfrLRh4KkvEALo,10651
184
- onnxslim-0.1.83.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
185
- onnxslim-0.1.83.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
186
- onnxslim-0.1.83.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
187
- onnxslim-0.1.83.dist-info/RECORD,,
183
+ onnxslim-0.1.84.dist-info/METADATA,sha256=ZoGC6wTTau3dqyYbSA8rtLL8ghV8TV5KwDXSmZw9yjo,10651
184
+ onnxslim-0.1.84.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
185
+ onnxslim-0.1.84.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
186
+ onnxslim-0.1.84.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
187
+ onnxslim-0.1.84.dist-info/RECORD,,