onnx-ir 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/serde.py ADDED
@@ -0,0 +1,1551 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """Serialize and deserialize the intermediate representation to/from ONNX protos."""
4
+
5
+ # NOTES for developers:
6
+ # NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
7
+ #
8
+ # NOTE: Protobuf serialization
9
+ # Initializing a protobuf message with initialized protobuf messages incurs
10
+ # a copy and is slow. Instead, use proto.add() to add to a repeated field.
11
+ # or initialize the message first and then set the fields if the fields are
12
+ # plain Python objects.
13
+
14
+ from __future__ import annotations
15
+
16
+ import functools
17
+
18
+ __all__ = [
19
+ # Tensors
20
+ "TensorProtoTensor",
21
+ # Deserialization
22
+ "from_proto",
23
+ "deserialize_attribute",
24
+ "deserialize_dimension",
25
+ "deserialize_function",
26
+ "deserialize_graph",
27
+ "deserialize_metadata_props",
28
+ "deserialize_model",
29
+ "deserialize_node",
30
+ "deserialize_opset_import",
31
+ "deserialize_tensor",
32
+ "deserialize_type_proto_for_shape",
33
+ "deserialize_type_proto_for_type",
34
+ "deserialize_value_info_proto",
35
+ # Serialization
36
+ "to_proto",
37
+ "serialize_attribute_into",
38
+ "serialize_attribute",
39
+ "serialize_dimension_into",
40
+ "serialize_function_into",
41
+ "serialize_function",
42
+ "serialize_graph_into",
43
+ "serialize_graph",
44
+ "serialize_model_into",
45
+ "serialize_model",
46
+ "serialize_node_into",
47
+ "serialize_node",
48
+ "serialize_shape_into",
49
+ "serialize_reference_attribute_into",
50
+ "serialize_tensor_into",
51
+ "serialize_tensor",
52
+ "serialize_type_into",
53
+ "serialize_type",
54
+ "serialize_value_into",
55
+ "serialize_value",
56
+ "SerdeError",
57
+ ]
58
+
59
+ import collections
60
+ import logging
61
+ import os
62
+ import typing
63
+ from typing import Any, Callable, List, Mapping, Sequence
64
+
65
+ import numpy as np
66
+ import onnx
67
+ import onnx.external_data_helper
68
+
69
+ from onnx_ir import _core, _enums, _metadata, _protocols, _type_casting
70
+
71
+ if typing.TYPE_CHECKING:
72
+ import google.protobuf.internal.containers as proto_containers
73
+ import numpy.typing as npt
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ _PLEASE_CONTRIBUTE = (
78
+ "Please contribute by creating a PR at https://github.com/microsoft/onnxscript."
79
+ )
80
+ _FUNCTION_VALUE_INFO_SUPPORTED_VERSION = (
81
+ 10 # ONNX IR version where value info in functions was introduced
82
+ )
83
+ _T = typing.TypeVar("_T", bound=Callable[..., Any])
84
+
85
+
86
+ class SerdeError(RuntimeError):
87
+ """Error during serialization or deserialization."""
88
+
89
+
90
+ def _capture_errors(arg_capturer: Callable[..., str]) -> Callable[[_T], _T]:
91
+ """Decorator to capture errors and display the stack."""
92
+
93
+ def decorator(func: _T) -> _T:
94
+ @functools.wraps(func)
95
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
96
+ try:
97
+ return func(*args, **kwargs)
98
+ except Exception as e:
99
+ raise SerdeError(
100
+ f"Error calling {func.__name__} with: {arg_capturer(*args, **kwargs)}"
101
+ ) from e
102
+
103
+ return wrapper # type: ignore
104
+
105
+ return decorator
106
+
107
+
108
+ def _little_endian_dtype(dtype) -> np.dtype:
109
+ """Create a small endian dtype on all platforms.
110
+
111
+ This is useful because ONNX always stores raw_data in small endian. On big
112
+ endian platforms, we still need to interpret the raw_data in small endian.
113
+ """
114
+ return np.dtype(dtype).newbyteorder("<")
115
+
116
+
117
+ def _unflatten_complex(
118
+ array: npt.NDArray[np.float32 | np.float64],
119
+ ) -> npt.NDArray[np.complex64 | np.complex128]:
120
+ """Convert the real representation of a complex dtype to the complex dtype."""
121
+ return array[::2] + 1j * array[1::2]
122
+
123
+
124
+ def from_proto(
125
+ proto: onnx.ModelProto
126
+ | onnx.GraphProto
127
+ | onnx.NodeProto
128
+ | onnx.TensorProto
129
+ | onnx.AttributeProto
130
+ | onnx.ValueInfoProto
131
+ | onnx.TypeProto
132
+ | onnx.FunctionProto,
133
+ ) -> Any:
134
+ """Deserialize an ONNX proto message to an IR object."""
135
+ if isinstance(proto, onnx.ModelProto):
136
+ return deserialize_model(proto)
137
+ if isinstance(proto, onnx.GraphProto):
138
+ return deserialize_graph(proto)
139
+ if isinstance(proto, onnx.NodeProto):
140
+ return deserialize_node(proto)
141
+ if isinstance(proto, onnx.TensorProto):
142
+ return deserialize_tensor(proto)
143
+ if isinstance(proto, onnx.AttributeProto):
144
+ return deserialize_attribute(proto)
145
+ if isinstance(proto, onnx.ValueInfoProto):
146
+ return deserialize_value_info_proto(proto, None)
147
+ if isinstance(proto, onnx.TypeProto):
148
+ return _core.TypeAndShape(
149
+ deserialize_type_proto_for_type(proto),
150
+ deserialize_type_proto_for_shape(proto),
151
+ )
152
+ if isinstance(proto, onnx.FunctionProto):
153
+ return deserialize_function(proto)
154
+ raise NotImplementedError(
155
+ f"Deserialization of {type(proto)} in from_proto is not implemented. "
156
+ "Use a specific ir.serde.deserialize* function instead."
157
+ )
158
+
159
+
160
+ def to_proto(
161
+ ir_object: _protocols.ModelProtocol
162
+ | _protocols.GraphProtocol
163
+ | _protocols.NodeProtocol
164
+ | _protocols.ValueProtocol
165
+ | _protocols.AttributeProtocol
166
+ | _protocols.ReferenceAttributeProtocol
167
+ | _protocols.TensorProtocol
168
+ | _protocols.TypeProtocol
169
+ | _protocols.GraphViewProtocol
170
+ | _protocols.FunctionProtocol,
171
+ ) -> Any:
172
+ """Serialize an IR object to a proto."""
173
+ if isinstance(ir_object, _protocols.ModelProtocol):
174
+ return serialize_model(ir_object)
175
+ if isinstance(ir_object, _protocols.GraphProtocol):
176
+ return serialize_graph(ir_object)
177
+ if isinstance(ir_object, _protocols.NodeProtocol):
178
+ return serialize_node(ir_object)
179
+ if isinstance(ir_object, _protocols.TensorProtocol):
180
+ return serialize_tensor(ir_object)
181
+ if isinstance(ir_object, _protocols.ValueProtocol):
182
+ return serialize_value(ir_object)
183
+ if isinstance(ir_object, _protocols.AttributeProtocol):
184
+ return serialize_attribute(ir_object)
185
+ if isinstance(ir_object, _protocols.ReferenceAttributeProtocol):
186
+ return serialize_reference_attribute_into(onnx.AttributeProto(), ir_object)
187
+ if isinstance(ir_object, _protocols.TypeProtocol):
188
+ return serialize_type_into(onnx.TypeProto(), ir_object)
189
+ if isinstance(ir_object, _protocols.GraphViewProtocol):
190
+ return serialize_graph(ir_object)
191
+ if isinstance(ir_object, _protocols.FunctionProtocol):
192
+ return serialize_function(ir_object)
193
+ raise NotImplementedError(
194
+ f"Serialization of {type(ir_object)} in to_proto is not implemented. "
195
+ "Use a specific ir.serde.serialize* function instead."
196
+ )
197
+
198
+
199
+ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
200
+ """A tensor initialized from a tensor proto."""
201
+
202
+ def __init__(self, proto: onnx.TensorProto) -> None:
203
+ self._proto = proto
204
+ self._metadata_props: dict[str, str] | None = deserialize_metadata_props(
205
+ proto.metadata_props
206
+ )
207
+ self._metadata: _metadata.MetadataStore | None = None
208
+
209
+ @property
210
+ def name(self) -> str:
211
+ return self._proto.name
212
+
213
+ @name.setter
214
+ def name(self, value: str | None) -> None:
215
+ if value is None:
216
+ self._proto.ClearField("name")
217
+ else:
218
+ self._proto.name = value
219
+
220
+ @property
221
+ def shape(self) -> _core.Shape:
222
+ return _core.Shape(self._proto.dims, frozen=True)
223
+
224
+ @property
225
+ def dtype(self) -> _enums.DataType:
226
+ return _enums.DataType(self._proto.data_type)
227
+
228
+ @property
229
+ def doc_string(self) -> str:
230
+ return self._proto.doc_string
231
+
232
+ @property
233
+ def raw(self) -> onnx.TensorProto:
234
+ return self._proto
235
+
236
+ def __repr__(self) -> str:
237
+ # It is a little hard to display the content when there can be types
238
+ # unsupported by numpy
239
+ # Preferably we should display some content when the tensor is small
240
+ return f"{self._repr_base()}(name={self.name!r})"
241
+
242
+ def __array__(self, dtype: Any = None) -> np.ndarray:
243
+ """Return the tensor as a numpy array, compatible with np.array."""
244
+ return self.numpy().__array__(dtype)
245
+
246
+ def __dlpack__(self, *, stream: Any = None) -> Any:
247
+ return self.numpy().__dlpack__(stream=stream)
248
+
249
+ def __dlpack_device__(self) -> tuple[int, int]:
250
+ return self.numpy().__dlpack_device__()
251
+
252
+ def numpy(self) -> np.ndarray:
253
+ """Return the tensor as a numpy array.
254
+
255
+ This is an improved version of onnx.numpy_helper.to_array.
256
+ It first reads the data using the dtype corresponding to the tensor
257
+ proto data field, then converts it to the correct dtype and shape.
258
+ Special cases are bfloat16, complex and int4 where we need to
259
+ reinterpret the data. Other types can simply be casted.
260
+
261
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
262
+ package are used. The values can be reinterpreted as bit representations
263
+ using the ``.view()`` method.
264
+
265
+ When the data type is a string, this method returns a numpy array
266
+ of bytes instead of a numpy array of strings, to follow the ONNX
267
+ specification.
268
+
269
+ External tensors are not supported by this class. Use
270
+ :class:`onnx_ir.ExternalTensor` instead.
271
+
272
+ Raises:
273
+ ValueError: If the data type is UNDEFINED.
274
+ """
275
+ dtype = self.dtype
276
+ if dtype == _enums.DataType.UNDEFINED:
277
+ raise ValueError("Cannot convert UNDEFINED tensor to numpy array.")
278
+ if self._proto.data_location == onnx.TensorProto.EXTERNAL:
279
+ raise ValueError(
280
+ "Cannot convert external tensor to numpy array. "
281
+ "Use ir.ExternalTensor instead."
282
+ )
283
+
284
+ if self._proto.HasField("raw_data"):
285
+ array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
286
+ # Cannot return now, because we may need to unpack 4bit tensors
287
+ elif dtype == _enums.DataType.STRING:
288
+ return np.array(self._proto.string_data).reshape(self._proto.dims)
289
+ elif self._proto.int32_data:
290
+ array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
291
+ if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}:
292
+ # Reinterpret the int32 as float16 or bfloat16
293
+ array = array.astype(np.uint16).view(dtype.numpy())
294
+ elif dtype in {
295
+ _enums.DataType.FLOAT8E4M3FN,
296
+ _enums.DataType.FLOAT8E4M3FNUZ,
297
+ _enums.DataType.FLOAT8E5M2,
298
+ _enums.DataType.FLOAT8E5M2FNUZ,
299
+ }:
300
+ array = array.astype(np.uint8).view(dtype.numpy())
301
+ elif self._proto.int64_data:
302
+ array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
303
+ elif self._proto.uint64_data:
304
+ array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
305
+ elif self._proto.float_data:
306
+ array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
307
+ if dtype == _enums.DataType.COMPLEX64:
308
+ array = _unflatten_complex(array)
309
+ elif self._proto.double_data:
310
+ array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
311
+ if dtype == _enums.DataType.COMPLEX128:
312
+ array = _unflatten_complex(array)
313
+ else:
314
+ # Empty tensor
315
+ if not self._proto.dims:
316
+ # When dims not precent and there is no data, we return an empty array
317
+ return np.array([], dtype=dtype.numpy())
318
+ else:
319
+ # Otherwise we return a size 0 array with the correct shape
320
+ return np.zeros(self._proto.dims, dtype=dtype.numpy())
321
+
322
+ if dtype == _enums.DataType.INT4:
323
+ return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
324
+ elif dtype == _enums.DataType.UINT4:
325
+ return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
326
+ else:
327
+ # Otherwise convert to the correct dtype and reshape
328
+ # Note we cannot use view() here because the storage dtype may not be the same size as the target
329
+ return array.astype(dtype.numpy()).reshape(self._proto.dims)
330
+
331
+ def tobytes(self) -> bytes:
332
+ """Return the tensor as a byte string conformed to the ONNX specification, in little endian.
333
+
334
+ Raises:
335
+ ValueError: If the tensor is a string tensor or an external tensor.
336
+ ValueError: If the tensor is of UNDEFINED data type.
337
+ """
338
+ if self._proto.data_location == onnx.TensorProto.EXTERNAL:
339
+ raise ValueError(
340
+ "Cannot convert external tensor to bytes. Use ir.ExternalTensor instead."
341
+ )
342
+ if self.dtype == _enums.DataType.STRING:
343
+ raise ValueError("Cannot convert string tensor to bytes.")
344
+ if self.dtype == _enums.DataType.UNDEFINED:
345
+ raise ValueError("Cannot convert UNDEFINED tensor to bytes.")
346
+
347
+ if self._proto.HasField("raw_data"):
348
+ return self._proto.raw_data
349
+ if self._proto.float_data:
350
+ return np.array(
351
+ self._proto.float_data, dtype=_little_endian_dtype(np.float32)
352
+ ).tobytes()
353
+ if self._proto.int32_data:
354
+ array = np.array(self._proto.int32_data, dtype=np.int32)
355
+ if self.dtype in {
356
+ _enums.DataType.INT16,
357
+ _enums.DataType.UINT16,
358
+ _enums.DataType.FLOAT16,
359
+ _enums.DataType.BFLOAT16,
360
+ }:
361
+ return array.astype(_little_endian_dtype(np.uint16)).tobytes()
362
+ if self.dtype in {
363
+ _enums.DataType.INT8,
364
+ _enums.DataType.UINT8,
365
+ _enums.DataType.BOOL,
366
+ _enums.DataType.FLOAT8E4M3FN,
367
+ _enums.DataType.FLOAT8E4M3FNUZ,
368
+ _enums.DataType.FLOAT8E5M2,
369
+ _enums.DataType.FLOAT8E5M2FNUZ,
370
+ _enums.DataType.INT4,
371
+ _enums.DataType.UINT4,
372
+ }:
373
+ # uint4 and int4 values are already packed, even when stored as int32
374
+ # so we don't need to pack them again
375
+ return array.astype(_little_endian_dtype(np.uint8)).tobytes()
376
+ assert self.dtype == _enums.DataType.INT32
377
+ return array.tobytes()
378
+ if self._proto.int64_data:
379
+ return np.array(
380
+ self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
381
+ ).tobytes()
382
+ if self._proto.double_data:
383
+ return np.array(
384
+ self._proto.double_data, dtype=_little_endian_dtype(np.float64)
385
+ ).tobytes()
386
+ if self._proto.uint64_data:
387
+ array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
388
+ if self.dtype == _enums.DataType.UINT32:
389
+ return array.astype(_little_endian_dtype(np.uint32)).tobytes()
390
+ assert self.dtype == _enums.DataType.UINT64
391
+ return array.tobytes()
392
+ # The repeating fields can be empty and still valid.
393
+ # For example, int32_data can be empty and still be a valid tensor.
394
+ return b""
395
+
396
+ @property
397
+ def meta(self) -> _metadata.MetadataStore:
398
+ """The metadata store for intermediate analysis.
399
+
400
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
401
+ to the ONNX proto.
402
+ """
403
+ if self._metadata is None:
404
+ self._metadata = _metadata.MetadataStore()
405
+ return self._metadata
406
+
407
+ @property
408
+ def metadata_props(self) -> dict[str, str]:
409
+ if self._metadata_props is None:
410
+ self._metadata_props = {}
411
+ return self._metadata_props
412
+
413
+
414
+ def _get_field(proto: Any, field: str) -> Any:
415
+ if proto.HasField(field):
416
+ return getattr(proto, field)
417
+ return None
418
+
419
+
420
+ # Deserialization
421
+
422
+
423
+ def deserialize_opset_import(
424
+ protos: Sequence[onnx.OperatorSetIdProto],
425
+ ) -> dict[str, int]:
426
+ return {opset.domain: opset.version for opset in protos}
427
+
428
+
429
+ def _parse_experimental_function_value_info_name(
430
+ name: str,
431
+ ) -> tuple[str, str, str] | None:
432
+ """Get the function domain, name and value name if the value info is for a function.
433
+
434
+ The experimental format is:
435
+ {function_domain}::{function_name}/{value_name}
436
+
437
+ Args:
438
+ name: The name stored in the value info.
439
+
440
+ Returns:
441
+ A tuple of the function domain, function name and value name if the value info is for a function.
442
+ None otherwise.
443
+ """
444
+ parts = name.split("/")
445
+ expected_parts = 2
446
+ if len(parts) != expected_parts:
447
+ return None
448
+ function, value_name = parts
449
+ parts = function.split("::")
450
+ if len(parts) != expected_parts:
451
+ return None
452
+ # NOTE: There will not be overload because overloads are introduced in ONNX IR v10, which also
453
+ # introduces the ValueInfoProto for functions
454
+ function_domain, function_name = parts
455
+ return function_domain, function_name, value_name
456
+
457
+
458
+ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
459
+ graph = _deserialize_graph(proto.graph, [])
460
+ graph.opset_imports.update(deserialize_opset_import(proto.opset_import))
461
+
462
+ functions = []
463
+ for func in proto.functions:
464
+ functions.append(deserialize_function(func))
465
+
466
+ model = _core.Model(
467
+ graph,
468
+ ir_version=proto.ir_version,
469
+ producer_name=_get_field(proto, "producer_name"),
470
+ producer_version=_get_field(proto, "producer_version"),
471
+ domain=_get_field(proto, "domain"),
472
+ model_version=_get_field(proto, "model_version"),
473
+ doc_string=_get_field(proto, "doc_string"),
474
+ functions=functions,
475
+ meta_data_props=deserialize_metadata_props(proto.metadata_props),
476
+ )
477
+
478
+ # Handle experimental value info for functions created by the dynamo exporter in IR version 9
479
+ if model.ir_version < _FUNCTION_VALUE_INFO_SUPPORTED_VERSION:
480
+ _deserialized_experimental_value_info_for_function_ir9(
481
+ model.functions, proto.graph.value_info
482
+ )
483
+
484
+ return model
485
+
486
+
487
+ def _deserialized_experimental_value_info_for_function_ir9(
488
+ functions: Mapping[_protocols.OperatorIdentifier, _core.Function],
489
+ value_info_protos: Sequence[onnx.ValueInfoProto],
490
+ ) -> None:
491
+ """Deserialize value info for functions when they are stored in an experimental format.
492
+
493
+ The experimental format is:
494
+ {function_domain}::{function_name}/{value_name}
495
+ """
496
+ # Parse value info for functions from the main graph
497
+ function_value_value_info_mapping: collections.defaultdict[
498
+ _protocols.OperatorIdentifier,
499
+ dict[str, onnx.ValueInfoProto],
500
+ ] = collections.defaultdict(dict)
501
+ for value_info_proto in value_info_protos:
502
+ if (
503
+ parsed := _parse_experimental_function_value_info_name(value_info_proto.name)
504
+ ) is None:
505
+ continue
506
+ function_domain, function_name, value_name = parsed
507
+ function_overload = ""
508
+ # TODO(justinchuby): Create a constructor for OperatorIdentifier so we don't create tuples manually
509
+ function_id = (function_domain, function_name, function_overload)
510
+ function = functions.get(function_id)
511
+ if function is None:
512
+ # Function not found
513
+ logger.debug(
514
+ "Function with ID '%s' not found in model functions. Value info '%s' will be ignored.",
515
+ function_id,
516
+ value_info_proto.name,
517
+ )
518
+ continue
519
+ function_value_value_info_mapping[function_id][value_name] = value_info_proto
520
+ for function_id, function in functions.items():
521
+ for input in function.inputs:
522
+ if input.name in function_value_value_info_mapping[function_id]:
523
+ deserialize_value_info_proto(
524
+ function_value_value_info_mapping[function_id][input.name], input
525
+ )
526
+ for node in function:
527
+ for output in node.outputs:
528
+ if output.name in function_value_value_info_mapping[function_id]:
529
+ deserialize_value_info_proto(
530
+ function_value_value_info_mapping[function_id][output.name],
531
+ output,
532
+ )
533
+ # The function outputs are handled as well because they are also node outputs
534
+
535
+
536
+ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
537
+ """Deserialize a graph proto, recursively if needed.
538
+
539
+ Args:
540
+ proto: The graph proto to deserialize.
541
+
542
+ Returns:
543
+ IR Graph.
544
+ """
545
+ return _deserialize_graph(proto, [])
546
+
547
+
548
+ @_capture_errors(lambda proto, scoped_values: proto.name)
549
+ def _deserialize_graph(
550
+ proto: onnx.GraphProto, scoped_values: list[dict[str, _core.Value]]
551
+ ) -> _core.Graph:
552
+ """Deserialize a graph proto, recursively if needed.
553
+
554
+ Args:
555
+ proto: The graph proto to deserialize.
556
+ scoped_values: A list of dictionaries mapping value names to their corresponding Value objects.
557
+ Every time we enter a new graph, a new scope is created and appended to this list to include
558
+ all values defined in the scope.
559
+ scoped_value_info: A list of dictionaries mapping value names to their corresponding ValueInfoProto.
560
+
561
+ Returns:
562
+ IR Graph.
563
+ """
564
+ # Create values for initializers and inputs
565
+ initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
566
+ inputs = [_core.Input(info.name) for info in proto.input]
567
+ for info, value in zip(proto.input, inputs):
568
+ deserialize_value_info_proto(info, value)
569
+
570
+ # Initialize the values dictionary for this graph scope with the inputs and initializers
571
+ values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
572
+ scoped_values.append(values)
573
+ initializer_values = []
574
+ for tensor in initializer_tensors:
575
+ if tensor.name in values:
576
+ # The initializer is for an input
577
+ initializer_value = values[tensor.name]
578
+ initializer_value.const_value = tensor
579
+ else:
580
+ # The initializer is for some other value. Create this value first
581
+ initializer_value = _core.Value(
582
+ None,
583
+ index=None,
584
+ name=tensor.name,
585
+ # TODO(justinchuby): Fix type hinting for shape and dtype
586
+ shape=tensor.shape, # type: ignore
587
+ type=_core.TensorType(tensor.dtype),
588
+ const_value=tensor,
589
+ )
590
+ values[tensor.name] = initializer_value # type: ignore[index]
591
+ initializer_values.append(initializer_value)
592
+
593
+ # Add ValueInfos for this graph scope
594
+ value_info = {info.name: info for info in proto.value_info}
595
+
596
+ # Deserialize nodes with all known values
597
+ nodes = [_deserialize_node(node, scoped_values, value_info) for node in proto.node]
598
+
599
+ # Fill in values for graph outputs
600
+ outputs = [deserialize_value_info_proto(info, values[info.name]) for info in proto.output]
601
+ scoped_values.pop()
602
+ return _core.Graph(
603
+ inputs,
604
+ outputs,
605
+ nodes=nodes,
606
+ initializers=initializer_values,
607
+ doc_string=_get_field(proto, "doc_string"),
608
+ name=_get_field(proto, "name"),
609
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
610
+ )
611
+
612
+
613
+ @_capture_errors(lambda proto: proto.name)
614
+ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
615
+ inputs = [_core.Input(name) for name in proto.input]
616
+ values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
617
+ value_info = {info.name: info for info in getattr(proto, "value_info", [])}
618
+
619
+ # TODO(justinchuby): Handle unsorted nodes
620
+ nodes = [_deserialize_node(node, [values], value_info=value_info) for node in proto.node]
621
+ outputs = [values[name] for name in proto.output]
622
+ graph = _core.Graph(
623
+ inputs,
624
+ outputs,
625
+ nodes=nodes,
626
+ initializers=(),
627
+ doc_string=_get_field(proto, "doc_string"),
628
+ opset_imports=deserialize_opset_import(proto.opset_import),
629
+ name=(
630
+ f"{proto.name}_{proto.domain}" + f"__{proto.overload}"
631
+ if hasattr(proto, "overload") and proto.overload
632
+ else ""
633
+ ),
634
+ )
635
+ attributes = [_deserialize_attribute(attr, []) for attr in proto.attribute_proto]
636
+ # Attributes without defaults
637
+ attributes += [
638
+ _core.Attr(name, _enums.AttributeType.UNDEFINED, None) for name in proto.attribute
639
+ ]
640
+ return _core.Function(
641
+ domain=proto.domain,
642
+ name=proto.name,
643
+ overload=getattr(proto, "overload", ""),
644
+ graph=graph,
645
+ attributes=typing.cast(List[_core.Attr], attributes),
646
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
647
+ )
648
+
649
+
650
+ @_capture_errors(lambda proto, value: str(proto))
651
+ def deserialize_value_info_proto(
652
+ proto: onnx.ValueInfoProto, value: _core.Value | None
653
+ ) -> _core.Value:
654
+ if value is None:
655
+ value = _core.Value(name=proto.name)
656
+ value.shape = deserialize_type_proto_for_shape(proto.type)
657
+ value.type = deserialize_type_proto_for_type(proto.type)
658
+ metadata_props = deserialize_metadata_props(proto.metadata_props)
659
+ if metadata_props is not None:
660
+ value.metadata_props.update(metadata_props)
661
+ value.doc_string = _get_field(proto, "doc_string")
662
+ return value
663
+
664
+
665
+ @_capture_errors(str)
666
+ def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
667
+ if proto.HasField("tensor_type"):
668
+ if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
669
+ return None
670
+ # This logic handles when the shape is [] as well
671
+ dim_protos = shape_proto.dim
672
+ deserialized_dim_denotations = [
673
+ deserialize_dimension(dim_proto) for dim_proto in dim_protos
674
+ ]
675
+ dims = [dim for dim, _ in deserialized_dim_denotations]
676
+ denotations = [denotation for _, denotation in deserialized_dim_denotations]
677
+ return _core.Shape(dims, denotations=denotations, frozen=True)
678
+ if proto.HasField("sparse_tensor_type"):
679
+ if (shape_proto := _get_field(proto.sparse_tensor_type, "shape")) is None:
680
+ return None
681
+ dim_protos = shape_proto.dim
682
+ deserialized_dim_denotations = [
683
+ deserialize_dimension(dim_proto) for dim_proto in dim_protos
684
+ ]
685
+ dims = [dim for dim, _ in deserialized_dim_denotations]
686
+ denotations = [denotation for _, denotation in deserialized_dim_denotations]
687
+ return _core.Shape(dims, denotations=denotations, frozen=True)
688
+ if proto.HasField("sequence_type"):
689
+ if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
690
+ return None
691
+ return deserialize_type_proto_for_shape(elem_type)
692
+ if proto.HasField("optional_type"):
693
+ if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
694
+ return None
695
+ return deserialize_type_proto_for_shape(elem_type)
696
+ if proto.HasField("map_type"):
697
+ # TODO(justinchuby): Do we need to support map types?
698
+ raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}")
699
+
700
+ return None
701
+
702
+
703
+ @_capture_errors(str)
704
+ def deserialize_type_proto_for_type(
705
+ proto: onnx.TypeProto,
706
+ ) -> _protocols.TypeProtocol | None:
707
+ denotation = _get_field(proto, "denotation")
708
+ if proto.HasField("tensor_type"):
709
+ if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None:
710
+ return None
711
+ return _core.TensorType(_enums.DataType(elem_type), denotation=denotation)
712
+ if proto.HasField("sparse_tensor_type"):
713
+ if (elem_type := _get_field(proto.sparse_tensor_type, "elem_type")) is None:
714
+ return None
715
+ return _core.SparseTensorType(_enums.DataType(elem_type), denotation=denotation)
716
+ if proto.HasField("sequence_type"):
717
+ # FIXME(justinchuby): Allow nested types being None
718
+ if (elem_type := _get_field(proto.sequence_type, "elem_type")) is None:
719
+ raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
720
+ nested_type = deserialize_type_proto_for_type(elem_type)
721
+ if nested_type is None:
722
+ raise ValueError(f"SequenceType must have elem_type set: {proto}")
723
+ return _core.SequenceType(nested_type, denotation=denotation)
724
+ if proto.HasField("optional_type"):
725
+ # FIXME(justinchuby): Allow nested types being None
726
+ if (elem_type := _get_field(proto.optional_type, "elem_type")) is None:
727
+ raise ValueError(f"SequenceTypeProto must have elem_type set: {proto}")
728
+ nested_type = deserialize_type_proto_for_type(elem_type)
729
+ if nested_type is None:
730
+ raise ValueError(f"SequenceType must have elem_type set: {proto}")
731
+ return _core.OptionalType(nested_type, denotation=denotation)
732
+ if proto.HasField("map_type"):
733
+ # TODO(justinchuby): Do we need to support map types?
734
+ raise NotImplementedError(f"Map types are not supported yet. {_PLEASE_CONTRIBUTE}")
735
+
736
+ return None
737
+
738
+
739
+ @_capture_errors(str)
740
+ def deserialize_dimension(
741
+ proto: onnx.TensorShapeProto.Dimension,
742
+ ) -> tuple[int | _core.SymbolicDim, str | None]:
743
+ """Deserialize a dimension proto into (dimension, denotation).
744
+
745
+ Args:
746
+ proto: The dimension proto to deserialize.
747
+
748
+ Returns:
749
+ A tuple of the dimension and its denotation.
750
+ """
751
+ value_field = proto.WhichOneof("value")
752
+ denotation = _get_field(proto, "denotation")
753
+ if value_field is not None:
754
+ value = getattr(proto, value_field)
755
+ if value_field == "dim_value":
756
+ return value, denotation
757
+ if value_field == "dim_param":
758
+ return _core.SymbolicDim(value), denotation
759
+ return _core.SymbolicDim(None), denotation
760
+
761
+
762
+ @_capture_errors(lambda proto, base_path: proto.name)
763
+ def deserialize_tensor(
764
+ proto: onnx.TensorProto, base_path: str | os.PathLike = ""
765
+ ) -> _protocols.TensorProtocol:
766
+ # TODO: Sanitize base_path
767
+ if proto.data_location == onnx.TensorProto.EXTERNAL:
768
+ external_info = onnx.external_data_helper.ExternalDataInfo(proto)
769
+ return _core.ExternalTensor(
770
+ external_info.location,
771
+ offset=external_info.offset,
772
+ length=external_info.length,
773
+ dtype=_enums.DataType(proto.data_type),
774
+ base_dir=base_path,
775
+ name=_get_field(proto, "name"),
776
+ shape=_core.Shape(proto.dims),
777
+ doc_string=_get_field(proto, "doc_string"),
778
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
779
+ )
780
+ if proto.data_type == _enums.DataType.STRING:
781
+ name = _get_field(proto, "name")
782
+ doc_string = _get_field(proto, "doc_string")
783
+ metadata_props = deserialize_metadata_props(proto.metadata_props)
784
+ return _core.StringTensor(
785
+ proto.string_data,
786
+ shape=_core.Shape(proto.dims),
787
+ name=name,
788
+ doc_string=doc_string,
789
+ metadata_props=metadata_props,
790
+ )
791
+ return TensorProtoTensor(proto)
792
+
793
+
794
+ def deserialize_metadata_props(
795
+ proto: Sequence[onnx.StringStringEntryProto],
796
+ ) -> dict[str, str] | None:
797
+ if len(proto) == 0:
798
+ # Avoid creating an empty dictionary to save memory
799
+ return None
800
+ return {entry.key: entry.value for entry in proto}
801
+
802
+
803
+ def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr | _core.RefAttr:
804
+ return _deserialize_attribute(proto, [])
805
+
806
+
807
+ @_capture_errors(lambda proto, scoped_values: str(proto))
808
+ def _deserialize_attribute(
809
+ proto: onnx.AttributeProto, scoped_values: list[dict[str, _core.Value]]
810
+ ) -> _core.Attr | _core.RefAttr:
811
+ name = proto.name
812
+ doc_string = _get_field(proto, "doc_string")
813
+ type_ = _enums.AttributeType(proto.type)
814
+ ref_attr_name = _get_field(proto, "ref_attr_name")
815
+ if ref_attr_name:
816
+ return _core.RefAttr(name, ref_attr_name, type_, doc_string=doc_string)
817
+
818
+ if type_ == _enums.AttributeType.INT:
819
+ return _core.AttrInt64(name, proto.i, doc_string=doc_string)
820
+ if type_ == _enums.AttributeType.FLOAT:
821
+ return _core.AttrFloat32(name, proto.f, doc_string=doc_string)
822
+ if type_ == _enums.AttributeType.STRING:
823
+ return _core.AttrString(name, proto.s.decode("utf-8"), doc_string=doc_string)
824
+ if type_ == _enums.AttributeType.INTS:
825
+ return _core.AttrInt64s(name, proto.ints, doc_string=doc_string)
826
+ if type_ == _enums.AttributeType.FLOATS:
827
+ return _core.AttrFloat32s(name, proto.floats, doc_string=doc_string)
828
+ if type_ == _enums.AttributeType.STRINGS:
829
+ return _core.AttrStrings(
830
+ name, [s.decode("utf-8") for s in proto.strings], doc_string=doc_string
831
+ )
832
+ if type_ == _enums.AttributeType.TENSOR:
833
+ return _core.AttrTensor(name, deserialize_tensor(proto.t), doc_string=doc_string)
834
+ if type_ == _enums.AttributeType.GRAPH:
835
+ return _core.AttrGraph(
836
+ name, _deserialize_graph(proto.g, scoped_values), doc_string=doc_string
837
+ )
838
+ if type_ == _enums.AttributeType.TENSORS:
839
+ return _core.AttrTensors(
840
+ name,
841
+ [deserialize_tensor(t) for t in proto.tensors],
842
+ doc_string=doc_string,
843
+ )
844
+ if type_ == _enums.AttributeType.GRAPHS:
845
+ return _core.AttrGraphs(
846
+ name,
847
+ [_deserialize_graph(g, scoped_values) for g in proto.graphs],
848
+ doc_string=doc_string,
849
+ )
850
+ if type_ == _enums.AttributeType.SPARSE_TENSOR:
851
+ raise NotImplementedError(
852
+ f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
853
+ )
854
+ if type_ == _enums.AttributeType.SPARSE_TENSORS:
855
+ raise NotImplementedError(
856
+ f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
857
+ )
858
+ if type_ == _enums.AttributeType.TYPE_PROTO:
859
+ ir_type = deserialize_type_proto_for_type(proto.tp)
860
+ shape = deserialize_type_proto_for_shape(proto.tp)
861
+ return _core.AttrTypeProto(
862
+ name, _core.TypeAndShape(ir_type, shape), doc_string=doc_string
863
+ )
864
+ if type_ == _enums.AttributeType.TYPE_PROTOS:
865
+ type_and_shapes = []
866
+ for type_proto in proto.type_protos:
867
+ ir_type = deserialize_type_proto_for_type(type_proto)
868
+ shape = deserialize_type_proto_for_shape(type_proto)
869
+ type_and_shapes.append(_core.TypeAndShape(ir_type, shape))
870
+ return _core.AttrTypeProtos(name, type_and_shapes, doc_string=doc_string)
871
+ if type_ == _enums.AttributeType.UNDEFINED:
872
+ return _core.Attr(name, type_, None, doc_string=doc_string)
873
+ raise ValueError(f"Unsupported attribute type: '{type_}'")
874
+
875
+
876
+ def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
877
+ return _deserialize_node(proto, scoped_values=[], value_info={})
878
+
879
+
880
+ @_capture_errors(lambda proto, scoped_values, value_info: str(proto))
881
+ def _deserialize_node(
882
+ proto: onnx.NodeProto,
883
+ scoped_values: list[dict[str, _core.Value]],
884
+ value_info: dict[str, onnx.ValueInfoProto],
885
+ ) -> _core.Node:
886
+ node_inputs: list[_core.Value | None] = []
887
+ for input_name in proto.input:
888
+ if input_name == "":
889
+ # Empty input
890
+ node_inputs.append(None)
891
+ continue
892
+
893
+ # Find the input in all value scopes
894
+ found = False
895
+ for values in reversed(scoped_values):
896
+ if input_name not in values:
897
+ continue
898
+ node_inputs.append(values[input_name])
899
+ found = True
900
+ del values # Remove the reference so it is not used by mistake
901
+ break
902
+ if not found:
903
+ # If the input is not found, we know the graph may be unsorted and
904
+ # the input may be a supposed-to-be initializer or an output of a node that comes later.
905
+ # Here we create the value with the name and add it to the current scope.
906
+ # Nodes need to check the value pool for potentially initialized outputs
907
+ logger.warning(
908
+ "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. "
909
+ "The graph may be unsorted. Creating a new input (current depth: %s) .",
910
+ input_name,
911
+ proto.name,
912
+ proto.domain,
913
+ proto.op_type,
914
+ getattr(proto, "overload", ""),
915
+ len(scoped_values),
916
+ )
917
+ if len(scoped_values) > 1:
918
+ logger.warning(
919
+ "Caveat: The value is created in the subgraph. If "
920
+ "the node is referencing a value that is not in the current graph, "
921
+ "it is impossible to create it in the correct scope.",
922
+ )
923
+ value = _core.Value(name=input_name)
924
+ # Fill in shape/type information if they exist
925
+ if input_name in value_info:
926
+ deserialize_value_info_proto(value_info[input_name], value)
927
+ node_inputs.append(value)
928
+ # We can only create the value in the current scope. If the subgraph is
929
+ # referencing a value that is not in the current scope, it is impossible
930
+ # to create it in the correct scope.
931
+ scoped_values[-1][input_name] = value
932
+
933
+ # Build the output values for the node.
934
+ node_outputs: list[_core.Value] = []
935
+ for output_name in proto.output:
936
+ if output_name == "":
937
+ # Empty output
938
+ node_outputs.append(_core.Value(name=""))
939
+ continue
940
+
941
+ # 1. When the graph is unsorted, we may be able to find the output already created
942
+ # as an input to some other nodes in the current scope.
943
+ # Note that a value is always owned by the producing node. Even though a value
944
+ # can be created when parsing inputs of other nodes, the new node created here
945
+ # that produces the value will assume ownership. It is then impossible to transfer
946
+ # the ownership to any other node.
947
+
948
+ # The output can only be found in the current scope. It is impossible for
949
+ # a node to produce an output that is not in its own scope.
950
+ current_scope = scoped_values[-1]
951
+ if output_name in current_scope:
952
+ value = current_scope[output_name]
953
+ else:
954
+ # 2. Common scenario: the graph is sorted and this is the first time we see the output.
955
+ # Create the value and add it to the current scope.
956
+ value = _core.Value(name=output_name)
957
+ current_scope[output_name] = value
958
+ # Fill in shape/type information if they exist
959
+ if output_name in value_info:
960
+ deserialize_value_info_proto(value_info[output_name], value)
961
+ else:
962
+ logger.debug(
963
+ "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
964
+ output_name,
965
+ proto.name,
966
+ proto.op_type,
967
+ )
968
+ node_outputs.append(value)
969
+ return _core.Node(
970
+ proto.domain,
971
+ proto.op_type,
972
+ node_inputs,
973
+ [_deserialize_attribute(a, scoped_values) for a in proto.attribute],
974
+ overload=getattr(proto, "overload", ""),
975
+ outputs=node_outputs,
976
+ name=proto.name,
977
+ doc_string=_get_field(proto, "doc_string"),
978
+ metadata_props=deserialize_metadata_props(proto.metadata_props),
979
+ )
980
+
981
+
982
+ # Serialization
983
+
984
+
985
+ def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto:
986
+ return serialize_model_into(onnx.ModelProto(), from_=model)
987
+
988
+
989
+ @_capture_errors(
990
+ lambda model_proto, from_: (
991
+ f"ir_version={from_.ir_version}, producer_name={from_.producer_name}, "
992
+ f"producer_version={from_.producer_version}, domain={from_.domain}, "
993
+ )
994
+ )
995
+ def serialize_model_into(
996
+ model_proto: onnx.ModelProto, from_: _protocols.ModelProtocol
997
+ ) -> onnx.ModelProto:
998
+ """Serialize an IR model to an ONNX model proto."""
999
+ model_proto.ir_version = from_.ir_version
1000
+ if from_.producer_name:
1001
+ model_proto.producer_name = from_.producer_name
1002
+ if from_.producer_version:
1003
+ model_proto.producer_version = from_.producer_version
1004
+ if from_.domain:
1005
+ model_proto.domain = from_.domain
1006
+ if from_.model_version:
1007
+ model_proto.model_version = from_.model_version
1008
+ if from_.doc_string:
1009
+ model_proto.doc_string = from_.doc_string
1010
+ # Sort names for deterministic serialization
1011
+ _serialize_opset_imports_into(model_proto.opset_import, from_.opset_imports)
1012
+ if from_.metadata_props:
1013
+ _serialize_metadata_props_into(model_proto.metadata_props, from_.metadata_props)
1014
+ serialize_graph_into(model_proto.graph, from_.graph)
1015
+
1016
+ create_value_info_in_functions = from_.ir_version >= _FUNCTION_VALUE_INFO_SUPPORTED_VERSION
1017
+ for func in from_.functions.values():
1018
+ serialize_function_into(
1019
+ model_proto.functions.add(),
1020
+ from_=func,
1021
+ create_value_info=create_value_info_in_functions,
1022
+ )
1023
+ if not create_value_info_in_functions:
1024
+ # Create them in the main graph instead
1025
+ _serialize_experimental_value_info_for_function_ir9_into(model_proto.graph, func)
1026
+ return model_proto
1027
+
1028
+
1029
+ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool:
1030
+ """Check if value info should be created for a value.
1031
+
1032
+ Args:
1033
+ value: The value to check.
1034
+
1035
+ Returns:
1036
+ True if value info should be created for the value.
1037
+ """
1038
+ # No need to serialize value info if it is not set
1039
+ return not (value.shape is None and value.type is None)
1040
+
1041
+
1042
+ def _serialize_experimental_value_info_for_function_ir9_into(
1043
+ graph_proto: onnx.GraphProto, function: _protocols.FunctionProtocol
1044
+ ) -> None:
1045
+ """Serialize value info for functions in an experimental format for IR version 9.
1046
+
1047
+ Because IRv9 and older does not have ValueInfoProto for functions, we give the value info
1048
+ special names and store them in the main graph instead.
1049
+
1050
+ The experimental format is:
1051
+ {function_domain}::{function_name}/{value_name}
1052
+
1053
+ Args:
1054
+ graph_proto: The graph proto to create ValueInfoProto in.
1055
+ function: The function to serialize.
1056
+ """
1057
+ # TODO(justinchuby): In the future, we can decide if it is a good idea to simply iterate over
1058
+ # all values in the function and call serialize_value_into instead.
1059
+ function_qualified_name = f"{function.domain}::{function.name}"
1060
+
1061
+ def format_name(value_name: str) -> str:
1062
+ return f"{function_qualified_name}/{value_name}"
1063
+
1064
+ for input in function.inputs:
1065
+ if not input.name:
1066
+ logging.warning(
1067
+ "Function '%s': Value name not set for function input: %s",
1068
+ function_qualified_name,
1069
+ input,
1070
+ )
1071
+ continue
1072
+ if not _should_create_value_info_for_value(input):
1073
+ # No need to serialize value info if it is not set
1074
+ continue
1075
+ serialize_value_into(graph_proto.value_info.add(), input, name=format_name(input.name))
1076
+ for node in function:
1077
+ for node_output in node.outputs:
1078
+ if not node_output.name:
1079
+ logging.warning(
1080
+ "Function '%s': Value name not set for node output: %s",
1081
+ function_qualified_name,
1082
+ node_output,
1083
+ )
1084
+ continue
1085
+ if not _should_create_value_info_for_value(node_output):
1086
+ # No need to serialize value info if it is not set
1087
+ continue
1088
+ serialize_value_into(
1089
+ graph_proto.value_info.add(),
1090
+ node_output,
1091
+ name=format_name(node_output.name),
1092
+ )
1093
+
1094
+
1095
+ def _serialize_opset_imports_into(
1096
+ opset_ids: proto_containers.RepeatedCompositeFieldContainer[onnx.OperatorSetIdProto],
1097
+ from_: Mapping[str, int],
1098
+ ) -> None:
1099
+ """Serialize opset imports into a repeated field of OperatorSetId protos.
1100
+
1101
+ Args:
1102
+ opset_ids: The repeated field to serialize into.
1103
+ from_: The mapping of opset domains to versions to serialize.
1104
+ """
1105
+ # Sort names for deterministic serialization
1106
+ for domain, version in from_.items():
1107
+ opset_ids.add(domain=domain, version=version)
1108
+
1109
+
1110
+ def _serialize_metadata_props_into(
1111
+ string_string_entries: proto_containers.RepeatedCompositeFieldContainer[
1112
+ onnx.StringStringEntryProto
1113
+ ],
1114
+ from_: Mapping[str, str],
1115
+ ) -> None:
1116
+ """Serialize metadata properties into a repeated field of string-string entries.
1117
+
1118
+ Args:
1119
+ string_string_entries: The repeated field to serialize into.
1120
+ from_: The mapping of metadata properties to serialize.
1121
+ """
1122
+ # Sort names for deterministic serialization
1123
+ for key in sorted(from_):
1124
+ string_string_entries.add(key=key, value=from_[key])
1125
+
1126
+
1127
+ def serialize_graph(
1128
+ graph: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
1129
+ ) -> onnx.GraphProto:
1130
+ """Serializes the given graph into an :class:`onnx.GraphProto`.
1131
+
1132
+ When the graph initializers do not have `const_value` set, they will be skipped.
1133
+
1134
+ Args:
1135
+ graph: The graph to be serialized.
1136
+
1137
+ Returns:
1138
+ The serialized ONNX GraphProto object.
1139
+ """
1140
+ graph_proto = onnx.GraphProto()
1141
+ serialize_graph_into(graph_proto, from_=graph)
1142
+ return graph_proto
1143
+
1144
+
1145
+ @_capture_errors(
1146
+ lambda graph_proto, from_: (
1147
+ f"name={from_.name}, doc_string={from_.doc_string}, "
1148
+ f"len(inputs)={len(from_.inputs)}, len(initializers)={len(from_.initializers)}, "
1149
+ f"len(nodes)={len(from_)}, len(outputs)={len(from_.outputs)}, metadata_props={from_.metadata_props}"
1150
+ )
1151
+ )
1152
+ def serialize_graph_into(
1153
+ graph_proto: onnx.GraphProto,
1154
+ from_: _protocols.GraphProtocol | _protocols.GraphViewProtocol,
1155
+ ) -> None:
1156
+ if from_.name:
1157
+ graph_proto.name = from_.name
1158
+ if from_.doc_string:
1159
+ graph_proto.doc_string = from_.doc_string
1160
+ for input_ in from_.inputs:
1161
+ serialize_value_into(graph_proto.input.add(), input_)
1162
+ # TODO(justinchuby): Support sparse_initializer
1163
+ for initializer in from_.initializers.values():
1164
+ if initializer.const_value is None:
1165
+ # Skip initializers without constant values
1166
+ logger.warning(
1167
+ "Initializer '%s' does not have a constant value set.", initializer.name
1168
+ )
1169
+ continue
1170
+ # Make sure the tensor's name is the same as the value's name
1171
+ initializer.const_value.name = initializer.name
1172
+ serialize_tensor_into(graph_proto.initializer.add(), from_=initializer.const_value)
1173
+ for node in from_:
1174
+ serialize_node_into(graph_proto.node.add(), from_=node)
1175
+ for node_output in node.outputs:
1176
+ if not _should_create_value_info_for_value(node_output):
1177
+ # No need to serialize value info if it is not set
1178
+ continue
1179
+ if node_output.is_graph_output():
1180
+ # No need to serialize value info for these outputs because they are also graph outputs
1181
+ continue
1182
+ serialize_value_into(graph_proto.value_info.add(), node_output)
1183
+ for output in from_.outputs:
1184
+ serialize_value_into(graph_proto.output.add(), from_=output)
1185
+ if from_.metadata_props:
1186
+ _serialize_metadata_props_into(graph_proto.metadata_props, from_.metadata_props)
1187
+
1188
+
1189
+ def serialize_function(
1190
+ function: _protocols.FunctionProtocol, *, create_value_info: bool = True
1191
+ ) -> onnx.FunctionProto:
1192
+ """Serialize an IR function as a FunctionProto.
1193
+
1194
+ Args:
1195
+ function: The function to serialize.
1196
+ create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
1197
+ starting from ONNX IR version 10.
1198
+ """
1199
+ function_proto = onnx.FunctionProto()
1200
+ serialize_function_into(
1201
+ function_proto, from_=function, create_value_info=create_value_info
1202
+ )
1203
+ return function_proto
1204
+
1205
+
1206
+ @_capture_errors(lambda function_proto, from_, create_value_info: repr(from_))
1207
+ def serialize_function_into(
1208
+ function_proto: onnx.FunctionProto,
1209
+ from_: _protocols.FunctionProtocol,
1210
+ *,
1211
+ create_value_info: bool = True,
1212
+ ) -> None:
1213
+ """Serialize an IR function into a FunctionProto.
1214
+
1215
+ Args:
1216
+ function_proto: The proto to serialize into.
1217
+ from_: The function to serialize.
1218
+ create_value_info: Whether to create ValueInfoProto for nodes in the function. This is supported
1219
+ starting from ONNX IR version 10.
1220
+ """
1221
+ if from_.domain:
1222
+ function_proto.domain = from_.domain
1223
+ if from_.name:
1224
+ function_proto.name = from_.name
1225
+ if from_.overload:
1226
+ function_proto.overload = from_.overload
1227
+ if from_.doc_string:
1228
+ function_proto.doc_string = from_.doc_string
1229
+ if from_.opset_imports:
1230
+ # A valid ONNX graph should have at least one opset import, that is
1231
+ # the default ONNX opset.
1232
+ # Here we check for emptiness before serializing to keep the logic consistent
1233
+ _serialize_opset_imports_into(function_proto.opset_import, from_.opset_imports)
1234
+ if from_.metadata_props:
1235
+ _serialize_metadata_props_into(function_proto.metadata_props, from_.metadata_props)
1236
+ for input_ in from_.inputs:
1237
+ function_proto.input.append(input_.name)
1238
+ if not _should_create_value_info_for_value(input_):
1239
+ # No need to serialize value info if it is not set
1240
+ continue
1241
+ if not create_value_info:
1242
+ continue
1243
+ serialize_value_into(function_proto.value_info.add(), input_)
1244
+ for attr in from_.attributes.values():
1245
+ if attr.value is not None:
1246
+ serialize_attribute_into(function_proto.attribute_proto.add(), from_=attr)
1247
+ else:
1248
+ # ONNX does not record type information if the attribute does not have a default
1249
+ function_proto.attribute.append(attr.name)
1250
+ for func_output in from_.outputs:
1251
+ function_proto.output.append(func_output.name)
1252
+ # No need to serialize value info for function outputs because they are
1253
+ # also node outputs
1254
+ for node in from_:
1255
+ serialize_node_into(function_proto.node.add(), from_=node)
1256
+ # Record value info for outputs
1257
+ for node_output in node.outputs:
1258
+ if not _should_create_value_info_for_value(node_output):
1259
+ # No need to serialize value info if it is not set
1260
+ continue
1261
+ if not create_value_info:
1262
+ continue
1263
+ serialize_value_into(function_proto.value_info.add(), node_output)
1264
+
1265
+
1266
+ def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
1267
+ node_proto = onnx.NodeProto()
1268
+ serialize_node_into(node_proto, from_=node)
1269
+ return node_proto
1270
+
1271
+
1272
+ @_capture_errors(lambda node_proto, from_: repr(from_))
1273
+ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtocol) -> None:
1274
+ node_proto.op_type = from_.op_type
1275
+ if from_.domain:
1276
+ # If the domain is "", we can assume the default domain and not set it
1277
+ node_proto.domain = from_.domain
1278
+ if from_.name:
1279
+ node_proto.name = from_.name
1280
+ if from_.overload:
1281
+ node_proto.overload = from_.overload
1282
+ if from_.doc_string:
1283
+ node_proto.doc_string = from_.doc_string
1284
+ if from_.metadata_props:
1285
+ _serialize_metadata_props_into(node_proto.metadata_props, from_.metadata_props)
1286
+ for input_ in from_.inputs:
1287
+ if input_ is None:
1288
+ node_proto.input.append("")
1289
+ else:
1290
+ node_proto.input.append(input_.name)
1291
+ for output in from_.outputs:
1292
+ node_proto.output.append(output.name)
1293
+ for attr in from_.attributes.values():
1294
+ if isinstance(attr, _core.Attr):
1295
+ serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1296
+ elif isinstance(attr, _core.RefAttr):
1297
+ serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
1298
+ # Handle protocol attributes for completeness. We do not check them first because
1299
+ # calling isinstance on a protocol can be slow.
1300
+ # Most of the time, we will have Attr or RefAttr so the two branches below
1301
+ # will not be taken.
1302
+ elif isinstance(attr, _protocols.AttributeProtocol):
1303
+ serialize_attribute_into(node_proto.attribute.add(), from_=attr)
1304
+ elif isinstance(attr, _protocols.ReferenceAttributeProtocol):
1305
+ serialize_reference_attribute_into(node_proto.attribute.add(), from_=attr)
1306
+ else:
1307
+ raise TypeError(f"Unsupported attribute type: {type(attr)}")
1308
+
1309
+
1310
+ def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
1311
+ tensor_proto = onnx.TensorProto()
1312
+ serialize_tensor_into(tensor_proto, from_=tensor)
1313
+ return tensor_proto
1314
+
1315
+
1316
+ @_capture_errors(lambda tensor_proto, from_: repr(from_))
1317
+ def serialize_tensor_into(
1318
+ tensor_proto: onnx.TensorProto, from_: _protocols.TensorProtocol
1319
+ ) -> None:
1320
+ if isinstance(from_, TensorProtoTensor):
1321
+ # Directly copy from the tensor proto if it is available
1322
+ tensor_proto.CopyFrom(from_.raw)
1323
+ if from_.metadata_props:
1324
+ _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
1325
+ return
1326
+
1327
+ if from_.name:
1328
+ tensor_proto.name = from_.name
1329
+ if from_.doc_string:
1330
+ tensor_proto.doc_string = from_.doc_string
1331
+ tensor_proto.data_type = from_.dtype.value
1332
+ tensor_proto.dims.extend(from_.shape.numpy())
1333
+ if isinstance(from_, _core.ExternalTensor):
1334
+ # Store external tensors as is
1335
+ tensor_proto.data_location = onnx.TensorProto.EXTERNAL
1336
+ for k, v in {
1337
+ "location": os.fspath(from_.location),
1338
+ "offset": from_.offset,
1339
+ "length": from_.length,
1340
+ }.items():
1341
+ if v is not None:
1342
+ entry = tensor_proto.external_data.add()
1343
+ entry.key = k
1344
+ entry.value = str(v)
1345
+ elif isinstance(from_, _core.StringTensor):
1346
+ tensor_proto.string_data.extend(from_.string_data())
1347
+ else:
1348
+ tensor_proto.raw_data = from_.tobytes()
1349
+ _serialize_metadata_props_into(tensor_proto.metadata_props, from_.metadata_props)
1350
+
1351
+
1352
+ def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto:
1353
+ attribute_proto = onnx.AttributeProto()
1354
+ serialize_attribute_into(attribute_proto, from_=attribute)
1355
+ return attribute_proto
1356
+
1357
+
1358
+ @_capture_errors(lambda attribute_proto, from_: repr(from_))
1359
+ def serialize_attribute_into(
1360
+ attribute_proto: onnx.AttributeProto, from_: _protocols.AttributeProtocol
1361
+ ) -> None:
1362
+ attribute_proto.name = from_.name
1363
+ if from_.doc_string:
1364
+ attribute_proto.doc_string = from_.doc_string
1365
+ _fill_in_value_for_attribute(attribute_proto, from_.type, from_.value)
1366
+
1367
+
1368
+ def _fill_in_value_for_attribute(
1369
+ attribute_proto: onnx.AttributeProto, type_: _enums.AttributeType, value: Any
1370
+ ) -> None:
1371
+ if type_ == _enums.AttributeType.INT:
1372
+ # value: int
1373
+ attribute_proto.i = value
1374
+ attribute_proto.type = onnx.AttributeProto.INT
1375
+ elif type_ == _enums.AttributeType.FLOAT:
1376
+ # value: float
1377
+ attribute_proto.f = value
1378
+ attribute_proto.type = onnx.AttributeProto.FLOAT
1379
+ elif type_ == _enums.AttributeType.STRING:
1380
+ # value: str
1381
+ attribute_proto.s = value.encode("utf-8")
1382
+ attribute_proto.type = onnx.AttributeProto.STRING
1383
+ elif type_ == _enums.AttributeType.INTS:
1384
+ # value: Sequence[int]
1385
+ attribute_proto.ints.extend(value)
1386
+ attribute_proto.type = onnx.AttributeProto.INTS
1387
+ elif type_ == _enums.AttributeType.FLOATS:
1388
+ # value: Sequence[float]
1389
+ attribute_proto.floats.extend(value)
1390
+ attribute_proto.type = onnx.AttributeProto.FLOATS
1391
+ elif type_ == _enums.AttributeType.STRINGS:
1392
+ # value: Sequence[str]
1393
+ attribute_proto.strings.extend([s.encode("utf-8") for s in value])
1394
+ attribute_proto.type = onnx.AttributeProto.STRINGS
1395
+ elif type_ == _enums.AttributeType.TENSOR:
1396
+ # value: _protocols.TensorProtocol
1397
+ serialize_tensor_into(attribute_proto.t, value)
1398
+ attribute_proto.type = onnx.AttributeProto.TENSOR
1399
+ elif type_ == _enums.AttributeType.GRAPH:
1400
+ # value: _protocols.GraphProtocol
1401
+ serialize_graph_into(attribute_proto.g, value)
1402
+ attribute_proto.type = onnx.AttributeProto.GRAPH
1403
+ elif type_ == _enums.AttributeType.TENSORS:
1404
+ # value: Sequence[_protocols.TensorProtocol]
1405
+ for tensor in value:
1406
+ serialize_tensor_into(attribute_proto.tensors.add(), tensor)
1407
+ attribute_proto.type = onnx.AttributeProto.TENSORS
1408
+ elif type_ == _enums.AttributeType.GRAPHS:
1409
+ # value: Sequence[_protocols.GraphProtocol]
1410
+ for graph in value:
1411
+ serialize_graph_into(attribute_proto.graphs.add(), graph)
1412
+ attribute_proto.type = onnx.AttributeProto.GRAPHS
1413
+ elif type_ == _enums.AttributeType.SPARSE_TENSOR:
1414
+ raise NotImplementedError(
1415
+ f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
1416
+ )
1417
+ elif type_ == _enums.AttributeType.SPARSE_TENSORS:
1418
+ raise NotImplementedError(
1419
+ f"Sparse tensors are not supported yet. {_PLEASE_CONTRIBUTE}"
1420
+ )
1421
+ elif type_ == _enums.AttributeType.TYPE_PROTO:
1422
+ # value: _core.TypeAndShape
1423
+ if value.type is not None:
1424
+ serialize_type_into(attribute_proto.tp, value.type)
1425
+ # Need to create the type _before_ writing the shape
1426
+ if value.shape is not None:
1427
+ serialize_shape_into(attribute_proto.tp, value.shape)
1428
+ attribute_proto.type = onnx.AttributeProto.TYPE_PROTO
1429
+ elif type_ == _enums.AttributeType.TYPE_PROTOS:
1430
+ for ir_type in value:
1431
+ # ir_type: _core.TypeAndShape
1432
+ type_proto = attribute_proto.type_protos.add()
1433
+ if ir_type.type is not None:
1434
+ serialize_type_into(type_proto, ir_type.type)
1435
+ # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
1436
+ if ir_type.shape is not None:
1437
+ serialize_shape_into(type_proto, ir_type.shape)
1438
+ attribute_proto.type = onnx.AttributeProto.TYPE_PROTOS
1439
+ else:
1440
+ raise TypeError(f"Unsupported attribute type: {type_}")
1441
+
1442
+
1443
+ @_capture_errors(lambda attribute_proto, from_: repr(from_))
1444
+ def serialize_reference_attribute_into(
1445
+ attribute_proto: onnx.AttributeProto, from_: _protocols.ReferenceAttributeProtocol
1446
+ ) -> None:
1447
+ attribute_proto.name = from_.name
1448
+ attribute_proto.ref_attr_name = from_.ref_attr_name
1449
+ if from_.doc_string:
1450
+ attribute_proto.doc_string = from_.doc_string
1451
+ attribute_proto.type = typing.cast(onnx.AttributeProto.AttributeType, from_.type.value)
1452
+
1453
+
1454
+ def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.ValueInfoProto:
1455
+ """Serialize a value into a ValueInfoProto.
1456
+
1457
+ Args:
1458
+ value: The proto to serialize into.
1459
+ from_: The value to serialize.
1460
+ name: A custom name to set for the value info. If not provided, the name from the value will be used.
1461
+ """
1462
+ value_info_proto = onnx.ValueInfoProto()
1463
+ serialize_value_into(value_info_proto, value, name=name)
1464
+ return value_info_proto
1465
+
1466
+
1467
+ @_capture_errors(lambda value_info_proto, from_: repr(from_))
1468
+ def serialize_value_into(
1469
+ value_info_proto: onnx.ValueInfoProto,
1470
+ from_: _protocols.ValueProtocol,
1471
+ *,
1472
+ name: str = "",
1473
+ ) -> None:
1474
+ """Serialize a value into a ValueInfoProto.
1475
+
1476
+ Args:
1477
+ value_info_proto: The proto to serialize into.
1478
+ from_: The value to serialize.
1479
+ name: A custom name to set for the value info. If not provided, the name from the value will be used.
1480
+ """
1481
+ if name:
1482
+ value_info_proto.name = name
1483
+ else:
1484
+ value_info_proto.name = from_.name
1485
+ if from_.metadata_props:
1486
+ _serialize_metadata_props_into(value_info_proto.metadata_props, from_.metadata_props)
1487
+ if from_.type is not None:
1488
+ serialize_type_into(value_info_proto.type, from_.type)
1489
+ # Need to create the type _before_ writing the shape so that the shape can be written to the leaf type proto
1490
+ if from_.shape is not None:
1491
+ serialize_shape_into(value_info_proto.type, from_.shape)
1492
+ if from_.doc_string:
1493
+ value_info_proto.doc_string = from_.doc_string
1494
+
1495
+
1496
+ @_capture_errors(lambda type_proto, from_: repr(from_))
1497
+ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtocol) -> None:
1498
+ if from_.denotation:
1499
+ type_proto.denotation = from_.denotation
1500
+ if isinstance(from_, _core.TensorType):
1501
+ tensor_type_proto = type_proto.tensor_type
1502
+ tensor_type_proto.elem_type = from_.dtype.value
1503
+ elif isinstance(from_, _core.SparseTensorType):
1504
+ sparse_tensor_type_proto = type_proto.sparse_tensor_type
1505
+ sparse_tensor_type_proto.elem_type = from_.dtype.value
1506
+ elif isinstance(from_, _core.SequenceType):
1507
+ sequence_type_proto = type_proto.sequence_type
1508
+ serialize_type_into(sequence_type_proto.elem_type, from_.elem_type)
1509
+ elif isinstance(from_, _core.OptionalType):
1510
+ optional_type_proto = type_proto.optional_type
1511
+ serialize_type_into(optional_type_proto.elem_type, from_.elem_type)
1512
+ else:
1513
+ raise TypeError(f"Unsupported type: {from_}")
1514
+
1515
+
1516
+ def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
1517
+ type_proto = onnx.TypeProto()
1518
+ serialize_type_into(type_proto, from_=type_protocol)
1519
+ return type_proto
1520
+
1521
+
1522
+ @_capture_errors(lambda type_proto, from_: repr(from_))
1523
+ def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
1524
+ value_field = type_proto.WhichOneof("value")
1525
+ tensor_type = getattr(type_proto, value_field)
1526
+ while not isinstance(tensor_type.elem_type, int):
1527
+ # Find the leaf type that has the shape field
1528
+ type_proto = tensor_type.elem_type
1529
+ value_field = type_proto.WhichOneof("value")
1530
+ tensor_type = getattr(type_proto, value_field)
1531
+ # When from is empty, we still need to set the shape field to an empty list by touching it
1532
+ tensor_type.shape.ClearField("dim")
1533
+ for i, dim in enumerate(from_):
1534
+ denotation = from_.get_denotation(i)
1535
+ serialize_dimension_into(tensor_type.shape.dim.add(), dim, denotation)
1536
+
1537
+
1538
+ @_capture_errors(lambda dim_proto, dim, denotation: repr(dim_proto))
1539
+ def serialize_dimension_into(
1540
+ dim_proto: onnx.TensorShapeProto.Dimension,
1541
+ dim: int | _protocols.SymbolicDimProtocol,
1542
+ denotation: str | None = None,
1543
+ ) -> None:
1544
+ if denotation:
1545
+ dim_proto.denotation = denotation
1546
+ if isinstance(dim, int):
1547
+ dim_proto.dim_value = dim
1548
+ elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
1549
+ if dim.value is not None:
1550
+ # TODO(justinchuby): None is probably not a valid value for dim_param
1551
+ dim_proto.dim_param = str(dim.value)