onnxslim 0.1.80__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. onnxslim/__init__.py +16 -0
  2. onnxslim/__main__.py +4 -0
  3. onnxslim/argparser.py +215 -0
  4. onnxslim/cli/__init__.py +1 -0
  5. onnxslim/cli/_main.py +180 -0
  6. onnxslim/core/__init__.py +219 -0
  7. onnxslim/core/optimization/__init__.py +146 -0
  8. onnxslim/core/optimization/dead_node_elimination.py +151 -0
  9. onnxslim/core/optimization/subexpression_elimination.py +76 -0
  10. onnxslim/core/optimization/weight_tying.py +59 -0
  11. onnxslim/core/pattern/__init__.py +249 -0
  12. onnxslim/core/pattern/elimination/__init__.py +5 -0
  13. onnxslim/core/pattern/elimination/concat.py +61 -0
  14. onnxslim/core/pattern/elimination/reshape.py +77 -0
  15. onnxslim/core/pattern/elimination/reshape_as.py +64 -0
  16. onnxslim/core/pattern/elimination/slice.py +108 -0
  17. onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
  18. onnxslim/core/pattern/fusion/__init__.py +8 -0
  19. onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
  20. onnxslim/core/pattern/fusion/convadd.py +70 -0
  21. onnxslim/core/pattern/fusion/convbn.py +86 -0
  22. onnxslim/core/pattern/fusion/convmul.py +69 -0
  23. onnxslim/core/pattern/fusion/gelu.py +47 -0
  24. onnxslim/core/pattern/fusion/gemm.py +330 -0
  25. onnxslim/core/pattern/fusion/padconv.py +89 -0
  26. onnxslim/core/pattern/fusion/reduce.py +67 -0
  27. onnxslim/core/pattern/registry.py +28 -0
  28. onnxslim/misc/__init__.py +0 -0
  29. onnxslim/misc/tabulate.py +2681 -0
  30. onnxslim/third_party/__init__.py +0 -0
  31. onnxslim/third_party/_sympy/__init__.py +0 -0
  32. onnxslim/third_party/_sympy/functions.py +205 -0
  33. onnxslim/third_party/_sympy/numbers.py +397 -0
  34. onnxslim/third_party/_sympy/printers.py +491 -0
  35. onnxslim/third_party/_sympy/solve.py +172 -0
  36. onnxslim/third_party/_sympy/symbol.py +102 -0
  37. onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
  38. onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
  39. onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
  40. onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
  41. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
  42. onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
  43. onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
  44. onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
  45. onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
  46. onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
  47. onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
  48. onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
  49. onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
  50. onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
  51. onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
  52. onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
  53. onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
  54. onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
  55. onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
  56. onnxslim/third_party/symbolic_shape_infer.py +3273 -0
  57. onnxslim/utils.py +794 -0
  58. onnxslim/version.py +1 -0
  59. onnxslim-0.1.80.dist-info/METADATA +207 -0
  60. onnxslim-0.1.80.dist-info/RECORD +65 -0
  61. onnxslim-0.1.80.dist-info/WHEEL +5 -0
  62. onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
  63. onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
  64. onnxslim-0.1.80.dist-info/top_level.txt +1 -0
  65. onnxslim-0.1.80.dist-info/zip-safe +1 -0
@@ -0,0 +1,558 @@
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+ from __future__ import annotations
18
+
19
+ import copy
20
+ from collections import OrderedDict
21
+ from typing import Any
22
+
23
+ import numpy as np
24
+ import onnx
25
+ import onnx.numpy_helper
26
+
27
+ from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter
28
+ from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
29
+ from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
30
+ from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
31
+ from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import (
32
+ Constant,
33
+ LazyValues,
34
+ SparseValues,
35
+ Tensor,
36
+ Variable,
37
+ )
38
+ from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER, LogMode
39
+ from onnxslim.third_party.onnx_graphsurgeon.util import misc
40
+
41
+ # Maps values from the AttributeType enum to their string representations, e.g., {1: "FLOAT"}
42
+ ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
43
+
44
+ # Maps an ONNX attribute to the corresponding Python property
45
+ ONNX_PYTHON_ATTR_MAPPING = {
46
+ "FLOAT": "f",
47
+ "INT": "i",
48
+ "STRING": "s",
49
+ "TENSOR": "t",
50
+ "GRAPH": "g",
51
+ "FLOATS": "floats",
52
+ "INTS": "ints",
53
+ "STRINGS": "strings",
54
+ }
55
+
56
+
57
+ def get_onnx_tensor_shape(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> list[int]:
58
+ """Returns the shape of an ONNX tensor as a list of dimensions."""
59
+ shape = None
60
+ if isinstance(onnx_tensor, (onnx.TensorProto, onnx.SparseTensorProto)):
61
+ shape = onnx_tensor.dims
62
+ shape = tuple(shape)
63
+ elif onnx_tensor.type.tensor_type.HasField("shape"):
64
+ shape = []
65
+ for dim in onnx_tensor.type.tensor_type.shape.dim:
66
+ if dim.HasField("dim_param"):
67
+ shape.append(dim.dim_param)
68
+ elif dim.HasField("dim_value"):
69
+ shape.append(dim.dim_value)
70
+ else:
71
+ shape.append(None)
72
+ shape = tuple(shape)
73
+ return shape
74
+
75
+
76
+ def get_dtype_name(onnx_type):
77
+ """Get the ONNX data type name from its integer representation."""
78
+ return {val: key for key, val in onnx.TensorProto.DataType.items()}[onnx_type]
79
+
80
+
81
+ def get_itemsize(dtype):
82
+ """Return the byte size of an element for a given ONNX data type."""
83
+ np_dtype = get_numpy_type(dtype)
84
+ if np_dtype is not None:
85
+ return np.dtype(np_dtype).itemsize
86
+
87
+ if dtype == onnx.TensorProto.BFLOAT16:
88
+ return 2
89
+
90
+ if dtype in {
91
+ onnx.TensorProto.FLOAT8E4M3FN,
92
+ onnx.TensorProto.FLOAT8E4M3FNUZ,
93
+ onnx.TensorProto.FLOAT8E5M2,
94
+ onnx.TensorProto.FLOAT8E5M2FNUZ,
95
+ }:
96
+ return 1
97
+ G_LOGGER.critical(f"Unsupported type: {dtype}")
98
+
99
+
100
+ def get_numpy_type(onnx_type):
101
+ """Convert an ONNX tensor type to a corresponding NumPy type, if supported."""
102
+ if not isinstance(onnx_type, int):
103
+ # Already a NumPy type
104
+ return onnx_type
105
+
106
+ numpy_unsupported_types = [
107
+ onnx.TensorProto.BFLOAT16,
108
+ onnx.TensorProto.FLOAT8E4M3FN,
109
+ onnx.TensorProto.FLOAT8E4M3FNUZ,
110
+ onnx.TensorProto.FLOAT8E5M2,
111
+ onnx.TensorProto.FLOAT8E5M2FNUZ,
112
+ ]
113
+
114
+ # TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types.
115
+ # This obviously breaks things, so we need to treat this as a special case.
116
+ if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes():
117
+ return onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
118
+ return None
119
+
120
+
121
+ def get_onnx_tensor_dtype(
122
+ onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto,
123
+ ) -> np.dtype | onnx.TensorProto.DataType:
124
+ """Determine the NumPy dtype or ONNX tensor data type from an ONNX tensor."""
125
+ if isinstance(onnx_tensor, onnx.TensorProto):
126
+ onnx_dtype = onnx_tensor.data_type
127
+ elif isinstance(onnx_tensor, onnx.SparseTensorProto):
128
+ onnx_dtype = onnx_tensor.values.data_type
129
+ elif onnx_tensor.type.HasField("tensor_type"):
130
+ onnx_dtype = onnx_tensor.type.tensor_type.elem_type
131
+ elif onnx_tensor.type.HasField("sequence_type"):
132
+ onnx_dtype = onnx_tensor.type.sequence_type.elem_type.tensor_type.elem_type
133
+ elif onnx_tensor.type.HasField("map_type"):
134
+ onnx_dtype = onnx_tensor.type.map_type.value_type
135
+ elif onnx_tensor.type.HasField("optional_type"):
136
+ onnx_dtype = onnx_tensor.type.optional_type.elem_type
137
+ elif onnx_tensor.type.HasField("sparse_tensor_type"):
138
+ onnx_dtype = onnx_tensor.type.sparse_tensor_type.elem_type
139
+ else:
140
+ onnx_dtype = onnx_tensor.type.opaque_type
141
+
142
+ dtype = get_numpy_type(onnx_dtype)
143
+ if dtype is not None:
144
+ return dtype
145
+
146
+ G_LOGGER.warning(
147
+ f"Could not convert: {get_dtype_name(onnx_dtype)} to a corresponding NumPy type. "
148
+ f"The original ONNX type will be preserved. ",
149
+ mode=LogMode.ONCE,
150
+ )
151
+ return onnx_dtype
152
+
153
+
154
+ def get_onnx_tensor_type(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> str:
155
+ """Determine the ONNX tensor type from a given ONNX TensorProto or ValueInfoProto."""
156
+ if isinstance(onnx_tensor, onnx.TensorProto):
157
+ return "tensor_type"
158
+ elif onnx_tensor.type.HasField("tensor_type"):
159
+ return "tensor_type"
160
+ elif onnx_tensor.type.HasField("sequence_type"):
161
+ return "sequence_type"
162
+ elif onnx_tensor.type.HasField("map_type"):
163
+ return "map_type"
164
+ elif onnx_tensor.type.HasField("optional_type"):
165
+ return "optional_type"
166
+ elif onnx_tensor.type.HasField("opaque_type"):
167
+ return "opaque_type"
168
+ elif onnx_tensor.type.HasField("sparse_tensor_type"):
169
+ return "sparse_tensor_type"
170
+ else:
171
+ return None
172
+
173
+
174
+ def get_onnx_tensor_type(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> str:
175
+ """Identifies and returns the specific data type category of a given ONNX tensor."""
176
+ if isinstance(onnx_tensor, onnx.TensorProto):
177
+ return "tensor_type"
178
+ elif onnx_tensor.type.HasField("tensor_type"):
179
+ return "tensor_type"
180
+ elif onnx_tensor.type.HasField("sequence_type"):
181
+ return "sequence_type"
182
+ elif onnx_tensor.type.HasField("map_type"):
183
+ return "map_type"
184
+ elif onnx_tensor.type.HasField("optional_type"):
185
+ return "optional_type"
186
+ elif onnx_tensor.type.HasField("opaque_type"):
187
+ return "opaque_type"
188
+ elif onnx_tensor.type.HasField("sparse_tensor_type"):
189
+ return "sparse_tensor_type"
190
+ else:
191
+ return None
192
+
193
+
194
+ class OnnxImporter(BaseImporter):
195
+ @staticmethod
196
+ def get_opset(model_or_func: onnx.ModelProto | onnx.FunctionProto):
197
+ """Return the ONNX opset version for the given ONNX model or function, or None if the information is
198
+ unavailable.
199
+ """
200
+ class_name = "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model"
201
+ try:
202
+ for importer in OnnxImporter.get_import_domains(model_or_func):
203
+ if importer.domain in {"", "ai.onnx"}:
204
+ return importer.version
205
+ G_LOGGER.warning(f"{class_name} does not contain ONNX domain opset information! Using default opset.")
206
+ return None
207
+ except Exception:
208
+ G_LOGGER.warning(f"{class_name} does not contain opset information! Using default opset.")
209
+ return None
210
+
211
+ @staticmethod
212
+ def get_import_domains(model_or_func: onnx.ModelProto | onnx.FunctionProto):
213
+ """Retrieves the import domains from an ONNX model or function."""
214
+ return model_or_func.opset_import
215
+
216
+ @staticmethod
217
+ def get_ir_version(model_or_func: onnx.ModelProto | onnx.FunctionProto):
218
+ """Retrieves the ir_version from an ONNX model or function."""
219
+ try:
220
+ return model_or_func.ir_version
221
+ except Exception:
222
+ return None
223
+
224
+ @staticmethod
225
+ def import_tensor(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto | onnx.SparseTensorProto) -> Tensor:
226
+ """Converts an ONNX tensor into a corresponding internal Tensor representation."""
227
+ if isinstance(onnx_tensor, onnx.SparseTensorProto):
228
+ return Constant(
229
+ name=onnx_tensor.values.name,
230
+ values=SparseValues(onnx_tensor),
231
+ data_location=onnx_tensor.values.data_location,
232
+ )
233
+ elif isinstance(onnx_tensor, onnx.TensorProto):
234
+ data_location = int(onnx_tensor.data_location) if onnx_tensor.HasField("data_location") else None
235
+ return Constant(
236
+ name=onnx_tensor.name,
237
+ values=LazyValues(onnx_tensor),
238
+ data_location=data_location,
239
+ )
240
+ else:
241
+ # A ValueInfoProto inside a subgraph might not have shape & type specified.
242
+ tensor = Variable(onnx_tensor.name)
243
+ if onnx_tensor.type.ByteSize() > 0:
244
+ tensor.dtype = get_onnx_tensor_dtype(onnx_tensor)
245
+ tensor.shape = get_onnx_tensor_shape(onnx_tensor)
246
+ tensor.type = get_onnx_tensor_type(onnx_tensor)
247
+ return tensor
248
+
249
+ @staticmethod
250
+ def import_attributes(
251
+ onnx_attributes: list[onnx.AttributeProto],
252
+ tensor_map: OrderedDict[str, Tensor],
253
+ subgraph_tensor_map: OrderedDict[str, Tensor],
254
+ opset: int,
255
+ import_domains: onnx.OperatorSetIdProto,
256
+ ) -> OrderedDict[str, Any]:
257
+ """Import ONNX attribute values into Python dictionary format, handling various ONNX attribute types."""
258
+ attr_dict = OrderedDict()
259
+ for attr in onnx_attributes:
260
+
261
+ def process_attr(attr_str: str):
262
+ """Process an ONNX attribute based on its type, handling strings, tensors, graphs, and numeric
263
+ sequences.
264
+ """
265
+ if attr.ref_attr_name:
266
+ attr_type = misc.convert_from_onnx_attr_type(attr.type)
267
+ return Node.AttributeRef(attr.ref_attr_name, attr_type)
268
+ processed = getattr(attr, ONNX_PYTHON_ATTR_MAPPING[attr_str])
269
+ if attr_str == "STRING":
270
+ processed = processed.decode()
271
+ elif attr_str == "TENSOR":
272
+ processed = OnnxImporter.import_tensor(processed)
273
+ elif attr_str == "GRAPH":
274
+ processed = OnnxImporter.import_graph(
275
+ processed,
276
+ misc.combine_dicts(tensor_map, subgraph_tensor_map),
277
+ opset=opset,
278
+ import_domains=import_domains,
279
+ )
280
+ elif attr_str in {"FLOATS", "INTS"}:
281
+ processed = list(processed)
282
+ elif attr_str == "STRINGS":
283
+ processed = [p.decode() for p in processed]
284
+ return processed
285
+
286
+ if attr.type in ATTR_TYPE_MAPPING:
287
+ attr_str = ATTR_TYPE_MAPPING[attr.type]
288
+ if attr_str in ONNX_PYTHON_ATTR_MAPPING:
289
+ attr_dict[attr.name] = process_attr(attr_str)
290
+ else:
291
+ G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.")
292
+ else:
293
+ G_LOGGER.warning(
294
+ f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute."
295
+ )
296
+ return attr_dict
297
+
298
+ @staticmethod
299
+ def import_node(
300
+ onnx_node: onnx.NodeProto,
301
+ tensor_map: OrderedDict[str, Tensor],
302
+ subgraph_tensor_map: OrderedDict[str, Tensor],
303
+ opset,
304
+ import_domains: onnx.OperatorSetIdProto,
305
+ ) -> Node:
306
+ # Optional inputs/outputs are represented by empty tensors. All other tensors should already have been populated during shape inference.
307
+ """Parse ONNX node, mapping its attributes and tensors for model integration."""
308
+
309
+ def get_tensor(name: str, check_outer_graph=True):
310
+ """Retrieve a tensor by its name, prioritizing the subgraph tensor map and optionally checking the outer
311
+ graph.
312
+ """
313
+ if name in subgraph_tensor_map:
314
+ return subgraph_tensor_map[name]
315
+
316
+ if check_outer_graph and name in tensor_map:
317
+ return tensor_map[name]
318
+
319
+ if not name:
320
+ # Empty tensors are not tracked by the graph, as these represent optional inputs/outputs that have been omitted.
321
+ G_LOGGER.verbose("Generating empty tensor")
322
+ return Variable.empty()
323
+
324
+ G_LOGGER.verbose(
325
+ f"Tensor: {name} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor."
326
+ )
327
+ subgraph_tensor_map[name] = Variable(name)
328
+ return subgraph_tensor_map[name]
329
+
330
+ # Retrieve Tensors for node inputs/outputs. Only empty tensors should need to be newly added.
331
+ def retrieve_node_inputs() -> list[Tensor]:
332
+ inputs = [] # List[Tensor]
333
+ for input_name in onnx_node.input:
334
+ inputs.append(get_tensor(input_name))
335
+ return inputs
336
+
337
+ def retrieve_node_outputs() -> list[Tensor]:
338
+ outputs = [] # List[Tensor]
339
+ for output_name in onnx_node.output:
340
+ # Node outputs cannot come from the outer graph, they must be created within the inner graph.
341
+ outputs.append(get_tensor(output_name, check_outer_graph=False))
342
+ return outputs
343
+
344
+ attributes = OnnxImporter.import_attributes(
345
+ onnx_node.attribute, tensor_map, subgraph_tensor_map, opset, import_domains
346
+ )
347
+
348
+ return Node(
349
+ op=onnx_node.op_type,
350
+ name=onnx_node.name,
351
+ attrs=attributes,
352
+ inputs=retrieve_node_inputs(),
353
+ outputs=retrieve_node_outputs(),
354
+ domain=onnx_node.domain if onnx_node.HasField("domain") else None,
355
+ )
356
+
357
+ @staticmethod
358
+ def import_function(
359
+ onnx_function: onnx.FunctionProto,
360
+ model_opset: int | None = None,
361
+ model_import_domains: onnx.OperatorSetIdProto = None,
362
+ ) -> Function:
363
+ """Imports an ONNX function to a Function object using the model opset and import domains."""
364
+ opset = OnnxImporter.get_opset(onnx_function) or model_opset
365
+ import_domains = OnnxImporter.get_import_domains(onnx_function) or model_import_domains
366
+ subgraph_tensor_map = OrderedDict() # Tensors in this function
367
+
368
+ def make_tensor(name: str) -> Tensor:
369
+ if name not in subgraph_tensor_map:
370
+ subgraph_tensor_map[name] = Variable(name)
371
+ return subgraph_tensor_map[name]
372
+
373
+ function_inputs = [make_tensor(inp) for inp in onnx_function.input]
374
+ function_outputs = [make_tensor(out) for out in onnx_function.output]
375
+ nodes = [
376
+ OnnxImporter.import_node(onnx_node, {}, subgraph_tensor_map, opset, import_domains)
377
+ for onnx_node in onnx_function.node
378
+ ]
379
+
380
+ attributes = {}
381
+ if onnx_function.attribute:
382
+ attributes = {attr_name: None for attr_name in onnx_function.attribute}
383
+ if onnx_function.attribute_proto:
384
+ attrs_with_default_value = OnnxImporter.import_attributes(
385
+ onnx_function.attribute_proto,
386
+ None,
387
+ subgraph_tensor_map,
388
+ opset,
389
+ import_domains,
390
+ )
391
+ attributes.update(attrs_with_default_value)
392
+
393
+ return Function(
394
+ onnx_function.name,
395
+ onnx_function.domain,
396
+ nodes=nodes,
397
+ inputs=function_inputs,
398
+ outputs=function_outputs,
399
+ doc_string=onnx_function.doc_string,
400
+ opset=opset,
401
+ import_domains=import_domains,
402
+ attrs=attributes,
403
+ )
404
+
405
+ @staticmethod
406
+ def import_graph(
407
+ onnx_graph: onnx.GraphProto,
408
+ tensor_map: OrderedDict[str, Tensor] | None = None,
409
+ opset=None,
410
+ import_domains: onnx.OperatorSetIdProto = None,
411
+ ir_version=None,
412
+ producer_name: str | None = None,
413
+ producer_version: str | None = None,
414
+ functions: list[Function] | None = None,
415
+ metadata_props=None,
416
+ ) -> Graph:
417
+ """
418
+ Imports a Graph from an ONNX Graph.
419
+
420
+ Args:
421
+ onnx_graph (onnx.GraphProto): The ONNX graph to import.
422
+
423
+ tensor_map (OrderedDict[str, Tensor]): A mapping of tensor names to Tensors. This is generally only useful for subgraph import.
424
+ opset (int): The ONNX opset to use for this graph.
425
+ producer_name (str): The name of the tool used to generate the model. Defaults to "".
426
+ producer_version (str): The version of the generating tool. Defaults to "".
427
+ functions (List[Function]): The list of custom functions which are available to use in the model.
428
+ """
429
+ functions = misc.default_value(functions, [])
430
+ tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only
431
+ subgraph_tensor_map = OrderedDict() # Tensors in this subgraph
432
+
433
+ # Retrieves a Tensor from subgraph_tensor_map or the outer graph (tensor_map) if present, otherwise imports the tensor
434
+ # If overwrite=True, this function will overwrite previously imported tensors
435
+ # if the new tensor has more information available.
436
+ def get_tensor(
437
+ onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto | onnx.SparseTensorProto,
438
+ overwrite=False,
439
+ check_outer_graph=True,
440
+ ) -> Tensor:
441
+ if isinstance(onnx_tensor, onnx.SparseTensorProto):
442
+ name = onnx_tensor.values.name
443
+ else:
444
+ name = onnx_tensor.name
445
+ # Prioritize the subgraph even if check_outer_graph is set
446
+ if name in subgraph_tensor_map:
447
+ if overwrite:
448
+ tensor = OnnxImporter.import_tensor(onnx_tensor)
449
+ if isinstance(subgraph_tensor_map[name], Variable):
450
+ subgraph_tensor_map[name].dtype = subgraph_tensor_map[name].dtype or tensor.dtype
451
+ subgraph_tensor_map[name].shape = subgraph_tensor_map[name].shape or tensor.shape
452
+ return subgraph_tensor_map[name]
453
+
454
+ if check_outer_graph and name in tensor_map:
455
+ return tensor_map[name]
456
+
457
+ subgraph_tensor_map[name] = OnnxImporter.import_tensor(onnx_tensor)
458
+ return subgraph_tensor_map[name]
459
+
460
+ # Import initializers contents into Constants.
461
+ G_LOGGER.verbose("Importing initializers")
462
+ for initializer in onnx_graph.initializer:
463
+ get_tensor(initializer)
464
+ for initializer in onnx_graph.sparse_initializer:
465
+ get_tensor(initializer)
466
+
467
+ # Import all tensors whose shapes are known. Tensors may be repeated, and some of these
468
+ # duplicates may not include shape/dtype information, so overwrite is set to True
469
+ # so that we can capture all the information available about the tensor
470
+ G_LOGGER.verbose("Importing tensors with known shapes")
471
+ for tensor in onnx_graph.value_info:
472
+ get_tensor(tensor, overwrite=True)
473
+
474
+ # Import graph inputs and outputs. Initializers are not considered to be inputs.
475
+ # Graph inputs and outputs can never come from the outer graph!
476
+ initializer_names = set(
477
+ [tensor.name for tensor in onnx_graph.initializer]
478
+ + [tensor.values.name for tensor in onnx_graph.sparse_initializer]
479
+ )
480
+ G_LOGGER.verbose("Importing graph inputs")
481
+ graph_inputs = [] # List[Tensor]
482
+ for inp in onnx_graph.input:
483
+ if inp.name not in initializer_names:
484
+ tensor = get_tensor(inp, check_outer_graph=False)
485
+ tensor.is_input = True
486
+ graph_inputs.append(tensor)
487
+
488
+ G_LOGGER.verbose("Importing graph outputs")
489
+ graph_outputs = [] # List[Tensor]
490
+ for out in onnx_graph.output:
491
+ tensor = get_tensor(out, check_outer_graph=False, overwrite=True)
492
+ tensor.is_output = True
493
+ graph_outputs.append(tensor)
494
+
495
+ G_LOGGER.verbose("Importing nodes")
496
+ nodes = [] # List[Node]
497
+ for onnx_node in onnx_graph.node:
498
+ node = OnnxImporter.import_node(onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains)
499
+ nodes.append(node)
500
+
501
+ return Graph(
502
+ nodes=nodes,
503
+ inputs=graph_inputs,
504
+ outputs=graph_outputs,
505
+ name=onnx_graph.name,
506
+ doc_string=onnx_graph.doc_string,
507
+ producer_name=producer_name,
508
+ producer_version=producer_version,
509
+ opset=opset,
510
+ import_domains=import_domains,
511
+ ir_version=ir_version,
512
+ functions=functions,
513
+ metadata_props=metadata_props,
514
+ )
515
+
516
+
517
+ def import_onnx(onnx_model: onnx.ModelProto) -> Graph:
518
+ """
519
+ Import an onnx-graphsurgeon Graph from the provided ONNX model.
520
+
521
+ Args:
522
+ onnx_model (onnx.ModelProto): The ONNX model.
523
+
524
+ Returns:
525
+ Graph: A corresponding onnx-graphsurgeon Graph.
526
+ """
527
+ model_opset = OnnxImporter.get_opset(onnx_model)
528
+ model_ir_version = OnnxImporter.get_ir_version(onnx_model)
529
+ model_import_domains = OnnxImporter.get_import_domains(onnx_model)
530
+ functions: list[Function] = [
531
+ OnnxImporter.import_function(
532
+ onnx_function,
533
+ model_opset=model_opset,
534
+ model_import_domains=model_import_domains,
535
+ )
536
+ for onnx_function in onnx_model.functions
537
+ ]
538
+
539
+ # Functions are identified by their name and domain.
540
+ # Make sure that no two Functions share the same name and domain.
541
+ function_unique_ids = set()
542
+ for func in functions:
543
+ unique_id = func.unique_id
544
+ if unique_id in function_unique_ids:
545
+ msg = "Model contains duplicate function definitions with "
546
+ msg += f'name="{func.name}" and domain="{func.domain}"'
547
+ G_LOGGER.warning(msg)
548
+
549
+ return OnnxImporter.import_graph(
550
+ onnx_model.graph,
551
+ opset=model_opset,
552
+ import_domains=model_import_domains,
553
+ ir_version=model_ir_version,
554
+ producer_name=onnx_model.producer_name,
555
+ producer_version=onnx_model.producer_version,
556
+ functions=functions,
557
+ metadata_props=onnx_model.metadata_props,
558
+ )
File without changes