onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
from collections.abc import Sequence
|
|
21
|
+
|
|
22
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
23
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
|
|
24
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Tensor, Variable
|
|
25
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
26
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class Function(Graph):
|
|
30
|
+
"""
|
|
31
|
+
Represents a local function, which is a default implementation of a Custom Op. This default implementation is
|
|
32
|
+
represented as a Graph of other Ops.
|
|
33
|
+
|
|
34
|
+
Functions are used in a model by creating a Node with the same name and domain as the function. This can be done
|
|
35
|
+
using the __call__() method of a Function, which creates this new node and appends it to a Graph. A Function is not
|
|
36
|
+
a subgraph of a Graph, and its Nodes, Tensors, and subgraphs are entirely separate from the main Graph.
|
|
37
|
+
|
|
38
|
+
Functions can be composed of other functions, but cyclical or recursive definitions are not allowed in ONNX.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
DEFAULT_DOMAIN = "onnx_graphsurgeon"
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
name: str,
|
|
46
|
+
domain: str | None = None,
|
|
47
|
+
nodes: Sequence[Node] | None = None,
|
|
48
|
+
inputs: Sequence[Tensor] | None = None,
|
|
49
|
+
outputs: Sequence[Tensor] | None = None,
|
|
50
|
+
doc_string: str | None = None,
|
|
51
|
+
opset: int | None = None,
|
|
52
|
+
import_domains: Sequence[onnx.OperatorSetIdProto] | None = None,
|
|
53
|
+
functions: Sequence[Function] | None = None,
|
|
54
|
+
attrs: dict | None = None,
|
|
55
|
+
):
|
|
56
|
+
"""
|
|
57
|
+
Args:
|
|
58
|
+
name (str): The name of the function.
|
|
59
|
+
domain (str): The domain/namespace of this function.
|
|
60
|
+
nodes (Sequence[Node]): A list of the nodes in this function.
|
|
61
|
+
inputs (Sequence[Tensor]): A list of graph input Tensors.
|
|
62
|
+
outputs (Sequence[Tensor]): A list of graph output Tensors.
|
|
63
|
+
doc_string (str): A doc_string for the function. Defaults to "".
|
|
64
|
+
opset (int): The ONNX opset used by nodes in this function.
|
|
65
|
+
import_domains (Sequence[onnx.OperatorSetIdProto]): The list of domains used by nodes in this function.
|
|
66
|
+
functions (Sequence[Function]): The list of functions in this model.
|
|
67
|
+
attrs (dict): A mapping of attribute names to their default values.
|
|
68
|
+
Nodes within this function can have attributes which take on the values of the Function attributes.
|
|
69
|
+
When a Function is instantiated into a Node, providing attributes to that Node will override the Function's
|
|
70
|
+
default attribute values. A default value of `None` means that the instantiated Node must provide the value
|
|
71
|
+
of that attribute (in other words, it is a required attribute).
|
|
72
|
+
"""
|
|
73
|
+
self.domain = misc.default_value(domain, Function.DEFAULT_DOMAIN)
|
|
74
|
+
self.attrs = misc.default_value(attrs, {})
|
|
75
|
+
|
|
76
|
+
super().__init__(
|
|
77
|
+
nodes,
|
|
78
|
+
inputs,
|
|
79
|
+
outputs,
|
|
80
|
+
name=name,
|
|
81
|
+
doc_string=doc_string,
|
|
82
|
+
opset=opset,
|
|
83
|
+
import_domains=import_domains,
|
|
84
|
+
functions=functions,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
# Properties of Graph that Function doesn't have.
|
|
88
|
+
del self.producer_name
|
|
89
|
+
del self.producer_version
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def unique_id(self):
|
|
93
|
+
"""Returns a tuple which uniquely identifies this function."""
|
|
94
|
+
return (self.domain, self.name)
|
|
95
|
+
|
|
96
|
+
def cleanup(
|
|
97
|
+
self,
|
|
98
|
+
remove_unused_node_outputs=False,
|
|
99
|
+
recurse_subgraphs=True,
|
|
100
|
+
remove_unused_graph_inputs=False,
|
|
101
|
+
recurse_functions=False,
|
|
102
|
+
):
|
|
103
|
+
"""See Graph.cleanup() The only difference is that 'recurse_functions' defaults to False, so that only this
|
|
104
|
+
Function is cleaned up.
|
|
105
|
+
"""
|
|
106
|
+
if recurse_functions:
|
|
107
|
+
G_LOGGER.warning(
|
|
108
|
+
"Function.cleanup() called with recurse_functions=True, meaning that other functions will also be cleaned up."
|
|
109
|
+
)
|
|
110
|
+
return super().cleanup(
|
|
111
|
+
remove_unused_node_outputs=remove_unused_node_outputs,
|
|
112
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
113
|
+
remove_unused_graph_inputs=remove_unused_graph_inputs,
|
|
114
|
+
recurse_functions=recurse_functions,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def fold_constants(self, recurse_functions=False, **kwargs):
|
|
118
|
+
"""See Graph.fold_constants() The only difference is that 'recurse_functions' defaults to False, so that only
|
|
119
|
+
this Function's constants are folded.
|
|
120
|
+
"""
|
|
121
|
+
if recurse_functions:
|
|
122
|
+
G_LOGGER.warning(
|
|
123
|
+
"Function.fold_constants() called with recurse_functions=True, meaning that other functions will also be const-folded."
|
|
124
|
+
)
|
|
125
|
+
return super().fold_constants(recurse_functions=recurse_functions, **kwargs)
|
|
126
|
+
|
|
127
|
+
def toposort(
|
|
128
|
+
self,
|
|
129
|
+
recurse_subgraphs=True,
|
|
130
|
+
recurse_functions=False,
|
|
131
|
+
mode="nodes",
|
|
132
|
+
):
|
|
133
|
+
"""See Graph.toposort() The only difference is that 'recurse_functions' defaults to False and mode defaults to
|
|
134
|
+
"nodes", so that by default only this function's nodes will be sorted.
|
|
135
|
+
"""
|
|
136
|
+
if recurse_functions:
|
|
137
|
+
G_LOGGER.warning(
|
|
138
|
+
"Function.toposort() called with recurse_functions=True, meaning that other functions will be sorted."
|
|
139
|
+
)
|
|
140
|
+
return super().toposort(
|
|
141
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
142
|
+
recurse_functions=recurse_functions,
|
|
143
|
+
mode=mode,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
def __call__(self, graph, inputs=None, outputs=None, *args, **kwargs) -> list[Tensor]:
|
|
147
|
+
"""
|
|
148
|
+
Creates a Node which is an instance of this function. The created node can be used in a Graph or another
|
|
149
|
+
Function.
|
|
150
|
+
|
|
151
|
+
The provided inputs are processed the same way as in Graph.layer().
|
|
152
|
+
If outputs are not provided, they are created based on the Function's outputs.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
graph (Union[Graph, Function]): The Graph of Function to add the new node to.
|
|
156
|
+
inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs.
|
|
157
|
+
outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs.
|
|
158
|
+
attrs (Dict[str, Any]): A list of attributes for the node.
|
|
159
|
+
The attribute names should be a subset of this Function's attribute names.
|
|
160
|
+
args/kwargs: These are passed directly to the constructor of Node.
|
|
161
|
+
|
|
162
|
+
Returns:
|
|
163
|
+
List[Tensor]: The output tensors of the node.
|
|
164
|
+
"""
|
|
165
|
+
if inputs is not None and len(inputs) != len(self.inputs):
|
|
166
|
+
msg_template = "Function {} expects {} inputs, but was called with {} inputs."
|
|
167
|
+
G_LOGGER.warning(msg_template.format(self.name, len(self.inputs), len(inputs)))
|
|
168
|
+
|
|
169
|
+
new_output_indices = []
|
|
170
|
+
if outputs is None:
|
|
171
|
+
# Graph.layer() will create Tensors and make sure the names do not conflict.
|
|
172
|
+
outputs = [out.name for out in self.outputs]
|
|
173
|
+
new_output_indices = list(range(len(outputs)))
|
|
174
|
+
elif len(outputs) != len(self.outputs):
|
|
175
|
+
msg_template = "Function {} expects {} outputs, but was called with {} outputs."
|
|
176
|
+
G_LOGGER.warning(msg_template.format(self.name, len(self.outputs), len(outputs)))
|
|
177
|
+
else:
|
|
178
|
+
new_output_indices = [i for i in range(len(outputs)) if not isinstance(outputs[i], Tensor)]
|
|
179
|
+
|
|
180
|
+
attrs = kwargs.get("attrs", None)
|
|
181
|
+
if attrs is not None:
|
|
182
|
+
for attr_name, default_val in self.attrs.items():
|
|
183
|
+
if default_val is None and attr_name not in attrs:
|
|
184
|
+
msg_template = "Function {} called without required attribute: {}"
|
|
185
|
+
G_LOGGER.warning(msg_template.format(self.name, attr_name))
|
|
186
|
+
|
|
187
|
+
inputs = misc.default_value(inputs, [])
|
|
188
|
+
outputs = misc.default_value(outputs, [])
|
|
189
|
+
outputs = graph.layer(
|
|
190
|
+
*args,
|
|
191
|
+
**kwargs,
|
|
192
|
+
op=self.name,
|
|
193
|
+
domain=self.domain,
|
|
194
|
+
inputs=inputs,
|
|
195
|
+
outputs=outputs,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# For newly created output tensors, set their shape and dtype to match the Function definition.
|
|
199
|
+
for i in new_output_indices:
|
|
200
|
+
outputs[i].dtype = self.outputs[i].dtype
|
|
201
|
+
outputs[i].shape = self.outputs[i].shape
|
|
202
|
+
|
|
203
|
+
return outputs
|
|
204
|
+
|
|
205
|
+
def copy(self):
|
|
206
|
+
"""
|
|
207
|
+
Copy the function.
|
|
208
|
+
|
|
209
|
+
This makes copies of all nodes and tensors in the function, but will not
|
|
210
|
+
do a deep-copy of weights or attributes (with the exception of ``Graph``
|
|
211
|
+
attributes, which will be copied using their ``copy`` method).
|
|
212
|
+
|
|
213
|
+
Returns:
|
|
214
|
+
Function: A copy of the function.
|
|
215
|
+
"""
|
|
216
|
+
local_tensor_copies = {n: t.copy() for n, t in self.tensors().items()}
|
|
217
|
+
|
|
218
|
+
def get_tensor(name):
|
|
219
|
+
"""Retrieve a tensor by name from a deep-copied dictionary of tensors."""
|
|
220
|
+
return local_tensor_copies[name] if name else Variable.empty()
|
|
221
|
+
|
|
222
|
+
# Next, copy nodes, and update inputs/outputs
|
|
223
|
+
new_nodes = []
|
|
224
|
+
for node in self.nodes:
|
|
225
|
+
new_node = node.copy(
|
|
226
|
+
inputs=[get_tensor(inp.name) for inp in node.inputs],
|
|
227
|
+
outputs=[get_tensor(out.name) for out in node.outputs],
|
|
228
|
+
tensor_map=local_tensor_copies,
|
|
229
|
+
)
|
|
230
|
+
new_nodes.append(new_node)
|
|
231
|
+
new_func_inputs = [get_tensor(inp.name) for inp in self.inputs]
|
|
232
|
+
new_func_outputs = [get_tensor(out.name) for out in self.outputs]
|
|
233
|
+
|
|
234
|
+
new_attrs = {name: copy.copy(val) for name, val in self.attrs.items()}
|
|
235
|
+
|
|
236
|
+
return Function(
|
|
237
|
+
self.name,
|
|
238
|
+
self.domain,
|
|
239
|
+
nodes=new_nodes,
|
|
240
|
+
inputs=new_func_inputs,
|
|
241
|
+
outputs=new_func_outputs,
|
|
242
|
+
doc_string=self.doc_string,
|
|
243
|
+
opset=self.opset,
|
|
244
|
+
import_domains=self.import_domains,
|
|
245
|
+
functions=self.functions,
|
|
246
|
+
attrs=new_attrs,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
def __eq__(self, other: Function):
|
|
250
|
+
"""Checks equality of self with another Function object based on their attributes."""
|
|
251
|
+
|
|
252
|
+
def sequences_equal(seq1, seq2):
|
|
253
|
+
"""Checks if two sequences are equal in length and elements."""
|
|
254
|
+
return len(seq1) == len(seq2) and all(elem1 == elem2 for elem1, elem2 in zip(seq1, seq2))
|
|
255
|
+
|
|
256
|
+
return (
|
|
257
|
+
self.unique_id == other.unique_id
|
|
258
|
+
and self.opset == other.opset
|
|
259
|
+
and self.import_domains == other.import_domains
|
|
260
|
+
and sequences_equal(self.inputs, other.inputs)
|
|
261
|
+
and sequences_equal(self.outputs, other.outputs)
|
|
262
|
+
and sequences_equal(self.nodes, other.nodes)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def __str__(self):
|
|
266
|
+
"""Returns a string representation of the function including its name, domain, opset, inputs, nodes, and
|
|
267
|
+
outputs.
|
|
268
|
+
"""
|
|
269
|
+
nodes_str = "\n".join([str(node) for node in self.nodes])
|
|
270
|
+
out = f"Function {self.name}, Domain {self.domain}, Opset {self.opset}"
|
|
271
|
+
out += f"\nInputs: {self.inputs}"
|
|
272
|
+
out += f"\nNodes: {nodes_str}"
|
|
273
|
+
out += f"\nOutputs: {self.outputs}"
|
|
274
|
+
return out
|