onnxslim 0.1.80__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxslim/__init__.py +16 -0
- onnxslim/__main__.py +4 -0
- onnxslim/argparser.py +215 -0
- onnxslim/cli/__init__.py +1 -0
- onnxslim/cli/_main.py +180 -0
- onnxslim/core/__init__.py +219 -0
- onnxslim/core/optimization/__init__.py +146 -0
- onnxslim/core/optimization/dead_node_elimination.py +151 -0
- onnxslim/core/optimization/subexpression_elimination.py +76 -0
- onnxslim/core/optimization/weight_tying.py +59 -0
- onnxslim/core/pattern/__init__.py +249 -0
- onnxslim/core/pattern/elimination/__init__.py +5 -0
- onnxslim/core/pattern/elimination/concat.py +61 -0
- onnxslim/core/pattern/elimination/reshape.py +77 -0
- onnxslim/core/pattern/elimination/reshape_as.py +64 -0
- onnxslim/core/pattern/elimination/slice.py +108 -0
- onnxslim/core/pattern/elimination/unsqueeze.py +92 -0
- onnxslim/core/pattern/fusion/__init__.py +8 -0
- onnxslim/core/pattern/fusion/concat_reshape.py +50 -0
- onnxslim/core/pattern/fusion/convadd.py +70 -0
- onnxslim/core/pattern/fusion/convbn.py +86 -0
- onnxslim/core/pattern/fusion/convmul.py +69 -0
- onnxslim/core/pattern/fusion/gelu.py +47 -0
- onnxslim/core/pattern/fusion/gemm.py +330 -0
- onnxslim/core/pattern/fusion/padconv.py +89 -0
- onnxslim/core/pattern/fusion/reduce.py +67 -0
- onnxslim/core/pattern/registry.py +28 -0
- onnxslim/misc/__init__.py +0 -0
- onnxslim/misc/tabulate.py +2681 -0
- onnxslim/third_party/__init__.py +0 -0
- onnxslim/third_party/_sympy/__init__.py +0 -0
- onnxslim/third_party/_sympy/functions.py +205 -0
- onnxslim/third_party/_sympy/numbers.py +397 -0
- onnxslim/third_party/_sympy/printers.py +491 -0
- onnxslim/third_party/_sympy/solve.py +172 -0
- onnxslim/third_party/_sympy/symbol.py +102 -0
- onnxslim/third_party/onnx_graphsurgeon/__init__.py +15 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/base_exporter.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/exporters/onnx_exporter.py +432 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/__init__.py +4 -0
- onnxslim/third_party/onnx_graphsurgeon/graph_pattern/graph_pattern.py +466 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/base_importer.py +33 -0
- onnxslim/third_party/onnx_graphsurgeon/importers/onnx_importer.py +558 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/function.py +274 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/graph.py +1575 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/node.py +266 -0
- onnxslim/third_party/onnx_graphsurgeon/ir/tensor.py +504 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/__init__.py +1 -0
- onnxslim/third_party/onnx_graphsurgeon/logger/logger.py +261 -0
- onnxslim/third_party/onnx_graphsurgeon/util/__init__.py +0 -0
- onnxslim/third_party/onnx_graphsurgeon/util/exception.py +20 -0
- onnxslim/third_party/onnx_graphsurgeon/util/misc.py +252 -0
- onnxslim/third_party/symbolic_shape_infer.py +3273 -0
- onnxslim/utils.py +794 -0
- onnxslim/version.py +1 -0
- onnxslim-0.1.80.dist-info/METADATA +207 -0
- onnxslim-0.1.80.dist-info/RECORD +65 -0
- onnxslim-0.1.80.dist-info/WHEEL +5 -0
- onnxslim-0.1.80.dist-info/entry_points.txt +2 -0
- onnxslim-0.1.80.dist-info/licenses/LICENSE +21 -0
- onnxslim-0.1.80.dist-info/top_level.txt +1 -0
- onnxslim-0.1.80.dist-info/zip-safe +1 -0
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
#
|
|
2
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
#
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import copy
|
|
20
|
+
from collections import OrderedDict
|
|
21
|
+
from typing import Any
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
import onnx
|
|
25
|
+
import onnx.numpy_helper
|
|
26
|
+
|
|
27
|
+
from onnxslim.third_party.onnx_graphsurgeon.importers.base_importer import BaseImporter
|
|
28
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.function import Function
|
|
29
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.graph import Graph
|
|
30
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.node import Node
|
|
31
|
+
from onnxslim.third_party.onnx_graphsurgeon.ir.tensor import (
|
|
32
|
+
Constant,
|
|
33
|
+
LazyValues,
|
|
34
|
+
SparseValues,
|
|
35
|
+
Tensor,
|
|
36
|
+
Variable,
|
|
37
|
+
)
|
|
38
|
+
from onnxslim.third_party.onnx_graphsurgeon.logger import G_LOGGER, LogMode
|
|
39
|
+
from onnxslim.third_party.onnx_graphsurgeon.util import misc
|
|
40
|
+
|
|
41
|
+
# Maps values from the AttributeType enum to their string representations, e.g., {1: "FLOAT"}
|
|
42
|
+
ATTR_TYPE_MAPPING = {v: k for k, v in onnx.AttributeProto.AttributeType.items()}
|
|
43
|
+
|
|
44
|
+
# Maps an ONNX attribute to the corresponding Python property
|
|
45
|
+
ONNX_PYTHON_ATTR_MAPPING = {
|
|
46
|
+
"FLOAT": "f",
|
|
47
|
+
"INT": "i",
|
|
48
|
+
"STRING": "s",
|
|
49
|
+
"TENSOR": "t",
|
|
50
|
+
"GRAPH": "g",
|
|
51
|
+
"FLOATS": "floats",
|
|
52
|
+
"INTS": "ints",
|
|
53
|
+
"STRINGS": "strings",
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def get_onnx_tensor_shape(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> list[int]:
|
|
58
|
+
"""Returns the shape of an ONNX tensor as a list of dimensions."""
|
|
59
|
+
shape = None
|
|
60
|
+
if isinstance(onnx_tensor, (onnx.TensorProto, onnx.SparseTensorProto)):
|
|
61
|
+
shape = onnx_tensor.dims
|
|
62
|
+
shape = tuple(shape)
|
|
63
|
+
elif onnx_tensor.type.tensor_type.HasField("shape"):
|
|
64
|
+
shape = []
|
|
65
|
+
for dim in onnx_tensor.type.tensor_type.shape.dim:
|
|
66
|
+
if dim.HasField("dim_param"):
|
|
67
|
+
shape.append(dim.dim_param)
|
|
68
|
+
elif dim.HasField("dim_value"):
|
|
69
|
+
shape.append(dim.dim_value)
|
|
70
|
+
else:
|
|
71
|
+
shape.append(None)
|
|
72
|
+
shape = tuple(shape)
|
|
73
|
+
return shape
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_dtype_name(onnx_type):
|
|
77
|
+
"""Get the ONNX data type name from its integer representation."""
|
|
78
|
+
return {val: key for key, val in onnx.TensorProto.DataType.items()}[onnx_type]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def get_itemsize(dtype):
|
|
82
|
+
"""Return the byte size of an element for a given ONNX data type."""
|
|
83
|
+
np_dtype = get_numpy_type(dtype)
|
|
84
|
+
if np_dtype is not None:
|
|
85
|
+
return np.dtype(np_dtype).itemsize
|
|
86
|
+
|
|
87
|
+
if dtype == onnx.TensorProto.BFLOAT16:
|
|
88
|
+
return 2
|
|
89
|
+
|
|
90
|
+
if dtype in {
|
|
91
|
+
onnx.TensorProto.FLOAT8E4M3FN,
|
|
92
|
+
onnx.TensorProto.FLOAT8E4M3FNUZ,
|
|
93
|
+
onnx.TensorProto.FLOAT8E5M2,
|
|
94
|
+
onnx.TensorProto.FLOAT8E5M2FNUZ,
|
|
95
|
+
}:
|
|
96
|
+
return 1
|
|
97
|
+
G_LOGGER.critical(f"Unsupported type: {dtype}")
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_numpy_type(onnx_type):
|
|
101
|
+
"""Convert an ONNX tensor type to a corresponding NumPy type, if supported."""
|
|
102
|
+
if not isinstance(onnx_type, int):
|
|
103
|
+
# Already a NumPy type
|
|
104
|
+
return onnx_type
|
|
105
|
+
|
|
106
|
+
numpy_unsupported_types = [
|
|
107
|
+
onnx.TensorProto.BFLOAT16,
|
|
108
|
+
onnx.TensorProto.FLOAT8E4M3FN,
|
|
109
|
+
onnx.TensorProto.FLOAT8E4M3FNUZ,
|
|
110
|
+
onnx.TensorProto.FLOAT8E5M2,
|
|
111
|
+
onnx.TensorProto.FLOAT8E5M2FNUZ,
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
# TENSOR_TYPE_TO_NP_TYPE maps types unsupported by NumPy to random other types.
|
|
115
|
+
# This obviously breaks things, so we need to treat this as a special case.
|
|
116
|
+
if onnx_type not in numpy_unsupported_types and onnx_type in onnx.helper.get_all_tensor_dtypes():
|
|
117
|
+
return onnx.helper.tensor_dtype_to_np_dtype(onnx_type)
|
|
118
|
+
return None
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def get_onnx_tensor_dtype(
|
|
122
|
+
onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto,
|
|
123
|
+
) -> np.dtype | onnx.TensorProto.DataType:
|
|
124
|
+
"""Determine the NumPy dtype or ONNX tensor data type from an ONNX tensor."""
|
|
125
|
+
if isinstance(onnx_tensor, onnx.TensorProto):
|
|
126
|
+
onnx_dtype = onnx_tensor.data_type
|
|
127
|
+
elif isinstance(onnx_tensor, onnx.SparseTensorProto):
|
|
128
|
+
onnx_dtype = onnx_tensor.values.data_type
|
|
129
|
+
elif onnx_tensor.type.HasField("tensor_type"):
|
|
130
|
+
onnx_dtype = onnx_tensor.type.tensor_type.elem_type
|
|
131
|
+
elif onnx_tensor.type.HasField("sequence_type"):
|
|
132
|
+
onnx_dtype = onnx_tensor.type.sequence_type.elem_type.tensor_type.elem_type
|
|
133
|
+
elif onnx_tensor.type.HasField("map_type"):
|
|
134
|
+
onnx_dtype = onnx_tensor.type.map_type.value_type
|
|
135
|
+
elif onnx_tensor.type.HasField("optional_type"):
|
|
136
|
+
onnx_dtype = onnx_tensor.type.optional_type.elem_type
|
|
137
|
+
elif onnx_tensor.type.HasField("sparse_tensor_type"):
|
|
138
|
+
onnx_dtype = onnx_tensor.type.sparse_tensor_type.elem_type
|
|
139
|
+
else:
|
|
140
|
+
onnx_dtype = onnx_tensor.type.opaque_type
|
|
141
|
+
|
|
142
|
+
dtype = get_numpy_type(onnx_dtype)
|
|
143
|
+
if dtype is not None:
|
|
144
|
+
return dtype
|
|
145
|
+
|
|
146
|
+
G_LOGGER.warning(
|
|
147
|
+
f"Could not convert: {get_dtype_name(onnx_dtype)} to a corresponding NumPy type. "
|
|
148
|
+
f"The original ONNX type will be preserved. ",
|
|
149
|
+
mode=LogMode.ONCE,
|
|
150
|
+
)
|
|
151
|
+
return onnx_dtype
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_onnx_tensor_type(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> str:
|
|
155
|
+
"""Determine the ONNX tensor type from a given ONNX TensorProto or ValueInfoProto."""
|
|
156
|
+
if isinstance(onnx_tensor, onnx.TensorProto):
|
|
157
|
+
return "tensor_type"
|
|
158
|
+
elif onnx_tensor.type.HasField("tensor_type"):
|
|
159
|
+
return "tensor_type"
|
|
160
|
+
elif onnx_tensor.type.HasField("sequence_type"):
|
|
161
|
+
return "sequence_type"
|
|
162
|
+
elif onnx_tensor.type.HasField("map_type"):
|
|
163
|
+
return "map_type"
|
|
164
|
+
elif onnx_tensor.type.HasField("optional_type"):
|
|
165
|
+
return "optional_type"
|
|
166
|
+
elif onnx_tensor.type.HasField("opaque_type"):
|
|
167
|
+
return "opaque_type"
|
|
168
|
+
elif onnx_tensor.type.HasField("sparse_tensor_type"):
|
|
169
|
+
return "sparse_tensor_type"
|
|
170
|
+
else:
|
|
171
|
+
return None
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def get_onnx_tensor_type(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto) -> str:
|
|
175
|
+
"""Identifies and returns the specific data type category of a given ONNX tensor."""
|
|
176
|
+
if isinstance(onnx_tensor, onnx.TensorProto):
|
|
177
|
+
return "tensor_type"
|
|
178
|
+
elif onnx_tensor.type.HasField("tensor_type"):
|
|
179
|
+
return "tensor_type"
|
|
180
|
+
elif onnx_tensor.type.HasField("sequence_type"):
|
|
181
|
+
return "sequence_type"
|
|
182
|
+
elif onnx_tensor.type.HasField("map_type"):
|
|
183
|
+
return "map_type"
|
|
184
|
+
elif onnx_tensor.type.HasField("optional_type"):
|
|
185
|
+
return "optional_type"
|
|
186
|
+
elif onnx_tensor.type.HasField("opaque_type"):
|
|
187
|
+
return "opaque_type"
|
|
188
|
+
elif onnx_tensor.type.HasField("sparse_tensor_type"):
|
|
189
|
+
return "sparse_tensor_type"
|
|
190
|
+
else:
|
|
191
|
+
return None
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
class OnnxImporter(BaseImporter):
|
|
195
|
+
@staticmethod
|
|
196
|
+
def get_opset(model_or_func: onnx.ModelProto | onnx.FunctionProto):
|
|
197
|
+
"""Return the ONNX opset version for the given ONNX model or function, or None if the information is
|
|
198
|
+
unavailable.
|
|
199
|
+
"""
|
|
200
|
+
class_name = "Function" if isinstance(model_or_func, onnx.FunctionProto) else "Model"
|
|
201
|
+
try:
|
|
202
|
+
for importer in OnnxImporter.get_import_domains(model_or_func):
|
|
203
|
+
if importer.domain in {"", "ai.onnx"}:
|
|
204
|
+
return importer.version
|
|
205
|
+
G_LOGGER.warning(f"{class_name} does not contain ONNX domain opset information! Using default opset.")
|
|
206
|
+
return None
|
|
207
|
+
except Exception:
|
|
208
|
+
G_LOGGER.warning(f"{class_name} does not contain opset information! Using default opset.")
|
|
209
|
+
return None
|
|
210
|
+
|
|
211
|
+
@staticmethod
|
|
212
|
+
def get_import_domains(model_or_func: onnx.ModelProto | onnx.FunctionProto):
|
|
213
|
+
"""Retrieves the import domains from an ONNX model or function."""
|
|
214
|
+
return model_or_func.opset_import
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def get_ir_version(model_or_func: onnx.ModelProto | onnx.FunctionProto):
|
|
218
|
+
"""Retrieves the ir_version from an ONNX model or function."""
|
|
219
|
+
try:
|
|
220
|
+
return model_or_func.ir_version
|
|
221
|
+
except Exception:
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
@staticmethod
|
|
225
|
+
def import_tensor(onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto | onnx.SparseTensorProto) -> Tensor:
|
|
226
|
+
"""Converts an ONNX tensor into a corresponding internal Tensor representation."""
|
|
227
|
+
if isinstance(onnx_tensor, onnx.SparseTensorProto):
|
|
228
|
+
return Constant(
|
|
229
|
+
name=onnx_tensor.values.name,
|
|
230
|
+
values=SparseValues(onnx_tensor),
|
|
231
|
+
data_location=onnx_tensor.values.data_location,
|
|
232
|
+
)
|
|
233
|
+
elif isinstance(onnx_tensor, onnx.TensorProto):
|
|
234
|
+
data_location = int(onnx_tensor.data_location) if onnx_tensor.HasField("data_location") else None
|
|
235
|
+
return Constant(
|
|
236
|
+
name=onnx_tensor.name,
|
|
237
|
+
values=LazyValues(onnx_tensor),
|
|
238
|
+
data_location=data_location,
|
|
239
|
+
)
|
|
240
|
+
else:
|
|
241
|
+
# A ValueInfoProto inside a subgraph might not have shape & type specified.
|
|
242
|
+
tensor = Variable(onnx_tensor.name)
|
|
243
|
+
if onnx_tensor.type.ByteSize() > 0:
|
|
244
|
+
tensor.dtype = get_onnx_tensor_dtype(onnx_tensor)
|
|
245
|
+
tensor.shape = get_onnx_tensor_shape(onnx_tensor)
|
|
246
|
+
tensor.type = get_onnx_tensor_type(onnx_tensor)
|
|
247
|
+
return tensor
|
|
248
|
+
|
|
249
|
+
@staticmethod
|
|
250
|
+
def import_attributes(
|
|
251
|
+
onnx_attributes: list[onnx.AttributeProto],
|
|
252
|
+
tensor_map: OrderedDict[str, Tensor],
|
|
253
|
+
subgraph_tensor_map: OrderedDict[str, Tensor],
|
|
254
|
+
opset: int,
|
|
255
|
+
import_domains: onnx.OperatorSetIdProto,
|
|
256
|
+
) -> OrderedDict[str, Any]:
|
|
257
|
+
"""Import ONNX attribute values into Python dictionary format, handling various ONNX attribute types."""
|
|
258
|
+
attr_dict = OrderedDict()
|
|
259
|
+
for attr in onnx_attributes:
|
|
260
|
+
|
|
261
|
+
def process_attr(attr_str: str):
|
|
262
|
+
"""Process an ONNX attribute based on its type, handling strings, tensors, graphs, and numeric
|
|
263
|
+
sequences.
|
|
264
|
+
"""
|
|
265
|
+
if attr.ref_attr_name:
|
|
266
|
+
attr_type = misc.convert_from_onnx_attr_type(attr.type)
|
|
267
|
+
return Node.AttributeRef(attr.ref_attr_name, attr_type)
|
|
268
|
+
processed = getattr(attr, ONNX_PYTHON_ATTR_MAPPING[attr_str])
|
|
269
|
+
if attr_str == "STRING":
|
|
270
|
+
processed = processed.decode()
|
|
271
|
+
elif attr_str == "TENSOR":
|
|
272
|
+
processed = OnnxImporter.import_tensor(processed)
|
|
273
|
+
elif attr_str == "GRAPH":
|
|
274
|
+
processed = OnnxImporter.import_graph(
|
|
275
|
+
processed,
|
|
276
|
+
misc.combine_dicts(tensor_map, subgraph_tensor_map),
|
|
277
|
+
opset=opset,
|
|
278
|
+
import_domains=import_domains,
|
|
279
|
+
)
|
|
280
|
+
elif attr_str in {"FLOATS", "INTS"}:
|
|
281
|
+
processed = list(processed)
|
|
282
|
+
elif attr_str == "STRINGS":
|
|
283
|
+
processed = [p.decode() for p in processed]
|
|
284
|
+
return processed
|
|
285
|
+
|
|
286
|
+
if attr.type in ATTR_TYPE_MAPPING:
|
|
287
|
+
attr_str = ATTR_TYPE_MAPPING[attr.type]
|
|
288
|
+
if attr_str in ONNX_PYTHON_ATTR_MAPPING:
|
|
289
|
+
attr_dict[attr.name] = process_attr(attr_str)
|
|
290
|
+
else:
|
|
291
|
+
G_LOGGER.warning(f"Attribute of type {attr_str} is currently unsupported. Skipping attribute.")
|
|
292
|
+
else:
|
|
293
|
+
G_LOGGER.warning(
|
|
294
|
+
f"Attribute type: {attr.type} was not recognized. Was the graph generated with a newer IR version than the installed `onnx` package? Skipping attribute."
|
|
295
|
+
)
|
|
296
|
+
return attr_dict
|
|
297
|
+
|
|
298
|
+
@staticmethod
|
|
299
|
+
def import_node(
|
|
300
|
+
onnx_node: onnx.NodeProto,
|
|
301
|
+
tensor_map: OrderedDict[str, Tensor],
|
|
302
|
+
subgraph_tensor_map: OrderedDict[str, Tensor],
|
|
303
|
+
opset,
|
|
304
|
+
import_domains: onnx.OperatorSetIdProto,
|
|
305
|
+
) -> Node:
|
|
306
|
+
# Optional inputs/outputs are represented by empty tensors. All other tensors should already have been populated during shape inference.
|
|
307
|
+
"""Parse ONNX node, mapping its attributes and tensors for model integration."""
|
|
308
|
+
|
|
309
|
+
def get_tensor(name: str, check_outer_graph=True):
|
|
310
|
+
"""Retrieve a tensor by its name, prioritizing the subgraph tensor map and optionally checking the outer
|
|
311
|
+
graph.
|
|
312
|
+
"""
|
|
313
|
+
if name in subgraph_tensor_map:
|
|
314
|
+
return subgraph_tensor_map[name]
|
|
315
|
+
|
|
316
|
+
if check_outer_graph and name in tensor_map:
|
|
317
|
+
return tensor_map[name]
|
|
318
|
+
|
|
319
|
+
if not name:
|
|
320
|
+
# Empty tensors are not tracked by the graph, as these represent optional inputs/outputs that have been omitted.
|
|
321
|
+
G_LOGGER.verbose("Generating empty tensor")
|
|
322
|
+
return Variable.empty()
|
|
323
|
+
|
|
324
|
+
G_LOGGER.verbose(
|
|
325
|
+
f"Tensor: {name} was not generated during shape inference, or shape inference was not run on this model. Creating a new Tensor."
|
|
326
|
+
)
|
|
327
|
+
subgraph_tensor_map[name] = Variable(name)
|
|
328
|
+
return subgraph_tensor_map[name]
|
|
329
|
+
|
|
330
|
+
# Retrieve Tensors for node inputs/outputs. Only empty tensors should need to be newly added.
|
|
331
|
+
def retrieve_node_inputs() -> list[Tensor]:
|
|
332
|
+
inputs = [] # List[Tensor]
|
|
333
|
+
for input_name in onnx_node.input:
|
|
334
|
+
inputs.append(get_tensor(input_name))
|
|
335
|
+
return inputs
|
|
336
|
+
|
|
337
|
+
def retrieve_node_outputs() -> list[Tensor]:
|
|
338
|
+
outputs = [] # List[Tensor]
|
|
339
|
+
for output_name in onnx_node.output:
|
|
340
|
+
# Node outputs cannot come from the outer graph, they must be created within the inner graph.
|
|
341
|
+
outputs.append(get_tensor(output_name, check_outer_graph=False))
|
|
342
|
+
return outputs
|
|
343
|
+
|
|
344
|
+
attributes = OnnxImporter.import_attributes(
|
|
345
|
+
onnx_node.attribute, tensor_map, subgraph_tensor_map, opset, import_domains
|
|
346
|
+
)
|
|
347
|
+
|
|
348
|
+
return Node(
|
|
349
|
+
op=onnx_node.op_type,
|
|
350
|
+
name=onnx_node.name,
|
|
351
|
+
attrs=attributes,
|
|
352
|
+
inputs=retrieve_node_inputs(),
|
|
353
|
+
outputs=retrieve_node_outputs(),
|
|
354
|
+
domain=onnx_node.domain if onnx_node.HasField("domain") else None,
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
@staticmethod
|
|
358
|
+
def import_function(
|
|
359
|
+
onnx_function: onnx.FunctionProto,
|
|
360
|
+
model_opset: int | None = None,
|
|
361
|
+
model_import_domains: onnx.OperatorSetIdProto = None,
|
|
362
|
+
) -> Function:
|
|
363
|
+
"""Imports an ONNX function to a Function object using the model opset and import domains."""
|
|
364
|
+
opset = OnnxImporter.get_opset(onnx_function) or model_opset
|
|
365
|
+
import_domains = OnnxImporter.get_import_domains(onnx_function) or model_import_domains
|
|
366
|
+
subgraph_tensor_map = OrderedDict() # Tensors in this function
|
|
367
|
+
|
|
368
|
+
def make_tensor(name: str) -> Tensor:
|
|
369
|
+
if name not in subgraph_tensor_map:
|
|
370
|
+
subgraph_tensor_map[name] = Variable(name)
|
|
371
|
+
return subgraph_tensor_map[name]
|
|
372
|
+
|
|
373
|
+
function_inputs = [make_tensor(inp) for inp in onnx_function.input]
|
|
374
|
+
function_outputs = [make_tensor(out) for out in onnx_function.output]
|
|
375
|
+
nodes = [
|
|
376
|
+
OnnxImporter.import_node(onnx_node, {}, subgraph_tensor_map, opset, import_domains)
|
|
377
|
+
for onnx_node in onnx_function.node
|
|
378
|
+
]
|
|
379
|
+
|
|
380
|
+
attributes = {}
|
|
381
|
+
if onnx_function.attribute:
|
|
382
|
+
attributes = {attr_name: None for attr_name in onnx_function.attribute}
|
|
383
|
+
if onnx_function.attribute_proto:
|
|
384
|
+
attrs_with_default_value = OnnxImporter.import_attributes(
|
|
385
|
+
onnx_function.attribute_proto,
|
|
386
|
+
None,
|
|
387
|
+
subgraph_tensor_map,
|
|
388
|
+
opset,
|
|
389
|
+
import_domains,
|
|
390
|
+
)
|
|
391
|
+
attributes.update(attrs_with_default_value)
|
|
392
|
+
|
|
393
|
+
return Function(
|
|
394
|
+
onnx_function.name,
|
|
395
|
+
onnx_function.domain,
|
|
396
|
+
nodes=nodes,
|
|
397
|
+
inputs=function_inputs,
|
|
398
|
+
outputs=function_outputs,
|
|
399
|
+
doc_string=onnx_function.doc_string,
|
|
400
|
+
opset=opset,
|
|
401
|
+
import_domains=import_domains,
|
|
402
|
+
attrs=attributes,
|
|
403
|
+
)
|
|
404
|
+
|
|
405
|
+
@staticmethod
|
|
406
|
+
def import_graph(
|
|
407
|
+
onnx_graph: onnx.GraphProto,
|
|
408
|
+
tensor_map: OrderedDict[str, Tensor] | None = None,
|
|
409
|
+
opset=None,
|
|
410
|
+
import_domains: onnx.OperatorSetIdProto = None,
|
|
411
|
+
ir_version=None,
|
|
412
|
+
producer_name: str | None = None,
|
|
413
|
+
producer_version: str | None = None,
|
|
414
|
+
functions: list[Function] | None = None,
|
|
415
|
+
metadata_props=None,
|
|
416
|
+
) -> Graph:
|
|
417
|
+
"""
|
|
418
|
+
Imports a Graph from an ONNX Graph.
|
|
419
|
+
|
|
420
|
+
Args:
|
|
421
|
+
onnx_graph (onnx.GraphProto): The ONNX graph to import.
|
|
422
|
+
|
|
423
|
+
tensor_map (OrderedDict[str, Tensor]): A mapping of tensor names to Tensors. This is generally only useful for subgraph import.
|
|
424
|
+
opset (int): The ONNX opset to use for this graph.
|
|
425
|
+
producer_name (str): The name of the tool used to generate the model. Defaults to "".
|
|
426
|
+
producer_version (str): The version of the generating tool. Defaults to "".
|
|
427
|
+
functions (List[Function]): The list of custom functions which are available to use in the model.
|
|
428
|
+
"""
|
|
429
|
+
functions = misc.default_value(functions, [])
|
|
430
|
+
tensor_map = copy.copy(misc.default_value(tensor_map, OrderedDict())) # Outer graph tensors, read-only
|
|
431
|
+
subgraph_tensor_map = OrderedDict() # Tensors in this subgraph
|
|
432
|
+
|
|
433
|
+
# Retrieves a Tensor from subgraph_tensor_map or the outer graph (tensor_map) if present, otherwise imports the tensor
|
|
434
|
+
# If overwrite=True, this function will overwrite previously imported tensors
|
|
435
|
+
# if the new tensor has more information available.
|
|
436
|
+
def get_tensor(
|
|
437
|
+
onnx_tensor: onnx.ValueInfoProto | onnx.TensorProto | onnx.SparseTensorProto,
|
|
438
|
+
overwrite=False,
|
|
439
|
+
check_outer_graph=True,
|
|
440
|
+
) -> Tensor:
|
|
441
|
+
if isinstance(onnx_tensor, onnx.SparseTensorProto):
|
|
442
|
+
name = onnx_tensor.values.name
|
|
443
|
+
else:
|
|
444
|
+
name = onnx_tensor.name
|
|
445
|
+
# Prioritize the subgraph even if check_outer_graph is set
|
|
446
|
+
if name in subgraph_tensor_map:
|
|
447
|
+
if overwrite:
|
|
448
|
+
tensor = OnnxImporter.import_tensor(onnx_tensor)
|
|
449
|
+
if isinstance(subgraph_tensor_map[name], Variable):
|
|
450
|
+
subgraph_tensor_map[name].dtype = subgraph_tensor_map[name].dtype or tensor.dtype
|
|
451
|
+
subgraph_tensor_map[name].shape = subgraph_tensor_map[name].shape or tensor.shape
|
|
452
|
+
return subgraph_tensor_map[name]
|
|
453
|
+
|
|
454
|
+
if check_outer_graph and name in tensor_map:
|
|
455
|
+
return tensor_map[name]
|
|
456
|
+
|
|
457
|
+
subgraph_tensor_map[name] = OnnxImporter.import_tensor(onnx_tensor)
|
|
458
|
+
return subgraph_tensor_map[name]
|
|
459
|
+
|
|
460
|
+
# Import initializers contents into Constants.
|
|
461
|
+
G_LOGGER.verbose("Importing initializers")
|
|
462
|
+
for initializer in onnx_graph.initializer:
|
|
463
|
+
get_tensor(initializer)
|
|
464
|
+
for initializer in onnx_graph.sparse_initializer:
|
|
465
|
+
get_tensor(initializer)
|
|
466
|
+
|
|
467
|
+
# Import all tensors whose shapes are known. Tensors may be repeated, and some of these
|
|
468
|
+
# duplicates may not include shape/dtype information, so overwrite is set to True
|
|
469
|
+
# so that we can capture all the information available about the tensor
|
|
470
|
+
G_LOGGER.verbose("Importing tensors with known shapes")
|
|
471
|
+
for tensor in onnx_graph.value_info:
|
|
472
|
+
get_tensor(tensor, overwrite=True)
|
|
473
|
+
|
|
474
|
+
# Import graph inputs and outputs. Initializers are not considered to be inputs.
|
|
475
|
+
# Graph inputs and outputs can never come from the outer graph!
|
|
476
|
+
initializer_names = set(
|
|
477
|
+
[tensor.name for tensor in onnx_graph.initializer]
|
|
478
|
+
+ [tensor.values.name for tensor in onnx_graph.sparse_initializer]
|
|
479
|
+
)
|
|
480
|
+
G_LOGGER.verbose("Importing graph inputs")
|
|
481
|
+
graph_inputs = [] # List[Tensor]
|
|
482
|
+
for inp in onnx_graph.input:
|
|
483
|
+
if inp.name not in initializer_names:
|
|
484
|
+
tensor = get_tensor(inp, check_outer_graph=False)
|
|
485
|
+
tensor.is_input = True
|
|
486
|
+
graph_inputs.append(tensor)
|
|
487
|
+
|
|
488
|
+
G_LOGGER.verbose("Importing graph outputs")
|
|
489
|
+
graph_outputs = [] # List[Tensor]
|
|
490
|
+
for out in onnx_graph.output:
|
|
491
|
+
tensor = get_tensor(out, check_outer_graph=False, overwrite=True)
|
|
492
|
+
tensor.is_output = True
|
|
493
|
+
graph_outputs.append(tensor)
|
|
494
|
+
|
|
495
|
+
G_LOGGER.verbose("Importing nodes")
|
|
496
|
+
nodes = [] # List[Node]
|
|
497
|
+
for onnx_node in onnx_graph.node:
|
|
498
|
+
node = OnnxImporter.import_node(onnx_node, tensor_map, subgraph_tensor_map, opset, import_domains)
|
|
499
|
+
nodes.append(node)
|
|
500
|
+
|
|
501
|
+
return Graph(
|
|
502
|
+
nodes=nodes,
|
|
503
|
+
inputs=graph_inputs,
|
|
504
|
+
outputs=graph_outputs,
|
|
505
|
+
name=onnx_graph.name,
|
|
506
|
+
doc_string=onnx_graph.doc_string,
|
|
507
|
+
producer_name=producer_name,
|
|
508
|
+
producer_version=producer_version,
|
|
509
|
+
opset=opset,
|
|
510
|
+
import_domains=import_domains,
|
|
511
|
+
ir_version=ir_version,
|
|
512
|
+
functions=functions,
|
|
513
|
+
metadata_props=metadata_props,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
def import_onnx(onnx_model: onnx.ModelProto) -> Graph:
|
|
518
|
+
"""
|
|
519
|
+
Import an onnx-graphsurgeon Graph from the provided ONNX model.
|
|
520
|
+
|
|
521
|
+
Args:
|
|
522
|
+
onnx_model (onnx.ModelProto): The ONNX model.
|
|
523
|
+
|
|
524
|
+
Returns:
|
|
525
|
+
Graph: A corresponding onnx-graphsurgeon Graph.
|
|
526
|
+
"""
|
|
527
|
+
model_opset = OnnxImporter.get_opset(onnx_model)
|
|
528
|
+
model_ir_version = OnnxImporter.get_ir_version(onnx_model)
|
|
529
|
+
model_import_domains = OnnxImporter.get_import_domains(onnx_model)
|
|
530
|
+
functions: list[Function] = [
|
|
531
|
+
OnnxImporter.import_function(
|
|
532
|
+
onnx_function,
|
|
533
|
+
model_opset=model_opset,
|
|
534
|
+
model_import_domains=model_import_domains,
|
|
535
|
+
)
|
|
536
|
+
for onnx_function in onnx_model.functions
|
|
537
|
+
]
|
|
538
|
+
|
|
539
|
+
# Functions are identified by their name and domain.
|
|
540
|
+
# Make sure that no two Functions share the same name and domain.
|
|
541
|
+
function_unique_ids = set()
|
|
542
|
+
for func in functions:
|
|
543
|
+
unique_id = func.unique_id
|
|
544
|
+
if unique_id in function_unique_ids:
|
|
545
|
+
msg = "Model contains duplicate function definitions with "
|
|
546
|
+
msg += f'name="{func.name}" and domain="{func.domain}"'
|
|
547
|
+
G_LOGGER.warning(msg)
|
|
548
|
+
|
|
549
|
+
return OnnxImporter.import_graph(
|
|
550
|
+
onnx_model.graph,
|
|
551
|
+
opset=model_opset,
|
|
552
|
+
import_domains=model_import_domains,
|
|
553
|
+
ir_version=model_ir_version,
|
|
554
|
+
producer_name=onnx_model.producer_name,
|
|
555
|
+
producer_version=onnx_model.producer_version,
|
|
556
|
+
functions=functions,
|
|
557
|
+
metadata_props=onnx_model.metadata_props,
|
|
558
|
+
)
|
|
File without changes
|