onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,1575 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
import numbers
|
|
21
|
+
from collections import OrderedDict, defaultdict
|
|
22
|
+
from collections.abc import Sequence
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
|
|
26
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
|
|
27
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import Constant, Tensor, Variable
|
|
28
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER, LogMode
|
|
29
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NodeIDAdder:
|
|
33
|
+
def __init__(self, graph):
|
|
34
|
+
"""Initializes NodeIDAdder with a specified graph."""
|
|
35
|
+
self.graph = graph
|
|
36
|
+
|
|
37
|
+
def __enter__(self):
|
|
38
|
+
"""Assigns unique `id` attributes to each node in the graph upon entering the context."""
|
|
39
|
+
# Using the index in the node list allows the same object to count as different nodes.
|
|
40
|
+
for index, node in enumerate(self.graph.nodes):
|
|
41
|
+
node.id = index
|
|
42
|
+
|
|
43
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
|
44
|
+
"""Removes `id` attributes from each node in the graph upon exiting the context."""
|
|
45
|
+
for node in self.graph.nodes:
|
|
46
|
+
del node.id
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class Graph:
|
|
50
|
+
"""Represents a graph containing nodes and tensors."""
|
|
51
|
+
|
|
52
|
+
DEFAULT_OPSET = 11
|
|
53
|
+
OPSET_FUNC_MAP = defaultdict(dict) # Ops registered for specific opsets.
|
|
54
|
+
GLOBAL_FUNC_MAP = {} # Ops registered for ALL opsets.
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def register(opsets=None):
|
|
58
|
+
"""
|
|
59
|
+
Registers a function with the Graph class for the specified group of opsets. After registering the function, it
|
|
60
|
+
can be accessed like a normal member function.
|
|
61
|
+
|
|
62
|
+
For example:
|
|
63
|
+
::
|
|
64
|
+
|
|
65
|
+
@Graph.register()
|
|
66
|
+
def add(self, a, b):
|
|
67
|
+
'''Registers a function with the Graph class for the specified group of opsets for dynamic access as a member function.'''
|
|
68
|
+
return self.layer(op="Add", inputs=[a, b], outputs=["add_out_gs"])
|
|
69
|
+
|
|
70
|
+
graph.add(a, b)
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
opsets (Sequence[int]):
|
|
74
|
+
A group of opsets for which to register the function. Multiple functions with the same
|
|
75
|
+
name may be registered simultaneously if they are registered for different opsets.
|
|
76
|
+
Registering a function with a duplicate name for the same opsets will overwrite any
|
|
77
|
+
function previously registered for those opsets. By default, the function is
|
|
78
|
+
registered for all opsets.
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
def register_func(func):
|
|
82
|
+
"""Registers a function for different opsets, overwriting any previously registered function with the same
|
|
83
|
+
name.
|
|
84
|
+
"""
|
|
85
|
+
if hasattr(Graph, func.__name__):
|
|
86
|
+
G_LOGGER.warning(
|
|
87
|
+
f"Registered function: {func.__name__} is hidden by a Graph attribute or function with the same name. "
|
|
88
|
+
"This function will never be called!"
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
# Default behavior is to register functions for all opsets.
|
|
92
|
+
if opsets is None:
|
|
93
|
+
Graph.GLOBAL_FUNC_MAP[func.__name__] = func
|
|
94
|
+
else:
|
|
95
|
+
for opset in opsets:
|
|
96
|
+
Graph.OPSET_FUNC_MAP[opset][func.__name__] = func
|
|
97
|
+
return func
|
|
98
|
+
|
|
99
|
+
return register_func
|
|
100
|
+
|
|
101
|
+
def __init__(
|
|
102
|
+
self,
|
|
103
|
+
nodes: Sequence[Node] | None = None,
|
|
104
|
+
inputs: Sequence[Tensor] | None = None,
|
|
105
|
+
outputs: Sequence[Tensor] | None = None,
|
|
106
|
+
name=None,
|
|
107
|
+
doc_string=None,
|
|
108
|
+
opset=None,
|
|
109
|
+
import_domains=None,
|
|
110
|
+
ir_version=None,
|
|
111
|
+
producer_name: str | None = None,
|
|
112
|
+
producer_version: str | None = None,
|
|
113
|
+
functions: Sequence[Function] | None = None,
|
|
114
|
+
metadata_props=None,
|
|
115
|
+
):
|
|
116
|
+
"""
|
|
117
|
+
Args:
|
|
118
|
+
nodes (Sequence[Node]): A list of the nodes in this graph.
|
|
119
|
+
inputs (Sequence[Tensor]): A list of graph input Tensors.
|
|
120
|
+
outputs (Sequence[Tensor]): A list of graph output Tensors.
|
|
121
|
+
name (str): The name of the graph. Defaults to "onnx_graphsurgeon_graph".
|
|
122
|
+
doc_string (str): A doc_string for the graph. Defaults to "".
|
|
123
|
+
opset (int): The ONNX opset to use when exporting this graph.
|
|
124
|
+
producer_name (str): The name of the tool used to generate the model. Defaults to "".
|
|
125
|
+
producer_version (str): The version of the generating tool. Defaults to "".
|
|
126
|
+
"""
|
|
127
|
+
self.nodes = misc.default_value(nodes, [])
|
|
128
|
+
self.inputs = list(misc.default_value(inputs, []))
|
|
129
|
+
self.outputs = list(misc.default_value(outputs, []))
|
|
130
|
+
|
|
131
|
+
self.name = misc.default_value(name, "onnx_graphsurgeon_graph")
|
|
132
|
+
self.__name__ = self.name
|
|
133
|
+
|
|
134
|
+
self.doc_string = misc.default_value(doc_string, "")
|
|
135
|
+
self.opset = misc.default_value(opset, Graph.DEFAULT_OPSET)
|
|
136
|
+
self.producer_name = misc.default_value(producer_name, "")
|
|
137
|
+
self.producer_version = misc.default_value(producer_version, "")
|
|
138
|
+
self.metadata_props = metadata_props
|
|
139
|
+
self.import_domains = import_domains
|
|
140
|
+
self.ir_version = ir_version
|
|
141
|
+
# For layer() function
|
|
142
|
+
self.name_idx = 0
|
|
143
|
+
|
|
144
|
+
# In ONNX, the same list of Functions is shared between all Graphs & Functions in a model.
|
|
145
|
+
# Protect the list object with an underscore as self._functions
|
|
146
|
+
# Users should access/modify/set this list via graph.functions
|
|
147
|
+
self._functions = list(misc.default_value(functions, []))
|
|
148
|
+
self._merge_subgraph_functions()
|
|
149
|
+
|
|
150
|
+
# Printing graphs can be very expensive
|
|
151
|
+
G_LOGGER.ultra_verbose(lambda: f"Created Graph: {self}")
|
|
152
|
+
|
|
153
|
+
def __getattr__(self, name):
|
|
154
|
+
"""Dynamically handles attribute access, falling back to superclass attribute retrieval if not found."""
|
|
155
|
+
try:
|
|
156
|
+
return super().__getattribute__(name)
|
|
157
|
+
except AttributeError as err:
|
|
158
|
+
# Warn user if the name matches multiple registered functions.
|
|
159
|
+
methods = []
|
|
160
|
+
method_descs = []
|
|
161
|
+
|
|
162
|
+
# Opset specific ops always take priority over global ops.
|
|
163
|
+
if self.opset in Graph.OPSET_FUNC_MAP and name in Graph.OPSET_FUNC_MAP[self.opset]:
|
|
164
|
+
methods.append(Graph.OPSET_FUNC_MAP[self.opset][name])
|
|
165
|
+
method_descs.append(f'GraphSurgeon-registered function "{name}" with opset {self.opset}')
|
|
166
|
+
|
|
167
|
+
# Registered ops take priority over Local Functions.
|
|
168
|
+
if name in Graph.GLOBAL_FUNC_MAP:
|
|
169
|
+
methods.append(Graph.GLOBAL_FUNC_MAP[name])
|
|
170
|
+
method_descs.append(f'GraphSurgeon-registered function "{name}"')
|
|
171
|
+
|
|
172
|
+
for func in self.functions:
|
|
173
|
+
if func.name == name:
|
|
174
|
+
methods.append(func.__call__)
|
|
175
|
+
method_descs.append(f'Local Function "{func.name}" with domain "{func.domain}"')
|
|
176
|
+
|
|
177
|
+
if methods:
|
|
178
|
+
if len(methods) > 1:
|
|
179
|
+
msg_template = (
|
|
180
|
+
"Method name {} is overloaded with the following candidates: {}. " + "Choosing candidate {}"
|
|
181
|
+
)
|
|
182
|
+
G_LOGGER.warning(
|
|
183
|
+
message=msg_template.format(name, method_descs, method_descs[0]),
|
|
184
|
+
mode=LogMode.ONCE,
|
|
185
|
+
)
|
|
186
|
+
return lambda *args, **kwargs: methods[0](self, *args, **kwargs)
|
|
187
|
+
|
|
188
|
+
found_in_other_opsets = {opset for opset, opset_map in Graph.OPSET_FUNC_MAP.items() if name in opset_map}
|
|
189
|
+
|
|
190
|
+
G_LOGGER.error(
|
|
191
|
+
f"Function: '{name}' was not registered for opset {self.opset}. "
|
|
192
|
+
+ (
|
|
193
|
+
f"Note: '{name}' was registered for opsets: {found_in_other_opsets}."
|
|
194
|
+
if found_in_other_opsets
|
|
195
|
+
else ""
|
|
196
|
+
)
|
|
197
|
+
)
|
|
198
|
+
raise err
|
|
199
|
+
|
|
200
|
+
def __setattr__(self, name, value):
|
|
201
|
+
"""Sets an attribute to the given value, converting 'inputs' and 'outputs' to lists if applicable."""
|
|
202
|
+
if name in {"inputs", "outputs"}:
|
|
203
|
+
value = list(value)
|
|
204
|
+
return super().__setattr__(name, value)
|
|
205
|
+
|
|
206
|
+
@property
|
|
207
|
+
def functions(self) -> list[Function]:
|
|
208
|
+
"""Returns the list of subgraph functions associated with this graph."""
|
|
209
|
+
return self._functions
|
|
210
|
+
|
|
211
|
+
@functions.setter
|
|
212
|
+
def functions(self, new_fns: Sequence[Function]):
|
|
213
|
+
"""Get or set the list of functions, ensuring changes propagate to all associated subgraphs and functions."""
|
|
214
|
+
# this graph, its subgraphs, and its functions.
|
|
215
|
+
# If the user sets a new value for self.functions,
|
|
216
|
+
# all subgraphs and functions should also see this new value.
|
|
217
|
+
self._functions.clear()
|
|
218
|
+
self._functions += list(new_fns)
|
|
219
|
+
|
|
220
|
+
def __eq__(self, other: Graph):
|
|
221
|
+
"""Check for equality between two Graph objects by comparing their nodes, inputs, and outputs."""
|
|
222
|
+
nodes_match = misc.sequences_equal(self.nodes, other.nodes)
|
|
223
|
+
if not nodes_match:
|
|
224
|
+
return False
|
|
225
|
+
inputs_match = misc.sequences_equal(self.inputs, other.inputs)
|
|
226
|
+
if not inputs_match:
|
|
227
|
+
return False
|
|
228
|
+
outputs_match = misc.sequences_equal(self.outputs, other.outputs)
|
|
229
|
+
if not outputs_match:
|
|
230
|
+
return False
|
|
231
|
+
|
|
232
|
+
return self.opset == other.opset and self.import_domains == other.import_domains
|
|
233
|
+
|
|
234
|
+
def node_ids(self):
|
|
235
|
+
"""
|
|
236
|
+
Returns a context manager that supplies unique integer IDs for Nodes in the Graph.
|
|
237
|
+
|
|
238
|
+
For example:
|
|
239
|
+
::
|
|
240
|
+
|
|
241
|
+
with graph.node_ids():
|
|
242
|
+
assert graph.nodes[0].id != graph.nodes[1].id
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
NodeIDAdder: A context manager that supplies unique integer IDs for Nodes.
|
|
246
|
+
"""
|
|
247
|
+
return NodeIDAdder(self)
|
|
248
|
+
|
|
249
|
+
# Gets the node ID for a node. All internal code should use this instead of accessing `node.id` directly.
|
|
250
|
+
def _get_node_id(self, node):
|
|
251
|
+
"""Gets the node ID for a node, ensuring all internal code uses this instead of directly accessing `node.id`."""
|
|
252
|
+
try:
|
|
253
|
+
return node.id
|
|
254
|
+
except AttributeError:
|
|
255
|
+
G_LOGGER.critical(
|
|
256
|
+
f"Encountered a node not in the graph:\n{node}.\n\n"
|
|
257
|
+
"To fix this, please append the node to this graph's `nodes` attribute."
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
# A tensor is local if it is produced in this graph, or is explicitly a graph input.
|
|
261
|
+
def _local_tensors(self):
|
|
262
|
+
"""Return a dictionary of tensors that are local to the graph, including nodes' outputs, graph inputs, and
|
|
263
|
+
constants.
|
|
264
|
+
"""
|
|
265
|
+
local_tensors = {t.name: t for node in self.nodes for t in node.outputs if not t.is_empty()}
|
|
266
|
+
local_tensors.update({t.name: t for t in self.inputs})
|
|
267
|
+
local_tensors.update({t.name: t for t in self.tensors().values() if isinstance(t, Constant)})
|
|
268
|
+
return local_tensors
|
|
269
|
+
|
|
270
|
+
# Returns tensors used by this graph which are not present in the graph.
|
|
271
|
+
# These may come from an outer graph for example.
|
|
272
|
+
def _foreign_tensors(self):
|
|
273
|
+
"""Returns tensors used by this graph which are not present in the graph, potentially from an outer graph."""
|
|
274
|
+
local_tensors = self._local_tensors()
|
|
275
|
+
foreign_tensors = {}
|
|
276
|
+
|
|
277
|
+
def is_foreign_tensor(tensor):
|
|
278
|
+
"""Check if a tensor is foreign by verifying its absence in local tensors."""
|
|
279
|
+
return tensor.name not in local_tensors
|
|
280
|
+
|
|
281
|
+
for node in self.nodes:
|
|
282
|
+
foreign_tensors.update({t.name: t for t in node.inputs if is_foreign_tensor(t)})
|
|
283
|
+
|
|
284
|
+
for subgraph in node.subgraphs():
|
|
285
|
+
subgraph_foreign_tensors = subgraph._foreign_tensors()
|
|
286
|
+
# Some of the foreign tensors from a subgraph may come from this graph.
|
|
287
|
+
subgraph_foreign_tensors = {
|
|
288
|
+
t.name: t for t in subgraph_foreign_tensors.values() if is_foreign_tensor(t)
|
|
289
|
+
}
|
|
290
|
+
foreign_tensors.update(subgraph_foreign_tensors)
|
|
291
|
+
|
|
292
|
+
return foreign_tensors
|
|
293
|
+
|
|
294
|
+
def _get_used_node_ids(self):
|
|
295
|
+
"""Returns a dictionary of tensors that are used by node IDs in the current subgraph."""
|
|
296
|
+
local_tensors = self._local_tensors()
|
|
297
|
+
|
|
298
|
+
class IgnoreDupAndForeign:
|
|
299
|
+
def __init__(self, initial_tensors=None):
|
|
300
|
+
"""Initialize IgnoreDupAndForeign with an optional list of initial tensors."""
|
|
301
|
+
tensors = misc.default_value(initial_tensors, [])
|
|
302
|
+
self.seen_tensors = {tensor.name for tensor in tensors}
|
|
303
|
+
|
|
304
|
+
def __call__(self, tensor):
|
|
305
|
+
"""Check if a tensor should be included based on its name and whether it has been seen before."""
|
|
306
|
+
# False if it should be filtered out.
|
|
307
|
+
if tensor.is_empty():
|
|
308
|
+
return True
|
|
309
|
+
elif tensor.name not in local_tensors:
|
|
310
|
+
return False
|
|
311
|
+
elif tensor.name not in self.seen_tensors:
|
|
312
|
+
self.seen_tensors.add(tensor.name)
|
|
313
|
+
return True
|
|
314
|
+
return False
|
|
315
|
+
|
|
316
|
+
# Traverse backwards from outputs to find all used nodes.
|
|
317
|
+
ignore_tensors = IgnoreDupAndForeign()
|
|
318
|
+
used_tensors = list(filter(ignore_tensors, self.outputs))
|
|
319
|
+
used_node_ids = set()
|
|
320
|
+
|
|
321
|
+
index = 0
|
|
322
|
+
while index < len(used_tensors):
|
|
323
|
+
used_tensor = used_tensors[index]
|
|
324
|
+
index += 1
|
|
325
|
+
for node in used_tensor.inputs:
|
|
326
|
+
# Must cast to list here, otherwise node_used_tensors will be SynchronizedList!
|
|
327
|
+
node_used_tensors = list(node.inputs)
|
|
328
|
+
|
|
329
|
+
# If a node includes a subgraph, get any tensors that it uses from the outer graph.
|
|
330
|
+
for subgraph in node.subgraphs():
|
|
331
|
+
node_used_tensors += list(subgraph._foreign_tensors().values())
|
|
332
|
+
|
|
333
|
+
used_node_ids.add(self._get_node_id(node))
|
|
334
|
+
used_tensors.extend(filter(ignore_tensors, node_used_tensors))
|
|
335
|
+
return used_node_ids, used_tensors
|
|
336
|
+
|
|
337
|
+
def _merge_subgraph_functions(self):
|
|
338
|
+
"""Merge function lists of subgraphs into the parent graph's function list."""
|
|
339
|
+
# function list than the parent graph. This function merges those lists.
|
|
340
|
+
func_ids = {func.unique_id for func in self.functions}
|
|
341
|
+
|
|
342
|
+
def absorb_function_list(func_list):
|
|
343
|
+
"""Absorb and merge unique functions from a provided function list into the parent graph's function list."""
|
|
344
|
+
for func in func_list:
|
|
345
|
+
if func.unique_id not in func_ids:
|
|
346
|
+
self.functions.append(func)
|
|
347
|
+
func_ids.add(func.unique_id)
|
|
348
|
+
return self.functions
|
|
349
|
+
|
|
350
|
+
for graph in [*self.functions, self]:
|
|
351
|
+
for subgraph in graph.subgraphs(recursive=True):
|
|
352
|
+
new_list = absorb_function_list(subgraph.functions)
|
|
353
|
+
subgraph._functions = new_list
|
|
354
|
+
|
|
355
|
+
for func in self.functions:
|
|
356
|
+
func._functions = absorb_function_list(func.functions)
|
|
357
|
+
|
|
358
|
+
def subgraphs(self, recursive=False):
|
|
359
|
+
"""
|
|
360
|
+
Convenience function to iterate over all subgraphs which are contained in this graph. Subgraphs are found in the
|
|
361
|
+
attributes of ONNX control flow nodes such as 'If' and 'Loop'.
|
|
362
|
+
|
|
363
|
+
Args:
|
|
364
|
+
recursive (bool): Whether to recursively search this graph's subgraphs for more subgraphs. Defaults to False.
|
|
365
|
+
|
|
366
|
+
Returns:
|
|
367
|
+
A generator which iterates over the subgraphs contained in this graph.
|
|
368
|
+
"""
|
|
369
|
+
for node in self.nodes:
|
|
370
|
+
yield from node.subgraphs(recursive=recursive)
|
|
371
|
+
|
|
372
|
+
def cleanup(
|
|
373
|
+
self,
|
|
374
|
+
remove_unused_node_outputs=False,
|
|
375
|
+
recurse_subgraphs=True,
|
|
376
|
+
remove_unused_graph_inputs=False,
|
|
377
|
+
recurse_functions=True,
|
|
378
|
+
):
|
|
379
|
+
"""
|
|
380
|
+
Removes unused nodes and tensors from the graph. A node or tensor is considered unused if it does not contribute
|
|
381
|
+
to any of the graph outputs.
|
|
382
|
+
|
|
383
|
+
Additionally, any producer nodes of graph input tensors, as well as consumer nodes of graph output
|
|
384
|
+
tensors that are not in the graph, are removed from the graph.
|
|
385
|
+
|
|
386
|
+
*Note: This function will never modify graph output tensors.*
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
remove_unused_node_outputs (bool): Whether to remove unused output tensors of nodes. This will never remove
|
|
390
|
+
empty-tensor (i.e. optional, but omitted) outputs. Defaults to False.
|
|
391
|
+
recurse_subgraphs (bool):
|
|
392
|
+
Whether to recursively cleanup subgraphs.
|
|
393
|
+
Defaults to True.
|
|
394
|
+
remove_unused_graph_inputs (bool):
|
|
395
|
+
Whether to remove unused graph inputs.
|
|
396
|
+
Defaults to False.
|
|
397
|
+
recurse_functions (bool):
|
|
398
|
+
Whether to also clean up this graph's local functions.
|
|
399
|
+
Defaults to True.
|
|
400
|
+
|
|
401
|
+
Returns:
|
|
402
|
+
self
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
def cleanup_subgraphs():
|
|
406
|
+
"""Clean up subgraphs by removing unused node outputs and graph inputs, optionally recursing into subgraphs
|
|
407
|
+
and local functions.
|
|
408
|
+
"""
|
|
409
|
+
for subgraph in self.subgraphs():
|
|
410
|
+
subgraph.cleanup(
|
|
411
|
+
remove_unused_node_outputs=remove_unused_node_outputs,
|
|
412
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
413
|
+
remove_unused_graph_inputs=False,
|
|
414
|
+
recurse_functions=False, # Only cleanup functions once
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
if recurse_subgraphs:
|
|
418
|
+
cleanup_subgraphs()
|
|
419
|
+
|
|
420
|
+
if recurse_functions:
|
|
421
|
+
for func in self.functions:
|
|
422
|
+
func.cleanup(
|
|
423
|
+
remove_unused_node_outputs=remove_unused_node_outputs,
|
|
424
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
425
|
+
remove_unused_graph_inputs=remove_unused_graph_inputs,
|
|
426
|
+
recurse_functions=False, # No infinite recursion
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
G_LOGGER.verbose(f"Cleaning up {self.name}")
|
|
430
|
+
|
|
431
|
+
with self.node_ids():
|
|
432
|
+
# Graph input producers must be removed first so used_node_ids is correct.
|
|
433
|
+
for inp in self.inputs:
|
|
434
|
+
inp.inputs.clear()
|
|
435
|
+
|
|
436
|
+
used_node_ids, used_tensors = self._get_used_node_ids()
|
|
437
|
+
|
|
438
|
+
inputs = []
|
|
439
|
+
for inp in self.inputs:
|
|
440
|
+
if inp in used_tensors or not remove_unused_graph_inputs:
|
|
441
|
+
inputs.append(inp)
|
|
442
|
+
else:
|
|
443
|
+
G_LOGGER.debug(f"Removing unused input: {inp}")
|
|
444
|
+
self.inputs = inputs
|
|
445
|
+
|
|
446
|
+
nodes = []
|
|
447
|
+
|
|
448
|
+
for node in self.nodes:
|
|
449
|
+
if self._get_node_id(node) in used_node_ids:
|
|
450
|
+
nodes.append(node)
|
|
451
|
+
else:
|
|
452
|
+
node.inputs.clear()
|
|
453
|
+
node.outputs.clear()
|
|
454
|
+
G_LOGGER.ultra_verbose(f"Removing unused node: {node}")
|
|
455
|
+
|
|
456
|
+
# Remove any hanging tensors - tensors without outputs
|
|
457
|
+
if remove_unused_node_outputs:
|
|
458
|
+
graph_output_names = {tensor.name for tensor in self.outputs}
|
|
459
|
+
for node in nodes:
|
|
460
|
+
|
|
461
|
+
def is_hanging_tensor(tensor):
|
|
462
|
+
"""Checks if a tensor is hanging by verifying it is non-empty, has no outputs, and is not a
|
|
463
|
+
graph output.
|
|
464
|
+
"""
|
|
465
|
+
return (
|
|
466
|
+
not tensor.is_empty() and len(tensor.outputs) == 0 and tensor.name not in graph_output_names
|
|
467
|
+
)
|
|
468
|
+
|
|
469
|
+
to_remove = [out for out in node.outputs if is_hanging_tensor(out)]
|
|
470
|
+
for out in to_remove:
|
|
471
|
+
if out in node.outputs:
|
|
472
|
+
node.outputs.remove(out)
|
|
473
|
+
|
|
474
|
+
self.nodes = nodes
|
|
475
|
+
|
|
476
|
+
return self
|
|
477
|
+
|
|
478
|
+
def toposort(
|
|
479
|
+
self,
|
|
480
|
+
recurse_subgraphs=True,
|
|
481
|
+
recurse_functions=True,
|
|
482
|
+
mode="full",
|
|
483
|
+
):
|
|
484
|
+
"""
|
|
485
|
+
Topologically sort the graph in place.
|
|
486
|
+
|
|
487
|
+
Args:
|
|
488
|
+
recurse_subgraphs (bool):
|
|
489
|
+
Whether to recursively topologically sort subgraphs.
|
|
490
|
+
Only applicable when mode="full" or mode="nodes".
|
|
491
|
+
Defaults to True.
|
|
492
|
+
recurse_functions (bool):
|
|
493
|
+
Whether to topologically sort the nodes of this graph's functions.
|
|
494
|
+
Only applicable when mode="full" or mode="nodes".
|
|
495
|
+
Defaults to True.
|
|
496
|
+
mode (str):
|
|
497
|
+
Whether to reorder this graph's list of nodes, list of functions, or both.
|
|
498
|
+
Possible values:
|
|
499
|
+
- "full": Topologically sort the list of nodes and the list of functions.
|
|
500
|
+
- "nodes": Only sort the list of nodes.
|
|
501
|
+
- "functions": Only sort the list of functions.
|
|
502
|
+
Defaults to "full".
|
|
503
|
+
|
|
504
|
+
Returns:
|
|
505
|
+
self
|
|
506
|
+
"""
|
|
507
|
+
ALLOWED_MODES = ["full", "nodes", "functions"]
|
|
508
|
+
if mode not in ALLOWED_MODES:
|
|
509
|
+
G_LOGGER.critical(f'Mode "{mode}" not in {ALLOWED_MODES}')
|
|
510
|
+
|
|
511
|
+
sort_nodes = mode in {"full", "nodes"}
|
|
512
|
+
sort_functions = mode in {"full", "functions"}
|
|
513
|
+
|
|
514
|
+
if sort_nodes and recurse_functions:
|
|
515
|
+
for func in self.functions:
|
|
516
|
+
func.toposort(recurse_subgraphs=recurse_subgraphs, mode="nodes")
|
|
517
|
+
|
|
518
|
+
if sort_nodes and recurse_subgraphs:
|
|
519
|
+
for subgraph in self.subgraphs():
|
|
520
|
+
subgraph.toposort(recurse_subgraphs=True, recurse_functions=False, mode="nodes")
|
|
521
|
+
|
|
522
|
+
G_LOGGER.debug(f"Topologically sorting {self.name}")
|
|
523
|
+
|
|
524
|
+
# Keeps track of a node and its level in the graph hierarchy.
|
|
525
|
+
# 0 corresponds to an input node, N corresponds to a node with N layers of inputs.
|
|
526
|
+
class HierarchyDescriptor:
|
|
527
|
+
def __init__(self, node_or_func, level=None):
|
|
528
|
+
"""Initializes a HierarchyDescriptor with a node or function and an optional level in the graph
|
|
529
|
+
hierarchy.
|
|
530
|
+
"""
|
|
531
|
+
self.node_or_func = node_or_func
|
|
532
|
+
self.level = level
|
|
533
|
+
|
|
534
|
+
def __lt__(self, other):
|
|
535
|
+
"""Defines less-than comparison behavior based on hierarchy levels."""
|
|
536
|
+
return self.level < other.level
|
|
537
|
+
|
|
538
|
+
hierarchy_levels = {} # Dict[int, HierarchyDescriptor]
|
|
539
|
+
|
|
540
|
+
local_tensors = self._local_tensors()
|
|
541
|
+
func_id_to_func = {}
|
|
542
|
+
|
|
543
|
+
def get_id(node_or_func):
|
|
544
|
+
"""Returns the unique ID for a Node object or a function."""
|
|
545
|
+
if isinstance(node_or_func, Node):
|
|
546
|
+
return self._get_node_id(node_or_func)
|
|
547
|
+
return node_or_func.unique_id
|
|
548
|
+
|
|
549
|
+
def get_hierarchy_level(node_or_func, visited=None):
|
|
550
|
+
"""Returns the hierarchy level of a node or function, with optional tracking of visited elements."""
|
|
551
|
+
visited = misc.default_value(visited, set())
|
|
552
|
+
visited.add(get_id(node_or_func))
|
|
553
|
+
|
|
554
|
+
def get_inputs(node_or_func):
|
|
555
|
+
"""Find all nodes used by a given node or function."""
|
|
556
|
+
|
|
557
|
+
def get_used_nodes(node):
|
|
558
|
+
"""Find all nodes that are used as inputs by a given node."""
|
|
559
|
+
inputs = {}
|
|
560
|
+
|
|
561
|
+
def add_local_producers(tensor):
|
|
562
|
+
"""Add local tensors and their producer nodes to the inputs dictionary."""
|
|
563
|
+
nonlocal inputs
|
|
564
|
+
if tensor.name in local_tensors:
|
|
565
|
+
for inp_node in tensor.inputs:
|
|
566
|
+
inputs[self._get_node_id(inp_node)] = inp_node
|
|
567
|
+
|
|
568
|
+
for tensor in node.inputs:
|
|
569
|
+
add_local_producers(tensor)
|
|
570
|
+
|
|
571
|
+
# If a node includes a subgraph, get any tensors that it uses from the outer graph.
|
|
572
|
+
for subgraph in node.subgraphs():
|
|
573
|
+
for tensor in subgraph._foreign_tensors().values():
|
|
574
|
+
add_local_producers(tensor)
|
|
575
|
+
|
|
576
|
+
return inputs.values()
|
|
577
|
+
|
|
578
|
+
# Find all functions used in this list of nodes.
|
|
579
|
+
def get_used_funcs(nodes):
|
|
580
|
+
"""Return a dictionary of functions used in the provided list of nodes."""
|
|
581
|
+
inputs = {}
|
|
582
|
+
for subgraph in self.subgraphs():
|
|
583
|
+
inputs.update(get_used_funcs(subgraph.nodes))
|
|
584
|
+
for node in nodes:
|
|
585
|
+
func_id = (node.domain, node.op)
|
|
586
|
+
if func_id in func_id_to_func:
|
|
587
|
+
inputs[func_id] = func_id_to_func[func_id]
|
|
588
|
+
return inputs
|
|
589
|
+
|
|
590
|
+
if isinstance(node_or_func, Node):
|
|
591
|
+
inputs = get_used_nodes(node_or_func)
|
|
592
|
+
else:
|
|
593
|
+
inputs = get_used_funcs(node_or_func.nodes).values()
|
|
594
|
+
return inputs
|
|
595
|
+
|
|
596
|
+
if get_id(node_or_func) in hierarchy_levels:
|
|
597
|
+
return hierarchy_levels[get_id(node_or_func)].level
|
|
598
|
+
|
|
599
|
+
# The level of a node is the level of its highest input + 1.
|
|
600
|
+
max_input_level = max(
|
|
601
|
+
[get_hierarchy_level(inp, visited=visited) for inp in get_inputs(node_or_func)] + [-1]
|
|
602
|
+
)
|
|
603
|
+
visited.remove(get_id(node_or_func))
|
|
604
|
+
|
|
605
|
+
hierarchy_levels[get_id(node_or_func)] = HierarchyDescriptor(node_or_func, level=max_input_level + 1)
|
|
606
|
+
return max_input_level + 1
|
|
607
|
+
|
|
608
|
+
if sort_nodes:
|
|
609
|
+
with self.node_ids():
|
|
610
|
+
for node in self.nodes:
|
|
611
|
+
hierarchy_levels[get_id(node)] = HierarchyDescriptor(node, level=get_hierarchy_level(node))
|
|
612
|
+
self.nodes = [hd.node_or_func for hd in sorted(hierarchy_levels.values())]
|
|
613
|
+
|
|
614
|
+
if sort_functions:
|
|
615
|
+
self._merge_subgraph_functions()
|
|
616
|
+
func_id_to_func.update({func.unique_id: func for func in self.functions})
|
|
617
|
+
hierarchy_levels.clear()
|
|
618
|
+
for func in self.functions:
|
|
619
|
+
hierarchy_levels[func.unique_id] = HierarchyDescriptor(func, level=get_hierarchy_level(func))
|
|
620
|
+
self.functions = [hd.node_or_func for hd in sorted(hierarchy_levels.values())]
|
|
621
|
+
|
|
622
|
+
return self
|
|
623
|
+
|
|
624
|
+
def tensors(self, check_duplicates=False):
|
|
625
|
+
"""
|
|
626
|
+
Creates a tensor map of all the tensors used by this graph by walking over all nodes. Empty tensors are omitted
|
|
627
|
+
from this map.
|
|
628
|
+
|
|
629
|
+
Tensors are guaranteed to be in order of the nodes in the graph. Hence, if the graph is topologically sorted, the tensor map will be too.
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
check_duplicates (bool): Whether to fail if multiple tensors with the same name are encountered.
|
|
633
|
+
|
|
634
|
+
Raises:
|
|
635
|
+
OnnxGraphSurgeonException: If check_duplicates is True and multiple distinct tensors in the graph share the same name.
|
|
636
|
+
|
|
637
|
+
Returns:
|
|
638
|
+
OrderedDict[str, Tensor]: A mapping of tensor names to tensors.
|
|
639
|
+
"""
|
|
640
|
+
tensor_map = OrderedDict()
|
|
641
|
+
|
|
642
|
+
def add_to_tensor_map(tensor):
|
|
643
|
+
"""Add a tensor to the tensor_map if it is not empty and ensure no duplicate tensor names exist."""
|
|
644
|
+
if not tensor.is_empty():
|
|
645
|
+
if tensor.name in tensor_map and tensor_map[tensor.name] is not tensor:
|
|
646
|
+
msg = f"Found distinct tensors that share the same name:\n[id: {id(tensor_map[tensor.name])}] {tensor_map[tensor.name]}\n[id: {id(tensor)}] {tensor}\n"
|
|
647
|
+
msg += f"Note: Producer node(s) of first tensor:\n{tensor_map[tensor.name].inputs}\nProducer node(s) of second tensor:\n{tensor.inputs}"
|
|
648
|
+
|
|
649
|
+
if check_duplicates:
|
|
650
|
+
G_LOGGER.critical(msg)
|
|
651
|
+
# G_LOGGER.warning(msg)
|
|
652
|
+
|
|
653
|
+
tensor_map[tensor.name] = tensor
|
|
654
|
+
|
|
655
|
+
# I/O tensors may not be attached to nodes.
|
|
656
|
+
for io_tensor in self.inputs:
|
|
657
|
+
add_to_tensor_map(io_tensor)
|
|
658
|
+
|
|
659
|
+
for node in self.nodes:
|
|
660
|
+
for tensor in node.inputs + node.outputs:
|
|
661
|
+
add_to_tensor_map(tensor)
|
|
662
|
+
|
|
663
|
+
for io_tensor in self.outputs:
|
|
664
|
+
add_to_tensor_map(io_tensor)
|
|
665
|
+
|
|
666
|
+
return tensor_map
|
|
667
|
+
|
|
668
|
+
def fold_constants(
|
|
669
|
+
self,
|
|
670
|
+
fold_shapes=True,
|
|
671
|
+
recurse_subgraphs=True,
|
|
672
|
+
partitioning=None,
|
|
673
|
+
error_ok=True,
|
|
674
|
+
flatten_subgraphs=True,
|
|
675
|
+
size_threshold=None,
|
|
676
|
+
should_exclude_node=None,
|
|
677
|
+
recurse_functions=True,
|
|
678
|
+
):
|
|
679
|
+
"""
|
|
680
|
+
Folds constants in-place in the graph. The graph's nodes and functions must be topologically sorted prior to
|
|
681
|
+
calling this function (see `toposort()`).
|
|
682
|
+
|
|
683
|
+
This function will not remove constants after folding them. In order to get rid of
|
|
684
|
+
these hanging nodes, you can run the `cleanup()` function.
|
|
685
|
+
|
|
686
|
+
*Note: Due to how this function is implemented, the graph must be exportable to ONNX,
|
|
687
|
+
and evaluable in ONNX-Runtime. Additionally, ONNX-Runtime must be installed.*
|
|
688
|
+
|
|
689
|
+
Args:
|
|
690
|
+
fold_shapes (bool):
|
|
691
|
+
Whether to fold `Shape` nodes in the graph.
|
|
692
|
+
This requires shapes to be inferred in the graph, and can only fold
|
|
693
|
+
static shapes.
|
|
694
|
+
Defaults to True.
|
|
695
|
+
recurse_subgraphs (bool):
|
|
696
|
+
Whether to recursively fold constants in subgraphs.
|
|
697
|
+
Defaults to True.
|
|
698
|
+
partitioning (Union[str, None]):
|
|
699
|
+
Whether/How to partition the graph so that errors in folding one
|
|
700
|
+
part of a model do not affect other parts. Available modes are:
|
|
701
|
+
|
|
702
|
+
- None: Do not partition the graph. If inference fails, no constants are folded.
|
|
703
|
+
- "basic": Partition the graph. If inference fails in one partition, other partitions will
|
|
704
|
+
remain unaffected.
|
|
705
|
+
- "recursive": Partition the graph recursively. If inference fails in a partition, the partition
|
|
706
|
+
will be further partitioned.
|
|
707
|
+
|
|
708
|
+
Defaults to None.
|
|
709
|
+
error_ok (bool):
|
|
710
|
+
Whether inference errors should be suppressed.
|
|
711
|
+
When this is False, any errors encountered during inference will be re-raised.
|
|
712
|
+
Defaults to True.
|
|
713
|
+
flatten_subgraphs (bool):
|
|
714
|
+
Whether to flatten subgraphs where possible. For example, `If` nodes with a constant condition
|
|
715
|
+
can be flattened into the parent graph.
|
|
716
|
+
size_threshold (int):
|
|
717
|
+
The maximum size threshold, in bytes, for which to fold constants.
|
|
718
|
+
Any tensors larger than this value will not be folded.
|
|
719
|
+
Set to ``None`` to disable the size threshold and always fold constants.
|
|
720
|
+
For example, some models may apply ops like `Tile` or `Expand` to constants, which can
|
|
721
|
+
result in very large tensors. Rather than pre-computing those constants and bloating
|
|
722
|
+
the model size, it may be desirable to skip folding them and allow them to be computed
|
|
723
|
+
at runtime.
|
|
724
|
+
Defaults to None.
|
|
725
|
+
should_exclude_node (Callable[[gs.Node], bool]):
|
|
726
|
+
A callable that accepts an onnx-graphsurgeon node from the graph and reports whether it should
|
|
727
|
+
be excluded from folding. This is only called for nodes which are otherwise foldable.
|
|
728
|
+
Note that preventing a node from being folded also prevents its consumers from being folded.
|
|
729
|
+
Defaults to a callable that always returns False.
|
|
730
|
+
recurse_functions (bool):
|
|
731
|
+
Whether to fold constants in this graph's Functions.
|
|
732
|
+
Defaults to True.
|
|
733
|
+
|
|
734
|
+
Returns:
|
|
735
|
+
self
|
|
736
|
+
"""
|
|
737
|
+
from onnxslim.third_party.onnx_graphsurgeon.exporters.onnx_exporter import (
|
|
738
|
+
dtype_to_onnx,
|
|
739
|
+
export_onnx,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
custom_should_exclude_node = misc.default_value(should_exclude_node, lambda node: False)
|
|
743
|
+
|
|
744
|
+
# Don't fold nodes with attribute values which are variable.
|
|
745
|
+
def should_exclude_node(node):
|
|
746
|
+
"""Determine if an ONNX graph node should be excluded based on its attributes."""
|
|
747
|
+
for attr_val in node.attrs.values():
|
|
748
|
+
if isinstance(attr_val, Node.AttributeRef):
|
|
749
|
+
return True
|
|
750
|
+
return custom_should_exclude_node(node)
|
|
751
|
+
|
|
752
|
+
PARTITIONING_MODES = [None, "basic", "recursive"]
|
|
753
|
+
if partitioning not in PARTITIONING_MODES:
|
|
754
|
+
G_LOGGER.critical(f"Argument for parameter 'partitioning' must be one of: {PARTITIONING_MODES}")
|
|
755
|
+
ORT_PROVIDERS = ["CPUExecutionProvider"]
|
|
756
|
+
|
|
757
|
+
G_LOGGER.debug(f"Folding constants in {self.name}")
|
|
758
|
+
|
|
759
|
+
# We apply constant folding in 5 passes:
|
|
760
|
+
# Pass 1 lowers 'Constant' nodes into Constant tensors.
|
|
761
|
+
# Pass 2 elides casts applied to shape tensors. This is done separately from other shape folding
|
|
762
|
+
# since it operates on the original graph rather than a clone.
|
|
763
|
+
# Pass 3 finds all Constant tensors in the graph, then finds all descendants which are dependent
|
|
764
|
+
# only on constants.
|
|
765
|
+
# Pass 4 searches for Shape nodes that have variable inputs (i.e. not marked const in pass 1)
|
|
766
|
+
# and turns them into Constants iff the input has a statically known shape.
|
|
767
|
+
# Pass 5 computes the descendants determined in Pass 3 using ONNX-Runtime and replaces them in the graph.
|
|
768
|
+
|
|
769
|
+
# Pass 1: Lower constant nodes
|
|
770
|
+
for tensor in self.tensors().values():
|
|
771
|
+
if len(tensor.inputs) == 1:
|
|
772
|
+
node = tensor.inputs[0]
|
|
773
|
+
if node.op == "Constant" and tensor.outputs:
|
|
774
|
+
if len(node.attrs) != 1:
|
|
775
|
+
G_LOGGER.warning("Constant node must contain exactly one attribute")
|
|
776
|
+
continue
|
|
777
|
+
attr_name, attr_val = next(iter(node.attrs.items()))
|
|
778
|
+
allowed_attrs = {
|
|
779
|
+
"value",
|
|
780
|
+
"value_float",
|
|
781
|
+
"value_floats",
|
|
782
|
+
"value_int",
|
|
783
|
+
"value_ints",
|
|
784
|
+
}
|
|
785
|
+
if attr_name not in allowed_attrs:
|
|
786
|
+
G_LOGGER.warning(f"Unsupported attribute for Constant node: {attr_name}")
|
|
787
|
+
continue
|
|
788
|
+
if isinstance(attr_val, Node.AttributeRef):
|
|
789
|
+
continue
|
|
790
|
+
elif isinstance(attr_val, Constant):
|
|
791
|
+
arr = attr_val._values # Using ._values avoids copying
|
|
792
|
+
else:
|
|
793
|
+
arr = np.array(attr_val, dtype=tensor.dtype)
|
|
794
|
+
tensor.to_constant(arr)
|
|
795
|
+
tensor.inputs.clear()
|
|
796
|
+
|
|
797
|
+
# Pass 2: Run shape-tensor cast elision
|
|
798
|
+
def run_cast_elision(node):
|
|
799
|
+
"""Perform cast elision optimization on an ONNX node to eliminate unnecessary cast operations."""
|
|
800
|
+
import onnx
|
|
801
|
+
|
|
802
|
+
# Search for Cast(s) (from int -> float) -> intermediate operator (with float constants) -> Cast(s) (back to int)
|
|
803
|
+
# This pattern is problematic for TensorRT since these operations may be performed on Shape Tensors, which
|
|
804
|
+
# are not allowed to be floating point type. Attempt to fold the pattern here
|
|
805
|
+
VALID_CAST_ELISION_OPS = {
|
|
806
|
+
"Add",
|
|
807
|
+
"Sub",
|
|
808
|
+
"Mul",
|
|
809
|
+
"Div",
|
|
810
|
+
"Max",
|
|
811
|
+
"Min",
|
|
812
|
+
"Equal",
|
|
813
|
+
"Greater",
|
|
814
|
+
"Less",
|
|
815
|
+
"Concat",
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
if node.op not in VALID_CAST_ELISION_OPS:
|
|
819
|
+
return
|
|
820
|
+
|
|
821
|
+
# If the uncasted outputs of this node have any consumers other than "Cast" nodes,
|
|
822
|
+
# then we cannot elide the cast.
|
|
823
|
+
for out_tensor in node.outputs:
|
|
824
|
+
if out_tensor in self.outputs:
|
|
825
|
+
return
|
|
826
|
+
|
|
827
|
+
if any(out_node.op != "Cast" for out_node in out_tensor.outputs):
|
|
828
|
+
return
|
|
829
|
+
|
|
830
|
+
# Get list of input nodes that cast to float32
|
|
831
|
+
inp_casts = [
|
|
832
|
+
inp_node
|
|
833
|
+
for inp_tensor in node.inputs
|
|
834
|
+
for inp_node in inp_tensor.inputs
|
|
835
|
+
if inp_node.op == "Cast" and inp_node.attrs["to"] == onnx.TensorProto.DataType.FLOAT
|
|
836
|
+
]
|
|
837
|
+
|
|
838
|
+
# No cast nodes found, return early
|
|
839
|
+
if not inp_casts:
|
|
840
|
+
return
|
|
841
|
+
|
|
842
|
+
# Ensure that all input cast nodes are casting from the same type
|
|
843
|
+
inp_dtypes = [dtype_to_onnx(inp_cast.inputs[0].dtype) for inp_cast in inp_casts]
|
|
844
|
+
if len(set(inp_dtypes)) != 1:
|
|
845
|
+
return
|
|
846
|
+
|
|
847
|
+
final_type = inp_dtypes[0]
|
|
848
|
+
|
|
849
|
+
# Get list of output nodes that cast to int32 or int64
|
|
850
|
+
out_casts = [
|
|
851
|
+
out_node
|
|
852
|
+
for out_tensor in node.outputs
|
|
853
|
+
for out_node in out_tensor.outputs
|
|
854
|
+
if out_node.op == "Cast"
|
|
855
|
+
and out_node.attrs["to"] in {onnx.TensorProto.DataType.INT32, onnx.TensorProto.DataType.INT64}
|
|
856
|
+
]
|
|
857
|
+
|
|
858
|
+
# No cast node found on outputs, return early
|
|
859
|
+
if not out_casts:
|
|
860
|
+
return
|
|
861
|
+
|
|
862
|
+
# Ensure that all output cast nodes are casting to the same type and that this
|
|
863
|
+
# matches the original type before the inputs were casted.
|
|
864
|
+
out_dtypes = [out_cast.attrs["to"] for out_cast in out_casts]
|
|
865
|
+
if len(set(out_dtypes)) != 1 or out_dtypes[0] != final_type:
|
|
866
|
+
return
|
|
867
|
+
|
|
868
|
+
# If all checks passed, reconnect inputs/outputs to the consumers/producers
|
|
869
|
+
# of the Cast nodes.
|
|
870
|
+
# Note that we need to be careful in how we rebind tensors since they may
|
|
871
|
+
# be used by multiple nodes. Thus, it is not necessarily safe to assume that
|
|
872
|
+
# `cast_node.inputs[0].outputs[0] == cast_node`.
|
|
873
|
+
for index, inp in enumerate(node.inputs):
|
|
874
|
+
if isinstance(inp, Constant):
|
|
875
|
+
inp.values = inp.values.astype(onnx.helper.tensor_dtype_to_np_dtype(final_type))
|
|
876
|
+
|
|
877
|
+
for cast in inp_casts:
|
|
878
|
+
if cast.outputs[0] == inp:
|
|
879
|
+
node.inputs[index] = cast.inputs[0]
|
|
880
|
+
|
|
881
|
+
for index, out in enumerate(node.outputs):
|
|
882
|
+
for cast in out_casts:
|
|
883
|
+
if cast.inputs[0] == out:
|
|
884
|
+
out_tensor = cast.outputs[0]
|
|
885
|
+
out_tensor.inputs.clear() # Disconnect from Cast
|
|
886
|
+
node.outputs[index] = out_tensor
|
|
887
|
+
|
|
888
|
+
if fold_shapes:
|
|
889
|
+
# Perform shape tensor cast elision prior to most other folding
|
|
890
|
+
G_LOGGER.debug(f"Performing shape tensor cast elision in {self.name}")
|
|
891
|
+
try:
|
|
892
|
+
with self.node_ids():
|
|
893
|
+
for node in self.nodes:
|
|
894
|
+
run_cast_elision(node)
|
|
895
|
+
except Exception as err:
|
|
896
|
+
if not error_ok:
|
|
897
|
+
raise err
|
|
898
|
+
G_LOGGER.warning("'{:}' routine failed with: {:}".format("Shape tensor cast elision", err))
|
|
899
|
+
|
|
900
|
+
# Note that most of the remaining passes operate on a clone of the original graph.
|
|
901
|
+
# Pass 3: Find all descendants of constant tensors
|
|
902
|
+
|
|
903
|
+
graph_clone = self.copy()
|
|
904
|
+
clone_tensors = graph_clone.tensors()
|
|
905
|
+
|
|
906
|
+
# If 'self' is a Function, then these fields need to be set so it can be exported as an ONNX Graph.
|
|
907
|
+
graph_clone.producer_name = ""
|
|
908
|
+
graph_clone.producer_version = ""
|
|
909
|
+
|
|
910
|
+
def update_foldable_outputs(graph_constants):
|
|
911
|
+
"""Updates the graph's outputs to ensure certain operations remain foldable."""
|
|
912
|
+
|
|
913
|
+
def is_foldable(node):
|
|
914
|
+
"""Determines if a given node operation is foldable based on its type."""
|
|
915
|
+
NO_FOLD_OPS = {
|
|
916
|
+
"QuantizeLinear",
|
|
917
|
+
"DequantizeLinear",
|
|
918
|
+
"DynamicQuantizeLinear",
|
|
919
|
+
"SequenceEmpty",
|
|
920
|
+
}
|
|
921
|
+
if node.op in NO_FOLD_OPS:
|
|
922
|
+
return False
|
|
923
|
+
|
|
924
|
+
def all_tensors_const(tensors):
|
|
925
|
+
"""Check if all tensors in a given list are constants in the graph."""
|
|
926
|
+
return all(t.name in graph_constants for t in tensors if not t.is_empty())
|
|
927
|
+
|
|
928
|
+
if not all_tensors_const(node.inputs):
|
|
929
|
+
return False
|
|
930
|
+
|
|
931
|
+
all_subgraph_foreign_tensors_const = True
|
|
932
|
+
for subgraph in node.subgraphs():
|
|
933
|
+
foreign_tensors = subgraph._foreign_tensors().values()
|
|
934
|
+
all_subgraph_foreign_tensors_const &= all_tensors_const(foreign_tensors)
|
|
935
|
+
|
|
936
|
+
return all_subgraph_foreign_tensors_const and not should_exclude_node(node)
|
|
937
|
+
|
|
938
|
+
# Walks along the outputs of graph_constants to see if they can also be computed statically.
|
|
939
|
+
# Since the graph is topologically sorted, this should find all constant nodes in the graph.
|
|
940
|
+
for node in graph_clone.nodes:
|
|
941
|
+
if is_foldable(node):
|
|
942
|
+
graph_constants.update({out.name: out for out in node.outputs})
|
|
943
|
+
return graph_constants
|
|
944
|
+
|
|
945
|
+
graph_constants = {}
|
|
946
|
+
for name, tensor in clone_tensors.items():
|
|
947
|
+
if isinstance(tensor, Constant):
|
|
948
|
+
if any((t.op == "Gather" and t.inputs.index(tensor) == 0) for t in tensor.outputs):
|
|
949
|
+
if len(tensor.outputs) <= 1:
|
|
950
|
+
graph_constants[name] = tensor
|
|
951
|
+
else:
|
|
952
|
+
graph_constants[name] = tensor
|
|
953
|
+
|
|
954
|
+
graph_constants = update_foldable_outputs(graph_constants)
|
|
955
|
+
|
|
956
|
+
# Pass 4: Shape Folding
|
|
957
|
+
|
|
958
|
+
def get_producer(tensor, op):
|
|
959
|
+
"""Get the producer of the specified tensor iff it matches op."""
|
|
960
|
+
if len(tensor.inputs) != 1:
|
|
961
|
+
return None
|
|
962
|
+
|
|
963
|
+
node = tensor.inputs[0]
|
|
964
|
+
return None if node.op != op else node
|
|
965
|
+
|
|
966
|
+
def get_input(node, index=0):
|
|
967
|
+
"""Get the input tensor of a node iff the input tensor is not already marked a graph constant."""
|
|
968
|
+
if node is None:
|
|
969
|
+
return None
|
|
970
|
+
|
|
971
|
+
inp = node.inputs[index]
|
|
972
|
+
|
|
973
|
+
# If the input was already found to be a constant, it will be folded anyway.
|
|
974
|
+
return None if inp.name in graph_constants else inp
|
|
975
|
+
|
|
976
|
+
def get_scalar_value(tensor):
|
|
977
|
+
"""Gets the scalar value of a constant tensor with a single item."""
|
|
978
|
+
return next(iter(tensor.values)) if tensor.shape else tensor.values
|
|
979
|
+
|
|
980
|
+
def fold_shape(tensor):
|
|
981
|
+
"""Returns the input tensor shape if available, otherwise returns None.
|
|
982
|
+
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
|
|
983
|
+
"""
|
|
984
|
+
shape_node = get_producer(tensor, "Shape")
|
|
985
|
+
inp = get_input(shape_node)
|
|
986
|
+
if inp is None:
|
|
987
|
+
return None
|
|
988
|
+
|
|
989
|
+
if inp.shape is None or misc.is_dynamic_shape(inp.shape):
|
|
990
|
+
return None
|
|
991
|
+
|
|
992
|
+
full_shape = inp.shape
|
|
993
|
+
num_dims = len(full_shape)
|
|
994
|
+
|
|
995
|
+
# Get start and end attributes (default: start=0, end=None means full shape)
|
|
996
|
+
start = shape_node.attrs.get("start", 0)
|
|
997
|
+
end = shape_node.attrs.get("end", None)
|
|
998
|
+
|
|
999
|
+
# Handle negative indices
|
|
1000
|
+
if start < 0:
|
|
1001
|
+
start = num_dims + start
|
|
1002
|
+
if end is None:
|
|
1003
|
+
end = num_dims
|
|
1004
|
+
elif end < 0:
|
|
1005
|
+
end = num_dims + end
|
|
1006
|
+
|
|
1007
|
+
# Clamp to valid range
|
|
1008
|
+
start = max(0, min(start, num_dims))
|
|
1009
|
+
end = max(0, min(end, num_dims))
|
|
1010
|
+
|
|
1011
|
+
if start > end:
|
|
1012
|
+
return None
|
|
1013
|
+
|
|
1014
|
+
target_shape = full_shape[start:end]
|
|
1015
|
+
return np.array(target_shape, dtype=np.int64)
|
|
1016
|
+
|
|
1017
|
+
def fold_shape_gather(tensor):
|
|
1018
|
+
"""Retrieves and returns the shape of the input tensor as a NumPy array, otherwise returns None.
|
|
1019
|
+
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
|
|
1020
|
+
"""
|
|
1021
|
+
gather = get_producer(tensor, "Gather")
|
|
1022
|
+
if gather is None:
|
|
1023
|
+
return None
|
|
1024
|
+
|
|
1025
|
+
data = gather.inputs[0]
|
|
1026
|
+
indices_tensor = gather.inputs[1]
|
|
1027
|
+
|
|
1028
|
+
shape_node = get_producer(data, "Shape")
|
|
1029
|
+
inp = get_input(shape_node)
|
|
1030
|
+
if inp is None or inp.shape is None:
|
|
1031
|
+
return None
|
|
1032
|
+
|
|
1033
|
+
if not isinstance(indices_tensor, Constant):
|
|
1034
|
+
return None
|
|
1035
|
+
|
|
1036
|
+
# Get the shape slice from Shape node (considering start/end attributes)
|
|
1037
|
+
full_shape = inp.shape
|
|
1038
|
+
num_dims = len(full_shape)
|
|
1039
|
+
|
|
1040
|
+
start = shape_node.attrs.get("start", 0)
|
|
1041
|
+
end = shape_node.attrs.get("end", None)
|
|
1042
|
+
|
|
1043
|
+
if start < 0:
|
|
1044
|
+
start = num_dims + start
|
|
1045
|
+
if end is None:
|
|
1046
|
+
end = num_dims
|
|
1047
|
+
elif end < 0:
|
|
1048
|
+
end = num_dims + end
|
|
1049
|
+
|
|
1050
|
+
start = max(0, min(start, num_dims))
|
|
1051
|
+
end = max(0, min(end, num_dims))
|
|
1052
|
+
|
|
1053
|
+
if start > end:
|
|
1054
|
+
return None
|
|
1055
|
+
|
|
1056
|
+
shape_slice = full_shape[start:end]
|
|
1057
|
+
|
|
1058
|
+
indices = indices_tensor.values
|
|
1059
|
+
if not indices.shape: # Scalar-case
|
|
1060
|
+
idx = int(indices)
|
|
1061
|
+
# Handle negative indices relative to shape_slice
|
|
1062
|
+
if idx < 0:
|
|
1063
|
+
idx = len(shape_slice) + idx
|
|
1064
|
+
if idx < 0 or idx >= len(shape_slice):
|
|
1065
|
+
return None
|
|
1066
|
+
shape = shape_slice[idx]
|
|
1067
|
+
if misc.is_dynamic_dimension(shape):
|
|
1068
|
+
return None
|
|
1069
|
+
else:
|
|
1070
|
+
shape = []
|
|
1071
|
+
for index in indices:
|
|
1072
|
+
idx = int(index)
|
|
1073
|
+
# Handle negative indices relative to shape_slice
|
|
1074
|
+
if idx < 0:
|
|
1075
|
+
idx = len(shape_slice) + idx
|
|
1076
|
+
if idx < 0 or idx >= len(shape_slice):
|
|
1077
|
+
return None
|
|
1078
|
+
shape.append(shape_slice[idx])
|
|
1079
|
+
if misc.is_dynamic_shape(shape):
|
|
1080
|
+
return None
|
|
1081
|
+
|
|
1082
|
+
return np.array(shape, dtype=np.int64)
|
|
1083
|
+
|
|
1084
|
+
def fold_shape_slice(tensor):
|
|
1085
|
+
"""Fold tensor shape slice information into a NumPy array of int64 type.
|
|
1086
|
+
Handles Shape node with optional 'start' and 'end' attributes (opset 15+).
|
|
1087
|
+
"""
|
|
1088
|
+
slice_node = get_producer(tensor, "Slice")
|
|
1089
|
+
if slice_node is None:
|
|
1090
|
+
return None
|
|
1091
|
+
|
|
1092
|
+
data = slice_node.inputs[0]
|
|
1093
|
+
|
|
1094
|
+
if len(slice_node.inputs) >= 3:
|
|
1095
|
+
starts, ends = slice_node.inputs[1:3]
|
|
1096
|
+
if any(not isinstance(t, Constant) for t in [starts, ends]):
|
|
1097
|
+
return None
|
|
1098
|
+
starts, ends = get_scalar_value(starts), get_scalar_value(ends)
|
|
1099
|
+
elif "starts" in slice_node.attrs and "ends" in slice_node.attrs:
|
|
1100
|
+
starts, ends = slice_node.attrs["starts"][0], slice_node.attrs["ends"][0]
|
|
1101
|
+
else:
|
|
1102
|
+
return None
|
|
1103
|
+
|
|
1104
|
+
shape_node = get_producer(data, "Shape")
|
|
1105
|
+
inp = get_input(shape_node)
|
|
1106
|
+
if inp is None or inp.shape is None:
|
|
1107
|
+
return None
|
|
1108
|
+
|
|
1109
|
+
# For shape tensors, we can only slice on the 0th dimension.
|
|
1110
|
+
if len(slice_node.inputs) > 3:
|
|
1111
|
+
axes = slice_node.inputs[3]
|
|
1112
|
+
if not isinstance(axes, Constant):
|
|
1113
|
+
return None
|
|
1114
|
+
|
|
1115
|
+
if get_scalar_value(axes) != 0:
|
|
1116
|
+
return None
|
|
1117
|
+
elif "axes" in slice_node.attrs:
|
|
1118
|
+
if slice_node.attrs["axes"][0] != 0:
|
|
1119
|
+
return None
|
|
1120
|
+
|
|
1121
|
+
steps = 1
|
|
1122
|
+
if len(slice_node.inputs) > 4:
|
|
1123
|
+
steps = slice_node.inputs[4]
|
|
1124
|
+
if not isinstance(steps, Constant):
|
|
1125
|
+
return None
|
|
1126
|
+
steps = get_scalar_value(steps)
|
|
1127
|
+
elif "steps" in slice_node.attrs:
|
|
1128
|
+
steps = slice_node.attrs["steps"][0]
|
|
1129
|
+
|
|
1130
|
+
# Get the shape slice from Shape node (considering start/end attributes)
|
|
1131
|
+
full_shape = inp.shape
|
|
1132
|
+
num_dims = len(full_shape)
|
|
1133
|
+
|
|
1134
|
+
shape_start = shape_node.attrs.get("start", 0)
|
|
1135
|
+
shape_end = shape_node.attrs.get("end", None)
|
|
1136
|
+
|
|
1137
|
+
if shape_start < 0:
|
|
1138
|
+
shape_start = num_dims + shape_start
|
|
1139
|
+
if shape_end is None:
|
|
1140
|
+
shape_end = num_dims
|
|
1141
|
+
elif shape_end < 0:
|
|
1142
|
+
shape_end = num_dims + shape_end
|
|
1143
|
+
|
|
1144
|
+
shape_start = max(0, min(shape_start, num_dims))
|
|
1145
|
+
shape_end = max(0, min(shape_end, num_dims))
|
|
1146
|
+
|
|
1147
|
+
if shape_start > shape_end:
|
|
1148
|
+
return None
|
|
1149
|
+
|
|
1150
|
+
shape_slice = full_shape[shape_start:shape_end]
|
|
1151
|
+
|
|
1152
|
+
# Apply the Slice operation on the shape_slice
|
|
1153
|
+
shape = shape_slice[starts:ends:steps]
|
|
1154
|
+
if misc.is_dynamic_shape(shape):
|
|
1155
|
+
return None
|
|
1156
|
+
|
|
1157
|
+
return np.array(shape, dtype=np.int64)
|
|
1158
|
+
|
|
1159
|
+
if fold_shapes:
|
|
1160
|
+
# NOTE: The order of shape folding passes is important to maximize how much we fold (phase-ordering problem).
|
|
1161
|
+
SHAPE_FOLD_FUNCS = {fold_shape_gather, fold_shape_slice, fold_shape}
|
|
1162
|
+
for shape_fold_func in SHAPE_FOLD_FUNCS:
|
|
1163
|
+
try:
|
|
1164
|
+
for tensor in clone_tensors.values():
|
|
1165
|
+
shape_of = shape_fold_func(tensor)
|
|
1166
|
+
|
|
1167
|
+
if shape_of is not None:
|
|
1168
|
+
G_LOGGER.ultra_verbose(f"Folding shape tensor: {tensor.name} to: {shape_of}")
|
|
1169
|
+
graph_constants[tensor.name] = tensor.to_constant(shape_of)
|
|
1170
|
+
graph_constants[tensor.name].inputs.clear()
|
|
1171
|
+
except Exception as err:
|
|
1172
|
+
if not error_ok:
|
|
1173
|
+
raise err
|
|
1174
|
+
G_LOGGER.warning(f"'{shape_fold_func.__name__}' routine failed with:\n{err}")
|
|
1175
|
+
else:
|
|
1176
|
+
graph_constants = update_foldable_outputs(graph_constants)
|
|
1177
|
+
|
|
1178
|
+
# Pass 5: Evaluate all tensors descended from constants with ONNX-Runtime and replace them with constant values.
|
|
1179
|
+
|
|
1180
|
+
def partition_and_infer(subgraph):
|
|
1181
|
+
"""Evaluates and partitions the subgraph to infer constant values using ONNX-Runtime."""
|
|
1182
|
+
|
|
1183
|
+
def get_out_node_ids():
|
|
1184
|
+
"""Gets the final output nodes, identifying producer nodes of graph output tensors with no other
|
|
1185
|
+
outputs.
|
|
1186
|
+
"""
|
|
1187
|
+
with subgraph.node_ids():
|
|
1188
|
+
out_node_ids = set()
|
|
1189
|
+
for out in subgraph.outputs:
|
|
1190
|
+
if not out.outputs and not isinstance(out, Constant):
|
|
1191
|
+
for n_inp in out.inputs:
|
|
1192
|
+
out_node_ids.add(subgraph._get_node_id(n_inp))
|
|
1193
|
+
return out_node_ids
|
|
1194
|
+
|
|
1195
|
+
# Compute each output node in a separate subgraph.
|
|
1196
|
+
out_node_ids = get_out_node_ids()
|
|
1197
|
+
constant_values = {}
|
|
1198
|
+
|
|
1199
|
+
for index in out_node_ids: # Have to use index since 'node' is not in part
|
|
1200
|
+
part = subgraph.copy()
|
|
1201
|
+
out_node = part.nodes[index]
|
|
1202
|
+
part.outputs = out_node.outputs
|
|
1203
|
+
part.name = f"Folding: {[out.name for out in part.outputs]}"
|
|
1204
|
+
part.cleanup(remove_unused_graph_inputs=True)
|
|
1205
|
+
names = [out.name for out in part.outputs]
|
|
1206
|
+
|
|
1207
|
+
try:
|
|
1208
|
+
# Determining types is not trivial, and ONNX-RT does its own type inference.
|
|
1209
|
+
import onnxruntime as onnxrt
|
|
1210
|
+
|
|
1211
|
+
sess = onnxrt.InferenceSession(
|
|
1212
|
+
export_onnx(part, do_type_check=False).SerializeToString(),
|
|
1213
|
+
providers=ORT_PROVIDERS,
|
|
1214
|
+
)
|
|
1215
|
+
values = sess.run(names, {})
|
|
1216
|
+
except Exception as err:
|
|
1217
|
+
G_LOGGER.warning(f"Inference failed for subgraph: {part.name}. Note: Error was:\n{err}")
|
|
1218
|
+
if partitioning == "recursive":
|
|
1219
|
+
G_LOGGER.verbose("Attempting to recursively partition subgraph")
|
|
1220
|
+
# Partition failed, peel off last node.
|
|
1221
|
+
# We only need to remove one node, so avoid doing an expensive call to cleanup()
|
|
1222
|
+
part.outputs = out_node.inputs
|
|
1223
|
+
del part.nodes[part.nodes.index(out_node)]
|
|
1224
|
+
out_node.outputs.clear()
|
|
1225
|
+
out_node.inputs.clear()
|
|
1226
|
+
else:
|
|
1227
|
+
G_LOGGER.info("You may see better results if you set partitioning='recursive'")
|
|
1228
|
+
if not error_ok:
|
|
1229
|
+
raise err
|
|
1230
|
+
|
|
1231
|
+
constant_values.update(partition_and_infer(part))
|
|
1232
|
+
else:
|
|
1233
|
+
constant_values.update(dict(zip(names, values)))
|
|
1234
|
+
|
|
1235
|
+
return constant_values
|
|
1236
|
+
|
|
1237
|
+
# Only evaluate foldable values that have non-foldable outputs or are graph outputs.
|
|
1238
|
+
# Otherwise, if all the outputs are foldable, then we can just evaluate the outputs directly.
|
|
1239
|
+
# Additionally, if we can determine tensor size, do not evaluate tensors whose sizes exceed the size threshold.
|
|
1240
|
+
def should_eval_foldable(tensor):
|
|
1241
|
+
"""Determine if foldable values should be evaluated based on output nature and tensor size constraints."""
|
|
1242
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.onnx_importer import (
|
|
1243
|
+
get_itemsize,
|
|
1244
|
+
)
|
|
1245
|
+
|
|
1246
|
+
non_const = not isinstance(tensor, Constant)
|
|
1247
|
+
is_graph_output = not tensor.outputs
|
|
1248
|
+
has_non_foldable_outputs = any(out.name not in graph_constants for out in tensor.outputs)
|
|
1249
|
+
exceeds_size_threshold = (
|
|
1250
|
+
tensor.shape is not None
|
|
1251
|
+
and not misc.is_dynamic_shape(tensor.shape)
|
|
1252
|
+
and tensor.dtype is not None
|
|
1253
|
+
and size_threshold is not None
|
|
1254
|
+
) and (misc.volume(tensor.shape) * get_itemsize(tensor.dtype) > size_threshold)
|
|
1255
|
+
|
|
1256
|
+
return non_const and (is_graph_output or has_non_foldable_outputs) and not exceeds_size_threshold
|
|
1257
|
+
|
|
1258
|
+
graph_clone.outputs = [t for t in graph_constants.values() if should_eval_foldable(t)]
|
|
1259
|
+
G_LOGGER.debug(f"Folding tensors: {graph_clone.outputs}")
|
|
1260
|
+
graph_clone.cleanup(remove_unused_graph_inputs=True, recurse_functions=False)
|
|
1261
|
+
|
|
1262
|
+
# Using ._values avoids a deep copy of the values.
|
|
1263
|
+
constant_values = {
|
|
1264
|
+
name: tensor._values for name, tensor in graph_constants.items() if isinstance(tensor, Constant)
|
|
1265
|
+
}
|
|
1266
|
+
if graph_clone.outputs:
|
|
1267
|
+
if partitioning:
|
|
1268
|
+
constant_values.update(partition_and_infer(graph_clone))
|
|
1269
|
+
else:
|
|
1270
|
+
names = [t.name for t in graph_clone.outputs]
|
|
1271
|
+
try:
|
|
1272
|
+
import os
|
|
1273
|
+
import tempfile
|
|
1274
|
+
|
|
1275
|
+
import onnx
|
|
1276
|
+
import onnxruntime as onnxrt
|
|
1277
|
+
|
|
1278
|
+
onnx_model = export_onnx(graph_clone, do_type_check=False)
|
|
1279
|
+
if onnx_model.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
|
|
1280
|
+
tmp_dir = tempfile.TemporaryDirectory()
|
|
1281
|
+
tmp_path = os.path.join(tmp_dir.name, "tmp.onnx")
|
|
1282
|
+
location = f"{os.path.basename(tmp_path)}.data"
|
|
1283
|
+
if os.path.exists(location):
|
|
1284
|
+
os.remove(location)
|
|
1285
|
+
onnx.save(
|
|
1286
|
+
onnx_model,
|
|
1287
|
+
tmp_path,
|
|
1288
|
+
save_as_external_data=True,
|
|
1289
|
+
all_tensors_to_one_file=True,
|
|
1290
|
+
location=location,
|
|
1291
|
+
)
|
|
1292
|
+
onnx_model = tmp_path
|
|
1293
|
+
else:
|
|
1294
|
+
onnx_model = onnx_model.SerializeToString()
|
|
1295
|
+
sess = onnxrt.InferenceSession(
|
|
1296
|
+
onnx_model,
|
|
1297
|
+
providers=ORT_PROVIDERS,
|
|
1298
|
+
)
|
|
1299
|
+
values = sess.run(names, {})
|
|
1300
|
+
constant_values.update(dict(zip(names, values)))
|
|
1301
|
+
except Exception as err:
|
|
1302
|
+
G_LOGGER.warning(
|
|
1303
|
+
"Inference failed. You may want to try enabling partitioning to see better results. "
|
|
1304
|
+
f"Note: Error was:\n{err}"
|
|
1305
|
+
)
|
|
1306
|
+
G_LOGGER.verbose(f"Note: Graph was:\n{graph_clone}")
|
|
1307
|
+
if not error_ok:
|
|
1308
|
+
raise
|
|
1309
|
+
elif not constant_values:
|
|
1310
|
+
G_LOGGER.debug(
|
|
1311
|
+
f"Could not find any nodes in this graph ({self.name}) that can be folded. "
|
|
1312
|
+
"This could mean that constant folding has already been run on this graph. "
|
|
1313
|
+
"Skipping."
|
|
1314
|
+
)
|
|
1315
|
+
|
|
1316
|
+
# Finally, replace the Variables in the original graph with constants.
|
|
1317
|
+
large_tensors = {}
|
|
1318
|
+
if constant_values:
|
|
1319
|
+
graph_tensors = self.tensors()
|
|
1320
|
+
for name, values in constant_values.items():
|
|
1321
|
+
tensor = graph_tensors[name]
|
|
1322
|
+
if isinstance(tensor, Constant) or not tensor.outputs:
|
|
1323
|
+
# No need to fold tensors that are already constant.
|
|
1324
|
+
continue
|
|
1325
|
+
|
|
1326
|
+
if size_threshold is not None and values.nbytes > size_threshold:
|
|
1327
|
+
G_LOGGER.debug(
|
|
1328
|
+
f"Will not fold: '{name}' since its size in bytes ({values.nbytes}) exceeds the size threshold ({size_threshold})"
|
|
1329
|
+
)
|
|
1330
|
+
continue
|
|
1331
|
+
elif size_threshold is None and values.nbytes > (1 << 20):
|
|
1332
|
+
large_tensors[name] = values.nbytes
|
|
1333
|
+
|
|
1334
|
+
tensor.to_constant(values)
|
|
1335
|
+
tensor.inputs.clear() # Constants do not need inputs
|
|
1336
|
+
|
|
1337
|
+
if large_tensors:
|
|
1338
|
+
large_tensors_mib = {
|
|
1339
|
+
tensor_name: f"{value // (1 << 20)} MiB" for tensor_name, value in large_tensors.items()
|
|
1340
|
+
}
|
|
1341
|
+
G_LOGGER.warning(
|
|
1342
|
+
"It looks like this model contains foldable nodes that produce large outputs.\n"
|
|
1343
|
+
"In order to avoid bloating the model, you may want to set a constant-folding size threshold.\n"
|
|
1344
|
+
f"Note: Large tensors and their corresponding sizes were: {large_tensors_mib}",
|
|
1345
|
+
mode=LogMode.ONCE,
|
|
1346
|
+
)
|
|
1347
|
+
|
|
1348
|
+
# Folding subgraphs after the outer graph can lead to better folding.
|
|
1349
|
+
def fold_subgraphs():
|
|
1350
|
+
"""Folds constants within subgraphs of the outer computational graph for optimization."""
|
|
1351
|
+
for subgraph in self.subgraphs():
|
|
1352
|
+
subgraph.fold_constants(
|
|
1353
|
+
fold_shapes=fold_shapes,
|
|
1354
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
1355
|
+
partitioning=partitioning,
|
|
1356
|
+
error_ok=error_ok,
|
|
1357
|
+
flatten_subgraphs=flatten_subgraphs,
|
|
1358
|
+
size_threshold=size_threshold,
|
|
1359
|
+
recurse_functions=False, # Functions are folded later
|
|
1360
|
+
)
|
|
1361
|
+
|
|
1362
|
+
if recurse_subgraphs:
|
|
1363
|
+
fold_subgraphs()
|
|
1364
|
+
|
|
1365
|
+
if flatten_subgraphs:
|
|
1366
|
+
# Flatten conditional subgraphs
|
|
1367
|
+
index = 0
|
|
1368
|
+
while index < len(self.nodes):
|
|
1369
|
+
node = self.nodes[index]
|
|
1370
|
+
if node.op == "If" and isinstance(node.inputs[0], Constant):
|
|
1371
|
+
G_LOGGER.debug(f"Flattening conditional: {node.name}")
|
|
1372
|
+
cond = get_scalar_value(node.inputs[0])
|
|
1373
|
+
subgraph = node.attrs["then_branch"] if cond else node.attrs["else_branch"]
|
|
1374
|
+
# Need to add a suffix to subgraph tensors so they don't collide with outer graph tensors
|
|
1375
|
+
for tensor in subgraph._local_tensors().values():
|
|
1376
|
+
tensor.name += f"_subg_{index}_{subgraph.name}"
|
|
1377
|
+
|
|
1378
|
+
# The subgraph outputs correspond to the If node outputs. Only the latter are visible
|
|
1379
|
+
# in the parent graph, so we rebind the producer nodes of the subgraph outputs to point
|
|
1380
|
+
# to the output tensors of the If instead.
|
|
1381
|
+
node_outputs = list(node.outputs)
|
|
1382
|
+
for node_out, subgraph_out in zip(node_outputs, subgraph.outputs):
|
|
1383
|
+
node_out.inputs.clear()
|
|
1384
|
+
for producer in subgraph_out.inputs:
|
|
1385
|
+
for tensor_idx, out_tensor in enumerate(producer.outputs):
|
|
1386
|
+
if out_tensor == subgraph_out:
|
|
1387
|
+
producer.outputs[tensor_idx] = node_out
|
|
1388
|
+
|
|
1389
|
+
# Copy subgraph nodes into parent graph at the index of the If.
|
|
1390
|
+
del self.nodes[index]
|
|
1391
|
+
self.nodes[index:index] = subgraph.nodes
|
|
1392
|
+
index += len(subgraph.nodes) - 1
|
|
1393
|
+
|
|
1394
|
+
index += 1
|
|
1395
|
+
|
|
1396
|
+
if recurse_functions:
|
|
1397
|
+
# Nodes which are constant-folded but not cleaned up can result in errors during inference,
|
|
1398
|
+
# so process functions in reverse topological order.
|
|
1399
|
+
for func in reversed(self.functions):
|
|
1400
|
+
func.fold_constants(
|
|
1401
|
+
fold_shapes=fold_shapes,
|
|
1402
|
+
recurse_subgraphs=recurse_subgraphs,
|
|
1403
|
+
partitioning=partitioning,
|
|
1404
|
+
error_ok=error_ok,
|
|
1405
|
+
flatten_subgraphs=flatten_subgraphs,
|
|
1406
|
+
size_threshold=size_threshold,
|
|
1407
|
+
should_exclude_node=should_exclude_node,
|
|
1408
|
+
recurse_functions=False, # No infinite recursion
|
|
1409
|
+
)
|
|
1410
|
+
|
|
1411
|
+
return self
|
|
1412
|
+
|
|
1413
|
+
def _generate_name(self, prefix: str, existing_names: set):
|
|
1414
|
+
"""Generate a unique name by appending an index to the given prefix, ensuring it does not clash with existing
|
|
1415
|
+
names.
|
|
1416
|
+
"""
|
|
1417
|
+
# Generation is done by appending an index to the prefix.
|
|
1418
|
+
while True:
|
|
1419
|
+
name = f"{prefix}_{self.name_idx}"
|
|
1420
|
+
self.name_idx += 1
|
|
1421
|
+
if name not in existing_names: # Ensure generated name is unique
|
|
1422
|
+
break
|
|
1423
|
+
return name
|
|
1424
|
+
|
|
1425
|
+
def layer(self, inputs=None, outputs=None, *args, **kwargs):
|
|
1426
|
+
"""
|
|
1427
|
+
Creates a node, adds it to this graph, and optionally creates its input and output tensors.
|
|
1428
|
+
|
|
1429
|
+
The input and output lists can include various different types:
|
|
1430
|
+
|
|
1431
|
+
- ``Tensor``:
|
|
1432
|
+
Any Tensors provided will be used as-is in the inputs/outputs of the node created.
|
|
1433
|
+
Therefore, you must ensure that the provided Tensors have unique names.
|
|
1434
|
+
- ``str``:
|
|
1435
|
+
If a string is provided, this function will generate a new tensor using
|
|
1436
|
+
the string to generate a name. It will append an index to the end of the provided string
|
|
1437
|
+
to guarantee unique names.
|
|
1438
|
+
- ``numpy.ndarray``:
|
|
1439
|
+
If a NumPy array is provided, this function will generate a Constant tensor
|
|
1440
|
+
using the name prefix: "onnx_graphsurgeon_constant", and append an index to the end
|
|
1441
|
+
of the prefix to guarantee unique names.
|
|
1442
|
+
- ``Union[List[Number], Tuple[Number]]``:
|
|
1443
|
+
If a list or tuple of numbers (int or float) is provided, this function will
|
|
1444
|
+
generate a Constant tensor using the name prefix: "onnx_graphsurgeon_lst_constant",
|
|
1445
|
+
and append an index to the end of the prefix to guarantee unique names.
|
|
1446
|
+
The values of the tensor will be a 1D array containing the specified values.
|
|
1447
|
+
The datatype will be either `np.float32` or `np.int64`.
|
|
1448
|
+
|
|
1449
|
+
Args:
|
|
1450
|
+
inputs (List[Union[Tensor, str, numpy.ndarray]]): The list of inputs
|
|
1451
|
+
outputs (List[Union[Tensor, str, numpy.ndarray]]): The list of outputs
|
|
1452
|
+
args/kwargs: These are passed directly to the constructor of Node
|
|
1453
|
+
|
|
1454
|
+
Returns:
|
|
1455
|
+
List[Tensor]: The output tensors of the node
|
|
1456
|
+
"""
|
|
1457
|
+
inputs = misc.default_value(inputs, [])
|
|
1458
|
+
outputs = misc.default_value(outputs, [])
|
|
1459
|
+
|
|
1460
|
+
def process_io(io, existing_names):
|
|
1461
|
+
"""Processes input/output elements, converting them to Tensor, Variable, or Constant, and ensuring unique
|
|
1462
|
+
names.
|
|
1463
|
+
"""
|
|
1464
|
+
new_io = []
|
|
1465
|
+
for elem in io:
|
|
1466
|
+
if isinstance(elem, Tensor):
|
|
1467
|
+
new_io.append(elem)
|
|
1468
|
+
elif isinstance(elem, str):
|
|
1469
|
+
name = self._generate_name(elem, existing_names)
|
|
1470
|
+
tensor = Variable(name=name)
|
|
1471
|
+
new_io.append(tensor)
|
|
1472
|
+
elif isinstance(elem, np.ndarray):
|
|
1473
|
+
name = self._generate_name("onnx_graphsurgeon_constant", existing_names)
|
|
1474
|
+
new_io.append(Constant(name=name, values=elem))
|
|
1475
|
+
elif isinstance(elem, (list, tuple, numbers.Number)):
|
|
1476
|
+
if isinstance(elem, (list, tuple)):
|
|
1477
|
+
dtype = np.float32 if any(isinstance(x, float) for x in elem) else np.int64
|
|
1478
|
+
else:
|
|
1479
|
+
dtype = np.float32 if isinstance(elem, float) else np.int64
|
|
1480
|
+
arr = np.array(elem, dtype=dtype)
|
|
1481
|
+
name = self._generate_name("onnx_graphsurgeon_lst_constant", existing_names)
|
|
1482
|
+
new_io.append(Constant(name=name, values=arr))
|
|
1483
|
+
else:
|
|
1484
|
+
G_LOGGER.critical(
|
|
1485
|
+
f"Unrecognized type passed to Graph.layer: {elem}.\n"
|
|
1486
|
+
"\tHint: Did you forget to unpack a list with `*`?\n"
|
|
1487
|
+
"\tPlease use Tensors, strings, or NumPy arrays."
|
|
1488
|
+
)
|
|
1489
|
+
if new_io[-1].name:
|
|
1490
|
+
existing_names.add(new_io[-1].name)
|
|
1491
|
+
return new_io
|
|
1492
|
+
|
|
1493
|
+
existing_names = set(self.tensors().keys()) # set for fast lookup
|
|
1494
|
+
inputs = process_io(inputs, existing_names)
|
|
1495
|
+
outputs = process_io(outputs, existing_names)
|
|
1496
|
+
|
|
1497
|
+
if "name" not in kwargs:
|
|
1498
|
+
kwargs["name"] = self._generate_name("onnx_graphsurgeon_node", {node.name for node in self.nodes})
|
|
1499
|
+
|
|
1500
|
+
node = Node(*args, **kwargs, inputs=inputs, outputs=outputs)
|
|
1501
|
+
self.nodes.append(node)
|
|
1502
|
+
return node.outputs
|
|
1503
|
+
|
|
1504
|
+
def copy(self, tensor_map: OrderedDict[str, Tensor] | None = None):
|
|
1505
|
+
"""
|
|
1506
|
+
Copy the graph.
|
|
1507
|
+
|
|
1508
|
+
This makes copies of all nodes and tensors in the graph, but will not
|
|
1509
|
+
do a deep-copy of weights or attributes (with the exception of ``Graph``
|
|
1510
|
+
attributes, which will be copied using their ``copy`` method).
|
|
1511
|
+
|
|
1512
|
+
Args:
|
|
1513
|
+
tensor_map (OrderedDict[str, Tensor]):
|
|
1514
|
+
A mapping of tensor names to tensors from the outer graph.
|
|
1515
|
+
This should be ``None`` if this is the outer-most graph.
|
|
1516
|
+
|
|
1517
|
+
Returns:
|
|
1518
|
+
Graph: A copy of the graph.
|
|
1519
|
+
"""
|
|
1520
|
+
# First, reconstruct each tensor in the graph, but with no inputs or outputs
|
|
1521
|
+
tensor_map = copy.copy(misc.default_value(tensor_map, {}))
|
|
1522
|
+
|
|
1523
|
+
local_tensor_copies = {}
|
|
1524
|
+
# When we're cloning a subgraph by itself, we need to use `tensors()` to get all
|
|
1525
|
+
# required tensors - even those produced by outer graphs.
|
|
1526
|
+
local_tensor_copies.update({n: t.copy() for n, t in self.tensors().items()})
|
|
1527
|
+
# However, we should prioritize copies already made by the outer graph.
|
|
1528
|
+
local_tensor_copies.update(tensor_map)
|
|
1529
|
+
# And locally produced tensors should take precedence over everything else.
|
|
1530
|
+
local_tensor_copies.update({n: t.copy() for n, t in self._local_tensors().items()})
|
|
1531
|
+
|
|
1532
|
+
def get_tensor(name):
|
|
1533
|
+
"""Retrieve a tensor by its name from local copies, or return an empty variable if no name is provided."""
|
|
1534
|
+
return local_tensor_copies[name] if name else Variable.empty()
|
|
1535
|
+
|
|
1536
|
+
# Next, copy nodes, and update inputs/outputs
|
|
1537
|
+
new_nodes = []
|
|
1538
|
+
for node in self.nodes:
|
|
1539
|
+
new_node = node.copy(
|
|
1540
|
+
inputs=[get_tensor(inp.name) for inp in node.inputs],
|
|
1541
|
+
outputs=[get_tensor(out.name) for out in node.outputs],
|
|
1542
|
+
tensor_map=local_tensor_copies,
|
|
1543
|
+
)
|
|
1544
|
+
new_nodes.append(new_node)
|
|
1545
|
+
|
|
1546
|
+
new_graph_inputs = [get_tensor(inp.name) for inp in self.inputs]
|
|
1547
|
+
new_graph_outputs = [get_tensor(out.name) for out in self.outputs]
|
|
1548
|
+
return Graph(
|
|
1549
|
+
nodes=new_nodes,
|
|
1550
|
+
inputs=new_graph_inputs,
|
|
1551
|
+
outputs=new_graph_outputs,
|
|
1552
|
+
name=copy.copy(self.name),
|
|
1553
|
+
doc_string=copy.copy(self.doc_string),
|
|
1554
|
+
opset=copy.copy(self.opset),
|
|
1555
|
+
import_domains=self.import_domains,
|
|
1556
|
+
ir_version=self.ir_version,
|
|
1557
|
+
functions=copy.copy(self.functions),
|
|
1558
|
+
)
|
|
1559
|
+
|
|
1560
|
+
def __str__(self):
|
|
1561
|
+
"""Return a string representation of the graph including its name, opset, local functions, inputs, nodes, and
|
|
1562
|
+
outputs.
|
|
1563
|
+
"""
|
|
1564
|
+
nodes_str = "\n".join([str(node) for node in self.nodes])
|
|
1565
|
+
functions_str = ",".join([str(func.name) for func in self.functions])
|
|
1566
|
+
out = f"Graph {self.name} (Opset {self.opset})"
|
|
1567
|
+
out += f"\nLocal Functions: [{functions_str}]"
|
|
1568
|
+
out += f"\nInputs: {self.inputs}"
|
|
1569
|
+
out += f"\nNodes: {nodes_str}"
|
|
1570
|
+
out += f"\nOutputs: {self.outputs}"
|
|
1571
|
+
return out
|
|
1572
|
+
|
|
1573
|
+
def __repr__(self):
|
|
1574
|
+
"""Returns a string representation of the object."""
|
|
1575
|
+
return self.__str__()
|