onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
onnxslim/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ import os
2
+ import warnings
3
+
4
+ from onnxslim.cli import slim
5
+ from onnxslim.core.pattern.registry import (
6
+ register_fusion_pattern,
7
+ )
8
+ from onnxslim.version import __version__
9
+
10
+ if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"):
11
+ message = (
12
+ "You are importing onnxslim within its own root folder ({}). "
13
+ "This is not expected to work and may give errors. Please exit the "
14
+ "onnxslim project source and relaunch your python interpreter."
15
+ )
16
+ warnings.warn(message.format(os.getcwd()))
onnxslim/__main__.py ADDED
@@ -0,0 +1,4 @@
1
+ from onnxslim.cli._main import main
2
+
3
+ if __name__ == "__main__":
4
+ main()
onnxslim/argparser.py ADDED
@@ -0,0 +1,215 @@
1
+ import argparse
2
+ import dataclasses
3
+ from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
4
+ from dataclasses import dataclass, field
5
+ from typing import List, Optional, Type, Union, get_args, get_origin, TypedDict, Dict, Literal
6
+
7
+ from .core.optimization import OptimizationSettings
8
+ from .core.pattern.registry import DEFAULT_FUSION_PATTERNS
9
+ from .version import __version__
10
+
11
+
12
+ class OnnxSlimKwargs(TypedDict, total=False):
13
+ model_check: bool
14
+ input_shapes: Dict[str, List[int]]
15
+ inputs: List[str]
16
+ outputs: List[str]
17
+ no_shape_infer: bool
18
+ skip_optimizations: List[str]
19
+ dtype: Literal["float16", "float32", "uint8", "int8"]
20
+ skip_fusion_patterns: List[str]
21
+ size_threshold: int
22
+ inspect: bool
23
+ dump_to_disk: bool
24
+ save_as_external_data: bool
25
+ model_check_inputs: Optional[List[str]]
26
+ verbose: bool
27
+
28
+ def _get_inner_type(arg_type):
29
+ if get_origin(arg_type) is Union:
30
+ return next((t for t in get_args(arg_type) if t is not type(None)), str)
31
+ return arg_type
32
+
33
+
34
+ @dataclass
35
+ class ModelArguments:
36
+ """
37
+ Args:
38
+ model (Union[str, onnx.ModelProto]): The ONNX model to be slimmed. It can be either a file path or an `onnx.ModelProto` object.
39
+
40
+ output_model (str, optional): File path to save the slimmed model. If None, the model will not be saved.
41
+ """
42
+
43
+ input_model: str = field(metadata={"help": "input onnx model"})
44
+ output_model: Optional[str] = field(default=None, metadata={"help": "output onnx model"})
45
+
46
+
47
+ @dataclass
48
+ class OptimizationArguments:
49
+ """
50
+ Args:
51
+ no_shape_infer (bool, optional): Flag indicating whether to perform shape inference. Default is False.
52
+
53
+ no_constant_folding (bool, optional): Flag indicating whether to perform constant folding. Default is False.
54
+
55
+ skip_fusion_patterns (str, optional): String representing fusion patterns to skip. Default is None.
56
+ """
57
+
58
+ no_shape_infer: bool = field(default=False, metadata={"help": "whether to disable shape_infer, default false."})
59
+ skip_optimizations: Optional[List[str]] = field(
60
+ default=None,
61
+ metadata={
62
+ "help": "whether to skip some optimizations",
63
+ "choices": list(OptimizationSettings.keys()),
64
+ },
65
+ )
66
+ skip_fusion_patterns: Optional[List[str]] = field(
67
+ default=None,
68
+ metadata={
69
+ "help": "whether to skip the fusion of some patterns",
70
+ "choices": list(DEFAULT_FUSION_PATTERNS.keys()),
71
+ },
72
+ )
73
+ size_threshold: int = field(
74
+ default=None,
75
+ metadata={
76
+ "help": "size threshold in bytes, size larger than this value will not be folded, default None, which means fold all constants",
77
+ },
78
+ )
79
+
80
+
81
+ @dataclass
82
+ class ModificationArguments:
83
+ """
84
+ Args:
85
+ input_shapes (str, optional): String representing the input shapes. Default is None.
86
+
87
+ outputs (str, optional): String representing the outputs. Default is None.
88
+
89
+ dtype (str, optional): Data type. Default is None.
90
+
91
+ save_as_external_data (bool, optional): Flag indicating whether to split onnx as model and weight. Default is False.
92
+ """
93
+
94
+ input_shapes: Optional[List[str]] = field(
95
+ default=None,
96
+ metadata={
97
+ "help": "input shape of the model, INPUT_NAME:SHAPE, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:1,3,224,224"
98
+ },
99
+ )
100
+ inputs: Optional[List[str]] = field(
101
+ default=None,
102
+ metadata={
103
+ "help": "input of the model, INPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the input will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
104
+ },
105
+ )
106
+ outputs: Optional[List[str]] = field(
107
+ default=None,
108
+ metadata={
109
+ "help": "output of the model, OUTPUT_NAME:DTYPE, e.g. y:fp32 or y1:fp32 y2:fp32. If dtype is not specified, the dtype of the output will be the same as the original model if it has dtype, otherwise it will be fp32, available dtype: fp16, fp32, int32"
110
+ },
111
+ )
112
+ dtype: Optional[str] = field(
113
+ default=None, metadata={"help": "convert data format to fp16 or fp32.", "choices": ["fp16", "fp32"]}
114
+ )
115
+ save_as_external_data: bool = field(
116
+ default=False, metadata={"help": "split onnx as model and weight, default False."}
117
+ )
118
+
119
+
120
+ @dataclass
121
+ class CheckerArguments:
122
+ """
123
+ Args:
124
+ model_check (bool, optional): Flag indicating whether to perform model checking. Default is False.
125
+
126
+ model_check_inputs (str, optional): The shape or tensor used for model check. Default is None.
127
+
128
+ inspect (bool, optional): Flag indicating whether to inspect the model. Default is False.
129
+
130
+ dump_to_disk (bool, optional): Flag indicating whether to dump the model detail to disk. Default is False.
131
+
132
+ verbose (bool, optional): Flag indicating whether to print verbose logs. Default is False.
133
+ """
134
+
135
+ model_check: bool = field(default=False, metadata={"help": "enable model check"})
136
+ model_check_inputs: Optional[List[str]] = field(
137
+ default=None,
138
+ metadata={
139
+ "help": "Works only when model_check is enabled, Input shape of the model or numpy data path, INPUT_NAME:SHAPE or INPUT_NAME:DATAPATH, e.g. x:1,3,224,224 or x1:1,3,224,224 x2:data.npy. Useful when input shapes are dynamic."
140
+ },
141
+ )
142
+ inspect: bool = field(default=False, metadata={"help": "inspect model, default False."})
143
+ dump_to_disk: bool = field(default=False, metadata={"help": "dump model info to disk, default False."})
144
+ verbose: bool = field(default=False, metadata={"help": "verbose mode, default False."})
145
+
146
+
147
+ class OnnxSlimArgumentParser(ArgumentParser):
148
+ def __init__(self, *argument_dataclasses: Type, **kwargs):
149
+ if "formatter_class" not in kwargs:
150
+ kwargs["formatter_class"] = ArgumentDefaultsHelpFormatter
151
+ super().__init__(**kwargs)
152
+ self.argument_dataclasses = argument_dataclasses
153
+ self.parser = argparse.ArgumentParser(
154
+ description="OnnxSlim: A Toolkit to Help Optimizer Onnx Model",
155
+ formatter_class=argparse.RawDescriptionHelpFormatter,
156
+ )
157
+ self._add_arguments()
158
+
159
+ def _add_arguments(self):
160
+ for dataclass_type in self.argument_dataclasses:
161
+ if dataclass_type is ModelArguments:
162
+ continue
163
+ for field_name, field_def in dataclass_type.__dataclass_fields__.items():
164
+ arg_type = _get_inner_type(field_def.type)
165
+ default_value = field_def.default if field_def.default is not field_def.default_factory else None
166
+ help_text = field_def.metadata.get("help", "")
167
+ nargs = "+" if get_origin(arg_type) == list else None
168
+ choices = field_def.metadata.get("choices", None)
169
+ if choices and default_value is not None and default_value not in choices:
170
+ raise ValueError(
171
+ f"Invalid default value '{default_value}' for argument '{field_name}'. Must be one of {choices}."
172
+ )
173
+ arg_type = get_args(arg_type)[0] if get_args(arg_type) else arg_type
174
+ if arg_type == bool:
175
+ self.parser.add_argument(
176
+ f"--{field_name.replace('_', '-')}",
177
+ action="store_true",
178
+ default=default_value,
179
+ help=help_text,
180
+ )
181
+ else:
182
+ self.parser.add_argument(
183
+ f"--{field_name.replace('_', '-')}",
184
+ type=arg_type,
185
+ default=default_value,
186
+ nargs=nargs,
187
+ choices=choices,
188
+ help=help_text,
189
+ )
190
+
191
+ # Add positional arguments separately for ModelArguments
192
+ self.parser.add_argument("input_model", help="input onnx model")
193
+ self.parser.add_argument("output_model", nargs="?", default=None, help="output onnx model")
194
+ self.parser.add_argument("-v", "--version", action="version", version=__version__)
195
+
196
+ def parse_args_into_dataclasses(self):
197
+ # Pre-parse arguments to check for `--inspect`
198
+ pre_parsed_args, _ = self.parser.parse_known_args()
199
+ if pre_parsed_args.inspect:
200
+ for action in self.parser._actions:
201
+ if action.dest == "input_model":
202
+ action.nargs = "+"
203
+ break
204
+
205
+ args = self.parser.parse_args()
206
+ args_dict = vars(args)
207
+
208
+ outputs = []
209
+ for dtype in self.argument_dataclasses:
210
+ keys = {f.name for f in dataclasses.fields(dtype) if f.init}
211
+ inputs = {k: v for k, v in args_dict.items() if k in keys}
212
+ obj = dtype(**inputs)
213
+ outputs.append(obj)
214
+
215
+ return (*outputs,)
@@ -0,0 +1 @@
1
+ from onnxslim.cli._main import main, slim
onnxslim/cli/_main.py ADDED
@@ -0,0 +1,180 @@
1
+ from __future__ import annotations
2
+
3
+ import onnx
4
+
5
+ from onnxslim.argparser import OnnxSlimKwargs
6
+
7
+
8
+ def slim(model: str | onnx.ModelProto | list[str | onnx.ModelProto], *args, **kwargs: OnnxSlimKwargs):
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+
13
+ from onnxslim.core import (
14
+ OptimizationSettings,
15
+ convert_data_format,
16
+ freeze,
17
+ input_modification,
18
+ input_shape_modification,
19
+ optimize,
20
+ output_modification,
21
+ shape_infer,
22
+ )
23
+ from onnxslim.utils import (
24
+ TensorInfo,
25
+ check_onnx,
26
+ check_point,
27
+ check_result,
28
+ dump_model_info_to_disk,
29
+ init_logging,
30
+ onnxruntime_inference,
31
+ print_model_info_as_table,
32
+ save,
33
+ summarize_model,
34
+ update_outputs_dims,
35
+ )
36
+
37
+ output_model = args[0] if len(args) > 0 else kwargs.get("output_model", None)
38
+ model_check = kwargs.get("model_check", False)
39
+ input_shapes = kwargs.get("input_shapes", None)
40
+ inputs = kwargs.get("inputs", None)
41
+ outputs = kwargs.get("outputs", None)
42
+ no_shape_infer = kwargs.get("no_shape_infer", False)
43
+ skip_optimizations = kwargs.get("skip_optimizations", None)
44
+ dtype = kwargs.get("dtype", None)
45
+ skip_fusion_patterns = kwargs.get("skip_fusion_patterns", None)
46
+ size_threshold = kwargs.get("size_threshold", None)
47
+ size_threshold = int(size_threshold) if size_threshold else None
48
+ kwargs.get("inspect", False)
49
+ dump_to_disk = kwargs.get("dump_to_disk", False)
50
+ save_as_external_data = kwargs.get("save_as_external_data", False)
51
+ model_check_inputs = kwargs.get("model_check_inputs", None)
52
+ verbose = kwargs.get("verbose", False)
53
+
54
+ logger = init_logging(verbose)
55
+
56
+ MAX_ITER = int(os.getenv("ONNXSLIM_MAX_ITER")) if os.getenv("ONNXSLIM_MAX_ITER") else 10
57
+
58
+ start_time = time.time()
59
+
60
+ def get_info(model, inspect=False):
61
+ if isinstance(model, str):
62
+ model_name = Path(model).name
63
+ model = onnx.load(model)
64
+ else:
65
+ model_name = "OnnxModel"
66
+
67
+ freeze(model)
68
+
69
+ if not inspect:
70
+ return model_name, model
71
+
72
+ model_info = summarize_model(model, model_name)
73
+
74
+ return model_info
75
+
76
+ if isinstance(model, list):
77
+ model_info_list = [get_info(m, inspect=True) for m in model]
78
+
79
+ if dump_to_disk:
80
+ [dump_model_info_to_disk(info) for info in model_info_list]
81
+
82
+ print_model_info_as_table(model_info_list)
83
+
84
+ return
85
+ else:
86
+ model_name, model = get_info(model)
87
+ if output_model:
88
+ original_info = summarize_model(model, model_name)
89
+
90
+ if inputs:
91
+ model = input_modification(model, inputs)
92
+
93
+ if input_shapes:
94
+ model = input_shape_modification(model, input_shapes)
95
+
96
+ if outputs:
97
+ model = output_modification(model, outputs)
98
+
99
+ if model_check:
100
+ input_data_dict, raw_onnx_output, model = check_onnx(model, model_check_inputs)
101
+
102
+ output_info = {TensorInfo(o).name: TensorInfo(o).shape for o in model.graph.output}
103
+
104
+ if not no_shape_infer:
105
+ model = shape_infer(model)
106
+
107
+ OptimizationSettings.reset(skip_optimizations)
108
+ if OptimizationSettings.enabled():
109
+ graph_check_point = check_point(model)
110
+ while MAX_ITER > 0:
111
+ logger.debug(f"iter: {MAX_ITER}")
112
+ model = optimize(model, skip_fusion_patterns, size_threshold)
113
+ if not no_shape_infer:
114
+ model = shape_infer(model)
115
+ graph = check_point(model)
116
+ if graph == graph_check_point:
117
+ logger.debug(f"converged at iter: {MAX_ITER}")
118
+ break
119
+ else:
120
+ graph_check_point = graph
121
+
122
+ MAX_ITER -= 1
123
+
124
+ if dtype:
125
+ model = convert_data_format(model, dtype)
126
+
127
+ model = update_outputs_dims(model, output_dims=output_info)
128
+
129
+ if model_check:
130
+ slimmed_onnx_output, model = onnxruntime_inference(model, input_data_dict)
131
+ if not check_result(raw_onnx_output, slimmed_onnx_output):
132
+ return None
133
+
134
+ if not output_model:
135
+ return model
136
+
137
+ slimmed_info = summarize_model(model, output_model)
138
+ save(model, output_model, model_check, save_as_external_data, slimmed_info)
139
+
140
+ end_time = time.time()
141
+ elapsed_time = end_time - start_time
142
+ print_model_info_as_table(
143
+ [original_info, slimmed_info],
144
+ elapsed_time,
145
+ )
146
+
147
+
148
+ def main():
149
+ """Entry point for the OnnxSlim toolkit, processes command-line arguments and passes them to the slim function."""
150
+ from onnxslim.argparser import (
151
+ CheckerArguments,
152
+ ModelArguments,
153
+ ModificationArguments,
154
+ OnnxSlimArgumentParser,
155
+ OptimizationArguments,
156
+ )
157
+
158
+ argument_parser = OnnxSlimArgumentParser(
159
+ ModelArguments, OptimizationArguments, ModificationArguments, CheckerArguments
160
+ )
161
+ model_args, optimization_args, modification_args, checker_args = argument_parser.parse_args_into_dataclasses()
162
+
163
+ if not checker_args.inspect and checker_args.dump_to_disk:
164
+ argument_parser.error("dump_to_disk can only be used with --inspect")
165
+
166
+ if not optimization_args.no_shape_infer:
167
+ from onnxslim.utils import check_onnx_compatibility, is_onnxruntime_available
168
+
169
+ if is_onnxruntime_available():
170
+ check_onnx_compatibility()
171
+
172
+ slim(
173
+ model_args.input_model,
174
+ model_args.output_model,
175
+ **optimization_args.__dict__,
176
+ **modification_args.__dict__,
177
+ **checker_args.__dict__,
178
+ )
179
+
180
+ return 0
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import tempfile
6
+ from typing import Optional
7
+
8
+ import numpy as np
9
+ import onnx
10
+ from onnx import checker
11
+
12
+ import onnxslim.third_party.onnx_graphsurgeon as gs
13
+ from onnxslim.core.optimization import OptimizationSettings, optimize_model
14
+ from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
15
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant
16
+ from onnxslim.third_party.symbolic_shape_infer import SymbolicShapeInference
17
+ from onnxslim.utils import save
18
+
19
+ logger = logging.getLogger("onnxslim")
20
+
21
+
22
+ DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
23
+ AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))
24
+ FORCE_ONNXRUNTIME_SHAPE_INFERENCE = bool(os.getenv("ONNXSLIM_FORCE_ONNXRUNTIME_SHAPE_INFERENCE"))
25
+
26
+
27
+ def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
28
+ """Modifies input tensor shapes in the ONNX model according to the specified input_shapes string."""
29
+ if not input_shapes:
30
+ return
31
+
32
+ graph = gs.import_onnx(model)
33
+ input_names = [input.name for input in graph.inputs]
34
+ tensors = graph.tensors()
35
+
36
+ for input_shape in input_shapes:
37
+ key, values = input_shape.rsplit(":", 1)
38
+ values_list = [int(value) for value in values.split(",")]
39
+ if key not in input_names:
40
+ raise Exception(f"Input name {key} not found in model, available keys: {' '.join(input_names)}")
41
+ tensors[key].shape = values_list
42
+
43
+ for tensor in tensors.values():
44
+ if tensor.name not in input_names:
45
+ if isinstance(tensor, Constant):
46
+ continue
47
+ tensor.shape = None
48
+
49
+ model = gs.export_onnx(graph)
50
+
51
+ return model
52
+
53
+
54
+ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto:
55
+ """Modifies the output layers of the ONNX model based on specified output names and data types."""
56
+ graph = gs.import_onnx(model)
57
+ graph.outputs.clear()
58
+ tensors = graph.tensors()
59
+ for output in outputs:
60
+ values = output.rsplit(":", 1)
61
+ if len(values) == 1:
62
+ key = values[0]
63
+ if key not in tensors.keys():
64
+ raise Exception(f"Output name {key} not found in model, available keys: {' '.join(tensors.keys())}")
65
+ dtype = tensors[key].dtype
66
+ if dtype is None:
67
+ dtype = np.float32
68
+ logger.warning(f"Output layer {key} has no dtype, set to default {dtype}")
69
+ else:
70
+ key, dtype = values
71
+ if dtype == "fp16":
72
+ dtype = np.float16
73
+ elif dtype == "fp32":
74
+ dtype = np.float32
75
+ elif dtype == "int32":
76
+ dtype = np.int32
77
+ elif dtype == "bool":
78
+ dtype = bool
79
+ else:
80
+ raise Exception(f"Output layer {key} assigned unsupported dtype {dtype}")
81
+
82
+ graph.outputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
83
+
84
+ graph.cleanup(remove_unused_graph_inputs=True).toposort()
85
+ model = gs.export_onnx(graph)
86
+
87
+ return model
88
+
89
+
90
+ def input_modification(model: onnx.ModelProto, inputs: str) -> onnx.ModelProto:
91
+ """Modifies the output layers of the ONNX model based on specified output names and data types."""
92
+ graph = gs.import_onnx(model)
93
+ graph.inputs.clear()
94
+ tensors = graph.tensors()
95
+ for input in inputs:
96
+ values = input.rsplit(":", 1)
97
+ if len(values) == 1:
98
+ key = values[0]
99
+ if key not in tensors.keys():
100
+ raise Exception(f"Input name {key} not found in model, available keys: {' '.join(tensors.keys())}")
101
+ dtype = tensors[key].dtype
102
+ if dtype is None:
103
+ dtype = np.float32
104
+ logger.warning(f"Input layer {key} has no dtype, set to default {dtype}")
105
+ else:
106
+ key, dtype = values
107
+ if dtype == "fp16":
108
+ dtype = np.float16
109
+ elif dtype == "fp32":
110
+ dtype = np.float32
111
+ elif dtype == "int32":
112
+ dtype = np.int32
113
+ elif dtype == "bool":
114
+ dtype = bool
115
+ else:
116
+ raise Exception(f"Input layer {key} assigned unsupported dtype {dtype}")
117
+
118
+ graph.inputs.append(tensors[key].to_variable(dtype=dtype, shape=tensors[key].shape))
119
+
120
+ graph.cleanup(remove_unused_graph_inputs=True).toposort()
121
+ model = gs.export_onnx(graph)
122
+
123
+ return model
124
+
125
+
126
+ def shape_infer(model: onnx.ModelProto):
127
+ """Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
128
+ logger.debug("Start shape inference.")
129
+ if FORCE_ONNXRUNTIME_SHAPE_INFERENCE:
130
+ logger.debug("force onnxruntime shape infer.")
131
+ return SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
132
+ try:
133
+ logger.debug("try onnxruntime shape infer.")
134
+ model = SymbolicShapeInference.infer_shapes(model, auto_merge=AUTO_MERGE)
135
+ except Exception as err:
136
+ logger.debug(f"onnxruntime shape infer failed, try onnx shape infer. {err}")
137
+ if model.ByteSize() >= checker.MAXIMUM_PROTOBUF:
138
+ tmp_dir = tempfile.TemporaryDirectory()
139
+ tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
140
+ tmp_infer_path = os.path.join(tmp_dir.name, "tmp_infer.onnx")
141
+ save(model, tmp_path)
142
+ onnx.shape_inference.infer_shapes_path(tmp_path, tmp_infer_path)
143
+ model = onnx.load(tmp_infer_path)
144
+ else:
145
+ model = onnx.shape_inference.infer_shapes(model)
146
+ if DEBUG:
147
+ onnx.save(model, "debug_shape_infer.onnx")
148
+ logger.debug("Finish shape inference.")
149
+ return model
150
+
151
+
152
+ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str | None = None, size_threshold: int | None = None):
153
+ """Optimize the given ONNX model with options to skip specific fusion patterns and return the optimized model."""
154
+ logger.debug("Start converting model to gs.")
155
+ graph = gs.import_onnx(model).toposort()
156
+ logger.debug("Finish converting model to gs.")
157
+ if OptimizationSettings.constant_folding:
158
+ logger.debug("Start constant folding.")
159
+ graph.fold_constants(size_threshold=size_threshold).cleanup().toposort()
160
+ logger.debug("Finish constant folding.")
161
+ logger.debug("Start optimize model.")
162
+ model = optimize_model(graph, skip_fusion_patterns)
163
+ logger.debug("Finish optimize model.")
164
+ if DEBUG:
165
+ onnx.save(model, "debug_slim.onnx")
166
+
167
+ return model
168
+
169
+
170
+ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
171
+ """Convert ONNX model data format to specified dtype, supporting 'fp16' and 'fp32'."""
172
+ if dtype == "fp16":
173
+ from onnxconverter_common import float16
174
+
175
+ model = float16.convert_float_to_float16(model)
176
+ elif dtype == "fp32":
177
+ graph = gs.import_onnx(model).toposort()
178
+
179
+ for node in graph.nodes:
180
+ if node.op == "Cast":
181
+ inp_dtype = next(input.dtype for input in node.inputs)
182
+ if inp_dtype in [np.float16, np.float32]:
183
+ node.erase()
184
+ else:
185
+ outp_dtype = next(output.dtype for output in node.outputs)
186
+ if outp_dtype == np.float16:
187
+ node.attrs["to"] = dtype_to_onnx(np.float32)
188
+ node.outputs[0].dtype = np.float32
189
+ elif node.op == "ConstantOfShape":
190
+ if hasattr(node, "attrs") and "value" in node.attrs:
191
+ if node.attrs["value"].dtype == np.float16:
192
+ node.attrs["value"].values = node.attrs["value"].values.astype(np.float32)
193
+ node.outputs[0].dtype = np.float32
194
+
195
+ for tensor in graph.tensors().values():
196
+ if isinstance(tensor, gs.Variable) and tensor.dtype == np.float16:
197
+ tensor.dtype = np.float32
198
+ elif isinstance(tensor, gs.Constant) and tensor.dtype == np.float16:
199
+ tensor.values = tensor.values.astype(np.float32)
200
+
201
+ graph.cleanup(remove_unused_graph_inputs=True).toposort()
202
+ model = gs.export_onnx(graph)
203
+
204
+ return model
205
+
206
+
207
+ def freeze(model: onnx.ModelProto):
208
+ """Freeze the input layers of an ONNX model by removing the initializers from the input graph."""
209
+ inputs = model.graph.input
210
+ name_to_input = {}
211
+ for input in inputs:
212
+ if input.name in name_to_input:
213
+ logger.warning(f"Duplicate input name: {input.name}")
214
+ name_to_input[input.name] = input
215
+
216
+ for initializer in model.graph.initializer:
217
+ if initializer.name in name_to_input:
218
+ inputs.remove(name_to_input[initializer.name])
219
+ name_to_input.pop(initializer.name)