onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Constant, Graph, Node
|
|
20
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class PatternMapping(dict):
|
|
24
|
+
"""Represents a graph pattern mapping result."""
|
|
25
|
+
|
|
26
|
+
def __init__(self, onnx_node=None) -> None:
|
|
27
|
+
"""Initializes a PatternMapping instance with associated ONNX node inputs and outputs."""
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.onnx_node = onnx_node
|
|
30
|
+
|
|
31
|
+
self.inputs = []
|
|
32
|
+
self.outputs = []
|
|
33
|
+
if onnx_node is not None:
|
|
34
|
+
self.inputs = onnx_node.inputs
|
|
35
|
+
self.outputs = onnx_node.outputs
|
|
36
|
+
|
|
37
|
+
self.constants = {} # constant name -> onnx tensor mapping
|
|
38
|
+
|
|
39
|
+
def set_input_onnx_tensor(self, onnx_tensor, index):
|
|
40
|
+
"""Sets an ONNX tensor at a specified index of the input list, extending the list if necessary."""
|
|
41
|
+
length = len(self.inputs)
|
|
42
|
+
for _ in range(index - length + 1):
|
|
43
|
+
self.inputs.append(None)
|
|
44
|
+
if self.inputs[index] is not None and self.inputs[index].name != onnx_tensor.name:
|
|
45
|
+
return False # This input tensor has been set up by another onnx tensor
|
|
46
|
+
self.inputs[index] = onnx_tensor
|
|
47
|
+
return True
|
|
48
|
+
|
|
49
|
+
def set_output_onnx_tensor(self, onnx_tensor, index):
|
|
50
|
+
"""Sets the output ONNX tensor at the given index within the outputs list."""
|
|
51
|
+
length = len(self.outputs)
|
|
52
|
+
for _ in range(index - length + 1):
|
|
53
|
+
self.outputs.append(None)
|
|
54
|
+
if self.outputs[index] is not None and self.outputs[index].name != onnx_tensor.name:
|
|
55
|
+
return False # This output tensor has been set up by another onnx tensor
|
|
56
|
+
self.outputs[index] = onnx_tensor
|
|
57
|
+
return True
|
|
58
|
+
|
|
59
|
+
def set_constant_onnx_tensor(self, onnx_tensor, name):
|
|
60
|
+
"""Set an ONNX tensor as a constant if it hasn't already been set with a different name."""
|
|
61
|
+
if name in self.constants and self.constants[name].name != onnx_tensor.name:
|
|
62
|
+
return False
|
|
63
|
+
self.constants[name] = onnx_tensor
|
|
64
|
+
return True
|
|
65
|
+
|
|
66
|
+
def _get_node(self):
|
|
67
|
+
"""Return the ONNX node associated with the current instance."""
|
|
68
|
+
return self.onnx_node
|
|
69
|
+
|
|
70
|
+
def get(self, name: str):
|
|
71
|
+
"""
|
|
72
|
+
Retrieve a pattern-to-graph mapping given the pattern node name.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
name (str): The name of the pattern node. The pattern node can be a single op node or a subpattern.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
PatternMapping for a subpattern node or gs.Node for a single op node.
|
|
79
|
+
"""
|
|
80
|
+
return self[name].onnx_node if self[name].onnx_node is not None else self[name]
|
|
81
|
+
|
|
82
|
+
def __str__(self) -> str:
|
|
83
|
+
"""Returns a string representation of the pattern mapping, including inputs, outputs, and constants."""
|
|
84
|
+
if self.onnx_node is None:
|
|
85
|
+
return "{" + str.join(", ", [f"{key}: {value!s}" for key, value in self.items()]) + "}"
|
|
86
|
+
return self.onnx_node.name
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class GraphPattern:
|
|
90
|
+
"""
|
|
91
|
+
Represent a graph pattern.
|
|
92
|
+
|
|
93
|
+
Example:
|
|
94
|
+
::
|
|
95
|
+
|
|
96
|
+
pattern = GraphPattern()
|
|
97
|
+
conv = pattern.add("Conv")
|
|
98
|
+
leaky_relu = pattern.add("LeakyReLU", inputs=[conv], check_func=lambda node: node.attrs["alpha"] < 1.0)
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
"""Initializes a graph pattern with optional node and tensor configurations."""
|
|
103
|
+
self.op = None # op (str)
|
|
104
|
+
self.check_func = None # callback function for single node
|
|
105
|
+
# pattern node name -> GraphPattern nodes(single or subpattern)
|
|
106
|
+
self.nodes: dict[str, GraphPattern] = {}
|
|
107
|
+
# pattern node name -> input tensors
|
|
108
|
+
self.node_inputs: dict[str, list[int]] = {}
|
|
109
|
+
# pattern node name -> output tensors
|
|
110
|
+
self.node_outputs: dict[str, list[int]] = {}
|
|
111
|
+
self.num_tensors = 0 # number of all tensors in the pattern
|
|
112
|
+
self.tensor_inputs: dict[int, list[str]] = {} # tensor id -> input node
|
|
113
|
+
self.tensor_outputs: dict[int, list[str]] = {} # tensor id -> output nodes
|
|
114
|
+
self.input_tensors: list[int] = [] # a list of input tensor ids of this pattern
|
|
115
|
+
self.output_tensors: list[int] = []
|
|
116
|
+
# tensor id -> tensor name of constant tensors.
|
|
117
|
+
self.constant_tensors: dict[int, str] = {}
|
|
118
|
+
"""Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking."""
|
|
119
|
+
|
|
120
|
+
def _add_tensor(self, input_node=None) -> int:
|
|
121
|
+
tensor_id = self.num_tensors
|
|
122
|
+
self.tensor_inputs[tensor_id] = []
|
|
123
|
+
if input_node is not None:
|
|
124
|
+
self.tensor_inputs[tensor_id].append(input_node)
|
|
125
|
+
self.tensor_outputs[tensor_id] = []
|
|
126
|
+
|
|
127
|
+
self.num_tensors += 1
|
|
128
|
+
return tensor_id
|
|
129
|
+
|
|
130
|
+
def variable(self) -> int:
|
|
131
|
+
"""
|
|
132
|
+
Add a variable tensor without a input node - This tensor will be an input tensor of this graph pattern.
|
|
133
|
+
|
|
134
|
+
Return:
|
|
135
|
+
int: the tensor id.
|
|
136
|
+
"""
|
|
137
|
+
tensor_id = self._add_tensor()
|
|
138
|
+
self.input_tensors.append(tensor_id)
|
|
139
|
+
return tensor_id
|
|
140
|
+
|
|
141
|
+
def constant(self, name=None) -> int:
|
|
142
|
+
"""
|
|
143
|
+
Add a constant tensor. If name is not provided, a default name will be assigned.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
name(str): the constant tensor name
|
|
147
|
+
|
|
148
|
+
Return:
|
|
149
|
+
int: the tensor id.
|
|
150
|
+
"""
|
|
151
|
+
tensor_id = self._add_tensor()
|
|
152
|
+
if name is None:
|
|
153
|
+
name = f"unnamed_constant_tensor_{tensor_id}"
|
|
154
|
+
self.constant_tensors[tensor_id] = name
|
|
155
|
+
return tensor_id
|
|
156
|
+
|
|
157
|
+
def set_output_tensors(self, output_tensors) -> None:
|
|
158
|
+
"""Sets the graph pattern's output tensors based on provided tensor IDs."""
|
|
159
|
+
for tensor_id in output_tensors:
|
|
160
|
+
assert tensor_id in self.tensor_inputs
|
|
161
|
+
self.output_tensors = output_tensors
|
|
162
|
+
|
|
163
|
+
def _init_single_node(self, op, check_func=None) -> None:
|
|
164
|
+
"""Initialize attributes for a single operation node and optionally set a validation function."""
|
|
165
|
+
self.op = op
|
|
166
|
+
self.check_func = check_func
|
|
167
|
+
|
|
168
|
+
def add(
|
|
169
|
+
self,
|
|
170
|
+
name: str,
|
|
171
|
+
op: GraphPattern | str,
|
|
172
|
+
check_func=None,
|
|
173
|
+
inputs=None,
|
|
174
|
+
num_output_tensors=1,
|
|
175
|
+
):
|
|
176
|
+
"""
|
|
177
|
+
Add an op node or a subpattern node to the current pattern.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
name (str): the node name.
|
|
181
|
+
op (Union[GraphPattern, str]): the GraphPattern instance if adding a subpattern node or the op name if adding a single op node.
|
|
182
|
+
check_func (function): the callback function for additional matching rules of an op node if adding a single op node.
|
|
183
|
+
inputs (list): the list of input tensors. If this node is a sub-pattern, the sequence of this list should align with the sequence of the sub-pattern's input tensors.
|
|
184
|
+
num_output_tensors (int): number of output tensors
|
|
185
|
+
|
|
186
|
+
Return:
|
|
187
|
+
tuple(int) or int or None: output tensors.
|
|
188
|
+
"""
|
|
189
|
+
assert self.op is None
|
|
190
|
+
assert name not in self.nodes
|
|
191
|
+
|
|
192
|
+
if inputs is None:
|
|
193
|
+
inputs = []
|
|
194
|
+
|
|
195
|
+
if isinstance(op, str):
|
|
196
|
+
op_name = op
|
|
197
|
+
op = GraphPattern()
|
|
198
|
+
op._init_single_node(op_name, check_func)
|
|
199
|
+
|
|
200
|
+
self.nodes[name] = op
|
|
201
|
+
|
|
202
|
+
self.node_inputs[name] = inputs
|
|
203
|
+
|
|
204
|
+
self.node_outputs[name] = []
|
|
205
|
+
for _ in range(num_output_tensors):
|
|
206
|
+
self.node_outputs[name].append(self._add_tensor(input_node=name))
|
|
207
|
+
|
|
208
|
+
for input in inputs:
|
|
209
|
+
self.tensor_outputs[input].append(name)
|
|
210
|
+
|
|
211
|
+
if len(self.node_outputs[name]) == 0:
|
|
212
|
+
return None
|
|
213
|
+
elif len(self.node_outputs[name]) == 1:
|
|
214
|
+
return self.node_outputs[name][0]
|
|
215
|
+
return tuple(self.node_outputs[name])
|
|
216
|
+
|
|
217
|
+
def _get_inbound(self, tensor_index):
|
|
218
|
+
"""Retrieve the tensor id and first inbound node for a given tensor index."""
|
|
219
|
+
if len(self.input_tensors) > tensor_index:
|
|
220
|
+
tensor_id = self.input_tensors[tensor_index]
|
|
221
|
+
if len(self.tensor_outputs[tensor_id]):
|
|
222
|
+
inbound_node = self.tensor_outputs[tensor_id][0]
|
|
223
|
+
return tensor_id, inbound_node
|
|
224
|
+
return None, None
|
|
225
|
+
|
|
226
|
+
def _get_outbound(self, tensor_index):
|
|
227
|
+
"""Retrieve the outbound node and tensor ID based on the specified tensor index."""
|
|
228
|
+
if len(self.output_tensors) > tensor_index:
|
|
229
|
+
tensor_id = self.output_tensors[tensor_index]
|
|
230
|
+
if len(self.tensor_inputs[tensor_id]):
|
|
231
|
+
outbound_node = self.tensor_inputs[tensor_id][0]
|
|
232
|
+
return tensor_id, outbound_node
|
|
233
|
+
return None, None
|
|
234
|
+
|
|
235
|
+
def _single_node_match(self, onnx_node: Node) -> bool:
|
|
236
|
+
"""Match the ONNX node with the pattern node based on op type and optional check_func criteria."""
|
|
237
|
+
assert self.op is not None
|
|
238
|
+
with G_LOGGER.indent():
|
|
239
|
+
if self.op != onnx_node.op:
|
|
240
|
+
G_LOGGER.info(
|
|
241
|
+
f"No match because: Op did not match. Node op was: {onnx_node.op} but pattern op was: {self.op}."
|
|
242
|
+
)
|
|
243
|
+
return False
|
|
244
|
+
if self.check_func is not None and not self.check_func(onnx_node):
|
|
245
|
+
G_LOGGER.info("No match because: check_func returned false.")
|
|
246
|
+
return False
|
|
247
|
+
G_LOGGER.info(f"Single node is matched: {self.op}, {onnx_node.name}")
|
|
248
|
+
return True
|
|
249
|
+
|
|
250
|
+
def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool):
|
|
251
|
+
"""Returns the index of a tensor for a given node, based on whether it is an input or output tensor."""
|
|
252
|
+
if is_node_input:
|
|
253
|
+
return self.node_inputs[node].index(tensor_id)
|
|
254
|
+
else:
|
|
255
|
+
return self.node_outputs[node].index(tensor_id)
|
|
256
|
+
|
|
257
|
+
def get_inbound_or_outbound_onnx_node(self, mapping: PatternMapping, is_inbound: bool, tensor_index: int):
|
|
258
|
+
"""Gets the ONNX node based on whether it's inbound or outbound for a specified tensor index and mapping."""
|
|
259
|
+
if self.op is not None:
|
|
260
|
+
return mapping._get_node()
|
|
261
|
+
if is_inbound:
|
|
262
|
+
inbound_tensor, inbound_node = self._get_inbound(tensor_index)
|
|
263
|
+
if inbound_node is not None:
|
|
264
|
+
return self.nodes[inbound_node].get_inbound_or_outbound_onnx_node(
|
|
265
|
+
mapping[inbound_node],
|
|
266
|
+
is_inbound=True,
|
|
267
|
+
tensor_index=self._get_tensor_index_for_node(inbound_node, inbound_tensor, is_node_input=True),
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
else:
|
|
271
|
+
outbound_tensor, outbound_node = self._get_outbound(tensor_index)
|
|
272
|
+
if outbound_node is not None:
|
|
273
|
+
return self.nodes[outbound_node].get_inbound_or_outbound_onnx_node(
|
|
274
|
+
mapping[outbound_node],
|
|
275
|
+
is_inbound=False,
|
|
276
|
+
tensor_index=self._get_tensor_index_for_node(outbound_node, outbound_tensor, is_node_input=False),
|
|
277
|
+
)
|
|
278
|
+
return None
|
|
279
|
+
|
|
280
|
+
# Match an onnx node and its subgraph with the current pattern.
|
|
281
|
+
def match(
|
|
282
|
+
self,
|
|
283
|
+
onnx_node: Node,
|
|
284
|
+
from_inbound: bool,
|
|
285
|
+
from_tensor_index: int,
|
|
286
|
+
mapped_onnx_nodes: set,
|
|
287
|
+
onnx_graph_output_tensors: set,
|
|
288
|
+
):
|
|
289
|
+
"""Matches an ONNX node and its subgraph to the current graph pattern."""
|
|
290
|
+
if onnx_node.id in mapped_onnx_nodes:
|
|
291
|
+
return None
|
|
292
|
+
if self.op is not None: # is single node
|
|
293
|
+
if not self._single_node_match(onnx_node):
|
|
294
|
+
return None
|
|
295
|
+
|
|
296
|
+
mapped_onnx_nodes.add(onnx_node.id)
|
|
297
|
+
return PatternMapping(onnx_node=onnx_node)
|
|
298
|
+
initial_node = None
|
|
299
|
+
if from_inbound:
|
|
300
|
+
from_tensor, initial_node = self._get_inbound(from_tensor_index)
|
|
301
|
+
else:
|
|
302
|
+
from_tensor, initial_node = self._get_outbound(from_tensor_index)
|
|
303
|
+
assert initial_node is not None
|
|
304
|
+
|
|
305
|
+
mapping = PatternMapping()
|
|
306
|
+
match = self._match_node(
|
|
307
|
+
initial_node,
|
|
308
|
+
onnx_node,
|
|
309
|
+
from_tensor,
|
|
310
|
+
mapping,
|
|
311
|
+
mapped_onnx_nodes,
|
|
312
|
+
onnx_graph_output_tensors,
|
|
313
|
+
from_inbound,
|
|
314
|
+
)
|
|
315
|
+
return mapping if match else None
|
|
316
|
+
|
|
317
|
+
# Match an onnx node and its subgraph with a starting pattern node(can be a subpattern node or a single node) and its subgraph. This is the actual dfs.
|
|
318
|
+
def _match_node(
|
|
319
|
+
self,
|
|
320
|
+
node_name: str,
|
|
321
|
+
onnx_node: Node,
|
|
322
|
+
from_tensor: int,
|
|
323
|
+
mapping: PatternMapping,
|
|
324
|
+
mapped_onnx_nodes: set,
|
|
325
|
+
onnx_graph_output_tensors: set,
|
|
326
|
+
from_inbound: bool,
|
|
327
|
+
) -> bool:
|
|
328
|
+
"""Matches ONNX nodes to the graph pattern starting from a specific node and tensor context."""
|
|
329
|
+
with G_LOGGER.indent():
|
|
330
|
+
G_LOGGER.info(f"Checking node: {onnx_node.name} against pattern node: {node_name}.")
|
|
331
|
+
tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound)
|
|
332
|
+
subgraph_mapping = self.nodes[node_name].match(
|
|
333
|
+
onnx_node,
|
|
334
|
+
from_inbound,
|
|
335
|
+
tensor_index_for_node,
|
|
336
|
+
mapped_onnx_nodes,
|
|
337
|
+
onnx_graph_output_tensors,
|
|
338
|
+
)
|
|
339
|
+
if subgraph_mapping is not None:
|
|
340
|
+
mapping[node_name] = subgraph_mapping
|
|
341
|
+
else:
|
|
342
|
+
return False
|
|
343
|
+
|
|
344
|
+
input_onnx_tensors = subgraph_mapping.inputs
|
|
345
|
+
if len(input_onnx_tensors) != len(self.node_inputs[node_name]):
|
|
346
|
+
return False # Number of node inputs should equal to number of input onnx tensors of the node.
|
|
347
|
+
for node_input_tensor, onnx_tensor in zip(self.node_inputs[node_name], input_onnx_tensors):
|
|
348
|
+
if onnx_tensor is None:
|
|
349
|
+
return False
|
|
350
|
+
# tensor paired up.
|
|
351
|
+
if node_input_tensor in self.input_tensors:
|
|
352
|
+
if not mapping.set_input_onnx_tensor(onnx_tensor, self.input_tensors.index(node_input_tensor)):
|
|
353
|
+
return False # this tensor is mapped to another onnx tensor
|
|
354
|
+
continue
|
|
355
|
+
if node_input_tensor in self.constant_tensors:
|
|
356
|
+
if not isinstance(onnx_tensor, Constant):
|
|
357
|
+
return False # constant tensor not match
|
|
358
|
+
if not mapping.set_constant_onnx_tensor(onnx_tensor, self.constant_tensors[node_input_tensor]):
|
|
359
|
+
# this constant tensor is mapped to another onnx tensor
|
|
360
|
+
return False
|
|
361
|
+
continue
|
|
362
|
+
if len(self.tensor_inputs[node_input_tensor]) != len(onnx_tensor.inputs):
|
|
363
|
+
return False
|
|
364
|
+
for input_node, input_onnx_node in zip(self.tensor_inputs[node_input_tensor], onnx_tensor.inputs):
|
|
365
|
+
# dfs ends when revisiting a node. We need to check if the edges are matched.
|
|
366
|
+
if input_node in mapping:
|
|
367
|
+
outbound_tensor_index = self._get_tensor_index_for_node(
|
|
368
|
+
input_node, node_input_tensor, is_node_input=False
|
|
369
|
+
)
|
|
370
|
+
outbound_onnx_node_of_input_node = self.nodes[input_node].get_inbound_or_outbound_onnx_node(
|
|
371
|
+
mapping[input_node],
|
|
372
|
+
is_inbound=False,
|
|
373
|
+
tensor_index=outbound_tensor_index,
|
|
374
|
+
)
|
|
375
|
+
if (
|
|
376
|
+
outbound_onnx_node_of_input_node is None
|
|
377
|
+
or outbound_onnx_node_of_input_node.name != input_onnx_node.name
|
|
378
|
+
):
|
|
379
|
+
return False
|
|
380
|
+
continue
|
|
381
|
+
match = self._match_node(
|
|
382
|
+
input_node,
|
|
383
|
+
input_onnx_node,
|
|
384
|
+
node_input_tensor,
|
|
385
|
+
mapping,
|
|
386
|
+
mapped_onnx_nodes,
|
|
387
|
+
onnx_graph_output_tensors,
|
|
388
|
+
from_inbound=False,
|
|
389
|
+
)
|
|
390
|
+
if not match:
|
|
391
|
+
return False
|
|
392
|
+
|
|
393
|
+
output_onnx_tensors = subgraph_mapping.outputs
|
|
394
|
+
if len(output_onnx_tensors) != len(self.node_outputs[node_name]):
|
|
395
|
+
return False # Number of node outputs should be equal to number of output onnx tensors of the node.
|
|
396
|
+
for node_output_tensor, onnx_tensor in zip(self.node_outputs[node_name], output_onnx_tensors):
|
|
397
|
+
if onnx_tensor is None:
|
|
398
|
+
return False
|
|
399
|
+
# tensor matched
|
|
400
|
+
if node_output_tensor in self.output_tensors:
|
|
401
|
+
if not mapping.set_output_onnx_tensor(onnx_tensor, self.output_tensors.index(node_output_tensor)):
|
|
402
|
+
return False # this tensor is mapped to another onnx tensor
|
|
403
|
+
continue
|
|
404
|
+
if onnx_tensor.name in onnx_graph_output_tensors:
|
|
405
|
+
return False # The pattern tensor is not an output but the onnx tensor is an output tensor of the onnx graph.
|
|
406
|
+
|
|
407
|
+
# For sub-patterns, each input tensor can only have 1 output node. Otherwise the following test will fail.
|
|
408
|
+
if len(self.tensor_outputs[node_output_tensor]) != len(onnx_tensor.outputs):
|
|
409
|
+
return False
|
|
410
|
+
for output_node, output_onnx_node in zip(self.tensor_outputs[node_output_tensor], onnx_tensor.outputs):
|
|
411
|
+
# dfs ends when revisiting a node. We need to check if the edges are matched.
|
|
412
|
+
if output_node in mapping:
|
|
413
|
+
inbound_tensor_index = self._get_tensor_index_for_node(
|
|
414
|
+
output_node, node_output_tensor, is_node_input=True
|
|
415
|
+
)
|
|
416
|
+
inbound_onnx_node_of_output_node = self.nodes[output_node].get_inbound_or_outbound_onnx_node(
|
|
417
|
+
mapping[output_node],
|
|
418
|
+
is_inbound=True,
|
|
419
|
+
tensor_index=inbound_tensor_index,
|
|
420
|
+
)
|
|
421
|
+
if (
|
|
422
|
+
inbound_onnx_node_of_output_node is None
|
|
423
|
+
or inbound_onnx_node_of_output_node.name != output_onnx_node.name
|
|
424
|
+
):
|
|
425
|
+
return False
|
|
426
|
+
continue
|
|
427
|
+
match = self._match_node(
|
|
428
|
+
output_node,
|
|
429
|
+
output_onnx_node,
|
|
430
|
+
node_output_tensor,
|
|
431
|
+
mapping,
|
|
432
|
+
mapped_onnx_nodes,
|
|
433
|
+
onnx_graph_output_tensors,
|
|
434
|
+
from_inbound=True,
|
|
435
|
+
)
|
|
436
|
+
if not match:
|
|
437
|
+
return False
|
|
438
|
+
return True
|
|
439
|
+
|
|
440
|
+
def match_all(self, graph: Graph) -> list[PatternMapping]:
|
|
441
|
+
"""
|
|
442
|
+
Find all the matched instances of subgraph with the current pattern in the given graph.
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
graph (Graph): the graph to match.
|
|
446
|
+
|
|
447
|
+
Return:
|
|
448
|
+
List[PatternMapping]: list of mappings.
|
|
449
|
+
"""
|
|
450
|
+
mappings = []
|
|
451
|
+
onnx_graph_output_tensors = {tensor.name for tensor in graph.outputs}
|
|
452
|
+
with graph.node_ids():
|
|
453
|
+
for node in graph.nodes:
|
|
454
|
+
G_LOGGER.info("Start a subgraph matching...")
|
|
455
|
+
mapped_onnx_nodes = set()
|
|
456
|
+
mapping = self.match(
|
|
457
|
+
node,
|
|
458
|
+
from_inbound=True,
|
|
459
|
+
from_tensor_index=0,
|
|
460
|
+
mapped_onnx_nodes=mapped_onnx_nodes,
|
|
461
|
+
onnx_graph_output_tensors=onnx_graph_output_tensors,
|
|
462
|
+
)
|
|
463
|
+
if mapping is not None:
|
|
464
|
+
G_LOGGER.info("Found a matched subgraph!")
|
|
465
|
+
mappings.append(mapping)
|
|
466
|
+
return mappings
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
|
|
18
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class BaseImporter:
|
|
22
|
+
@staticmethod
|
|
23
|
+
def import_graph(graph) -> Graph:
|
|
24
|
+
"""
|
|
25
|
+
Import a graph from some source graph.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
graph (object): The source graph to import. For example, this might be an onnx.GraphProto.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Graph: The equivalent onnx-graphsurgeon graph.
|
|
32
|
+
"""
|
|
33
|
+
raise NotImplementedError("BaseImporter is an abstract class")
|