onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,466 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Constant, Graph, Node
20
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER
21
+
22
+
23
+ class PatternMapping(dict):
24
+ """Represents a graph pattern mapping result."""
25
+
26
+ def __init__(self, onnx_node=None) -> None:
27
+ """Initializes a PatternMapping instance with associated ONNX node inputs and outputs."""
28
+ super().__init__()
29
+ self.onnx_node = onnx_node
30
+
31
+ self.inputs = []
32
+ self.outputs = []
33
+ if onnx_node is not None:
34
+ self.inputs = onnx_node.inputs
35
+ self.outputs = onnx_node.outputs
36
+
37
+ self.constants = {} # constant name -> onnx tensor mapping
38
+
39
+ def set_input_onnx_tensor(self, onnx_tensor, index):
40
+ """Sets an ONNX tensor at a specified index of the input list, extending the list if necessary."""
41
+ length = len(self.inputs)
42
+ for _ in range(index - length + 1):
43
+ self.inputs.append(None)
44
+ if self.inputs[index] is not None and self.inputs[index].name != onnx_tensor.name:
45
+ return False # This input tensor has been set up by another onnx tensor
46
+ self.inputs[index] = onnx_tensor
47
+ return True
48
+
49
+ def set_output_onnx_tensor(self, onnx_tensor, index):
50
+ """Sets the output ONNX tensor at the given index within the outputs list."""
51
+ length = len(self.outputs)
52
+ for _ in range(index - length + 1):
53
+ self.outputs.append(None)
54
+ if self.outputs[index] is not None and self.outputs[index].name != onnx_tensor.name:
55
+ return False # This output tensor has been set up by another onnx tensor
56
+ self.outputs[index] = onnx_tensor
57
+ return True
58
+
59
+ def set_constant_onnx_tensor(self, onnx_tensor, name):
60
+ """Set an ONNX tensor as a constant if it hasn't already been set with a different name."""
61
+ if name in self.constants and self.constants[name].name != onnx_tensor.name:
62
+ return False
63
+ self.constants[name] = onnx_tensor
64
+ return True
65
+
66
+ def _get_node(self):
67
+ """Return the ONNX node associated with the current instance."""
68
+ return self.onnx_node
69
+
70
+ def get(self, name: str):
71
+ """
72
+ Retrieve a pattern-to-graph mapping given the pattern node name.
73
+
74
+ Args:
75
+ name (str): The name of the pattern node. The pattern node can be a single op node or a subpattern.
76
+
77
+ Returns:
78
+ PatternMapping for a subpattern node or gs.Node for a single op node.
79
+ """
80
+ return self[name].onnx_node if self[name].onnx_node is not None else self[name]
81
+
82
+ def __str__(self) -> str:
83
+ """Returns a string representation of the pattern mapping, including inputs, outputs, and constants."""
84
+ if self.onnx_node is None:
85
+ return "{" + str.join(", ", [f"{key}: {value!s}" for key, value in self.items()]) + "}"
86
+ return self.onnx_node.name
87
+
88
+
89
+ class GraphPattern:
90
+ """
91
+ Represent a graph pattern.
92
+
93
+ Example:
94
+ ::
95
+
96
+ pattern = GraphPattern()
97
+ conv = pattern.add("Conv")
98
+ leaky_relu = pattern.add("LeakyReLU", inputs=[conv], check_func=lambda node: node.attrs["alpha"] < 1.0)
99
+ """
100
+
101
+ def __init__(self) -> None:
102
+ """Initializes a graph pattern with optional node and tensor configurations."""
103
+ self.op = None # op (str)
104
+ self.check_func = None # callback function for single node
105
+ # pattern node name -> GraphPattern nodes(single or subpattern)
106
+ self.nodes: dict[str, GraphPattern] = {}
107
+ # pattern node name -> input tensors
108
+ self.node_inputs: dict[str, list[int]] = {}
109
+ # pattern node name -> output tensors
110
+ self.node_outputs: dict[str, list[int]] = {}
111
+ self.num_tensors = 0 # number of all tensors in the pattern
112
+ self.tensor_inputs: dict[int, list[str]] = {} # tensor id -> input node
113
+ self.tensor_outputs: dict[int, list[str]] = {} # tensor id -> output nodes
114
+ self.input_tensors: list[int] = [] # a list of input tensor ids of this pattern
115
+ self.output_tensors: list[int] = []
116
+ # tensor id -> tensor name of constant tensors.
117
+ self.constant_tensors: dict[int, str] = {}
118
+ """Assigns a unique tensor ID, tracks its input node if provided, and initializes output node tracking."""
119
+
120
+ def _add_tensor(self, input_node=None) -> int:
121
+ tensor_id = self.num_tensors
122
+ self.tensor_inputs[tensor_id] = []
123
+ if input_node is not None:
124
+ self.tensor_inputs[tensor_id].append(input_node)
125
+ self.tensor_outputs[tensor_id] = []
126
+
127
+ self.num_tensors += 1
128
+ return tensor_id
129
+
130
+ def variable(self) -> int:
131
+ """
132
+ Add a variable tensor without a input node - This tensor will be an input tensor of this graph pattern.
133
+
134
+ Return:
135
+ int: the tensor id.
136
+ """
137
+ tensor_id = self._add_tensor()
138
+ self.input_tensors.append(tensor_id)
139
+ return tensor_id
140
+
141
+ def constant(self, name=None) -> int:
142
+ """
143
+ Add a constant tensor. If name is not provided, a default name will be assigned.
144
+
145
+ Args:
146
+ name(str): the constant tensor name
147
+
148
+ Return:
149
+ int: the tensor id.
150
+ """
151
+ tensor_id = self._add_tensor()
152
+ if name is None:
153
+ name = f"unnamed_constant_tensor_{tensor_id}"
154
+ self.constant_tensors[tensor_id] = name
155
+ return tensor_id
156
+
157
+ def set_output_tensors(self, output_tensors) -> None:
158
+ """Sets the graph pattern's output tensors based on provided tensor IDs."""
159
+ for tensor_id in output_tensors:
160
+ assert tensor_id in self.tensor_inputs
161
+ self.output_tensors = output_tensors
162
+
163
+ def _init_single_node(self, op, check_func=None) -> None:
164
+ """Initialize attributes for a single operation node and optionally set a validation function."""
165
+ self.op = op
166
+ self.check_func = check_func
167
+
168
+ def add(
169
+ self,
170
+ name: str,
171
+ op: GraphPattern | str,
172
+ check_func=None,
173
+ inputs=None,
174
+ num_output_tensors=1,
175
+ ):
176
+ """
177
+ Add an op node or a subpattern node to the current pattern.
178
+
179
+ Args:
180
+ name (str): the node name.
181
+ op (Union[GraphPattern, str]): the GraphPattern instance if adding a subpattern node or the op name if adding a single op node.
182
+ check_func (function): the callback function for additional matching rules of an op node if adding a single op node.
183
+ inputs (list): the list of input tensors. If this node is a sub-pattern, the sequence of this list should align with the sequence of the sub-pattern's input tensors.
184
+ num_output_tensors (int): number of output tensors
185
+
186
+ Return:
187
+ tuple(int) or int or None: output tensors.
188
+ """
189
+ assert self.op is None
190
+ assert name not in self.nodes
191
+
192
+ if inputs is None:
193
+ inputs = []
194
+
195
+ if isinstance(op, str):
196
+ op_name = op
197
+ op = GraphPattern()
198
+ op._init_single_node(op_name, check_func)
199
+
200
+ self.nodes[name] = op
201
+
202
+ self.node_inputs[name] = inputs
203
+
204
+ self.node_outputs[name] = []
205
+ for _ in range(num_output_tensors):
206
+ self.node_outputs[name].append(self._add_tensor(input_node=name))
207
+
208
+ for input in inputs:
209
+ self.tensor_outputs[input].append(name)
210
+
211
+ if len(self.node_outputs[name]) == 0:
212
+ return None
213
+ elif len(self.node_outputs[name]) == 1:
214
+ return self.node_outputs[name][0]
215
+ return tuple(self.node_outputs[name])
216
+
217
+ def _get_inbound(self, tensor_index):
218
+ """Retrieve the tensor id and first inbound node for a given tensor index."""
219
+ if len(self.input_tensors) > tensor_index:
220
+ tensor_id = self.input_tensors[tensor_index]
221
+ if len(self.tensor_outputs[tensor_id]):
222
+ inbound_node = self.tensor_outputs[tensor_id][0]
223
+ return tensor_id, inbound_node
224
+ return None, None
225
+
226
+ def _get_outbound(self, tensor_index):
227
+ """Retrieve the outbound node and tensor ID based on the specified tensor index."""
228
+ if len(self.output_tensors) > tensor_index:
229
+ tensor_id = self.output_tensors[tensor_index]
230
+ if len(self.tensor_inputs[tensor_id]):
231
+ outbound_node = self.tensor_inputs[tensor_id][0]
232
+ return tensor_id, outbound_node
233
+ return None, None
234
+
235
+ def _single_node_match(self, onnx_node: Node) -> bool:
236
+ """Match the ONNX node with the pattern node based on op type and optional check_func criteria."""
237
+ assert self.op is not None
238
+ with G_LOGGER.indent():
239
+ if self.op != onnx_node.op:
240
+ G_LOGGER.info(
241
+ f"No match because: Op did not match. Node op was: {onnx_node.op} but pattern op was: {self.op}."
242
+ )
243
+ return False
244
+ if self.check_func is not None and not self.check_func(onnx_node):
245
+ G_LOGGER.info("No match because: check_func returned false.")
246
+ return False
247
+ G_LOGGER.info(f"Single node is matched: {self.op}, {onnx_node.name}")
248
+ return True
249
+
250
+ def _get_tensor_index_for_node(self, node: str, tensor_id: int, is_node_input: bool):
251
+ """Returns the index of a tensor for a given node, based on whether it is an input or output tensor."""
252
+ if is_node_input:
253
+ return self.node_inputs[node].index(tensor_id)
254
+ else:
255
+ return self.node_outputs[node].index(tensor_id)
256
+
257
+ def get_inbound_or_outbound_onnx_node(self, mapping: PatternMapping, is_inbound: bool, tensor_index: int):
258
+ """Gets the ONNX node based on whether it's inbound or outbound for a specified tensor index and mapping."""
259
+ if self.op is not None:
260
+ return mapping._get_node()
261
+ if is_inbound:
262
+ inbound_tensor, inbound_node = self._get_inbound(tensor_index)
263
+ if inbound_node is not None:
264
+ return self.nodes[inbound_node].get_inbound_or_outbound_onnx_node(
265
+ mapping[inbound_node],
266
+ is_inbound=True,
267
+ tensor_index=self._get_tensor_index_for_node(inbound_node, inbound_tensor, is_node_input=True),
268
+ )
269
+
270
+ else:
271
+ outbound_tensor, outbound_node = self._get_outbound(tensor_index)
272
+ if outbound_node is not None:
273
+ return self.nodes[outbound_node].get_inbound_or_outbound_onnx_node(
274
+ mapping[outbound_node],
275
+ is_inbound=False,
276
+ tensor_index=self._get_tensor_index_for_node(outbound_node, outbound_tensor, is_node_input=False),
277
+ )
278
+ return None
279
+
280
+ # Match an onnx node and its subgraph with the current pattern.
281
+ def match(
282
+ self,
283
+ onnx_node: Node,
284
+ from_inbound: bool,
285
+ from_tensor_index: int,
286
+ mapped_onnx_nodes: set,
287
+ onnx_graph_output_tensors: set,
288
+ ):
289
+ """Matches an ONNX node and its subgraph to the current graph pattern."""
290
+ if onnx_node.id in mapped_onnx_nodes:
291
+ return None
292
+ if self.op is not None: # is single node
293
+ if not self._single_node_match(onnx_node):
294
+ return None
295
+
296
+ mapped_onnx_nodes.add(onnx_node.id)
297
+ return PatternMapping(onnx_node=onnx_node)
298
+ initial_node = None
299
+ if from_inbound:
300
+ from_tensor, initial_node = self._get_inbound(from_tensor_index)
301
+ else:
302
+ from_tensor, initial_node = self._get_outbound(from_tensor_index)
303
+ assert initial_node is not None
304
+
305
+ mapping = PatternMapping()
306
+ match = self._match_node(
307
+ initial_node,
308
+ onnx_node,
309
+ from_tensor,
310
+ mapping,
311
+ mapped_onnx_nodes,
312
+ onnx_graph_output_tensors,
313
+ from_inbound,
314
+ )
315
+ return mapping if match else None
316
+
317
+ # Match an onnx node and its subgraph with a starting pattern node(can be a subpattern node or a single node) and its subgraph. This is the actual dfs.
318
+ def _match_node(
319
+ self,
320
+ node_name: str,
321
+ onnx_node: Node,
322
+ from_tensor: int,
323
+ mapping: PatternMapping,
324
+ mapped_onnx_nodes: set,
325
+ onnx_graph_output_tensors: set,
326
+ from_inbound: bool,
327
+ ) -> bool:
328
+ """Matches ONNX nodes to the graph pattern starting from a specific node and tensor context."""
329
+ with G_LOGGER.indent():
330
+ G_LOGGER.info(f"Checking node: {onnx_node.name} against pattern node: {node_name}.")
331
+ tensor_index_for_node = self._get_tensor_index_for_node(node_name, from_tensor, is_node_input=from_inbound)
332
+ subgraph_mapping = self.nodes[node_name].match(
333
+ onnx_node,
334
+ from_inbound,
335
+ tensor_index_for_node,
336
+ mapped_onnx_nodes,
337
+ onnx_graph_output_tensors,
338
+ )
339
+ if subgraph_mapping is not None:
340
+ mapping[node_name] = subgraph_mapping
341
+ else:
342
+ return False
343
+
344
+ input_onnx_tensors = subgraph_mapping.inputs
345
+ if len(input_onnx_tensors) != len(self.node_inputs[node_name]):
346
+ return False # Number of node inputs should equal to number of input onnx tensors of the node.
347
+ for node_input_tensor, onnx_tensor in zip(self.node_inputs[node_name], input_onnx_tensors):
348
+ if onnx_tensor is None:
349
+ return False
350
+ # tensor paired up.
351
+ if node_input_tensor in self.input_tensors:
352
+ if not mapping.set_input_onnx_tensor(onnx_tensor, self.input_tensors.index(node_input_tensor)):
353
+ return False # this tensor is mapped to another onnx tensor
354
+ continue
355
+ if node_input_tensor in self.constant_tensors:
356
+ if not isinstance(onnx_tensor, Constant):
357
+ return False # constant tensor not match
358
+ if not mapping.set_constant_onnx_tensor(onnx_tensor, self.constant_tensors[node_input_tensor]):
359
+ # this constant tensor is mapped to another onnx tensor
360
+ return False
361
+ continue
362
+ if len(self.tensor_inputs[node_input_tensor]) != len(onnx_tensor.inputs):
363
+ return False
364
+ for input_node, input_onnx_node in zip(self.tensor_inputs[node_input_tensor], onnx_tensor.inputs):
365
+ # dfs ends when revisiting a node. We need to check if the edges are matched.
366
+ if input_node in mapping:
367
+ outbound_tensor_index = self._get_tensor_index_for_node(
368
+ input_node, node_input_tensor, is_node_input=False
369
+ )
370
+ outbound_onnx_node_of_input_node = self.nodes[input_node].get_inbound_or_outbound_onnx_node(
371
+ mapping[input_node],
372
+ is_inbound=False,
373
+ tensor_index=outbound_tensor_index,
374
+ )
375
+ if (
376
+ outbound_onnx_node_of_input_node is None
377
+ or outbound_onnx_node_of_input_node.name != input_onnx_node.name
378
+ ):
379
+ return False
380
+ continue
381
+ match = self._match_node(
382
+ input_node,
383
+ input_onnx_node,
384
+ node_input_tensor,
385
+ mapping,
386
+ mapped_onnx_nodes,
387
+ onnx_graph_output_tensors,
388
+ from_inbound=False,
389
+ )
390
+ if not match:
391
+ return False
392
+
393
+ output_onnx_tensors = subgraph_mapping.outputs
394
+ if len(output_onnx_tensors) != len(self.node_outputs[node_name]):
395
+ return False # Number of node outputs should be equal to number of output onnx tensors of the node.
396
+ for node_output_tensor, onnx_tensor in zip(self.node_outputs[node_name], output_onnx_tensors):
397
+ if onnx_tensor is None:
398
+ return False
399
+ # tensor matched
400
+ if node_output_tensor in self.output_tensors:
401
+ if not mapping.set_output_onnx_tensor(onnx_tensor, self.output_tensors.index(node_output_tensor)):
402
+ return False # this tensor is mapped to another onnx tensor
403
+ continue
404
+ if onnx_tensor.name in onnx_graph_output_tensors:
405
+ return False # The pattern tensor is not an output but the onnx tensor is an output tensor of the onnx graph.
406
+
407
+ # For sub-patterns, each input tensor can only have 1 output node. Otherwise the following test will fail.
408
+ if len(self.tensor_outputs[node_output_tensor]) != len(onnx_tensor.outputs):
409
+ return False
410
+ for output_node, output_onnx_node in zip(self.tensor_outputs[node_output_tensor], onnx_tensor.outputs):
411
+ # dfs ends when revisiting a node. We need to check if the edges are matched.
412
+ if output_node in mapping:
413
+ inbound_tensor_index = self._get_tensor_index_for_node(
414
+ output_node, node_output_tensor, is_node_input=True
415
+ )
416
+ inbound_onnx_node_of_output_node = self.nodes[output_node].get_inbound_or_outbound_onnx_node(
417
+ mapping[output_node],
418
+ is_inbound=True,
419
+ tensor_index=inbound_tensor_index,
420
+ )
421
+ if (
422
+ inbound_onnx_node_of_output_node is None
423
+ or inbound_onnx_node_of_output_node.name != output_onnx_node.name
424
+ ):
425
+ return False
426
+ continue
427
+ match = self._match_node(
428
+ output_node,
429
+ output_onnx_node,
430
+ node_output_tensor,
431
+ mapping,
432
+ mapped_onnx_nodes,
433
+ onnx_graph_output_tensors,
434
+ from_inbound=True,
435
+ )
436
+ if not match:
437
+ return False
438
+ return True
439
+
440
+ def match_all(self, graph: Graph) -> list[PatternMapping]:
441
+ """
442
+ Find all the matched instances of subgraph with the current pattern in the given graph.
443
+
444
+ Args:
445
+ graph (Graph): the graph to match.
446
+
447
+ Return:
448
+ List[PatternMapping]: list of mappings.
449
+ """
450
+ mappings = []
451
+ onnx_graph_output_tensors = {tensor.name for tensor in graph.outputs}
452
+ with graph.node_ids():
453
+ for node in graph.nodes:
454
+ G_LOGGER.info("Start a subgraph matching...")
455
+ mapped_onnx_nodes = set()
456
+ mapping = self.match(
457
+ node,
458
+ from_inbound=True,
459
+ from_tensor_index=0,
460
+ mapped_onnx_nodes=mapped_onnx_nodes,
461
+ onnx_graph_output_tensors=onnx_graph_output_tensors,
462
+ )
463
+ if mapping is not None:
464
+ G_LOGGER.info("Found a matched subgraph!")
465
+ mappings.append(mapping)
466
+ return mappings
@@ -0,0 +1 @@
1
+ from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter
@@ -0,0 +1,33 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
19
+
20
+
21
+ class BaseImporter:
22
+ @staticmethod
23
+ def import_graph(graph) -> Graph:
24
+ """
25
+ Import a graph from some source graph.
26
+
27
+ Args:
28
+ graph (object): The source graph to import. For example, this might be an onnx.GraphProto.
29
+
30
+ Returns:
31
+ Graph: The equivalent onnx-graphsurgeon graph.
32
+ """
33
+ raise NotImplementedError("BaseImporter is an abstract class")