onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/_core.py ADDED
@@ -0,0 +1,4435 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """data structures for the intermediate representation."""
4
+
5
+ # NOTES for developers:
6
+ # NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because
7
+ # We cannot assume that the build tool chain has protoc installed and would like
8
+ # to keep this module protobuf free. This way we separate the concerns of the IR
9
+ # and the serialization/deserialization.
10
+ #
11
+ # NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
12
+
13
+ from __future__ import annotations
14
+
15
+ import abc
16
+ import contextlib
17
+ import dataclasses
18
+ import heapq
19
+ import logging
20
+ import math
21
+ import mmap
22
+ import os
23
+ import sys
24
+ import textwrap
25
+ import typing
26
+ from collections.abc import (
27
+ Collection,
28
+ Hashable,
29
+ Iterable,
30
+ Iterator,
31
+ Mapping,
32
+ MutableSequence,
33
+ Sequence,
34
+ )
35
+ from collections.abc import (
36
+ Set as AbstractSet,
37
+ )
38
+ from typing import (
39
+ Any,
40
+ Callable,
41
+ ClassVar,
42
+ Generic,
43
+ NamedTuple,
44
+ Protocol,
45
+ SupportsInt,
46
+ Union,
47
+ )
48
+
49
+ import ml_dtypes
50
+ import numpy as np
51
+ from typing_extensions import TypeIs, deprecated
52
+
53
+ import onnx_ir
54
+ from onnx_ir import (
55
+ _display,
56
+ _enums,
57
+ _graph_containers,
58
+ _linked_list,
59
+ _metadata,
60
+ _name_authority,
61
+ _protocols,
62
+ _type_casting,
63
+ )
64
+
65
+ if typing.TYPE_CHECKING:
66
+ import numpy.typing as npt
67
+ from typing_extensions import TypeGuard
68
+
69
+ TArrayCompatible = typing.TypeVar(
70
+ "TArrayCompatible",
71
+ bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible],
72
+ )
73
+
74
+ # System is little endian
75
+ _IS_LITTLE_ENDIAN = sys.byteorder == "little"
76
+ # Data types that are not supported by numpy
77
+ _NON_NUMPY_NATIVE_TYPES = frozenset(
78
+ (
79
+ _enums.DataType.BFLOAT16,
80
+ _enums.DataType.FLOAT8E4M3FN,
81
+ _enums.DataType.FLOAT8E4M3FNUZ,
82
+ _enums.DataType.FLOAT8E5M2,
83
+ _enums.DataType.FLOAT8E5M2FNUZ,
84
+ _enums.DataType.FLOAT8E8M0,
85
+ _enums.DataType.INT4,
86
+ _enums.DataType.UINT4,
87
+ _enums.DataType.FLOAT4E2M1,
88
+ _enums.DataType.INT2,
89
+ _enums.DataType.UINT2,
90
+ )
91
+ )
92
+
93
+ logger = logging.getLogger(__name__)
94
+
95
+
96
+ def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
97
+ """Use this function to check if an object is compatible with numpy.
98
+
99
+ Avoid isinstance checks with the ArrayCompatible protocol for performance reasons.
100
+ """
101
+ return hasattr(obj, "__array__")
102
+
103
+
104
+ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
105
+ """Use this function to check if an object is compatible with DLPack.
106
+
107
+ Avoid isinstance checks with the DLPackCompatible protocol for performance reasons.
108
+ """
109
+ return hasattr(obj, "__dlpack__")
110
+
111
+
112
+ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
113
+ """Convenience Shared methods for classes implementing TensorProtocol."""
114
+
115
+ __slots__ = (
116
+ "_doc_string",
117
+ "_metadata",
118
+ "_metadata_props",
119
+ "_name",
120
+ )
121
+
122
+ def __init__(
123
+ self,
124
+ name: str | None = None,
125
+ doc_string: str | None = None,
126
+ metadata_props: dict[str, str] | None = None,
127
+ ) -> None:
128
+ self._metadata: _metadata.MetadataStore | None = None
129
+ self._metadata_props: dict[str, str] | None = metadata_props
130
+ self._name: str | None = name
131
+ self._doc_string: str | None = doc_string
132
+
133
+ def _printable_type_shape(self) -> str:
134
+ """Return a string representation of the shape and data type."""
135
+ return f"{self.dtype},{self.shape}"
136
+
137
+ def _repr_base(self) -> str:
138
+ """Base string for the repr method.
139
+
140
+ Example: Tensor<FLOAT,[5,42]>
141
+ """
142
+ return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
143
+
144
+ @property
145
+ def name(self) -> str | None:
146
+ """The name of the tensor."""
147
+ return self._name
148
+
149
+ @name.setter
150
+ def name(self, value: str | None) -> None:
151
+ self._name = value
152
+
153
+ @property
154
+ def doc_string(self) -> str | None:
155
+ """The documentation string."""
156
+ return self._doc_string
157
+
158
+ @doc_string.setter
159
+ def doc_string(self, value: str | None) -> None:
160
+ self._doc_string = value
161
+
162
+ @property
163
+ def size(self) -> int:
164
+ """The number of elements in the tensor."""
165
+ return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
166
+
167
+ @property
168
+ def nbytes(self) -> int:
169
+ """The number of bytes in the tensor."""
170
+ # Use math.ceil because when dtype is INT4, the itemsize is 0.5
171
+ return math.ceil(self.dtype.itemsize * self.size)
172
+
173
+ @property
174
+ def metadata_props(self) -> dict[str, str]:
175
+ """The metadata properties of the tensor.
176
+
177
+ The metadata properties are used to store additional information about the tensor.
178
+ Unlike ``meta``, this property is serialized to the ONNX proto.
179
+ """
180
+ if self._metadata_props is None:
181
+ self._metadata_props = {}
182
+ return self._metadata_props
183
+
184
+ @property
185
+ def meta(self) -> _metadata.MetadataStore:
186
+ """The metadata store for intermediate analysis.
187
+
188
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
189
+ to the ONNX proto.
190
+ """
191
+ if self._metadata is None:
192
+ self._metadata = _metadata.MetadataStore()
193
+ return self._metadata
194
+
195
+ def tofile(self, file) -> None:
196
+ """Write the tensor to a binary file.
197
+
198
+ This method writes the raw bytes of the tensor to a file-like object.
199
+ The file-like object must have a ``write`` method that accepts bytes.
200
+
201
+ .. versionadded:: 0.1.11
202
+
203
+ Args:
204
+ file: A file-like object with a ``write`` method that accepts bytes.
205
+ """
206
+ file.write(self.tobytes())
207
+
208
+ def display(self, *, page: bool = False) -> None:
209
+ rich = _display.require_rich()
210
+
211
+ if rich is None:
212
+ status_manager = contextlib.nullcontext()
213
+ else:
214
+ import rich.status # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
215
+
216
+ status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")
217
+
218
+ from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel
219
+ asciichartpy,
220
+ )
221
+
222
+ with status_manager:
223
+ # Construct the text to display
224
+ lines = []
225
+ array = self.numpy().flatten()
226
+ lines.append(repr(self))
227
+ lines.append("")
228
+ nan_values = np.isnan(array)
229
+ nan_count = np.count_nonzero(nan_values)
230
+ inf_count = np.count_nonzero(np.isinf(array))
231
+ numbers = array[~nan_values]
232
+ lines.append(
233
+ f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, "
234
+ f"NaN count: {nan_count}, "
235
+ f"Inf count: {inf_count}"
236
+ )
237
+ # Compute sparsity
238
+ sparse_threathold = 1e-6
239
+ # NOTE: count_nonzero() is faster than sum() for boolean arrays
240
+ sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size
241
+ lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}")
242
+
243
+ # Compute histogram
244
+ finite_numbers = array[np.isfinite(array)]
245
+ lines.append("Histogram:")
246
+ hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False)
247
+ lines.append(
248
+ asciichartpy.plot(
249
+ hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"}
250
+ )
251
+ )
252
+
253
+ text = "\n".join(lines)
254
+
255
+ if rich is None:
256
+ print(text)
257
+ elif page:
258
+ import rich.console # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
259
+
260
+ console = rich.console.Console()
261
+ with console.pager():
262
+ console.print(text)
263
+ else:
264
+ rich.print(text)
265
+
266
+
267
+ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None:
268
+ """Check if the numpy array dtype matches the IR data type.
269
+
270
+ When the dtype is not one of the numpy native dtypes, the value needs need to be:
271
+
272
+ - ``int8`` or ``uint8`` for int2, int4, with the sign bit extended to 8 bits.
273
+ - ``uint8`` for uint2, uint4 or float4.
274
+ - ``uint8`` for 8-bit data types.
275
+ - ``uint16`` for bfloat16
276
+
277
+ or corresponding dtypes from the ``ml_dtype`` package.
278
+ """
279
+ if dtype in _NON_NUMPY_NATIVE_TYPES:
280
+ if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
281
+ raise TypeError(
282
+ f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
283
+ )
284
+ if dtype.bitwidth == 8 and array.dtype not in (
285
+ np.uint8,
286
+ ml_dtypes.float8_e4m3fnuz,
287
+ ml_dtypes.float8_e4m3fn,
288
+ ml_dtypes.float8_e5m2fnuz,
289
+ ml_dtypes.float8_e5m2,
290
+ ml_dtypes.float8_e8m0fnu,
291
+ ):
292
+ raise TypeError(
293
+ f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
294
+ )
295
+ if dtype == _enums.DataType.INT4:
296
+ if array.dtype not in (np.int8, np.uint8, ml_dtypes.int4):
297
+ raise TypeError(
298
+ f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int4 (not {array.dtype}) for IR data type {dtype}."
299
+ )
300
+ if dtype == _enums.DataType.UINT4:
301
+ if array.dtype not in (np.uint8, ml_dtypes.uint4):
302
+ raise TypeError(
303
+ f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
304
+ )
305
+ if dtype == _enums.DataType.FLOAT4E2M1:
306
+ if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
307
+ raise TypeError(
308
+ f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
309
+ )
310
+ if dtype == _enums.DataType.INT2:
311
+ if array.dtype not in (np.int8, np.uint8, ml_dtypes.int2):
312
+ raise TypeError(
313
+ f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int2 (not {array.dtype}) for IR data type {dtype}."
314
+ )
315
+ if dtype == _enums.DataType.UINT2:
316
+ if array.dtype not in (np.uint8, ml_dtypes.uint2):
317
+ raise TypeError(
318
+ f"The numpy array dtype must be uint8 or ml_dtypes.uint2 (not {array.dtype}) for IR data type {dtype}."
319
+ )
320
+ return
321
+
322
+ try:
323
+ dtype_numpy = _enums.DataType.from_numpy(array.dtype)
324
+ except TypeError as e:
325
+ raise TypeError(
326
+ "Failed to convert the numpy dtype to an IR data type. "
327
+ "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when "
328
+ "creating a Tensor."
329
+ ) from e
330
+
331
+ if dtype_numpy != dtype:
332
+ raise TypeError(
333
+ f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}."
334
+ )
335
+
336
+
337
+ def _maybe_view_np_array_with_ml_dtypes(
338
+ array: np.ndarray, dtype: _enums.DataType
339
+ ) -> np.ndarray:
340
+ """Reinterpret the array when it is a bit representation of a dtype not supported by numpy.
341
+
342
+ Args:
343
+ array: The numpy array to reinterpret.
344
+ dtype: The data type to reinterpret the array as.
345
+
346
+ Returns:
347
+ The array reinterpreted as the dtype.
348
+ """
349
+ if dtype == _enums.DataType.BFLOAT16:
350
+ return array.view(ml_dtypes.bfloat16)
351
+ if dtype == _enums.DataType.FLOAT8E4M3FN:
352
+ return array.view(ml_dtypes.float8_e4m3fn)
353
+ if dtype == _enums.DataType.FLOAT8E4M3FNUZ:
354
+ return array.view(ml_dtypes.float8_e4m3fnuz)
355
+ if dtype == _enums.DataType.FLOAT8E5M2:
356
+ return array.view(ml_dtypes.float8_e5m2)
357
+ if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
358
+ return array.view(ml_dtypes.float8_e5m2fnuz)
359
+ if dtype == _enums.DataType.FLOAT8E8M0:
360
+ return array.view(ml_dtypes.float8_e8m0fnu)
361
+ if dtype == _enums.DataType.INT4:
362
+ return array.view(ml_dtypes.int4)
363
+ if dtype == _enums.DataType.UINT4:
364
+ return array.view(ml_dtypes.uint4)
365
+ if dtype == _enums.DataType.FLOAT4E2M1:
366
+ return array.view(ml_dtypes.float4_e2m1fn)
367
+ if dtype == _enums.DataType.INT2:
368
+ return array.view(ml_dtypes.int2)
369
+ if dtype == _enums.DataType.UINT2:
370
+ return array.view(ml_dtypes.uint2)
371
+ return array
372
+
373
+
374
+ def _supports_fileno(file: Any) -> bool:
375
+ """Check if the file-like object supports fileno()."""
376
+ if not hasattr(file, "fileno"):
377
+ return False
378
+ try:
379
+ file.fileno()
380
+ except Exception: # pylint: disable=broad-except
381
+ return False
382
+ return True
383
+
384
+
385
+ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
386
+ """Create a numpy array for the byte representation of the tensor.
387
+
388
+ This function is used for serializing the tensor to bytes. It handles the
389
+ special cases for 2-bit and 4-bit data types and endianness.
390
+ """
391
+ array = tensor.numpy()
392
+ if tensor.dtype in {
393
+ _enums.DataType.INT4,
394
+ _enums.DataType.UINT4,
395
+ _enums.DataType.FLOAT4E2M1,
396
+ }:
397
+ # Pack the array into int4
398
+ array = _type_casting.pack_4bitx2(array)
399
+ elif tensor.dtype in {
400
+ _enums.DataType.INT2,
401
+ _enums.DataType.UINT2,
402
+ }:
403
+ # Pack the array into int2
404
+ array = _type_casting.pack_2bitx4(array)
405
+ else:
406
+ assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
407
+ if not _IS_LITTLE_ENDIAN:
408
+ array = array.astype(array.dtype.newbyteorder("<"))
409
+ return array
410
+
411
+
412
+ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
413
+ """An immutable concrete tensor.
414
+
415
+ This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
416
+ compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object.
417
+ The tensor is immutable and the data is not copied at initialization.
418
+
419
+ To create a tensor from a numpy array::
420
+
421
+ >>> import numpy as np
422
+ >>> array = np.array([1, 2, 3])
423
+ >>> tensor = Tensor(array)
424
+ >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method
425
+ >>> np.allclose(tensor, array)
426
+ True
427
+
428
+ To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor
429
+ to a byte string for serialization, call :meth:`tobytes`.
430
+
431
+ It is recommended to check the size of the tensor first before accessing the
432
+ underlying data, because accessing the data may be expensive and incur IO
433
+ overhead.
434
+
435
+ Subclass this class to efficiently handle different types of tensors from different frameworks.
436
+
437
+ Attributes:
438
+ name: The name of the tensor.
439
+ shape: The shape of the tensor.
440
+ dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
441
+ doc_string: Documentation string.
442
+ raw: The raw data behind this tensor. It can be anything.
443
+ size: The number of elements in the tensor.
444
+ nbytes: The number of bytes in the tensor.
445
+ metadata_props: Metadata that will be serialized to the ONNX file.
446
+ meta: Metadata store for graph transform passes.
447
+ """
448
+
449
+ __slots__ = (
450
+ "_dtype",
451
+ "_raw",
452
+ "_shape",
453
+ )
454
+
455
+ def __init__(
456
+ self,
457
+ value: TArrayCompatible,
458
+ dtype: _enums.DataType | None = None,
459
+ *,
460
+ shape: Shape | None = None,
461
+ name: str | None = None,
462
+ doc_string: str | None = None,
463
+ metadata_props: dict[str, str] | None = None,
464
+ ) -> None:
465
+ """Initialize a tensor.
466
+
467
+ Args:
468
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
469
+ When the dtype is not one of the numpy native dtypes, the value can
470
+ be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
471
+ and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
472
+ ``dtype`` must be specified in this case.
473
+ dtype: The data type of the tensor. It can be None only when value is a numpy array.
474
+ Users are responsible for making sure the dtype matches the value when value is not a numpy array.
475
+ shape: The shape of the tensor. If None, the shape is obtained from the value.
476
+ name: The name of the tensor.
477
+ doc_string: The documentation string.
478
+ metadata_props: The metadata properties.
479
+
480
+ Raises:
481
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
482
+ TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array.
483
+ ValueError: If the shape is not specified and the value does not have a shape attribute.
484
+ ValueError: If the dtype is not specified and the value is not a numpy array.
485
+ """
486
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
487
+ # NOTE: We should not do any copying here for performance reasons
488
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
489
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
490
+ if shape is None:
491
+ # Obtain the shape from the value
492
+ if not hasattr(value, "shape"):
493
+ raise ValueError(
494
+ f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
495
+ "Please specify the shape explicitly."
496
+ )
497
+ self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
498
+ else:
499
+ self._shape = shape
500
+ self._shape.freeze()
501
+ if isinstance(value, np.generic):
502
+ # Turn numpy scalar into a numpy array
503
+ value = np.array(value) # type: ignore[assignment]
504
+ if dtype is None:
505
+ if isinstance(value, np.ndarray):
506
+ self._dtype = _enums.DataType.from_numpy(value.dtype)
507
+ else:
508
+ raise ValueError(
509
+ "The dtype must be specified when the value is not a numpy array. "
510
+ "Value type: {type(value)}"
511
+ )
512
+ else:
513
+ if isinstance(value, np.ndarray):
514
+ # Make sure the dtype matches the value
515
+ _check_numpy_representation_type(value, dtype)
516
+ # Users are responsible for making sure the dtype matches the value
517
+ # when value is not a numpy array
518
+ self._dtype = dtype
519
+
520
+ # View the bfloat16, float8 and int2, int4 types using ml_dtypes
521
+ if isinstance(value, np.ndarray):
522
+ value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
523
+
524
+ self._raw = value
525
+
526
+ def __array__(self, dtype: Any = None) -> np.ndarray:
527
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
528
+ return self._raw.__array__(dtype)
529
+ assert _compatible_with_dlpack(self._raw), (
530
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
531
+ )
532
+ return np.from_dlpack(self._raw)
533
+
534
+ def __dlpack__(self, *, stream: Any = None) -> Any:
535
+ if _compatible_with_dlpack(self._raw):
536
+ return self._raw.__dlpack__(stream=stream)
537
+ return self.__array__().__dlpack__(stream=stream)
538
+
539
+ def __dlpack_device__(self) -> tuple[int, int]:
540
+ if _compatible_with_dlpack(self._raw):
541
+ return self._raw.__dlpack_device__()
542
+ return self.__array__().__dlpack_device__()
543
+
544
+ def __repr__(self) -> str:
545
+ # Avoid multi-line repr
546
+ tensor_lines = repr(self._raw).split("\n")
547
+ tensor_text = " ".join(line.strip() for line in tensor_lines)
548
+ return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
549
+
550
+ @property
551
+ def dtype(self) -> _enums.DataType:
552
+ """The data type of the tensor. Immutable."""
553
+ return self._dtype
554
+
555
+ @property
556
+ def shape(self) -> Shape:
557
+ """The shape of the tensor. Immutable."""
558
+ return self._shape
559
+
560
+ @property
561
+ def raw(self) -> TArrayCompatible:
562
+ """Backing data of the tensor. Immutable."""
563
+ return self._raw # type: ignore[return-value]
564
+
565
+ def numpy(self) -> np.ndarray:
566
+ """Return the tensor as a numpy array.
567
+
568
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
569
+ package are used. The values can be reinterpreted as bit representations
570
+ using the ``.view()`` method.
571
+ """
572
+ if isinstance(self._raw, np.ndarray):
573
+ return self._raw
574
+ # We do not cache the value to save memory
575
+ return self.__array__()
576
+
577
+ def tobytes(self) -> bytes:
578
+ """Returns the value as bytes encoded in little endian.
579
+
580
+ Override this method for more efficient serialization when the raw
581
+ value is not a numpy array.
582
+ """
583
+ # TODO(justinchuby): Support DLPack
584
+ array = _create_np_array_for_byte_representation(self)
585
+ return array.tobytes()
586
+
587
+ def tofile(self, file) -> None:
588
+ """Write the tensor to a binary file.
589
+
590
+ .. versionadded:: 0.1.11
591
+
592
+ Args:
593
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
594
+ """
595
+ if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
596
+ # This is a duplication of tobytes() for handling special cases
597
+ array = _create_np_array_for_byte_representation(self)
598
+ array.tofile(file)
599
+ else:
600
+ file.write(self.tobytes())
601
+
602
+
603
+ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
604
+ """An immutable concrete tensor with its data store on disk.
605
+
606
+ This class uses memory mapping to avoid loading the tensor into memory,
607
+ when the data type is supported by numpy. Otherwise, the tensor is loaded
608
+ into memory lazily when accessed.
609
+
610
+ Calling :attr:`shape` does not incur IO. Checking shape before loading
611
+ the tensor is recommended if IO overhead and memory usage is a concern.
612
+
613
+ To obtain an array, call :meth:`numpy`. To obtain the bytes,
614
+ call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
615
+
616
+ The :attr:`location` must be a relative path conforming to the ONNX
617
+ specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
618
+ to be the full path to the data file. Users should expect that the :attr:`path`
619
+ always leads to the correct file. At initialization, paths are not checked.
620
+ It is the user's responsibility to ensure the paths are valid and accessible.
621
+
622
+ Attributes:
623
+ location: The location of the data file. It is the path relative to the base directory.
624
+ base_dir: The base directory for the external data. It is used to resolve relative paths.
625
+ At serialization, only the :attr:`location` is serialized into the "location" field of the ``TensorProto``.
626
+ path: The path to the data file. This is computed by joining :attr:`base_dir` and :attr:`location`.
627
+ offset: The offset in bytes from the start of the file.
628
+ length: The length of the data in bytes.
629
+ dtype: The data type of the tensor.
630
+ shape: The shape of the tensor.
631
+ name: The name of the tensor. It must be specified.
632
+ doc_string: The documentation string.
633
+ metadata_props: The metadata properties.
634
+ """
635
+
636
+ __slots__ = (
637
+ "_array",
638
+ "_base_dir",
639
+ "_dtype",
640
+ "_length",
641
+ "_location",
642
+ "_offset",
643
+ "_shape",
644
+ "_valid",
645
+ "raw",
646
+ )
647
+
648
+ def __init__(
649
+ self,
650
+ location: os.PathLike | str,
651
+ offset: int | None,
652
+ length: int | None,
653
+ dtype: _enums.DataType,
654
+ *,
655
+ shape: Shape,
656
+ name: str,
657
+ doc_string: str | None = None,
658
+ metadata_props: dict[str, str] | None = None,
659
+ base_dir: os.PathLike | str = "",
660
+ ) -> None:
661
+ """Initialize an external tensor.
662
+
663
+ Args:
664
+ location: The location of the data file. It is the path relative to the base directory.
665
+ offset: The offset in bytes from the start of the file.
666
+ length: The length of the data in bytes.
667
+ dtype: The data type of the tensor.
668
+ shape: The shape of the tensor.
669
+ name: The name of the tensor.
670
+ doc_string: The documentation string.
671
+ metadata_props: The metadata properties.
672
+ base_dir: The base directory for the external data. It is used to resolve relative paths.
673
+ """
674
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
675
+ # NOTE: Do not verify the location by default. This is because the location field
676
+ # in the tensor proto can be anything and we would like deserialization from
677
+ # proto to IR to not fail.
678
+ if onnx_ir.DEBUG:
679
+ if os.path.isabs(location):
680
+ raise ValueError(
681
+ "The location must be a relative path. Please specify base_dir as well."
682
+ )
683
+ self._location = location
684
+ self._base_dir = base_dir
685
+ self._offset: int | None = offset
686
+ self._length: int | None = length
687
+ self._dtype: _enums.DataType = dtype
688
+ self.name: str = name # mutable
689
+ self._shape: Shape = shape
690
+ self._shape.freeze()
691
+ self.doc_string: str | None = doc_string # mutable
692
+ self._array: np.ndarray | None = None
693
+ self.raw: mmap.mmap | None = None
694
+ self._metadata_props = metadata_props
695
+ self._metadata: _metadata.MetadataStore | None = None
696
+ self._valid = True
697
+
698
+ @property
699
+ def base_dir(self) -> str | os.PathLike:
700
+ # Mutable
701
+ return self._base_dir
702
+
703
+ @base_dir.setter
704
+ def base_dir(self, value: str | os.PathLike) -> None:
705
+ self._base_dir = value
706
+
707
+ @property
708
+ def location(self) -> str | os.PathLike:
709
+ # Immutable
710
+ return self._location
711
+
712
+ @property
713
+ def path(self) -> str:
714
+ # Immutable, computed
715
+ return os.path.join(self._base_dir, self._location)
716
+
717
+ @property
718
+ def offset(self) -> int | None:
719
+ # Immutable
720
+ return self._offset
721
+
722
+ @property
723
+ def length(self) -> int | None:
724
+ # Immutable
725
+ return self._length
726
+
727
+ @property
728
+ def dtype(self) -> _enums.DataType:
729
+ # Immutable
730
+ return self._dtype
731
+
732
+ @property
733
+ def shape(self) -> Shape:
734
+ # Immutable
735
+ return self._shape
736
+
737
+ def _load(self):
738
+ self._check_validity()
739
+ assert self._array is None, "Bug: The array should be loaded only once."
740
+ if self.size == 0:
741
+ # When the size is 0, mmap is impossible and meaningless
742
+ self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
743
+ return
744
+ # Map the whole file into the memory
745
+ with open(self.path, "rb") as f:
746
+ self.raw = mmap.mmap(
747
+ f.fileno(),
748
+ 0,
749
+ access=mmap.ACCESS_READ,
750
+ )
751
+
752
+ if self.dtype in {
753
+ _enums.DataType.INT4,
754
+ _enums.DataType.UINT4,
755
+ _enums.DataType.FLOAT4E2M1,
756
+ _enums.DataType.INT2,
757
+ _enums.DataType.UINT2,
758
+ }:
759
+ # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
760
+ # No need to set endianness for uint8
761
+ dt = np.dtype(np.uint8)
762
+ count = self.size // 2 + self.size % 2
763
+ else:
764
+ # Handle the byte order correctly by always using little endian
765
+ dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
766
+ count = self.size
767
+
768
+ self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
769
+ shape = self.shape.numpy()
770
+
771
+ if self.dtype.bitwidth == 4:
772
+ # Unpack the 4bit arrays
773
+ self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
774
+ self.dtype.numpy()
775
+ )
776
+ elif self.dtype.bitwidth == 2:
777
+ # Unpack the 2bit arrays
778
+ self._array = _type_casting.unpack_2bitx4(self._array, shape).view(
779
+ self.dtype.numpy()
780
+ )
781
+ else:
782
+ self._array = self._array.reshape(shape)
783
+
784
+ def __array__(self, dtype: Any = None) -> np.ndarray:
785
+ self._check_validity()
786
+ if self._array is None:
787
+ self._load()
788
+ assert self._array is not None
789
+ return self._array.__array__(dtype)
790
+
791
+ def __dlpack__(self, *, stream: Any = None) -> Any:
792
+ raise NotImplementedError(
793
+ "ExternalTensor does not support DLPack because it uses memory mapping. "
794
+ "Call numpy() to get a numpy array instead."
795
+ )
796
+
797
+ def __dlpack_device__(self) -> tuple[int, int]:
798
+ raise NotImplementedError(
799
+ "ExternalTensor does not support DLPack because it uses memory mapping. "
800
+ "Call numpy() to get a numpy array instead."
801
+ )
802
+
803
+ def __repr__(self) -> str:
804
+ return (
805
+ f"{self._repr_base()}(location='{self.location}', name={self.name!r}, "
806
+ f"offset={self.offset!r}, length={self.length!r}, base_dir={self.base_dir!r})"
807
+ )
808
+
809
+ def numpy(self) -> np.ndarray:
810
+ """Return the tensor as a numpy array.
811
+
812
+ The data will be memory mapped into memory and will not taken up physical memory space.
813
+ """
814
+ self._check_validity()
815
+ if self._array is None:
816
+ self._load()
817
+ assert self._array is not None
818
+ return self._array
819
+
820
+ def tobytes(self) -> bytes:
821
+ """Return the bytes of the tensor.
822
+
823
+ This will load the tensor into memory.
824
+ """
825
+ self._check_validity()
826
+ if self.raw is None:
827
+ self._load()
828
+ assert self.raw is not None
829
+ offset = self._offset or 0
830
+ length = self._length or self.nbytes
831
+ return self.raw[offset : offset + length]
832
+
833
+ def tofile(self, file) -> None:
834
+ self._check_validity()
835
+ with open(self.path, "rb") as src:
836
+ if self._offset is not None:
837
+ src.seek(self._offset)
838
+ bytes_to_copy = self._length or self.nbytes
839
+ chunk_size = 1024 * 1024 # 1MB
840
+ while bytes_to_copy > 0:
841
+ chunk = src.read(min(chunk_size, bytes_to_copy))
842
+ file.write(chunk)
843
+ bytes_to_copy -= len(chunk)
844
+
845
+ def valid(self) -> bool:
846
+ """Check if the tensor is valid.
847
+
848
+ The external tensor is valid if it has not been invalidated.
849
+ """
850
+ return self._valid
851
+
852
+ def _check_validity(self) -> None:
853
+ if not self.valid():
854
+ raise ValueError(
855
+ f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
856
+ )
857
+
858
+ def invalidate(self) -> None:
859
+ """Invalidate the tensor.
860
+
861
+ The external tensor is invalidated when the data is known to be corrupted or deleted.
862
+ """
863
+ self._valid = False
864
+
865
+ def release(self) -> None:
866
+ """Delete all references to the memory buffer and close the memory-mapped file."""
867
+ self._array = None
868
+ if self.raw is not None:
869
+ self.raw.close()
870
+ self.raw = None
871
+
872
+
873
+ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
874
+ """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
875
+
876
+ __slots__ = (
877
+ "_raw",
878
+ "_shape",
879
+ )
880
+
881
+ def __init__(
882
+ self,
883
+ value: Sequence[bytes] | npt.NDArray[np.bytes_],
884
+ *,
885
+ shape: Shape | None = None,
886
+ name: str | None = None,
887
+ doc_string: str | None = None,
888
+ metadata_props: dict[str, str] | None = None,
889
+ ) -> None:
890
+ """Initialize a tensor.
891
+
892
+ Args:
893
+ value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes.
894
+ shape: The shape of the tensor. If None, the shape is obtained from the value.
895
+ name: The name of the tensor.
896
+ doc_string: The documentation string.
897
+ metadata_props: The metadata properties.
898
+ """
899
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
900
+ if shape is None:
901
+ if not hasattr(value, "shape"):
902
+ raise ValueError(
903
+ f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
904
+ "Please specify the shape explicitly."
905
+ )
906
+ self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
907
+ else:
908
+ self._shape = shape
909
+ self._shape.freeze()
910
+ self._raw = value
911
+
912
+ def __array__(self, dtype: Any = None) -> np.ndarray:
913
+ if isinstance(self._raw, np.ndarray):
914
+ return self._raw
915
+ assert isinstance(self._raw, Sequence), (
916
+ f"Bug: Expected a sequence, got {type(self._raw)}"
917
+ )
918
+ return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
919
+
920
+ def __dlpack__(self, *, stream: Any = None) -> Any:
921
+ del stream # unused
922
+ raise TypeError("StringTensor does not support DLPack")
923
+
924
+ def __dlpack_device__(self) -> tuple[int, int]:
925
+ raise TypeError("StringTensor does not support DLPack")
926
+
927
+ def __repr__(self) -> str:
928
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
929
+
930
+ @property
931
+ def dtype(self) -> _enums.DataType:
932
+ """The data type of the tensor. Immutable."""
933
+ return _enums.DataType.STRING
934
+
935
+ @property
936
+ def shape(self) -> Shape:
937
+ """The shape of the tensor. Immutable."""
938
+ return self._shape
939
+
940
+ @property
941
+ def nbytes(self) -> int:
942
+ """The number of bytes in the tensor."""
943
+ return sum(len(string) for string in self.string_data())
944
+
945
+ @property
946
+ def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
947
+ """Backing data of the tensor. Immutable."""
948
+ return self._raw # type: ignore[return-value]
949
+
950
+ def numpy(self) -> npt.NDArray[np.bytes_]:
951
+ """Return the tensor as a numpy array."""
952
+ return self.__array__()
953
+
954
+ def tobytes(self) -> bytes:
955
+ raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")
956
+
957
+ def string_data(self) -> Sequence[bytes]:
958
+ """Return the string data of the tensor."""
959
+ if isinstance(self._raw, np.ndarray):
960
+ return self._raw.flatten().tolist()
961
+ return self._raw
962
+
963
+
964
+ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
965
+ """A tensor that lazily evaluates a function to get the actual tensor.
966
+
967
+ This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
968
+ The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
969
+
970
+ Example::
971
+
972
+ >>> import numpy as np
973
+ >>> import onnx_ir as ir
974
+ >>> weights = np.array([[1, 2, 3]])
975
+ >>> def create_tensor(): # Delay applying transformations to the weights
976
+ ... weights_t = weights.transpose()
977
+ ... return ir.tensor(weights_t)
978
+ >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
979
+ >>> print(lazy_tensor.numpy())
980
+ [[1]
981
+ [2]
982
+ [3]]
983
+
984
+ Attributes:
985
+ func: The function that returns the actual tensor.
986
+ dtype: The data type of the tensor.
987
+ shape: The shape of the tensor.
988
+ cache: Whether to cache the result of the function. If False,
989
+ the function is called every time the tensor content is accessed.
990
+ If True, the function is called only once and the result is cached in memory.
991
+ Default is False.
992
+ name: The name of the tensor.
993
+ doc_string: The documentation string.
994
+ metadata_props: The metadata properties.
995
+ """
996
+
997
+ __slots__ = (
998
+ "_dtype",
999
+ "_func",
1000
+ "_shape",
1001
+ "_tensor",
1002
+ "cache",
1003
+ )
1004
+
1005
+ def __init__(
1006
+ self,
1007
+ func: Callable[[], _protocols.TensorProtocol],
1008
+ dtype: _enums.DataType,
1009
+ shape: Shape,
1010
+ *,
1011
+ cache: bool = False,
1012
+ name: str | None = None,
1013
+ doc_string: str | None = None,
1014
+ metadata_props: dict[str, str] | None = None,
1015
+ ) -> None:
1016
+ """Initialize a lazy tensor.
1017
+
1018
+ Args:
1019
+ func: The function that returns the actual tensor.
1020
+ dtype: The data type of the tensor.
1021
+ shape: The shape of the tensor.
1022
+ cache: Whether to cache the result of the function.
1023
+ name: The name of the tensor.
1024
+ doc_string: The documentation string.
1025
+ metadata_props: The metadata properties.
1026
+ """
1027
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
1028
+ self._func = func
1029
+ self._dtype = dtype
1030
+ self._shape = shape
1031
+ self._tensor: _protocols.TensorProtocol | None = None
1032
+ self.cache = cache
1033
+
1034
+ def _evaluate(self) -> _protocols.TensorProtocol:
1035
+ """Evaluate the function to get the actual tensor."""
1036
+ if not self.cache:
1037
+ return self._func()
1038
+
1039
+ # Cache the tensor
1040
+ if self._tensor is None:
1041
+ self._tensor = self._func()
1042
+ return self._tensor
1043
+
1044
+ def __array__(self, dtype: Any = None) -> np.ndarray:
1045
+ return self._evaluate().__array__(dtype)
1046
+
1047
+ def __dlpack__(self, *, stream: Any = None) -> Any:
1048
+ return self._evaluate().__dlpack__(stream=stream)
1049
+
1050
+ def __dlpack_device__(self) -> tuple[int, int]:
1051
+ return self._evaluate().__dlpack_device__()
1052
+
1053
+ def __repr__(self) -> str:
1054
+ return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
1055
+
1056
+ @property
1057
+ def raw(self) -> Callable[[], _protocols.TensorProtocol]:
1058
+ return self._func
1059
+
1060
+ @property
1061
+ def dtype(self) -> _enums.DataType:
1062
+ """The data type of the tensor. Immutable."""
1063
+ return self._dtype
1064
+
1065
+ @property
1066
+ def shape(self) -> Shape:
1067
+ """The shape of the tensor. Immutable."""
1068
+ return self._shape
1069
+
1070
+ def numpy(self) -> np.ndarray:
1071
+ """Return the tensor as a numpy array."""
1072
+ return self._evaluate().numpy()
1073
+
1074
+ def tobytes(self) -> bytes:
1075
+ """Return the bytes of the tensor."""
1076
+ return self._evaluate().tobytes()
1077
+
1078
+ def tofile(self, file) -> None:
1079
+ tensor = self._evaluate()
1080
+ if hasattr(tensor, "tofile"):
1081
+ # Some existing implementation of TensorProtocol
1082
+ # may not have tofile() as it was introduced in v0.1.11
1083
+ tensor.tofile(file)
1084
+ else:
1085
+ super().tofile(file)
1086
+
1087
+
1088
+ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
1089
+ """A tensor that stores 2bit and 4bit datatypes in packed format.
1090
+
1091
+ .. versionadded:: 0.1.2
1092
+ """
1093
+
1094
+ __slots__ = (
1095
+ "_dtype",
1096
+ "_raw",
1097
+ "_shape",
1098
+ )
1099
+
1100
+ def __init__(
1101
+ self,
1102
+ value: TArrayCompatible,
1103
+ dtype: _enums.DataType,
1104
+ *,
1105
+ shape: Shape | Sequence[int],
1106
+ name: str | None = None,
1107
+ doc_string: str | None = None,
1108
+ metadata_props: dict[str, str] | None = None,
1109
+ ) -> None:
1110
+ """Initialize a tensor.
1111
+
1112
+ Args:
1113
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
1114
+ The value MUST be packed in an integer dtype.
1115
+ dtype: The data type of the tensor. Must be one of INT2, UINT2, INT4, UINT4, FLOAT4E2M1.
1116
+ shape: The shape of the tensor.
1117
+ name: The name of the tensor.
1118
+ doc_string: The documentation string.
1119
+ metadata_props: The metadata properties.
1120
+
1121
+ Raises:
1122
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
1123
+ TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
1124
+ """
1125
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
1126
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
1127
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
1128
+ self._shape = Shape(shape)
1129
+ self._shape.freeze()
1130
+ if dtype.bitwidth not in (2, 4):
1131
+ raise TypeError(
1132
+ f"PackedTensor only supports INT2, UINT2, INT4, UINT4, FLOAT4E2M1, but got {dtype}"
1133
+ )
1134
+ self._dtype = dtype
1135
+ self._raw = value
1136
+
1137
+ if isinstance(value, np.ndarray):
1138
+ if (
1139
+ value.dtype == ml_dtypes.float4_e2m1fn
1140
+ or value.dtype == ml_dtypes.uint4
1141
+ or value.dtype == ml_dtypes.int4
1142
+ or value.dtype == ml_dtypes.uint2
1143
+ or value.dtype == ml_dtypes.int2
1144
+ ):
1145
+ raise TypeError(
1146
+ f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
1147
+ "Please pack the value or use `onnx_ir.Tensor`."
1148
+ )
1149
+ # Check after shape and dtype is set
1150
+ if value.size != self.nbytes:
1151
+ raise ValueError(
1152
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
1153
+ )
1154
+
1155
+ def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
1156
+ return self.numpy()
1157
+
1158
+ def __dlpack__(self, *, stream: Any = None) -> Any:
1159
+ if _compatible_with_dlpack(self._raw):
1160
+ return self._raw.__dlpack__(stream=stream)
1161
+ return self.__array__().__dlpack__(stream=stream)
1162
+
1163
+ def __dlpack_device__(self) -> tuple[int, int]:
1164
+ if _compatible_with_dlpack(self._raw):
1165
+ return self._raw.__dlpack_device__()
1166
+ return self.__array__().__dlpack_device__()
1167
+
1168
+ def __repr__(self) -> str:
1169
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
1170
+
1171
+ @property
1172
+ def dtype(self) -> _enums.DataType:
1173
+ """The data type of the tensor. Immutable."""
1174
+ return self._dtype
1175
+
1176
+ @property
1177
+ def shape(self) -> Shape:
1178
+ """The shape of the tensor. Immutable."""
1179
+ return self._shape
1180
+
1181
+ @property
1182
+ def raw(self) -> TArrayCompatible:
1183
+ """Backing data of the tensor. Immutable."""
1184
+ return self._raw # type: ignore[return-value]
1185
+
1186
+ def numpy(self) -> np.ndarray:
1187
+ """Return the tensor as a numpy array.
1188
+
1189
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1190
+ package are used. The values can be reinterpreted as bit representations
1191
+ using the ``.view()`` method.
1192
+ """
1193
+ array = self.numpy_packed()
1194
+ # ONNX IR returns the unpacked arrays
1195
+ return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
1196
+
1197
+ def numpy_packed(self) -> npt.NDArray[np.uint8]:
1198
+ """Return the tensor as a packed array."""
1199
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
1200
+ array = np.asarray(self._raw)
1201
+ else:
1202
+ assert _compatible_with_dlpack(self._raw), (
1203
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
1204
+ )
1205
+ array = np.from_dlpack(self._raw)
1206
+ if array.nbytes != self.nbytes:
1207
+ raise ValueError(
1208
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
1209
+ )
1210
+ return array.view(np.uint8)
1211
+
1212
+ def tobytes(self) -> bytes:
1213
+ """Returns the value as bytes encoded in little endian.
1214
+
1215
+ Override this method for more efficient serialization when the raw
1216
+ value is not a numpy array.
1217
+ """
1218
+ array = self.numpy_packed()
1219
+ if not _IS_LITTLE_ENDIAN:
1220
+ array = array.astype(array.dtype.newbyteorder("<"))
1221
+ return array.tobytes()
1222
+
1223
+ def tofile(self, file) -> None:
1224
+ """Write the tensor to a binary file.
1225
+
1226
+ .. versionadded:: 0.1.11
1227
+
1228
+ Args:
1229
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
1230
+ """
1231
+ if _supports_fileno(file):
1232
+ # This is a duplication of tobytes() for handling edge cases
1233
+ array = self.numpy_packed()
1234
+ if not _IS_LITTLE_ENDIAN:
1235
+ array = array.astype(array.dtype.newbyteorder("<"))
1236
+ array.tofile(file)
1237
+ else:
1238
+ file.write(self.tobytes())
1239
+
1240
+
1241
+ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1242
+ """Immutable symbolic dimension that can be shared across multiple shapes.
1243
+
1244
+ SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
1245
+ It is immutable and can be compared or hashed.
1246
+ """
1247
+
1248
+ __slots__ = ("_value",)
1249
+
1250
+ def __init__(self, value: str | None) -> None:
1251
+ """Initialize a symbolic dimension.
1252
+
1253
+ Args:
1254
+ value: The value of the dimension. It should not be an int.
1255
+
1256
+ Raises:
1257
+ TypeError: If value is an int.
1258
+ """
1259
+ if isinstance(value, int):
1260
+ raise TypeError(
1261
+ "The value of a SymbolicDim cannot be an int. "
1262
+ "If you are creating a Shape, use int directly instead of SymbolicDim."
1263
+ )
1264
+ self._value = value
1265
+
1266
+ def __eq__(self, other: object) -> bool:
1267
+ """Check equality with another SymbolicDim or string/None."""
1268
+ if not isinstance(other, SymbolicDim):
1269
+ return self.value == other
1270
+ return self.value == other.value
1271
+
1272
+ def __hash__(self) -> int:
1273
+ """Return the hash of the symbolic dimension value."""
1274
+ return hash(self.value)
1275
+
1276
+ @property
1277
+ def value(self) -> str | None:
1278
+ """The value of the symbolic dimension (string or None)."""
1279
+ return self._value
1280
+
1281
+ def __str__(self) -> str:
1282
+ return f"{self._value}"
1283
+
1284
+ def __repr__(self) -> str:
1285
+ return f"{self.__class__.__name__}({self._value})"
1286
+
1287
+
1288
+ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1289
+ """Check if the value is compatible with int (i.e., can be safely cast to int).
1290
+
1291
+ Args:
1292
+ value: The value to check.
1293
+
1294
+ Returns:
1295
+ True if the value is an int or has an __int__ method, False otherwise.
1296
+ """
1297
+ if isinstance(value, int):
1298
+ return True
1299
+ if hasattr(value, "__int__"):
1300
+ # For performance reasons, we do not use isinstance(value, SupportsInt)
1301
+ return True
1302
+ return False
1303
+
1304
+
1305
+ def _maybe_convert_to_symbolic_dim(
1306
+ dim: int | SupportsInt | SymbolicDim | str | None,
1307
+ ) -> SymbolicDim | int:
1308
+ """Convert the value to a SymbolicDim if it is not an int.
1309
+
1310
+ Args:
1311
+ dim: The dimension value, which can be int, str, None, or SymbolicDim.
1312
+
1313
+ Returns:
1314
+ An int or SymbolicDim instance.
1315
+
1316
+ Raises:
1317
+ TypeError: If the value is not int, str, None, or SymbolicDim.
1318
+ """
1319
+ if dim is None or isinstance(dim, str):
1320
+ return SymbolicDim(dim)
1321
+ if _is_int_compatible(dim):
1322
+ return int(dim)
1323
+ if isinstance(dim, SymbolicDim):
1324
+ return dim
1325
+ raise TypeError(
1326
+ f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
1327
+ )
1328
+
1329
+
1330
+ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1331
+ """Represents the shape of a tensor, including its dimensions and optional denotations.
1332
+
1333
+ The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
1334
+ symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
1335
+ shapes to other shapes or plain Python lists.
1336
+
1337
+ A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1338
+ unfrozen, making it suitable to be shared across tensors or values.
1339
+ Call :meth:`freeze` to freeze the shape.
1340
+
1341
+ To update the dimension of a frozen shape, call :meth:`copy` to create a
1342
+ new shape with the same dimensions that can be modified.
1343
+
1344
+ Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
1345
+
1346
+ .. note::
1347
+ Two shapes can be compared for equality. Be careful when comparing shapes with
1348
+ unknown dimensions (``None``), as they may not be considered semantically equal
1349
+ even if all dimensions are the same. You can use :meth:`has_unknown_dim` to
1350
+ check if a shape has any unknown dimensions.
1351
+
1352
+ Example::
1353
+
1354
+ >>> import onnx_ir as ir
1355
+ >>> shape = ir.Shape(["B", None, 3])
1356
+ >>> shape.rank()
1357
+ 3
1358
+ >>> shape.is_static()
1359
+ False
1360
+ >>> shape.is_dynamic()
1361
+ True
1362
+ >>> shape.is_static(dim=2)
1363
+ True
1364
+ >>> shape[0] = 1
1365
+ >>> shape[1] = 2
1366
+ >>> shape.dims
1367
+ (1, 2, 3)
1368
+ >>> shape == [1, 2, 3]
1369
+ True
1370
+ >>> shape.frozen
1371
+ False
1372
+ >>> shape.freeze()
1373
+ >>> shape.frozen
1374
+ True
1375
+
1376
+ Attributes:
1377
+ dims: A tuple of dimensions representing the shape.
1378
+ Each dimension can be an integer, None, or a :class:`SymbolicDim`.
1379
+ frozen: Indicates whether the shape is immutable. When frozen, the shape
1380
+ cannot be modified or unfrozen.
1381
+ """
1382
+
1383
+ __slots__ = ("_dims", "_frozen")
1384
+
1385
+ def __init__(
1386
+ self,
1387
+ dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
1388
+ /,
1389
+ denotations: Iterable[str | None] | None = None,
1390
+ frozen: bool = False,
1391
+ ) -> None:
1392
+ """Initialize a shape.
1393
+
1394
+ Args:
1395
+ dims: The dimensions of the shape. Each dimension can be an integer or a
1396
+ SymbolicDim or any Python object. When a ``dim`` is not an integer or a
1397
+ SymbolicDim, it is converted to a SymbolicDim.
1398
+ denotations: The denotations of the dimensions. If None, the denotations are not set.
1399
+ Standard denotation can optionally be used to denote tensor
1400
+ dimensions with standard semantic descriptions to ensure
1401
+ that operations are applied to the correct axis of a tensor.
1402
+ Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
1403
+ for pre-defined dimension denotations.
1404
+ frozen: If True, the shape is immutable and cannot be modified. This
1405
+ is useful when the shape is initialized by a Tensor or when the shape
1406
+ is shared across multiple tensors. The default is False.
1407
+ """
1408
+ self._dims: list[int | SymbolicDim] = [
1409
+ _maybe_convert_to_symbolic_dim(dim) for dim in dims
1410
+ ]
1411
+ self._denotations: list[str | None] = (
1412
+ list(denotations) if denotations is not None else [None] * len(self._dims)
1413
+ )
1414
+ if len(self._denotations) != len(self._dims):
1415
+ raise ValueError(
1416
+ "The number of denotations, when provided, must be equal to the number of dimensions."
1417
+ )
1418
+ self._frozen: bool = frozen
1419
+
1420
+ @property
1421
+ def dims(self) -> tuple[int | SymbolicDim, ...]:
1422
+ """All dimensions in the shape.
1423
+
1424
+ This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape.
1425
+ """
1426
+ return tuple(self._dims)
1427
+
1428
+ @property
1429
+ def frozen(self) -> bool:
1430
+ """Whether the shape is frozen.
1431
+
1432
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1433
+ Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
1434
+ new shape with the same dimensions that can be modified.
1435
+ """
1436
+ return self._frozen
1437
+
1438
+ def freeze(self) -> None:
1439
+ """Freeze the shape.
1440
+
1441
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1442
+ """
1443
+ self._frozen = True
1444
+
1445
+ def copy(self, frozen: bool = False):
1446
+ """Return a copy of the shape."""
1447
+ return Shape(self._dims, self._denotations, frozen=frozen)
1448
+
1449
+ def rank(self) -> int:
1450
+ """The rank of the tensor this shape represents."""
1451
+ return len(self._dims)
1452
+
1453
+ def numpy(self) -> tuple[int, ...]:
1454
+ if any(not isinstance(dim, int) for dim in self._dims):
1455
+ raise ValueError(f"Cannot convert the shape {self} to a tuple of ints")
1456
+ return tuple(dim for dim in self._dims) # type: ignore
1457
+
1458
+ def __len__(self) -> int:
1459
+ return len(self._dims)
1460
+
1461
+ def __iter__(self) -> Iterator[int | SymbolicDim]:
1462
+ return iter(self._dims)
1463
+
1464
+ @typing.overload
1465
+ def __getitem__(self, index: int) -> int | SymbolicDim: ...
1466
+
1467
+ @typing.overload
1468
+ def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ...
1469
+
1470
+ def __getitem__(self, index):
1471
+ return tuple(self._dims)[index]
1472
+
1473
+ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
1474
+ """Set the dimension at the index.
1475
+
1476
+ Args:
1477
+ index: The index of the dimension.
1478
+ value: The value of the dimension.
1479
+
1480
+ Raises:
1481
+ TypeError: If the shape is frozen and cannot be modified.
1482
+ TypeError: If the value is not an int or SymbolicDim.
1483
+ """
1484
+ if self._frozen:
1485
+ raise TypeError("The shape is frozen and cannot be modified.")
1486
+
1487
+ self._dims[index] = _maybe_convert_to_symbolic_dim(value)
1488
+
1489
+ def get_denotation(self, index: int) -> str | None:
1490
+ """Return the denotation of the dimension at the index.
1491
+
1492
+ Args:
1493
+ index: The index of the dimension.
1494
+
1495
+ Returns:
1496
+ The denotation of the dimension.
1497
+ """
1498
+ return self._denotations[index]
1499
+
1500
+ def set_denotation(self, index: int, denotation: str | None) -> None:
1501
+ """Set the denotation of the dimension at the index.
1502
+
1503
+ Args:
1504
+ index: The index of the dimension.
1505
+ denotation: The denotation of the dimension.
1506
+ """
1507
+ self._denotations[index] = denotation
1508
+
1509
+ def __repr__(self) -> str:
1510
+ return f"{self.__class__.__name__}({self._dims!r})"
1511
+
1512
+ def __str__(self) -> str:
1513
+ """Return a string representation of the shape.
1514
+
1515
+ E.g. [n,1,3]
1516
+ """
1517
+ return f"[{','.join([str(dim) for dim in self._dims])}]"
1518
+
1519
+ def __eq__(self, other: object) -> bool:
1520
+ """Return True if the shapes are equal.
1521
+
1522
+ Two shapes are equal if all their dimensions are equal.
1523
+ """
1524
+ if isinstance(other, Shape):
1525
+ return self._dims == other._dims
1526
+ if not isinstance(other, Iterable):
1527
+ return False
1528
+ return self._dims == list(other)
1529
+
1530
+ def __ne__(self, other: object) -> bool:
1531
+ return not self.__eq__(other)
1532
+
1533
+ @typing.overload
1534
+ def is_static(self, dim: int) -> bool: # noqa: D418
1535
+ """Return True if the dimension is static."""
1536
+
1537
+ @typing.overload
1538
+ def is_static(self) -> bool: # noqa: D418
1539
+ """Return True if all dimensions are static."""
1540
+
1541
+ def is_static(self, dim=None) -> bool:
1542
+ """Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
1543
+ if dim is None:
1544
+ return all(isinstance(dim, int) for dim in self._dims)
1545
+ return isinstance(self[dim], int)
1546
+
1547
+ @typing.overload
1548
+ def is_dynamic(self, dim: int) -> bool: # noqa: D418
1549
+ """Return True if the dimension is dynamic."""
1550
+
1551
+ @typing.overload
1552
+ def is_dynamic(self) -> bool: # noqa: D418
1553
+ """Return True if any dimension is dynamic."""
1554
+
1555
+ def is_dynamic(self, dim=None) -> bool:
1556
+ if dim is None:
1557
+ return not self.is_static()
1558
+ return not self.is_static(dim)
1559
+
1560
+ def is_unknown_dim(self, dim: int) -> bool:
1561
+ """Return True if the dimension is unknown (None).
1562
+
1563
+ A dynamic dimension without a symbolic name is considered unknown.
1564
+
1565
+ .. versionadded:: 0.1.10
1566
+
1567
+ Args:
1568
+ dim: The index of the dimension.
1569
+ """
1570
+ dim_obj = self._dims[dim]
1571
+ return isinstance(dim_obj, SymbolicDim) and dim_obj.value is None
1572
+
1573
+ def has_unknown_dim(self) -> bool:
1574
+ """Return True if any dimension is unknown (None).
1575
+
1576
+ You can use :meth:`is_unknown_dim` to check if a specific dimension is unknown.
1577
+
1578
+ .. versionadded:: 0.1.10
1579
+ """
1580
+ # We can use "in" directly because SymbolicDim implements __eq__ with None
1581
+ return None in self._dims
1582
+
1583
+
1584
+ def _quoted(string: str) -> str:
1585
+ """Return a quoted string.
1586
+
1587
+ This function is used to quote value/node names in the IR for better readability.
1588
+ """
1589
+ return f'"{string}"'
1590
+
1591
+
1592
+ class Usage(NamedTuple):
1593
+ """A usage of a value in a node.
1594
+
1595
+ Attributes:
1596
+ node: The node that uses the value.
1597
+ idx: The input index of the value in the node.
1598
+ """
1599
+
1600
+ node: Node
1601
+ idx: int
1602
+
1603
+
1604
+ def _short_tensor_str_for_node(x: Value) -> str:
1605
+ if x.const_value is None:
1606
+ return ""
1607
+ if x.const_value.size <= 10:
1608
+ try:
1609
+ data = x.const_value.numpy().tolist()
1610
+ except Exception: # pylint: disable=broad-except
1611
+ return "{...}"
1612
+ return f"{{{data}}}"
1613
+ return "{...}"
1614
+
1615
+
1616
+ def _normalize_domain(domain: str) -> str:
1617
+ """Normalize 'ai.onnx' to ''."""
1618
+ return "" if domain == "ai.onnx" else domain
1619
+
1620
+
1621
+ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1622
+ """IR Node.
1623
+
1624
+ .. tip::
1625
+ For a more convenient way (that supports Python objects
1626
+ as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
1627
+
1628
+ If ``graph`` is provided, the node will be added to the graph. Otherwise,
1629
+ the user is responsible for calling ``graph.append(node)`` (or other mutation methods
1630
+ in :class:`Graph`) to add the node to the graph.
1631
+
1632
+ After the node is initialized, it will add itself as a user of its input values.
1633
+
1634
+ The output values of the node are created during node initialization and are immutable.
1635
+ To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
1636
+ replace the input in the consuming node by calling :meth:`replace_input_with`.
1637
+ You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
1638
+ to replace all uses of the output values.
1639
+
1640
+ .. note::
1641
+ When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
1642
+ """
1643
+
1644
+ __slots__ = (
1645
+ "_attributes",
1646
+ "_domain",
1647
+ "_graph",
1648
+ "_inputs",
1649
+ "_metadata",
1650
+ "_metadata_props",
1651
+ "_name",
1652
+ "_op_type",
1653
+ "_outputs",
1654
+ "_overload",
1655
+ "_version",
1656
+ "doc_string",
1657
+ )
1658
+
1659
+ def __init__(
1660
+ self,
1661
+ domain: str,
1662
+ op_type: str,
1663
+ inputs: Iterable[Value | None],
1664
+ attributes: Iterable[Attr] | Mapping[str, Attr] = (),
1665
+ *,
1666
+ overload: str = "",
1667
+ num_outputs: int | None = None,
1668
+ outputs: Sequence[Value] | None = None,
1669
+ version: int | None = None,
1670
+ graph: Graph | Function | None = None,
1671
+ name: str | None = None,
1672
+ doc_string: str | None = None,
1673
+ metadata_props: dict[str, str] | None = None,
1674
+ ):
1675
+ """Initialize a node and add it as a user of the input values.
1676
+
1677
+ Args:
1678
+ domain: The domain of the operator. For onnx operators, this is an empty string.
1679
+ When it is ``"ai.onnx"``, it is normalized to ``""``.
1680
+ op_type: The name of the operator.
1681
+ inputs: The input values. When an input is ``None``, it is an empty input.
1682
+ attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
1683
+ overload: The overload name when the node is invoking a function.
1684
+ num_outputs: The number of outputs of the node. If not specified, the number is 1.
1685
+ outputs: The output values. If ``None``, the outputs are created during initialization.
1686
+ version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
1687
+ graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
1688
+ A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
1689
+ of the function is assigned to the node.
1690
+ name: The name of the node. If ``None``, the node is anonymous. The name may be
1691
+ set by a :class:`Graph` if ``graph`` is specified.
1692
+ doc_string: The documentation string.
1693
+ metadata_props: The metadata properties.
1694
+
1695
+ Raises:
1696
+ TypeError: If the attributes are not :class:`Attr`.
1697
+ ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
1698
+ ValueError: If an output value is ``None``, when outputs is specified.
1699
+ ValueError: If an output value has a producer set already, when outputs is specified.
1700
+ """
1701
+ self._name = name
1702
+ self._domain: str = _normalize_domain(domain)
1703
+ self._op_type: str = op_type
1704
+ # NOTE: Make inputs immutable with the assumption that they are not mutated
1705
+ # very often. This way all mutations can be tracked.
1706
+ # If necessary, we can cache the inputs and outputs as tuples.
1707
+ self._inputs: tuple[Value | None, ...] = tuple(inputs)
1708
+ # Values belong to their defining nodes. The values list is immutable
1709
+ self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1710
+ if isinstance(attributes, Mapping):
1711
+ attributes = tuple(attributes.values())
1712
+ self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
1713
+ attributes
1714
+ )
1715
+ self._overload: str = overload
1716
+ # TODO(justinchuby): Potentially support a version range
1717
+ self._version: int | None = version
1718
+ self._metadata: _metadata.MetadataStore | None = None
1719
+ self._metadata_props: dict[str, str] | None = metadata_props
1720
+ # _graph is set by graph.append
1721
+ self._graph: Graph | None = None
1722
+ # Add the node to the graph if graph is specified
1723
+ if graph is not None:
1724
+ graph.append(self)
1725
+ self.doc_string = doc_string
1726
+
1727
+ # Add the node as a use of the inputs
1728
+ for i, input_value in enumerate(self._inputs):
1729
+ if input_value is not None:
1730
+ input_value._add_usage(self, i) # pylint: disable=protected-access
1731
+
1732
+ def _create_outputs(
1733
+ self, num_outputs: int | None, outputs: Sequence[Value] | None
1734
+ ) -> tuple[Value, ...]:
1735
+ """Check the parameters and create outputs for the node.
1736
+
1737
+ Args:
1738
+ num_outputs: The number of outputs of the node.
1739
+ outputs: The output values of the node.
1740
+
1741
+ Returns:
1742
+ The output values of the node.
1743
+
1744
+ Raises:
1745
+ ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1746
+ ValueError: If an output value is None.
1747
+ ValueError: If an output value has a producer set already.
1748
+ """
1749
+ # Check num_outputs and outputs are consistent
1750
+ if num_outputs is not None and outputs is not None and num_outputs != len(outputs):
1751
+ raise ValueError(
1752
+ "num_outputs must be the same as len(outputs) when num_outputs is specified."
1753
+ f"num_outputs: {num_outputs}, outputs: {outputs}"
1754
+ )
1755
+ # 1. If outputs is specified (can be empty []), use the outputs
1756
+ if outputs is not None:
1757
+ # Check all output values are valid first
1758
+ for output in outputs:
1759
+ if output is None:
1760
+ raise ValueError(f"Output value cannot be None. All outputs: {outputs}")
1761
+ if output.producer() is not None:
1762
+ raise ValueError(
1763
+ f"Supplied output value cannot have a producer when used for initializing a Node. "
1764
+ f"Output: {output}. All outputs: {outputs}"
1765
+ )
1766
+ result = []
1767
+ for i, output in enumerate(outputs):
1768
+ output._producer = self # pylint: disable=protected-access
1769
+ output._index = i # pylint: disable=protected-access
1770
+ result.append(output)
1771
+ return tuple(result)
1772
+
1773
+ # 2. If num_outputs is specified, create num_outputs outputs
1774
+ if num_outputs is None:
1775
+ # Default to 1 output
1776
+ num_outputs = 1
1777
+ assert num_outputs is not None
1778
+ return tuple(Value(self, index=i) for i in range(num_outputs))
1779
+
1780
+ def __str__(self) -> str:
1781
+ node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * (
1782
+ self._overload != ""
1783
+ )
1784
+ inputs_text = (
1785
+ "("
1786
+ + ", ".join(
1787
+ [
1788
+ (
1789
+ f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
1790
+ if x is not None
1791
+ else "None"
1792
+ )
1793
+ for x in self._inputs
1794
+ ]
1795
+ )
1796
+ + ")"
1797
+ )
1798
+ attributes_text = (
1799
+ (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}")
1800
+ if self._attributes
1801
+ else ""
1802
+ )
1803
+ outputs_text = ", ".join(str(x) for x in self._outputs)
1804
+
1805
+ return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}"
1806
+
1807
+ def __repr__(self) -> str:
1808
+ return (
1809
+ f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, "
1810
+ f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, "
1811
+ f"overload={self._overload!r}, outputs={self._outputs!r}, "
1812
+ f"version={self._version!r}, doc_string={self.doc_string!r})"
1813
+ )
1814
+
1815
+ @property
1816
+ def name(self) -> str | None:
1817
+ """Optional name of the node."""
1818
+ return self._name
1819
+
1820
+ @name.setter
1821
+ def name(self, value: str | None) -> None:
1822
+ self._name = value
1823
+
1824
+ @property
1825
+ def domain(self) -> str:
1826
+ """The domain of the operator. For onnx operators, this is an empty string.
1827
+
1828
+ .. note:
1829
+ When domain is `"ai.onnx"`, it is normalized to `""`.
1830
+ """
1831
+ return self._domain
1832
+
1833
+ @domain.setter
1834
+ def domain(self, value: str) -> None:
1835
+ self._domain = _normalize_domain(value)
1836
+
1837
+ @property
1838
+ def version(self) -> int | None:
1839
+ """Opset version of the operator called.
1840
+
1841
+ If ``None``, the version is unspecified and will follow that of the graph.
1842
+ This property is special to ONNX IR to allow mixed opset usage in a graph
1843
+ for supporting more flexible graph transformations. It does not exist in the ONNX
1844
+ serialization (protobuf) spec.
1845
+ """
1846
+ return self._version
1847
+
1848
+ @version.setter
1849
+ def version(self, value: int | None) -> None:
1850
+ self._version = value
1851
+
1852
+ @property
1853
+ def op_type(self) -> str:
1854
+ """The name of the operator called."""
1855
+ return self._op_type
1856
+
1857
+ @op_type.setter
1858
+ def op_type(self, value: str) -> None:
1859
+ self._op_type = value
1860
+
1861
+ @property
1862
+ def overload(self) -> str:
1863
+ """The overload name when the node is invoking a function."""
1864
+ return self._overload
1865
+
1866
+ @overload.setter
1867
+ def overload(self, value: str) -> None:
1868
+ self._overload = value
1869
+
1870
+ @property
1871
+ def inputs(self) -> Sequence[Value | None]:
1872
+ """The input values of the node.
1873
+
1874
+ The inputs are immutable. To change the inputs, create a new node and
1875
+ replace the inputs of the using nodes of this node's outputs by calling
1876
+ :meth:`replace_input_with` on the using nodes of this node's outputs.
1877
+ """
1878
+ return self._inputs
1879
+
1880
+ @inputs.setter
1881
+ def inputs(self, _: Any) -> None:
1882
+ raise AttributeError(
1883
+ "Node.inputs cannot be assigned to. Please use 'resize_inputs' and "
1884
+ "'replace_input_with' instead."
1885
+ )
1886
+
1887
+ def resize_inputs(self, new_size: int, /) -> None:
1888
+ """Resize the inputs of the node.
1889
+
1890
+ If the new size is greater than the current size, new inputs are added as None.
1891
+ If the new size is less than the current size, the extra inputs are removed.
1892
+
1893
+ After ``inputs`` is resized, you can use :meth:`replace_input_with` to set the new inputs.
1894
+
1895
+ .. versionadded:: 0.1.13
1896
+
1897
+ Args:
1898
+ new_size: The new number of inputs.
1899
+ """
1900
+ current_size = len(self._inputs)
1901
+ if new_size == current_size:
1902
+ return
1903
+ if new_size < current_size:
1904
+ # Remove extra inputs
1905
+ for i in range(new_size, current_size):
1906
+ self.replace_input_with(i, None)
1907
+ self._inputs = self._inputs[:new_size]
1908
+ else:
1909
+ # Add new inputs as None
1910
+ self._inputs = self._inputs + (None,) * (new_size - current_size)
1911
+
1912
+ def predecessors(self) -> Sequence[Node]:
1913
+ """Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
1914
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1915
+ predecessors: dict[Node, None] = {}
1916
+ for value in self.inputs:
1917
+ if value is not None and (producer := value.producer()) is not None:
1918
+ predecessors[producer] = None
1919
+ return tuple(predecessors)
1920
+
1921
+ def successors(self) -> Sequence[Node]:
1922
+ """Return the successor nodes of the node, deduplicated, in a deterministic order."""
1923
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1924
+ successors: dict[Node, None] = {}
1925
+ for value in self.outputs:
1926
+ assert value is not None, "Bug: Output values are not expected to be None"
1927
+ for usage in value.uses():
1928
+ successors[usage.node] = None
1929
+ return tuple(successors)
1930
+
1931
+ def replace_input_with(self, index: int, value: Value | None) -> None:
1932
+ """Replace an input with a new value."""
1933
+ if index < 0 or index >= len(self.inputs):
1934
+ raise ValueError(f"Index out of range: {index}")
1935
+ old_input = self.inputs[index]
1936
+ self._inputs = tuple(
1937
+ value if i == index else old_input for i, old_input in enumerate(self.inputs)
1938
+ )
1939
+ if old_input is not None:
1940
+ old_input._remove_usage(self, index) # pylint: disable=protected-access
1941
+ if value is not None:
1942
+ value._add_usage(self, index) # pylint: disable=protected-access
1943
+
1944
+ def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
1945
+ """Insert a node before this node in the list of nodes in the graph.
1946
+
1947
+ It is the same as calling ``graph.insert_before(self, nodes)``.
1948
+
1949
+ Example::
1950
+
1951
+ Before: previous_node -> self
1952
+ previous_node' -> node -> next_node'
1953
+ After: previous_node -> node -> self
1954
+ previous_node' -> next_node'
1955
+
1956
+ Args:
1957
+ nodes: A node or a sequence of nodes to put before this node.
1958
+ """
1959
+ if self._graph is None:
1960
+ raise ValueError("The node to prepend to does not belong to any graph.")
1961
+ self._graph.insert_before(self, nodes)
1962
+
1963
+ def append(self, /, nodes: Node | Iterable[Node]) -> None:
1964
+ """Insert a node after this node in the list of nodes in the graph.
1965
+
1966
+ It is the same as calling ``graph.insert_after(self, nodes)``.
1967
+
1968
+ Example::
1969
+
1970
+ Before: previous_node -> self
1971
+ previous_node' -> node -> next_node'
1972
+ After: previous_node -> self -> node
1973
+ previous_node' -> next_node'
1974
+
1975
+ Args:
1976
+ nodes: A node or a sequence of nodes to put after this node.
1977
+ """
1978
+ if self._graph is None:
1979
+ raise ValueError("The node to append to does not belong to any graph.")
1980
+ self._graph.insert_after(self, nodes)
1981
+
1982
+ @property
1983
+ def outputs(self) -> Sequence[Value]:
1984
+ """The output values of the node.
1985
+
1986
+ The outputs are always attached to this node once initialized (immutable),
1987
+ except that the list can be resized to remove or add outputs.
1988
+
1989
+ Use :meth:`resize_outputs` to change the number of outputs of the node.
1990
+ """
1991
+ return self._outputs
1992
+
1993
+ @outputs.setter
1994
+ def outputs(self, _: Sequence[Value]) -> None:
1995
+ raise AttributeError(
1996
+ "Node.outputs cannot be assigned to. Please use 'resize_outputs' or create a new node instead."
1997
+ )
1998
+
1999
+ def resize_outputs(self, new_size: int, /) -> None:
2000
+ """Resize the outputs of the node.
2001
+
2002
+ If the new size is greater than the current size, new output values are created.
2003
+ If the new size is less than the current size, the extra output values are removed.
2004
+ The removed output values must not have any uses.
2005
+
2006
+ .. versionadded:: 0.1.13
2007
+
2008
+ Args:
2009
+ new_size: The new number of outputs.
2010
+
2011
+ Raises:
2012
+ ValueError: If the new size is less than the current size and
2013
+ the removed outputs have uses.
2014
+ """
2015
+ current_size = len(self._outputs)
2016
+ if new_size == current_size:
2017
+ return
2018
+ if new_size < current_size:
2019
+ # Check that the removed outputs have no uses
2020
+ for output in self._outputs[new_size:]:
2021
+ if output.uses():
2022
+ raise ValueError(
2023
+ f"Cannot remove output {output} because it has uses: {output.uses()}"
2024
+ )
2025
+ for output in self._outputs[new_size:]:
2026
+ # Detach the output from this node
2027
+ output._producer = None # pylint: disable=protected-access
2028
+ output._index = -1 # pylint: disable=protected-access
2029
+ self._outputs = self._outputs[:new_size]
2030
+ else:
2031
+ # Create new outputs
2032
+ new_outputs = [Value(self, index=i) for i in range(current_size, new_size)]
2033
+ self._outputs = self._outputs + tuple(new_outputs)
2034
+
2035
+ @property
2036
+ def attributes(self) -> _graph_containers.Attributes:
2037
+ """The attributes of the node as ``dict[str, Attr]`` with additional access methods.
2038
+
2039
+ Use it as a dictionary with keys being the attribute names and values being the
2040
+ :class:`Attr` objects.
2041
+
2042
+ Use ``node.attributes.add(attr)`` to add an attribute to the node.
2043
+ Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
2044
+ Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
2045
+ """
2046
+ return self._attributes
2047
+
2048
+ @property
2049
+ def meta(self) -> _metadata.MetadataStore:
2050
+ """The metadata store for intermediate analysis.
2051
+
2052
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2053
+ to the ONNX proto.
2054
+ """
2055
+ if self._metadata is None:
2056
+ self._metadata = _metadata.MetadataStore()
2057
+ return self._metadata
2058
+
2059
+ @property
2060
+ def metadata_props(self) -> dict[str, str]:
2061
+ """The metadata properties of the node.
2062
+
2063
+ The metadata properties are used to store additional information about the node.
2064
+ Unlike ``meta``, this property is serialized to the ONNX proto.
2065
+ """
2066
+ if self._metadata_props is None:
2067
+ self._metadata_props = {}
2068
+ return self._metadata_props
2069
+
2070
+ @property
2071
+ def graph(self) -> Graph | None:
2072
+ """The graph that the node belongs to.
2073
+
2074
+ If the node is not added to any graph, this property is None.
2075
+ """
2076
+ return self._graph
2077
+
2078
+ @graph.setter
2079
+ def graph(self, value: Graph | None) -> None:
2080
+ self._graph = value
2081
+
2082
+ def op_identifier(self) -> _protocols.OperatorIdentifier:
2083
+ """Return the operator identifier of the node.
2084
+
2085
+ The operator identifier is a tuple of the domain, op_type and overload.
2086
+ """
2087
+ return self._domain, self._op_type, self._overload
2088
+
2089
+ def display(self, *, page: bool = False) -> None:
2090
+ """Pretty print the node.
2091
+
2092
+ This method is used for debugging and visualization purposes.
2093
+ """
2094
+ # Add the node's name to the displayed text
2095
+ print(f"Node: {self.name!r}")
2096
+ if self.doc_string:
2097
+ print(f"Doc: {self.doc_string}")
2098
+ super().display(page=page)
2099
+
2100
+
2101
+ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
2102
+ """Tensor types that are non recursive types."""
2103
+
2104
+ __slots__ = ("_dtype", "denotation")
2105
+
2106
+ def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None:
2107
+ self._dtype = dtype
2108
+ self.denotation = denotation
2109
+
2110
+ @property
2111
+ def dtype(self) -> _enums.DataType:
2112
+ return self._dtype
2113
+
2114
+ @dtype.setter
2115
+ def dtype(self, value: _enums.DataType) -> None:
2116
+ self._dtype = value
2117
+
2118
+ @property
2119
+ def elem_type(self) -> _enums.DataType:
2120
+ """Return the element type of the tensor type."""
2121
+ return self.dtype
2122
+
2123
+ def __hash__(self) -> int:
2124
+ return hash(repr(self))
2125
+
2126
+ def __eq__(self, other: object) -> bool:
2127
+ if self.__class__ is not other.__class__:
2128
+ return False
2129
+ return self.dtype == other.dtype # type: ignore[attr-defined]
2130
+
2131
+ def __repr__(self) -> str:
2132
+ # Remove "Type" from name for display
2133
+ short_name = self.__class__.__name__[:-4]
2134
+ return f"{short_name}({self.dtype!r})"
2135
+
2136
+
2137
+ class TensorType(_TensorTypeBase):
2138
+ """A type that represents a tensor."""
2139
+
2140
+ def __str__(self) -> str:
2141
+ return f"{self.dtype}"
2142
+
2143
+
2144
+ class SparseTensorType(_TensorTypeBase):
2145
+ """A type that represents a sparse tensor."""
2146
+
2147
+
2148
+ class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
2149
+ """Base for recursive types like Optional and Sequence."""
2150
+
2151
+ __slots__ = ("_elem_type", "denotation")
2152
+
2153
+ def __init__(
2154
+ self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None
2155
+ ) -> None:
2156
+ self._elem_type = elem_type
2157
+ self.denotation = denotation
2158
+
2159
+ @property
2160
+ def dtype(self) -> _enums.DataType:
2161
+ return self._elem_type.dtype
2162
+
2163
+ @dtype.setter
2164
+ def dtype(self, value: _enums.DataType) -> None:
2165
+ self._elem_type.dtype = value
2166
+
2167
+ @property
2168
+ def elem_type(self) -> _protocols.TypeProtocol:
2169
+ return self._elem_type
2170
+
2171
+ def __hash__(self) -> int:
2172
+ return hash(repr(self))
2173
+
2174
+ def __eq__(self, other: object) -> bool:
2175
+ if not isinstance(other, _RecursiveTypeBase):
2176
+ return False
2177
+ if self.__class__ != other.__class__:
2178
+ return False
2179
+ # Recursively compare the type of the elements
2180
+ return self.elem_type == other.elem_type
2181
+
2182
+ def __repr__(self) -> str:
2183
+ # Remove "Type" from name for display
2184
+ short_name = self.__class__.__name__[:-4]
2185
+ return f"{short_name}({self.elem_type!r})"
2186
+
2187
+
2188
+ class SequenceType(_RecursiveTypeBase):
2189
+ """A type that represents a sequence of elements."""
2190
+
2191
+
2192
+ class OptionalType(_RecursiveTypeBase):
2193
+ """A type that represents an optional element."""
2194
+
2195
+
2196
+ class _OpHandlerProtocol(Protocol):
2197
+ """Protocol for an object that can handle magic methods on Values.
2198
+
2199
+ .. note::
2200
+ Only the basic arithmetic magic methods are supported on Values.
2201
+
2202
+ Importantly, ``__eq__`` is not included because Values may need to be compared for identity.
2203
+ For consistency, none of the other comparison operators are included.
2204
+ """
2205
+
2206
+ def Add(self, lhs, rhs) -> Value: ... # noqa: N802
2207
+ def Sub(self, lhs, rhs) -> Value: ... # noqa: N802
2208
+ def Mul(self, lhs, rhs) -> Value: ... # noqa: N802
2209
+ def Div(self, lhs, rhs) -> Value: ... # noqa: N802
2210
+ def Neg(self, operand) -> Value: ... # noqa: N802
2211
+
2212
+
2213
+ def set_value_magic_handler(handler: _OpHandlerProtocol | None) -> _OpHandlerProtocol | None:
2214
+ """Set the magic handler for Value arithmetic methods.
2215
+
2216
+ Framework authors can implement custom context managers that set
2217
+ the magic handler to enable arithmetic operations on Values.
2218
+
2219
+ Example::
2220
+ class MyOpHandler:
2221
+ def Add(self, lhs, rhs):
2222
+ # Implement addition logic here
2223
+ pass
2224
+ ...
2225
+
2226
+ @contextlib.contextmanager
2227
+ def graph_context(graph):
2228
+ old_handler = onnx_ir.set_value_magic_handler(MyOpHandler(graph))
2229
+ try:
2230
+ yield
2231
+ finally:
2232
+ onnx_ir.set_value_magic_handler(old_handler)
2233
+
2234
+ Args:
2235
+ handler: The magic handler to set.
2236
+
2237
+ Returns:
2238
+ The previous magic handler.
2239
+ """
2240
+ old_handler = WithArithmeticMethods._magic_handler
2241
+ WithArithmeticMethods._magic_handler = handler
2242
+ return old_handler
2243
+
2244
+
2245
+ class WithArithmeticMethods:
2246
+ """Mixin class that adds arithmetic methods to Value.
2247
+
2248
+ This class is used to add arithmetic methods to Value that support arithmetic operations.
2249
+ """
2250
+
2251
+ _magic_handler: ClassVar[_OpHandlerProtocol | None] = None
2252
+
2253
+ def _get_magic_handler(self):
2254
+ if self._magic_handler is None:
2255
+ raise ValueError(
2256
+ "No magic handler is set. Please use 'onnx_ir.set_value_magic_handler' to set a handler."
2257
+ )
2258
+ return self._magic_handler
2259
+
2260
+ # Magic methods for arithmetic operations
2261
+ def __add__(self, other, /):
2262
+ return self._get_magic_handler().Add(self, other) # type: ignore[union-attr]
2263
+
2264
+ def __sub__(self, other, /):
2265
+ return self._get_magic_handler().Sub(self, other) # type: ignore[union-attr]
2266
+
2267
+ def __mul__(self, other, /):
2268
+ return self._get_magic_handler().Mul(self, other) # type: ignore[union-attr]
2269
+
2270
+ def __truediv__(self, other, /):
2271
+ return self._get_magic_handler().Div(self, other) # type: ignore[union-attr]
2272
+
2273
+ def __neg__(self):
2274
+ return self._get_magic_handler().Neg(self) # type: ignore[union-attr]
2275
+
2276
+ def __radd__(self, other, /):
2277
+ return self._get_magic_handler().Add(other, self) # type: ignore[union-attr]
2278
+
2279
+ def __rsub__(self, other, /):
2280
+ return self._get_magic_handler().Sub(other, self) # type: ignore[union-attr]
2281
+
2282
+ def __rmul__(self, other, /):
2283
+ return self._get_magic_handler().Mul(other, self) # type: ignore[union-attr]
2284
+
2285
+ def __rtruediv__(self, other, /):
2286
+ return self._get_magic_handler().Div(other, self) # type: ignore[union-attr]
2287
+
2288
+
2289
+ class Value(WithArithmeticMethods, _protocols.ValueProtocol, _display.PrettyPrintable):
2290
+ """IR Value.
2291
+
2292
+ A value is a named entity that can be used to represent an input or output of a graph,
2293
+ a function, or a node. The information it stores generalizes over ``ValueInfoProto``
2294
+ in the ONNX specification.
2295
+
2296
+ A :class:`Value` is always not owned or owned by exactly one node. When the value is not
2297
+ owned, it must be an input of a graph or a function. ``producer`` and ``index``
2298
+ are ``None``.
2299
+
2300
+ When the value is owned by a node, it is an output of the node.
2301
+ The node that produces the value can be accessed with :meth:`producer`.
2302
+ The index of the output of the node that produces the value can be accessed with
2303
+ :meth:`index`.
2304
+
2305
+ To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
2306
+ nodes can be obtained with :meth:`consumers`.
2307
+
2308
+ To check if the value is an is an input, output or initializer of a graph,
2309
+ use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
2310
+
2311
+ Use :attr:`graph` to get the graph that owns the value.
2312
+
2313
+ .. note:: Magic methods
2314
+ Only the basic arithmetic magic methods are supported on Values.
2315
+
2316
+ Importantly, ``__eq__`` is not included because Values may need to be compared for identity.
2317
+ For consistency, none of the other comparison operators are included.
2318
+
2319
+ .. versionadded:: 0.1.14
2320
+ Value now supports arithmetic magic methods when a handler is set via
2321
+ :func:`onnx_ir.set_value_magic_handler`.
2322
+ """
2323
+
2324
+ __slots__ = (
2325
+ "_const_value",
2326
+ "_graph",
2327
+ "_index",
2328
+ "_is_graph_input",
2329
+ "_is_graph_output",
2330
+ "_is_initializer",
2331
+ "_metadata",
2332
+ "_metadata_props",
2333
+ "_name",
2334
+ "_producer",
2335
+ "_shape",
2336
+ "_type",
2337
+ "_uses",
2338
+ "doc_string",
2339
+ )
2340
+
2341
+ def __init__(
2342
+ self,
2343
+ producer: Node | None = None,
2344
+ *,
2345
+ index: int | None = None,
2346
+ name: str | None = None,
2347
+ shape: Shape | None = None,
2348
+ type: _protocols.TypeProtocol | None = None,
2349
+ doc_string: str | None = None,
2350
+ const_value: _protocols.TensorProtocol | None = None,
2351
+ metadata_props: dict[str, str] | None = None,
2352
+ ) -> None:
2353
+ """Initialize a value.
2354
+
2355
+ When assigning a name to the value, the name of the backing `const_value` (Tensor)
2356
+ will also be updated. If the value is an initializer of a graph, the initializers
2357
+ dictionary of the graph will also be updated.
2358
+
2359
+ .. versionchanged:: 0.1.10
2360
+ Assigning a name to the value will also update the graph initializer entry
2361
+ if the value is an initializer of a graph.
2362
+
2363
+ Args:
2364
+ producer: The node that produces the value.
2365
+ It can be ``None`` when the value is initialized first than its producer.
2366
+ index: The index of the output of the defining node.
2367
+ name: The name of the value.
2368
+ shape: The shape of the value.
2369
+ type: The type of the value.
2370
+ doc_string: The documentation string.
2371
+ const_value: The constant tensor if the value is constant.
2372
+ metadata_props: Metadata that will be serialized to the ONNX file.
2373
+ """
2374
+ self._producer: Node | None = producer
2375
+ self._index: int | None = index
2376
+ self._metadata: _metadata.MetadataStore | None = None
2377
+ self._metadata_props: dict[str, str] | None = metadata_props
2378
+
2379
+ self._name: str | None = name
2380
+ self._shape: Shape | None = shape
2381
+ self._type: _protocols.TypeProtocol | None = type
2382
+ # TODO(justinchuby): Handle initialization when a const value is provided
2383
+ # We can get shape and type information from the const value
2384
+ self._const_value = const_value
2385
+ # Use a collection of (Node, int) to store uses. This is needed
2386
+ # because a single use can use the same value multiple times.
2387
+ # Use a dictionary to preserve insertion order so that the visiting order is deterministic
2388
+ self._uses: dict[Usage, None] = {}
2389
+ self.doc_string = doc_string
2390
+
2391
+ # The graph this value belongs to. It is set *only* when the value is added as
2392
+ # a graph input, output or initializer.
2393
+ # The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
2394
+ self._graph: Graph | None = None
2395
+ self._is_graph_input: bool = False
2396
+ self._is_graph_output: bool = False
2397
+ self._is_initializer: bool = False
2398
+
2399
+ def __repr__(self) -> str:
2400
+ value_name = self.name if self.name else "anonymous:" + str(id(self))
2401
+ type_text = f", type={self.type!r}" if self.type is not None else ""
2402
+ shape_text = f", shape={self.shape!r}" if self.shape is not None else ""
2403
+ producer = self.producer()
2404
+ if producer is None:
2405
+ producer_text = ""
2406
+ elif producer.name is not None:
2407
+ producer_text = f", producer='{producer.name}'"
2408
+ else:
2409
+ producer_text = f", producer=anonymous_node:{id(producer)}"
2410
+ index_text = f", index={self.index()}" if self.index() is not None else ""
2411
+ const_value_text = self._constant_tensor_part()
2412
+ if const_value_text:
2413
+ const_value_text = f", const_value={const_value_text}"
2414
+ return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})"
2415
+
2416
+ def __str__(self) -> str:
2417
+ value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
2418
+ shape_text = str(self.shape) if self.shape is not None else "?"
2419
+ type_text = str(self.type) if self.type is not None else "?"
2420
+
2421
+ # Quote the name because in reality the names can have invalid characters
2422
+ # that make them hard to read
2423
+ return (
2424
+ f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
2425
+ )
2426
+
2427
+ def _constant_tensor_part(self) -> str:
2428
+ """Display string for the constant tensor attached to str of Value."""
2429
+ if self.const_value is not None:
2430
+ # Only display when the const value is small
2431
+ if self.const_value.size <= 10:
2432
+ return f"{{{self.const_value}}}"
2433
+ else:
2434
+ return f"{{{self.const_value.__class__.__name__}(...)}}"
2435
+ return ""
2436
+
2437
+ @property
2438
+ def graph(self) -> Graph | None:
2439
+ """Return the graph that defines this value.
2440
+
2441
+ When the value is an input/output/initializer of a graph, the owning graph
2442
+ is that graph. When the value is an output of a node, the owning graph is the
2443
+ graph that the node belongs to. When the value is not owned by any graph,
2444
+ it returns ``None``.
2445
+ """
2446
+ if self._graph is not None:
2447
+ return self._graph
2448
+ if self._producer is not None:
2449
+ return self._producer.graph
2450
+ return None
2451
+
2452
+ def _owned_by_graph(self) -> bool:
2453
+ """Return True if the value is owned by a graph."""
2454
+ result = self._is_graph_input or self._is_graph_output or self._is_initializer
2455
+ if result:
2456
+ assert self._graph is not None
2457
+ return result
2458
+
2459
+ def producer(self) -> Node | None:
2460
+ """The node that produces this value.
2461
+
2462
+ When producer is ``None``, the value does not belong to a node, and is
2463
+ typically a graph input or an initializer. You can use :meth:`graph``
2464
+ to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
2465
+ or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
2466
+ """
2467
+ return self._producer
2468
+
2469
+ def consumers(self) -> Sequence[Node]:
2470
+ """Return the nodes (deduplicated) that consume this value."""
2471
+ return tuple({usage.node: None for usage in self._uses})
2472
+
2473
+ def index(self) -> int | None:
2474
+ """The index of the output of the defining node."""
2475
+ return self._index
2476
+
2477
+ def uses(self) -> Collection[Usage]:
2478
+ """Return a set of uses of the value.
2479
+
2480
+ The set contains tuples of ``(Node, index)`` where the index is the index of the input
2481
+ of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
2482
+ """
2483
+ # Create a tuple for the collection so that iteration on will will not
2484
+ # be affected when the usage changes during graph mutation.
2485
+ # This adds a small overhead but is better a user experience than
2486
+ # having users call tuple().
2487
+ return tuple(self._uses)
2488
+
2489
+ def _add_usage(self, use: Node, index: int) -> None:
2490
+ """Add a usage of this value.
2491
+
2492
+ This is an internal method. It should only be called by the Node class.
2493
+ """
2494
+ self._uses[Usage(use, index)] = None
2495
+
2496
+ def _remove_usage(self, use: Node, index: int) -> None:
2497
+ """Remove a node from the uses of this value.
2498
+
2499
+ This is an internal method. It should only be called by the Node class.
2500
+ """
2501
+ self._uses.pop(Usage(use, index))
2502
+
2503
+ @property
2504
+ def name(self) -> str | None:
2505
+ return self._name
2506
+
2507
+ @name.setter
2508
+ def name(self, value: str | None) -> None:
2509
+ if self._name == value:
2510
+ return
2511
+
2512
+ # First check if renaming is valid. Do not change anything if it is invalid
2513
+ # to prevent the value from being in an inconsistent state.
2514
+ is_initializer = self.is_initializer()
2515
+ if is_initializer:
2516
+ if value is None:
2517
+ raise ValueError(
2518
+ "Initializer value cannot have name set to None. Please pop() the value from initializers first to do so."
2519
+ )
2520
+ graph = self._graph
2521
+ assert graph is not None
2522
+ if value in graph.initializers and graph.initializers[value] is not self:
2523
+ raise ValueError(
2524
+ f"Cannot rename initializer '{self}' to '{value}': an initializer with that name already exists."
2525
+ )
2526
+
2527
+ # Rename the backing constant tensor
2528
+ if self._const_value is not None:
2529
+ self._const_value.name = value
2530
+
2531
+ # Rename self
2532
+ old_name = self._name
2533
+ self._name = value
2534
+
2535
+ if is_initializer:
2536
+ # Rename the initializer entry in the graph
2537
+ assert value is not None, "debug: Should be guarded above"
2538
+ graph = self._graph
2539
+ assert graph is not None
2540
+ assert old_name is not None
2541
+ graph.initializers.pop(old_name)
2542
+ graph.initializers[value] = self
2543
+
2544
+ @property
2545
+ def type(self) -> _protocols.TypeProtocol | None:
2546
+ """The type of the tensor.
2547
+
2548
+ Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``.
2549
+ To obtain the data type of the tensor, use ``type.dtype`` or conveniently
2550
+ :attr:`dtype`.
2551
+ """
2552
+ return self._type
2553
+
2554
+ @type.setter
2555
+ def type(self, value: _protocols.TypeProtocol | None) -> None:
2556
+ self._type = value
2557
+
2558
+ @property
2559
+ def dtype(self) -> _enums.DataType | None:
2560
+ """The data type of the tensor."""
2561
+ if self._type is None:
2562
+ return None
2563
+ return self._type.dtype
2564
+
2565
+ @dtype.setter
2566
+ def dtype(self, value: _enums.DataType) -> None:
2567
+ """Set the data type of the tensor.
2568
+
2569
+ If the type is not set, it will be initialized to a new TensorType. To
2570
+ set the type as other types like ``SequenceType``, initialize the type
2571
+ then set :attr:`type` instead.
2572
+ """
2573
+ if self._type is None:
2574
+ self._type = TensorType(value)
2575
+ else:
2576
+ self._type.dtype = value
2577
+
2578
+ @property
2579
+ def shape(self) -> Shape | None:
2580
+ return self._shape
2581
+
2582
+ @shape.setter
2583
+ def shape(self, value: Shape | None) -> None:
2584
+ if value is None:
2585
+ self._shape = None
2586
+ return
2587
+ if isinstance(value, Shape):
2588
+ self._shape = value
2589
+ return
2590
+ raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'")
2591
+
2592
+ @property
2593
+ def const_value(
2594
+ self,
2595
+ ) -> _protocols.TensorProtocol | None:
2596
+ """The backing constant tensor for the value.
2597
+
2598
+ If the ``Value`` has a ``const_value`` and is part of a graph initializers
2599
+ dictionary, the value is an initialized value. Its ``const_value``
2600
+ will appear as an ``initializer`` in the GraphProto when serialized.
2601
+
2602
+ If the ``Value`` is not part of a graph initializers dictionary, the ``const_value``
2603
+ field will be ignored during serialization.
2604
+
2605
+ ``const_value`` can be backed by different raw data types, such as numpy arrays.
2606
+ The only guarantee is that it conforms TensorProtocol.
2607
+ """
2608
+ return self._const_value
2609
+
2610
+ @const_value.setter
2611
+ def const_value(
2612
+ self,
2613
+ value: _protocols.TensorProtocol | None,
2614
+ ) -> None:
2615
+ if onnx_ir.DEBUG:
2616
+ if value is not None and not isinstance(value, _protocols.TensorProtocol):
2617
+ raise TypeError(
2618
+ f"Expected value to be a TensorProtocol or None, got '{type(value)}'"
2619
+ )
2620
+ self._const_value = value
2621
+
2622
+ @property
2623
+ def meta(self) -> _metadata.MetadataStore:
2624
+ """The metadata store for intermediate analysis.
2625
+
2626
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2627
+ to the ONNX proto.
2628
+ """
2629
+ if self._metadata is None:
2630
+ self._metadata = _metadata.MetadataStore()
2631
+ return self._metadata
2632
+
2633
+ @property
2634
+ def metadata_props(self) -> dict[str, str]:
2635
+ """The metadata properties of the value.
2636
+
2637
+ The metadata properties are used to store additional information about the value.
2638
+ Unlike ``meta``, this property is serialized to the ONNX proto.
2639
+ """
2640
+ if self._metadata_props is None:
2641
+ self._metadata_props = {}
2642
+ return self._metadata_props
2643
+
2644
+ def is_graph_input(self) -> bool:
2645
+ """Whether the value is an input of a graph."""
2646
+ return self._is_graph_input
2647
+
2648
+ def is_graph_output(self) -> bool:
2649
+ """Whether the value is an output of a graph."""
2650
+ return self._is_graph_output
2651
+
2652
+ def is_initializer(self) -> bool:
2653
+ """Whether the value is an initializer of a graph."""
2654
+ return self._is_initializer
2655
+
2656
+ def replace_all_uses_with(
2657
+ self, replacement: Value, /, replace_graph_outputs: bool = False
2658
+ ) -> None:
2659
+ """Replace all uses of this value with another value.
2660
+
2661
+ .. tip::
2662
+ **Handling graph outputs**
2663
+
2664
+ To also replace graph outputs that reference the values being replaced, either
2665
+ set ``replace_graph_outputs`` to True, or manually update the graph outputs
2666
+ before calling this function to avoid an error being raised when ``replace_graph_outputs=False``.
2667
+
2668
+ Be careful when a value appears multiple times in the graph outputs -
2669
+ this is invalid. An identity node will need to be added on each duplicated
2670
+ outputs to ensure a valid ONNX graph.
2671
+
2672
+ You may also want to assign the name of this value to the replacement value
2673
+ to maintain the name when it is a graph output.
2674
+
2675
+ To replace usage of a sequence of values with another sequence of values, consider using
2676
+ :func:`onnx_ir.convenience.replace_all_uses_with`.
2677
+
2678
+ .. versionadded:: 0.1.12
2679
+
2680
+ Args:
2681
+ replacement: The value to replace all uses with.
2682
+ replace_graph_outputs: If True, graph outputs that reference this value
2683
+ will also be updated to reference the replacement.
2684
+
2685
+ Raises:
2686
+ ValueError: When ``replace_graph_outputs`` is False && when the value to
2687
+ replace is a graph output.
2688
+ """
2689
+ # NOTE: Why we don't replace the value name when the value is an output:
2690
+ # When the replacement value is already an output of the graph, renaming it
2691
+ # to the name of this value will cause name conflicts. It is better to let
2692
+ # the user handle the renaming explicitly and insert identity nodes if needed.
2693
+ if self.is_graph_output():
2694
+ graph = self.graph
2695
+ assert graph is not None
2696
+
2697
+ if not replace_graph_outputs:
2698
+ raise ValueError(
2699
+ f"{self!r} is an output of graph {graph.name!r}. "
2700
+ "Set replace_graph_outputs=True or replace the graph output frist before "
2701
+ "calling replace_all_uses_with."
2702
+ )
2703
+
2704
+ for i, output in enumerate(graph.outputs):
2705
+ if output is self:
2706
+ graph.outputs[i] = replacement
2707
+
2708
+ for user_node, index in self.uses():
2709
+ user_node.replace_input_with(index, replacement)
2710
+
2711
+ def merge_shapes(self, other: Shape | None, /) -> None:
2712
+ """Merge the shape of this value with another shape to update the existing shape, with the current shape's dimensions taking precedence.
2713
+
2714
+ Two dimensions are merged as follows:
2715
+
2716
+ * If both dimensions are equal, the merged dimension is the same.
2717
+ * If one dimension is SymbolicDim and the other is concrete, the merged dimension is the concrete one.
2718
+ * If both dimensions are SymbolicDim, a named symbolic dimension (non-None value) is preferred over an unnamed one (None value).
2719
+ * In all other cases where the dimensions differ, the current shape's dimension is taken (a warning is emitted when both are concrete integers).
2720
+
2721
+ .. versionadded:: 0.1.14
2722
+
2723
+ Args:
2724
+ other: The other shape to merge with.
2725
+
2726
+ Returns:
2727
+ A new shape that is the result of merging this shape with the other shape.
2728
+
2729
+ Raises:
2730
+ ValueError: If the shapes have different ranks.
2731
+ ValueError: If there are conflicting concrete dimensions.
2732
+ """
2733
+ if other is None:
2734
+ return
2735
+
2736
+ merged_shape = self.shape
2737
+ if merged_shape is None:
2738
+ self._shape = other.copy()
2739
+ return
2740
+
2741
+ if merged_shape.frozen:
2742
+ merged_shape = merged_shape.copy()
2743
+
2744
+ if len(merged_shape) != len(other):
2745
+ raise ValueError(f"Shapes must have the same rank, got self={self}, other={other}")
2746
+
2747
+ def merge_dims(dim1, dim2):
2748
+ if dim1 == dim2:
2749
+ return dim1
2750
+ if isinstance(dim1, int) and isinstance(dim2, int):
2751
+ raise ValueError( # noqa: TRY004
2752
+ f"Conflicting dimensions {dim1} and {dim2} when merging shapes "
2753
+ f"{self} and {other}."
2754
+ )
2755
+ if not isinstance(dim1, SymbolicDim):
2756
+ return dim1 # Prefer int value over symbolic dim
2757
+ if not isinstance(dim2, SymbolicDim):
2758
+ return dim2
2759
+ if dim1.value is None:
2760
+ return dim2
2761
+ return dim1
2762
+
2763
+ for i, (dim1, dim2) in enumerate(zip(merged_shape, other)):
2764
+ merged_shape[i] = merge_dims(dim1, dim2)
2765
+
2766
+ self._shape = merged_shape
2767
+
2768
+
2769
+ @deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
2770
+ def Input( # noqa: N802
2771
+ name: str | None = None,
2772
+ shape: Shape | None = None,
2773
+ type: _protocols.TypeProtocol | None = None,
2774
+ doc_string: str | None = None,
2775
+ ) -> Value:
2776
+ """Create an input of a Graph or a Function.
2777
+
2778
+ This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
2779
+ """
2780
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2781
+
2782
+ return Value(name=name, shape=shape, type=type, doc_string=doc_string)
2783
+
2784
+
2785
+ def _check_node_safe_to_remove(
2786
+ node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value]
2787
+ ) -> None:
2788
+ """Check if a node is safe to remove.
2789
+
2790
+ 1. It checks to make sure there are no users of the node that are not
2791
+ to be removed before removing it.
2792
+ 2. It checks the node does not contribute to any graph outputs.
2793
+
2794
+ This check is typically O(1) assuming the number of uses of the node is small
2795
+
2796
+ Args:
2797
+ node: The node to check.
2798
+ to_remove: A set of nodes that are to be removed.
2799
+ This set is used to check if the node is still being used by other
2800
+ nodes that are not to be removed.
2801
+ graph_outputs: A set of values that are outputs of the graph.
2802
+
2803
+ Raises:
2804
+ ValueError: If the node does not belong to this graph or if there are users of the node.
2805
+ ValueError: If the node is still being used by other nodes not to be removed.
2806
+ """
2807
+ for output in node.outputs:
2808
+ if output in graph_outputs:
2809
+ raise ValueError(
2810
+ f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True."
2811
+ )
2812
+ uses_not_to_remove = [user for user, _ in output.uses() if user not in to_remove]
2813
+ if uses_not_to_remove:
2814
+ raise ValueError(
2815
+ f"Output value '{output!r}' is still being used by other nodes that are not to be "
2816
+ f"removed. All of its users that is not being removed: {uses_not_to_remove!r}. "
2817
+ "Please make sure these nodes are no longer using the output value."
2818
+ )
2819
+
2820
+
2821
+ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2822
+ """IR Graph.
2823
+
2824
+ Graph represents a computation graph. In addition to the ONNX specification
2825
+ specified fields, it also contains a mapping of :attr:`opset_imports`. This
2826
+ allows different subgraphs to import different opsets. It is the responsibility
2827
+ of the deserializer to reconcile the different opsets.
2828
+
2829
+ The `nodes` are not guaranteed to be topologically sorted. But the
2830
+ iteration order should be deterministic across different runs. It is the
2831
+ responsibility of the user to maintain a topological order of the nodes.
2832
+
2833
+ Note that there is not a ``node`` attribute in the Graph. The Graph can be
2834
+ seen as a Sequence of nodes and should be used as such. For example, to obtain
2835
+ all nodes as a list, call ``list(graph)``.
2836
+
2837
+ .. versionchanged:: 0.1.1
2838
+ Values with non-none producers will be rejected as graph inputs or initializers.
2839
+
2840
+ .. versionadded:: 0.1.1
2841
+ Added ``add`` method to initializers and attributes.
2842
+
2843
+ Attributes:
2844
+ name: The name of the graph.
2845
+ inputs: The input values of the graph.
2846
+ outputs: The output values of the graph.
2847
+ initializers: The initializers in the graph.
2848
+ doc_string: Documentation string.
2849
+ opset_imports: Opsets imported by the graph.
2850
+ metadata_props: Metadata that will be serialized to the ONNX file.
2851
+ meta: Metadata store for graph transform passes.
2852
+ """
2853
+
2854
+ __slots__ = (
2855
+ "_doc_string",
2856
+ "_initializers",
2857
+ "_inputs",
2858
+ "_metadata",
2859
+ "_metadata_props",
2860
+ "_name_authority",
2861
+ "_nodes",
2862
+ "_opset_imports",
2863
+ "_outputs",
2864
+ "name",
2865
+ )
2866
+
2867
+ def __init__(
2868
+ self,
2869
+ inputs: Sequence[Value],
2870
+ outputs: Sequence[Value],
2871
+ *,
2872
+ nodes: Iterable[Node],
2873
+ initializers: Sequence[Value] = (),
2874
+ doc_string: str | None = None,
2875
+ opset_imports: dict[str, int] | None = None,
2876
+ name: str | None = None,
2877
+ metadata_props: dict[str, str] | None = None,
2878
+ ):
2879
+ self.name = name
2880
+
2881
+ # Private fields that are not to be accessed by any other classes
2882
+ self._inputs = _graph_containers.GraphInputs(self, inputs)
2883
+ self._outputs = _graph_containers.GraphOutputs(self, outputs)
2884
+ self._initializers = _graph_containers.GraphInitializers(
2885
+ self, {initializer.name: initializer for initializer in initializers}
2886
+ )
2887
+ self._doc_string = doc_string
2888
+ self._opset_imports = opset_imports or {}
2889
+ self._metadata: _metadata.MetadataStore | None = None
2890
+ self._metadata_props: dict[str, str] | None = metadata_props
2891
+ self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet()
2892
+ # Be sure the initialize the name authority before extending the nodes
2893
+ # because it is used to name the nodes and their outputs
2894
+ self._name_authority = _name_authority.NameAuthority()
2895
+ # TODO(justinchuby): Trigger again if inputs or initializers are modified.
2896
+ self._set_input_and_initializer_value_names_into_name_authority()
2897
+ # Call self.extend not self._nodes.extend so the graph reference is added to the nodes
2898
+ self.extend(nodes)
2899
+
2900
+ @property
2901
+ def inputs(self) -> MutableSequence[Value]:
2902
+ return self._inputs
2903
+
2904
+ @property
2905
+ def outputs(self) -> MutableSequence[Value]:
2906
+ return self._outputs
2907
+
2908
+ @property
2909
+ def initializers(self) -> _graph_containers.GraphInitializers:
2910
+ """The initializers of the graph as a ``dict[str, Value]``.
2911
+
2912
+ The keys are the names of the initializers. The values are the :class:`Value` objects.
2913
+
2914
+ This property additionally supports the ``add`` method, which takes a :class:`Value`
2915
+ and adds it to the initializers if it is not already present.
2916
+
2917
+ .. note::
2918
+ When setting an initializer with ``graph.initializers[key] = value``,
2919
+ if the value does not have a name, it will be assigned ``key`` as its name.
2920
+
2921
+ """
2922
+ return self._initializers
2923
+
2924
+ def register_initializer(self, value: Value) -> None:
2925
+ """Register an initializer to the graph.
2926
+
2927
+ This is a convenience method to register an initializer to the graph with
2928
+ checks.
2929
+
2930
+ Args:
2931
+ value: The :class:`Value` to register as an initializer of the graph.
2932
+ It must have its ``.const_value`` set.
2933
+
2934
+ Raises:
2935
+ ValueError: If a value of the same name that is not this value
2936
+ is already registered.
2937
+ ValueError: If the value does not have a name.
2938
+ ValueError: If the initializer is produced by a node.
2939
+ ValueError: If the value does not have its ``.const_value`` set.
2940
+ """
2941
+ if not value.name:
2942
+ raise ValueError(f"Initializer must have a name: {value!r}")
2943
+ if value.name in self._initializers:
2944
+ if self._initializers[value.name] is not value:
2945
+ raise ValueError(
2946
+ f"Initializer '{value.name}' is already registered, but"
2947
+ " it is not the same object: existing={self._initializers[value.name]!r},"
2948
+ f" new={value!r}"
2949
+ )
2950
+ if value.const_value is None:
2951
+ raise ValueError(
2952
+ f"Value '{value!r}' must have its const_value set to be an initializer."
2953
+ )
2954
+ self._initializers.add(value)
2955
+
2956
+ @property
2957
+ def doc_string(self) -> str | None:
2958
+ return self._doc_string
2959
+
2960
+ @doc_string.setter
2961
+ def doc_string(self, value: str | None) -> None:
2962
+ self._doc_string = value
2963
+
2964
+ @property
2965
+ def opset_imports(self) -> dict[str, int]:
2966
+ return self._opset_imports
2967
+
2968
+ @typing.overload
2969
+ def __getitem__(self, index: int) -> Node: ...
2970
+ @typing.overload
2971
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2972
+
2973
+ def __getitem__(self, index):
2974
+ return self._nodes[index]
2975
+
2976
+ def __len__(self) -> int:
2977
+ return len(self._nodes)
2978
+
2979
+ def __iter__(self) -> Iterator[Node]:
2980
+ return iter(self._nodes)
2981
+
2982
+ def __reversed__(self) -> Iterator[Node]:
2983
+ return reversed(self._nodes)
2984
+
2985
+ def _set_input_and_initializer_value_names_into_name_authority(self):
2986
+ for value in self.inputs:
2987
+ self._name_authority.register_or_name_value(value)
2988
+ for value in self.initializers.values():
2989
+ self._name_authority.register_or_name_value(value)
2990
+
2991
+ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
2992
+ """Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
2993
+ if node.graph is not None and node.graph is not self:
2994
+ raise ValueError(
2995
+ f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()."
2996
+ )
2997
+ # Give the node and its output values names if they don't not have one
2998
+ self._name_authority.register_or_name_node(node)
2999
+ for value in node._outputs: # pylint: disable=protected-access
3000
+ self._name_authority.register_or_name_value(value)
3001
+ node.graph = self
3002
+ return node
3003
+
3004
+ def node(self, index_or_name: int | str, /) -> Node:
3005
+ """Get a node by index or name.
3006
+
3007
+ This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1).
3008
+
3009
+ .. note::
3010
+ If you need repeated random access, consider turning it into a list with ``list(graph)`` .
3011
+ Or a dictionary for repeated access by name: ``{node.name for node in graph}`` .
3012
+
3013
+ When a name is provided and if there are multiple nodes with the same name,
3014
+ the first node with the name is returned.
3015
+
3016
+ Args:
3017
+ index_or_name: The index or name of the node.
3018
+
3019
+ Returns:
3020
+ The node if found.
3021
+
3022
+ Raises:
3023
+ IndexError: If the index is out of range.
3024
+ ValueError: If the node with the given name is not found.
3025
+ """
3026
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3027
+ if isinstance(index_or_name, int):
3028
+ return self[index_or_name]
3029
+ for node in self:
3030
+ if node.name == index_or_name:
3031
+ return node
3032
+ raise ValueError(f"Node with name '{index_or_name}' not found.")
3033
+
3034
+ def num_nodes(self) -> int:
3035
+ """Get the number of nodes in the graph in O(1) time.
3036
+
3037
+ Note that this method returns the number of nodes this graph directly contains.
3038
+ It does not count nodes in subgraphs.
3039
+
3040
+ This is an alias for ``len(graph)``. Use this if you prefer a more descriptive
3041
+ name for readability.
3042
+ """
3043
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3044
+ return len(self)
3045
+
3046
+ def all_nodes(self) -> Iterator[Node]:
3047
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3048
+
3049
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3050
+ Consider using
3051
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3052
+ traversals on nodes.
3053
+
3054
+ .. versionadded:: 0.1.2
3055
+ """
3056
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3057
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
3058
+
3059
+ def subgraphs(self) -> Iterator[Graph]:
3060
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time.
3061
+
3062
+ .. versionadded:: 0.1.2
3063
+ """
3064
+ # Use a dict to preserve order
3065
+ seen_graphs: dict[Graph, None] = {}
3066
+
3067
+ # Need to use the enter_graph callback so that empty subgraphs are collected
3068
+ def enter_subgraph(graph) -> None:
3069
+ if graph is self:
3070
+ return
3071
+ if not isinstance(graph, Graph):
3072
+ raise TypeError(
3073
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
3074
+ )
3075
+ if graph not in seen_graphs:
3076
+ seen_graphs[graph] = None
3077
+
3078
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
3079
+ pass
3080
+ yield from seen_graphs.keys()
3081
+
3082
+ def clone(self, allow_outer_scope_values: bool = False) -> Graph:
3083
+ """Create a deep copy of this graph in O(#nodes + #values) time.
3084
+
3085
+ All nodes, values, and subgraphs are cloned. The cloned graph will have
3086
+ the same structure as this graph, but all nodes and values will be different
3087
+ objects.
3088
+
3089
+ Tensors in initializers and constant values will be shared.
3090
+
3091
+ .. versionadded:: 0.1.14
3092
+ .. versionadded:: 0.1.15
3093
+ Added ``allow_outer_scope_values`` argument.
3094
+
3095
+ Args:
3096
+ allow_outer_scope_values: When True, values that are from outer scopes
3097
+ (not defined in this graph) will not be cloned. Instead, the cloned
3098
+ graph will reference the same outer scope values. This is useful
3099
+ when cloning subgraphs that reference values from the outer graph.
3100
+ When False (default), values from outer scopes will cause an error if they
3101
+ are referenced in the cloned graph.
3102
+
3103
+ Returns:
3104
+ A deep copy of this graph.
3105
+
3106
+ Raises:
3107
+ ValueError: If ``allow_outer_scope_values`` is False and the graph
3108
+ references values from outer scopes.
3109
+ """
3110
+ from onnx_ir import _cloner
3111
+
3112
+ cloner = _cloner.Cloner(
3113
+ attr_map={},
3114
+ value_map={},
3115
+ metadata_props={},
3116
+ resolve_ref_attrs=False,
3117
+ allow_outer_scope_values=allow_outer_scope_values,
3118
+ )
3119
+ return cloner.clone_graph(self)
3120
+
3121
+ # Mutation methods
3122
+ def append(self, node: Node, /) -> None:
3123
+ """Append a node to the graph in O(1) time.
3124
+
3125
+ Unique names will be assigned to the node and its values if any name is ``None``.
3126
+
3127
+ Args:
3128
+ node: The node to append.
3129
+
3130
+ Raises:
3131
+ ValueError: If the node belongs to another graph.
3132
+ """
3133
+ self._set_node_graph_to_self_and_assign_names(node)
3134
+ self._nodes.append(node)
3135
+
3136
+ def extend(self, nodes: Iterable[Node], /) -> None:
3137
+ """Extend the graph with the given nodes in O(#new_nodes) time.
3138
+
3139
+ Unique names will be assigned to the node and its values if any name is ``None``.
3140
+
3141
+ Args:
3142
+ nodes: The nodes to extend the graph with.
3143
+
3144
+ Raises:
3145
+ ValueError: If any node belongs to another graph.
3146
+ """
3147
+ nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes]
3148
+ self._nodes.extend(nodes)
3149
+
3150
+ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
3151
+ """Remove nodes from the graph in O(#num of nodes to remove) time.
3152
+
3153
+ If any errors are raise, to ensure the graph is not left in an inconsistent state,
3154
+ the graph is not modified.
3155
+
3156
+ Args:
3157
+ nodes: The node to remove.
3158
+ safe: If True, performs the following actions before removal:
3159
+
3160
+ 1. It checks to make sure there are no users of the node that are not
3161
+ to be removed before removing it.
3162
+ 2. It checks the node does not contribute to any graph outputs.
3163
+ 3. It removes references to all inputs so it is no longer a user of other nodes.
3164
+
3165
+ Raises:
3166
+ ValueError: If any node to remove does not belong to this graph.
3167
+ ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
3168
+ ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
3169
+ """
3170
+ if not isinstance(nodes, Iterable):
3171
+ nodes_set: AbstractSet[Node] = {nodes}
3172
+ else:
3173
+ nodes_set = frozenset(nodes)
3174
+ graph_outputs = frozenset(self.outputs)
3175
+ for node in nodes_set:
3176
+ if node.graph is not self:
3177
+ raise ValueError(f"The node '{node!r}' does not belong to this graph.")
3178
+ if safe:
3179
+ # Check 1, 2
3180
+ _check_node_safe_to_remove(node, nodes_set, graph_outputs)
3181
+ for node in nodes_set:
3182
+ if safe:
3183
+ # 3. Detach from all inputs so that it is no longer a user of other nodes
3184
+ for i in range(len(node.inputs)):
3185
+ node.replace_input_with(i, None)
3186
+ # Set attributes to remove the node from this graph
3187
+ node.graph = None
3188
+ self._nodes.remove(node)
3189
+
3190
+ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
3191
+ """Insert new nodes after the given node in O(#new_nodes) time.
3192
+
3193
+ Unique names will be assigned to the node and its values if any name is ``None``.
3194
+
3195
+ Args:
3196
+ node: The node to insert after.
3197
+ new_nodes: The new nodes to insert.
3198
+
3199
+ Raises:
3200
+ ValueError: If any node belongs to another graph.
3201
+ """
3202
+ if isinstance(new_nodes, Node):
3203
+ new_nodes = (new_nodes,)
3204
+ new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
3205
+ self._nodes.insert_after(node, new_nodes)
3206
+
3207
+ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
3208
+ """Insert new nodes before the given node in O(#new_nodes) time.
3209
+
3210
+ Unique names will be assigned to the node and its values if any name is ``None``.
3211
+
3212
+ Args:
3213
+ node: The node to insert before.
3214
+ new_nodes: The new nodes to insert.
3215
+
3216
+ Raises:
3217
+ ValueError: If any node belongs to another graph.
3218
+ """
3219
+ if isinstance(new_nodes, Node):
3220
+ new_nodes = (new_nodes,)
3221
+ new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
3222
+ self._nodes.insert_before(node, new_nodes)
3223
+
3224
+ def sort(self) -> None:
3225
+ """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.
3226
+
3227
+ This sort is stable. It preserves the original order as much as possible.
3228
+
3229
+ Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
3230
+
3231
+ Raises:
3232
+ ValueError: If the graph contains a cycle, making topological sorting impossible.
3233
+ """
3234
+ # Obtain all nodes from the graph and its subgraphs for sorting
3235
+ nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self))
3236
+ # Store the sorted nodes of each subgraph
3237
+ sorted_nodes_by_graph: dict[Graph, list[Node]] = {
3238
+ graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
3239
+ }
3240
+ # TODO(justinchuby): Explain why we need to store direct predecessors and children and why
3241
+ # we only need to store the direct ones
3242
+
3243
+ # The depth of a node is defined as the number of direct children it has
3244
+ node_depth: dict[Node, int] = dict.fromkeys(nodes, 0)
3245
+ # Direct predecessors of a node
3246
+ node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes}
3247
+ # Store the negative index of the nodes because heapq is a min heap and we
3248
+ # want to pop the node with largest index value first, effectively turning
3249
+ # it to a max heap
3250
+ neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)}
3251
+
3252
+ def add_predecessor(child: Node, predecessor: Node | None) -> None:
3253
+ """Add a predecessor of a node, and increment the depth of the predecessor."""
3254
+ if predecessor is None:
3255
+ return
3256
+ node_predecessors[child].append(predecessor)
3257
+ node_depth[predecessor] += 1
3258
+
3259
+ # 1. Build the direct predecessors of each node and the depth of each node
3260
+ # for sorting topologically using Kahn's algorithm.
3261
+ # Note that when a node contains graph attributes (aka. has subgraphs),
3262
+ # we consider all nodes in the subgraphs *predecessors* of this node. This
3263
+ # way we ensure the implicit dependencies of the subgraphs are captured
3264
+ # as predecessors of the node.
3265
+ for node in nodes:
3266
+ # All producers of input values are considered as direct predecessors.
3267
+ for input_value in node.inputs:
3268
+ if input_value is None:
3269
+ continue
3270
+ predecessor_node = input_value.producer()
3271
+ add_predecessor(node, predecessor_node)
3272
+ # All nodes in attribute graphs are considered as direct predecessors.
3273
+ for attr in node.attributes.values():
3274
+ if not isinstance(attr, Attr):
3275
+ continue
3276
+ # A nice thing about this algorithm is that we only need to record
3277
+ # direct predecessors. This continues to be true even with subgraphs.
3278
+ # When a node in a subgraph (a) contains its own subgraphs (b), the
3279
+ # node in subgraphs (b) are guranteed to appear before the node
3280
+ # in (a).
3281
+ if attr.type == _enums.AttributeType.GRAPH:
3282
+ for predecessor_node in attr.value:
3283
+ add_predecessor(node, predecessor_node)
3284
+ elif attr.type == _enums.AttributeType.GRAPHS:
3285
+ for attribute_graph in attr.value:
3286
+ for predecessor_node in attribute_graph:
3287
+ add_predecessor(node, predecessor_node)
3288
+
3289
+ # 2. Priority Queue: Track nodes with zero direct children in a priority queue,
3290
+ # using NEGATIVE original index for ordering.
3291
+ # This ensures nodes appearing LATER in the original order are processed EARLIER.
3292
+ # We get REVERSED topological order of each subgraph.
3293
+ priority_queue: list[tuple[int, Node]] = [
3294
+ (neg_node_index[node], node) for node in nodes if node_depth[node] == 0
3295
+ ]
3296
+ heapq.heapify(priority_queue)
3297
+
3298
+ # 3. Topological Sort:
3299
+ num_of_sorted_nodes = 0
3300
+ while priority_queue:
3301
+ # Pop the node with the most negative index and add it to the sorted nodes by subgraph.
3302
+ _, current_node = heapq.heappop(priority_queue)
3303
+ assert current_node.graph is not None
3304
+ sorted_nodes_by_graph[current_node.graph].append(current_node)
3305
+ num_of_sorted_nodes += 1
3306
+ # Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue.
3307
+ for predecessor_node in node_predecessors[current_node]:
3308
+ node_depth[predecessor_node] -= 1
3309
+ if node_depth[predecessor_node] == 0:
3310
+ heapq.heappush(
3311
+ priority_queue, (neg_node_index[predecessor_node], predecessor_node)
3312
+ )
3313
+
3314
+ # 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle.
3315
+ if num_of_sorted_nodes != len(nodes):
3316
+ raise ValueError("Graph contains a cycle, topological sort is not possible.")
3317
+
3318
+ # 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order.
3319
+ for graph, sorted_nodes in sorted_nodes_by_graph.items():
3320
+ # The graph container ensures all the nodes are unique so we can safely extend
3321
+ graph.extend(reversed(sorted_nodes))
3322
+
3323
+ # End of mutation methods
3324
+
3325
+ @property
3326
+ def meta(self) -> _metadata.MetadataStore:
3327
+ """The metadata store for intermediate analysis.
3328
+
3329
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3330
+ to the ONNX proto.
3331
+ """
3332
+ if self._metadata is None:
3333
+ self._metadata = _metadata.MetadataStore()
3334
+ return self._metadata
3335
+
3336
+ @property
3337
+ def metadata_props(self) -> dict[str, str]:
3338
+ """The metadata properties of the graph.
3339
+
3340
+ The metadata properties are used to store additional information about the graph.
3341
+ Unlike ``meta``, this property is serialized to the ONNX proto.
3342
+ """
3343
+ if self._metadata_props is None:
3344
+ self._metadata_props = {}
3345
+ return self._metadata_props
3346
+
3347
+ def __str__(self) -> str:
3348
+ return _graph_str(self)
3349
+
3350
+ def __repr__(self) -> str:
3351
+ return _graph_repr(self)
3352
+
3353
+
3354
+ def _graph_str(graph: Graph | GraphView) -> str:
3355
+ """Return a string representation of the graph."""
3356
+ # TODO(justinchuby): Show docstrings and metadata
3357
+ inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
3358
+ outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
3359
+ initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
3360
+ if initializers_text:
3361
+ initializers_text = (
3362
+ "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
3363
+ )
3364
+ signature = f"""\
3365
+ graph(
3366
+ name={graph.name or "anonymous_graph:" + str(id(graph))},
3367
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
3368
+ ),
3369
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
3370
+ ),{textwrap.indent(initializers_text, " " * 4)}
3371
+ )"""
3372
+ node_count = len(graph)
3373
+ number_width = len(str(node_count))
3374
+ node_lines = []
3375
+ for i, node in enumerate(graph):
3376
+ node_name = node.name if node.name else f":anonymous_node:{id(node)}"
3377
+ node_text = f"# {node_name}\n{node}"
3378
+ indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
3379
+ # Remove the leading spaces
3380
+ indented_node_text = indented_node_text.strip()
3381
+ node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
3382
+ returns = ", ".join(str(x) for x in graph.outputs)
3383
+ body = (
3384
+ "{\n"
3385
+ + textwrap.indent("\n".join(node_lines), " " * 4)
3386
+ + textwrap.indent(f"\nreturn {returns}", " " * 4)
3387
+ + "\n}"
3388
+ )
3389
+
3390
+ return f"{signature} {body}"
3391
+
3392
+
3393
+ def _graph_repr(graph: Graph | GraphView) -> str:
3394
+ """Return an repr string of the graph."""
3395
+ inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
3396
+ outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
3397
+ initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
3398
+ if initializers_text:
3399
+ initializers_text = (
3400
+ "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
3401
+ )
3402
+ return f"""\
3403
+ {graph.__class__.__name__}(
3404
+ name={graph.name or "anonymous_graph:" + str(id(graph))!r},
3405
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
3406
+ ),
3407
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
3408
+ ),{textwrap.indent(initializers_text, " " * 4)}
3409
+ len()={len(graph)}
3410
+ )"""
3411
+
3412
+
3413
+ class GraphView(Sequence[Node], _display.PrettyPrintable):
3414
+ """A read-only view on a graph.
3415
+
3416
+ The GraphView is useful for analysis of a subgraph. It can be initialized
3417
+ with a subset of nodes from a :class:`Graph`. Creating GraphView does not
3418
+ change the ownership of the nodes, and so it is possible to create multiple
3419
+ GraphViews that contain the same nodes. If the underlying nodes / connections
3420
+ are mutated, the mutation will be reflected in all views as well.
3421
+
3422
+ The graph view can be serialized to ONNX::
3423
+
3424
+ graph_proto = ir.serde.serialize_graph(graph_view)
3425
+
3426
+ It can also be used to create a model::
3427
+
3428
+ model = ir.Model(graph_view, ir_version=8)
3429
+ model_proto = ir.serde.serialize_model(model)
3430
+
3431
+ The model created with a GraphView will have a fixed topology, and its graph
3432
+ will remain read-only as a GraphView. No copying will be done during the
3433
+ initialization process.
3434
+
3435
+ Attributes:
3436
+ name: The name of the graph.
3437
+ inputs: The input values of the graph.
3438
+ outputs: The output values of the graph.
3439
+ initializers: The initializers in the graph.
3440
+ doc_string: Documentation string.
3441
+ opset_imports: Opsets imported by the graph.
3442
+ metadata_props: Metadata that will be serialized to the ONNX file.
3443
+ meta: Metadata store for graph transform passes.
3444
+ """
3445
+
3446
+ __slots__ = (
3447
+ "_metadata",
3448
+ "_metadata_props",
3449
+ "doc_string",
3450
+ "initializers",
3451
+ "inputs",
3452
+ "name",
3453
+ "nodes",
3454
+ "opset_imports",
3455
+ "outputs",
3456
+ )
3457
+
3458
+ def __init__(
3459
+ self,
3460
+ inputs: Sequence[Value],
3461
+ outputs: Sequence[Value],
3462
+ *,
3463
+ nodes: Iterable[Node],
3464
+ initializers: Sequence[Value] = (),
3465
+ doc_string: str | None = None,
3466
+ opset_imports: dict[str, int] | None = None,
3467
+ name: str | None = None,
3468
+ metadata_props: dict[str, str] | None = None,
3469
+ ):
3470
+ self.name = name
3471
+ self.inputs = tuple(inputs)
3472
+ self.outputs = tuple(outputs)
3473
+ self.initializers: dict[str, Value] = {}
3474
+ for initializer in initializers:
3475
+ if not initializer.name:
3476
+ raise ValueError(f"Initializer must have a name: {initializer!r}")
3477
+ self.initializers[initializer.name] = initializer
3478
+ self.doc_string = doc_string
3479
+ self.opset_imports = opset_imports or {}
3480
+ self._metadata: _metadata.MetadataStore | None = None
3481
+ self._metadata_props: dict[str, str] | None = metadata_props
3482
+ self._nodes: tuple[Node, ...] = tuple(nodes)
3483
+
3484
+ @typing.overload
3485
+ def __getitem__(self, index: int) -> Node: ...
3486
+ @typing.overload
3487
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
3488
+
3489
+ def __getitem__(self, index):
3490
+ return self._nodes[index]
3491
+
3492
+ def __len__(self) -> int:
3493
+ return len(self._nodes)
3494
+
3495
+ def __iter__(self) -> Iterator[Node]:
3496
+ return iter(self._nodes)
3497
+
3498
+ def __reversed__(self) -> Iterator[Node]:
3499
+ return reversed(self._nodes)
3500
+
3501
+ @property
3502
+ def meta(self) -> _metadata.MetadataStore:
3503
+ """The metadata store for intermediate analysis.
3504
+
3505
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3506
+ to the ONNX proto.
3507
+ """
3508
+ if self._metadata is None:
3509
+ self._metadata = _metadata.MetadataStore()
3510
+ return self._metadata
3511
+
3512
+ @property
3513
+ def metadata_props(self) -> dict[str, str]:
3514
+ if self._metadata_props is None:
3515
+ self._metadata_props = {}
3516
+ return self._metadata_props
3517
+
3518
+ def __str__(self) -> str:
3519
+ return _graph_str(self)
3520
+
3521
+ def __repr__(self) -> str:
3522
+ return _graph_repr(self)
3523
+
3524
+ def clone(self) -> Graph:
3525
+ """Create a deep copy of this graph in O(#nodes + #values) time.
3526
+
3527
+ All nodes, values, and subgraphs are cloned. The cloned graph will have
3528
+ the same structure as this graph, but all nodes and values will be different
3529
+ objects.
3530
+
3531
+ Tensors in initializers and constant values will be shared.
3532
+
3533
+ .. versionadded:: 0.1.14
3534
+
3535
+ Returns:
3536
+ A deep copy of this graph.
3537
+ """
3538
+ from onnx_ir import _cloner
3539
+
3540
+ cloner = _cloner.Cloner(
3541
+ attr_map={},
3542
+ value_map={},
3543
+ metadata_props={},
3544
+ resolve_ref_attrs=False,
3545
+ )
3546
+ return cloner.clone_graph(self)
3547
+
3548
+
3549
+ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
3550
+ __slots__ = (
3551
+ "_functions",
3552
+ "_metadata",
3553
+ "_metadata_props",
3554
+ "doc_string",
3555
+ "domain",
3556
+ "graph",
3557
+ "ir_version",
3558
+ "model_version",
3559
+ "producer_name",
3560
+ "producer_version",
3561
+ )
3562
+ """IR Model.
3563
+
3564
+ A model is a container for a graph and metadata.
3565
+
3566
+ Attributes:
3567
+ graph: The graph of the model.
3568
+ ir_version: The version of the IR.
3569
+ producer_name: The name of the producer.
3570
+ producer_version: The version of the producer.
3571
+ domain: The domain of the model.
3572
+ model_version: The version of the model.
3573
+ doc_string: Documentation string.
3574
+ functions: The functions defined in the model.
3575
+ metadata_props: Metadata.
3576
+ """
3577
+
3578
+ def __init__(
3579
+ self,
3580
+ graph: Graph,
3581
+ *,
3582
+ ir_version: int,
3583
+ producer_name: str | None = None,
3584
+ producer_version: str | None = None,
3585
+ domain: str | None = None,
3586
+ model_version: int | None = None,
3587
+ doc_string: str | None = None,
3588
+ functions: Sequence[Function] = (),
3589
+ metadata_props: dict[str, str] | None = None,
3590
+ ) -> None:
3591
+ self.graph: Graph = graph
3592
+ self.ir_version = ir_version
3593
+ self.producer_name = producer_name
3594
+ self.producer_version = producer_version
3595
+ self.domain = domain
3596
+ self.model_version = model_version
3597
+ self.doc_string = doc_string
3598
+ self._functions = {func.identifier(): func for func in functions}
3599
+ self._metadata: _metadata.MetadataStore | None = None
3600
+ self._metadata_props: dict[str, str] | None = metadata_props
3601
+
3602
+ @property
3603
+ def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
3604
+ return self._functions
3605
+
3606
+ @property
3607
+ def opset_imports(self) -> dict[str, int]:
3608
+ return self.graph.opset_imports
3609
+
3610
+ @property
3611
+ def meta(self) -> _metadata.MetadataStore:
3612
+ """The metadata store for intermediate analysis.
3613
+
3614
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3615
+ to the ONNX proto.
3616
+ """
3617
+ if self._metadata is None:
3618
+ self._metadata = _metadata.MetadataStore()
3619
+ return self._metadata
3620
+
3621
+ @property
3622
+ def metadata_props(self) -> dict[str, str]:
3623
+ """The metadata properties of the model.
3624
+
3625
+ The metadata properties are used to store additional information about the model.
3626
+ Unlike ``meta``, this property is serialized to the ONNX proto.
3627
+ """
3628
+ if self._metadata_props is None:
3629
+ self._metadata_props = {}
3630
+ return self._metadata_props
3631
+
3632
+ def __str__(self) -> str:
3633
+ # TODO(justinchuby): Show docstrings and metadata
3634
+ signature = f"""\
3635
+ <
3636
+ ir_version={self.ir_version!r},
3637
+ opset_imports={self.opset_imports!r},
3638
+ producer_name={self.producer_name!r},
3639
+ producer_version={self.producer_version!r},
3640
+ domain={self.domain!r},
3641
+ model_version={self.model_version!r},
3642
+ >"""
3643
+ graph_text = str(self.graph)
3644
+ functions_text = "\n\n".join(str(func) for func in self.functions.values())
3645
+ return f"{signature}\n{graph_text}" + f"\n\n{functions_text}"
3646
+
3647
+ def __repr__(self) -> str:
3648
+ return f"""\
3649
+ Model(
3650
+ ir_version={self.ir_version!r},
3651
+ opset_imports={self.opset_imports!r},
3652
+ producer_name={self.producer_name!r},
3653
+ producer_version={self.producer_version!r},
3654
+ domain={self.domain!r},
3655
+ model_version={self.model_version!r},
3656
+ functions={self.functions!r},
3657
+ graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
3658
+ )"""
3659
+
3660
+ def graphs(self) -> Iterable[Graph]:
3661
+ """Get all graphs and subgraphs in the model.
3662
+
3663
+ This is a convenience method to traverse the model. Consider using
3664
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3665
+ traversals on nodes.
3666
+ """
3667
+ # NOTE(justinchuby): Given
3668
+ # (1) how useful the method is
3669
+ # (2) I couldn't find an appropriate name for it in `traversal.py`
3670
+ # (3) Users familiar with onnxruntime optimization tools expect this method
3671
+ # I created this method as a core method instead of an iterator in
3672
+ # `traversal.py`.
3673
+ yield self.graph
3674
+ yield from self.graph.subgraphs()
3675
+
3676
+ def clone(self) -> Model:
3677
+ """Create a deep copy of this model.
3678
+
3679
+ All graphs, nodes, values, and subgraphs are cloned. The cloned model will have
3680
+ the same structure as this model, but all graphs, nodes, and values will be different
3681
+ objects.
3682
+
3683
+ Tensors in initializers and constant values will be shared.
3684
+
3685
+ .. versionadded:: 0.1.14
3686
+
3687
+ Returns:
3688
+ A deep copy of this model.
3689
+ """
3690
+ new_graph = self.graph.clone()
3691
+ new_functions = [func.clone() for func in self.functions.values()]
3692
+ new_model = Model(
3693
+ new_graph,
3694
+ ir_version=self.ir_version,
3695
+ producer_name=self.producer_name,
3696
+ producer_version=self.producer_version,
3697
+ domain=self.domain,
3698
+ model_version=self.model_version,
3699
+ doc_string=self.doc_string,
3700
+ functions=new_functions,
3701
+ metadata_props=dict(self.metadata_props),
3702
+ )
3703
+
3704
+ return new_model
3705
+
3706
+
3707
+ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
3708
+ """IR functions.
3709
+
3710
+ Like a graph, a function can have nodes that are not topologically sorted. It is
3711
+ the responsibility of the user to maintain a topological order of the nodes.
3712
+
3713
+ Note that there is not a ``node`` attribute in the Function. The Function can be
3714
+ seen as a Sequence of nodes and should be used as such. For example, to obtain
3715
+ all nodes as a list, call ``list(function)``.
3716
+
3717
+ Attributes:
3718
+ name: The function name.
3719
+ domain: The domain this function is defined in.
3720
+ overload: The overload name when the function is overloaded.
3721
+ inputs: The input values of the function.
3722
+ attributes: The attributes this function defines.
3723
+ outputs: The output values of the function.
3724
+ opset_imports: Opsets imported by the function.
3725
+ doc_string: Documentation string.
3726
+ meta: Metadata store for graph transform passes.
3727
+ metadata_props: Metadata that will be serialized to the ONNX file.
3728
+ """
3729
+
3730
+ __slots__ = (
3731
+ "_attributes",
3732
+ "_domain",
3733
+ "_graph",
3734
+ "_name",
3735
+ "_overload",
3736
+ )
3737
+
3738
+ def __init__(
3739
+ self,
3740
+ domain: str,
3741
+ name: str,
3742
+ overload: str = "",
3743
+ *,
3744
+ # Ensure the inputs and outputs of the function belong to a graph
3745
+ # and not from an outer scope
3746
+ graph: Graph,
3747
+ attributes: Iterable[Attr] | Mapping[str, Attr],
3748
+ ) -> None:
3749
+ self._domain = domain
3750
+ self._name = name
3751
+ self._overload = overload
3752
+ self._graph = graph
3753
+ if isinstance(attributes, Mapping):
3754
+ attributes = tuple(attributes.values())
3755
+ self._attributes = _graph_containers.Attributes(attributes)
3756
+
3757
+ def identifier(self) -> _protocols.OperatorIdentifier:
3758
+ return self.domain, self.name, self.overload
3759
+
3760
+ @property
3761
+ def name(self) -> str:
3762
+ return self._name
3763
+
3764
+ @name.setter
3765
+ def name(self, value: str) -> None:
3766
+ self._name = value
3767
+
3768
+ @property
3769
+ def domain(self) -> str:
3770
+ return self._domain
3771
+
3772
+ @domain.setter
3773
+ def domain(self, value: str) -> None:
3774
+ self._domain = _normalize_domain(value)
3775
+
3776
+ @property
3777
+ def overload(self) -> str:
3778
+ return self._overload
3779
+
3780
+ @overload.setter
3781
+ def overload(self, value: str) -> None:
3782
+ self._overload = value
3783
+
3784
+ @property
3785
+ def inputs(self) -> MutableSequence[Value]:
3786
+ return self._graph.inputs
3787
+
3788
+ @property
3789
+ def outputs(self) -> MutableSequence[Value]:
3790
+ return self._graph.outputs
3791
+
3792
+ @property
3793
+ def attributes(self) -> _graph_containers.Attributes:
3794
+ return self._attributes
3795
+
3796
+ @property
3797
+ def graph(self) -> Graph:
3798
+ """The underlying Graph object that contains the nodes of this function.
3799
+
3800
+ Only use this graph for identity comparison::
3801
+
3802
+ if value.graph is function.graph:
3803
+ # Do something with the value that belongs to this function
3804
+
3805
+ Otherwise use the Function object directly to access the nodes and other properties.
3806
+
3807
+ .. versionadded:: 0.1.7
3808
+ """
3809
+ return self._graph
3810
+
3811
+ @typing.overload
3812
+ def __getitem__(self, index: int) -> Node: ...
3813
+ @typing.overload
3814
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
3815
+
3816
+ def __getitem__(self, index):
3817
+ return self._graph.__getitem__(index)
3818
+
3819
+ def __len__(self) -> int:
3820
+ return self._graph.__len__()
3821
+
3822
+ def __iter__(self) -> Iterator[Node]:
3823
+ return self._graph.__iter__()
3824
+
3825
+ def __reversed__(self) -> Iterator[Node]:
3826
+ return self._graph.__reversed__()
3827
+
3828
+ @property
3829
+ def doc_string(self) -> str | None:
3830
+ return self._graph.doc_string
3831
+
3832
+ @doc_string.setter
3833
+ def doc_string(self, value: str | None) -> None:
3834
+ self._graph.doc_string = value
3835
+
3836
+ @property
3837
+ def opset_imports(self) -> dict[str, int]:
3838
+ return self._graph.opset_imports
3839
+
3840
+ @property
3841
+ def meta(self) -> _metadata.MetadataStore:
3842
+ """The metadata store for intermediate analysis.
3843
+
3844
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3845
+ to the ONNX proto.
3846
+ """
3847
+ return self._graph.meta
3848
+
3849
+ @property
3850
+ def metadata_props(self) -> dict[str, str]:
3851
+ """The metadata properties of the function.
3852
+
3853
+ The metadata properties are used to store additional information about the function.
3854
+ Unlike ``meta``, this property is serialized to the ONNX proto.
3855
+ """
3856
+ return self._graph.metadata_props
3857
+
3858
+ def all_nodes(self) -> Iterator[Node]:
3859
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3860
+
3861
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3862
+ Consider using
3863
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3864
+ traversals on nodes.
3865
+
3866
+ .. versionadded:: 0.1.2
3867
+ """
3868
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3869
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
3870
+
3871
+ def subgraphs(self) -> Iterator[Graph]:
3872
+ """Get all subgraphs in the function in O(#nodes + #attributes) time.
3873
+
3874
+ .. versionadded:: 0.1.2
3875
+ """
3876
+ seen_graphs: dict[Graph, None] = {}
3877
+
3878
+ # Need to use the enter_graph callback so that empty subgraphs are collected
3879
+ def enter_subgraph(graph) -> None:
3880
+ if graph is self:
3881
+ return
3882
+ if not isinstance(graph, Graph):
3883
+ raise TypeError(
3884
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
3885
+ )
3886
+ if graph not in seen_graphs:
3887
+ seen_graphs[graph] = None
3888
+
3889
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
3890
+ pass
3891
+ yield from seen_graphs.keys()
3892
+
3893
+ def clone(self) -> Function:
3894
+ """Create a deep copy of this function in O(#nodes + #values) time.
3895
+
3896
+ All nodes, values, and subgraphs are cloned. The cloned function will have
3897
+ the same structure as this function, but all nodes and values will be different
3898
+ objects.
3899
+
3900
+ Tensors in initializers and constant values will be shared.
3901
+
3902
+ .. versionadded:: 0.1.14
3903
+
3904
+ Returns:
3905
+ A deep copy of this function.
3906
+ """
3907
+ from onnx_ir import _cloner
3908
+
3909
+ cloner = _cloner.Cloner(
3910
+ attr_map={},
3911
+ value_map={},
3912
+ metadata_props={},
3913
+ resolve_ref_attrs=False,
3914
+ )
3915
+ new_graph = cloner.clone_graph(self._graph)
3916
+ new_attributes = [
3917
+ cloner.clone_attr(attr.name, attr) for attr in self._attributes.values()
3918
+ ]
3919
+ return Function(
3920
+ domain=self._domain,
3921
+ name=self._name,
3922
+ overload=self._overload,
3923
+ graph=new_graph,
3924
+ attributes=new_attributes, # type: ignore
3925
+ )
3926
+
3927
+ # Mutation methods
3928
+ def append(self, node: Node, /) -> None:
3929
+ """Append a node to the function in O(1) time."""
3930
+ self._graph.append(node)
3931
+
3932
+ def extend(self, nodes: Iterable[Node], /) -> None:
3933
+ """Extend the function with the given nodes in O(#new_nodes) time."""
3934
+ self._graph.extend(nodes)
3935
+
3936
+ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
3937
+ """Remove nodes from the graph in O(#num of nodes) time.
3938
+
3939
+ If any errors are raise, to ensure the graph is not left in an inconsistent state,
3940
+ the graph is not modified.
3941
+
3942
+ Args:
3943
+ nodes: The node to remove.
3944
+ safe: If True, performs the following actions before removal:
3945
+
3946
+ 1. It checks to make sure there are no users of the node that are not
3947
+ to be removed before removing it.
3948
+ 2. It checks the node does not contribute to any graph outputs.
3949
+ 3. It removes references to all inputs so it is no longer a user of other nodes.
3950
+
3951
+ Raises:
3952
+ ValueError: If any node to remove does not belong to this graph.
3953
+ ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
3954
+ ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
3955
+ """
3956
+ self._graph.remove(nodes, safe=safe)
3957
+
3958
+ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
3959
+ """Insert new nodes after the given node in O(#new_nodes) time."""
3960
+ self._graph.insert_after(node, new_nodes)
3961
+
3962
+ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
3963
+ """Insert new nodes before the given node in O(#new_nodes) time."""
3964
+ self._graph.insert_before(node, new_nodes)
3965
+
3966
+ def sort(self) -> None:
3967
+ """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time."""
3968
+ self._graph.sort()
3969
+
3970
+ # End of mutation methods
3971
+
3972
+ def __str__(self) -> str:
3973
+ full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "")
3974
+ inputs_text = ",\n".join(str(x) for x in self.inputs)
3975
+ outputs_text = ",\n".join(str(x) for x in self.outputs)
3976
+ attributes_text = ",\n".join(
3977
+ f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is not None)
3978
+ for attr in self.attributes.values()
3979
+ )
3980
+ if attributes_text:
3981
+ attributes_text = (
3982
+ "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}"
3983
+ )
3984
+ signature = f"""\
3985
+ <
3986
+ opset_imports={self.opset_imports!r},
3987
+ >
3988
+ def {full_name}(
3989
+ inputs=(
3990
+ {textwrap.indent(inputs_text, " " * 8)}
3991
+ ),{textwrap.indent(attributes_text, " " * 4)}
3992
+ outputs=(
3993
+ {textwrap.indent(outputs_text, " " * 8)}
3994
+ ),
3995
+ )"""
3996
+ node_count = len(self)
3997
+ number_width = len(str(node_count))
3998
+ node_lines = []
3999
+ for i, node in enumerate(self):
4000
+ node_name = node.name if node.name else f":anonymous_node:{id(node)}"
4001
+ node_text = f"# {node_name}\n{node}"
4002
+ indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
4003
+ # Remove the leading spaces
4004
+ indented_node_text = indented_node_text.strip()
4005
+ node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
4006
+ returns = ", ".join(str(x) for x in self.outputs)
4007
+ body = (
4008
+ "{\n"
4009
+ + textwrap.indent("\n".join(node_lines), " " * 4)
4010
+ + textwrap.indent(f"\nreturn {returns}", " " * 4)
4011
+ + "\n}"
4012
+ )
4013
+
4014
+ return f"{signature} {body}"
4015
+
4016
+ def __repr__(self) -> str:
4017
+ return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
4018
+
4019
+
4020
+ class Attr(
4021
+ _protocols.AttributeProtocol,
4022
+ _protocols.ReferenceAttributeProtocol,
4023
+ _display.PrettyPrintable,
4024
+ ):
4025
+ """Base class for ONNX attributes or references."""
4026
+
4027
+ __slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
4028
+
4029
+ def __init__(
4030
+ self,
4031
+ name: str,
4032
+ type: _enums.AttributeType,
4033
+ value: Any,
4034
+ ref_attr_name: str | None = None,
4035
+ *,
4036
+ doc_string: str | None = None,
4037
+ ) -> None:
4038
+ # Quick checks to ensure that INT and FLOAT attributes are stored as int and float,
4039
+ # not np.int32, np.float32, bool, etc.
4040
+ # This also allows errors to be raised at the time of construction instead of later
4041
+ # during serialization.
4042
+ # TODO(justinchuby): Use case matching when we drop support for Python 3.9
4043
+ if value is None:
4044
+ # Value can be None for reference attributes or when it is used as a
4045
+ # placeholder for schemas
4046
+ pass
4047
+ elif type == _enums.AttributeType.INT:
4048
+ value = int(value)
4049
+ elif type == _enums.AttributeType.FLOAT:
4050
+ value = float(value)
4051
+ elif type == _enums.AttributeType.INTS:
4052
+ value = tuple(int(v) for v in value)
4053
+ elif type == _enums.AttributeType.FLOATS:
4054
+ value = tuple(float(v) for v in value)
4055
+ elif type in {
4056
+ _enums.AttributeType.STRINGS,
4057
+ _enums.AttributeType.TENSORS,
4058
+ _enums.AttributeType.GRAPHS,
4059
+ _enums.AttributeType.TYPE_PROTOS,
4060
+ }:
4061
+ value = tuple(value)
4062
+
4063
+ self._name = name
4064
+ self._type = type
4065
+ self._value = value
4066
+ self._ref_attr_name = ref_attr_name
4067
+ self.doc_string = doc_string
4068
+ self._metadata: _metadata.MetadataStore | None = None
4069
+
4070
+ @property
4071
+ def name(self) -> str:
4072
+ return self._name
4073
+
4074
+ @name.setter
4075
+ def name(self, value: str) -> None:
4076
+ self._name = value
4077
+
4078
+ @property
4079
+ def type(self) -> _enums.AttributeType:
4080
+ return self._type
4081
+
4082
+ @property
4083
+ def value(self) -> Any:
4084
+ return self._value
4085
+
4086
+ @property
4087
+ def ref_attr_name(self) -> str | None:
4088
+ return self._ref_attr_name
4089
+
4090
+ @property
4091
+ def meta(self) -> _metadata.MetadataStore:
4092
+ """The metadata store for intermediate analysis.
4093
+
4094
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
4095
+ to the ONNX proto.
4096
+ """
4097
+ if self._metadata is None:
4098
+ self._metadata = _metadata.MetadataStore()
4099
+ return self._metadata
4100
+
4101
+ def is_ref(self) -> bool:
4102
+ """Check if this attribute is a reference attribute."""
4103
+ return self.ref_attr_name is not None
4104
+
4105
+ def __eq__(self, other: object) -> bool:
4106
+ if not isinstance(other, _protocols.AttributeProtocol):
4107
+ return False
4108
+
4109
+ if self.name != other.name:
4110
+ return False
4111
+ if self.type != other.type:
4112
+ return False
4113
+ if self.value != other.value:
4114
+ return False
4115
+ if self.doc_string != other.doc_string:
4116
+ return False
4117
+ return True
4118
+
4119
+ def __str__(self) -> str:
4120
+ if self.is_ref():
4121
+ return f"@{self.ref_attr_name}"
4122
+ if self.type == _enums.AttributeType.GRAPH:
4123
+ return textwrap.indent("\n" + str(self.value), " " * 4)
4124
+ return repr(self.value)
4125
+
4126
+ def __repr__(self) -> str:
4127
+ if self.is_ref():
4128
+ return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
4129
+ return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
4130
+
4131
+ # Well typed getters
4132
+ def as_float(self) -> float:
4133
+ """Get the attribute value as a float."""
4134
+ if self.type != _enums.AttributeType.FLOAT:
4135
+ raise TypeError(
4136
+ f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
4137
+ )
4138
+ # value is guaranteed to be a float in the constructor
4139
+ return self.value
4140
+
4141
+ def as_int(self) -> int:
4142
+ """Get the attribute value as an int."""
4143
+ if self.type != _enums.AttributeType.INT:
4144
+ raise TypeError(
4145
+ f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
4146
+ )
4147
+ # value is guaranteed to be an int in the constructor
4148
+ return self.value
4149
+
4150
+ def as_string(self) -> str:
4151
+ """Get the attribute value as a string."""
4152
+ if self.type != _enums.AttributeType.STRING:
4153
+ raise TypeError(
4154
+ f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
4155
+ )
4156
+ value = self.value
4157
+ if not isinstance(value, str):
4158
+ raise TypeError(f"Value of attribute '{self!r}' is not a string.")
4159
+ return value
4160
+
4161
+ def as_tensor(self) -> _protocols.TensorProtocol:
4162
+ """Get the attribute value as a tensor."""
4163
+ if self.type != _enums.AttributeType.TENSOR:
4164
+ raise TypeError(
4165
+ f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
4166
+ )
4167
+ value = self.value
4168
+ if not isinstance(value, _protocols.TensorProtocol):
4169
+ raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
4170
+ return value
4171
+
4172
+ def as_graph(self) -> Graph:
4173
+ """Get the attribute value as a graph."""
4174
+ if self.type != _enums.AttributeType.GRAPH:
4175
+ raise TypeError(
4176
+ f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
4177
+ )
4178
+ value = self.value
4179
+ if not isinstance(value, Graph):
4180
+ raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
4181
+ return value
4182
+
4183
+ def as_floats(self) -> tuple[float, ...]:
4184
+ """Get the attribute value as a sequence of floats."""
4185
+ if self.type != _enums.AttributeType.FLOATS:
4186
+ raise TypeError(
4187
+ f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
4188
+ )
4189
+ # value is guaranteed to be a sequence of float in the constructor
4190
+ return self.value
4191
+
4192
+ def as_ints(self) -> tuple[int, ...]:
4193
+ """Get the attribute value as a sequence of ints."""
4194
+ if self.type != _enums.AttributeType.INTS:
4195
+ raise TypeError(
4196
+ f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
4197
+ )
4198
+ # value is guaranteed to be a sequence of int in the constructor
4199
+ return self.value
4200
+
4201
+ def as_strings(self) -> tuple[str, ...]:
4202
+ """Get the attribute value as a sequence of strings."""
4203
+ if self.type != _enums.AttributeType.STRINGS:
4204
+ raise TypeError(
4205
+ f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
4206
+ )
4207
+ if onnx_ir.DEBUG:
4208
+ if not all(isinstance(x, str) for x in self.value):
4209
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
4210
+ # value is guaranteed to be a sequence in the constructor
4211
+ return self.value
4212
+
4213
+ def as_tensors(self) -> tuple[_protocols.TensorProtocol, ...]:
4214
+ """Get the attribute value as a sequence of tensors."""
4215
+ if self.type != _enums.AttributeType.TENSORS:
4216
+ raise TypeError(
4217
+ f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
4218
+ )
4219
+ if onnx_ir.DEBUG:
4220
+ if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
4221
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
4222
+ # value is guaranteed to be a sequence in the constructor
4223
+ return tuple(self.value)
4224
+
4225
+ def as_graphs(self) -> tuple[Graph, ...]:
4226
+ """Get the attribute value as a sequence of graphs."""
4227
+ if self.type != _enums.AttributeType.GRAPHS:
4228
+ raise TypeError(
4229
+ f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
4230
+ )
4231
+ if onnx_ir.DEBUG:
4232
+ if not all(isinstance(x, Graph) for x in self.value):
4233
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
4234
+ # value is guaranteed to be a sequence in the constructor
4235
+ return tuple(self.value)
4236
+
4237
+
4238
+ # NOTE: The following functions are just for convenience
4239
+
4240
+
4241
+ def RefAttr( # noqa: N802
4242
+ name: str,
4243
+ ref_attr_name: str,
4244
+ type: _enums.AttributeType,
4245
+ doc_string: str | None = None,
4246
+ ) -> Attr:
4247
+ """Create a reference attribute.
4248
+
4249
+ Args:
4250
+ name: The name of the attribute.
4251
+ type: The type of the attribute.
4252
+ ref_attr_name: The name of the referenced attribute.
4253
+ doc_string: Documentation string.
4254
+
4255
+ Returns:
4256
+ A reference attribute.
4257
+ """
4258
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4259
+ return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
4260
+
4261
+
4262
+ def AttrFloat32(name: str, value: float | np.floating, doc_string: str | None = None) -> Attr: # noqa: N802
4263
+ """Create a float attribute."""
4264
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4265
+ return Attr(
4266
+ name,
4267
+ _enums.AttributeType.FLOAT,
4268
+ value,
4269
+ doc_string=doc_string,
4270
+ )
4271
+
4272
+
4273
+ def AttrInt64(name: str, value: int | np.integer, doc_string: str | None = None) -> Attr: # noqa: N802
4274
+ """Create an int attribute."""
4275
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4276
+ return Attr(
4277
+ name,
4278
+ _enums.AttributeType.INT,
4279
+ value,
4280
+ doc_string=doc_string,
4281
+ )
4282
+
4283
+
4284
+ def AttrString(name: str, value: str, doc_string: str | None = None) -> Attr: # noqa: N802
4285
+ """Create a str attribute."""
4286
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4287
+ return Attr(
4288
+ name,
4289
+ _enums.AttributeType.STRING,
4290
+ value,
4291
+ doc_string=doc_string,
4292
+ )
4293
+
4294
+
4295
+ def AttrTensor( # noqa: N802
4296
+ name: str, value: _protocols.TensorProtocol, doc_string: str | None = None
4297
+ ) -> Attr:
4298
+ """Create a tensor attribute."""
4299
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4300
+ return Attr(
4301
+ name,
4302
+ _enums.AttributeType.TENSOR,
4303
+ value,
4304
+ doc_string=doc_string,
4305
+ )
4306
+
4307
+
4308
+ def AttrGraph(name: str, value: Graph, doc_string: str | None = None) -> Attr: # noqa: N802
4309
+ """Create a graph attribute."""
4310
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4311
+ return Attr(
4312
+ name,
4313
+ _enums.AttributeType.GRAPH,
4314
+ value,
4315
+ doc_string=doc_string,
4316
+ )
4317
+
4318
+
4319
+ def AttrFloat32s(name: str, value: Sequence[float], doc_string: str | None = None) -> Attr: # noqa: N802
4320
+ """Create a float sequence attribute."""
4321
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4322
+ return Attr(
4323
+ name,
4324
+ _enums.AttributeType.FLOATS,
4325
+ value,
4326
+ doc_string=doc_string,
4327
+ )
4328
+
4329
+
4330
+ def AttrInt64s(name: str, value: Sequence[int], doc_string: str | None = None) -> Attr: # noqa: N802
4331
+ """Create an int sequence attribute."""
4332
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4333
+ return Attr(
4334
+ name,
4335
+ _enums.AttributeType.INTS,
4336
+ value,
4337
+ doc_string=doc_string,
4338
+ )
4339
+
4340
+
4341
+ def AttrStrings(name: str, value: Sequence[str], doc_string: str | None = None) -> Attr: # noqa: N802
4342
+ """Create a string sequence attribute."""
4343
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4344
+ return Attr(
4345
+ name,
4346
+ _enums.AttributeType.STRINGS,
4347
+ value,
4348
+ doc_string=doc_string,
4349
+ )
4350
+
4351
+
4352
+ def AttrTensors( # noqa: N802
4353
+ name: str, value: Sequence[_protocols.TensorProtocol], doc_string: str | None = None
4354
+ ) -> Attr:
4355
+ """Create a tensor sequence attribute."""
4356
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4357
+ return Attr(
4358
+ name,
4359
+ _enums.AttributeType.TENSORS,
4360
+ value,
4361
+ doc_string=doc_string,
4362
+ )
4363
+
4364
+
4365
+ def AttrGraphs(name: str, value: Sequence[Graph], doc_string: str | None = None) -> Attr: # noqa: N802
4366
+ """Create a graph sequence attribute."""
4367
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4368
+ return Attr(
4369
+ name,
4370
+ _enums.AttributeType.GRAPHS,
4371
+ value,
4372
+ doc_string=doc_string,
4373
+ )
4374
+
4375
+
4376
+ # NOTE: SparseTensor should be a sparse tensor proto
4377
+ def AttrSparseTensor( # noqa: N802
4378
+ name: str, value: _protocols.SparseTensorProtocol, doc_string: str | None = None
4379
+ ) -> Attr:
4380
+ """Create a sparse tensor attribute."""
4381
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4382
+ return Attr(
4383
+ name,
4384
+ _enums.AttributeType.SPARSE_TENSOR,
4385
+ value,
4386
+ doc_string=doc_string,
4387
+ )
4388
+
4389
+
4390
+ def AttrSparseTensors( # noqa: N802
4391
+ name: str, value: Sequence[_protocols.SparseTensorProtocol], doc_string: str | None = None
4392
+ ) -> Attr:
4393
+ """Create a sparse tensor sequence attribute."""
4394
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4395
+ return Attr(
4396
+ name,
4397
+ _enums.AttributeType.SPARSE_TENSORS,
4398
+ value,
4399
+ doc_string=doc_string,
4400
+ )
4401
+
4402
+
4403
+ @dataclasses.dataclass
4404
+ class TypeAndShape:
4405
+ """Type and shape.
4406
+
4407
+ Useful for constructing a type proto.
4408
+ """
4409
+
4410
+ type: _protocols.TypeProtocol | None
4411
+ shape: Shape | None
4412
+
4413
+
4414
+ def AttrTypeProto(name: str, value: TypeAndShape, doc_string: str | None = None) -> Attr: # noqa: N802
4415
+ """Create a type attribute."""
4416
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4417
+ return Attr(
4418
+ name,
4419
+ _enums.AttributeType.TYPE_PROTO,
4420
+ value,
4421
+ doc_string=doc_string,
4422
+ )
4423
+
4424
+
4425
+ def AttrTypeProtos( # noqa: N802
4426
+ name: str, value: Sequence[TypeAndShape], doc_string: str | None = None
4427
+ ) -> Attr:
4428
+ """Create a type sequence attribute."""
4429
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
4430
+ return Attr(
4431
+ name,
4432
+ _enums.AttributeType.TYPE_PROTOS,
4433
+ value,
4434
+ doc_string=doc_string,
4435
+ )