onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/serde.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Serialize and deserialize the intermediate representation to/from ONNX protos."""
4
4
 
5
5
  # NOTES for developers:
@@ -14,12 +14,14 @@
14
14
  from __future__ import annotations
15
15
 
16
16
  import functools
17
+ import typing
17
18
 
18
19
  __all__ = [
19
20
  # Tensors
20
21
  "TensorProtoTensor",
21
22
  # Deserialization
22
23
  "from_proto",
24
+ "from_onnx_text",
23
25
  "deserialize_attribute",
24
26
  "deserialize_dimension",
25
27
  "deserialize_function",
@@ -29,6 +31,7 @@ __all__ = [
29
31
  "deserialize_node",
30
32
  "deserialize_opset_import",
31
33
  "deserialize_tensor",
34
+ "deserialize_tensor_shape",
32
35
  "deserialize_type_proto_for_shape",
33
36
  "deserialize_type_proto_for_type",
34
37
  "deserialize_value_info_proto",
@@ -59,14 +62,14 @@ __all__ = [
59
62
  import collections
60
63
  import logging
61
64
  import os
62
- import typing
63
- from typing import Any, Callable, List, Mapping, Sequence
65
+ from collections.abc import Mapping, Sequence
66
+ from typing import Any, Callable
64
67
 
65
68
  import numpy as np
66
69
  import onnx
67
70
  import onnx.external_data_helper
68
71
 
69
- from onnx_ir import _core, _enums, _metadata, _protocols, _type_casting
72
+ from onnx_ir import _core, _enums, _protocols, _type_casting
70
73
 
71
74
  if typing.TYPE_CHECKING:
72
75
  import google.protobuf.internal.containers as proto_containers
@@ -74,12 +77,11 @@ if typing.TYPE_CHECKING:
74
77
 
75
78
  logger = logging.getLogger(__name__)
76
79
 
77
- _PLEASE_CONTRIBUTE = (
78
- "Please contribute by creating a PR at https://github.com/microsoft/onnxscript."
79
- )
80
+ _PLEASE_CONTRIBUTE = "Please contribute by creating a PR at https://github.com/onnx/onnx-ir."
80
81
  _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
81
82
  10 # ONNX IR version where value info in functions was introduced
82
83
  )
84
+ _QUANT_PARAMETER_TENSOR_NAMES_FIELD = "quant_parameter_tensor_names"
83
85
  _T = typing.TypeVar("_T", bound=Callable[..., Any])
84
86
 
85
87
 
@@ -121,16 +123,35 @@ def _unflatten_complex(
121
123
  return array[::2] + 1j * array[1::2]
122
124
 
123
125
 
124
- def from_proto(
125
- proto: onnx.ModelProto
126
- | onnx.GraphProto
127
- | onnx.NodeProto
128
- | onnx.TensorProto
129
- | onnx.AttributeProto
130
- | onnx.ValueInfoProto
131
- | onnx.TypeProto
132
- | onnx.FunctionProto,
133
- ) -> Any:
126
+ @typing.overload
127
+ def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
128
+ @typing.overload
129
+ def from_proto(proto: onnx.GraphProto) -> _core.Graph: ... # type: ignore[overload-overlap]
130
+ @typing.overload
131
+ def from_proto(proto: onnx.NodeProto) -> _core.Node: ... # type: ignore[overload-overlap]
132
+ @typing.overload
133
+ def from_proto(proto: onnx.TensorProto) -> _protocols.TensorProtocol: ... # type: ignore[overload-overlap]
134
+ @typing.overload
135
+ def from_proto(proto: onnx.AttributeProto) -> _core.Attr: ... # type: ignore[overload-overlap]
136
+ @typing.overload
137
+ def from_proto(proto: onnx.ValueInfoProto) -> _core.Value: ... # type: ignore[overload-overlap]
138
+ @typing.overload
139
+ def from_proto(proto: onnx.TypeProto) -> _core.TypeAndShape: ... # type: ignore[overload-overlap]
140
+ @typing.overload
141
+ def from_proto(proto: onnx.FunctionProto) -> _core.Function: ... # type: ignore[overload-overlap]
142
+ @typing.overload
143
+ def from_proto(proto: onnx.TensorShapeProto) -> _core.Shape: ... # type: ignore[overload-overlap]
144
+ @typing.overload
145
+ def from_proto( # type: ignore[overload-overlap]
146
+ proto: onnx.TensorShapeProto.Dimension,
147
+ ) -> tuple[int | _core.SymbolicDim, str | None]: ...
148
+ @typing.overload
149
+ def from_proto(proto: Sequence[onnx.OperatorSetIdProto]) -> dict[str, int]: ... # type: ignore[overload-overlap]
150
+ @typing.overload
151
+ def from_proto(proto: Sequence[onnx.StringStringEntryProto]) -> dict[str, str]: ... # type: ignore[overload-overlap]
152
+
153
+
154
+ def from_proto(proto: object) -> object:
134
155
  """Deserialize an ONNX proto message to an IR object."""
135
156
  if isinstance(proto, onnx.ModelProto):
136
157
  return deserialize_model(proto)
@@ -151,24 +172,56 @@ def from_proto(
151
172
  )
152
173
  if isinstance(proto, onnx.FunctionProto):
153
174
  return deserialize_function(proto)
175
+ if isinstance(proto, onnx.TensorShapeProto):
176
+ return deserialize_tensor_shape(proto)
177
+ if isinstance(proto, onnx.TensorShapeProto.Dimension):
178
+ return deserialize_dimension(proto)
179
+ if isinstance(proto, Sequence) and all(
180
+ isinstance(p, onnx.OperatorSetIdProto) for p in proto
181
+ ):
182
+ return deserialize_opset_import(proto)
183
+ if isinstance(proto, Sequence) and all(
184
+ isinstance(p, onnx.StringStringEntryProto) for p in proto
185
+ ):
186
+ return deserialize_metadata_props(proto)
154
187
  raise NotImplementedError(
155
188
  f"Deserialization of {type(proto)} in from_proto is not implemented. "
156
189
  "Use a specific ir.serde.deserialize* function instead."
157
190
  )
158
191
 
159
192
 
160
- def to_proto(
161
- ir_object: _protocols.ModelProtocol
162
- | _protocols.GraphProtocol
163
- | _protocols.NodeProtocol
164
- | _protocols.ValueProtocol
165
- | _protocols.AttributeProtocol
166
- | _protocols.ReferenceAttributeProtocol
167
- | _protocols.TensorProtocol
168
- | _protocols.TypeProtocol
169
- | _protocols.GraphViewProtocol
170
- | _protocols.FunctionProtocol,
171
- ) -> Any:
193
+ def from_onnx_text(model_text: str, /) -> _core.Model:
194
+ """Convert the ONNX textual representation to an IR model.
195
+
196
+ Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
197
+ """
198
+ proto = onnx.parser.parse_model(model_text)
199
+ return deserialize_model(proto)
200
+
201
+
202
+ @typing.overload
203
+ def to_proto(ir_object: _protocols.ModelProtocol) -> onnx.ModelProto: ... # type: ignore[overload-overlap]
204
+ @typing.overload
205
+ def to_proto(ir_object: _protocols.GraphProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
206
+ @typing.overload
207
+ def to_proto(ir_object: _protocols.NodeProtocol) -> onnx.NodeProto: ... # type: ignore[overload-overlap]
208
+ @typing.overload
209
+ def to_proto(ir_object: _protocols.TensorProtocol) -> onnx.TensorProto: ... # type: ignore[overload-overlap]
210
+ @typing.overload
211
+ def to_proto(ir_object: _protocols.AttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
212
+ @typing.overload
213
+ def to_proto(ir_object: _protocols.ReferenceAttributeProtocol) -> onnx.AttributeProto: ... # type: ignore[overload-overlap]
214
+ @typing.overload
215
+ def to_proto(ir_object: _protocols.ValueProtocol) -> onnx.ValueInfoProto: ... # type: ignore[overload-overlap]
216
+ @typing.overload
217
+ def to_proto(ir_object: _protocols.TypeProtocol) -> onnx.TypeProto: ... # type: ignore[overload-overlap]
218
+ @typing.overload
219
+ def to_proto(ir_object: _protocols.FunctionProtocol) -> onnx.FunctionProto: ... # type: ignore[overload-overlap]
220
+ @typing.overload
221
+ def to_proto(ir_object: _protocols.GraphViewProtocol) -> onnx.GraphProto: ... # type: ignore[overload-overlap]
222
+
223
+
224
+ def to_proto(ir_object: object) -> object:
172
225
  """Serialize an IR object to a proto."""
173
226
  if isinstance(ir_object, _protocols.ModelProtocol):
174
227
  return serialize_model(ir_object)
@@ -180,9 +233,10 @@ def to_proto(
180
233
  return serialize_tensor(ir_object)
181
234
  if isinstance(ir_object, _protocols.ValueProtocol):
182
235
  return serialize_value(ir_object)
183
- if isinstance(ir_object, _protocols.AttributeProtocol):
236
+ if isinstance(ir_object, _protocols.AttributeProtocol) and not ir_object.is_ref():
184
237
  return serialize_attribute(ir_object)
185
238
  if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
239
+ assert ir_object.is_ref()
186
240
  return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object)
187
241
  if isinstance(ir_object, _protocols.TypeProtocol):
188
242
  return serialize_type_into(onnx.TypeProto(), ir_object)
@@ -199,12 +253,11 @@ def to_proto(
199
253
  class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
200
254
  """A tensor initialized from a tensor proto."""
201
255
 
256
+ __slots__ = ("_proto",)
257
+
202
258
  def __init__(self, proto: onnx.TensorProto) -> None:
259
+ super().__init__(metadata_props=deserialize_metadata_props(proto.metadata_props))
203
260
  self._proto = proto
204
- self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
205
- proto.metadata_props
206
- )
207
- self._metadata: _metadata.MetadataStore | None = None
208
261
 
209
262
  @property
210
263
  def name(self) -> str:
@@ -225,7 +278,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
225
278
  def dtype(self) -> _enums.DataType:
226
279
  return _enums.DataType(self._proto.data_type)
227
280
 
228
- @property
281
+ @property # type: ignore[misc]
229
282
  def doc_string(self) -> str:
230
283
  return self._proto.doc_string
231
284
 
@@ -234,9 +287,10 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
234
287
  return self._proto
235
288
 
236
289
  def __repr__(self) -> str:
237
- # It is a little hard to display the content when there can be types
238
- # unsupported by numpy
239
- # Preferably we should display some content when the tensor is small
290
+ if self.size <= 10:
291
+ tensor_lines = repr(self.numpy()).split("\n")
292
+ tensor_text = " ".join(line.strip() for line in tensor_lines)
293
+ return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
240
294
  return f"{self._repr_base()}(name={self.name!r})"
241
295
 
242
296
  def __array__(self, dtype: Any = None) -> np.ndarray:
@@ -277,8 +331,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
277
331
  raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
278
332
  if self._proto.data_location == onnx.TensorProto.EXTERNAL:
279
333
  raise ValueError(
280
- "Cannot convert external tensor to numpy array. "
281
- "Use ir.ExternalTensor instead."
334
+ "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
282
335
  )
283
336
 
284
337
  if self._proto.HasField("raw_data"):
@@ -323,6 +376,8 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
323
376
  return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
324
377
  elif dtype == _enums.DataType.UINT4:
325
378
  return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
379
+ elif dtype == _enums.DataType.FLOAT4E2M1:
380
+ return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
326
381
  else:
327
382
  # Otherwise convert to the correct dtype and reshape
328
383
  # Note we cannot use view() here because the storage dtype may not be the same size as the target
@@ -369,6 +424,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
369
424
  _enums.DataType.FLOAT8E5M2FNUZ,
370
425
  _enums.DataType.INT4,
371
426
  _enums.DataType.UINT4,
427
+ _enums.DataType.FLOAT4E2M1,
372
428
  }:
373
429
  # uint4 and int4 values are already packed, even when stored as int32
374
430
  # so we don't need to pack them again
@@ -393,23 +449,6 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
393
449
  # For example, int32_data can be empty and still be a valid tensor.
394
450
  return b""
395
451
 
396
- @property
397
- def meta(self) -> _metadata.MetadataStore:
398
- """The metadata store for intermediate analysis.
399
-
400
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
401
- to the ONNX proto.
402
- """
403
- if self._metadata is None:
404
- self._metadata = _metadata.MetadataStore()
405
- return self._metadata
406
-
407
- @property
408
- def metadata_props(self) -> dict[str, str]:
409
- if self._metadata_props is None:
410
- self._metadata_props = {}
411
- return self._metadata_props
412
-
413
452
 
414
453
  def _get_field(proto: Any, field: str) -> Any:
415
454
  if proto.HasField(field):
@@ -472,7 +511,7 @@ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
472
511
  model_version=_get_field(proto, "model_version"),
473
512
  doc_string=_get_field(proto, "doc_string"),
474
513
  functions=functions,
475
- meta_data_props=deserialize_metadata_props(proto.metadata_props),
514
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
476
515
  )
477
516
 
478
517
  # Handle experimental value info for functions created by the dynamo exporter in IR version 9
@@ -541,6 +580,9 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
541
580
 
542
581
  Returns:
543
582
  IR Graph.
583
+
584
+ .. versionadded:: 0.3
585
+ Support for *quantization_annotation* is added.
544
586
  """
545
587
  return _deserialize_graph(proto, [])
546
588
 
@@ -561,44 +603,89 @@ def _deserialize_graph(
561
603
  Returns:
562
604
  IR Graph.
563
605
  """
606
+ # Process TensorAnnotation for quantization
607
+ quantization_annotations = {
608
+ annotation.tensor_name: annotation for annotation in proto.quantization_annotation
609
+ }
610
+
564
611
  # Create values for initializers and inputs
565
612
  initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
566
613
  inputs = [_core.Input(info.name) for info in proto.input]
567
614
  for info, value in zip(proto.input, inputs):
568
615
  deserialize_value_info_proto(info, value)
569
616
 
617
+ # Add TensorAnnotation for inputs if they exist
618
+ if value.name in quantization_annotations:
619
+ _deserialize_quantization_annotation(quantization_annotations[value.name], value)
620
+
570
621
  # Initialize the values dictionary for this graph scope with the inputs and initializers
571
622
  values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
623
+
624
+ # Enter the graph scope by pushing the values for this scope to the stack
572
625
  scoped_values.append(values)
626
+
573
627
  initializer_values = []
574
- for tensor in initializer_tensors:
575
- if tensor.name in values:
628
+ for i, tensor in enumerate(initializer_tensors):
629
+ initializer_name = tensor.name
630
+ if not initializer_name:
631
+ logger.warning(
632
+ "Initializer tensor must have a name but the %s-th initializer does not. Skipping this initializer.",
633
+ i,
634
+ )
635
+ continue
636
+ if initializer_name in values:
576
637
  # The initializer is for an input
577
- initializer_value = values[tensor.name]
638
+ initializer_value = values[initializer_name]
578
639
  initializer_value.const_value = tensor
579
640
  else:
580
641
  # The initializer is for some other value. Create this value first
581
642
  initializer_value = _core.Value(
582
643
  None,
583
644
  index=None,
584
- name=tensor.name,
585
- # TODO(justinchuby): Fix type hinting for shape and dtype
586
- shape=tensor.shape, # type: ignore
645
+ name=initializer_name,
646
+ # Include shape and type even if the shape or type is not provided as ValueInfoProto.
647
+ # Users expect initialized values to have shape and type information.
587
648
  type=_core.TensorType(tensor.dtype),
649
+ shape=tensor.shape, # type: ignore[arg-type]
588
650
  const_value=tensor,
589
651
  )
590
- values[tensor.name] = initializer_value # type: ignore[index]
652
+ if initializer_value.name in quantization_annotations:
653
+ _deserialize_quantization_annotation(
654
+ quantization_annotations[initializer_value.name], initializer_value
655
+ )
656
+ values[initializer_name] = initializer_value
591
657
  initializer_values.append(initializer_value)
592
658
 
593
- # Add ValueInfos for this graph scope
659
+ # Build the value info dictionary to allow for quick lookup for this graph scope
594
660
  value_info = {info.name: info for info in proto.value_info}
595
661
 
596
662
  # Deserialize nodes with all known values
597
- nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]
663
+ nodes = [
664
+ _deserialize_node(node, scoped_values, value_info, quantization_annotations)
665
+ for node in proto.node
666
+ ]
598
667
 
599
- # Fill in values for graph outputs
600
- outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
668
+ outputs = []
669
+ for info in proto.output:
670
+ # Fill in values for graph outputs
671
+ output_name = info.name
672
+ if output_name not in values:
673
+ # Handle (invalid) graph outputs that do not have any producers
674
+ logger.warning(
675
+ "Output '%s' is not produced by any node. The graph has an invalid output",
676
+ output_name,
677
+ )
678
+ value = _core.Value(name=output_name)
679
+ else:
680
+ # A valid, normal graph output
681
+ value = values[output_name]
682
+ # Fill in shape/type information
683
+ deserialize_value_info_proto(info, value)
684
+ outputs.append(value)
685
+
686
+ # Exit the graph scope by popping the values for this scope from the stack
601
687
  scoped_values.pop()
688
+
602
689
  return _core.Graph(
603
690
  inputs,
604
691
  outputs,
@@ -617,7 +704,10 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
617
704
  value_info = {info.name: info for info in getattr(proto, "value_info", [])}
618
705
 
619
706
  # TODO(justinchuby): Handle unsorted nodes
620
- nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node]
707
+ nodes = [
708
+ _deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
709
+ for node in proto.node
710
+ ]
621
711
  outputs = [values[name] for name in proto.output]
622
712
  graph = _core.Graph(
623
713
  inputs,
@@ -631,6 +721,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
631
721
  if hasattr(proto, "overload") and proto.overload
632
722
  else ""
633
723
  ),
724
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
634
725
  )
635
726
  attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
636
727
  # Attributes without defaults
@@ -642,8 +733,7 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
642
733
  name=proto.name,
643
734
  overload=getattr(proto, "overload", ""),
644
735
  graph=graph,
645
- attributes=typing.cast(List[_core.Attr], attributes),
646
- metadata_props=deserialize_metadata_props(proto.metadata_props),
736
+ attributes=attributes,
647
737
  )
648
738
 
649
739
 
@@ -662,29 +752,41 @@ def deserialize_value_info_proto(
662
752
  return value
663
753
 
664
754
 
755
+ @_capture_errors(lambda proto, value: str(proto))
756
+ def _deserialize_quantization_annotation(
757
+ proto: onnx.TensorAnnotation, value: _core.Value
758
+ ) -> None:
759
+ """Deserialize a quantization_annotation as TensorAnnotation into a Value.
760
+
761
+ This function is marked private because we don't expect users to call it directly.
762
+ """
763
+ value.meta[_QUANT_PARAMETER_TENSOR_NAMES_FIELD] = _deserialize_string_string_maps(
764
+ proto.quant_parameter_tensor_names
765
+ )
766
+
767
+
768
+ @_capture_errors(str)
769
+ def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
770
+ # This logic handles when the shape is [] as well
771
+ dim_protos = proto.dim
772
+ deserialized_dim_denotations = [
773
+ deserialize_dimension(dim_proto) for dim_proto in dim_protos
774
+ ]
775
+ dims = [dim for dim, _ in deserialized_dim_denotations]
776
+ denotations = [denotation for _, denotation in deserialized_dim_denotations]
777
+ return _core.Shape(dims, denotations=denotations, frozen=True)
778
+
779
+
665
780
  @_capture_errors(str)
666
781
  def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
667
782
  if proto.HasField("tensor_type"):
668
783
  if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
669
784
  return None
670
- # This logic handles when the shape is [] as well
671
- dim_protos = shape_proto.dim
672
- deserialized_dim_denotations = [
673
- deserialize_dimension(dim_proto) for dim_proto in dim_protos
674
- ]
675
- dims = [dim for dim, _ in deserialized_dim_denotations]
676
- denotations = [denotation for _, denotation in deserialized_dim_denotations]
677
- return _core.Shape(dims, denotations=denotations, frozen=True)
785
+ return deserialize_tensor_shape(shape_proto)
678
786
  if proto.HasField("sparse_tensor_type"):
679
787
  if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
680
788
  return None
681
- dim_protos = shape_proto.dim
682
- deserialized_dim_denotations = [
683
- deserialize_dimension(dim_proto) for dim_proto in dim_protos
684
- ]
685
- dims = [dim for dim, _ in deserialized_dim_denotations]
686
- denotations = [denotation for _, denotation in deserialized_dim_denotations]
687
- return _core.Shape(dims, denotations=denotations, frozen=True)
789
+ return deserialize_tensor_shape(shape_proto)
688
790
  if proto.HasField("sequence_type"):
689
791
  if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
690
792
  return None
@@ -800,14 +902,17 @@ def deserialize_metadata_props(
800
902
  return {entry.key: entry.value for entry in proto}
801
903
 
802
904
 
803
- def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
905
+ _deserialize_string_string_maps = deserialize_metadata_props
906
+
907
+
908
+ def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
804
909
  return _deserialize_attribute(proto, [])
805
910
 
806
911
 
807
912
  @_capture_errors(lambda proto, scoped_values: str(proto))
808
913
  def _deserialize_attribute(
809
914
  proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
810
- ) -> _core.Attr | _core.RefAttr:
915
+ ) -> _core.Attr:
811
916
  name = proto.name
812
917
  doc_string = _get_field(proto, "doc_string")
813
918
  type_ = _enums.AttributeType(proto.type)
@@ -874,14 +979,17 @@ def _deserialize_attribute(
874
979
 
875
980
 
876
981
  def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
877
- return _deserialize_node(proto, scoped_values=[], value_info={})
982
+ return _deserialize_node(
983
+ proto, scoped_values=[{}], value_info={}, quantization_annotations={}
984
+ )
878
985
 
879
986
 
880
- @_capture_errors(lambda proto, scoped_values, value_info: str(proto))
987
+ @_capture_errors(lambda proto, scoped_values, value_info, quantization_annotations: str(proto))
881
988
  def _deserialize_node(
882
989
  proto: onnx.NodeProto,
883
990
  scoped_values: list[dict[str, _core.Value]],
884
991
  value_info: dict[str, onnx.ValueInfoProto],
992
+ quantization_annotations: dict[str, onnx.TensorAnnotation],
885
993
  ) -> _core.Node:
886
994
  node_inputs: list[_core.Value | None] = []
887
995
  for input_name in proto.input:
@@ -924,6 +1032,10 @@ def _deserialize_node(
924
1032
  # Fill in shape/type information if they exist
925
1033
  if input_name in value_info:
926
1034
  deserialize_value_info_proto(value_info[input_name], value)
1035
+ if input_name in quantization_annotations:
1036
+ _deserialize_quantization_annotation(
1037
+ quantization_annotations[input_name], value
1038
+ )
927
1039
  node_inputs.append(value)
928
1040
  # We can only create the value in the current scope. If the subgraph is
929
1041
  # referencing a value that is not in the current scope, it is impossible
@@ -965,6 +1077,8 @@ def _deserialize_node(
965
1077
  proto.name,
966
1078
  proto.op_type,
967
1079
  )
1080
+ if output_name in quantization_annotations:
1081
+ _deserialize_quantization_annotation(quantization_annotations[output_name], value)
968
1082
  node_outputs.append(value)
969
1083
  return _core.Node(
970
1084
  proto.domain,
@@ -1036,7 +1150,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool
1036
1150
  True if value info should be created for the value.
1037
1151
  """
1038
1152
  # No need to serialize value info if it is not set
1039
- return not (value.shape is None and value.type is None)
1153
+ if value.shape is None and value.type is None:
1154
+ return False
1155
+ if not value.name:
1156
+ logger.debug("Did not serialize '%s' because its name is empty", value)
1157
+ return False
1158
+ return True
1040
1159
 
1041
1160
 
1042
1161
  def _serialize_experimental_value_info_for_function_ir9_into(
@@ -1063,7 +1182,7 @@ def _serialize_experimental_value_info_for_function_ir9_into(
1063
1182
 
1064
1183
  for input in function.inputs:
1065
1184
  if not input.name:
1066
- logging.warning(
1185
+ logger.warning(
1067
1186
  "Function '%s': Value name not set for function input: %s",
1068
1187
  function_qualified_name,
1069
1188
  input,
@@ -1076,7 +1195,7 @@ def _serialize_experimental_value_info_for_function_ir9_into(
1076
1195
  for node in function:
1077
1196
  for node_output in node.outputs:
1078
1197
  if not node_output.name:
1079
- logging.warning(
1198
+ logger.warning(
1080
1199
  "Function '%s': Value name not set for node output: %s",
1081
1200
  function_qualified_name,
1082
1201
  node_output,
@@ -1107,23 +1226,46 @@ def _serialize_opset_imports_into(
1107
1226
  opset_ids.add(domain=domain, version=version)
1108
1227
 
1109
1228
 
1110
- def _serialize_metadata_props_into(
1229
+ def _serialize_string_string_maps(
1111
1230
  string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
1112
1231
  onnx.StringStringEntryProto
1113
1232
  ],
1114
1233
  from_: Mapping[str, str],
1115
1234
  ) -> None:
1116
- """Serialize metadata properties into a repeated field of string-string entries.
1235
+ """Serialize a <str, str> mapping into a repeated field of string-string entries.
1117
1236
 
1118
1237
  Args:
1119
1238
  string_string_entries: The repeated field to serialize into.
1120
- from_: The mapping of metadata properties to serialize.
1239
+ from_: The mapping of a <str, str> mapping to serialize.
1121
1240
  """
1122
1241
  # Sort names for deterministic serialization
1123
1242
  for key in sorted(from_):
1124
1243
  string_string_entries.add(key=key, value=from_[key])
1125
1244
 
1126
1245
 
1246
+ _serialize_metadata_props_into = _serialize_string_string_maps
1247
+
1248
+
1249
+ def _maybe_add_quantization_annotation(
1250
+ graph_proto: onnx.GraphProto, value: _protocols.ValueProtocol
1251
+ ) -> None:
1252
+ if quantization_annotation := value.meta.get(_QUANT_PARAMETER_TENSOR_NAMES_FIELD):
1253
+ _serialize_tensor_annotation_into(
1254
+ graph_proto.quantization_annotation.add(), value.name, quantization_annotation
1255
+ )
1256
+
1257
+
1258
+ def _serialize_tensor_annotation_into(
1259
+ tensor_annotation_proto: onnx.TensorAnnotation,
1260
+ tensor_name: str,
1261
+ quant_parameter_tensor_names: dict[str, str],
1262
+ ) -> None:
1263
+ tensor_annotation_proto.tensor_name = tensor_name
1264
+ _serialize_string_string_maps(
1265
+ tensor_annotation_proto.quant_parameter_tensor_names, quant_parameter_tensor_names
1266
+ )
1267
+
1268
+
1127
1269
  def serialize_graph(
1128
1270
  graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
1129
1271
  ) -> onnx.GraphProto:
@@ -1159,29 +1301,41 @@ def serialize_graph_into(
1159
1301
  graph_proto.doc_string = from_.doc_string
1160
1302
  for input_ in from_.inputs:
1161
1303
  serialize_value_into(graph_proto.input.add(), input_)
1304
+ if input_.name not in from_.initializers:
1305
+ # Annotations for initializers will be added below to avoid double adding
1306
+ # TODO(justinchuby): We should add a method is_initializer() on Value when
1307
+ # the initializer list is tracked
1308
+ _maybe_add_quantization_annotation(graph_proto, input_)
1309
+ input_names = {input_.name for input_ in from_.inputs}
1162
1310
  # TODO(justinchuby): Support sparse_initializer
1163
- for initializer in from_.initializers.values():
1164
- if initializer.const_value is None:
1311
+ for value in from_.initializers.values():
1312
+ _maybe_add_quantization_annotation(graph_proto, value)
1313
+ if _should_create_value_info_for_value(value) and value.name not in input_names:
1314
+ # Serialize information about all initializers into value_info,
1315
+ # except for those that are also graph inputs
1316
+ serialize_value_into(graph_proto.value_info.add(), value)
1317
+ if value.const_value is None:
1165
1318
  # Skip initializers without constant values
1166
- logger.warning(
1167
- "Initializer '%s' does not have a constant value set.", initializer.name
1168
- )
1319
+ logger.warning("Initializer '%s' does not have a constant value set.", value.name)
1169
1320
  continue
1170
1321
  # Make sure the tensor's name is the same as the value's name
1171
- initializer.const_value.name = initializer.name
1172
- serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
1322
+ value.const_value.name = value.name
1323
+ serialize_tensor_into(graph_proto.initializer.add(), from_=value.const_value)
1173
1324
  for node in from_:
1174
1325
  serialize_node_into(graph_proto.node.add(), from_=node)
1175
1326
  for node_output in node.outputs:
1176
- if not _should_create_value_info_for_value(node_output):
1177
- # No need to serialize value info if it is not set
1178
- continue
1179
1327
  if node_output.is_graph_output():
1180
- # No need to serialize value info for these outputs because they are also graph outputs
1328
+ # No need to serialize info for these outputs because they are handled as graph outputs
1181
1329
  continue
1182
- serialize_value_into(graph_proto.value_info.add(), node_output)
1330
+ _maybe_add_quantization_annotation(graph_proto, node_output)
1331
+ if not _should_create_value_info_for_value(node_output): # pylint: disable=no-else-continue
1332
+ # No need to serialize value info if it is not set
1333
+ continue
1334
+ else:
1335
+ serialize_value_into(graph_proto.value_info.add(), node_output)
1183
1336
  for output in from_.outputs:
1184
1337
  serialize_value_into(graph_proto.output.add(), from_=output)
1338
+ _maybe_add_quantization_annotation(graph_proto, output)
1185
1339
  if from_.metadata_props:
1186
1340
  _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
1187
1341
 
@@ -1269,6 +1423,23 @@ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
1269
1423
  return node_proto
1270
1424
 
1271
1425
 
1426
+ def _remove_trailing_outputs(
1427
+ outputs: Sequence[_protocols.ValueProtocol],
1428
+ ) -> Sequence[_protocols.ValueProtocol]:
1429
+ """Remove trailing outputs that have empty names.
1430
+
1431
+ Args:
1432
+ outputs: The outputs to remove trailing outputs from.
1433
+
1434
+ Returns:
1435
+ The outputs with trailing outputs removed.
1436
+ """
1437
+ for i, output in enumerate(reversed(outputs)):
1438
+ if output.name:
1439
+ return outputs[: len(outputs) - i]
1440
+ return []
1441
+
1442
+
1272
1443
  @_capture_errors(lambda node_proto, from_: repr(from_))
1273
1444
  def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
1274
1445
  node_proto.op_type = from_.op_type
@@ -1288,23 +1459,16 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
1288
1459
  node_proto.input.append("")
1289
1460
  else:
1290
1461
  node_proto.input.append(input_.name)
1291
- for output in from_.outputs:
1462
+
1463
+ # Do not include the trailing outputs that have empty names
1464
+ for output in _remove_trailing_outputs(from_.outputs):
1292
1465
  node_proto.output.append(output.name)
1466
+
1293
1467
  for attr in from_.attributes.values():
1294
- if isinstance(attr, _core.Attr):
1295
- serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1296
- elif isinstance(attr, _core.RefAttr):
1297
- serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
1298
- # Handle protocol attributes for completeness. We do not check them first because
1299
- # calling isinstance on a protocol can be slow.
1300
- # Most of the time, we will have Attr or RefAttr so the two branches below
1301
- # will not be taken.
1302
- elif isinstance(attr, _protocols.AttributeProtocol):
1303
- serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1304
- elif isinstance(attr, _protocols.ReferenceAttributeProtocol):
1305
- serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
1468
+ if not attr.is_ref():
1469
+ serialize_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
1306
1470
  else:
1307
- raise TypeError(f"Unsupported attribute type: {type(attr)}")
1471
+ serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr) # type: ignore[arg-type]
1308
1472
 
1309
1473
 
1310
1474
  def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto: