onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/serde.py
ADDED
|
@@ -0,0 +1,2043 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Serialize and deserialize the intermediate representation to/from ONNX protos."""
|
|
4
|
+
|
|
5
|
+
# NOTES for developers:
|
|
6
|
+
# NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
|
|
7
|
+
#
|
|
8
|
+
# NOTE: Protobuf serialization
|
|
9
|
+
# Initializing a protobuf message with initialized protobuf messages incurs
|
|
10
|
+
# a copy and is slow. Instead, use proto.add() to add to a repeated field.
|
|
11
|
+
# or initialize the message first and then set the fields if the fields are
|
|
12
|
+
# plain Python objects.
|
|
13
|
+
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import functools
|
|
17
|
+
import typing
|
|
18
|
+
|
|
19
|
+
__all__ = [
|
|
20
|
+
# Tensors
|
|
21
|
+
"TensorProtoTensor",
|
|
22
|
+
# Deserialization
|
|
23
|
+
"from_proto",
|
|
24
|
+
"from_onnx_text",
|
|
25
|
+
"deserialize_attribute",
|
|
26
|
+
"deserialize_dimension",
|
|
27
|
+
"deserialize_function",
|
|
28
|
+
"deserialize_graph",
|
|
29
|
+
"deserialize_metadata_props",
|
|
30
|
+
"deserialize_model",
|
|
31
|
+
"deserialize_node",
|
|
32
|
+
"deserialize_opset_import",
|
|
33
|
+
"deserialize_tensor",
|
|
34
|
+
"deserialize_tensor_shape",
|
|
35
|
+
"deserialize_type_proto_for_shape",
|
|
36
|
+
"deserialize_type_proto_for_type",
|
|
37
|
+
"deserialize_value_info_proto",
|
|
38
|
+
# Serialization
|
|
39
|
+
"to_proto",
|
|
40
|
+
"to_onnx_text",
|
|
41
|
+
"serialize_attribute_into",
|
|
42
|
+
"serialize_attribute",
|
|
43
|
+
"serialize_dimension_into",
|
|
44
|
+
"serialize_function_into",
|
|
45
|
+
"serialize_function",
|
|
46
|
+
"serialize_graph_into",
|
|
47
|
+
"serialize_graph",
|
|
48
|
+
"serialize_model_into",
|
|
49
|
+
"serialize_model",
|
|
50
|
+
"serialize_node_into",
|
|
51
|
+
"serialize_node",
|
|
52
|
+
"serialize_shape_into",
|
|
53
|
+
"serialize_reference_attribute_into",
|
|
54
|
+
"serialize_reference_attribute",
|
|
55
|
+
"serialize_tensor_into",
|
|
56
|
+
"serialize_tensor",
|
|
57
|
+
"serialize_type_into",
|
|
58
|
+
"serialize_type",
|
|
59
|
+
"serialize_value_into",
|
|
60
|
+
"serialize_value",
|
|
61
|
+
"SerdeError",
|
|
62
|
+
]
|
|
63
|
+
|
|
64
|
+
import collections
|
|
65
|
+
import logging
|
|
66
|
+
import os
|
|
67
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
68
|
+
from typing import Any, Callable
|
|
69
|
+
|
|
70
|
+
import numpy as np
|
|
71
|
+
import onnx # noqa: TID251
|
|
72
|
+
import onnx.external_data_helper # noqa: TID251
|
|
73
|
+
|
|
74
|
+
from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
|
|
75
|
+
|
|
76
|
+
if typing.TYPE_CHECKING:
|
|
77
|
+
import google.protobuf.internal.containers as proto_containers
|
|
78
|
+
|
|
79
|
+
logger = logging.getLogger(__name__)
|
|
80
|
+
|
|
81
|
+
_PLEASE_CONTRIBUTE = "Please contribute by creating a PR at https://github.com/onnx/onnx-ir."
|
|
82
|
+
_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
|
|
83
|
+
10 # ONNX IR version where value info in functions was introduced
|
|
84
|
+
)
|
|
85
|
+
_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
|
|
86
|
+
_T = typing.TypeVar("_T", bound=Callable[..., Any])
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class SerdeError(RuntimeError):
|
|
90
|
+
"""Error during serialization or deserialization."""
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def _capture_errors(arg_capturer: Callable[..., str]) -> Callable[[_T], _T]:
|
|
94
|
+
"""Decorator to capture errors and display the stack."""
|
|
95
|
+
|
|
96
|
+
def decorator(func: _T) -> _T:
|
|
97
|
+
@functools.wraps(func)
|
|
98
|
+
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
99
|
+
try:
|
|
100
|
+
return func(*args, **kwargs)
|
|
101
|
+
except Exception as e:
|
|
102
|
+
raise SerdeError(
|
|
103
|
+
f"Error calling {func.__name__} with: {arg_capturer(*args, **kwargs)}"
|
|
104
|
+
) from e
|
|
105
|
+
|
|
106
|
+
return wrapper # type: ignore
|
|
107
|
+
|
|
108
|
+
return decorator
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def _little_endian_dtype(dtype) -> np.dtype:
|
|
112
|
+
"""Create a small endian dtype on all platforms.
|
|
113
|
+
|
|
114
|
+
This is useful because ONNX always stores raw_data in small endian. On big
|
|
115
|
+
endian platforms, we still need to interpret the raw_data in small endian.
|
|
116
|
+
"""
|
|
117
|
+
return np.dtype(dtype).newbyteorder("<")
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@typing.overload
|
|
121
|
+
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
|
|
122
|
+
@typing.overload
|
|
123
|
+
def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap]
|
|
124
|
+
@typing.overload
|
|
125
|
+
def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap]
|
|
126
|
+
@typing.overload
|
|
127
|
+
def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap]
|
|
128
|
+
@typing.overload
|
|
129
|
+
def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap]
|
|
130
|
+
@typing.overload
|
|
131
|
+
def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap]
|
|
132
|
+
@typing.overload
|
|
133
|
+
def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap]
|
|
134
|
+
@typing.overload
|
|
135
|
+
def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap]
|
|
136
|
+
@typing.overload
|
|
137
|
+
def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap]
|
|
138
|
+
@typing.overload
|
|
139
|
+
def from_proto( # type: ignore[overload-overlap]
|
|
140
|
+
proto: onnx.TensorShapeProto.Dimension,
|
|
141
|
+
) -> tuple[int | _core.SymbolicDim, str | None]: ...
|
|
142
|
+
@typing.overload
|
|
143
|
+
def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap]
|
|
144
|
+
@typing.overload
|
|
145
|
+
def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap]
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def from_proto(proto: object) -> object:
|
|
149
|
+
"""Deserialize an ONNX proto message to an IR object."""
|
|
150
|
+
if isinstance(proto, onnx.ModelProto):
|
|
151
|
+
return deserialize_model(proto)
|
|
152
|
+
if isinstance(proto, onnx.GraphProto):
|
|
153
|
+
return deserialize_graph(proto)
|
|
154
|
+
if isinstance(proto, onnx.NodeProto):
|
|
155
|
+
return deserialize_node(proto)
|
|
156
|
+
if isinstance(proto, onnx.TensorProto):
|
|
157
|
+
return deserialize_tensor(proto)
|
|
158
|
+
if isinstance(proto, onnx.AttributeProto):
|
|
159
|
+
return deserialize_attribute(proto)
|
|
160
|
+
if isinstance(proto, onnx.ValueInfoProto):
|
|
161
|
+
return deserialize_value_info_proto(proto, None)
|
|
162
|
+
if isinstance(proto, onnx.TypeProto):
|
|
163
|
+
return _core.TypeAndShape(
|
|
164
|
+
deserialize_type_proto_for_type(proto),
|
|
165
|
+
deserialize_type_proto_for_shape(proto),
|
|
166
|
+
)
|
|
167
|
+
if isinstance(proto, onnx.FunctionProto):
|
|
168
|
+
return deserialize_function(proto)
|
|
169
|
+
if isinstance(proto, onnx.TensorShapeProto):
|
|
170
|
+
return deserialize_tensor_shape(proto)
|
|
171
|
+
if isinstance(proto, onnx.TensorShapeProto.Dimension):
|
|
172
|
+
return deserialize_dimension(proto)
|
|
173
|
+
if isinstance(proto, Sequence) and all(
|
|
174
|
+
isinstance(p, onnx.OperatorSetIdProto) for p in proto
|
|
175
|
+
):
|
|
176
|
+
return deserialize_opset_import(proto)
|
|
177
|
+
if isinstance(proto, Sequence) and all(
|
|
178
|
+
isinstance(p, onnx.StringStringEntryProto) for p in proto
|
|
179
|
+
):
|
|
180
|
+
return deserialize_metadata_props(proto)
|
|
181
|
+
raise NotImplementedError(
|
|
182
|
+
f"Deserialization of {type(proto)} in from_proto is not implemented. "
|
|
183
|
+
"Use a specific ir.serde.deserialize* function instead."
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def from_onnx_text(
|
|
188
|
+
model_text: str,
|
|
189
|
+
/,
|
|
190
|
+
initializers: Iterable[_protocols.TensorProtocol] | None = None,
|
|
191
|
+
) -> _core.Model:
|
|
192
|
+
"""Convert the ONNX textual representation to an IR model.
|
|
193
|
+
|
|
194
|
+
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
195
|
+
|
|
196
|
+
.. versionchanged:: 0.1.2
|
|
197
|
+
Added the ``initializers`` argument.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
model_text: The ONNX textual representation of the model.
|
|
201
|
+
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
202
|
+
will be added to the model as initializers. If a name does not exist in the model,
|
|
203
|
+
a ValueError will be raised.
|
|
204
|
+
|
|
205
|
+
Returns:
|
|
206
|
+
The IR model corresponding to the ONNX textual representation.
|
|
207
|
+
|
|
208
|
+
Raises:
|
|
209
|
+
ValueError: If a tensor name in `initializers` does not match any value in the model.
|
|
210
|
+
"""
|
|
211
|
+
proto = onnx.parser.parse_model(model_text)
|
|
212
|
+
model = deserialize_model(proto)
|
|
213
|
+
values = _convenience.create_value_mapping(model.graph)
|
|
214
|
+
if initializers:
|
|
215
|
+
# Add initializers to the model
|
|
216
|
+
for tensor in initializers:
|
|
217
|
+
name = tensor.name
|
|
218
|
+
if not name:
|
|
219
|
+
raise ValueError(
|
|
220
|
+
"Initializer tensor must have a name. "
|
|
221
|
+
f"Please provide a name for the initializer: {tensor}"
|
|
222
|
+
)
|
|
223
|
+
if name not in values:
|
|
224
|
+
raise ValueError(f"Value '{name}' does not exist in model.")
|
|
225
|
+
initializer = values[name]
|
|
226
|
+
initializer.const_value = tensor
|
|
227
|
+
model.graph.register_initializer(initializer)
|
|
228
|
+
return model
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def to_onnx_text(
|
|
232
|
+
model: _protocols.ModelProtocol, /, exclude_initializers: bool = False
|
|
233
|
+
) -> str:
|
|
234
|
+
"""Convert the IR model to the ONNX textual representation.
|
|
235
|
+
|
|
236
|
+
.. versionadded:: 0.1.2
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
model: The IR model to convert.
|
|
240
|
+
exclude_initializers: If True, the initializers will not be included in the output.
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
The ONNX textual representation of the model.
|
|
244
|
+
"""
|
|
245
|
+
proto = serialize_model(model)
|
|
246
|
+
if exclude_initializers:
|
|
247
|
+
del proto.graph.initializer[:]
|
|
248
|
+
text = onnx.printer.to_text(proto)
|
|
249
|
+
return text
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
@typing.overload
|
|
253
|
+
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
|
|
254
|
+
@typing.overload
|
|
255
|
+
def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
|
|
256
|
+
@typing.overload
|
|
257
|
+
def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap]
|
|
258
|
+
@typing.overload
|
|
259
|
+
def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap]
|
|
260
|
+
@typing.overload
|
|
261
|
+
def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
|
|
262
|
+
@typing.overload
|
|
263
|
+
def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
|
|
264
|
+
@typing.overload
|
|
265
|
+
def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap]
|
|
266
|
+
@typing.overload
|
|
267
|
+
def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap]
|
|
268
|
+
@typing.overload
|
|
269
|
+
def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap]
|
|
270
|
+
@typing.overload
|
|
271
|
+
def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def to_proto(ir_object: object) -> object:
|
|
275
|
+
"""Serialize an IR object to a proto."""
|
|
276
|
+
if isinstance(ir_object, _protocols.ModelProtocol):
|
|
277
|
+
return serialize_model(ir_object)
|
|
278
|
+
if isinstance(ir_object, _protocols.GraphProtocol):
|
|
279
|
+
return serialize_graph(ir_object)
|
|
280
|
+
if isinstance(ir_object, _protocols.NodeProtocol):
|
|
281
|
+
return serialize_node(ir_object)
|
|
282
|
+
if isinstance(ir_object, _protocols.TensorProtocol):
|
|
283
|
+
return serialize_tensor(ir_object)
|
|
284
|
+
if isinstance(ir_object, _protocols.ValueProtocol):
|
|
285
|
+
return serialize_value(ir_object)
|
|
286
|
+
if isinstance(ir_object, _protocols.AttributeProtocol) and not ir_object.is_ref():
|
|
287
|
+
return serialize_attribute(ir_object)
|
|
288
|
+
if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
|
|
289
|
+
assert ir_object.is_ref()
|
|
290
|
+
return serialize_reference_attribute(ir_object)
|
|
291
|
+
if isinstance(ir_object, _protocols.TypeProtocol):
|
|
292
|
+
return serialize_type_into(onnx.TypeProto(), ir_object)
|
|
293
|
+
if isinstance(ir_object, _protocols.GraphViewProtocol):
|
|
294
|
+
return serialize_graph(ir_object)
|
|
295
|
+
if isinstance(ir_object, _protocols.FunctionProtocol):
|
|
296
|
+
return serialize_function(ir_object)
|
|
297
|
+
raise NotImplementedError(
|
|
298
|
+
f"Serialization of {type(ir_object)} in to_proto is not implemented. "
|
|
299
|
+
"Use a specific ir.serde.serialize* function instead."
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
304
|
+
"""A tensor initialized from a tensor proto."""
|
|
305
|
+
|
|
306
|
+
__slots__ = ("_proto",)
|
|
307
|
+
|
|
308
|
+
def __init__(self, proto: onnx.TensorProto) -> None:
|
|
309
|
+
super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props))
|
|
310
|
+
self._proto = proto
|
|
311
|
+
|
|
312
|
+
@property
|
|
313
|
+
def name(self) -> str:
|
|
314
|
+
return self._proto.name
|
|
315
|
+
|
|
316
|
+
@name.setter
|
|
317
|
+
def name(self, value: str | None) -> None:
|
|
318
|
+
if value is None:
|
|
319
|
+
self._proto.ClearField("name")
|
|
320
|
+
else:
|
|
321
|
+
self._proto.name = value
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def shape(self) -> _core.Shape:
|
|
325
|
+
return _core.Shape(self._proto.dims, frozen=True)
|
|
326
|
+
|
|
327
|
+
@property
|
|
328
|
+
def dtype(self) -> _enums.DataType:
|
|
329
|
+
return _enums.DataType(self._proto.data_type)
|
|
330
|
+
|
|
331
|
+
@property # type: ignore[misc]
|
|
332
|
+
def doc_string(self) -> str:
|
|
333
|
+
return self._proto.doc_string
|
|
334
|
+
|
|
335
|
+
@property
|
|
336
|
+
def raw(self) -> onnx.TensorProto:
|
|
337
|
+
return self._proto
|
|
338
|
+
|
|
339
|
+
def __repr__(self) -> str:
|
|
340
|
+
if self.size <= 10:
|
|
341
|
+
tensor_lines = repr(self.numpy()).split("\n")
|
|
342
|
+
tensor_text = " ".join(line.strip() for line in tensor_lines)
|
|
343
|
+
return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
|
|
344
|
+
return f"{self._repr_base()}(name={self.name!r})"
|
|
345
|
+
|
|
346
|
+
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
347
|
+
"""Return the tensor as a numpy array, compatible with np.array."""
|
|
348
|
+
return self.numpy().__array__(dtype)
|
|
349
|
+
|
|
350
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
351
|
+
return self.numpy().__dlpack__(stream=stream)
|
|
352
|
+
|
|
353
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
354
|
+
return self.numpy().__dlpack_device__()
|
|
355
|
+
|
|
356
|
+
def numpy(self) -> np.ndarray:
|
|
357
|
+
"""Return the tensor as a numpy array.
|
|
358
|
+
|
|
359
|
+
This is an improved version of onnx.numpy_helper.to_array.
|
|
360
|
+
It first reads the data using the dtype corresponding to the tensor
|
|
361
|
+
proto data field, then converts it to the correct dtype and shape.
|
|
362
|
+
Special cases are bfloat16, complex and int4 where we need to
|
|
363
|
+
reinterpret the data. Other types can simply be casted.
|
|
364
|
+
|
|
365
|
+
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
|
|
366
|
+
package are used. The values can be reinterpreted as bit representations
|
|
367
|
+
using the ``.view()`` method.
|
|
368
|
+
|
|
369
|
+
When the data type is a string, this method returns a numpy array
|
|
370
|
+
of bytes instead of a numpy array of strings, to follow the ONNX
|
|
371
|
+
specification.
|
|
372
|
+
|
|
373
|
+
External tensors are not supported by this class. Use
|
|
374
|
+
:class:`onnx_ir.ExternalTensor` instead.
|
|
375
|
+
|
|
376
|
+
Raises:
|
|
377
|
+
ValueError: If the data type is UNDEFINED.
|
|
378
|
+
"""
|
|
379
|
+
dtype = self.dtype
|
|
380
|
+
if dtype == _enums.DataType.UNDEFINED:
|
|
381
|
+
raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
|
|
382
|
+
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
|
|
383
|
+
raise ValueError(
|
|
384
|
+
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
shape = self._proto.dims
|
|
388
|
+
|
|
389
|
+
if self._proto.HasField("raw_data"):
|
|
390
|
+
if dtype.bitwidth == 4:
|
|
391
|
+
return _type_casting.unpack_4bitx2(
|
|
392
|
+
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
|
|
393
|
+
).view(dtype.numpy())
|
|
394
|
+
if dtype.bitwidth == 2:
|
|
395
|
+
return _type_casting.unpack_2bitx4(
|
|
396
|
+
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
|
|
397
|
+
).view(dtype.numpy())
|
|
398
|
+
return np.frombuffer(
|
|
399
|
+
self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
|
|
400
|
+
).reshape(shape)
|
|
401
|
+
if dtype == _enums.DataType.STRING:
|
|
402
|
+
return np.array(self._proto.string_data).reshape(shape)
|
|
403
|
+
if self._proto.int32_data:
|
|
404
|
+
assert dtype in {
|
|
405
|
+
_enums.DataType.BFLOAT16,
|
|
406
|
+
_enums.DataType.BOOL,
|
|
407
|
+
_enums.DataType.FLOAT16,
|
|
408
|
+
_enums.DataType.FLOAT4E2M1,
|
|
409
|
+
_enums.DataType.FLOAT8E4M3FN,
|
|
410
|
+
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
411
|
+
_enums.DataType.FLOAT8E5M2,
|
|
412
|
+
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
413
|
+
_enums.DataType.FLOAT8E8M0,
|
|
414
|
+
_enums.DataType.INT16,
|
|
415
|
+
_enums.DataType.INT32,
|
|
416
|
+
_enums.DataType.INT2,
|
|
417
|
+
_enums.DataType.INT4,
|
|
418
|
+
_enums.DataType.INT8,
|
|
419
|
+
_enums.DataType.UINT16,
|
|
420
|
+
_enums.DataType.UINT2,
|
|
421
|
+
_enums.DataType.UINT4,
|
|
422
|
+
_enums.DataType.UINT8,
|
|
423
|
+
}, f"Unsupported dtype {dtype} for int32_data"
|
|
424
|
+
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
|
|
425
|
+
if dtype.bitwidth == 32:
|
|
426
|
+
return array.reshape(shape)
|
|
427
|
+
if dtype.bitwidth == 16:
|
|
428
|
+
# Reinterpret the int32 as float16 or bfloat16
|
|
429
|
+
return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
|
|
430
|
+
if dtype.bitwidth == 8:
|
|
431
|
+
return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
|
|
432
|
+
if dtype.bitwidth == 4:
|
|
433
|
+
return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
|
|
434
|
+
dtype.numpy()
|
|
435
|
+
)
|
|
436
|
+
if dtype.bitwidth == 2:
|
|
437
|
+
return _type_casting.unpack_2bitx4(array.astype(np.uint8), shape).view(
|
|
438
|
+
dtype.numpy()
|
|
439
|
+
)
|
|
440
|
+
raise ValueError(
|
|
441
|
+
f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
|
|
442
|
+
)
|
|
443
|
+
if self._proto.int64_data:
|
|
444
|
+
assert dtype in {
|
|
445
|
+
_enums.DataType.INT64,
|
|
446
|
+
}, f"Unsupported dtype {dtype} for int64_data"
|
|
447
|
+
return np.array(
|
|
448
|
+
self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
|
|
449
|
+
).reshape(shape)
|
|
450
|
+
if self._proto.uint64_data:
|
|
451
|
+
assert dtype in {
|
|
452
|
+
_enums.DataType.UINT64,
|
|
453
|
+
_enums.DataType.UINT32,
|
|
454
|
+
}, f"Unsupported dtype {dtype} for uint64_data"
|
|
455
|
+
array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
|
|
456
|
+
if dtype == _enums.DataType.UINT32:
|
|
457
|
+
return array.astype(np.uint32).reshape(shape)
|
|
458
|
+
return array.reshape(shape)
|
|
459
|
+
if self._proto.float_data:
|
|
460
|
+
assert dtype in {
|
|
461
|
+
_enums.DataType.FLOAT,
|
|
462
|
+
_enums.DataType.COMPLEX64,
|
|
463
|
+
}, f"Unsupported dtype {dtype} for float_data"
|
|
464
|
+
array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
|
|
465
|
+
if dtype == _enums.DataType.COMPLEX64:
|
|
466
|
+
return array.view(np.complex64).reshape(shape)
|
|
467
|
+
return array.reshape(shape)
|
|
468
|
+
if self._proto.double_data:
|
|
469
|
+
assert dtype in {
|
|
470
|
+
_enums.DataType.DOUBLE,
|
|
471
|
+
_enums.DataType.COMPLEX128,
|
|
472
|
+
}, f"Unsupported dtype {dtype} for double_data"
|
|
473
|
+
array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
|
|
474
|
+
if dtype == _enums.DataType.COMPLEX128:
|
|
475
|
+
return array.view(np.complex128).reshape(shape)
|
|
476
|
+
return array.reshape(shape)
|
|
477
|
+
|
|
478
|
+
# Empty tensor. We return a size 0 array with the correct shape
|
|
479
|
+
return np.zeros(shape, dtype=dtype.numpy())
|
|
480
|
+
|
|
481
|
+
def tobytes(self) -> bytes:
|
|
482
|
+
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian.
|
|
483
|
+
|
|
484
|
+
Raises:
|
|
485
|
+
ValueError: If the tensor is a string tensor or an external tensor.
|
|
486
|
+
ValueError: If the tensor is of UNDEFINED data type.
|
|
487
|
+
"""
|
|
488
|
+
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
|
|
489
|
+
raise ValueError(
|
|
490
|
+
"Cannot convert external tensor to bytes. Use ir.ExternalTensor instead."
|
|
491
|
+
)
|
|
492
|
+
if self.dtype == _enums.DataType.STRING:
|
|
493
|
+
raise ValueError("Cannot convert string tensor to bytes.")
|
|
494
|
+
if self.dtype == _enums.DataType.UNDEFINED:
|
|
495
|
+
raise ValueError("Cannot convert UNDEFINED tensor to bytes.")
|
|
496
|
+
|
|
497
|
+
if self._proto.HasField("raw_data"):
|
|
498
|
+
return self._proto.raw_data
|
|
499
|
+
if self._proto.float_data:
|
|
500
|
+
return np.array(
|
|
501
|
+
self._proto.float_data, dtype=_little_endian_dtype(np.float32)
|
|
502
|
+
).tobytes()
|
|
503
|
+
if self._proto.int32_data:
|
|
504
|
+
array = np.array(self._proto.int32_data, dtype=np.int32)
|
|
505
|
+
if self.dtype in {
|
|
506
|
+
_enums.DataType.INT16,
|
|
507
|
+
_enums.DataType.UINT16,
|
|
508
|
+
_enums.DataType.FLOAT16,
|
|
509
|
+
_enums.DataType.BFLOAT16,
|
|
510
|
+
}:
|
|
511
|
+
return array.astype(_little_endian_dtype(np.uint16)).tobytes()
|
|
512
|
+
if self.dtype in {
|
|
513
|
+
_enums.DataType.INT8,
|
|
514
|
+
_enums.DataType.UINT8,
|
|
515
|
+
_enums.DataType.BOOL,
|
|
516
|
+
_enums.DataType.FLOAT8E4M3FN,
|
|
517
|
+
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
518
|
+
_enums.DataType.FLOAT8E5M2,
|
|
519
|
+
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
520
|
+
_enums.DataType.FLOAT8E8M0,
|
|
521
|
+
_enums.DataType.INT2,
|
|
522
|
+
_enums.DataType.INT4,
|
|
523
|
+
_enums.DataType.UINT2,
|
|
524
|
+
_enums.DataType.UINT4,
|
|
525
|
+
_enums.DataType.FLOAT4E2M1,
|
|
526
|
+
}:
|
|
527
|
+
# uint2, uint4, int2 and int4 values are already packed, even when stored as int32
|
|
528
|
+
# so we don't need to pack them again
|
|
529
|
+
return array.astype(_little_endian_dtype(np.uint8)).tobytes()
|
|
530
|
+
assert self.dtype == _enums.DataType.INT32
|
|
531
|
+
return array.tobytes()
|
|
532
|
+
if self._proto.int64_data:
|
|
533
|
+
return np.array(
|
|
534
|
+
self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
|
|
535
|
+
).tobytes()
|
|
536
|
+
if self._proto.double_data:
|
|
537
|
+
return np.array(
|
|
538
|
+
self._proto.double_data, dtype=_little_endian_dtype(np.float64)
|
|
539
|
+
).tobytes()
|
|
540
|
+
if self._proto.uint64_data:
|
|
541
|
+
array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
|
|
542
|
+
if self.dtype == _enums.DataType.UINT32:
|
|
543
|
+
return array.astype(_little_endian_dtype(np.uint32)).tobytes()
|
|
544
|
+
assert self.dtype == _enums.DataType.UINT64
|
|
545
|
+
return array.tobytes()
|
|
546
|
+
# The repeating fields can be empty and still valid.
|
|
547
|
+
# For example, int32_data can be empty and still be a valid tensor.
|
|
548
|
+
return b""
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def _get_field(proto: Any, field: str) -> Any:
|
|
552
|
+
if proto.HasField(field):
|
|
553
|
+
return getattr(proto, field)
|
|
554
|
+
return None
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
# Deserialization
|
|
558
|
+
|
|
559
|
+
|
|
560
|
+
def deserialize_opset_import(
|
|
561
|
+
protos: Sequence[onnx.OperatorSetIdProto],
|
|
562
|
+
) -> dict[str, int]:
|
|
563
|
+
"""Deserialize a sequence of OperatorSetIdProto to opset imports mapping.
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
protos: The sequence of ONNX OperatorSetIdProto objects.
|
|
567
|
+
|
|
568
|
+
Returns:
|
|
569
|
+
A dictionary mapping domain strings to version integers.
|
|
570
|
+
"""
|
|
571
|
+
return {opset.domain: opset.version for opset in protos}
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def _parse_experimental_function_value_info_name(
|
|
575
|
+
name: str,
|
|
576
|
+
) -> tuple[str, str, str] | None:
|
|
577
|
+
"""Get the function domain, name and value name if the value info is for a function.
|
|
578
|
+
|
|
579
|
+
The experimental format is:
|
|
580
|
+
{function_domain}::{function_name}/{value_name}
|
|
581
|
+
|
|
582
|
+
Args:
|
|
583
|
+
name: The name stored in the value info.
|
|
584
|
+
|
|
585
|
+
Returns:
|
|
586
|
+
A tuple of the function domain, function name and value name if the value info is for a function.
|
|
587
|
+
None otherwise.
|
|
588
|
+
"""
|
|
589
|
+
parts = name.split("/")
|
|
590
|
+
expected_parts = 2
|
|
591
|
+
if len(parts) != expected_parts:
|
|
592
|
+
return None
|
|
593
|
+
function, value_name = parts
|
|
594
|
+
parts = function.split("::")
|
|
595
|
+
if len(parts) != expected_parts:
|
|
596
|
+
return None
|
|
597
|
+
# NOTE: There will not be overload because overloads are introduced in ONNX IR v10, which also
|
|
598
|
+
# introduces the ValueInfoProto for functions
|
|
599
|
+
function_domain, function_name = parts
|
|
600
|
+
return function_domain, function_name, value_name
|
|
601
|
+
|
|
602
|
+
|
|
603
|
+
def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
|
|
604
|
+
"""Deserialize an ONNX ModelProto into an IR Model.
|
|
605
|
+
|
|
606
|
+
Args:
|
|
607
|
+
proto: The ONNX ModelProto to deserialize.
|
|
608
|
+
|
|
609
|
+
Returns:
|
|
610
|
+
An IR Model object representing the ONNX model.
|
|
611
|
+
"""
|
|
612
|
+
graph = _deserialize_graph(proto.graph, [])
|
|
613
|
+
graph.opset_imports.update(deserialize_opset_import(proto.opset_import))
|
|
614
|
+
|
|
615
|
+
functions = []
|
|
616
|
+
for func in proto.functions:
|
|
617
|
+
functions.append(deserialize_function(func))
|
|
618
|
+
|
|
619
|
+
model = _core.Model(
|
|
620
|
+
graph,
|
|
621
|
+
ir_version=proto.ir_version,
|
|
622
|
+
producer_name=_get_field(proto, "producer_name"),
|
|
623
|
+
producer_version=_get_field(proto, "producer_version"),
|
|
624
|
+
domain=_get_field(proto, "domain"),
|
|
625
|
+
model_version=_get_field(proto, "model_version"),
|
|
626
|
+
doc_string=_get_field(proto, "doc_string"),
|
|
627
|
+
functions=functions,
|
|
628
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
# Handle experimental value info for functions created by the dynamo exporter in IR version 9
|
|
632
|
+
if model.ir_version < _FUNCTION_VALUE_INFO_SUPPORTED_VERSION:
|
|
633
|
+
_deserialized_experimental_value_info_for_function_ir9(
|
|
634
|
+
model.functions, proto.graph.value_info
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return model
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def _deserialized_experimental_value_info_for_function_ir9(
|
|
641
|
+
functions: Mapping[_protocols.OperatorIdentifier, _core.Function],
|
|
642
|
+
value_info_protos: Sequence[onnx.ValueInfoProto],
|
|
643
|
+
) -> None:
|
|
644
|
+
"""Deserialize value info for functions when they are stored in an experimental format.
|
|
645
|
+
|
|
646
|
+
The experimental format is:
|
|
647
|
+
{function_domain}::{function_name}/{value_name}
|
|
648
|
+
"""
|
|
649
|
+
# Parse value info for functions from the main graph
|
|
650
|
+
function_value_value_info_mapping: collections.defaultdict[
|
|
651
|
+
_protocols.OperatorIdentifier,
|
|
652
|
+
dict[str, onnx.ValueInfoProto],
|
|
653
|
+
] = collections.defaultdict(dict)
|
|
654
|
+
for value_info_proto in value_info_protos:
|
|
655
|
+
if (
|
|
656
|
+
parsed := _parse_experimental_function_value_info_name(value_info_proto.name)
|
|
657
|
+
) is None:
|
|
658
|
+
continue
|
|
659
|
+
function_domain, function_name, value_name = parsed
|
|
660
|
+
function_overload = ""
|
|
661
|
+
# TODO(justinchuby): Create a constructor for OperatorIdentifier so we don't create tuples manually
|
|
662
|
+
function_id = (function_domain, function_name, function_overload)
|
|
663
|
+
function = functions.get(function_id)
|
|
664
|
+
if function is None:
|
|
665
|
+
# Function not found
|
|
666
|
+
logger.debug(
|
|
667
|
+
"Function with ID '%s' not found in model functions. Value info '%s' will be ignored.",
|
|
668
|
+
function_id,
|
|
669
|
+
value_info_proto.name,
|
|
670
|
+
)
|
|
671
|
+
continue
|
|
672
|
+
function_value_value_info_mapping[function_id][value_name] = value_info_proto
|
|
673
|
+
for function_id, function in functions.items():
|
|
674
|
+
for input in function.inputs:
|
|
675
|
+
if input.name in function_value_value_info_mapping[function_id]:
|
|
676
|
+
deserialize_value_info_proto(
|
|
677
|
+
function_value_value_info_mapping[function_id][input.name], input
|
|
678
|
+
)
|
|
679
|
+
for node in function:
|
|
680
|
+
for output in node.outputs:
|
|
681
|
+
if output.name in function_value_value_info_mapping[function_id]:
|
|
682
|
+
deserialize_value_info_proto(
|
|
683
|
+
function_value_value_info_mapping[function_id][output.name],
|
|
684
|
+
output,
|
|
685
|
+
)
|
|
686
|
+
# The function outputs are handled as well because they are also node outputs
|
|
687
|
+
|
|
688
|
+
|
|
689
|
+
def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
|
|
690
|
+
"""Deserialize a graph proto, recursively if needed.
|
|
691
|
+
|
|
692
|
+
Args:
|
|
693
|
+
proto: The graph proto to deserialize.
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
IR Graph.
|
|
697
|
+
|
|
698
|
+
.. versionadded:: 0.1.3
|
|
699
|
+
Support for `quantization_annotation` is added.
|
|
700
|
+
"""
|
|
701
|
+
return _deserialize_graph(proto, [])
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
@_capture_errors(lambda proto, scoped_values: proto.name)
|
|
705
|
+
def _deserialize_graph(
|
|
706
|
+
proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]]
|
|
707
|
+
) -> _core.Graph:
|
|
708
|
+
"""Deserialize a graph proto, recursively if needed.
|
|
709
|
+
|
|
710
|
+
Args:
|
|
711
|
+
proto: The graph proto to deserialize.
|
|
712
|
+
scoped_values: A list of dictionaries mapping value names to their corresponding Value objects.
|
|
713
|
+
Every time we enter a new graph, a new scope is created and appended to this list to include
|
|
714
|
+
all values defined in the scope.
|
|
715
|
+
scoped_value_info: A list of dictionaries mapping value names to their corresponding ValueInfoProto.
|
|
716
|
+
|
|
717
|
+
Returns:
|
|
718
|
+
IR Graph.
|
|
719
|
+
"""
|
|
720
|
+
# Process TensorAnnotation for quantization
|
|
721
|
+
quantization_annotations = {
|
|
722
|
+
annotation.tensor_name: annotation for annotation in proto.quantization_annotation
|
|
723
|
+
}
|
|
724
|
+
|
|
725
|
+
# Create values for inputs
|
|
726
|
+
inputs = [_core.Value(name=info.name) for info in proto.input]
|
|
727
|
+
for info, value in zip(proto.input, inputs):
|
|
728
|
+
deserialize_value_info_proto(info, value)
|
|
729
|
+
|
|
730
|
+
# Add TensorAnnotation for inputs if they exist
|
|
731
|
+
if value.name in quantization_annotations:
|
|
732
|
+
_deserialize_quantization_annotation(quantization_annotations[value.name], value)
|
|
733
|
+
|
|
734
|
+
# Initialize the values dictionary for this graph scope with the inputs and initializers
|
|
735
|
+
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
736
|
+
|
|
737
|
+
# Enter the graph scope by pushing the values for this scope to the stack
|
|
738
|
+
scoped_values.append(values)
|
|
739
|
+
|
|
740
|
+
# Build the value info dictionary to allow for quick lookup for this graph scope
|
|
741
|
+
value_info = {info.name: info for info in proto.value_info}
|
|
742
|
+
|
|
743
|
+
# Create values for initializers
|
|
744
|
+
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
|
|
745
|
+
initializer_values = []
|
|
746
|
+
for i, tensor in enumerate(initializer_tensors):
|
|
747
|
+
initializer_name = tensor.name
|
|
748
|
+
if not initializer_name:
|
|
749
|
+
logger.warning(
|
|
750
|
+
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
|
|
751
|
+
i,
|
|
752
|
+
)
|
|
753
|
+
continue
|
|
754
|
+
if initializer_name in values:
|
|
755
|
+
# The initializer is for an input
|
|
756
|
+
initializer_value = values[initializer_name]
|
|
757
|
+
initializer_value.const_value = tensor
|
|
758
|
+
else:
|
|
759
|
+
# The initializer is for some other value. Create this value first
|
|
760
|
+
initializer_value = _core.Value(
|
|
761
|
+
None,
|
|
762
|
+
index=None,
|
|
763
|
+
name=initializer_name,
|
|
764
|
+
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
|
|
765
|
+
# Users expect initialized values to have shape and type information.
|
|
766
|
+
type=_core.TensorType(tensor.dtype),
|
|
767
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
768
|
+
const_value=tensor,
|
|
769
|
+
)
|
|
770
|
+
if initializer_name in value_info:
|
|
771
|
+
deserialize_value_info_proto(value_info[initializer_name], initializer_value)
|
|
772
|
+
if initializer_value.name in quantization_annotations:
|
|
773
|
+
_deserialize_quantization_annotation(
|
|
774
|
+
quantization_annotations[initializer_value.name], initializer_value
|
|
775
|
+
)
|
|
776
|
+
values[initializer_name] = initializer_value
|
|
777
|
+
initializer_values.append(initializer_value)
|
|
778
|
+
|
|
779
|
+
# Declare values for all node outputs from this graph scope. This is necessary
|
|
780
|
+
# to handle the case where a node in a subgraph uses a value that is declared out
|
|
781
|
+
# of order in the outer graph. Declaring the values first allows us to find the
|
|
782
|
+
# values later when deserializing the nodes in subgraphs.
|
|
783
|
+
for node in proto.node:
|
|
784
|
+
_declare_node_outputs(
|
|
785
|
+
node,
|
|
786
|
+
values,
|
|
787
|
+
value_info=value_info,
|
|
788
|
+
quantization_annotations=quantization_annotations,
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
# Deserialize nodes with all known values
|
|
792
|
+
nodes = [
|
|
793
|
+
_deserialize_node(node, scoped_values, value_info, quantization_annotations)
|
|
794
|
+
for node in proto.node
|
|
795
|
+
]
|
|
796
|
+
|
|
797
|
+
outputs = []
|
|
798
|
+
for info in proto.output:
|
|
799
|
+
# Fill in values for graph outputs
|
|
800
|
+
output_name = info.name
|
|
801
|
+
if output_name not in values:
|
|
802
|
+
# Handle (invalid) graph outputs that do not have any producers
|
|
803
|
+
logger.warning(
|
|
804
|
+
"Output '%s' is not produced by any node. The graph has an invalid output",
|
|
805
|
+
output_name,
|
|
806
|
+
)
|
|
807
|
+
value = _core.Value(name=output_name)
|
|
808
|
+
else:
|
|
809
|
+
# A valid, normal graph output
|
|
810
|
+
value = values[output_name]
|
|
811
|
+
# Fill in shape/type information
|
|
812
|
+
deserialize_value_info_proto(info, value)
|
|
813
|
+
outputs.append(value)
|
|
814
|
+
|
|
815
|
+
# Exit the graph scope by popping the values for this scope from the stack
|
|
816
|
+
scoped_values.pop()
|
|
817
|
+
|
|
818
|
+
return _core.Graph(
|
|
819
|
+
inputs,
|
|
820
|
+
outputs,
|
|
821
|
+
nodes=nodes,
|
|
822
|
+
initializers=initializer_values,
|
|
823
|
+
doc_string=_get_field(proto, "doc_string"),
|
|
824
|
+
name=_get_field(proto, "name"),
|
|
825
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def _declare_node_outputs(
|
|
830
|
+
proto: onnx.NodeProto,
|
|
831
|
+
current_value_scope: dict[str, _core.Value],
|
|
832
|
+
value_info: dict[str, onnx.ValueInfoProto],
|
|
833
|
+
quantization_annotations: dict[str, onnx.TensorAnnotation],
|
|
834
|
+
) -> None:
|
|
835
|
+
"""Declare outputs for a node in the current graph scope.
|
|
836
|
+
|
|
837
|
+
This is necessary to handle the case where a node in a subgraph uses a value that is declared
|
|
838
|
+
out of order in the outer graph. Declaring the values first allows us to find the values later
|
|
839
|
+
when deserializing the nodes in subgraphs.
|
|
840
|
+
|
|
841
|
+
Args:
|
|
842
|
+
proto: The ONNX NodeProto to declare outputs for.
|
|
843
|
+
current_value_scope: The current scope of values, mapping value names to their corresponding Value objects.
|
|
844
|
+
value_info: A dictionary mapping value names to their corresponding ValueInfoProto.
|
|
845
|
+
quantization_annotations: A dictionary mapping tensor names to their corresponding TensorAnnotation.
|
|
846
|
+
|
|
847
|
+
Raises:
|
|
848
|
+
ValueError: If an output name is redeclared in the current graph scope.
|
|
849
|
+
"""
|
|
850
|
+
for output_name in proto.output:
|
|
851
|
+
if output_name == "":
|
|
852
|
+
continue
|
|
853
|
+
if output_name in current_value_scope:
|
|
854
|
+
raise ValueError(
|
|
855
|
+
f"Output '{output_name}' is redeclared in the current graph scope. "
|
|
856
|
+
f"Original declaration {current_value_scope[output_name]}. "
|
|
857
|
+
f"New declaration: by operator '{proto.op_type}' of node '{proto.name}'. "
|
|
858
|
+
"The model is invalid"
|
|
859
|
+
)
|
|
860
|
+
|
|
861
|
+
# Create the value and add it to the current scope.
|
|
862
|
+
value = _core.Value(name=output_name)
|
|
863
|
+
current_value_scope[output_name] = value
|
|
864
|
+
# Fill in shape/type information if they exist
|
|
865
|
+
if output_name in value_info:
|
|
866
|
+
deserialize_value_info_proto(value_info[output_name], value)
|
|
867
|
+
else:
|
|
868
|
+
logger.debug(
|
|
869
|
+
"ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
|
|
870
|
+
output_name,
|
|
871
|
+
proto.name,
|
|
872
|
+
proto.op_type,
|
|
873
|
+
)
|
|
874
|
+
if output_name in quantization_annotations:
|
|
875
|
+
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
|
|
876
|
+
|
|
877
|
+
|
|
878
|
+
@_capture_errors(lambda proto: proto.name)
|
|
879
|
+
def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
880
|
+
"""Deserialize an ONNX FunctionProto into an IR Function.
|
|
881
|
+
|
|
882
|
+
Args:
|
|
883
|
+
proto: The ONNX FunctionProto to deserialize.
|
|
884
|
+
|
|
885
|
+
Returns:
|
|
886
|
+
An IR Function object representing the ONNX function.
|
|
887
|
+
"""
|
|
888
|
+
inputs = [_core.Value(name=name) for name in proto.input]
|
|
889
|
+
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
890
|
+
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
|
|
891
|
+
|
|
892
|
+
for node in proto.node:
|
|
893
|
+
_declare_node_outputs(
|
|
894
|
+
node,
|
|
895
|
+
values,
|
|
896
|
+
value_info=value_info,
|
|
897
|
+
quantization_annotations={},
|
|
898
|
+
)
|
|
899
|
+
|
|
900
|
+
nodes = [
|
|
901
|
+
_deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
|
|
902
|
+
for node in proto.node
|
|
903
|
+
]
|
|
904
|
+
outputs = [values[name] for name in proto.output]
|
|
905
|
+
graph = _core.Graph(
|
|
906
|
+
inputs,
|
|
907
|
+
outputs,
|
|
908
|
+
nodes=nodes,
|
|
909
|
+
initializers=(),
|
|
910
|
+
doc_string=_get_field(proto, "doc_string"),
|
|
911
|
+
opset_imports=deserialize_opset_import(proto.opset_import),
|
|
912
|
+
name=(
|
|
913
|
+
f"{proto.name}_{proto.domain}" + f"__{proto.overload}"
|
|
914
|
+
if hasattr(proto, "overload") and proto.overload
|
|
915
|
+
else ""
|
|
916
|
+
),
|
|
917
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
918
|
+
)
|
|
919
|
+
attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
|
|
920
|
+
# Attributes without defaults
|
|
921
|
+
attributes += [
|
|
922
|
+
_core.Attr(name, _enums.AttributeType.UNDEFINED, None) for name in proto.attribute
|
|
923
|
+
]
|
|
924
|
+
return _core.Function(
|
|
925
|
+
domain=proto.domain,
|
|
926
|
+
name=proto.name,
|
|
927
|
+
overload=getattr(proto, "overload", ""),
|
|
928
|
+
graph=graph,
|
|
929
|
+
attributes=attributes,
|
|
930
|
+
)
|
|
931
|
+
|
|
932
|
+
|
|
933
|
+
@_capture_errors(lambda proto, value: str(proto))
|
|
934
|
+
def deserialize_value_info_proto(
|
|
935
|
+
proto: onnx.ValueInfoProto, value: _core.Value | None
|
|
936
|
+
) -> _core.Value:
|
|
937
|
+
"""Deserialize an ONNX ValueInfoProto into an IR Value.
|
|
938
|
+
|
|
939
|
+
Args:
|
|
940
|
+
proto: The ONNX ValueInfoProto to deserialize.
|
|
941
|
+
value: An existing Value to update, or None to create a new one.
|
|
942
|
+
|
|
943
|
+
Returns:
|
|
944
|
+
An IR Value object with type and shape information populated from the proto.
|
|
945
|
+
"""
|
|
946
|
+
if value is None:
|
|
947
|
+
value = _core.Value(name=proto.name)
|
|
948
|
+
value.shape = deserialize_type_proto_for_shape(proto.type)
|
|
949
|
+
value.type = deserialize_type_proto_for_type(proto.type)
|
|
950
|
+
metadata_props = deserialize_metadata_props(proto.metadata_props)
|
|
951
|
+
if metadata_props is not None:
|
|
952
|
+
value.metadata_props.update(metadata_props)
|
|
953
|
+
value.doc_string = _get_field(proto, "doc_string")
|
|
954
|
+
return value
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
@_capture_errors(lambda proto, value: str(proto))
|
|
958
|
+
def _deserialize_quantization_annotation(
|
|
959
|
+
proto: onnx.TensorAnnotation, value: _core.Value
|
|
960
|
+
) -> None:
|
|
961
|
+
"""Deserialize a quantization_annotation as TensorAnnotation into a Value.
|
|
962
|
+
|
|
963
|
+
This function is marked private because we don't expect users to call it directly.
|
|
964
|
+
"""
|
|
965
|
+
value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps(
|
|
966
|
+
proto.quant_parameter_tensor_names
|
|
967
|
+
)
|
|
968
|
+
|
|
969
|
+
|
|
970
|
+
@_capture_errors(str)
|
|
971
|
+
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
972
|
+
"""Deserialize an ONNX TensorShapeProto into an IR Shape.
|
|
973
|
+
|
|
974
|
+
Args:
|
|
975
|
+
proto: The ONNX TensorShapeProto to deserialize.
|
|
976
|
+
|
|
977
|
+
Returns:
|
|
978
|
+
An IR Shape object representing the tensor shape.
|
|
979
|
+
"""
|
|
980
|
+
# This logic handles when the shape is [] as well
|
|
981
|
+
dim_protos = proto.dim
|
|
982
|
+
deserialized_dim_denotations = [
|
|
983
|
+
deserialize_dimension(dim_proto) for dim_proto in dim_protos
|
|
984
|
+
]
|
|
985
|
+
dims = [dim for dim, _ in deserialized_dim_denotations]
|
|
986
|
+
denotations = [denotation for _, denotation in deserialized_dim_denotations]
|
|
987
|
+
return _core.Shape(dims, denotations=denotations, frozen=True)
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
@_capture_errors(str)
|
|
991
|
+
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
|
|
992
|
+
"""Extract and deserialize shape information from an ONNX TypeProto.
|
|
993
|
+
|
|
994
|
+
Args:
|
|
995
|
+
proto: The ONNX TypeProto to extract shape from.
|
|
996
|
+
|
|
997
|
+
Returns:
|
|
998
|
+
An IR Shape object if shape information is present, None otherwise.
|
|
999
|
+
"""
|
|
1000
|
+
if proto.HasField("tensor_type"):
|
|
1001
|
+
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
|
|
1002
|
+
return None
|
|
1003
|
+
return deserialize_tensor_shape(shape_proto)
|
|
1004
|
+
if proto.HasField("sparse_tensor_type"):
|
|
1005
|
+
if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
|
|
1006
|
+
return None
|
|
1007
|
+
return deserialize_tensor_shape(shape_proto)
|
|
1008
|
+
if proto.HasField("sequence_type"):
|
|
1009
|
+
if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
|
|
1010
|
+
return None
|
|
1011
|
+
return deserialize_type_proto_for_shape(elem_type)
|
|
1012
|
+
if proto.HasField("optional_type"):
|
|
1013
|
+
if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
|
|
1014
|
+
return None
|
|
1015
|
+
return deserialize_type_proto_for_shape(elem_type)
|
|
1016
|
+
if proto.HasField("map_type"):
|
|
1017
|
+
# TODO(justinchuby): Do we need to support map types?
|
|
1018
|
+
raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}")
|
|
1019
|
+
|
|
1020
|
+
return None
|
|
1021
|
+
|
|
1022
|
+
|
|
1023
|
+
@_capture_errors(str)
|
|
1024
|
+
def deserialize_type_proto_for_type(
|
|
1025
|
+
proto: onnx.TypeProto,
|
|
1026
|
+
) -> _protocols.TypeProtocol | None:
|
|
1027
|
+
"""Extract and deserialize type information from an ONNX TypeProto.
|
|
1028
|
+
|
|
1029
|
+
Args:
|
|
1030
|
+
proto: The ONNX TypeProto to extract type from.
|
|
1031
|
+
|
|
1032
|
+
Returns:
|
|
1033
|
+
An IR type object (TensorType, SequenceType, etc.) if type information is present, None otherwise.
|
|
1034
|
+
"""
|
|
1035
|
+
denotation = _get_field(proto, "denotation")
|
|
1036
|
+
if proto.HasField("tensor_type"):
|
|
1037
|
+
if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None:
|
|
1038
|
+
return None
|
|
1039
|
+
return _core.TensorType(_enums.DataType(elem_type), denotation=denotation)
|
|
1040
|
+
if proto.HasField("sparse_tensor_type"):
|
|
1041
|
+
if (elem_type := _get_field(proto.sparse_tensor_type, "elem_type")) is None:
|
|
1042
|
+
return None
|
|
1043
|
+
return _core.SparseTensorType(_enums.DataType(elem_type), denotation=denotation)
|
|
1044
|
+
if proto.HasField("sequence_type"):
|
|
1045
|
+
# FIXME(justinchuby): Allow nested types being None
|
|
1046
|
+
if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
|
|
1047
|
+
raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
|
|
1048
|
+
nested_type = deserialize_type_proto_for_type(elem_type)
|
|
1049
|
+
if nested_type is None:
|
|
1050
|
+
raise ValueError(f"SequenceType must have elem_type set: {proto}")
|
|
1051
|
+
return _core.SequenceType(nested_type, denotation=denotation)
|
|
1052
|
+
if proto.HasField("optional_type"):
|
|
1053
|
+
# FIXME(justinchuby): Allow nested types being None
|
|
1054
|
+
if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
|
|
1055
|
+
raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
|
|
1056
|
+
nested_type = deserialize_type_proto_for_type(elem_type)
|
|
1057
|
+
if nested_type is None:
|
|
1058
|
+
raise ValueError(f"SequenceType must have elem_type set: {proto}")
|
|
1059
|
+
return _core.OptionalType(nested_type, denotation=denotation)
|
|
1060
|
+
if proto.HasField("map_type"):
|
|
1061
|
+
# TODO(justinchuby): Do we need to support map types?
|
|
1062
|
+
raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}")
|
|
1063
|
+
|
|
1064
|
+
return None
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
@_capture_errors(str)
|
|
1068
|
+
def deserialize_dimension(
|
|
1069
|
+
proto: onnx.TensorShapeProto.Dimension,
|
|
1070
|
+
) -> tuple[int | _core.SymbolicDim, str | None]:
|
|
1071
|
+
"""Deserialize a dimension proto into (dimension, denotation).
|
|
1072
|
+
|
|
1073
|
+
Args:
|
|
1074
|
+
proto: The dimension proto to deserialize.
|
|
1075
|
+
|
|
1076
|
+
Returns:
|
|
1077
|
+
A tuple of the dimension and its denotation.
|
|
1078
|
+
"""
|
|
1079
|
+
value_field = proto.WhichOneof("value")
|
|
1080
|
+
denotation = _get_field(proto, "denotation")
|
|
1081
|
+
if value_field is not None:
|
|
1082
|
+
value = getattr(proto, value_field)
|
|
1083
|
+
if value_field == "dim_value":
|
|
1084
|
+
return value, denotation
|
|
1085
|
+
if value_field == "dim_param":
|
|
1086
|
+
return _core.SymbolicDim(value), denotation
|
|
1087
|
+
return _core.SymbolicDim(None), denotation
|
|
1088
|
+
|
|
1089
|
+
|
|
1090
|
+
@_capture_errors(lambda proto, base_path: proto.name)
|
|
1091
|
+
def deserialize_tensor(
|
|
1092
|
+
proto: onnx.TensorProto, base_path: str | os.PathLike = ""
|
|
1093
|
+
) -> _protocols.TensorProtocol:
|
|
1094
|
+
# TODO: Sanitize base_path
|
|
1095
|
+
if proto.data_location == onnx.TensorProto.EXTERNAL:
|
|
1096
|
+
external_info = onnx.external_data_helper.ExternalDataInfo(proto)
|
|
1097
|
+
return _core.ExternalTensor(
|
|
1098
|
+
external_info.location,
|
|
1099
|
+
offset=external_info.offset,
|
|
1100
|
+
length=external_info.length,
|
|
1101
|
+
dtype=_enums.DataType(proto.data_type),
|
|
1102
|
+
base_dir=base_path,
|
|
1103
|
+
name=_get_field(proto, "name"),
|
|
1104
|
+
shape=_core.Shape(proto.dims),
|
|
1105
|
+
doc_string=_get_field(proto, "doc_string"),
|
|
1106
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
1107
|
+
)
|
|
1108
|
+
if proto.data_type == _enums.DataType.STRING:
|
|
1109
|
+
name = _get_field(proto, "name")
|
|
1110
|
+
doc_string = _get_field(proto, "doc_string")
|
|
1111
|
+
metadata_props = deserialize_metadata_props(proto.metadata_props)
|
|
1112
|
+
return _core.StringTensor(
|
|
1113
|
+
proto.string_data,
|
|
1114
|
+
shape=_core.Shape(proto.dims),
|
|
1115
|
+
name=name,
|
|
1116
|
+
doc_string=doc_string,
|
|
1117
|
+
metadata_props=metadata_props,
|
|
1118
|
+
)
|
|
1119
|
+
return TensorProtoTensor(proto)
|
|
1120
|
+
|
|
1121
|
+
|
|
1122
|
+
def deserialize_metadata_props(
|
|
1123
|
+
proto: Sequence[onnx.StringStringEntryProto],
|
|
1124
|
+
) -> dict[str, str] | None:
|
|
1125
|
+
if len(proto) == 0:
|
|
1126
|
+
# Avoid creating an empty dictionary to save memory
|
|
1127
|
+
return None
|
|
1128
|
+
return {entry.key: entry.value for entry in proto}
|
|
1129
|
+
|
|
1130
|
+
|
|
1131
|
+
_deserialize_string_string_maps = deserialize_metadata_props
|
|
1132
|
+
|
|
1133
|
+
|
|
1134
|
+
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
|
|
1135
|
+
"""Deserialize an ONNX AttributeProto into an IR Attribute.
|
|
1136
|
+
|
|
1137
|
+
Args:
|
|
1138
|
+
proto: The ONNX AttributeProto to deserialize.
|
|
1139
|
+
|
|
1140
|
+
Returns:
|
|
1141
|
+
An IR Attribute object representing the ONNX attribute.
|
|
1142
|
+
"""
|
|
1143
|
+
return _deserialize_attribute(proto, [])
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
@_capture_errors(lambda proto, scoped_values: str(proto))
|
|
1147
|
+
def _deserialize_attribute(
|
|
1148
|
+
proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
|
|
1149
|
+
) -> _core.Attr:
|
|
1150
|
+
name = proto.name
|
|
1151
|
+
doc_string = _get_field(proto, "doc_string")
|
|
1152
|
+
type_ = _enums.AttributeType(proto.type)
|
|
1153
|
+
ref_attr_name = _get_field(proto, "ref_attr_name")
|
|
1154
|
+
if ref_attr_name:
|
|
1155
|
+
return _core.RefAttr(name, ref_attr_name, type_, doc_string=doc_string)
|
|
1156
|
+
|
|
1157
|
+
if type_ == _enums.AttributeType.INT:
|
|
1158
|
+
return _core.AttrInt64(name, proto.i, doc_string=doc_string)
|
|
1159
|
+
if type_ == _enums.AttributeType.FLOAT:
|
|
1160
|
+
return _core.AttrFloat32(name, proto.f, doc_string=doc_string)
|
|
1161
|
+
if type_ == _enums.AttributeType.STRING:
|
|
1162
|
+
try:
|
|
1163
|
+
return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
|
|
1164
|
+
except UnicodeDecodeError:
|
|
1165
|
+
# Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
|
|
1166
|
+
# for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
|
|
1167
|
+
logger.warning(
|
|
1168
|
+
"Attribute %r contains invalid UTF-8 bytes. ONNX spec requires string attributes "
|
|
1169
|
+
"to be UTF-8 encoded so the model is invalid. We will skip decoding the attribute and "
|
|
1170
|
+
"use the bytes as attribute value",
|
|
1171
|
+
name,
|
|
1172
|
+
)
|
|
1173
|
+
return _core.Attr(name, type_, proto.s, doc_string=doc_string)
|
|
1174
|
+
|
|
1175
|
+
if type_ == _enums.AttributeType.INTS:
|
|
1176
|
+
return _core.AttrInt64s(name, proto.ints, doc_string=doc_string)
|
|
1177
|
+
if type_ == _enums.AttributeType.FLOATS:
|
|
1178
|
+
return _core.AttrFloat32s(name, proto.floats, doc_string=doc_string)
|
|
1179
|
+
if type_ == _enums.AttributeType.STRINGS:
|
|
1180
|
+
return _core.AttrStrings(
|
|
1181
|
+
name, [s.decode("utf-8") for s in proto.strings], doc_string=doc_string
|
|
1182
|
+
)
|
|
1183
|
+
if type_ == _enums.AttributeType.TENSOR:
|
|
1184
|
+
return _core.AttrTensor(name, deserialize_tensor(proto.t), doc_string=doc_string)
|
|
1185
|
+
if type_ == _enums.AttributeType.GRAPH:
|
|
1186
|
+
return _core.AttrGraph(
|
|
1187
|
+
name, _deserialize_graph(proto.g, scoped_values), doc_string=doc_string
|
|
1188
|
+
)
|
|
1189
|
+
if type_ == _enums.AttributeType.TENSORS:
|
|
1190
|
+
return _core.AttrTensors(
|
|
1191
|
+
name,
|
|
1192
|
+
[deserialize_tensor(t) for t in proto.tensors],
|
|
1193
|
+
doc_string=doc_string,
|
|
1194
|
+
)
|
|
1195
|
+
if type_ == _enums.AttributeType.GRAPHS:
|
|
1196
|
+
return _core.AttrGraphs(
|
|
1197
|
+
name,
|
|
1198
|
+
[_deserialize_graph(g, scoped_values) for g in proto.graphs],
|
|
1199
|
+
doc_string=doc_string,
|
|
1200
|
+
)
|
|
1201
|
+
if type_ == _enums.AttributeType.SPARSE_TENSOR:
|
|
1202
|
+
raise NotImplementedError(
|
|
1203
|
+
f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
|
|
1204
|
+
)
|
|
1205
|
+
if type_ == _enums.AttributeType.SPARSE_TENSORS:
|
|
1206
|
+
raise NotImplementedError(
|
|
1207
|
+
f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
|
|
1208
|
+
)
|
|
1209
|
+
if type_ == _enums.AttributeType.TYPE_PROTO:
|
|
1210
|
+
ir_type = deserialize_type_proto_for_type(proto.tp)
|
|
1211
|
+
shape = deserialize_type_proto_for_shape(proto.tp)
|
|
1212
|
+
return _core.AttrTypeProto(
|
|
1213
|
+
name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string
|
|
1214
|
+
)
|
|
1215
|
+
if type_ == _enums.AttributeType.TYPE_PROTOS:
|
|
1216
|
+
type_and_shapes = []
|
|
1217
|
+
for type_proto in proto.type_protos:
|
|
1218
|
+
ir_type = deserialize_type_proto_for_type(type_proto)
|
|
1219
|
+
shape = deserialize_type_proto_for_shape(type_proto)
|
|
1220
|
+
type_and_shapes.append(_core.TypeAndShape(ir_type, shape))
|
|
1221
|
+
return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string)
|
|
1222
|
+
if type_ == _enums.AttributeType.UNDEFINED:
|
|
1223
|
+
return _core.Attr(name, type_, None, doc_string=doc_string)
|
|
1224
|
+
raise ValueError(f"Unsupported attribute type: '{type_}'")
|
|
1225
|
+
|
|
1226
|
+
|
|
1227
|
+
def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
|
|
1228
|
+
"""Deserialize an ONNX NodeProto into an IR Node.
|
|
1229
|
+
|
|
1230
|
+
Args:
|
|
1231
|
+
proto: The ONNX NodeProto to deserialize.
|
|
1232
|
+
|
|
1233
|
+
Returns:
|
|
1234
|
+
An IR Node object representing the ONNX node.
|
|
1235
|
+
"""
|
|
1236
|
+
value_scope: dict[str, _core.Value] = {}
|
|
1237
|
+
_declare_node_outputs(
|
|
1238
|
+
proto,
|
|
1239
|
+
value_scope,
|
|
1240
|
+
value_info={},
|
|
1241
|
+
quantization_annotations={},
|
|
1242
|
+
)
|
|
1243
|
+
return _deserialize_node(
|
|
1244
|
+
proto, scoped_values=[value_scope], value_info={}, quantization_annotations={}
|
|
1245
|
+
)
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto))
|
|
1249
|
+
def _deserialize_node(
|
|
1250
|
+
proto: onnx.NodeProto,
|
|
1251
|
+
scoped_values: list[dict[str, _core.Value]],
|
|
1252
|
+
value_info: dict[str, onnx.ValueInfoProto],
|
|
1253
|
+
quantization_annotations: dict[str, onnx.TensorAnnotation],
|
|
1254
|
+
) -> _core.Node:
|
|
1255
|
+
node_inputs: list[_core.Value | None] = []
|
|
1256
|
+
for input_name in proto.input:
|
|
1257
|
+
if input_name == "":
|
|
1258
|
+
# Empty input
|
|
1259
|
+
node_inputs.append(None)
|
|
1260
|
+
continue
|
|
1261
|
+
|
|
1262
|
+
# Find the input in all value scopes
|
|
1263
|
+
found = False
|
|
1264
|
+
for values in reversed(scoped_values):
|
|
1265
|
+
if input_name not in values:
|
|
1266
|
+
continue
|
|
1267
|
+
|
|
1268
|
+
node_inputs.append(values[input_name])
|
|
1269
|
+
found = True
|
|
1270
|
+
del values # Remove the reference so it is not used by mistake
|
|
1271
|
+
break
|
|
1272
|
+
if not found:
|
|
1273
|
+
# If the input is not found, we know the graph is invalid because the value
|
|
1274
|
+
# is not declared. We will still create a new input for the node so that
|
|
1275
|
+
# it can be fixed later.
|
|
1276
|
+
logger.warning(
|
|
1277
|
+
"Input '%s' of node '%s' (%s::%s:%s) cannot be found in any scope. "
|
|
1278
|
+
"The model is invalid but we will still create a new input for the node (current depth: %s)",
|
|
1279
|
+
input_name,
|
|
1280
|
+
proto.name,
|
|
1281
|
+
proto.domain,
|
|
1282
|
+
proto.op_type,
|
|
1283
|
+
getattr(proto, "overload", ""),
|
|
1284
|
+
len(scoped_values),
|
|
1285
|
+
)
|
|
1286
|
+
if len(scoped_values) > 1:
|
|
1287
|
+
logger.warning(
|
|
1288
|
+
"Caveat: The value is created in the subgraph. If "
|
|
1289
|
+
"the node is referencing a value that is not in the current graph, "
|
|
1290
|
+
"it is impossible to create it in the correct scope.",
|
|
1291
|
+
)
|
|
1292
|
+
value = _core.Value(name=input_name)
|
|
1293
|
+
# Fill in shape/type information if they exist
|
|
1294
|
+
if input_name in value_info:
|
|
1295
|
+
deserialize_value_info_proto(value_info[input_name], value)
|
|
1296
|
+
if input_name in quantization_annotations:
|
|
1297
|
+
_deserialize_quantization_annotation(
|
|
1298
|
+
quantization_annotations[input_name], value
|
|
1299
|
+
)
|
|
1300
|
+
node_inputs.append(value)
|
|
1301
|
+
# We can only create the value in the current scope. If the subgraph is
|
|
1302
|
+
# referencing a value that is not in the current scope, it is impossible
|
|
1303
|
+
# to create it in the correct scope.
|
|
1304
|
+
scoped_values[-1][input_name] = value
|
|
1305
|
+
|
|
1306
|
+
# Build the output values for the node.
|
|
1307
|
+
node_outputs: list[_core.Value] = []
|
|
1308
|
+
for output_name in proto.output:
|
|
1309
|
+
if output_name == "":
|
|
1310
|
+
# Empty output
|
|
1311
|
+
node_outputs.append(_core.Value(name=""))
|
|
1312
|
+
continue
|
|
1313
|
+
|
|
1314
|
+
# The outputs should already be declared in the current scope by _declare_node_outputs.
|
|
1315
|
+
#
|
|
1316
|
+
# When the graph is unsorted, we may be able to find the output already created
|
|
1317
|
+
# as an input to some other nodes in the current scope.
|
|
1318
|
+
# Note that a value is always owned by the producing node. Even though a value
|
|
1319
|
+
# can be created when parsing inputs of other nodes, the new node created here
|
|
1320
|
+
# that produces the value will assume ownership. It is then impossible to transfer
|
|
1321
|
+
# the ownership to any other node.
|
|
1322
|
+
#
|
|
1323
|
+
# The output can only be found in the current scope. It is impossible for
|
|
1324
|
+
# a node to produce an output that is not in its own scope.
|
|
1325
|
+
current_scope = scoped_values[-1]
|
|
1326
|
+
assert output_name in current_scope, (
|
|
1327
|
+
f"Output '{output_name}' not found in the current scope. This is unexpected"
|
|
1328
|
+
)
|
|
1329
|
+
value = current_scope[output_name]
|
|
1330
|
+
node_outputs.append(value)
|
|
1331
|
+
return _core.Node(
|
|
1332
|
+
proto.domain,
|
|
1333
|
+
proto.op_type,
|
|
1334
|
+
node_inputs,
|
|
1335
|
+
[_deserialize_attribute(a, scoped_values) for a in proto.attribute],
|
|
1336
|
+
overload=getattr(proto, "overload", ""),
|
|
1337
|
+
outputs=node_outputs,
|
|
1338
|
+
name=proto.name,
|
|
1339
|
+
doc_string=_get_field(proto, "doc_string"),
|
|
1340
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
1341
|
+
)
|
|
1342
|
+
|
|
1343
|
+
|
|
1344
|
+
# Serialization
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto:
|
|
1348
|
+
"""Serialize an IR Model to an ONNX ModelProto.
|
|
1349
|
+
|
|
1350
|
+
Args:
|
|
1351
|
+
model: The IR Model to serialize.
|
|
1352
|
+
|
|
1353
|
+
Returns:
|
|
1354
|
+
The serialized ONNX ModelProto object.
|
|
1355
|
+
"""
|
|
1356
|
+
return serialize_model_into(onnx.ModelProto(), from_=model)
|
|
1357
|
+
|
|
1358
|
+
|
|
1359
|
+
@_capture_errors(
|
|
1360
|
+
lambda model_proto, from_: (
|
|
1361
|
+
f"ir_version={from_.ir_version}, producer_name={from_.producer_name}, "
|
|
1362
|
+
f"producer_version={from_.producer_version}, domain={from_.domain}, "
|
|
1363
|
+
)
|
|
1364
|
+
)
|
|
1365
|
+
def serialize_model_into(
|
|
1366
|
+
model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol
|
|
1367
|
+
) -> onnx.ModelProto:
|
|
1368
|
+
"""Serialize an IR model to an ONNX model proto."""
|
|
1369
|
+
model_proto.ir_version = from_.ir_version
|
|
1370
|
+
if from_.producer_name:
|
|
1371
|
+
model_proto.producer_name = from_.producer_name
|
|
1372
|
+
if from_.producer_version:
|
|
1373
|
+
model_proto.producer_version = from_.producer_version
|
|
1374
|
+
if from_.domain:
|
|
1375
|
+
model_proto.domain = from_.domain
|
|
1376
|
+
if from_.model_version:
|
|
1377
|
+
model_proto.model_version = from_.model_version
|
|
1378
|
+
if from_.doc_string:
|
|
1379
|
+
model_proto.doc_string = from_.doc_string
|
|
1380
|
+
# Sort names for deterministic serialization
|
|
1381
|
+
_serialize_opset_imports_into(model_proto.opset_import, from_.opset_imports)
|
|
1382
|
+
if from_.metadata_props:
|
|
1383
|
+
_serialize_metadata_props_into(model_proto.metadata_props, from_.metadata_props)
|
|
1384
|
+
serialize_graph_into(model_proto.graph, from_.graph)
|
|
1385
|
+
|
|
1386
|
+
create_value_info_in_functions = from_.ir_version >= _FUNCTION_VALUE_INFO_SUPPORTED_VERSION
|
|
1387
|
+
for func in from_.functions.values():
|
|
1388
|
+
serialize_function_into(
|
|
1389
|
+
model_proto.functions.add(),
|
|
1390
|
+
from_=func,
|
|
1391
|
+
create_value_info=create_value_info_in_functions,
|
|
1392
|
+
)
|
|
1393
|
+
if not create_value_info_in_functions:
|
|
1394
|
+
# Create them in the main graph instead
|
|
1395
|
+
_serialize_experimental_value_info_for_function_ir9_into(model_proto.graph, func)
|
|
1396
|
+
return model_proto
|
|
1397
|
+
|
|
1398
|
+
|
|
1399
|
+
def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool:
|
|
1400
|
+
"""Check if value info should be created for a value.
|
|
1401
|
+
|
|
1402
|
+
Args:
|
|
1403
|
+
value: The value to check.
|
|
1404
|
+
|
|
1405
|
+
Returns:
|
|
1406
|
+
True if value info should be created for the value.
|
|
1407
|
+
"""
|
|
1408
|
+
# No need to serialize value info if it is not set
|
|
1409
|
+
if (
|
|
1410
|
+
value.shape is None
|
|
1411
|
+
and value.type is None
|
|
1412
|
+
and not value.metadata_props
|
|
1413
|
+
and not value.doc_string
|
|
1414
|
+
):
|
|
1415
|
+
return False
|
|
1416
|
+
if not value.name:
|
|
1417
|
+
logger.debug("Did not serialize '%s' because its name is empty", value)
|
|
1418
|
+
return False
|
|
1419
|
+
return True
|
|
1420
|
+
|
|
1421
|
+
|
|
1422
|
+
def _serialize_experimental_value_info_for_function_ir9_into(
|
|
1423
|
+
graph_proto: onnx.GraphProto, function: _protocols.FunctionProtocol
|
|
1424
|
+
) -> None:
|
|
1425
|
+
"""Serialize value info for functions in an experimental format for IR version 9.
|
|
1426
|
+
|
|
1427
|
+
Because IRv9 and older does not have ValueInfoProto for functions, we give the value info
|
|
1428
|
+
special names and store them in the main graph instead.
|
|
1429
|
+
|
|
1430
|
+
The experimental format is:
|
|
1431
|
+
{function_domain}::{function_name}/{value_name}
|
|
1432
|
+
|
|
1433
|
+
Args:
|
|
1434
|
+
graph_proto: The graph proto to create ValueInfoProto in.
|
|
1435
|
+
function: The function to serialize.
|
|
1436
|
+
"""
|
|
1437
|
+
# TODO(justinchuby): In the future, we can decide if it is a good idea to simply iterate over
|
|
1438
|
+
# all values in the function and call serialize_value_into instead.
|
|
1439
|
+
function_qualified_name = f"{function.domain}::{function.name}"
|
|
1440
|
+
|
|
1441
|
+
def format_name(value_name: str) -> str:
|
|
1442
|
+
return f"{function_qualified_name}/{value_name}"
|
|
1443
|
+
|
|
1444
|
+
for input in function.inputs:
|
|
1445
|
+
if not input.name:
|
|
1446
|
+
logger.warning(
|
|
1447
|
+
"Function '%s': Value name not set for function input: %s",
|
|
1448
|
+
function_qualified_name,
|
|
1449
|
+
input,
|
|
1450
|
+
)
|
|
1451
|
+
continue
|
|
1452
|
+
if not _should_create_value_info_for_value(input):
|
|
1453
|
+
# No need to serialize value info if it is not set
|
|
1454
|
+
continue
|
|
1455
|
+
serialize_value_into(graph_proto.value_info.add(), input, name=format_name(input.name))
|
|
1456
|
+
for node in function:
|
|
1457
|
+
for node_output in node.outputs:
|
|
1458
|
+
if not node_output.name:
|
|
1459
|
+
logger.warning(
|
|
1460
|
+
"Function '%s': Value name not set for node output: %s",
|
|
1461
|
+
function_qualified_name,
|
|
1462
|
+
node_output,
|
|
1463
|
+
)
|
|
1464
|
+
continue
|
|
1465
|
+
if not _should_create_value_info_for_value(node_output):
|
|
1466
|
+
# No need to serialize value info if it is not set
|
|
1467
|
+
continue
|
|
1468
|
+
serialize_value_into(
|
|
1469
|
+
graph_proto.value_info.add(),
|
|
1470
|
+
node_output,
|
|
1471
|
+
name=format_name(node_output.name),
|
|
1472
|
+
)
|
|
1473
|
+
|
|
1474
|
+
|
|
1475
|
+
def _serialize_opset_imports_into(
|
|
1476
|
+
opset_ids: proto_containers.RepeatedCompositeFieldContainer[onnx.OperatorSetIdProto],
|
|
1477
|
+
from_: Mapping[str, int],
|
|
1478
|
+
) -> None:
|
|
1479
|
+
"""Serialize opset imports into a repeated field of OperatorSetId protos.
|
|
1480
|
+
|
|
1481
|
+
Args:
|
|
1482
|
+
opset_ids: The repeated field to serialize into.
|
|
1483
|
+
from_: The mapping of opset domains to versions to serialize.
|
|
1484
|
+
"""
|
|
1485
|
+
# Sort names for deterministic serialization
|
|
1486
|
+
for domain, version in from_.items():
|
|
1487
|
+
opset_ids.add(domain=domain, version=version)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
def _serialize_string_string_maps(
|
|
1491
|
+
string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
|
|
1492
|
+
onnx.StringStringEntryProto
|
|
1493
|
+
],
|
|
1494
|
+
from_: Mapping[str, str],
|
|
1495
|
+
) -> None:
|
|
1496
|
+
"""Serialize a <str, str> mapping into a repeated field of string-string entries.
|
|
1497
|
+
|
|
1498
|
+
Args:
|
|
1499
|
+
string_string_entries: The repeated field to serialize into.
|
|
1500
|
+
from_: The mapping of a <str, str> mapping to serialize.
|
|
1501
|
+
"""
|
|
1502
|
+
# Sort names for deterministic serialization
|
|
1503
|
+
for key in sorted(from_):
|
|
1504
|
+
string_string_entries.add(key=key, value=from_[key])
|
|
1505
|
+
|
|
1506
|
+
|
|
1507
|
+
_serialize_metadata_props_into = _serialize_string_string_maps
|
|
1508
|
+
|
|
1509
|
+
|
|
1510
|
+
def _maybe_add_quantization_annotation(
|
|
1511
|
+
graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol
|
|
1512
|
+
) -> None:
|
|
1513
|
+
if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD):
|
|
1514
|
+
_serialize_tensor_annotation_into(
|
|
1515
|
+
graph_proto.quantization_annotation.add(), value.name, quantization_annotation
|
|
1516
|
+
)
|
|
1517
|
+
|
|
1518
|
+
|
|
1519
|
+
def _serialize_tensor_annotation_into(
|
|
1520
|
+
tensor_annotation_proto: onnx.TensorAnnotation,
|
|
1521
|
+
tensor_name: str,
|
|
1522
|
+
quant_parameter_tensor_names: dict[str, str],
|
|
1523
|
+
) -> None:
|
|
1524
|
+
tensor_annotation_proto.tensor_name = tensor_name
|
|
1525
|
+
_serialize_string_string_maps(
|
|
1526
|
+
tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
|
|
1530
|
+
def serialize_graph(
|
|
1531
|
+
graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
|
|
1532
|
+
) -> onnx.GraphProto:
|
|
1533
|
+
"""Serializes the given graph into an :class:`onnx.GraphProto`.
|
|
1534
|
+
|
|
1535
|
+
When the graph initializers do not have `const_value` set, they will be skipped.
|
|
1536
|
+
|
|
1537
|
+
Args:
|
|
1538
|
+
graph: The graph to be serialized.
|
|
1539
|
+
|
|
1540
|
+
Returns:
|
|
1541
|
+
The serialized ONNX GraphProto object.
|
|
1542
|
+
"""
|
|
1543
|
+
graph_proto = onnx.GraphProto()
|
|
1544
|
+
serialize_graph_into(graph_proto, from_=graph)
|
|
1545
|
+
return graph_proto
|
|
1546
|
+
|
|
1547
|
+
|
|
1548
|
+
@_capture_errors(
|
|
1549
|
+
lambda graph_proto, from_: (
|
|
1550
|
+
f"name={from_.name}, doc_string={from_.doc_string}, "
|
|
1551
|
+
f"len(inputs)={len(from_.inputs)}, len(initializers)={len(from_.initializers)}, "
|
|
1552
|
+
f"len(nodes)={len(from_)}, len(outputs)={len(from_.outputs)}, metadata_props={from_.metadata_props}"
|
|
1553
|
+
)
|
|
1554
|
+
)
|
|
1555
|
+
def serialize_graph_into(
|
|
1556
|
+
graph_proto: onnx.GraphProto,
|
|
1557
|
+
from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
|
|
1558
|
+
) -> None:
|
|
1559
|
+
if from_.name:
|
|
1560
|
+
graph_proto.name = from_.name
|
|
1561
|
+
if from_.doc_string:
|
|
1562
|
+
graph_proto.doc_string = from_.doc_string
|
|
1563
|
+
for input_ in from_.inputs:
|
|
1564
|
+
serialize_value_into(graph_proto.input.add(), input_)
|
|
1565
|
+
if input_.name not in from_.initializers:
|
|
1566
|
+
# Annotations for initializers will be added below to avoid double adding
|
|
1567
|
+
_maybe_add_quantization_annotation(graph_proto, input_)
|
|
1568
|
+
input_names = {input_.name for input_ in from_.inputs}
|
|
1569
|
+
# TODO(justinchuby): Support sparse_initializer
|
|
1570
|
+
for value in from_.initializers.values():
|
|
1571
|
+
_maybe_add_quantization_annotation(graph_proto, value)
|
|
1572
|
+
if _should_create_value_info_for_value(value) and value.name not in input_names:
|
|
1573
|
+
# Serialize information about all initializers into value_info,
|
|
1574
|
+
# except for those that are also graph inputs
|
|
1575
|
+
serialize_value_into(graph_proto.value_info.add(), value)
|
|
1576
|
+
if value.const_value is None:
|
|
1577
|
+
# Skip initializers without constant values
|
|
1578
|
+
logger.warning("Initializer '%s' does not have a constant value set.", value.name)
|
|
1579
|
+
continue
|
|
1580
|
+
# Make sure the tensor's name is the same as the value's name
|
|
1581
|
+
value.const_value.name = value.name
|
|
1582
|
+
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
|
|
1583
|
+
for node in from_:
|
|
1584
|
+
serialize_node_into(graph_proto.node.add(), from_=node)
|
|
1585
|
+
for node_output in node.outputs:
|
|
1586
|
+
if node_output.is_graph_output():
|
|
1587
|
+
# No need to serialize info for these outputs because they are handled as graph outputs
|
|
1588
|
+
continue
|
|
1589
|
+
_maybe_add_quantization_annotation(graph_proto, node_output)
|
|
1590
|
+
if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue
|
|
1591
|
+
# No need to serialize value info if it is not set
|
|
1592
|
+
continue
|
|
1593
|
+
else:
|
|
1594
|
+
serialize_value_into(graph_proto.value_info.add(), node_output)
|
|
1595
|
+
for output in from_.outputs:
|
|
1596
|
+
serialize_value_into(graph_proto.output.add(), from_=output)
|
|
1597
|
+
_maybe_add_quantization_annotation(graph_proto, output)
|
|
1598
|
+
if from_.metadata_props:
|
|
1599
|
+
_serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
|
|
1600
|
+
|
|
1601
|
+
|
|
1602
|
+
def serialize_function(
|
|
1603
|
+
function: _protocols.FunctionProtocol, *, create_value_info: bool = True
|
|
1604
|
+
) -> onnx.FunctionProto:
|
|
1605
|
+
"""Serialize an IR function as a FunctionProto.
|
|
1606
|
+
|
|
1607
|
+
Args:
|
|
1608
|
+
function: The function to serialize.
|
|
1609
|
+
create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
|
|
1610
|
+
starting from ONNX IR version 10.
|
|
1611
|
+
"""
|
|
1612
|
+
function_proto = onnx.FunctionProto()
|
|
1613
|
+
serialize_function_into(
|
|
1614
|
+
function_proto, from_=function, create_value_info=create_value_info
|
|
1615
|
+
)
|
|
1616
|
+
return function_proto
|
|
1617
|
+
|
|
1618
|
+
|
|
1619
|
+
@_capture_errors(lambda function_proto, from_, create_value_info: repr(from_))
|
|
1620
|
+
def serialize_function_into(
|
|
1621
|
+
function_proto: onnx.FunctionProto,
|
|
1622
|
+
from_: _protocols.FunctionProtocol,
|
|
1623
|
+
*,
|
|
1624
|
+
create_value_info: bool = True,
|
|
1625
|
+
) -> None:
|
|
1626
|
+
"""Serialize an IR function into a FunctionProto.
|
|
1627
|
+
|
|
1628
|
+
Args:
|
|
1629
|
+
function_proto: The proto to serialize into.
|
|
1630
|
+
from_: The function to serialize.
|
|
1631
|
+
create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
|
|
1632
|
+
starting from ONNX IR version 10.
|
|
1633
|
+
"""
|
|
1634
|
+
if from_.domain:
|
|
1635
|
+
function_proto.domain = from_.domain
|
|
1636
|
+
if from_.name:
|
|
1637
|
+
function_proto.name = from_.name
|
|
1638
|
+
if from_.overload:
|
|
1639
|
+
function_proto.overload = from_.overload
|
|
1640
|
+
if from_.doc_string:
|
|
1641
|
+
function_proto.doc_string = from_.doc_string
|
|
1642
|
+
if from_.opset_imports:
|
|
1643
|
+
# A valid ONNX graph should have at least one opset import, that is
|
|
1644
|
+
# the default ONNX opset.
|
|
1645
|
+
# Here we check for emptiness before serializing to keep the logic consistent
|
|
1646
|
+
_serialize_opset_imports_into(function_proto.opset_import, from_.opset_imports)
|
|
1647
|
+
if from_.metadata_props:
|
|
1648
|
+
_serialize_metadata_props_into(function_proto.metadata_props, from_.metadata_props)
|
|
1649
|
+
for input_ in from_.inputs:
|
|
1650
|
+
function_proto.input.append(input_.name)
|
|
1651
|
+
if not _should_create_value_info_for_value(input_):
|
|
1652
|
+
# No need to serialize value info if it is not set
|
|
1653
|
+
continue
|
|
1654
|
+
if not create_value_info:
|
|
1655
|
+
continue
|
|
1656
|
+
serialize_value_into(function_proto.value_info.add(), input_)
|
|
1657
|
+
for attr in from_.attributes.values():
|
|
1658
|
+
if attr.value is not None:
|
|
1659
|
+
serialize_attribute_into(function_proto.attribute_proto.add(), from_=attr)
|
|
1660
|
+
else:
|
|
1661
|
+
# ONNX does not record type information if the attribute does not have a default
|
|
1662
|
+
function_proto.attribute.append(attr.name)
|
|
1663
|
+
for func_output in from_.outputs:
|
|
1664
|
+
function_proto.output.append(func_output.name)
|
|
1665
|
+
# No need to serialize value info for function outputs because they are
|
|
1666
|
+
# also node outputs
|
|
1667
|
+
for node in from_:
|
|
1668
|
+
serialize_node_into(function_proto.node.add(), from_=node)
|
|
1669
|
+
# Record value info for outputs
|
|
1670
|
+
for node_output in node.outputs:
|
|
1671
|
+
if not _should_create_value_info_for_value(node_output):
|
|
1672
|
+
# No need to serialize value info if it is not set
|
|
1673
|
+
continue
|
|
1674
|
+
if not create_value_info:
|
|
1675
|
+
continue
|
|
1676
|
+
serialize_value_into(function_proto.value_info.add(), node_output)
|
|
1677
|
+
|
|
1678
|
+
|
|
1679
|
+
def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
|
|
1680
|
+
"""Serialize an IR Node to an ONNX NodeProto.
|
|
1681
|
+
|
|
1682
|
+
Args:
|
|
1683
|
+
node: The IR Node to serialize.
|
|
1684
|
+
|
|
1685
|
+
Returns:
|
|
1686
|
+
The serialized ONNX NodeProto object.
|
|
1687
|
+
"""
|
|
1688
|
+
node_proto = onnx.NodeProto()
|
|
1689
|
+
serialize_node_into(node_proto, from_=node)
|
|
1690
|
+
return node_proto
|
|
1691
|
+
|
|
1692
|
+
|
|
1693
|
+
def _remove_trailing_outputs(
|
|
1694
|
+
outputs: Sequence[_protocols.ValueProtocol],
|
|
1695
|
+
) -> Sequence[_protocols.ValueProtocol]:
|
|
1696
|
+
"""Remove trailing outputs that have empty names.
|
|
1697
|
+
|
|
1698
|
+
Args:
|
|
1699
|
+
outputs: The outputs to remove trailing outputs from.
|
|
1700
|
+
|
|
1701
|
+
Returns:
|
|
1702
|
+
The outputs with trailing outputs removed.
|
|
1703
|
+
"""
|
|
1704
|
+
for i, output in enumerate(reversed(outputs)):
|
|
1705
|
+
if output.name:
|
|
1706
|
+
return outputs[: len(outputs) - i]
|
|
1707
|
+
return []
|
|
1708
|
+
|
|
1709
|
+
|
|
1710
|
+
@_capture_errors(lambda node_proto, from_: repr(from_))
|
|
1711
|
+
def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
|
|
1712
|
+
node_proto.op_type = from_.op_type
|
|
1713
|
+
if from_.domain:
|
|
1714
|
+
# If the domain is "", we can assume the default domain and not set it
|
|
1715
|
+
node_proto.domain = from_.domain
|
|
1716
|
+
if from_.name:
|
|
1717
|
+
node_proto.name = from_.name
|
|
1718
|
+
if from_.overload:
|
|
1719
|
+
node_proto.overload = from_.overload
|
|
1720
|
+
if from_.doc_string:
|
|
1721
|
+
node_proto.doc_string = from_.doc_string
|
|
1722
|
+
if from_.metadata_props:
|
|
1723
|
+
_serialize_metadata_props_into(node_proto.metadata_props, from_.metadata_props)
|
|
1724
|
+
for input_ in from_.inputs:
|
|
1725
|
+
if input_ is None:
|
|
1726
|
+
node_proto.input.append("")
|
|
1727
|
+
else:
|
|
1728
|
+
node_proto.input.append(input_.name)
|
|
1729
|
+
|
|
1730
|
+
# Do not include the trailing outputs that have empty names
|
|
1731
|
+
for output in _remove_trailing_outputs(from_.outputs):
|
|
1732
|
+
node_proto.output.append(output.name)
|
|
1733
|
+
|
|
1734
|
+
for attr in from_.attributes.values():
|
|
1735
|
+
if not attr.is_ref():
|
|
1736
|
+
serialize_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
|
|
1737
|
+
else:
|
|
1738
|
+
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
|
|
1739
|
+
|
|
1740
|
+
|
|
1741
|
+
def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
|
|
1742
|
+
"""Serialize an IR Tensor to an ONNX TensorProto.
|
|
1743
|
+
|
|
1744
|
+
Args:
|
|
1745
|
+
tensor: The IR Tensor to serialize.
|
|
1746
|
+
|
|
1747
|
+
Returns:
|
|
1748
|
+
The serialized ONNX TensorProto object.
|
|
1749
|
+
"""
|
|
1750
|
+
tensor_proto = onnx.TensorProto()
|
|
1751
|
+
serialize_tensor_into(tensor_proto, from_=tensor)
|
|
1752
|
+
return tensor_proto
|
|
1753
|
+
|
|
1754
|
+
|
|
1755
|
+
@_capture_errors(lambda tensor_proto, from_: repr(from_))
|
|
1756
|
+
def serialize_tensor_into(
|
|
1757
|
+
tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol
|
|
1758
|
+
) -> None:
|
|
1759
|
+
if isinstance(from_, TensorProtoTensor):
|
|
1760
|
+
# Directly copy from the tensor proto if it is available
|
|
1761
|
+
tensor_proto.CopyFrom(from_.raw)
|
|
1762
|
+
if from_.metadata_props:
|
|
1763
|
+
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
|
|
1764
|
+
return
|
|
1765
|
+
|
|
1766
|
+
if from_.name:
|
|
1767
|
+
tensor_proto.name = from_.name
|
|
1768
|
+
if from_.doc_string:
|
|
1769
|
+
tensor_proto.doc_string = from_.doc_string
|
|
1770
|
+
tensor_proto.data_type = from_.dtype.value
|
|
1771
|
+
tensor_proto.dims.extend(from_.shape.numpy())
|
|
1772
|
+
if isinstance(from_, _core.ExternalTensor):
|
|
1773
|
+
# Store external tensors as is
|
|
1774
|
+
tensor_proto.data_location = onnx.TensorProto.EXTERNAL
|
|
1775
|
+
for k, v in {
|
|
1776
|
+
"location": os.fspath(from_.location),
|
|
1777
|
+
"offset": from_.offset,
|
|
1778
|
+
"length": from_.length,
|
|
1779
|
+
}.items():
|
|
1780
|
+
if v is not None:
|
|
1781
|
+
entry = tensor_proto.external_data.add()
|
|
1782
|
+
entry.key = k
|
|
1783
|
+
entry.value = str(v)
|
|
1784
|
+
elif isinstance(from_, _core.StringTensor):
|
|
1785
|
+
tensor_proto.string_data.extend(from_.string_data())
|
|
1786
|
+
else:
|
|
1787
|
+
tensor_proto.raw_data = from_.tobytes()
|
|
1788
|
+
_serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
|
|
1789
|
+
|
|
1790
|
+
|
|
1791
|
+
def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto:
|
|
1792
|
+
"""Serialize an IR Attribute to an ONNX AttributeProto.
|
|
1793
|
+
|
|
1794
|
+
Args:
|
|
1795
|
+
attribute: The IR Attribute to serialize.
|
|
1796
|
+
|
|
1797
|
+
Returns:
|
|
1798
|
+
The serialized ONNX AttributeProto object.
|
|
1799
|
+
"""
|
|
1800
|
+
attribute_proto = onnx.AttributeProto()
|
|
1801
|
+
serialize_attribute_into(attribute_proto, from_=attribute)
|
|
1802
|
+
return attribute_proto
|
|
1803
|
+
|
|
1804
|
+
|
|
1805
|
+
@_capture_errors(lambda attribute_proto, from_: repr(from_))
|
|
1806
|
+
def serialize_attribute_into(
|
|
1807
|
+
attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol
|
|
1808
|
+
) -> None:
|
|
1809
|
+
attribute_proto.name = from_.name
|
|
1810
|
+
if from_.doc_string:
|
|
1811
|
+
attribute_proto.doc_string = from_.doc_string
|
|
1812
|
+
_fill_in_value_for_attribute(attribute_proto, from_.type, from_.value)
|
|
1813
|
+
|
|
1814
|
+
|
|
1815
|
+
def _fill_in_value_for_attribute(
|
|
1816
|
+
attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any
|
|
1817
|
+
) -> None:
|
|
1818
|
+
if type_ == _enums.AttributeType.INT:
|
|
1819
|
+
# value: int
|
|
1820
|
+
attribute_proto.i = value
|
|
1821
|
+
attribute_proto.type = onnx.AttributeProto.INT
|
|
1822
|
+
elif type_ == _enums.AttributeType.FLOAT:
|
|
1823
|
+
# value: float
|
|
1824
|
+
attribute_proto.f = value
|
|
1825
|
+
attribute_proto.type = onnx.AttributeProto.FLOAT
|
|
1826
|
+
elif type_ == _enums.AttributeType.STRING:
|
|
1827
|
+
# value: str
|
|
1828
|
+
if type(value) is bytes:
|
|
1829
|
+
# Even though onnx.ai/onnx/repo-docs/IR.html#attributes requires the attribute
|
|
1830
|
+
# for strings to be utf-8 encoded bytes, custom ops may still store arbitrary data there
|
|
1831
|
+
logger.warning(
|
|
1832
|
+
"Value in attribute %r should be a string but is instead bytes. ONNX "
|
|
1833
|
+
"spec requires string attributes to be UTF-8 encoded so the model is invalid. "
|
|
1834
|
+
"We will skip encoding the attribute and use the bytes as attribute value",
|
|
1835
|
+
attribute_proto.name,
|
|
1836
|
+
)
|
|
1837
|
+
attribute_proto.s = value
|
|
1838
|
+
else:
|
|
1839
|
+
attribute_proto.s = value.encode("utf-8")
|
|
1840
|
+
attribute_proto.type = onnx.AttributeProto.STRING
|
|
1841
|
+
elif type_ == _enums.AttributeType.INTS:
|
|
1842
|
+
# value: Sequence[int]
|
|
1843
|
+
attribute_proto.ints.extend(value)
|
|
1844
|
+
attribute_proto.type = onnx.AttributeProto.INTS
|
|
1845
|
+
elif type_ == _enums.AttributeType.FLOATS:
|
|
1846
|
+
# value: Sequence[float]
|
|
1847
|
+
attribute_proto.floats.extend(value)
|
|
1848
|
+
attribute_proto.type = onnx.AttributeProto.FLOATS
|
|
1849
|
+
elif type_ == _enums.AttributeType.STRINGS:
|
|
1850
|
+
# value: Sequence[str]
|
|
1851
|
+
attribute_proto.strings.extend([s.encode("utf-8") for s in value])
|
|
1852
|
+
attribute_proto.type = onnx.AttributeProto.STRINGS
|
|
1853
|
+
elif type_ == _enums.AttributeType.TENSOR:
|
|
1854
|
+
# value: _protocols.TensorProtocol
|
|
1855
|
+
serialize_tensor_into(attribute_proto.t, value)
|
|
1856
|
+
attribute_proto.type = onnx.AttributeProto.TENSOR
|
|
1857
|
+
elif type_ == _enums.AttributeType.GRAPH:
|
|
1858
|
+
# value: _protocols.GraphProtocol
|
|
1859
|
+
serialize_graph_into(attribute_proto.g, value)
|
|
1860
|
+
attribute_proto.type = onnx.AttributeProto.GRAPH
|
|
1861
|
+
elif type_ == _enums.AttributeType.TENSORS:
|
|
1862
|
+
# value: Sequence[_protocols.TensorProtocol]
|
|
1863
|
+
for tensor in value:
|
|
1864
|
+
serialize_tensor_into(attribute_proto.tensors.add(), tensor)
|
|
1865
|
+
attribute_proto.type = onnx.AttributeProto.TENSORS
|
|
1866
|
+
elif type_ == _enums.AttributeType.GRAPHS:
|
|
1867
|
+
# value: Sequence[_protocols.GraphProtocol]
|
|
1868
|
+
for graph in value:
|
|
1869
|
+
serialize_graph_into(attribute_proto.graphs.add(), graph)
|
|
1870
|
+
attribute_proto.type = onnx.AttributeProto.GRAPHS
|
|
1871
|
+
elif type_ == _enums.AttributeType.SPARSE_TENSOR:
|
|
1872
|
+
raise NotImplementedError(
|
|
1873
|
+
f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
|
|
1874
|
+
)
|
|
1875
|
+
elif type_ == _enums.AttributeType.SPARSE_TENSORS:
|
|
1876
|
+
raise NotImplementedError(
|
|
1877
|
+
f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
|
|
1878
|
+
)
|
|
1879
|
+
elif type_ == _enums.AttributeType.TYPE_PROTO:
|
|
1880
|
+
# value: _core.TypeAndShape
|
|
1881
|
+
if value.type is not None:
|
|
1882
|
+
serialize_type_into(attribute_proto.tp, value.type)
|
|
1883
|
+
# Need to create the type _before_ writing the shape
|
|
1884
|
+
if value.shape is not None:
|
|
1885
|
+
serialize_shape_into(attribute_proto.tp, value.shape)
|
|
1886
|
+
attribute_proto.type = onnx.AttributeProto.TYPE_PROTO
|
|
1887
|
+
elif type_ == _enums.AttributeType.TYPE_PROTOS:
|
|
1888
|
+
for ir_type in value:
|
|
1889
|
+
# ir_type: _core.TypeAndShape
|
|
1890
|
+
type_proto = attribute_proto.type_protos.add()
|
|
1891
|
+
if ir_type.type is not None:
|
|
1892
|
+
serialize_type_into(type_proto, ir_type.type)
|
|
1893
|
+
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
|
|
1894
|
+
if ir_type.shape is not None:
|
|
1895
|
+
serialize_shape_into(type_proto, ir_type.shape)
|
|
1896
|
+
attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS
|
|
1897
|
+
else:
|
|
1898
|
+
raise TypeError(f"Unsupported attribute type: {type_}")
|
|
1899
|
+
|
|
1900
|
+
|
|
1901
|
+
@_capture_errors(lambda attribute_proto, from_: repr(from_))
|
|
1902
|
+
def serialize_reference_attribute_into(
|
|
1903
|
+
attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol
|
|
1904
|
+
) -> None:
|
|
1905
|
+
attribute_proto.name = from_.name
|
|
1906
|
+
attribute_proto.ref_attr_name = from_.ref_attr_name
|
|
1907
|
+
if from_.doc_string:
|
|
1908
|
+
attribute_proto.doc_string = from_.doc_string
|
|
1909
|
+
attribute_proto.type = typing.cast(onnx.AttributeProto.AttributeType, from_.type.value)
|
|
1910
|
+
|
|
1911
|
+
|
|
1912
|
+
def serialize_reference_attribute(
|
|
1913
|
+
attr: _protocols.ReferenceAttributeProtocol,
|
|
1914
|
+
) -> onnx.AttributeProto:
|
|
1915
|
+
attr_proto = onnx.AttributeProto()
|
|
1916
|
+
serialize_reference_attribute_into(attr_proto, attr)
|
|
1917
|
+
return attr_proto
|
|
1918
|
+
|
|
1919
|
+
|
|
1920
|
+
def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.ValueInfoProto:
|
|
1921
|
+
"""Serialize a value into a ValueInfoProto.
|
|
1922
|
+
|
|
1923
|
+
Args:
|
|
1924
|
+
value: The proto to serialize into.
|
|
1925
|
+
from_: The value to serialize.
|
|
1926
|
+
name: A custom name to set for the value info. If not provided, the name from the value will be used.
|
|
1927
|
+
"""
|
|
1928
|
+
value_info_proto = onnx.ValueInfoProto()
|
|
1929
|
+
serialize_value_into(value_info_proto, value, name=name)
|
|
1930
|
+
return value_info_proto
|
|
1931
|
+
|
|
1932
|
+
|
|
1933
|
+
@_capture_errors(lambda value_info_proto, from_, name="": repr(from_))
|
|
1934
|
+
def serialize_value_into(
|
|
1935
|
+
value_info_proto: onnx.ValueInfoProto,
|
|
1936
|
+
from_: _protocols.ValueProtocol,
|
|
1937
|
+
*,
|
|
1938
|
+
name: str = "",
|
|
1939
|
+
) -> None:
|
|
1940
|
+
"""Serialize a value into a ValueInfoProto.
|
|
1941
|
+
|
|
1942
|
+
Args:
|
|
1943
|
+
value_info_proto: The proto to serialize into.
|
|
1944
|
+
from_: The value to serialize.
|
|
1945
|
+
name: A custom name to set for the value info. If not provided, the name from the value will be used.
|
|
1946
|
+
"""
|
|
1947
|
+
if name:
|
|
1948
|
+
value_info_proto.name = name
|
|
1949
|
+
else:
|
|
1950
|
+
value_info_proto.name = from_.name
|
|
1951
|
+
if from_.metadata_props:
|
|
1952
|
+
_serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props)
|
|
1953
|
+
if from_.type is not None:
|
|
1954
|
+
serialize_type_into(value_info_proto.type, from_.type)
|
|
1955
|
+
# Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
|
|
1956
|
+
if from_.shape is not None:
|
|
1957
|
+
serialize_shape_into(value_info_proto.type, from_.shape)
|
|
1958
|
+
if from_.doc_string:
|
|
1959
|
+
value_info_proto.doc_string = from_.doc_string
|
|
1960
|
+
|
|
1961
|
+
|
|
1962
|
+
@_capture_errors(lambda type_proto, from_: repr(from_))
|
|
1963
|
+
def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
|
|
1964
|
+
if from_.denotation:
|
|
1965
|
+
type_proto.denotation = from_.denotation
|
|
1966
|
+
if isinstance(from_, _core.TensorType):
|
|
1967
|
+
tensor_type_proto = type_proto.tensor_type
|
|
1968
|
+
tensor_type_proto.elem_type = from_.dtype.value
|
|
1969
|
+
elif isinstance(from_, _core.SparseTensorType):
|
|
1970
|
+
sparse_tensor_type_proto = type_proto.sparse_tensor_type
|
|
1971
|
+
sparse_tensor_type_proto.elem_type = from_.dtype.value
|
|
1972
|
+
elif isinstance(from_, _core.SequenceType):
|
|
1973
|
+
sequence_type_proto = type_proto.sequence_type
|
|
1974
|
+
serialize_type_into(sequence_type_proto.elem_type, from_.elem_type)
|
|
1975
|
+
elif isinstance(from_, _core.OptionalType):
|
|
1976
|
+
optional_type_proto = type_proto.optional_type
|
|
1977
|
+
serialize_type_into(optional_type_proto.elem_type, from_.elem_type)
|
|
1978
|
+
else:
|
|
1979
|
+
raise TypeError(f"Unsupported type: {from_}")
|
|
1980
|
+
|
|
1981
|
+
|
|
1982
|
+
def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
|
|
1983
|
+
"""Serialize an IR Type to an ONNX TypeProto.
|
|
1984
|
+
|
|
1985
|
+
Args:
|
|
1986
|
+
type_protocol: The IR Type to serialize.
|
|
1987
|
+
|
|
1988
|
+
Returns:
|
|
1989
|
+
The serialized ONNX TypeProto object.
|
|
1990
|
+
"""
|
|
1991
|
+
type_proto = onnx.TypeProto()
|
|
1992
|
+
serialize_type_into(type_proto, from_=type_protocol)
|
|
1993
|
+
return type_proto
|
|
1994
|
+
|
|
1995
|
+
|
|
1996
|
+
@_capture_errors(lambda type_proto, from_: repr(from_))
|
|
1997
|
+
def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
|
|
1998
|
+
value_field = type_proto.WhichOneof("value")
|
|
1999
|
+
if value_field is None:
|
|
2000
|
+
# We cannot write the shape because we do not know where to write it
|
|
2001
|
+
logger.warning(
|
|
2002
|
+
# TODO(justinchuby): Show more context about the value when move everything to an object
|
|
2003
|
+
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
|
|
2004
|
+
from_,
|
|
2005
|
+
)
|
|
2006
|
+
return
|
|
2007
|
+
tensor_type = getattr(type_proto, value_field)
|
|
2008
|
+
while not isinstance(tensor_type.elem_type, int):
|
|
2009
|
+
# Find the leaf type that has the shape field
|
|
2010
|
+
type_proto = tensor_type.elem_type
|
|
2011
|
+
value_field = type_proto.WhichOneof("value")
|
|
2012
|
+
if value_field is None:
|
|
2013
|
+
logger.warning(
|
|
2014
|
+
# TODO(justinchuby): Show more context about the value when move everything to an object
|
|
2015
|
+
"The value type for shape %s is not known. Please set type for the value. Skipping serialization",
|
|
2016
|
+
from_,
|
|
2017
|
+
)
|
|
2018
|
+
return
|
|
2019
|
+
tensor_type = getattr(type_proto, value_field)
|
|
2020
|
+
# When from is empty, we still need to set the shape field to an empty list by touching it
|
|
2021
|
+
tensor_type.shape.ClearField("dim")
|
|
2022
|
+
for i, dim in enumerate(from_):
|
|
2023
|
+
denotation = from_.get_denotation(i)
|
|
2024
|
+
serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation)
|
|
2025
|
+
|
|
2026
|
+
|
|
2027
|
+
@_capture_errors(lambda dim_proto, dim, denotation: repr(dim_proto))
|
|
2028
|
+
def serialize_dimension_into(
|
|
2029
|
+
dim_proto: onnx.TensorShapeProto.Dimension,
|
|
2030
|
+
dim: int | _protocols.SymbolicDimProtocol,
|
|
2031
|
+
denotation: str | None = None,
|
|
2032
|
+
) -> None:
|
|
2033
|
+
if denotation:
|
|
2034
|
+
dim_proto.denotation = denotation
|
|
2035
|
+
if isinstance(dim, int):
|
|
2036
|
+
dim_proto.dim_value = dim
|
|
2037
|
+
elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
|
|
2038
|
+
if dim.value is not None:
|
|
2039
|
+
dim_proto.dim_param = str(dim.value)
|
|
2040
|
+
# NOTE: None is a valid value for symbolic dimension:
|
|
2041
|
+
# A dimension MAY have neither dim_value nor dim_param set. Such a dimension
|
|
2042
|
+
# represents an unknown dimension unrelated to other unknown dimensions.
|
|
2043
|
+
# Here we will just leave the dim_proto empty.
|