onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
import enum
|
|
19
|
+
import inspect
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
import time
|
|
23
|
+
|
|
24
|
+
from onnxslim.third_party.onnx_graphsurgeon.util.exception import (
|
|
25
|
+
OnnxGraphSurgeonException,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# Context manager to apply indentation to messages
|
|
30
|
+
class LoggerIndent:
|
|
31
|
+
def __init__(self, logger, indent):
|
|
32
|
+
"""Initialize the LoggerIndent context manager with the specified logger and indentation level."""
|
|
33
|
+
self.logger = logger
|
|
34
|
+
self.old_indent = self.logger.logging_indent
|
|
35
|
+
self.indent = indent
|
|
36
|
+
|
|
37
|
+
def __enter__(self):
|
|
38
|
+
"""Set logger indentation level on entering the context."""
|
|
39
|
+
self.logger.logging_indent = self.indent
|
|
40
|
+
return self
|
|
41
|
+
|
|
42
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
43
|
+
"""Reset logger indentation level on exiting the context."""
|
|
44
|
+
self.logger.logging_indent = self.old_indent
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# Context manager to suppress messages
|
|
48
|
+
class LoggerSuppress:
|
|
49
|
+
def __init__(self, logger, severity):
|
|
50
|
+
"""Initialize a LoggerSuppress object with a logger and severity level."""
|
|
51
|
+
self.logger = logger
|
|
52
|
+
self.old_severity = self.logger.severity
|
|
53
|
+
self.severity = severity
|
|
54
|
+
|
|
55
|
+
def __enter__(self):
|
|
56
|
+
"""Set logger severity to a specified level when entering the context."""
|
|
57
|
+
self.logger.severity = self.severity
|
|
58
|
+
return self
|
|
59
|
+
|
|
60
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
61
|
+
"""Reset logger severity to its original level when exiting the context."""
|
|
62
|
+
self.logger.severity = self.old_severity
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class LogMode(enum.IntEnum):
|
|
66
|
+
EACH = 0 # Log the message each time
|
|
67
|
+
ONCE = 1 # Log the message only once. The same message will not be logged again.
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class Logger:
|
|
71
|
+
ULTRA_VERBOSE = -10
|
|
72
|
+
VERBOSE = 0
|
|
73
|
+
DEBUG = 10
|
|
74
|
+
INFO = 20
|
|
75
|
+
WARNING = 30
|
|
76
|
+
ERROR = 40
|
|
77
|
+
CRITICAL = 50
|
|
78
|
+
|
|
79
|
+
SEVERITY_LETTER_MAPPING = {
|
|
80
|
+
ULTRA_VERBOSE: "[UV]",
|
|
81
|
+
VERBOSE: "[V]",
|
|
82
|
+
DEBUG: "[D]",
|
|
83
|
+
INFO: "[I]",
|
|
84
|
+
WARNING: "[W]",
|
|
85
|
+
ERROR: "[E]",
|
|
86
|
+
CRITICAL: "[C]",
|
|
87
|
+
}
|
|
88
|
+
|
|
89
|
+
SEVERITY_COLOR_MAPPING = {
|
|
90
|
+
ULTRA_VERBOSE: "cyan",
|
|
91
|
+
VERBOSE: "dark_gray",
|
|
92
|
+
DEBUG: "light_gray",
|
|
93
|
+
INFO: "light_green",
|
|
94
|
+
WARNING: "light_yellow",
|
|
95
|
+
ERROR: "red_1",
|
|
96
|
+
CRITICAL: "red_1",
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
def __init__(self, severity=INFO, colors=True, letter=True, timestamp=False, line_info=False):
|
|
100
|
+
"""
|
|
101
|
+
Logger.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
severity (Logger.Severity): Messages below this severity are ignored.
|
|
105
|
+
colors (bool): Whether to use colored output.
|
|
106
|
+
letter (bool): Whether to prepend each logging message with a letter indicating it's severity. Defaults to True.
|
|
107
|
+
timestamp (bool): Whether to include a timestamp in the logging output. Defaults to False.
|
|
108
|
+
line_info (bool): Whether to include file and line number information in the logging output. Defaults to False.
|
|
109
|
+
"""
|
|
110
|
+
self._severity = severity
|
|
111
|
+
self.logging_indent = 0
|
|
112
|
+
self.root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir, os.pardir))
|
|
113
|
+
self.once_logged = set()
|
|
114
|
+
self.colors = colors
|
|
115
|
+
self.letter = letter
|
|
116
|
+
self.timestamp = timestamp
|
|
117
|
+
self.line_info = line_info
|
|
118
|
+
self.logger_callbacks = []
|
|
119
|
+
|
|
120
|
+
@property
|
|
121
|
+
def severity(self):
|
|
122
|
+
"""Returns the logging severity level."""
|
|
123
|
+
return self._severity
|
|
124
|
+
|
|
125
|
+
@severity.setter
|
|
126
|
+
def severity(self, value):
|
|
127
|
+
"""Returns or sets the logging severity level with callback updates."""
|
|
128
|
+
self._severity = value
|
|
129
|
+
for callback in self.logger_callbacks:
|
|
130
|
+
callback(self._severity)
|
|
131
|
+
|
|
132
|
+
def register_callback(self, callback):
|
|
133
|
+
"""
|
|
134
|
+
Registers a callback with the logger, which will be invoked when the logging severity is modified. The callback
|
|
135
|
+
is guaranteed to be called at least once in the register_callback function.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
callback (Callable(Logger.Severity)): A callback that accepts the current logger severity.
|
|
139
|
+
"""
|
|
140
|
+
callback(self._severity)
|
|
141
|
+
self.logger_callbacks.append(callback)
|
|
142
|
+
|
|
143
|
+
def indent(self, level=1):
|
|
144
|
+
"""Returns a context manager that indents all strings logged by the specified amount."""
|
|
145
|
+
return LoggerIndent(self, level + self.logging_indent)
|
|
146
|
+
|
|
147
|
+
def suppress(self, severity=CRITICAL):
|
|
148
|
+
"""
|
|
149
|
+
Returns a context manager that temporarily changes the severity of the logger for its duration.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
severity (Logger.Severity): The severity to set the logger to. Defaults to Logger.CRITICAL, which will suppress all messages.
|
|
153
|
+
"""
|
|
154
|
+
return LoggerSuppress(self, severity)
|
|
155
|
+
|
|
156
|
+
# If once is True, the logger will only log this message a single time. Useful in loops.
|
|
157
|
+
# message may be a callable which returns a message. This way, only if the message needs to be logged is it ever generated.
|
|
158
|
+
def log(self, message, severity, mode=LogMode.EACH, stack_depth=2):
|
|
159
|
+
"""Logs a message with a specified severity and mode, supporting both single and repeated logging based on
|
|
160
|
+
conditions.
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
def process_message(message, stack_depth):
|
|
164
|
+
"""Generates a log message prefix with file name and line number based on the specified stack depth."""
|
|
165
|
+
|
|
166
|
+
def get_prefix():
|
|
167
|
+
def get_line_info():
|
|
168
|
+
module = inspect.getmodule(sys._getframe(stack_depth + 3)) or inspect.getmodule(
|
|
169
|
+
sys._getframe(stack_depth + 2)
|
|
170
|
+
)
|
|
171
|
+
filename = module.__file__
|
|
172
|
+
filename = os.path.relpath(filename, self.root_dir)
|
|
173
|
+
# If the file is not located in trt_smeagol, use its basename instead.
|
|
174
|
+
if os.pardir in filename:
|
|
175
|
+
filename = os.path.basename(filename)
|
|
176
|
+
return f"[{filename}:{sys._getframe(stack_depth).f_lineno}] "
|
|
177
|
+
|
|
178
|
+
prefix = ""
|
|
179
|
+
if self.letter:
|
|
180
|
+
prefix += f"{Logger.SEVERITY_LETTER_MAPPING[severity]} "
|
|
181
|
+
if self.timestamp:
|
|
182
|
+
prefix += "({:}) ".format(time.strftime("%X"))
|
|
183
|
+
if self.line_info:
|
|
184
|
+
prefix += get_line_info()
|
|
185
|
+
return prefix
|
|
186
|
+
|
|
187
|
+
def apply_indentation(message):
|
|
188
|
+
"""Indent each line in the message by the specified logging_indent level."""
|
|
189
|
+
message_lines = str(message).splitlines()
|
|
190
|
+
return "\n".join(["\t" * self.logging_indent + line for line in message_lines])
|
|
191
|
+
|
|
192
|
+
def apply_color(message):
|
|
193
|
+
"""Apply color formatting to the message if color support is enabled."""
|
|
194
|
+
if self.colors:
|
|
195
|
+
try:
|
|
196
|
+
import colored
|
|
197
|
+
|
|
198
|
+
color = Logger.SEVERITY_COLOR_MAPPING[severity]
|
|
199
|
+
return colored.stylize(message, [colored.fg(color)])
|
|
200
|
+
except ImportError:
|
|
201
|
+
self.colors = False
|
|
202
|
+
self.warning(
|
|
203
|
+
"colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored"
|
|
204
|
+
)
|
|
205
|
+
self.colors = True
|
|
206
|
+
return message
|
|
207
|
+
|
|
208
|
+
prefix = get_prefix()
|
|
209
|
+
message = apply_indentation(message)
|
|
210
|
+
return apply_color(f"{prefix}{message}")
|
|
211
|
+
|
|
212
|
+
def should_log(message):
|
|
213
|
+
"""Determines if a message should be logged based on the severity level and logging mode."""
|
|
214
|
+
should = severity >= self._severity
|
|
215
|
+
if mode == LogMode.ONCE:
|
|
216
|
+
message_hash = hash(message)
|
|
217
|
+
should &= message_hash not in self.once_logged
|
|
218
|
+
self.once_logged.add(message_hash)
|
|
219
|
+
return should
|
|
220
|
+
|
|
221
|
+
if not should_log(message):
|
|
222
|
+
return
|
|
223
|
+
|
|
224
|
+
if callable(message):
|
|
225
|
+
message = message()
|
|
226
|
+
message = str(message)
|
|
227
|
+
print(process_message(message, stack_depth=stack_depth))
|
|
228
|
+
|
|
229
|
+
def ultra_verbose(self, message, mode=LogMode.EACH):
|
|
230
|
+
"""Logs an ultra-verbose message with a specified mode and stack depth of 3."""
|
|
231
|
+
self.log(message, Logger.ULTRA_VERBOSE, mode=mode, stack_depth=3)
|
|
232
|
+
|
|
233
|
+
def verbose(self, message, mode=LogMode.EACH):
|
|
234
|
+
"""Logs a verbose message with a specified mode and stack depth of 3."""
|
|
235
|
+
self.log(message, Logger.VERBOSE, mode=mode, stack_depth=3)
|
|
236
|
+
|
|
237
|
+
def debug(self, message, mode=LogMode.EACH):
|
|
238
|
+
"""Logs a debug message with a specified mode and stack depth of 3."""
|
|
239
|
+
self.log(message, Logger.DEBUG, mode=mode, stack_depth=3)
|
|
240
|
+
|
|
241
|
+
def info(self, message, mode=LogMode.EACH):
|
|
242
|
+
"""Logs an informational message with a specified mode and stack depth of 3."""
|
|
243
|
+
self.log(message, Logger.INFO, mode=mode, stack_depth=3)
|
|
244
|
+
|
|
245
|
+
def warning(self, message, mode=LogMode.EACH):
|
|
246
|
+
"""Logs a warning message with a specified mode and stack depth of 3."""
|
|
247
|
+
self.log(message, Logger.WARNING, mode=mode, stack_depth=3)
|
|
248
|
+
|
|
249
|
+
def error(self, message, mode=LogMode.EACH):
|
|
250
|
+
"""Logs an error message with a specified mode and stack depth of 3."""
|
|
251
|
+
self.log(message, Logger.ERROR, mode=mode, stack_depth=3)
|
|
252
|
+
|
|
253
|
+
# Like error, but immediately exits.
|
|
254
|
+
def critical(self, message):
|
|
255
|
+
"""Logs a critical message with a stack depth of 3 and raises an OnnxGraphSurgeonException."""
|
|
256
|
+
self.log(message, Logger.CRITICAL, stack_depth=3)
|
|
257
|
+
raise OnnxGraphSurgeonException(message) from None # Erase exception chain
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
global G_LOGGER
|
|
261
|
+
G_LOGGER = Logger()
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class OnnxGraphSurgeonException(Exception):
|
|
20
|
+
"""An exception raised by ONNX-GraphSurgeon."""
|
|
@@ -0,0 +1,252 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
import numpy as np
|
|
23
|
+
from onnx import AttributeProto
|
|
24
|
+
|
|
25
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
# default_value exists to solve issues that might result from Python's normal default argument behavior.
|
|
29
|
+
# Specifically, consider the following class:
|
|
30
|
+
#
|
|
31
|
+
# class MyClass(object):
|
|
32
|
+
# def __init__(self, value=[]):
|
|
33
|
+
# self.value = value
|
|
34
|
+
#
|
|
35
|
+
# This leads to unwanted behavior when the default value is used:
|
|
36
|
+
#
|
|
37
|
+
# >>> x = MyClass()
|
|
38
|
+
# >>> x.value.append("SHOULD NOT BE IN Y")
|
|
39
|
+
# >>> y = MyClass()
|
|
40
|
+
# >>> y.value
|
|
41
|
+
# ['SHOULD NOT BE IN Y']
|
|
42
|
+
#
|
|
43
|
+
# If we rewrite the class using default value:
|
|
44
|
+
#
|
|
45
|
+
# class MyClass(object):
|
|
46
|
+
# def __init__(self, value=None):
|
|
47
|
+
# self.value = default_value(value, [])
|
|
48
|
+
#
|
|
49
|
+
# Then we get the desired behavior:
|
|
50
|
+
#
|
|
51
|
+
# >>> x = MyClass()
|
|
52
|
+
# >>> x.value.append("SHOULD NOT BE IN Y")
|
|
53
|
+
# >>> y = MyClass()
|
|
54
|
+
# >>> y.value
|
|
55
|
+
# []
|
|
56
|
+
def default_value(value, default):
|
|
57
|
+
"""Return the value if not None, otherwise return the default value."""
|
|
58
|
+
return value if value is not None else default
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def combine_dicts(dict0, dict1):
|
|
62
|
+
"""
|
|
63
|
+
Combine two dictionaries.
|
|
64
|
+
|
|
65
|
+
Values in the second will overwrite values in the first.
|
|
66
|
+
"""
|
|
67
|
+
if dict1 is None:
|
|
68
|
+
return dict0
|
|
69
|
+
combined = OrderedDict()
|
|
70
|
+
combined.update(dict0)
|
|
71
|
+
combined.update(dict1)
|
|
72
|
+
return combined
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def unique_dicts(dict0, dict1):
|
|
76
|
+
"""
|
|
77
|
+
Subtract two dictionaries.
|
|
78
|
+
|
|
79
|
+
Values in the second will be subtracted from the first.
|
|
80
|
+
"""
|
|
81
|
+
return {k: v for k, v in dict0.items() if k not in dict1} if dict1 else dict0
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def is_dynamic_dimension(dim):
|
|
85
|
+
"""Check if a dimension is dynamic (non-integer or negative)."""
|
|
86
|
+
return not isinstance(dim, int) or dim < 0
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def is_dynamic_shape(shape):
|
|
90
|
+
"""Determine if any dimension in the given shape is dynamic (non-integer or negative)."""
|
|
91
|
+
return any(is_dynamic_dimension(dim) for dim in shape)
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def volume(obj):
|
|
95
|
+
"""Calculate the volume by multiplying the elements of an iterable object."""
|
|
96
|
+
vol = 1
|
|
97
|
+
for elem in obj:
|
|
98
|
+
vol *= elem
|
|
99
|
+
return vol
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
_ONNX_ATTR_TYPE_TO_GS_TYPE = {}
|
|
103
|
+
_GS_TYPE_TO_ONNX_ATTR_TYPE = {}
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
# This method prevents circular import of Tensor and Graph
|
|
107
|
+
def _init_dicts():
|
|
108
|
+
"""Initialize mapping dictionaries to prevent circular imports of Tensor and Graph."""
|
|
109
|
+
global _ONNX_ATTR_TYPE_TO_GS_TYPE
|
|
110
|
+
global _GS_TYPE_TO_ONNX_ATTR_TYPE
|
|
111
|
+
if _ONNX_ATTR_TYPE_TO_GS_TYPE and _GS_TYPE_TO_ONNX_ATTR_TYPE:
|
|
112
|
+
return
|
|
113
|
+
|
|
114
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
115
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor
|
|
116
|
+
|
|
117
|
+
_ONNX_ATTR_TYPE_TO_GS_TYPE = {
|
|
118
|
+
AttributeProto.UNDEFINED: None,
|
|
119
|
+
AttributeProto.FLOAT: float,
|
|
120
|
+
AttributeProto.INT: int,
|
|
121
|
+
AttributeProto.STRING: str,
|
|
122
|
+
AttributeProto.TENSOR: Tensor,
|
|
123
|
+
AttributeProto.GRAPH: Graph,
|
|
124
|
+
AttributeProto.SPARSE_TENSOR: AttributeProto.SPARSE_TENSOR,
|
|
125
|
+
AttributeProto.TYPE_PROTO: AttributeProto.TYPE_PROTO,
|
|
126
|
+
AttributeProto.FLOATS: list[float],
|
|
127
|
+
AttributeProto.INTS: list[int],
|
|
128
|
+
AttributeProto.STRINGS: list[str],
|
|
129
|
+
AttributeProto.TENSORS: list[Tensor],
|
|
130
|
+
AttributeProto.GRAPHS: list[Graph],
|
|
131
|
+
AttributeProto.SPARSE_TENSORS: AttributeProto.SPARSE_TENSORS,
|
|
132
|
+
AttributeProto.TYPE_PROTOS: AttributeProto.TYPE_PROTOS,
|
|
133
|
+
}
|
|
134
|
+
_GS_TYPE_TO_ONNX_ATTR_TYPE = {v: k for k, v in _ONNX_ATTR_TYPE_TO_GS_TYPE.items()}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def convert_from_onnx_attr_type(onnx_attr_type):
|
|
138
|
+
"""Converts an ONNX attribute type to its corresponding GS attribute type."""
|
|
139
|
+
_init_dicts()
|
|
140
|
+
return _ONNX_ATTR_TYPE_TO_GS_TYPE[onnx_attr_type]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def convert_to_onnx_attr_type(any_type):
|
|
144
|
+
"""Converts a given type to its corresponding ONNX attribute type."""
|
|
145
|
+
_init_dicts()
|
|
146
|
+
if any_type in _GS_TYPE_TO_ONNX_ATTR_TYPE:
|
|
147
|
+
return _GS_TYPE_TO_ONNX_ATTR_TYPE[any_type]
|
|
148
|
+
if np.issubdtype(any_type, np.floating):
|
|
149
|
+
return AttributeProto.FLOAT
|
|
150
|
+
if np.issubdtype(any_type, np.integer):
|
|
151
|
+
return AttributeProto.INT
|
|
152
|
+
G_LOGGER.warning(f"Unable to convert {any_type} into an ONNX AttributeType")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
# Special type of list that synchronizes contents with another list.
|
|
156
|
+
# Concrete example: Assume some node, n, contains an input tensor, t. If we remove t from n.inputs,
|
|
157
|
+
# we also need to remove n from t.outputs. To avoid having to do this manually, we use SynchronizedList,
|
|
158
|
+
# which takes an attribute name as a parameter, and then synchronizes to that attribute of each of its elements.
|
|
159
|
+
# So, in the example above, we can make n.inputs a synchronized list whose field_name is set to "outputs".
|
|
160
|
+
# See test_ir.TestNodeIO for functional tests
|
|
161
|
+
class SynchronizedList(list):
|
|
162
|
+
def __init__(self, parent_obj, field_name, initial):
|
|
163
|
+
"""Initialize a SynchronizedList with a parent object, a field name, and an initial set of elements."""
|
|
164
|
+
self.parent_obj = parent_obj
|
|
165
|
+
self.field_name = field_name
|
|
166
|
+
self.extend(initial)
|
|
167
|
+
|
|
168
|
+
def _add_to_elem(self, elem):
|
|
169
|
+
"""Append the parent_obj to the list attribute defined by field_name in the provided elem object."""
|
|
170
|
+
list.append(getattr(elem, self.field_name), self.parent_obj)
|
|
171
|
+
|
|
172
|
+
def _remove_from_elem(self, elem):
|
|
173
|
+
"""Remove the parent_obj from the list attribute defined by field_name in the provided elem object."""
|
|
174
|
+
list.remove(getattr(elem, self.field_name), self.parent_obj)
|
|
175
|
+
|
|
176
|
+
def __delitem__(self, index):
|
|
177
|
+
"""Remove the element at the specified index and update the corresponding list attribute in the parent
|
|
178
|
+
object.
|
|
179
|
+
"""
|
|
180
|
+
self._remove_from_elem(self[index])
|
|
181
|
+
super().__delitem__(index)
|
|
182
|
+
|
|
183
|
+
def __setitem__(self, index, elem):
|
|
184
|
+
"""Update the element at the specified index and modify the corresponding list attribute in the parent
|
|
185
|
+
object.
|
|
186
|
+
"""
|
|
187
|
+
self._remove_from_elem(self[index])
|
|
188
|
+
super().__setitem__(index, elem)
|
|
189
|
+
self._add_to_elem(elem)
|
|
190
|
+
|
|
191
|
+
def append(self, x):
|
|
192
|
+
"""Append an element to the list and update the parent object's corresponding list attribute."""
|
|
193
|
+
super().append(x)
|
|
194
|
+
self._add_to_elem(x)
|
|
195
|
+
|
|
196
|
+
def extend(self, iterable: Sequence[object]):
|
|
197
|
+
"""Extend the list with elements from an iterable and update the parent object's corresponding list
|
|
198
|
+
attribute.
|
|
199
|
+
"""
|
|
200
|
+
super().extend(iterable)
|
|
201
|
+
for elem in iterable:
|
|
202
|
+
self._add_to_elem(elem)
|
|
203
|
+
|
|
204
|
+
def insert(self, i, x):
|
|
205
|
+
"""Insert an element at a given position and update the parent object's corresponding list attribute."""
|
|
206
|
+
super().insert(i, x)
|
|
207
|
+
self._add_to_elem(x)
|
|
208
|
+
|
|
209
|
+
def remove(self, x):
|
|
210
|
+
"""Remove an element from the list and update the parent object's corresponding list attribute."""
|
|
211
|
+
super().remove(x)
|
|
212
|
+
self._remove_from_elem(x)
|
|
213
|
+
|
|
214
|
+
def pop(self, i=-1):
|
|
215
|
+
"""Remove and return the element at index i (default last) from the list and update the parent object's
|
|
216
|
+
corresponding list attribute.
|
|
217
|
+
"""
|
|
218
|
+
elem = super().pop(i)
|
|
219
|
+
self._remove_from_elem(elem)
|
|
220
|
+
return elem
|
|
221
|
+
|
|
222
|
+
def clear(self):
|
|
223
|
+
"""Clear all elements from the list and update the parent object's corresponding list attribute."""
|
|
224
|
+
for elem in self:
|
|
225
|
+
self._remove_from_elem(elem)
|
|
226
|
+
super().clear()
|
|
227
|
+
|
|
228
|
+
def __add__(self, other_list: list[object]):
|
|
229
|
+
"""Concatenate the current list with another list and return the resulting list."""
|
|
230
|
+
return list(self) + list(other_list)
|
|
231
|
+
|
|
232
|
+
def __iadd__(self, other_list: list[object]):
|
|
233
|
+
"""Append elements from another list to the current list and return the modified list."""
|
|
234
|
+
self.extend(other_list)
|
|
235
|
+
return self
|
|
236
|
+
|
|
237
|
+
def __copy__(self):
|
|
238
|
+
"""Return a shallow copy of the current list."""
|
|
239
|
+
return list(self)
|
|
240
|
+
|
|
241
|
+
def __deepcopy__(self, memo):
|
|
242
|
+
"""Return a deep copy of the current list."""
|
|
243
|
+
return list(self)
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def sequences_equal(seq1, seq2):
|
|
247
|
+
"""Check if two sequences are equal by comparing their lengths and elements."""
|
|
248
|
+
length_match = len(seq1) == len(seq2)
|
|
249
|
+
if not length_match:
|
|
250
|
+
return False
|
|
251
|
+
|
|
252
|
+
return all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
|