onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/serde.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Serialize and deserialize the intermediate representation to/from ONNX protos."""
|
|
4
4
|
|
|
5
5
|
# NOTES for developers:
|
|
@@ -14,12 +14,14 @@
|
|
|
14
14
|
from __future__ import annotations
|
|
15
15
|
|
|
16
16
|
import functools
|
|
17
|
+
import typing
|
|
17
18
|
|
|
18
19
|
__all__ = [
|
|
19
20
|
# Tensors
|
|
20
21
|
"TensorProtoTensor",
|
|
21
22
|
# Deserialization
|
|
22
23
|
"from_proto",
|
|
24
|
+
"from_onnx_text",
|
|
23
25
|
"deserialize_attribute",
|
|
24
26
|
"deserialize_dimension",
|
|
25
27
|
"deserialize_function",
|
|
@@ -29,6 +31,7 @@ __all__ = [
|
|
|
29
31
|
"deserialize_node",
|
|
30
32
|
"deserialize_opset_import",
|
|
31
33
|
"deserialize_tensor",
|
|
34
|
+
"deserialize_tensor_shape",
|
|
32
35
|
"deserialize_type_proto_for_shape",
|
|
33
36
|
"deserialize_type_proto_for_type",
|
|
34
37
|
"deserialize_value_info_proto",
|
|
@@ -59,14 +62,14 @@ __all__ = [
|
|
|
59
62
|
import collections
|
|
60
63
|
import logging
|
|
61
64
|
import os
|
|
62
|
-
import
|
|
63
|
-
from typing import Any, Callable
|
|
65
|
+
from collections.abc import Mapping, Sequence
|
|
66
|
+
from typing import Any, Callable
|
|
64
67
|
|
|
65
68
|
import numpy as np
|
|
66
69
|
import onnx
|
|
67
70
|
import onnx.external_data_helper
|
|
68
71
|
|
|
69
|
-
from onnx_ir import _core, _enums,
|
|
72
|
+
from onnx_ir import _core, _enums, _protocols, _type_casting
|
|
70
73
|
|
|
71
74
|
if typing.TYPE_CHECKING:
|
|
72
75
|
import google.protobuf.internal.containers as proto_containers
|
|
@@ -74,12 +77,11 @@ if typing.TYPE_CHECKING:
|
|
|
74
77
|
|
|
75
78
|
logger = logging.getLogger(__name__)
|
|
76
79
|
|
|
77
|
-
_PLEASE_CONTRIBUTE =
|
|
78
|
-
"Please contribute by creating a PR at https://github.com/microsoft/onnxscript."
|
|
79
|
-
)
|
|
80
|
+
_PLEASE_CONTRIBUTE = "Please contribute by creating a PR at https://github.com/onnx/onnx-ir."
|
|
80
81
|
_FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
|
|
81
82
|
10 # ONNX IR version where value info in functions was introduced
|
|
82
83
|
)
|
|
84
|
+
_QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
|
|
83
85
|
_T = typing.TypeVar("_T", bound=Callable[..., Any])
|
|
84
86
|
|
|
85
87
|
|
|
@@ -121,16 +123,35 @@ def _unflatten_complex(
|
|
|
121
123
|
return array[::2] + 1j * array[1::2]
|
|
122
124
|
|
|
123
125
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
) ->
|
|
126
|
+
@typing.overload
|
|
127
|
+
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
|
|
128
|
+
@typing.overload
|
|
129
|
+
def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap]
|
|
130
|
+
@typing.overload
|
|
131
|
+
def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap]
|
|
132
|
+
@typing.overload
|
|
133
|
+
def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap]
|
|
134
|
+
@typing.overload
|
|
135
|
+
def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap]
|
|
136
|
+
@typing.overload
|
|
137
|
+
def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap]
|
|
138
|
+
@typing.overload
|
|
139
|
+
def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap]
|
|
140
|
+
@typing.overload
|
|
141
|
+
def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap]
|
|
142
|
+
@typing.overload
|
|
143
|
+
def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap]
|
|
144
|
+
@typing.overload
|
|
145
|
+
def from_proto( # type: ignore[overload-overlap]
|
|
146
|
+
proto: onnx.TensorShapeProto.Dimension,
|
|
147
|
+
) -> tuple[int | _core.SymbolicDim, str | None]: ...
|
|
148
|
+
@typing.overload
|
|
149
|
+
def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap]
|
|
150
|
+
@typing.overload
|
|
151
|
+
def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def from_proto(proto: object) -> object:
|
|
134
155
|
"""Deserialize an ONNX proto message to an IR object."""
|
|
135
156
|
if isinstance(proto, onnx.ModelProto):
|
|
136
157
|
return deserialize_model(proto)
|
|
@@ -151,24 +172,56 @@ def from_proto(
|
|
|
151
172
|
)
|
|
152
173
|
if isinstance(proto, onnx.FunctionProto):
|
|
153
174
|
return deserialize_function(proto)
|
|
175
|
+
if isinstance(proto, onnx.TensorShapeProto):
|
|
176
|
+
return deserialize_tensor_shape(proto)
|
|
177
|
+
if isinstance(proto, onnx.TensorShapeProto.Dimension):
|
|
178
|
+
return deserialize_dimension(proto)
|
|
179
|
+
if isinstance(proto, Sequence) and all(
|
|
180
|
+
isinstance(p, onnx.OperatorSetIdProto) for p in proto
|
|
181
|
+
):
|
|
182
|
+
return deserialize_opset_import(proto)
|
|
183
|
+
if isinstance(proto, Sequence) and all(
|
|
184
|
+
isinstance(p, onnx.StringStringEntryProto) for p in proto
|
|
185
|
+
):
|
|
186
|
+
return deserialize_metadata_props(proto)
|
|
154
187
|
raise NotImplementedError(
|
|
155
188
|
f"Deserialization of {type(proto)} in from_proto is not implemented. "
|
|
156
189
|
"Use a specific ir.serde.deserialize* function instead."
|
|
157
190
|
)
|
|
158
191
|
|
|
159
192
|
|
|
160
|
-
def
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
193
|
+
def from_onnx_text(model_text: str, /) -> _core.Model:
|
|
194
|
+
"""Convert the ONNX textual representation to an IR model.
|
|
195
|
+
|
|
196
|
+
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
197
|
+
"""
|
|
198
|
+
proto = onnx.parser.parse_model(model_text)
|
|
199
|
+
return deserialize_model(proto)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
@typing.overload
|
|
203
|
+
def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
|
|
204
|
+
@typing.overload
|
|
205
|
+
def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
|
|
206
|
+
@typing.overload
|
|
207
|
+
def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap]
|
|
208
|
+
@typing.overload
|
|
209
|
+
def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap]
|
|
210
|
+
@typing.overload
|
|
211
|
+
def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
|
|
212
|
+
@typing.overload
|
|
213
|
+
def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
|
|
214
|
+
@typing.overload
|
|
215
|
+
def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap]
|
|
216
|
+
@typing.overload
|
|
217
|
+
def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap]
|
|
218
|
+
@typing.overload
|
|
219
|
+
def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap]
|
|
220
|
+
@typing.overload
|
|
221
|
+
def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def to_proto(ir_object: object) -> object:
|
|
172
225
|
"""Serialize an IR object to a proto."""
|
|
173
226
|
if isinstance(ir_object, _protocols.ModelProtocol):
|
|
174
227
|
return serialize_model(ir_object)
|
|
@@ -180,9 +233,10 @@ def to_proto(
|
|
|
180
233
|
return serialize_tensor(ir_object)
|
|
181
234
|
if isinstance(ir_object, _protocols.ValueProtocol):
|
|
182
235
|
return serialize_value(ir_object)
|
|
183
|
-
if isinstance(ir_object, _protocols.AttributeProtocol):
|
|
236
|
+
if isinstance(ir_object, _protocols.AttributeProtocol) and not ir_object.is_ref():
|
|
184
237
|
return serialize_attribute(ir_object)
|
|
185
238
|
if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
|
|
239
|
+
assert ir_object.is_ref()
|
|
186
240
|
return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object)
|
|
187
241
|
if isinstance(ir_object, _protocols.TypeProtocol):
|
|
188
242
|
return serialize_type_into(onnx.TypeProto(), ir_object)
|
|
@@ -199,12 +253,11 @@ def to_proto(
|
|
|
199
253
|
class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
200
254
|
"""A tensor initialized from a tensor proto."""
|
|
201
255
|
|
|
256
|
+
__slots__ = ("_proto",)
|
|
257
|
+
|
|
202
258
|
def __init__(self, proto: onnx.TensorProto) -> None:
|
|
259
|
+
super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props))
|
|
203
260
|
self._proto = proto
|
|
204
|
-
self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
|
|
205
|
-
proto.metadata_props
|
|
206
|
-
)
|
|
207
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
208
261
|
|
|
209
262
|
@property
|
|
210
263
|
def name(self) -> str:
|
|
@@ -225,7 +278,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
225
278
|
def dtype(self) -> _enums.DataType:
|
|
226
279
|
return _enums.DataType(self._proto.data_type)
|
|
227
280
|
|
|
228
|
-
@property
|
|
281
|
+
@property # type: ignore[misc]
|
|
229
282
|
def doc_string(self) -> str:
|
|
230
283
|
return self._proto.doc_string
|
|
231
284
|
|
|
@@ -234,9 +287,10 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
234
287
|
return self._proto
|
|
235
288
|
|
|
236
289
|
def __repr__(self) -> str:
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
290
|
+
if self.size <= 10:
|
|
291
|
+
tensor_lines = repr(self.numpy()).split("\n")
|
|
292
|
+
tensor_text = " ".join(line.strip() for line in tensor_lines)
|
|
293
|
+
return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
|
|
240
294
|
return f"{self._repr_base()}(name={self.name!r})"
|
|
241
295
|
|
|
242
296
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
@@ -277,8 +331,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
277
331
|
raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
|
|
278
332
|
if self._proto.data_location == onnx.TensorProto.EXTERNAL:
|
|
279
333
|
raise ValueError(
|
|
280
|
-
"Cannot convert external tensor to numpy array. "
|
|
281
|
-
"Use ir.ExternalTensor instead."
|
|
334
|
+
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
|
|
282
335
|
)
|
|
283
336
|
|
|
284
337
|
if self._proto.HasField("raw_data"):
|
|
@@ -323,6 +376,8 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
323
376
|
return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
|
|
324
377
|
elif dtype == _enums.DataType.UINT4:
|
|
325
378
|
return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
|
|
379
|
+
elif dtype == _enums.DataType.FLOAT4E2M1:
|
|
380
|
+
return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
|
|
326
381
|
else:
|
|
327
382
|
# Otherwise convert to the correct dtype and reshape
|
|
328
383
|
# Note we cannot use view() here because the storage dtype may not be the same size as the target
|
|
@@ -369,6 +424,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
369
424
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
370
425
|
_enums.DataType.INT4,
|
|
371
426
|
_enums.DataType.UINT4,
|
|
427
|
+
_enums.DataType.FLOAT4E2M1,
|
|
372
428
|
}:
|
|
373
429
|
# uint4 and int4 values are already packed, even when stored as int32
|
|
374
430
|
# so we don't need to pack them again
|
|
@@ -393,23 +449,6 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
393
449
|
# For example, int32_data can be empty and still be a valid tensor.
|
|
394
450
|
return b""
|
|
395
451
|
|
|
396
|
-
@property
|
|
397
|
-
def meta(self) -> _metadata.MetadataStore:
|
|
398
|
-
"""The metadata store for intermediate analysis.
|
|
399
|
-
|
|
400
|
-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
401
|
-
to the ONNX proto.
|
|
402
|
-
"""
|
|
403
|
-
if self._metadata is None:
|
|
404
|
-
self._metadata = _metadata.MetadataStore()
|
|
405
|
-
return self._metadata
|
|
406
|
-
|
|
407
|
-
@property
|
|
408
|
-
def metadata_props(self) -> dict[str, str]:
|
|
409
|
-
if self._metadata_props is None:
|
|
410
|
-
self._metadata_props = {}
|
|
411
|
-
return self._metadata_props
|
|
412
|
-
|
|
413
452
|
|
|
414
453
|
def _get_field(proto: Any, field: str) -> Any:
|
|
415
454
|
if proto.HasField(field):
|
|
@@ -472,7 +511,7 @@ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
|
|
|
472
511
|
model_version=_get_field(proto, "model_version"),
|
|
473
512
|
doc_string=_get_field(proto, "doc_string"),
|
|
474
513
|
functions=functions,
|
|
475
|
-
|
|
514
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
476
515
|
)
|
|
477
516
|
|
|
478
517
|
# Handle experimental value info for functions created by the dynamo exporter in IR version 9
|
|
@@ -541,6 +580,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
|
|
|
541
580
|
|
|
542
581
|
Returns:
|
|
543
582
|
IR Graph.
|
|
583
|
+
|
|
584
|
+
.. versionadded:: 0.3
|
|
585
|
+
Support for *quantization_annotation* is added.
|
|
544
586
|
"""
|
|
545
587
|
return _deserialize_graph(proto, [])
|
|
546
588
|
|
|
@@ -561,44 +603,89 @@ def _deserialize_graph(
|
|
|
561
603
|
Returns:
|
|
562
604
|
IR Graph.
|
|
563
605
|
"""
|
|
606
|
+
# Process TensorAnnotation for quantization
|
|
607
|
+
quantization_annotations = {
|
|
608
|
+
annotation.tensor_name: annotation for annotation in proto.quantization_annotation
|
|
609
|
+
}
|
|
610
|
+
|
|
564
611
|
# Create values for initializers and inputs
|
|
565
612
|
initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
|
|
566
613
|
inputs = [_core.Input(info.name) for info in proto.input]
|
|
567
614
|
for info, value in zip(proto.input, inputs):
|
|
568
615
|
deserialize_value_info_proto(info, value)
|
|
569
616
|
|
|
617
|
+
# Add TensorAnnotation for inputs if they exist
|
|
618
|
+
if value.name in quantization_annotations:
|
|
619
|
+
_deserialize_quantization_annotation(quantization_annotations[value.name], value)
|
|
620
|
+
|
|
570
621
|
# Initialize the values dictionary for this graph scope with the inputs and initializers
|
|
571
622
|
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
623
|
+
|
|
624
|
+
# Enter the graph scope by pushing the values for this scope to the stack
|
|
572
625
|
scoped_values.append(values)
|
|
626
|
+
|
|
573
627
|
initializer_values = []
|
|
574
|
-
for tensor in initializer_tensors:
|
|
575
|
-
|
|
628
|
+
for i, tensor in enumerate(initializer_tensors):
|
|
629
|
+
initializer_name = tensor.name
|
|
630
|
+
if not initializer_name:
|
|
631
|
+
logger.warning(
|
|
632
|
+
"Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
|
|
633
|
+
i,
|
|
634
|
+
)
|
|
635
|
+
continue
|
|
636
|
+
if initializer_name in values:
|
|
576
637
|
# The initializer is for an input
|
|
577
|
-
initializer_value = values[
|
|
638
|
+
initializer_value = values[initializer_name]
|
|
578
639
|
initializer_value.const_value = tensor
|
|
579
640
|
else:
|
|
580
641
|
# The initializer is for some other value. Create this value first
|
|
581
642
|
initializer_value = _core.Value(
|
|
582
643
|
None,
|
|
583
644
|
index=None,
|
|
584
|
-
name=
|
|
585
|
-
#
|
|
586
|
-
|
|
645
|
+
name=initializer_name,
|
|
646
|
+
# Include shape and type even if the shape or type is not provided as ValueInfoProto.
|
|
647
|
+
# Users expect initialized values to have shape and type information.
|
|
587
648
|
type=_core.TensorType(tensor.dtype),
|
|
649
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
588
650
|
const_value=tensor,
|
|
589
651
|
)
|
|
590
|
-
|
|
652
|
+
if initializer_value.name in quantization_annotations:
|
|
653
|
+
_deserialize_quantization_annotation(
|
|
654
|
+
quantization_annotations[initializer_value.name], initializer_value
|
|
655
|
+
)
|
|
656
|
+
values[initializer_name] = initializer_value
|
|
591
657
|
initializer_values.append(initializer_value)
|
|
592
658
|
|
|
593
|
-
#
|
|
659
|
+
# Build the value info dictionary to allow for quick lookup for this graph scope
|
|
594
660
|
value_info = {info.name: info for info in proto.value_info}
|
|
595
661
|
|
|
596
662
|
# Deserialize nodes with all known values
|
|
597
|
-
nodes = [
|
|
663
|
+
nodes = [
|
|
664
|
+
_deserialize_node(node, scoped_values, value_info, quantization_annotations)
|
|
665
|
+
for node in proto.node
|
|
666
|
+
]
|
|
598
667
|
|
|
599
|
-
|
|
600
|
-
|
|
668
|
+
outputs = []
|
|
669
|
+
for info in proto.output:
|
|
670
|
+
# Fill in values for graph outputs
|
|
671
|
+
output_name = info.name
|
|
672
|
+
if output_name not in values:
|
|
673
|
+
# Handle (invalid) graph outputs that do not have any producers
|
|
674
|
+
logger.warning(
|
|
675
|
+
"Output '%s' is not produced by any node. The graph has an invalid output",
|
|
676
|
+
output_name,
|
|
677
|
+
)
|
|
678
|
+
value = _core.Value(name=output_name)
|
|
679
|
+
else:
|
|
680
|
+
# A valid, normal graph output
|
|
681
|
+
value = values[output_name]
|
|
682
|
+
# Fill in shape/type information
|
|
683
|
+
deserialize_value_info_proto(info, value)
|
|
684
|
+
outputs.append(value)
|
|
685
|
+
|
|
686
|
+
# Exit the graph scope by popping the values for this scope from the stack
|
|
601
687
|
scoped_values.pop()
|
|
688
|
+
|
|
602
689
|
return _core.Graph(
|
|
603
690
|
inputs,
|
|
604
691
|
outputs,
|
|
@@ -617,7 +704,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
617
704
|
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
|
|
618
705
|
|
|
619
706
|
# TODO(justinchuby): Handle unsorted nodes
|
|
620
|
-
nodes = [
|
|
707
|
+
nodes = [
|
|
708
|
+
_deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
|
|
709
|
+
for node in proto.node
|
|
710
|
+
]
|
|
621
711
|
outputs = [values[name] for name in proto.output]
|
|
622
712
|
graph = _core.Graph(
|
|
623
713
|
inputs,
|
|
@@ -631,6 +721,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
631
721
|
if hasattr(proto, "overload") and proto.overload
|
|
632
722
|
else ""
|
|
633
723
|
),
|
|
724
|
+
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
634
725
|
)
|
|
635
726
|
attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
|
|
636
727
|
# Attributes without defaults
|
|
@@ -642,8 +733,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
642
733
|
name=proto.name,
|
|
643
734
|
overload=getattr(proto, "overload", ""),
|
|
644
735
|
graph=graph,
|
|
645
|
-
attributes=
|
|
646
|
-
metadata_props=deserialize_metadata_props(proto.metadata_props),
|
|
736
|
+
attributes=attributes,
|
|
647
737
|
)
|
|
648
738
|
|
|
649
739
|
|
|
@@ -662,29 +752,41 @@ def deserialize_value_info_proto(
|
|
|
662
752
|
return value
|
|
663
753
|
|
|
664
754
|
|
|
755
|
+
@_capture_errors(lambda proto, value: str(proto))
|
|
756
|
+
def _deserialize_quantization_annotation(
|
|
757
|
+
proto: onnx.TensorAnnotation, value: _core.Value
|
|
758
|
+
) -> None:
|
|
759
|
+
"""Deserialize a quantization_annotation as TensorAnnotation into a Value.
|
|
760
|
+
|
|
761
|
+
This function is marked private because we don't expect users to call it directly.
|
|
762
|
+
"""
|
|
763
|
+
value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps(
|
|
764
|
+
proto.quant_parameter_tensor_names
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
@_capture_errors(str)
|
|
769
|
+
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
770
|
+
# This logic handles when the shape is [] as well
|
|
771
|
+
dim_protos = proto.dim
|
|
772
|
+
deserialized_dim_denotations = [
|
|
773
|
+
deserialize_dimension(dim_proto) for dim_proto in dim_protos
|
|
774
|
+
]
|
|
775
|
+
dims = [dim for dim, _ in deserialized_dim_denotations]
|
|
776
|
+
denotations = [denotation for _, denotation in deserialized_dim_denotations]
|
|
777
|
+
return _core.Shape(dims, denotations=denotations, frozen=True)
|
|
778
|
+
|
|
779
|
+
|
|
665
780
|
@_capture_errors(str)
|
|
666
781
|
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
|
|
667
782
|
if proto.HasField("tensor_type"):
|
|
668
783
|
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
|
|
669
784
|
return None
|
|
670
|
-
|
|
671
|
-
dim_protos = shape_proto.dim
|
|
672
|
-
deserialized_dim_denotations = [
|
|
673
|
-
deserialize_dimension(dim_proto) for dim_proto in dim_protos
|
|
674
|
-
]
|
|
675
|
-
dims = [dim for dim, _ in deserialized_dim_denotations]
|
|
676
|
-
denotations = [denotation for _, denotation in deserialized_dim_denotations]
|
|
677
|
-
return _core.Shape(dims, denotations=denotations, frozen=True)
|
|
785
|
+
return deserialize_tensor_shape(shape_proto)
|
|
678
786
|
if proto.HasField("sparse_tensor_type"):
|
|
679
787
|
if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
|
|
680
788
|
return None
|
|
681
|
-
|
|
682
|
-
deserialized_dim_denotations = [
|
|
683
|
-
deserialize_dimension(dim_proto) for dim_proto in dim_protos
|
|
684
|
-
]
|
|
685
|
-
dims = [dim for dim, _ in deserialized_dim_denotations]
|
|
686
|
-
denotations = [denotation for _, denotation in deserialized_dim_denotations]
|
|
687
|
-
return _core.Shape(dims, denotations=denotations, frozen=True)
|
|
789
|
+
return deserialize_tensor_shape(shape_proto)
|
|
688
790
|
if proto.HasField("sequence_type"):
|
|
689
791
|
if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
|
|
690
792
|
return None
|
|
@@ -800,14 +902,17 @@ def deserialize_metadata_props(
|
|
|
800
902
|
return {entry.key: entry.value for entry in proto}
|
|
801
903
|
|
|
802
904
|
|
|
803
|
-
|
|
905
|
+
_deserialize_string_string_maps = deserialize_metadata_props
|
|
906
|
+
|
|
907
|
+
|
|
908
|
+
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
|
|
804
909
|
return _deserialize_attribute(proto, [])
|
|
805
910
|
|
|
806
911
|
|
|
807
912
|
@_capture_errors(lambda proto, scoped_values: str(proto))
|
|
808
913
|
def _deserialize_attribute(
|
|
809
914
|
proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
|
|
810
|
-
) -> _core.Attr
|
|
915
|
+
) -> _core.Attr:
|
|
811
916
|
name = proto.name
|
|
812
917
|
doc_string = _get_field(proto, "doc_string")
|
|
813
918
|
type_ = _enums.AttributeType(proto.type)
|
|
@@ -874,14 +979,17 @@ def _deserialize_attribute(
|
|
|
874
979
|
|
|
875
980
|
|
|
876
981
|
def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
|
|
877
|
-
return _deserialize_node(
|
|
982
|
+
return _deserialize_node(
|
|
983
|
+
proto, scoped_values=[{}], value_info={}, quantization_annotations={}
|
|
984
|
+
)
|
|
878
985
|
|
|
879
986
|
|
|
880
|
-
@_capture_errors(lambda proto, scoped_values, value_info: str(proto))
|
|
987
|
+
@_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto))
|
|
881
988
|
def _deserialize_node(
|
|
882
989
|
proto: onnx.NodeProto,
|
|
883
990
|
scoped_values: list[dict[str, _core.Value]],
|
|
884
991
|
value_info: dict[str, onnx.ValueInfoProto],
|
|
992
|
+
quantization_annotations: dict[str, onnx.TensorAnnotation],
|
|
885
993
|
) -> _core.Node:
|
|
886
994
|
node_inputs: list[_core.Value | None] = []
|
|
887
995
|
for input_name in proto.input:
|
|
@@ -924,6 +1032,10 @@ def _deserialize_node(
|
|
|
924
1032
|
# Fill in shape/type information if they exist
|
|
925
1033
|
if input_name in value_info:
|
|
926
1034
|
deserialize_value_info_proto(value_info[input_name], value)
|
|
1035
|
+
if input_name in quantization_annotations:
|
|
1036
|
+
_deserialize_quantization_annotation(
|
|
1037
|
+
quantization_annotations[input_name], value
|
|
1038
|
+
)
|
|
927
1039
|
node_inputs.append(value)
|
|
928
1040
|
# We can only create the value in the current scope. If the subgraph is
|
|
929
1041
|
# referencing a value that is not in the current scope, it is impossible
|
|
@@ -965,6 +1077,8 @@ def _deserialize_node(
|
|
|
965
1077
|
proto.name,
|
|
966
1078
|
proto.op_type,
|
|
967
1079
|
)
|
|
1080
|
+
if output_name in quantization_annotations:
|
|
1081
|
+
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
|
|
968
1082
|
node_outputs.append(value)
|
|
969
1083
|
return _core.Node(
|
|
970
1084
|
proto.domain,
|
|
@@ -1036,7 +1150,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool
|
|
|
1036
1150
|
True if value info should be created for the value.
|
|
1037
1151
|
"""
|
|
1038
1152
|
# No need to serialize value info if it is not set
|
|
1039
|
-
|
|
1153
|
+
if value.shape is None and value.type is None:
|
|
1154
|
+
return False
|
|
1155
|
+
if not value.name:
|
|
1156
|
+
logger.debug("Did not serialize '%s' because its name is empty", value)
|
|
1157
|
+
return False
|
|
1158
|
+
return True
|
|
1040
1159
|
|
|
1041
1160
|
|
|
1042
1161
|
def _serialize_experimental_value_info_for_function_ir9_into(
|
|
@@ -1063,7 +1182,7 @@ def _serialize_experimental_value_info_for_function_ir9_into(
|
|
|
1063
1182
|
|
|
1064
1183
|
for input in function.inputs:
|
|
1065
1184
|
if not input.name:
|
|
1066
|
-
|
|
1185
|
+
logger.warning(
|
|
1067
1186
|
"Function '%s': Value name not set for function input: %s",
|
|
1068
1187
|
function_qualified_name,
|
|
1069
1188
|
input,
|
|
@@ -1076,7 +1195,7 @@ def _serialize_experimental_value_info_for_function_ir9_into(
|
|
|
1076
1195
|
for node in function:
|
|
1077
1196
|
for node_output in node.outputs:
|
|
1078
1197
|
if not node_output.name:
|
|
1079
|
-
|
|
1198
|
+
logger.warning(
|
|
1080
1199
|
"Function '%s': Value name not set for node output: %s",
|
|
1081
1200
|
function_qualified_name,
|
|
1082
1201
|
node_output,
|
|
@@ -1107,23 +1226,46 @@ def _serialize_opset_imports_into(
|
|
|
1107
1226
|
opset_ids.add(domain=domain, version=version)
|
|
1108
1227
|
|
|
1109
1228
|
|
|
1110
|
-
def
|
|
1229
|
+
def _serialize_string_string_maps(
|
|
1111
1230
|
string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
|
|
1112
1231
|
onnx.StringStringEntryProto
|
|
1113
1232
|
],
|
|
1114
1233
|
from_: Mapping[str, str],
|
|
1115
1234
|
) -> None:
|
|
1116
|
-
"""Serialize
|
|
1235
|
+
"""Serialize a <str, str> mapping into a repeated field of string-string entries.
|
|
1117
1236
|
|
|
1118
1237
|
Args:
|
|
1119
1238
|
string_string_entries: The repeated field to serialize into.
|
|
1120
|
-
from_: The mapping of
|
|
1239
|
+
from_: The mapping of a <str, str> mapping to serialize.
|
|
1121
1240
|
"""
|
|
1122
1241
|
# Sort names for deterministic serialization
|
|
1123
1242
|
for key in sorted(from_):
|
|
1124
1243
|
string_string_entries.add(key=key, value=from_[key])
|
|
1125
1244
|
|
|
1126
1245
|
|
|
1246
|
+
_serialize_metadata_props_into = _serialize_string_string_maps
|
|
1247
|
+
|
|
1248
|
+
|
|
1249
|
+
def _maybe_add_quantization_annotation(
|
|
1250
|
+
graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol
|
|
1251
|
+
) -> None:
|
|
1252
|
+
if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD):
|
|
1253
|
+
_serialize_tensor_annotation_into(
|
|
1254
|
+
graph_proto.quantization_annotation.add(), value.name, quantization_annotation
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
|
|
1258
|
+
def _serialize_tensor_annotation_into(
|
|
1259
|
+
tensor_annotation_proto: onnx.TensorAnnotation,
|
|
1260
|
+
tensor_name: str,
|
|
1261
|
+
quant_parameter_tensor_names: dict[str, str],
|
|
1262
|
+
) -> None:
|
|
1263
|
+
tensor_annotation_proto.tensor_name = tensor_name
|
|
1264
|
+
_serialize_string_string_maps(
|
|
1265
|
+
tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
|
|
1127
1269
|
def serialize_graph(
|
|
1128
1270
|
graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
|
|
1129
1271
|
) -> onnx.GraphProto:
|
|
@@ -1159,29 +1301,41 @@ def serialize_graph_into(
|
|
|
1159
1301
|
graph_proto.doc_string = from_.doc_string
|
|
1160
1302
|
for input_ in from_.inputs:
|
|
1161
1303
|
serialize_value_into(graph_proto.input.add(), input_)
|
|
1304
|
+
if input_.name not in from_.initializers:
|
|
1305
|
+
# Annotations for initializers will be added below to avoid double adding
|
|
1306
|
+
# TODO(justinchuby): We should add a method is_initializer() on Value when
|
|
1307
|
+
# the initializer list is tracked
|
|
1308
|
+
_maybe_add_quantization_annotation(graph_proto, input_)
|
|
1309
|
+
input_names = {input_.name for input_ in from_.inputs}
|
|
1162
1310
|
# TODO(justinchuby): Support sparse_initializer
|
|
1163
|
-
for
|
|
1164
|
-
|
|
1311
|
+
for value in from_.initializers.values():
|
|
1312
|
+
_maybe_add_quantization_annotation(graph_proto, value)
|
|
1313
|
+
if _should_create_value_info_for_value(value) and value.name not in input_names:
|
|
1314
|
+
# Serialize information about all initializers into value_info,
|
|
1315
|
+
# except for those that are also graph inputs
|
|
1316
|
+
serialize_value_into(graph_proto.value_info.add(), value)
|
|
1317
|
+
if value.const_value is None:
|
|
1165
1318
|
# Skip initializers without constant values
|
|
1166
|
-
logger.warning(
|
|
1167
|
-
"Initializer '%s' does not have a constant value set.", initializer.name
|
|
1168
|
-
)
|
|
1319
|
+
logger.warning("Initializer '%s' does not have a constant value set.", value.name)
|
|
1169
1320
|
continue
|
|
1170
1321
|
# Make sure the tensor's name is the same as the value's name
|
|
1171
|
-
|
|
1172
|
-
serialize_tensor_into(graph_proto.initializer.add(), from_=
|
|
1322
|
+
value.const_value.name = value.name
|
|
1323
|
+
serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
|
|
1173
1324
|
for node in from_:
|
|
1174
1325
|
serialize_node_into(graph_proto.node.add(), from_=node)
|
|
1175
1326
|
for node_output in node.outputs:
|
|
1176
|
-
if not _should_create_value_info_for_value(node_output):
|
|
1177
|
-
# No need to serialize value info if it is not set
|
|
1178
|
-
continue
|
|
1179
1327
|
if node_output.is_graph_output():
|
|
1180
|
-
# No need to serialize
|
|
1328
|
+
# No need to serialize info for these outputs because they are handled as graph outputs
|
|
1181
1329
|
continue
|
|
1182
|
-
|
|
1330
|
+
_maybe_add_quantization_annotation(graph_proto, node_output)
|
|
1331
|
+
if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue
|
|
1332
|
+
# No need to serialize value info if it is not set
|
|
1333
|
+
continue
|
|
1334
|
+
else:
|
|
1335
|
+
serialize_value_into(graph_proto.value_info.add(), node_output)
|
|
1183
1336
|
for output in from_.outputs:
|
|
1184
1337
|
serialize_value_into(graph_proto.output.add(), from_=output)
|
|
1338
|
+
_maybe_add_quantization_annotation(graph_proto, output)
|
|
1185
1339
|
if from_.metadata_props:
|
|
1186
1340
|
_serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
|
|
1187
1341
|
|
|
@@ -1269,6 +1423,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
|
|
|
1269
1423
|
return node_proto
|
|
1270
1424
|
|
|
1271
1425
|
|
|
1426
|
+
def _remove_trailing_outputs(
|
|
1427
|
+
outputs: Sequence[_protocols.ValueProtocol],
|
|
1428
|
+
) -> Sequence[_protocols.ValueProtocol]:
|
|
1429
|
+
"""Remove trailing outputs that have empty names.
|
|
1430
|
+
|
|
1431
|
+
Args:
|
|
1432
|
+
outputs: The outputs to remove trailing outputs from.
|
|
1433
|
+
|
|
1434
|
+
Returns:
|
|
1435
|
+
The outputs with trailing outputs removed.
|
|
1436
|
+
"""
|
|
1437
|
+
for i, output in enumerate(reversed(outputs)):
|
|
1438
|
+
if output.name:
|
|
1439
|
+
return outputs[: len(outputs) - i]
|
|
1440
|
+
return []
|
|
1441
|
+
|
|
1442
|
+
|
|
1272
1443
|
@_capture_errors(lambda node_proto, from_: repr(from_))
|
|
1273
1444
|
def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
|
|
1274
1445
|
node_proto.op_type = from_.op_type
|
|
@@ -1288,23 +1459,16 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
|
|
|
1288
1459
|
node_proto.input.append("")
|
|
1289
1460
|
else:
|
|
1290
1461
|
node_proto.input.append(input_.name)
|
|
1291
|
-
|
|
1462
|
+
|
|
1463
|
+
# Do not include the trailing outputs that have empty names
|
|
1464
|
+
for output in _remove_trailing_outputs(from_.outputs):
|
|
1292
1465
|
node_proto.output.append(output.name)
|
|
1466
|
+
|
|
1293
1467
|
for attr in from_.attributes.values():
|
|
1294
|
-
if
|
|
1295
|
-
serialize_attribute_into(node_proto.attribute.add(), from_=attr)
|
|
1296
|
-
elif isinstance(attr, _core.RefAttr):
|
|
1297
|
-
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
|
|
1298
|
-
# Handle protocol attributes for completeness. We do not check them first because
|
|
1299
|
-
# calling isinstance on a protocol can be slow.
|
|
1300
|
-
# Most of the time, we will have Attr or RefAttr so the two branches below
|
|
1301
|
-
# will not be taken.
|
|
1302
|
-
elif isinstance(attr, _protocols.AttributeProtocol):
|
|
1303
|
-
serialize_attribute_into(node_proto.attribute.add(), from_=attr)
|
|
1304
|
-
elif isinstance(attr, _protocols.ReferenceAttributeProtocol):
|
|
1305
|
-
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
|
|
1468
|
+
if not attr.is_ref():
|
|
1469
|
+
serialize_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
|
|
1306
1470
|
else:
|
|
1307
|
-
|
|
1471
|
+
serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
|
|
1308
1472
|
|
|
1309
1473
|
|
|
1310
1474
|
def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
|