onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,266 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from collections import OrderedDict
|
|
20
|
+
from dataclasses import dataclass
|
|
21
|
+
|
|
22
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
|
|
23
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
24
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Node:
|
|
28
|
+
@dataclass
|
|
29
|
+
class AttributeRef:
|
|
30
|
+
"""
|
|
31
|
+
An AttributeRef is an attribute value which references an attribute in the parent function. A node's attribute
|
|
32
|
+
can only be an AttributeRef if the node lives inside a Function.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
name (str): The name of the referenced attribute in the parent Function.
|
|
36
|
+
type (type): The attribute's type.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
name: str
|
|
40
|
+
type: type
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
op: str,
|
|
45
|
+
name: str | None = None,
|
|
46
|
+
attrs: dict[str, object] | None = None,
|
|
47
|
+
inputs: list[Tensor] | None = None,
|
|
48
|
+
outputs: list[Tensor] | None = None,
|
|
49
|
+
domain: str | None = None,
|
|
50
|
+
):
|
|
51
|
+
"""
|
|
52
|
+
A node represents an operation in a graph, and consumes zero or more Tensors, and produces zero or more Tensors.
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
op (str): The operation this node performs.
|
|
56
|
+
|
|
57
|
+
name (str): The name of this node.
|
|
58
|
+
attrs (Dict[str, object]): A dictionary that maps attribute names to their values.
|
|
59
|
+
inputs (List[Tensor]): A list of zero or more input Tensors.
|
|
60
|
+
outputs (List[Tensor]): A list of zero or more output Tensors.
|
|
61
|
+
domain (str): The domain of this node,
|
|
62
|
+
"""
|
|
63
|
+
self.op = op
|
|
64
|
+
self.name = misc.default_value(name, "")
|
|
65
|
+
self.attrs = misc.default_value(attrs, OrderedDict())
|
|
66
|
+
self.inputs = misc.SynchronizedList(self, field_name="outputs", initial=misc.default_value(inputs, []))
|
|
67
|
+
self.outputs = misc.SynchronizedList(self, field_name="inputs", initial=misc.default_value(outputs, []))
|
|
68
|
+
self.domain = domain
|
|
69
|
+
|
|
70
|
+
def i(self, tensor_idx=0, producer_idx=0):
|
|
71
|
+
"""
|
|
72
|
+
Convenience function to get a producer node of one of this node's input tensors. Note that the parameters are
|
|
73
|
+
swapped compared to the o() function; this is because tensors are likely to have only a single producer.
|
|
74
|
+
|
|
75
|
+
For example:
|
|
76
|
+
::
|
|
77
|
+
|
|
78
|
+
assert node.i() == node.inputs[0].inputs[0]
|
|
79
|
+
assert node.i(1, 2) == node.inputs[1].inputs[2]
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
tensor_idx (int): The index of the input tensor of this node. Defaults to 0.
|
|
83
|
+
producer_idx (int): The index of the producer of the input tensor, if the tensor has multiple producers. Defaults to 0
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Node: The specified producer (input) node.
|
|
87
|
+
"""
|
|
88
|
+
return self.inputs[tensor_idx].inputs[producer_idx]
|
|
89
|
+
|
|
90
|
+
def o(self, consumer_idx=0, tensor_idx=0):
|
|
91
|
+
"""
|
|
92
|
+
Convenience function to get a consumer node of one of this node's output tensors.
|
|
93
|
+
|
|
94
|
+
For example:
|
|
95
|
+
::
|
|
96
|
+
|
|
97
|
+
assert node.o() == node.outputs[0].outputs[0]
|
|
98
|
+
assert node.o(2, 1) == node.outputs[1].outputs[2]
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
consumer_idx (int): The index of the consumer of the input tensor. Defaults to 0.
|
|
102
|
+
tensor_idx (int): The index of the output tensor of this node, if the node has multiple outputs. Defaults to 0.
|
|
103
|
+
|
|
104
|
+
Returns:
|
|
105
|
+
Node: The specified consumer (output) node
|
|
106
|
+
"""
|
|
107
|
+
return self.outputs[tensor_idx].outputs[consumer_idx]
|
|
108
|
+
|
|
109
|
+
def subgraphs(self, recursive=False):
|
|
110
|
+
"""
|
|
111
|
+
Convenience function to iterate over all subgraphs which are contained in this node. Node subgraphs are found in
|
|
112
|
+
attributes of ONNX control flow nodes such as 'If' and 'Loop'.
|
|
113
|
+
|
|
114
|
+
Args:
|
|
115
|
+
recursive (bool): Whether to recurse into the subgraph nodes when looking for subgraphs. Defaults to False.
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
A generator which iterates over this node's subgraphs.
|
|
119
|
+
"""
|
|
120
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
121
|
+
|
|
122
|
+
visit_queue = [self]
|
|
123
|
+
|
|
124
|
+
# This prevents infinite recursion in the (illegal) case of cyclical graphs.
|
|
125
|
+
visited = set()
|
|
126
|
+
|
|
127
|
+
while visit_queue:
|
|
128
|
+
node = visit_queue.pop()
|
|
129
|
+
for attr in node.attrs.values():
|
|
130
|
+
if isinstance(attr, Graph) and id(attr) not in visited:
|
|
131
|
+
visited.add(id(attr))
|
|
132
|
+
if recursive:
|
|
133
|
+
visit_queue.extend(attr.nodes)
|
|
134
|
+
yield attr
|
|
135
|
+
|
|
136
|
+
def __setattr__(self, name, value):
|
|
137
|
+
"""Sets the attribute 'name' to 'value', handling special cases for 'inputs' and 'outputs' attributes."""
|
|
138
|
+
if name in {"inputs", "outputs"}:
|
|
139
|
+
try:
|
|
140
|
+
attr = getattr(self, name)
|
|
141
|
+
if value is attr:
|
|
142
|
+
# This can happen when using things like +=
|
|
143
|
+
# The __iadd__ is executed followed by an assignment
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
attr.clear()
|
|
147
|
+
attr.extend(value)
|
|
148
|
+
except AttributeError:
|
|
149
|
+
super().__setattr__(name, value)
|
|
150
|
+
else:
|
|
151
|
+
super().__setattr__(name, value)
|
|
152
|
+
|
|
153
|
+
def copy(
|
|
154
|
+
self,
|
|
155
|
+
inputs: list[Tensor] | None = None,
|
|
156
|
+
outputs: list[Tensor] | None = None,
|
|
157
|
+
tensor_map=None,
|
|
158
|
+
):
|
|
159
|
+
"""
|
|
160
|
+
Makes a shallow copy of this node, overriding input and output information.
|
|
161
|
+
|
|
162
|
+
Note: Generally, you should only ever make a copy of a Graph.
|
|
163
|
+
"""
|
|
164
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
165
|
+
|
|
166
|
+
new_attrs = OrderedDict()
|
|
167
|
+
for name, attr in self.attrs.items():
|
|
168
|
+
new_attrs[name] = attr.copy(tensor_map) if isinstance(attr, Graph) else attr
|
|
169
|
+
return Node(
|
|
170
|
+
self.op,
|
|
171
|
+
self.name,
|
|
172
|
+
new_attrs,
|
|
173
|
+
inputs=inputs,
|
|
174
|
+
outputs=outputs,
|
|
175
|
+
domain=self.domain,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def __str__(self):
|
|
179
|
+
"""Return a string representation of the object showing its name and operation."""
|
|
180
|
+
ret = f"{self.name} ({self.op})"
|
|
181
|
+
|
|
182
|
+
def add_io(name, io):
|
|
183
|
+
"""Add the input or output operations and their names to the string representation of the object."""
|
|
184
|
+
nonlocal ret
|
|
185
|
+
ret += f"\n\t{name}: ["
|
|
186
|
+
for elem in io:
|
|
187
|
+
ret += f"\n\t\t{elem}"
|
|
188
|
+
ret += "\n\t]"
|
|
189
|
+
|
|
190
|
+
add_io("Inputs", self.inputs)
|
|
191
|
+
add_io("Outputs", self.outputs)
|
|
192
|
+
|
|
193
|
+
if self.attrs:
|
|
194
|
+
ret += f"\nAttributes: {self.attrs}"
|
|
195
|
+
|
|
196
|
+
if self.domain:
|
|
197
|
+
ret += f"\nDomain: {self.domain}"
|
|
198
|
+
|
|
199
|
+
return ret
|
|
200
|
+
|
|
201
|
+
def __repr__(self):
|
|
202
|
+
"""Return the string representation of the Ultralytics object."""
|
|
203
|
+
return self.__str__()
|
|
204
|
+
|
|
205
|
+
def __eq__(self, other):
|
|
206
|
+
"""Check whether two nodes are equal by comparing name, attributes, op, inputs, and outputs."""
|
|
207
|
+
G_LOGGER.verbose(f"Comparing node: {self.name} with {other.name}")
|
|
208
|
+
attrs_match = self.name == other.name and self.op == other.op and self.attrs == other.attrs
|
|
209
|
+
if not attrs_match:
|
|
210
|
+
return False
|
|
211
|
+
|
|
212
|
+
inputs_match = misc.sequences_equal(self.inputs, other.inputs)
|
|
213
|
+
if not inputs_match:
|
|
214
|
+
return False
|
|
215
|
+
|
|
216
|
+
outputs_match = misc.sequences_equal(self.outputs, other.outputs)
|
|
217
|
+
return self.domain == other.domain if outputs_match else False
|
|
218
|
+
|
|
219
|
+
@property
|
|
220
|
+
def users(self):
|
|
221
|
+
users = []
|
|
222
|
+
for output in self.outputs: # output is a Variable
|
|
223
|
+
if output.is_output:
|
|
224
|
+
output.op = "output"
|
|
225
|
+
users.append(output)
|
|
226
|
+
users.extend(iter(output.outputs))
|
|
227
|
+
return users
|
|
228
|
+
|
|
229
|
+
@property
|
|
230
|
+
def feeds(self):
|
|
231
|
+
"""Retrieve the list of nodes that provide inputs to the given node."""
|
|
232
|
+
feeds = []
|
|
233
|
+
for input in self.inputs:
|
|
234
|
+
if len(input.inputs) == 0 and not isinstance(input, Constant):
|
|
235
|
+
feeds.append(input)
|
|
236
|
+
elif isinstance(input, Constant):
|
|
237
|
+
feeds.append(input)
|
|
238
|
+
else:
|
|
239
|
+
feeds.extend(input if feed.op == "Split" else feed for feed in input.inputs)
|
|
240
|
+
return feeds
|
|
241
|
+
|
|
242
|
+
def erase(self, input_var_idx=0, output_var_idx=0):
|
|
243
|
+
if isinstance(self.inputs[input_var_idx], Variable):
|
|
244
|
+
if self.inputs[input_var_idx].is_input:
|
|
245
|
+
self.outputs[output_var_idx].replace_all_uses_with(self.inputs[input_var_idx])
|
|
246
|
+
self.inputs.clear()
|
|
247
|
+
self.outputs.clear()
|
|
248
|
+
else:
|
|
249
|
+
self.inputs[input_var_idx].replace_all_uses_with(self.outputs[output_var_idx])
|
|
250
|
+
self.inputs.clear()
|
|
251
|
+
self.outputs.clear()
|
|
252
|
+
|
|
253
|
+
def replace_all_uses_with(self, node: Node):
|
|
254
|
+
"""Replace all uses of this node with the given node."""
|
|
255
|
+
for user in self.users:
|
|
256
|
+
for inp in user.inputs:
|
|
257
|
+
if inp in self.outputs:
|
|
258
|
+
for i, input in enumerate(user.inputs):
|
|
259
|
+
if input == inp:
|
|
260
|
+
user.inputs[i] = node.outputs[self.outputs.index(inp)]
|
|
261
|
+
|
|
262
|
+
if isinstance(self.outputs[0], Variable) and self.outputs[0].is_output:
|
|
263
|
+
node.outputs[0] = self.outputs[0]
|
|
264
|
+
|
|
265
|
+
self.inputs.clear()
|
|
266
|
+
self.outputs.clear()
|