onnxslim 0.1.69__py3-none-any.whl → 0.1.71__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
onnxslim/__init__.py CHANGED
@@ -2,9 +2,7 @@ import os
2
2
  import warnings
3
3
 
4
4
  from onnxslim.cli import slim
5
- from onnxslim.core.optimization import OptimizationSettings
6
5
  from onnxslim.core.pattern.registry import (
7
- DEFAULT_FUSION_PATTERNS,
8
6
  register_fusion_pattern,
9
7
  )
10
8
  from onnxslim.version import __version__
onnxslim/argparser.py CHANGED
@@ -2,10 +2,28 @@ import argparse
2
2
  import dataclasses
3
3
  from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
4
4
  from dataclasses import dataclass, field
5
- from typing import List, Optional, Type, Union, get_args, get_origin
6
-
7
- import onnxslim
8
-
5
+ from typing import List, Optional, Type, Union, get_args, get_origin, TypedDict, Dict, Literal
6
+
7
+ from .core.optimization import OptimizationSettings
8
+ from .core.pattern.registry import DEFAULT_FUSION_PATTERNS
9
+ from .version import __version__
10
+
11
+
12
+ class OnnxSlimKwargs(TypedDict, total=False):
13
+ model_check: bool
14
+ input_shapes: Dict[str, List[int]]
15
+ inputs: List[str]
16
+ outputs: List[str]
17
+ no_shape_infer: bool
18
+ skip_optimizations: List[str]
19
+ dtype: Literal["float16", "float32", "uint8", "int8"]
20
+ skip_fusion_patterns: List[str]
21
+ size_threshold: int
22
+ inspect: bool
23
+ dump_to_disk: bool
24
+ save_as_external_data: bool
25
+ model_check_inputs: Optional[List[str]]
26
+ verbose: bool
9
27
 
10
28
  def _get_inner_type(arg_type):
11
29
  if get_origin(arg_type) is Union:
@@ -42,14 +60,14 @@ class OptimizationArguments:
42
60
  default=None,
43
61
  metadata={
44
62
  "help": "whether to skip some optimizations",
45
- "choices": list(onnxslim.OptimizationSettings.keys()),
63
+ "choices": list(OptimizationSettings.keys()),
46
64
  },
47
65
  )
48
66
  skip_fusion_patterns: Optional[List[str]] = field(
49
67
  default=None,
50
68
  metadata={
51
69
  "help": "whether to skip the fusion of some patterns",
52
- "choices": list(onnxslim.DEFAULT_FUSION_PATTERNS.keys()),
70
+ "choices": list(DEFAULT_FUSION_PATTERNS.keys()),
53
71
  },
54
72
  )
55
73
  size_threshold: int = field(
@@ -173,7 +191,7 @@ class OnnxSlimArgumentParser(ArgumentParser):
173
191
  # Add positional arguments separately for ModelArguments
174
192
  self.parser.add_argument("input_model", help="input onnx model")
175
193
  self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model")
176
- self.parser.add_argument("-v", "--version", action="version", version=onnxslim.__version__)
194
+ self.parser.add_argument("-v", "--version", action="version", version=__version__)
177
195
 
178
196
  def parse_args_into_dataclasses(self):
179
197
  # Pre-parse arguments to check for `--inspect`
onnxslim/cli/_main.py CHANGED
@@ -2,8 +2,10 @@ from __future__ import annotations
2
2
 
3
3
  import onnx
4
4
 
5
+ from onnxslim.argparser import OnnxSlimKwargs
5
6
 
6
- def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs):
7
+
8
+ def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
7
9
  import os
8
10
  import time
9
11
  from pathlib import Path
@@ -57,10 +57,7 @@ def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str
57
57
  if OptimizationSettings.graph_fusion:
58
58
  logger.debug("Start graph_fusion.")
59
59
  fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
60
- fusion_pairs = find_matches(graph, fusion_patterns)
61
- for match in fusion_pairs.values():
62
- graph.replace_custom_layer(**match)
63
- graph.cleanup(remove_unused_graph_inputs=True).toposort()
60
+ graph_fusion(graph, fusion_patterns)
64
61
  logger.debug("Finish graph_fusion.")
65
62
  if OptimizationSettings.dead_node_elimination:
66
63
  logger.debug("Start dead_node_elimination.")
@@ -102,9 +99,21 @@ def replace_custom_layer(
102
99
  )
103
100
 
104
101
 
102
+ def graph_fusion(graph: Graph, fusion_patterns: dict, is_subgraph=False):
103
+ for subgraph in graph.subgraphs():
104
+ graph_fusion(subgraph, fusion_patterns, is_subgraph=True)
105
+
106
+ fusion_pairs = find_matches(graph, fusion_patterns)
107
+ for match in fusion_pairs.values():
108
+ graph.replace_custom_layer(**match)
109
+
110
+ graph.cleanup(remove_unused_graph_inputs=True if not is_subgraph else False).toposort()
111
+
112
+
105
113
  def find_matches(graph: Graph, fusion_patterns: dict):
106
114
  """Find matching patterns in the graph based on provided fusion patterns."""
107
115
  match_map = {}
116
+
108
117
  counter = Counter()
109
118
  for node in reversed(graph.nodes):
110
119
  if node.name not in match_map:
@@ -63,6 +63,9 @@ def subexpression_elimination(graph):
63
63
  """Perform subexpression elimination on a computational graph to optimize node operations."""
64
64
  nodes_by_op = {}
65
65
 
66
+ for subgraph in graph.subgraphs():
67
+ subexpression_elimination(subgraph)
68
+
66
69
  for node in graph.nodes:
67
70
  op = node.op
68
71
  if op not in nodes_by_op:
@@ -530,7 +530,7 @@ class SymbolicShapeInference:
530
530
  # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
531
531
  # (3) The initializer is not in graph input. The means the node input is "constant" in inference.
532
532
  initializers = []
533
- if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze":
533
+ if (get_opset(self.out_mp_) >= 9) and (node.op_type == "Unsqueeze" or node.op_type == "ReduceMax"):
534
534
  initializers = [
535
535
  self.initializers_[name]
536
536
  for name in node.input
@@ -1042,6 +1042,9 @@ class SymbolicShapeInference:
1042
1042
  def _infer_Constant(self, node): # noqa: N802
1043
1043
  """Infer the constant value for a given node and store it in sympy_data_."""
1044
1044
  t = get_attribute(node, "value")
1045
+ # Lower constant nodes to initializers
1046
+ t.name = node.output[0]
1047
+ self.initializers_[node.output[0]] = t
1045
1048
  self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
1046
1049
 
1047
1050
  def _infer_ConstantOfShape(self, node): # noqa: N802
@@ -1944,7 +1947,25 @@ class SymbolicShapeInference:
1944
1947
 
1945
1948
  def _infer_Shape(self, node): # noqa: N802
1946
1949
  """Infers and sets the symbolic shape for the output node in the computation graph."""
1947
- self.sympy_data_[node.output[0]] = self._get_sympy_shape(node, 0)
1950
+ start = get_attribute(node, "start", 0)
1951
+ end = get_attribute(node, "end", None)
1952
+
1953
+ full_sympy_shape = self._get_sympy_shape(node, 0)
1954
+ num_dims = len(full_sympy_shape)
1955
+
1956
+ if start < 0:
1957
+ start = num_dims + start
1958
+ if end is None:
1959
+ end = num_dims
1960
+ elif end < 0:
1961
+ end = num_dims + end
1962
+
1963
+ assert 0 <= start <= end <= num_dims, (
1964
+ f"reshape start/end invalid: start={start}, end={end}, total_dims={num_dims}"
1965
+ )
1966
+
1967
+ target_sympy_shape = full_sympy_shape[start:end]
1968
+ self.sympy_data_[node.output[0]] = target_sympy_shape
1948
1969
 
1949
1970
  def _infer_Size(self, node): # noqa: N802
1950
1971
  """Infers and sets the size of the output node by computing the product of its shape in the computation
onnxslim/utils.py CHANGED
@@ -599,8 +599,8 @@ def get_itemsize(dtype):
599
599
  ]:
600
600
  return 1
601
601
 
602
- print(dtype)
603
- raise
602
+ print(f"Unknown ONNX dtype: {dtype}")
603
+ raise ValueError(f"Unsupported TensorProto dtype: {dtype}")
604
604
 
605
605
 
606
606
  def calculate_tensor_size(tensor):
onnxslim/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.69"
1
+ __version__ = "0.1.71"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnxslim
3
- Version: 0.1.69
3
+ Version: 0.1.71
4
4
  Summary: OnnxSlim: A Toolkit to Help Optimize Onnx Model
5
5
  Home-page: https://github.com/inisis/OnnxSlim
6
6
  Author: inisis
@@ -1,14 +1,14 @@
1
- onnxslim/__init__.py,sha256=-zmz0A4bd_t8VgNXZVMizIpKJswIlC99hr8juWYm2kc,660
1
+ onnxslim/__init__.py,sha256=ECHGdxzg4b-4SZaPhxM_KulBi-xDbVcVUbpJc8i6a60,571
2
2
  onnxslim/__main__.py,sha256=FgDcl6xX8kV_52rB-jPVsmGqidlVhkpe_YhXK75-nFU,75
3
- onnxslim/argparser.py,sha256=aLVOV7c_HmsvuuAeb1FMRyqXTPRtjwm-bGt4O-7AxMg,8350
4
- onnxslim/utils.py,sha256=uNh3Mr4DberXtKL6p649QrrLR7gYbldYuIm6QLo14Zg,28043
5
- onnxslim/version.py,sha256=3nwo1wJ5elZRHnCuCkc-pqzqRZO_38VMOqiKtCotf8M,23
3
+ onnxslim/argparser.py,sha256=pFv3nEZH2BiHO9ejS4Iq5ZuZ3GrpdyRQJypAyR0xF7w,8942
4
+ onnxslim/utils.py,sha256=1CDFOOvZVl6pj1K6eR44IagxbM44UBzZGTbcTgW_ovw,28122
5
+ onnxslim/version.py,sha256=ebu6Nblu_UmpciwG6xnbUJm-16F-ZA6L-sagDt37smo,23
6
6
  onnxslim/cli/__init__.py,sha256=kxK27cDgWotBOWRs86rbRQf_dtmniKr1GZJeasxfESE,42
7
- onnxslim/cli/_main.py,sha256=CPitq8khtOLCrJxvmam61ONppACfjQUzRshZtyabZak,5487
7
+ onnxslim/cli/_main.py,sha256=jEKv9q7y_y9g0GsrfXcnk_wyMVej6jhe9QNPChE7yTs,5550
8
8
  onnxslim/core/__init__.py,sha256=d6zfD4rcR51yJIFg1LpGQt2tK4g1WLmA87_7Gna89OU,8469
9
- onnxslim/core/optimization/__init__.py,sha256=cAzv2YoZMS2krWoylrqDd9HU0aAhSCAM29dZclKqOQo,4777
9
+ onnxslim/core/optimization/__init__.py,sha256=M72h_Jyj9j0bAWI5GRt5MkAnXledYWvH-y9f9FG36Nw,5020
10
10
  onnxslim/core/optimization/dead_node_elimination.py,sha256=xAO37JvW7Yb3mT5cXY6K6cWMAvbmFL0pczpcRQfd5Ek,7557
11
- onnxslim/core/optimization/subexpression_elimination.py,sha256=g_AypWGzW60QaHMqxFZ9elCo8CxqrshzUwtVSPhpgDA,2754
11
+ onnxslim/core/optimization/subexpression_elimination.py,sha256=wauPADo0msVifYGgJ_fiNmPSe7LrJ7A-D3ur5z1wif8,2838
12
12
  onnxslim/core/optimization/weight_tying.py,sha256=uAz_2AJa-s-budTcDXIr6ZcME1hXPdm54CzR-u850X4,2347
13
13
  onnxslim/core/pattern/__init__.py,sha256=mIqdX7JC7oiypOBbaCC5bG3km7x6MIvqrz4koXzRX1U,9434
14
14
  onnxslim/core/pattern/registry.py,sha256=Dv7FqS9E-zLfvWCTh682PNRyTFcZoYiuH2MnfjMNpec,831
@@ -29,7 +29,7 @@ onnxslim/core/pattern/fusion/reduce.py,sha256=dMC7CPlFglrJxugsJWjcc-jQCIa_GIbW1y
29
29
  onnxslim/misc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  onnxslim/misc/tabulate.py,sha256=pytOhVcWPW6B464BvG0rycbz8QA0MAeCSqQ7EO93tqs,99648
31
31
  onnxslim/third_party/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- onnxslim/third_party/symbolic_shape_infer.py,sha256=bAyxkpF0ahHvdThvE9Lpbs70pqV6jAivfCcKQHlBTeE,153036
32
+ onnxslim/third_party/symbolic_shape_infer.py,sha256=mxpOiv0QYEWZ0Pm8NW7ABRv2X-L_luBRCCKDri0JlmA,153751
33
33
  onnxslim/third_party/_sympy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
34
  onnxslim/third_party/_sympy/functions.py,sha256=FrrTHPkDhwh_EccLGugbcodeOjfGeM17T5_oVPUjOlA,8877
35
35
  onnxslim/third_party/_sympy/numbers.py,sha256=w1dJJcQkKRzLDCJMn70YTwPtrEWPFrhCJpAsZhukJOk,11383
@@ -55,10 +55,10 @@ onnxslim/third_party/onnx_graphsurgeon/logger/logger.py,sha256=L12rrwn33RHH-2WLv
55
55
  onnxslim/third_party/onnx_graphsurgeon/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
56
56
  onnxslim/third_party/onnx_graphsurgeon/util/exception.py,sha256=KrsHbKEQ4237UbjlODsUzvkXoAY72LZi23ApBeFANWg,786
57
57
  onnxslim/third_party/onnx_graphsurgeon/util/misc.py,sha256=kyxInD2SCRLU4wHMeiDEYEHB3871fGks6kQTuF9uATY,8960
58
- onnxslim-0.1.69.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
59
- onnxslim-0.1.69.dist-info/METADATA,sha256=NMqsOJCdOhIfLwnZnOmPSlUw7PV1Vu8YhIoG1zyR-10,7621
60
- onnxslim-0.1.69.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
- onnxslim-0.1.69.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
62
- onnxslim-0.1.69.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
63
- onnxslim-0.1.69.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
64
- onnxslim-0.1.69.dist-info/RECORD,,
58
+ onnxslim-0.1.71.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
59
+ onnxslim-0.1.71.dist-info/METADATA,sha256=pYpwf2phxrDay3yTJceIS_MedaWzMTJxhGfdi9hbWcE,7621
60
+ onnxslim-0.1.71.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
61
+ onnxslim-0.1.71.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
62
+ onnxslim-0.1.71.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
63
+ onnxslim-0.1.71.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
64
+ onnxslim-0.1.71.dist-info/RECORD,,