onnxslim 0.1.69__py3-none-any.whl → 0.1.71__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +0 -2
- onnxslim/argparser.py +25 -7
- onnxslim/cli/_main.py +3 -1
- onnxslim/core/optimization/__init__.py +13 -4
- onnxslim/core/optimization/subexpression_elimination.py +3 -0
- onnxslim/third_party/symbolic_shape_infer.py +23 -2
- onnxslim/utils.py +2 -2
- onnxslim/version.py +1 -1
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/METADATA +1 -1
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/RECORD +15 -15
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/WHEEL +0 -0
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/entry_points.txt +0 -0
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/licenses/LICENSE +0 -0
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/top_level.txt +0 -0
- {onnxslim-0.1.69.dist-info → onnxslim-0.1.71.dist-info}/zip-safe +0 -0
onnxslim/__init__.py
CHANGED
|
@@ -2,9 +2,7 @@ import os
|
|
|
2
2
|
import warnings
|
|
3
3
|
|
|
4
4
|
from onnxslim.cli import slim
|
|
5
|
-
from onnxslim.core.optimization import OptimizationSettings
|
|
6
5
|
from onnxslim.core.pattern.registry import (
|
|
7
|
-
DEFAULT_FUSION_PATTERNS,
|
|
8
6
|
register_fusion_pattern,
|
|
9
7
|
)
|
|
10
8
|
from onnxslim.version import __version__
|
onnxslim/argparser.py
CHANGED
|
@@ -2,10 +2,28 @@ import argparse
|
|
|
2
2
|
import dataclasses
|
|
3
3
|
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
4
4
|
from dataclasses import dataclass, field
|
|
5
|
-
from typing import List, Optional, Type, Union, get_args, get_origin
|
|
6
|
-
|
|
7
|
-
import
|
|
8
|
-
|
|
5
|
+
from typing import List, Optional, Type, Union, get_args, get_origin, TypedDict, Dict, Literal
|
|
6
|
+
|
|
7
|
+
from .core.optimization import OptimizationSettings
|
|
8
|
+
from .core.pattern.registry import DEFAULT_FUSION_PATTERNS
|
|
9
|
+
from .version import __version__
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OnnxSlimKwargs(TypedDict, total=False):
|
|
13
|
+
model_check: bool
|
|
14
|
+
input_shapes: Dict[str, List[int]]
|
|
15
|
+
inputs: List[str]
|
|
16
|
+
outputs: List[str]
|
|
17
|
+
no_shape_infer: bool
|
|
18
|
+
skip_optimizations: List[str]
|
|
19
|
+
dtype: Literal["float16", "float32", "uint8", "int8"]
|
|
20
|
+
skip_fusion_patterns: List[str]
|
|
21
|
+
size_threshold: int
|
|
22
|
+
inspect: bool
|
|
23
|
+
dump_to_disk: bool
|
|
24
|
+
save_as_external_data: bool
|
|
25
|
+
model_check_inputs: Optional[List[str]]
|
|
26
|
+
verbose: bool
|
|
9
27
|
|
|
10
28
|
def _get_inner_type(arg_type):
|
|
11
29
|
if get_origin(arg_type) is Union:
|
|
@@ -42,14 +60,14 @@ class OptimizationArguments:
|
|
|
42
60
|
default=None,
|
|
43
61
|
metadata={
|
|
44
62
|
"help": "whether to skip some optimizations",
|
|
45
|
-
"choices": list(
|
|
63
|
+
"choices": list(OptimizationSettings.keys()),
|
|
46
64
|
},
|
|
47
65
|
)
|
|
48
66
|
skip_fusion_patterns: Optional[List[str]] = field(
|
|
49
67
|
default=None,
|
|
50
68
|
metadata={
|
|
51
69
|
"help": "whether to skip the fusion of some patterns",
|
|
52
|
-
"choices": list(
|
|
70
|
+
"choices": list(DEFAULT_FUSION_PATTERNS.keys()),
|
|
53
71
|
},
|
|
54
72
|
)
|
|
55
73
|
size_threshold: int = field(
|
|
@@ -173,7 +191,7 @@ class OnnxSlimArgumentParser(ArgumentParser):
|
|
|
173
191
|
# Add positional arguments separately for ModelArguments
|
|
174
192
|
self.parser.add_argument("input_model", help="input onnx model")
|
|
175
193
|
self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model")
|
|
176
|
-
self.parser.add_argument("-v", "--version", action="version", version=
|
|
194
|
+
self.parser.add_argument("-v", "--version", action="version", version=__version__)
|
|
177
195
|
|
|
178
196
|
def parse_args_into_dataclasses(self):
|
|
179
197
|
# Pre-parse arguments to check for `--inspect`
|
onnxslim/cli/_main.py
CHANGED
|
@@ -2,8 +2,10 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import onnx
|
|
4
4
|
|
|
5
|
+
from onnxslim.argparser import OnnxSlimKwargs
|
|
5
6
|
|
|
6
|
-
|
|
7
|
+
|
|
8
|
+
def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
|
|
7
9
|
import os
|
|
8
10
|
import time
|
|
9
11
|
from pathlib import Path
|
|
@@ -57,10 +57,7 @@ def optimize_model(model: onnx.ModelProto | gs.Graph, skip_fusion_patterns: str
|
|
|
57
57
|
if OptimizationSettings.graph_fusion:
|
|
58
58
|
logger.debug("Start graph_fusion.")
|
|
59
59
|
fusion_patterns = get_fusion_patterns(skip_fusion_patterns)
|
|
60
|
-
|
|
61
|
-
for match in fusion_pairs.values():
|
|
62
|
-
graph.replace_custom_layer(**match)
|
|
63
|
-
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
60
|
+
graph_fusion(graph, fusion_patterns)
|
|
64
61
|
logger.debug("Finish graph_fusion.")
|
|
65
62
|
if OptimizationSettings.dead_node_elimination:
|
|
66
63
|
logger.debug("Start dead_node_elimination.")
|
|
@@ -102,9 +99,21 @@ def replace_custom_layer(
|
|
|
102
99
|
)
|
|
103
100
|
|
|
104
101
|
|
|
102
|
+
def graph_fusion(graph: Graph, fusion_patterns: dict, is_subgraph=False):
|
|
103
|
+
for subgraph in graph.subgraphs():
|
|
104
|
+
graph_fusion(subgraph, fusion_patterns, is_subgraph=True)
|
|
105
|
+
|
|
106
|
+
fusion_pairs = find_matches(graph, fusion_patterns)
|
|
107
|
+
for match in fusion_pairs.values():
|
|
108
|
+
graph.replace_custom_layer(**match)
|
|
109
|
+
|
|
110
|
+
graph.cleanup(remove_unused_graph_inputs=True if not is_subgraph else False).toposort()
|
|
111
|
+
|
|
112
|
+
|
|
105
113
|
def find_matches(graph: Graph, fusion_patterns: dict):
|
|
106
114
|
"""Find matching patterns in the graph based on provided fusion patterns."""
|
|
107
115
|
match_map = {}
|
|
116
|
+
|
|
108
117
|
counter = Counter()
|
|
109
118
|
for node in reversed(graph.nodes):
|
|
110
119
|
if node.name not in match_map:
|
|
@@ -63,6 +63,9 @@ def subexpression_elimination(graph):
|
|
|
63
63
|
"""Perform subexpression elimination on a computational graph to optimize node operations."""
|
|
64
64
|
nodes_by_op = {}
|
|
65
65
|
|
|
66
|
+
for subgraph in graph.subgraphs():
|
|
67
|
+
subexpression_elimination(subgraph)
|
|
68
|
+
|
|
66
69
|
for node in graph.nodes:
|
|
67
70
|
op = node.op
|
|
68
71
|
if op not in nodes_by_op:
|
|
@@ -530,7 +530,7 @@ class SymbolicShapeInference:
|
|
|
530
530
|
# (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec.
|
|
531
531
|
# (3) The initializer is not in graph input. The means the node input is "constant" in inference.
|
|
532
532
|
initializers = []
|
|
533
|
-
if (get_opset(self.out_mp_) >= 9) and node.op_type == "Unsqueeze":
|
|
533
|
+
if (get_opset(self.out_mp_) >= 9) and (node.op_type == "Unsqueeze" or node.op_type == "ReduceMax"):
|
|
534
534
|
initializers = [
|
|
535
535
|
self.initializers_[name]
|
|
536
536
|
for name in node.input
|
|
@@ -1042,6 +1042,9 @@ class SymbolicShapeInference:
|
|
|
1042
1042
|
def _infer_Constant(self, node): # noqa: N802
|
|
1043
1043
|
"""Infer the constant value for a given node and store it in sympy_data_."""
|
|
1044
1044
|
t = get_attribute(node, "value")
|
|
1045
|
+
# Lower constant nodes to initializers
|
|
1046
|
+
t.name = node.output[0]
|
|
1047
|
+
self.initializers_[node.output[0]] = t
|
|
1045
1048
|
self.sympy_data_[node.output[0]] = numpy_helper.to_array(t)
|
|
1046
1049
|
|
|
1047
1050
|
def _infer_ConstantOfShape(self, node): # noqa: N802
|
|
@@ -1944,7 +1947,25 @@ class SymbolicShapeInference:
|
|
|
1944
1947
|
|
|
1945
1948
|
def _infer_Shape(self, node): # noqa: N802
|
|
1946
1949
|
"""Infers and sets the symbolic shape for the output node in the computation graph."""
|
|
1947
|
-
|
|
1950
|
+
start = get_attribute(node, "start", 0)
|
|
1951
|
+
end = get_attribute(node, "end", None)
|
|
1952
|
+
|
|
1953
|
+
full_sympy_shape = self._get_sympy_shape(node, 0)
|
|
1954
|
+
num_dims = len(full_sympy_shape)
|
|
1955
|
+
|
|
1956
|
+
if start < 0:
|
|
1957
|
+
start = num_dims + start
|
|
1958
|
+
if end is None:
|
|
1959
|
+
end = num_dims
|
|
1960
|
+
elif end < 0:
|
|
1961
|
+
end = num_dims + end
|
|
1962
|
+
|
|
1963
|
+
assert 0 <= start <= end <= num_dims, (
|
|
1964
|
+
f"reshape start/end invalid: start={start}, end={end}, total_dims={num_dims}"
|
|
1965
|
+
)
|
|
1966
|
+
|
|
1967
|
+
target_sympy_shape = full_sympy_shape[start:end]
|
|
1968
|
+
self.sympy_data_[node.output[0]] = target_sympy_shape
|
|
1948
1969
|
|
|
1949
1970
|
def _infer_Size(self, node): # noqa: N802
|
|
1950
1971
|
"""Infers and sets the size of the output node by computing the product of its shape in the computation
|
onnxslim/utils.py
CHANGED
onnxslim/version.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.1.
|
|
1
|
+
__version__ = "0.1.71"
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
onnxslim/__init__.py,sha256
|
|
1
|
+
onnxslim/__init__.py,sha256=ECHGdxzg4b-4SZaPhxM_KulBi-xDbVcVUbpJc8i6a60,571
|
|
2
2
|
onnxslim/__main__.py,sha256=FgDcl6xX8kV_52rB-jPVsmGqidlVhkpe_YhXK75-nFU,75
|
|
3
|
-
onnxslim/argparser.py,sha256=
|
|
4
|
-
onnxslim/utils.py,sha256=
|
|
5
|
-
onnxslim/version.py,sha256=
|
|
3
|
+
onnxslim/argparser.py,sha256=pFv3nEZH2BiHO9ejS4Iq5ZuZ3GrpdyRQJypAyR0xF7w,8942
|
|
4
|
+
onnxslim/utils.py,sha256=1CDFOOvZVl6pj1K6eR44IagxbM44UBzZGTbcTgW_ovw,28122
|
|
5
|
+
onnxslim/version.py,sha256=ebu6Nblu_UmpciwG6xnbUJm-16F-ZA6L-sagDt37smo,23
|
|
6
6
|
onnxslim/cli/__init__.py,sha256=kxK27cDgWotBOWRs86rbRQf_dtmniKr1GZJeasxfESE,42
|
|
7
|
-
onnxslim/cli/_main.py,sha256=
|
|
7
|
+
onnxslim/cli/_main.py,sha256=jEKv9q7y_y9g0GsrfXcnk_wyMVej6jhe9QNPChE7yTs,5550
|
|
8
8
|
onnxslim/core/__init__.py,sha256=d6zfD4rcR51yJIFg1LpGQt2tK4g1WLmA87_7Gna89OU,8469
|
|
9
|
-
onnxslim/core/optimization/__init__.py,sha256=
|
|
9
|
+
onnxslim/core/optimization/__init__.py,sha256=M72h_Jyj9j0bAWI5GRt5MkAnXledYWvH-y9f9FG36Nw,5020
|
|
10
10
|
onnxslim/core/optimization/dead_node_elimination.py,sha256=xAO37JvW7Yb3mT5cXY6K6cWMAvbmFL0pczpcRQfd5Ek,7557
|
|
11
|
-
onnxslim/core/optimization/subexpression_elimination.py,sha256=
|
|
11
|
+
onnxslim/core/optimization/subexpression_elimination.py,sha256=wauPADo0msVifYGgJ_fiNmPSe7LrJ7A-D3ur5z1wif8,2838
|
|
12
12
|
onnxslim/core/optimization/weight_tying.py,sha256=uAz_2AJa-s-budTcDXIr6ZcME1hXPdm54CzR-u850X4,2347
|
|
13
13
|
onnxslim/core/pattern/__init__.py,sha256=mIqdX7JC7oiypOBbaCC5bG3km7x6MIvqrz4koXzRX1U,9434
|
|
14
14
|
onnxslim/core/pattern/registry.py,sha256=Dv7FqS9E-zLfvWCTh682PNRyTFcZoYiuH2MnfjMNpec,831
|
|
@@ -29,7 +29,7 @@ onnxslim/core/pattern/fusion/reduce.py,sha256=dMC7CPlFglrJxugsJWjcc-jQCIa_GIbW1y
|
|
|
29
29
|
onnxslim/misc/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
onnxslim/misc/tabulate.py,sha256=pytOhVcWPW6B464BvG0rycbz8QA0MAeCSqQ7EO93tqs,99648
|
|
31
31
|
onnxslim/third_party/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
32
|
-
onnxslim/third_party/symbolic_shape_infer.py,sha256=
|
|
32
|
+
onnxslim/third_party/symbolic_shape_infer.py,sha256=mxpOiv0QYEWZ0Pm8NW7ABRv2X-L_luBRCCKDri0JlmA,153751
|
|
33
33
|
onnxslim/third_party/_sympy/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
34
|
onnxslim/third_party/_sympy/functions.py,sha256=FrrTHPkDhwh_EccLGugbcodeOjfGeM17T5_oVPUjOlA,8877
|
|
35
35
|
onnxslim/third_party/_sympy/numbers.py,sha256=w1dJJcQkKRzLDCJMn70YTwPtrEWPFrhCJpAsZhukJOk,11383
|
|
@@ -55,10 +55,10 @@ onnxslim/third_party/onnx_graphsurgeon/logger/logger.py,sha256=L12rrwn33RHH-2WLv
|
|
|
55
55
|
onnxslim/third_party/onnx_graphsurgeon/util/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
56
56
|
onnxslim/third_party/onnx_graphsurgeon/util/exception.py,sha256=KrsHbKEQ4237UbjlODsUzvkXoAY72LZi23ApBeFANWg,786
|
|
57
57
|
onnxslim/third_party/onnx_graphsurgeon/util/misc.py,sha256=kyxInD2SCRLU4wHMeiDEYEHB3871fGks6kQTuF9uATY,8960
|
|
58
|
-
onnxslim-0.1.
|
|
59
|
-
onnxslim-0.1.
|
|
60
|
-
onnxslim-0.1.
|
|
61
|
-
onnxslim-0.1.
|
|
62
|
-
onnxslim-0.1.
|
|
63
|
-
onnxslim-0.1.
|
|
64
|
-
onnxslim-0.1.
|
|
58
|
+
onnxslim-0.1.71.dist-info/licenses/LICENSE,sha256=oHZXw-yrBwdNVGu4JtlZhMgmQHKIZ7BJJlJdhu1HKvI,1062
|
|
59
|
+
onnxslim-0.1.71.dist-info/METADATA,sha256=pYpwf2phxrDay3yTJceIS_MedaWzMTJxhGfdi9hbWcE,7621
|
|
60
|
+
onnxslim-0.1.71.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
61
|
+
onnxslim-0.1.71.dist-info/entry_points.txt,sha256=O2QgceCVeGeRhnxRSDRcGiFd0ZNfElwrTiRo1W2V7KA,47
|
|
62
|
+
onnxslim-0.1.71.dist-info/top_level.txt,sha256=EWFTb99i0kc6cC9akqNKp88ipzg17_VZzYN7z1kQNlA,9
|
|
63
|
+
onnxslim-0.1.71.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
64
|
+
onnxslim-0.1.71.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|