onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
onnxslim/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import warnings
|
|
3
|
+
|
|
4
|
+
from onnxslim.cli import slim
|
|
5
|
+
from onnxslim.core.pattern.registry import (
|
|
6
|
+
register_fusion_pattern,
|
|
7
|
+
)
|
|
8
|
+
from onnxslim.version import __version__
|
|
9
|
+
|
|
10
|
+
if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"):
|
|
11
|
+
message = (
|
|
12
|
+
"You are importing onnxslim within its own root folder ({}). "
|
|
13
|
+
"This is not expected to work and may give errors. Please exit the "
|
|
14
|
+
"onnxslim project source and relaunch your python interpreter."
|
|
15
|
+
)
|
|
16
|
+
warnings.warn(message.format(os.getcwd()))
|
onnxslim/__main__.py
ADDED
onnxslim/argparser.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import dataclasses
|
|
3
|
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from typing import List, Optional, Type, Union, get_args, get_origin, TypedDict, Dict, Literal
|
|
6
|
+
|
|
7
|
+
from .core.optimization import OptimizationSettings
|
|
8
|
+
from .core.pattern.registry import DEFAULT_FUSION_PATTERNS
|
|
9
|
+
from .version import __version__
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class OnnxSlimKwargs(TypedDict, total=False):
|
|
13
|
+
model_check: bool
|
|
14
|
+
input_shapes: Dict[str, List[int]]
|
|
15
|
+
inputs: List[str]
|
|
16
|
+
outputs: List[str]
|
|
17
|
+
no_shape_infer: bool
|
|
18
|
+
skip_optimizations: List[str]
|
|
19
|
+
dtype: Literal["float16", "float32", "uint8", "int8"]
|
|
20
|
+
skip_fusion_patterns: List[str]
|
|
21
|
+
size_threshold: int
|
|
22
|
+
inspect: bool
|
|
23
|
+
dump_to_disk: bool
|
|
24
|
+
save_as_external_data: bool
|
|
25
|
+
model_check_inputs: Optional[List[str]]
|
|
26
|
+
verbose: bool
|
|
27
|
+
|
|
28
|
+
def _get_inner_type(arg_type):
|
|
29
|
+
if get_origin(arg_type) is Union:
|
|
30
|
+
return next((t for t in get_args(arg_type) if t is not type(None)), str)
|
|
31
|
+
return arg_type
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass
|
|
35
|
+
class ModelArguments:
|
|
36
|
+
"""
|
|
37
|
+
Args:
|
|
38
|
+
model (Union[str, onnx.ModelProto]): The ONNX model to be slimmed. It can be either a file path or an `onnx.ModelProto` object.
|
|
39
|
+
|
|
40
|
+
output_model (str, optional): File path to save the slimmed model. If None, the model will not be saved.
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
input_model: str = field(metadata={"help": "input onnx model"})
|
|
44
|
+
output_model: Optional[str] = field(default=None, metadata={"help": "output onnx model"})
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class OptimizationArguments:
|
|
49
|
+
"""
|
|
50
|
+
Args:
|
|
51
|
+
no_shape_infer (bool, optional): Flag indicating whether to perform shape inference. Default is False.
|
|
52
|
+
|
|
53
|
+
no_constant_folding (bool, optional): Flag indicating whether to perform constant folding. Default is False.
|
|
54
|
+
|
|
55
|
+
skip_fusion_patterns (str, optional): String representing fusion patterns to skip. Default is None.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
no_shape_infer: bool = field(default=False, metadata={"help": "whether to disable shape_infer, default false."})
|
|
59
|
+
skip_optimizations: Optional[List[str]] = field(
|
|
60
|
+
default=None,
|
|
61
|
+
metadata={
|
|
62
|
+
"help": "whether to skip some optimizations",
|
|
63
|
+
"choices": list(OptimizationSettings.keys()),
|
|
64
|
+
},
|
|
65
|
+
)
|
|
66
|
+
skip_fusion_patterns: Optional[List[str]] = field(
|
|
67
|
+
default=None,
|
|
68
|
+
metadata={
|
|
69
|
+
"help": "whether to skip the fusion of some patterns",
|
|
70
|
+
"choices": list(DEFAULT_FUSION_PATTERNS.keys()),
|
|
71
|
+
},
|
|
72
|
+
)
|
|
73
|
+
size_threshold: int = field(
|
|
74
|
+
default=None,
|
|
75
|
+
metadata={
|
|
76
|
+
"help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants",
|
|
77
|
+
},
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@dataclass
|
|
82
|
+
class ModificationArguments:
|
|
83
|
+
"""
|
|
84
|
+
Args:
|
|
85
|
+
input_shapes (str, optional): String representing the input shapes. Default is None.
|
|
86
|
+
|
|
87
|
+
outputs (str, optional): String representing the outputs. Default is None.
|
|
88
|
+
|
|
89
|
+
dtype (str, optional): Data type. Default is None.
|
|
90
|
+
|
|
91
|
+
save_as_external_data (bool, optional): Flag indicating whether to split onnx as model and weight. Default is False.
|
|
92
|
+
"""
|
|
93
|
+
|
|
94
|
+
input_shapes: Optional[List[str]] = field(
|
|
95
|
+
default=None,
|
|
96
|
+
metadata={
|
|
97
|
+
"help": "input shape of the model, INPUT_NAME:SHAPE, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:1,3,224,224"
|
|
98
|
+
},
|
|
99
|
+
)
|
|
100
|
+
inputs: Optional[List[str]] = field(
|
|
101
|
+
default=None,
|
|
102
|
+
metadata={
|
|
103
|
+
"help": "input of the model, INPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the input will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
|
|
104
|
+
},
|
|
105
|
+
)
|
|
106
|
+
outputs: Optional[List[str]] = field(
|
|
107
|
+
default=None,
|
|
108
|
+
metadata={
|
|
109
|
+
"help": "output of the model, OUTPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the output will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
|
|
110
|
+
},
|
|
111
|
+
)
|
|
112
|
+
dtype: Optional[str] = field(
|
|
113
|
+
default=None, metadata={"help": "convert data format to fp16 or fp32.", "choices": ["fp16", "fp32"]}
|
|
114
|
+
)
|
|
115
|
+
save_as_external_data: bool = field(
|
|
116
|
+
default=False, metadata={"help": "split onnx as model and weight, default False."}
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@dataclass
|
|
121
|
+
class CheckerArguments:
|
|
122
|
+
"""
|
|
123
|
+
Args:
|
|
124
|
+
model_check (bool, optional): Flag indicating whether to perform model checking. Default is False.
|
|
125
|
+
|
|
126
|
+
model_check_inputs (str, optional): The shape or tensor used for model check. Default is None.
|
|
127
|
+
|
|
128
|
+
inspect (bool, optional): Flag indicating whether to inspect the model. Default is False.
|
|
129
|
+
|
|
130
|
+
dump_to_disk (bool, optional): Flag indicating whether to dump the model detail to disk. Default is False.
|
|
131
|
+
|
|
132
|
+
verbose (bool, optional): Flag indicating whether to print verbose logs. Default is False.
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
model_check: bool = field(default=False, metadata={"help": "enable model check"})
|
|
136
|
+
model_check_inputs: Optional[List[str]] = field(
|
|
137
|
+
default=None,
|
|
138
|
+
metadata={
|
|
139
|
+
"help": "Works only when model_check is enabled, Input shape of the model or numpy data path, INPUT_NAME:SHAPE or INPUT_NAME:DATAPATH, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:data.npy. Useful when input shapes are dynamic."
|
|
140
|
+
},
|
|
141
|
+
)
|
|
142
|
+
inspect: bool = field(default=False, metadata={"help": "inspect model, default False."})
|
|
143
|
+
dump_to_disk: bool = field(default=False, metadata={"help": "dump model info to disk, default False."})
|
|
144
|
+
verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."})
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
class OnnxSlimArgumentParser(ArgumentParser):
|
|
148
|
+
def __init__(self, *argument_dataclasses: Type, **kwargs):
|
|
149
|
+
if "formatter_class" not in kwargs:
|
|
150
|
+
kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
|
|
151
|
+
super().__init__(**kwargs)
|
|
152
|
+
self.argument_dataclasses = argument_dataclasses
|
|
153
|
+
self.parser = argparse.ArgumentParser(
|
|
154
|
+
description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model",
|
|
155
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
156
|
+
)
|
|
157
|
+
self._add_arguments()
|
|
158
|
+
|
|
159
|
+
def _add_arguments(self):
|
|
160
|
+
for dataclass_type in self.argument_dataclasses:
|
|
161
|
+
if dataclass_type is ModelArguments:
|
|
162
|
+
continue
|
|
163
|
+
for field_name, field_def in dataclass_type.__dataclass_fields__.items():
|
|
164
|
+
arg_type = _get_inner_type(field_def.type)
|
|
165
|
+
default_value = field_def.default if field_def.default is not field_def.default_factory else None
|
|
166
|
+
help_text = field_def.metadata.get("help", "")
|
|
167
|
+
nargs = "+" if get_origin(arg_type) == list else None
|
|
168
|
+
choices = field_def.metadata.get("choices", None)
|
|
169
|
+
if choices and default_value is not None and default_value not in choices:
|
|
170
|
+
raise ValueError(
|
|
171
|
+
f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
|
|
172
|
+
)
|
|
173
|
+
arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
|
|
174
|
+
if arg_type == bool:
|
|
175
|
+
self.parser.add_argument(
|
|
176
|
+
f"--{field_name.replace('_', '-')}",
|
|
177
|
+
action="store_true",
|
|
178
|
+
default=default_value,
|
|
179
|
+
help=help_text,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
self.parser.add_argument(
|
|
183
|
+
f"--{field_name.replace('_', '-')}",
|
|
184
|
+
type=arg_type,
|
|
185
|
+
default=default_value,
|
|
186
|
+
nargs=nargs,
|
|
187
|
+
choices=choices,
|
|
188
|
+
help=help_text,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
# Add positional arguments separately for ModelArguments
|
|
192
|
+
self.parser.add_argument("input_model", help="input onnx model")
|
|
193
|
+
self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model")
|
|
194
|
+
self.parser.add_argument("-v", "--version", action="version", version=__version__)
|
|
195
|
+
|
|
196
|
+
def parse_args_into_dataclasses(self):
|
|
197
|
+
# Pre-parse arguments to check for `--inspect`
|
|
198
|
+
pre_parsed_args, _ = self.parser.parse_known_args()
|
|
199
|
+
if pre_parsed_args.inspect:
|
|
200
|
+
for action in self.parser._actions:
|
|
201
|
+
if action.dest == "input_model":
|
|
202
|
+
action.nargs = "+"
|
|
203
|
+
break
|
|
204
|
+
|
|
205
|
+
args = self.parser.parse_args()
|
|
206
|
+
args_dict = vars(args)
|
|
207
|
+
|
|
208
|
+
outputs = []
|
|
209
|
+
for dtype in self.argument_dataclasses:
|
|
210
|
+
keys = {f.name for f in dataclasses.fields(dtype) if f.init}
|
|
211
|
+
inputs = {k: v for k, v in args_dict.items() if k in keys}
|
|
212
|
+
obj = dtype(**inputs)
|
|
213
|
+
outputs.append(obj)
|
|
214
|
+
|
|
215
|
+
return (*outputs,)
|
onnxslim/cli/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from onnxslim.cli._main import main, slim
|
onnxslim/cli/_main.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
from onnxslim.argparser import OnnxSlimKwargs
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
|
|
9
|
+
import os
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
|
|
13
|
+
from onnxslim.core import (
|
|
14
|
+
OptimizationSettings,
|
|
15
|
+
convert_data_format,
|
|
16
|
+
freeze,
|
|
17
|
+
input_modification,
|
|
18
|
+
input_shape_modification,
|
|
19
|
+
optimize,
|
|
20
|
+
output_modification,
|
|
21
|
+
shape_infer,
|
|
22
|
+
)
|
|
23
|
+
from onnxslim.utils import (
|
|
24
|
+
TensorInfo,
|
|
25
|
+
check_onnx,
|
|
26
|
+
check_point,
|
|
27
|
+
check_result,
|
|
28
|
+
dump_model_info_to_disk,
|
|
29
|
+
init_logging,
|
|
30
|
+
onnxruntime_inference,
|
|
31
|
+
print_model_info_as_table,
|
|
32
|
+
save,
|
|
33
|
+
summarize_model,
|
|
34
|
+
update_outputs_dims,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
|
|
38
|
+
model_check = kwargs.get("model_check", False)
|
|
39
|
+
input_shapes = kwargs.get("input_shapes", None)
|
|
40
|
+
inputs = kwargs.get("inputs", None)
|
|
41
|
+
outputs = kwargs.get("outputs", None)
|
|
42
|
+
no_shape_infer = kwargs.get("no_shape_infer", False)
|
|
43
|
+
skip_optimizations = kwargs.get("skip_optimizations", None)
|
|
44
|
+
dtype = kwargs.get("dtype", None)
|
|
45
|
+
skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
|
|
46
|
+
size_threshold = kwargs.get("size_threshold", None)
|
|
47
|
+
size_threshold = int(size_threshold) if size_threshold else None
|
|
48
|
+
kwargs.get("inspect", False)
|
|
49
|
+
dump_to_disk = kwargs.get("dump_to_disk", False)
|
|
50
|
+
save_as_external_data = kwargs.get("save_as_external_data", False)
|
|
51
|
+
model_check_inputs = kwargs.get("model_check_inputs", None)
|
|
52
|
+
verbose = kwargs.get("verbose", False)
|
|
53
|
+
|
|
54
|
+
logger = init_logging(verbose)
|
|
55
|
+
|
|
56
|
+
MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10
|
|
57
|
+
|
|
58
|
+
start_time = time.time()
|
|
59
|
+
|
|
60
|
+
def get_info(model, inspect=False):
|
|
61
|
+
if isinstance(model, str):
|
|
62
|
+
model_name = Path(model).name
|
|
63
|
+
model = onnx.load(model)
|
|
64
|
+
else:
|
|
65
|
+
model_name = "OnnxModel"
|
|
66
|
+
|
|
67
|
+
freeze(model)
|
|
68
|
+
|
|
69
|
+
if not inspect:
|
|
70
|
+
return model_name, model
|
|
71
|
+
|
|
72
|
+
model_info = summarize_model(model, model_name)
|
|
73
|
+
|
|
74
|
+
return model_info
|
|
75
|
+
|
|
76
|
+
if isinstance(model, list):
|
|
77
|
+
model_info_list = [get_info(m, inspect=True) for m in model]
|
|
78
|
+
|
|
79
|
+
if dump_to_disk:
|
|
80
|
+
[dump_model_info_to_disk(info) for info in model_info_list]
|
|
81
|
+
|
|
82
|
+
print_model_info_as_table(model_info_list)
|
|
83
|
+
|
|
84
|
+
return
|
|
85
|
+
else:
|
|
86
|
+
model_name, model = get_info(model)
|
|
87
|
+
if output_model:
|
|
88
|
+
original_info = summarize_model(model, model_name)
|
|
89
|
+
|
|
90
|
+
if inputs:
|
|
91
|
+
model = input_modification(model, inputs)
|
|
92
|
+
|
|
93
|
+
if input_shapes:
|
|
94
|
+
model = input_shape_modification(model, input_shapes)
|
|
95
|
+
|
|
96
|
+
if outputs:
|
|
97
|
+
model = output_modification(model, outputs)
|
|
98
|
+
|
|
99
|
+
if model_check:
|
|
100
|
+
input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs)
|
|
101
|
+
|
|
102
|
+
output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output}
|
|
103
|
+
|
|
104
|
+
if not no_shape_infer:
|
|
105
|
+
model = shape_infer(model)
|
|
106
|
+
|
|
107
|
+
OptimizationSettings.reset(skip_optimizations)
|
|
108
|
+
if OptimizationSettings.enabled():
|
|
109
|
+
graph_check_point = check_point(model)
|
|
110
|
+
while MAX_ITER > 0:
|
|
111
|
+
logger.debug(f"iter: {MAX_ITER}")
|
|
112
|
+
model = optimize(model, skip_fusion_patterns, size_threshold)
|
|
113
|
+
if not no_shape_infer:
|
|
114
|
+
model = shape_infer(model)
|
|
115
|
+
graph = check_point(model)
|
|
116
|
+
if graph == graph_check_point:
|
|
117
|
+
logger.debug(f"converged at iter: {MAX_ITER}")
|
|
118
|
+
break
|
|
119
|
+
else:
|
|
120
|
+
graph_check_point = graph
|
|
121
|
+
|
|
122
|
+
MAX_ITER -= 1
|
|
123
|
+
|
|
124
|
+
if dtype:
|
|
125
|
+
model = convert_data_format(model, dtype)
|
|
126
|
+
|
|
127
|
+
model = update_outputs_dims(model, output_dims=output_info)
|
|
128
|
+
|
|
129
|
+
if model_check:
|
|
130
|
+
slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict)
|
|
131
|
+
if not check_result(raw_onnx_output, slimmed_onnx_output):
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
if not output_model:
|
|
135
|
+
return model
|
|
136
|
+
|
|
137
|
+
slimmed_info = summarize_model(model, output_model)
|
|
138
|
+
save(model, output_model, model_check, save_as_external_data, slimmed_info)
|
|
139
|
+
|
|
140
|
+
end_time = time.time()
|
|
141
|
+
elapsed_time = end_time - start_time
|
|
142
|
+
print_model_info_as_table(
|
|
143
|
+
[original_info, slimmed_info],
|
|
144
|
+
elapsed_time,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def main():
|
|
149
|
+
"""Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
|
|
150
|
+
from onnxslim.argparser import (
|
|
151
|
+
CheckerArguments,
|
|
152
|
+
ModelArguments,
|
|
153
|
+
ModificationArguments,
|
|
154
|
+
OnnxSlimArgumentParser,
|
|
155
|
+
OptimizationArguments,
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
argument_parser = OnnxSlimArgumentParser(
|
|
159
|
+
ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
|
|
160
|
+
)
|
|
161
|
+
model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()
|
|
162
|
+
|
|
163
|
+
if not checker_args.inspect and checker_args.dump_to_disk:
|
|
164
|
+
argument_parser.error("dump_to_disk can only be used with --inspect")
|
|
165
|
+
|
|
166
|
+
if not optimization_args.no_shape_infer:
|
|
167
|
+
from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available
|
|
168
|
+
|
|
169
|
+
if is_onnxruntime_available():
|
|
170
|
+
check_onnx_compatibility()
|
|
171
|
+
|
|
172
|
+
slim(
|
|
173
|
+
model_args.input_model,
|
|
174
|
+
model_args.output_model,
|
|
175
|
+
**optimization_args.__dict__,
|
|
176
|
+
**modification_args.__dict__,
|
|
177
|
+
**checker_args.__dict__,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
return 0
|
|
@@ -0,0 +1,219 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnx
|
|
10
|
+
from onnx import checker
|
|
11
|
+
|
|
12
|
+
import onnxslim.third_party.onnx_graphsurgeon as gs
|
|
13
|
+
from onnxslim.core.optimization import OptimizationSettings, optimize_model
|
|
14
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
|
|
15
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant
|
|
16
|
+
from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference
|
|
17
|
+
from onnxslim.utils import save
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger("onnxslim")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
|
|
23
|
+
AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))
|
|
24
|
+
FORCE_ONNXRUNTIME_SHAPE_INFERENCE = bool(os.getenv("ONNXSLIM_FORCE_ONNXRUNTIME_SHAPE_INFERENCE"))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
|
|
28
|
+
"""Modifies input tensor shapes in the ONNX model according to the specified input_shapes string."""
|
|
29
|
+
if not input_shapes:
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
graph = gs.import_onnx(model)
|
|
33
|
+
input_names = [input.name for input in graph.inputs]
|
|
34
|
+
tensors = graph.tensors()
|
|
35
|
+
|
|
36
|
+
for input_shape in input_shapes:
|
|
37
|
+
key, values = input_shape.rsplit(":", 1)
|
|
38
|
+
values_list = [int(value) for value in values.split(",")]
|
|
39
|
+
if key not in input_names:
|
|
40
|
+
raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}")
|
|
41
|
+
tensors[key].shape = values_list
|
|
42
|
+
|
|
43
|
+
for tensor in tensors.values():
|
|
44
|
+
if tensor.name not in input_names:
|
|
45
|
+
if isinstance(tensor, Constant):
|
|
46
|
+
continue
|
|
47
|
+
tensor.shape = None
|
|
48
|
+
|
|
49
|
+
model = gs.export_onnx(graph)
|
|
50
|
+
|
|
51
|
+
return model
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto:
|
|
55
|
+
"""Modifies the output layers of the ONNX model based on specified output names and data types."""
|
|
56
|
+
graph = gs.import_onnx(model)
|
|
57
|
+
graph.outputs.clear()
|
|
58
|
+
tensors = graph.tensors()
|
|
59
|
+
for output in outputs:
|
|
60
|
+
values = output.rsplit(":", 1)
|
|
61
|
+
if len(values) == 1:
|
|
62
|
+
key = values[0]
|
|
63
|
+
if key not in tensors.keys():
|
|
64
|
+
raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}")
|
|
65
|
+
dtype = tensors[key].dtype
|
|
66
|
+
if dtype is None:
|
|
67
|
+
dtype = np.float32
|
|
68
|
+
logger.warning(f"Output layer {key} has no dtype, set to default {dtype}")
|
|
69
|
+
else:
|
|
70
|
+
key, dtype = values
|
|
71
|
+
if dtype == "fp16":
|
|
72
|
+
dtype = np.float16
|
|
73
|
+
elif dtype == "fp32":
|
|
74
|
+
dtype = np.float32
|
|
75
|
+
elif dtype == "int32":
|
|
76
|
+
dtype = np.int32
|
|
77
|
+
elif dtype == "bool":
|
|
78
|
+
dtype = bool
|
|
79
|
+
else:
|
|
80
|
+
raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}")
|
|
81
|
+
|
|
82
|
+
graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
|
|
83
|
+
|
|
84
|
+
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
85
|
+
model = gs.export_onnx(graph)
|
|
86
|
+
|
|
87
|
+
return model
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto:
|
|
91
|
+
"""Modifies the output layers of the ONNX model based on specified output names and data types."""
|
|
92
|
+
graph = gs.import_onnx(model)
|
|
93
|
+
graph.inputs.clear()
|
|
94
|
+
tensors = graph.tensors()
|
|
95
|
+
for input in inputs:
|
|
96
|
+
values = input.rsplit(":", 1)
|
|
97
|
+
if len(values) == 1:
|
|
98
|
+
key = values[0]
|
|
99
|
+
if key not in tensors.keys():
|
|
100
|
+
raise Exception(f"Input name {key} not found in model, available keys: {' '.join(tensors.keys())}")
|
|
101
|
+
dtype = tensors[key].dtype
|
|
102
|
+
if dtype is None:
|
|
103
|
+
dtype = np.float32
|
|
104
|
+
logger.warning(f"Input layer {key} has no dtype, set to default {dtype}")
|
|
105
|
+
else:
|
|
106
|
+
key, dtype = values
|
|
107
|
+
if dtype == "fp16":
|
|
108
|
+
dtype = np.float16
|
|
109
|
+
elif dtype == "fp32":
|
|
110
|
+
dtype = np.float32
|
|
111
|
+
elif dtype == "int32":
|
|
112
|
+
dtype = np.int32
|
|
113
|
+
elif dtype == "bool":
|
|
114
|
+
dtype = bool
|
|
115
|
+
else:
|
|
116
|
+
raise Exception(f"Input layer {key} assigned unsupported dtype {dtype}")
|
|
117
|
+
|
|
118
|
+
graph.inputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
|
|
119
|
+
|
|
120
|
+
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
121
|
+
model = gs.export_onnx(graph)
|
|
122
|
+
|
|
123
|
+
return model
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def shape_infer(model: onnx.ModelProto):
|
|
127
|
+
"""Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
|
|
128
|
+
logger.debug("Start shape inference.")
|
|
129
|
+
if FORCE_ONNXRUNTIME_SHAPE_INFERENCE:
|
|
130
|
+
logger.debug("force onnxruntime shape infer.")
|
|
131
|
+
return SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
|
|
132
|
+
try:
|
|
133
|
+
logger.debug("try onnxruntime shape infer.")
|
|
134
|
+
model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
|
|
135
|
+
except Exception as err:
|
|
136
|
+
logger.debug(f"onnxruntime shape infer failed, try onnx shape infer. {err}")
|
|
137
|
+
if model.ByteSize() >= checker.MAXIMUM_PROTOBUF:
|
|
138
|
+
tmp_dir = tempfile.TemporaryDirectory()
|
|
139
|
+
tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
|
|
140
|
+
tmp_infer_path = os.path.join(tmp_dir.name, "tmp_infer.onnx")
|
|
141
|
+
save(model, tmp_path)
|
|
142
|
+
onnx.shape_inference.infer_shapes_path(tmp_path, tmp_infer_path)
|
|
143
|
+
model = onnx.load(tmp_infer_path)
|
|
144
|
+
else:
|
|
145
|
+
model = onnx.shape_inference.infer_shapes(model)
|
|
146
|
+
if DEBUG:
|
|
147
|
+
onnx.save(model, "debug_shape_infer.onnx")
|
|
148
|
+
logger.debug("Finish shape inference.")
|
|
149
|
+
return model
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def optimize(model: onnx.ModelProto, skip_fusion_patterns: str | None = None, size_threshold: int | None = None):
|
|
153
|
+
"""Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
|
|
154
|
+
logger.debug("Start converting model to gs.")
|
|
155
|
+
graph = gs.import_onnx(model).toposort()
|
|
156
|
+
logger.debug("Finish converting model to gs.")
|
|
157
|
+
if OptimizationSettings.constant_folding:
|
|
158
|
+
logger.debug("Start constant folding.")
|
|
159
|
+
graph.fold_constants(size_threshold=size_threshold).cleanup().toposort()
|
|
160
|
+
logger.debug("Finish constant folding.")
|
|
161
|
+
logger.debug("Start optimize model.")
|
|
162
|
+
model = optimize_model(graph, skip_fusion_patterns)
|
|
163
|
+
logger.debug("Finish optimize model.")
|
|
164
|
+
if DEBUG:
|
|
165
|
+
onnx.save(model, "debug_slim.onnx")
|
|
166
|
+
|
|
167
|
+
return model
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
|
|
171
|
+
"""Convert ONNX model data format to specified dtype, supporting 'fp16' and 'fp32'."""
|
|
172
|
+
if dtype == "fp16":
|
|
173
|
+
from onnxconverter_common import float16
|
|
174
|
+
|
|
175
|
+
model = float16.convert_float_to_float16(model)
|
|
176
|
+
elif dtype == "fp32":
|
|
177
|
+
graph = gs.import_onnx(model).toposort()
|
|
178
|
+
|
|
179
|
+
for node in graph.nodes:
|
|
180
|
+
if node.op == "Cast":
|
|
181
|
+
inp_dtype = next(input.dtype for input in node.inputs)
|
|
182
|
+
if inp_dtype in [np.float16, np.float32]:
|
|
183
|
+
node.erase()
|
|
184
|
+
else:
|
|
185
|
+
outp_dtype = next(output.dtype for output in node.outputs)
|
|
186
|
+
if outp_dtype == np.float16:
|
|
187
|
+
node.attrs["to"] = dtype_to_onnx(np.float32)
|
|
188
|
+
node.outputs[0].dtype = np.float32
|
|
189
|
+
elif node.op == "ConstantOfShape":
|
|
190
|
+
if hasattr(node, "attrs") and "value" in node.attrs:
|
|
191
|
+
if node.attrs["value"].dtype == np.float16:
|
|
192
|
+
node.attrs["value"].values = node.attrs["value"].values.astype(np.float32)
|
|
193
|
+
node.outputs[0].dtype = np.float32
|
|
194
|
+
|
|
195
|
+
for tensor in graph.tensors().values():
|
|
196
|
+
if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16:
|
|
197
|
+
tensor.dtype = np.float32
|
|
198
|
+
elif isinstance(tensor, gs.Constant) and tensor.dtype == np.float16:
|
|
199
|
+
tensor.values = tensor.values.astype(np.float32)
|
|
200
|
+
|
|
201
|
+
graph.cleanup(remove_unused_graph_inputs=True).toposort()
|
|
202
|
+
model = gs.export_onnx(graph)
|
|
203
|
+
|
|
204
|
+
return model
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def freeze(model: onnx.ModelProto):
|
|
208
|
+
"""Freeze the input layers of an ONNX model by removing the initializers from the input graph."""
|
|
209
|
+
inputs = model.graph.input
|
|
210
|
+
name_to_input = {}
|
|
211
|
+
for input in inputs:
|
|
212
|
+
if input.name in name_to_input:
|
|
213
|
+
logger.warning(f"Duplicate input name: {input.name}")
|
|
214
|
+
name_to_input[input.name] = input
|
|
215
|
+
|
|
216
|
+
for initializer in model.graph.initializer:
|
|
217
|
+
if initializer.name in name_to_input:
|
|
218
|
+
inputs.remove(name_to_input[initializer.name])
|
|
219
|
+
name_to_input.pop(initializer.name)
|