onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """data structures for the intermediate representation."""
4
4
 
5
5
  # NOTES for developers:
@@ -22,26 +22,36 @@ import os
22
22
  import sys
23
23
  import textwrap
24
24
  import typing
25
- from typing import (
26
- AbstractSet,
27
- Any,
25
+ from collections.abc import (
28
26
  Collection,
29
- Generic,
30
27
  Hashable,
31
28
  Iterable,
32
29
  Iterator,
33
- OrderedDict,
30
+ Mapping,
31
+ MutableSequence,
34
32
  Sequence,
33
+ )
34
+ from collections.abc import (
35
+ Set as AbstractSet,
36
+ )
37
+ from typing import (
38
+ Any,
39
+ Callable,
40
+ Generic,
41
+ NamedTuple,
42
+ SupportsInt,
35
43
  Union,
36
44
  )
37
45
 
38
46
  import ml_dtypes
39
47
  import numpy as np
48
+ from typing_extensions import TypeIs
40
49
 
41
50
  import onnx_ir
42
51
  from onnx_ir import (
43
52
  _display,
44
53
  _enums,
54
+ _graph_containers,
45
55
  _linked_list,
46
56
  _metadata,
47
57
  _name_authority,
@@ -70,6 +80,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
70
80
  _enums.DataType.FLOAT8E5M2FNUZ,
71
81
  _enums.DataType.INT4,
72
82
  _enums.DataType.UINT4,
83
+ _enums.DataType.FLOAT4E2M1,
73
84
  )
74
85
  )
75
86
 
@@ -93,7 +104,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
93
104
  class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
94
105
  """Convenience Shared methods for classes implementing TensorProtocol."""
95
106
 
96
- __slots__ = ()
107
+ __slots__ = (
108
+ "_doc_string",
109
+ "_metadata",
110
+ "_metadata_props",
111
+ "_name",
112
+ )
113
+
114
+ def __init__(
115
+ self,
116
+ name: str | None = None,
117
+ doc_string: str | None = None,
118
+ metadata_props: dict[str, str] | None = None,
119
+ ) -> None:
120
+ self._metadata: _metadata.MetadataStore | None = None
121
+ self._metadata_props: dict[str, str] | None = metadata_props
122
+ self._name: str | None = name
123
+ self._doc_string: str | None = doc_string
97
124
 
98
125
  def _printable_type_shape(self) -> str:
99
126
  """Return a string representation of the shape and data type."""
@@ -106,10 +133,28 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
106
133
  """
107
134
  return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
108
135
 
136
+ @property
137
+ def name(self) -> str | None:
138
+ """The name of the tensor."""
139
+ return self._name
140
+
141
+ @name.setter
142
+ def name(self, value: str | None) -> None:
143
+ self._name = value
144
+
145
+ @property
146
+ def doc_string(self) -> str | None:
147
+ """The documentation string."""
148
+ return self._doc_string
149
+
150
+ @doc_string.setter
151
+ def doc_string(self, value: str | None) -> None:
152
+ self._doc_string = value
153
+
109
154
  @property
110
155
  def size(self) -> int:
111
156
  """The number of elements in the tensor."""
112
- return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined]
157
+ return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
113
158
 
114
159
  @property
115
160
  def nbytes(self) -> int:
@@ -117,6 +162,23 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
117
162
  # Use math.ceil because when dtype is INT4, the itemsize is 0.5
118
163
  return math.ceil(self.dtype.itemsize * self.size)
119
164
 
165
+ @property
166
+ def metadata_props(self) -> dict[str, str]:
167
+ if self._metadata_props is None:
168
+ self._metadata_props = {}
169
+ return self._metadata_props
170
+
171
+ @property
172
+ def meta(self) -> _metadata.MetadataStore:
173
+ """The metadata store for intermediate analysis.
174
+
175
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
176
+ to the ONNX proto.
177
+ """
178
+ if self._metadata is None:
179
+ self._metadata = _metadata.MetadataStore()
180
+ return self._metadata
181
+
120
182
  def display(self, *, page: bool = False) -> None:
121
183
  rich = _display.require_rich()
122
184
 
@@ -182,7 +244,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
182
244
  When the dtype is not one of the numpy native dtypes, the value needs need to be:
183
245
 
184
246
  - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
185
- - ``uint8`` for uint4.
247
+ - ``uint8`` for uint4 or float4.
186
248
  - ``uint8`` for 8-bit data types.
187
249
  - ``uint16`` for bfloat16
188
250
 
@@ -195,7 +257,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
195
257
  )
196
258
  if dtype.itemsize == 1 and array.dtype not in (
197
259
  np.uint8,
198
- ml_dtypes.float8_e4m3b11fnuz,
260
+ ml_dtypes.float8_e4m3fnuz,
199
261
  ml_dtypes.float8_e4m3fn,
200
262
  ml_dtypes.float8_e5m2fnuz,
201
263
  ml_dtypes.float8_e5m2,
@@ -213,6 +275,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
213
275
  raise TypeError(
214
276
  f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
215
277
  )
278
+ if dtype == _enums.DataType.FLOAT4E2M1:
279
+ if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
280
+ raise TypeError(
281
+ f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
282
+ )
216
283
  return
217
284
 
218
285
  try:
@@ -256,6 +323,8 @@ def _maybe_view_np_array_with_ml_dtypes(
256
323
  return array.view(ml_dtypes.int4)
257
324
  if dtype == _enums.DataType.UINT4:
258
325
  return array.view(ml_dtypes.uint4)
326
+ if dtype == _enums.DataType.FLOAT4E2M1:
327
+ return array.view(ml_dtypes.float4_e2m1fn)
259
328
  return array
260
329
 
261
330
 
@@ -298,12 +367,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
298
367
 
299
368
  __slots__ = (
300
369
  "_dtype",
301
- "_metadata",
302
- "_metadata_props",
303
370
  "_raw",
304
371
  "_shape",
305
- "doc_string",
306
- "name",
307
372
  )
308
373
 
309
374
  def __init__(
@@ -322,7 +387,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
322
387
  value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
323
388
  When the dtype is not one of the numpy native dtypes, the value needs
324
389
  to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
325
- when the value is a numpy array; :param:`dtype` must be specified in this case.
390
+ when the value is a numpy array; ``dtype`` must be specified in this case.
326
391
  dtype: The data type of the tensor. It can be None only when value is a numpy array.
327
392
  Users are responsible for making sure the dtype matches the value when value is not a numpy array.
328
393
  shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -336,6 +401,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
336
401
  ValueError: If the shape is not specified and the value does not have a shape attribute.
337
402
  ValueError: If the dtype is not specified and the value is not a numpy array.
338
403
  """
404
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
339
405
  # NOTE: We should not do any copying here for performance reasons
340
406
  if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
341
407
  raise TypeError(f"Expected an array compatible object, got {type(value)}")
@@ -349,7 +415,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
349
415
  self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
350
416
  else:
351
417
  self._shape = shape
352
- self._shape._frozen = True
418
+ self._shape.freeze()
353
419
  if dtype is None:
354
420
  if isinstance(value, np.ndarray):
355
421
  self._dtype = _enums.DataType.from_numpy(value.dtype)
@@ -370,17 +436,13 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
370
436
  value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
371
437
 
372
438
  self._raw = value
373
- self.name = name
374
- self.doc_string = doc_string
375
- self._metadata: _metadata.MetadataStore | None = None
376
- self._metadata_props = metadata_props
377
439
 
378
440
  def __array__(self, dtype: Any = None) -> np.ndarray:
379
441
  if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
380
442
  return self._raw.__array__(dtype)
381
- assert _compatible_with_dlpack(
382
- self._raw
383
- ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
443
+ assert _compatible_with_dlpack(self._raw), (
444
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
445
+ )
384
446
  return np.from_dlpack(self._raw)
385
447
 
386
448
  def __dlpack__(self, *, stream: Any = None) -> Any:
@@ -394,7 +456,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
394
456
  return self.__array__().__dlpack_device__()
395
457
 
396
458
  def __repr__(self) -> str:
397
- return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
459
+ # Avoid multi-line repr
460
+ tensor_lines = repr(self._raw).split("\n")
461
+ tensor_text = " ".join(line.strip() for line in tensor_lines)
462
+ return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
398
463
 
399
464
  @property
400
465
  def dtype(self) -> _enums.DataType:
@@ -431,7 +496,11 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
431
496
  """
432
497
  # TODO(justinchuby): Support DLPack
433
498
  array = self.numpy()
434
- if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
499
+ if self.dtype in {
500
+ _enums.DataType.INT4,
501
+ _enums.DataType.UINT4,
502
+ _enums.DataType.FLOAT4E2M1,
503
+ }:
435
504
  # Pack the array into int4
436
505
  array = _type_casting.pack_int4(array)
437
506
  else:
@@ -440,23 +509,6 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
440
509
  array = array.view(array.dtype.newbyteorder("<"))
441
510
  return array.tobytes()
442
511
 
443
- @property
444
- def metadata_props(self) -> dict[str, str]:
445
- if self._metadata_props is None:
446
- self._metadata_props = {}
447
- return self._metadata_props
448
-
449
- @property
450
- def meta(self) -> _metadata.MetadataStore:
451
- """The metadata store for intermediate analysis.
452
-
453
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
454
- to the ONNX proto.
455
- """
456
- if self._metadata is None:
457
- self._metadata = _metadata.MetadataStore()
458
- return self._metadata
459
-
460
512
 
461
513
  class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
462
514
  """An immutable concrete tensor with its data store on disk.
@@ -497,12 +549,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
497
549
  "_dtype",
498
550
  "_length",
499
551
  "_location",
500
- "_metadata",
501
- "_metadata_props",
502
552
  "_offset",
503
553
  "_shape",
504
- "doc_string",
505
- "name",
554
+ "_valid",
506
555
  "raw",
507
556
  )
508
557
 
@@ -532,6 +581,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
532
581
  metadata_props: The metadata properties.
533
582
  base_dir: The base directory for the external data. It is used to resolve relative paths.
534
583
  """
584
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
535
585
  # NOTE: Do not verify the location by default. This is because the location field
536
586
  # in the tensor proto can be anything and we would like deserialization from
537
587
  # proto to IR to not fail.
@@ -547,12 +597,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
547
597
  self._dtype: _enums.DataType = dtype
548
598
  self.name: str = name # mutable
549
599
  self._shape: Shape = shape
550
- self._shape._frozen = True
600
+ self._shape.freeze()
551
601
  self.doc_string: str | None = doc_string # mutable
552
602
  self._array: np.ndarray | None = None
553
603
  self.raw: mmap.mmap | None = None
554
604
  self._metadata_props = metadata_props
555
605
  self._metadata: _metadata.MetadataStore | None = None
606
+ self._valid = True
556
607
 
557
608
  @property
558
609
  def base_dir(self) -> str | os.PathLike:
@@ -594,6 +645,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
594
645
  return self._shape
595
646
 
596
647
  def _load(self):
648
+ self._check_validity()
597
649
  assert self._array is None, "Bug: The array should be loaded only once."
598
650
  if self.size == 0:
599
651
  # When the size is 0, mmap is impossible and meaningless
@@ -609,7 +661,11 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
609
661
  )
610
662
  # Handle the byte order correctly by always using little endian
611
663
  dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
612
- if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
664
+ if self.dtype in {
665
+ _enums.DataType.INT4,
666
+ _enums.DataType.UINT4,
667
+ _enums.DataType.FLOAT4E2M1,
668
+ }:
613
669
  # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
614
670
  dt = np.dtype(np.uint8).newbyteorder("<")
615
671
  count = self.size // 2 + self.size % 2
@@ -622,10 +678,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
622
678
  self._array = _type_casting.unpack_int4(self._array, shape)
623
679
  elif self.dtype == _enums.DataType.UINT4:
624
680
  self._array = _type_casting.unpack_uint4(self._array, shape)
681
+ elif self.dtype == _enums.DataType.FLOAT4E2M1:
682
+ self._array = _type_casting.unpack_float4e2m1(self._array, shape)
625
683
  else:
626
684
  self._array = self._array.reshape(shape)
627
685
 
628
686
  def __array__(self, dtype: Any = None) -> np.ndarray:
687
+ self._check_validity()
629
688
  if self._array is None:
630
689
  self._load()
631
690
  assert self._array is not None
@@ -654,6 +713,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
654
713
 
655
714
  The data will be memory mapped into memory and will not taken up physical memory space.
656
715
  """
716
+ self._check_validity()
657
717
  if self._array is None:
658
718
  self._load()
659
719
  assert self._array is not None
@@ -664,6 +724,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
664
724
 
665
725
  This will load the tensor into memory.
666
726
  """
727
+ self._check_validity()
667
728
  if self.raw is None:
668
729
  self._load()
669
730
  assert self.raw is not None
@@ -671,6 +732,26 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
671
732
  length = self._length or self.nbytes
672
733
  return self.raw[offset : offset + length]
673
734
 
735
+ def valid(self) -> bool:
736
+ """Check if the tensor is valid.
737
+
738
+ The external tensor is valid if it has not been invalidated.
739
+ """
740
+ return self._valid
741
+
742
+ def _check_validity(self) -> None:
743
+ if not self.valid():
744
+ raise ValueError(
745
+ f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
746
+ )
747
+
748
+ def invalidate(self) -> None:
749
+ """Invalidate the tensor.
750
+
751
+ The external tensor is invalidated when the data is known to be corrupted or deleted.
752
+ """
753
+ self._valid = False
754
+
674
755
  def release(self) -> None:
675
756
  """Delete all references to the memory buffer and close the memory-mapped file."""
676
757
  self._array = None
@@ -678,34 +759,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
678
759
  self.raw.close()
679
760
  self.raw = None
680
761
 
681
- @property
682
- def metadata_props(self) -> dict[str, str]:
683
- if self._metadata_props is None:
684
- self._metadata_props = {}
685
- return self._metadata_props
686
-
687
- @property
688
- def meta(self) -> _metadata.MetadataStore:
689
- """The metadata store for intermediate analysis.
690
-
691
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
692
- to the ONNX proto.
693
- """
694
- if self._metadata is None:
695
- self._metadata = _metadata.MetadataStore()
696
- return self._metadata
697
-
698
762
 
699
763
  class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
700
764
  """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
701
765
 
702
766
  __slots__ = (
703
- "_metadata",
704
- "_metadata_props",
705
767
  "_raw",
706
768
  "_shape",
707
- "doc_string",
708
- "name",
709
769
  )
710
770
 
711
771
  def __init__(
@@ -726,6 +786,7 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
726
786
  doc_string: The documentation string.
727
787
  metadata_props: The metadata properties.
728
788
  """
789
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
729
790
  if shape is None:
730
791
  if not hasattr(value, "shape"):
731
792
  raise ValueError(
@@ -735,19 +796,15 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
735
796
  self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
736
797
  else:
737
798
  self._shape = shape
738
- self._shape._frozen = True
799
+ self._shape.freeze()
739
800
  self._raw = value
740
- self.name = name
741
- self.doc_string = doc_string
742
- self._metadata: _metadata.MetadataStore | None = None
743
- self._metadata_props = metadata_props
744
801
 
745
802
  def __array__(self, dtype: Any = None) -> np.ndarray:
746
803
  if isinstance(self._raw, np.ndarray):
747
804
  return self._raw
748
- assert isinstance(
749
- self._raw, Sequence
750
- ), f"Bug: Expected a sequence, got {type(self._raw)}"
805
+ assert isinstance(self._raw, Sequence), (
806
+ f"Bug: Expected a sequence, got {type(self._raw)}"
807
+ )
751
808
  return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
752
809
 
753
810
  def __dlpack__(self, *, stream: Any = None) -> Any:
@@ -788,25 +845,125 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
788
845
  return self._raw.flatten().tolist()
789
846
  return self._raw
790
847
 
848
+
849
+ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
850
+ """A tensor that lazily evaluates a function to get the actual tensor.
851
+
852
+ This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
853
+ The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
854
+
855
+ Example::
856
+
857
+ >>> import numpy as np
858
+ >>> import onnx_ir as ir
859
+ >>> weights = np.array([[1, 2, 3]])
860
+ >>> def create_tensor(): # Delay applying transformations to the weights
861
+ ... weights_t = weights.transpose()
862
+ ... return ir.tensor(weights_t)
863
+ >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
864
+ >>> print(lazy_tensor.numpy())
865
+ [[1]
866
+ [2]
867
+ [3]]
868
+
869
+ Attributes:
870
+ func: The function that returns the actual tensor.
871
+ dtype: The data type of the tensor.
872
+ shape: The shape of the tensor.
873
+ cache: Whether to cache the result of the function. If False,
874
+ the function is called every time the tensor content is accessed.
875
+ If True, the function is called only once and the result is cached in memory.
876
+ Default is False.
877
+ name: The name of the tensor.
878
+ doc_string: The documentation string.
879
+ metadata_props: The metadata properties.
880
+ """
881
+
882
+ __slots__ = (
883
+ "_dtype",
884
+ "_func",
885
+ "_shape",
886
+ "_tensor",
887
+ "cache",
888
+ )
889
+
890
+ def __init__(
891
+ self,
892
+ func: Callable[[], _protocols.TensorProtocol],
893
+ dtype: _enums.DataType,
894
+ shape: Shape,
895
+ *,
896
+ cache: bool = False,
897
+ name: str | None = None,
898
+ doc_string: str | None = None,
899
+ metadata_props: dict[str, str] | None = None,
900
+ ) -> None:
901
+ """Initialize a lazy tensor.
902
+
903
+ Args:
904
+ func: The function that returns the actual tensor.
905
+ dtype: The data type of the tensor.
906
+ shape: The shape of the tensor.
907
+ cache: Whether to cache the result of the function.
908
+ name: The name of the tensor.
909
+ doc_string: The documentation string.
910
+ metadata_props: The metadata properties.
911
+ """
912
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
913
+ self._func = func
914
+ self._dtype = dtype
915
+ self._shape = shape
916
+ self._tensor: _protocols.TensorProtocol | None = None
917
+ self.cache = cache
918
+
919
+ def _evaluate(self) -> _protocols.TensorProtocol:
920
+ """Evaluate the function to get the actual tensor."""
921
+ if not self.cache:
922
+ return self._func()
923
+
924
+ # Cache the tensor
925
+ if self._tensor is None:
926
+ self._tensor = self._func()
927
+ return self._tensor
928
+
929
+ def __array__(self, dtype: Any = None) -> np.ndarray:
930
+ return self._evaluate().__array__(dtype)
931
+
932
+ def __dlpack__(self, *, stream: Any = None) -> Any:
933
+ return self._evaluate().__dlpack__(stream=stream)
934
+
935
+ def __dlpack_device__(self) -> tuple[int, int]:
936
+ return self._evaluate().__dlpack_device__()
937
+
938
+ def __repr__(self) -> str:
939
+ return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
940
+
791
941
  @property
792
- def metadata_props(self) -> dict[str, str]:
793
- if self._metadata_props is None:
794
- self._metadata_props = {}
795
- return self._metadata_props
942
+ def raw(self) -> Callable[[], _protocols.TensorProtocol]:
943
+ return self._func
796
944
 
797
945
  @property
798
- def meta(self) -> _metadata.MetadataStore:
799
- """The metadata store for intermediate analysis.
946
+ def dtype(self) -> _enums.DataType:
947
+ """The data type of the tensor. Immutable."""
948
+ return self._dtype
800
949
 
801
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
802
- to the ONNX proto.
803
- """
804
- if self._metadata is None:
805
- self._metadata = _metadata.MetadataStore()
806
- return self._metadata
950
+ @property
951
+ def shape(self) -> Shape:
952
+ """The shape of the tensor. Immutable."""
953
+ return self._shape
954
+
955
+ def numpy(self) -> np.ndarray:
956
+ """Return the tensor as a numpy array."""
957
+ return self._evaluate().numpy()
958
+
959
+ def tobytes(self) -> bytes:
960
+ """Return the bytes of the tensor."""
961
+ return self._evaluate().tobytes()
807
962
 
808
963
 
809
964
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
965
+ """Immutable symbolic dimension that can be shared across multiple shapes."""
966
+
810
967
  __slots__ = ("_value",)
811
968
 
812
969
  def __init__(self, value: str | None) -> None:
@@ -841,12 +998,84 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
841
998
  return f"{self.__class__.__name__}({self._value})"
842
999
 
843
1000
 
1001
+ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1002
+ """Return True if the value is int compatible."""
1003
+ if isinstance(value, int):
1004
+ return True
1005
+ if hasattr(value, "__int__"):
1006
+ # For performance reasons, we do not use isinstance(value, SupportsInt)
1007
+ return True
1008
+ return False
1009
+
1010
+
1011
+ def _maybe_convert_to_symbolic_dim(
1012
+ dim: int | SupportsInt | SymbolicDim | str | None,
1013
+ ) -> SymbolicDim | int:
1014
+ """Convert the value to a SymbolicDim if it is not an int."""
1015
+ if dim is None or isinstance(dim, str):
1016
+ return SymbolicDim(dim)
1017
+ if _is_int_compatible(dim):
1018
+ return int(dim)
1019
+ if isinstance(dim, SymbolicDim):
1020
+ return dim
1021
+ raise TypeError(
1022
+ f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
1023
+ )
1024
+
1025
+
844
1026
  class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1027
+ """The shape of a tensor, including its dimensions and optional denotations.
1028
+
1029
+ The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1030
+ symbolic dimensions.
1031
+
1032
+ A shape can be compared to another shape or plain Python list.
1033
+
1034
+ A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1035
+ unfrozen, making it suitable to be shared across tensors or values.
1036
+ Call :method:`freeze` to freeze the shape.
1037
+
1038
+ To update the dimension of a frozen shape, call :method:`copy` to create a
1039
+ new shape with the same dimensions that can be modified.
1040
+
1041
+ Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1042
+
1043
+ Example::
1044
+
1045
+ >>> import onnx_ir as ir
1046
+ >>> shape = ir.Shape(["B", None, 3])
1047
+ >>> shape.rank()
1048
+ 3
1049
+ >>> shape.is_static()
1050
+ False
1051
+ >>> shape.is_dynamic()
1052
+ True
1053
+ >>> shape.is_static(dim=2)
1054
+ True
1055
+ >>> shape[0] = 1
1056
+ >>> shape[1] = 2
1057
+ >>> shape.dims
1058
+ (1, 2, 3)
1059
+ >>> shape == [1, 2, 3]
1060
+ True
1061
+ >>> shape.frozen
1062
+ False
1063
+ >>> shape.freeze()
1064
+ >>> shape.frozen
1065
+ True
1066
+
1067
+ Attributes:
1068
+ dims: A tuple of dimensions representing the shape.
1069
+ Each dimension can be an integer, None or a :class:`SymbolicDim`.
1070
+ frozen: Indicates whether the shape is immutable. When frozen, the shape
1071
+ cannot be modified or unfrozen.
1072
+ """
1073
+
845
1074
  __slots__ = ("_dims", "_frozen")
846
1075
 
847
1076
  def __init__(
848
1077
  self,
849
- dims: Iterable[int | SymbolicDim | str | None],
1078
+ dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
850
1079
  /,
851
1080
  denotations: Iterable[str | None] | None = None,
852
1081
  frozen: bool = False,
@@ -864,11 +1093,11 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
864
1093
  Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
865
1094
  for pre-defined dimension denotations.
866
1095
  frozen: If True, the shape is immutable and cannot be modified. This
867
- is useful when the shape is initialized by a Tensor.
1096
+ is useful when the shape is initialized by a Tensor or when the shape
1097
+ is shared across multiple tensors. The default is False.
868
1098
  """
869
1099
  self._dims: list[int | SymbolicDim] = [
870
- SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim
871
- for dim in dims
1100
+ _maybe_convert_to_symbolic_dim(dim) for dim in dims
872
1101
  ]
873
1102
  self._denotations: list[str | None] = (
874
1103
  list(denotations) if denotations is not None else [None] * len(self._dims)
@@ -879,10 +1108,6 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
879
1108
  )
880
1109
  self._frozen: bool = frozen
881
1110
 
882
- def copy(self):
883
- """Return a copy of the shape."""
884
- return Shape(self._dims, self._denotations, self._frozen)
885
-
886
1111
  @property
887
1112
  def dims(self) -> tuple[int | SymbolicDim, ...]:
888
1113
  """All dimensions in the shape.
@@ -891,8 +1116,29 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
891
1116
  """
892
1117
  return tuple(self._dims)
893
1118
 
1119
+ @property
1120
+ def frozen(self) -> bool:
1121
+ """Whether the shape is frozen.
1122
+
1123
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1124
+ Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1125
+ new shape with the same dimensions that can be modified.
1126
+ """
1127
+ return self._frozen
1128
+
1129
+ def freeze(self) -> None:
1130
+ """Freeze the shape.
1131
+
1132
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1133
+ """
1134
+ self._frozen = True
1135
+
1136
+ def copy(self, frozen: bool = False):
1137
+ """Return a copy of the shape."""
1138
+ return Shape(self._dims, self._denotations, frozen=frozen)
1139
+
894
1140
  def rank(self) -> int:
895
- """The rank of the shape."""
1141
+ """The rank of the tensor this shape represents."""
896
1142
  return len(self._dims)
897
1143
 
898
1144
  def numpy(self) -> tuple[int, ...]:
@@ -928,12 +1174,8 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
928
1174
  """
929
1175
  if self._frozen:
930
1176
  raise TypeError("The shape is frozen and cannot be modified.")
931
- if isinstance(value, str) or value is None:
932
- value = SymbolicDim(value)
933
- if not isinstance(value, (int, SymbolicDim)):
934
- raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
935
1177
 
936
- self._dims[index] = value
1178
+ self._dims[index] = _maybe_convert_to_symbolic_dim(value)
937
1179
 
938
1180
  def get_denotation(self, index: int) -> str | None:
939
1181
  """Return the denotation of the dimension at the index.
@@ -968,7 +1210,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
968
1210
  def __eq__(self, other: object) -> bool:
969
1211
  """Return True if the shapes are equal.
970
1212
 
971
- Two shapes are eqaul if all their dimensions are equal.
1213
+ Two shapes are equal if all their dimensions are equal.
972
1214
  """
973
1215
  if isinstance(other, Shape):
974
1216
  return self._dims == other._dims
@@ -979,6 +1221,33 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
979
1221
  def __ne__(self, other: object) -> bool:
980
1222
  return not self.__eq__(other)
981
1223
 
1224
+ @typing.overload
1225
+ def is_static(self, dim: int) -> bool: # noqa: D418
1226
+ """Return True if the dimension is static."""
1227
+
1228
+ @typing.overload
1229
+ def is_static(self) -> bool: # noqa: D418
1230
+ """Return True if all dimensions are static."""
1231
+
1232
+ def is_static(self, dim=None) -> bool:
1233
+ """Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
1234
+ if dim is None:
1235
+ return all(isinstance(dim, int) for dim in self._dims)
1236
+ return isinstance(self[dim], int)
1237
+
1238
+ @typing.overload
1239
+ def is_dynamic(self, dim: int) -> bool: # noqa: D418
1240
+ """Return True if the dimension is dynamic."""
1241
+
1242
+ @typing.overload
1243
+ def is_dynamic(self) -> bool: # noqa: D418
1244
+ """Return True if any dimension is dynamic."""
1245
+
1246
+ def is_dynamic(self, dim=None) -> bool:
1247
+ if dim is None:
1248
+ return not self.is_static()
1249
+ return not self.is_static(dim)
1250
+
982
1251
 
983
1252
  def _quoted(string: str) -> str:
984
1253
  """Return a quoted string.
@@ -988,6 +1257,35 @@ def _quoted(string: str) -> str:
988
1257
  return f'"{string}"'
989
1258
 
990
1259
 
1260
+ class Usage(NamedTuple):
1261
+ """A usage of a value in a node.
1262
+
1263
+ Attributes:
1264
+ node: The node that uses the value.
1265
+ idx: The input index of the value in the node.
1266
+ """
1267
+
1268
+ node: Node
1269
+ idx: int
1270
+
1271
+
1272
+ def _short_tensor_str_for_node(x: Value) -> str:
1273
+ if x.const_value is None:
1274
+ return ""
1275
+ if x.const_value.size <= 10:
1276
+ try:
1277
+ data = x.const_value.numpy().tolist()
1278
+ except Exception: # pylint: disable=broad-except
1279
+ return "{...}"
1280
+ return f"{{{data}}}"
1281
+ return "{...}"
1282
+
1283
+
1284
+ def _normalize_domain(domain: str) -> str:
1285
+ """Normalize 'ai.onnx' to ''."""
1286
+ return "" if domain == "ai.onnx" else domain
1287
+
1288
+
991
1289
  class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
992
1290
  """IR Node.
993
1291
 
@@ -1001,6 +1299,9 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1001
1299
  To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1002
1300
  the new output values by calling :meth:`replace_input_with` on the using nodes
1003
1301
  of this node's outputs.
1302
+
1303
+ .. note:
1304
+ When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
1004
1305
  """
1005
1306
 
1006
1307
  __slots__ = (
@@ -1023,13 +1324,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1023
1324
  domain: str,
1024
1325
  op_type: str,
1025
1326
  inputs: Iterable[Value | None],
1026
- attributes: Iterable[Attr | RefAttr] = (),
1327
+ attributes: Iterable[Attr] | Mapping[str, Attr] = (),
1027
1328
  *,
1028
1329
  overload: str = "",
1029
1330
  num_outputs: int | None = None,
1030
1331
  outputs: Sequence[Value] | None = None,
1031
1332
  version: int | None = None,
1032
- graph: Graph | None = None,
1333
+ graph: Graph | Function | None = None,
1033
1334
  name: str | None = None,
1034
1335
  doc_string: str | None = None,
1035
1336
  metadata_props: dict[str, str] | None = None,
@@ -1038,27 +1339,30 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1038
1339
 
1039
1340
  Args:
1040
1341
  domain: The domain of the operator. For onnx operators, this is an empty string.
1342
+ When it is `"ai.onnx"`, it is normalized to `""`.
1041
1343
  op_type: The name of the operator.
1042
- inputs: The input values. When an input is None, it is an empty input.
1344
+ inputs: The input values. When an input is ``None``, it is an empty input.
1043
1345
  attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
1044
1346
  overload: The overload name when the node is invoking a function.
1045
1347
  num_outputs: The number of outputs of the node. If not specified, the number is 1.
1046
- outputs: The output values. If None, the outputs are created during initialization.
1047
- version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
1048
- graph: The graph that the node belongs to. If None, the node is not added to any graph.
1049
- A `Node` must belong to zero or one graph.
1050
- name: The name of the node. If None, the node is anonymous.
1348
+ outputs: The output values. If ``None``, the outputs are created during initialization.
1349
+ version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
1350
+ graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
1351
+ A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
1352
+ of the function is assigned to the node.
1353
+ name: The name of the node. If ``None``, the node is anonymous. The name may be
1354
+ set by a :class:`Graph` if ``graph`` is specified.
1051
1355
  doc_string: The documentation string.
1052
1356
  metadata_props: The metadata properties.
1053
1357
 
1054
1358
  Raises:
1055
- TypeError: If the attributes are not Attr or RefAttr.
1056
- ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1057
- ValueError: If an output value is None, when outputs is specified.
1359
+ TypeError: If the attributes are not :class:`Attr`.
1360
+ ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
1361
+ ValueError: If an output value is ``None``, when outputs is specified.
1058
1362
  ValueError: If an output value has a producer set already, when outputs is specified.
1059
1363
  """
1060
1364
  self._name = name
1061
- self._domain: str = domain
1365
+ self._domain: str = _normalize_domain(domain)
1062
1366
  self._op_type: str = op_type
1063
1367
  # NOTE: Make inputs immutable with the assumption that they are not mutated
1064
1368
  # very often. This way all mutations can be tracked.
@@ -1066,22 +1370,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1066
1370
  self._inputs: tuple[Value | None, ...] = tuple(inputs)
1067
1371
  # Values belong to their defining nodes. The values list is immutable
1068
1372
  self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1069
- attributes = tuple(attributes)
1070
- if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
1071
- raise TypeError(
1072
- f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. "
1073
- "If you are copying the attributes from another node, make sure you call "
1074
- "node.attributes.values() because it is a dictionary."
1075
- )
1076
- self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
1077
- (attr.name, attr) for attr in attributes
1373
+ if isinstance(attributes, Mapping):
1374
+ attributes = tuple(attributes.values())
1375
+ self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
1376
+ attributes
1078
1377
  )
1079
1378
  self._overload: str = overload
1080
1379
  # TODO(justinchuby): Potentially support a version range
1081
1380
  self._version: int | None = version
1082
1381
  self._metadata: _metadata.MetadataStore | None = None
1083
1382
  self._metadata_props: dict[str, str] | None = metadata_props
1084
- self._graph: Graph | None = graph
1383
+ # _graph is set by graph.append
1384
+ self._graph: Graph | None = None
1385
+ # Add the node to the graph if graph is specified
1386
+ if graph is not None:
1387
+ graph.append(self)
1085
1388
  self.doc_string = doc_string
1086
1389
 
1087
1390
  # Add the node as a use of the inputs
@@ -1089,10 +1392,6 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1089
1392
  if input_value is not None:
1090
1393
  input_value._add_usage(self, i) # pylint: disable=protected-access
1091
1394
 
1092
- # Add the node to the graph if graph is specified
1093
- if self._graph is not None:
1094
- self._graph.append(self)
1095
-
1096
1395
  def _create_outputs(
1097
1396
  self, num_outputs: int | None, outputs: Sequence[Value] | None
1098
1397
  ) -> tuple[Value, ...]:
@@ -1150,7 +1449,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1150
1449
  + ", ".join(
1151
1450
  [
1152
1451
  (
1153
- f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
1452
+ f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
1154
1453
  if x is not None
1155
1454
  else "None"
1156
1455
  )
@@ -1178,6 +1477,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1178
1477
 
1179
1478
  @property
1180
1479
  def name(self) -> str | None:
1480
+ """Optional name of the node."""
1181
1481
  return self._name
1182
1482
 
1183
1483
  @name.setter
@@ -1186,14 +1486,26 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1186
1486
 
1187
1487
  @property
1188
1488
  def domain(self) -> str:
1489
+ """The domain of the operator. For onnx operators, this is an empty string.
1490
+
1491
+ .. note:
1492
+ When domain is `"ai.onnx"`, it is normalized to `""`.
1493
+ """
1189
1494
  return self._domain
1190
1495
 
1191
1496
  @domain.setter
1192
1497
  def domain(self, value: str) -> None:
1193
- self._domain = value
1498
+ self._domain = _normalize_domain(value)
1194
1499
 
1195
1500
  @property
1196
1501
  def version(self) -> int | None:
1502
+ """Opset version of the operator called.
1503
+
1504
+ If ``None``, the version is unspecified and will follow that of the graph.
1505
+ This property is special to ONNX IR to allow mixed opset usage in a graph
1506
+ for supporting more flexible graph transformations. It does not exist in the ONNX
1507
+ serialization (protobuf) spec.
1508
+ """
1197
1509
  return self._version
1198
1510
 
1199
1511
  @version.setter
@@ -1202,6 +1514,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1202
1514
 
1203
1515
  @property
1204
1516
  def op_type(self) -> str:
1517
+ """The name of the operator called."""
1205
1518
  return self._op_type
1206
1519
 
1207
1520
  @op_type.setter
@@ -1210,6 +1523,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1210
1523
 
1211
1524
  @property
1212
1525
  def overload(self) -> str:
1526
+ """The overload name when the node is invoking a function."""
1213
1527
  return self._overload
1214
1528
 
1215
1529
  @overload.setter
@@ -1218,6 +1532,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1218
1532
 
1219
1533
  @property
1220
1534
  def inputs(self) -> Sequence[Value | None]:
1535
+ """The input values of the node.
1536
+
1537
+ The inputs are immutable. To change the inputs, create a new node and
1538
+ replace the inputs of the using nodes of this node's outputs by calling
1539
+ :meth:`replace_input_with` on the using nodes of this node's outputs.
1540
+ """
1221
1541
  return self._inputs
1222
1542
 
1223
1543
  @inputs.setter
@@ -1226,6 +1546,25 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1226
1546
  "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
1227
1547
  )
1228
1548
 
1549
+ def predecessors(self) -> Sequence[Node]:
1550
+ """Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
1551
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1552
+ predecessors: dict[Node, None] = {}
1553
+ for value in self.inputs:
1554
+ if value is not None and (producer := value.producer()) is not None:
1555
+ predecessors[producer] = None
1556
+ return tuple(predecessors)
1557
+
1558
+ def successors(self) -> Sequence[Node]:
1559
+ """Return the successor nodes of the node, deduplicated, in a deterministic order."""
1560
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1561
+ successors: dict[Node, None] = {}
1562
+ for value in self.outputs:
1563
+ assert value is not None, "Bug: Output values are not expected to be None"
1564
+ for usage in value.uses():
1565
+ successors[usage.node] = None
1566
+ return tuple(successors)
1567
+
1229
1568
  def replace_input_with(self, index: int, value: Value | None) -> None:
1230
1569
  """Replace an input with a new value."""
1231
1570
  if index < 0 or index >= len(self.inputs):
@@ -1279,6 +1618,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1279
1618
 
1280
1619
  @property
1281
1620
  def outputs(self) -> Sequence[Value]:
1621
+ """The output values of the node.
1622
+
1623
+ The outputs are immutable. To change the outputs, create a new node and
1624
+ replace the inputs of the using nodes of this node's outputs by calling
1625
+ :meth:`replace_input_with` on the using nodes of this node's outputs.
1626
+ """
1282
1627
  return self._outputs
1283
1628
 
1284
1629
  @outputs.setter
@@ -1286,7 +1631,8 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1286
1631
  raise AttributeError("outputs is immutable. Please create a new node instead.")
1287
1632
 
1288
1633
  @property
1289
- def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
1634
+ def attributes(self) -> _graph_containers.Attributes:
1635
+ """The attributes of the node."""
1290
1636
  return self._attributes
1291
1637
 
1292
1638
  @property
@@ -1302,12 +1648,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1302
1648
 
1303
1649
  @property
1304
1650
  def metadata_props(self) -> dict[str, str]:
1651
+ """The metadata properties of the node.
1652
+
1653
+ The metadata properties are used to store additional information about the node.
1654
+ Unlike ``meta``, this property is serialized to the ONNX proto.
1655
+ """
1305
1656
  if self._metadata_props is None:
1306
1657
  self._metadata_props = {}
1307
1658
  return self._metadata_props
1308
1659
 
1309
1660
  @property
1310
1661
  def graph(self) -> Graph | None:
1662
+ """The graph that the node belongs to.
1663
+
1664
+ If the node is not added to any graph, this property is None.
1665
+ """
1311
1666
  return self._graph
1312
1667
 
1313
1668
  @graph.setter
@@ -1315,9 +1670,17 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1315
1670
  self._graph = value
1316
1671
 
1317
1672
  def op_identifier(self) -> _protocols.OperatorIdentifier:
1673
+ """Return the operator identifier of the node.
1674
+
1675
+ The operator identifier is a tuple of the domain, op_type and overload.
1676
+ """
1318
1677
  return self.domain, self.op_type, self.overload
1319
1678
 
1320
1679
  def display(self, *, page: bool = False) -> None:
1680
+ """Pretty print the node.
1681
+
1682
+ This method is used for debugging and visualization purposes.
1683
+ """
1321
1684
  # Add the node's name to the displayed text
1322
1685
  print(f"Node: {self.name!r}")
1323
1686
  if self.doc_string:
@@ -1344,7 +1707,7 @@ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashabl
1344
1707
 
1345
1708
  @property
1346
1709
  def elem_type(self) -> _enums.DataType:
1347
- """Return the element type of the tensor type"""
1710
+ """Return the element type of the tensor type."""
1348
1711
  return self.dtype
1349
1712
 
1350
1713
  def __hash__(self) -> int:
@@ -1438,18 +1801,19 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1438
1801
 
1439
1802
  To find all the nodes that use this value as an input, call :meth:`uses`.
1440
1803
 
1441
- To check if the value is an output of a graph, call :meth:`is_graph_output`.
1804
+ To check if the value is an is an input, output or initializer of a graph,
1805
+ use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
1442
1806
 
1443
- Attributes:
1444
- name: The name of the value. A value is always named when it is part of a graph.
1445
- shape: The shape of the value.
1446
- type: The type of the value.
1447
- metadata_props: Metadata.
1807
+ Use :meth:`graph` to get the graph that owns the value.
1448
1808
  """
1449
1809
 
1450
1810
  __slots__ = (
1451
1811
  "_const_value",
1812
+ "_graph",
1452
1813
  "_index",
1814
+ "_is_graph_input",
1815
+ "_is_graph_output",
1816
+ "_is_initializer",
1453
1817
  "_metadata",
1454
1818
  "_metadata_props",
1455
1819
  "_name",
@@ -1497,19 +1861,33 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1497
1861
  # Use a collection of (Node, int) to store uses. This is needed
1498
1862
  # because a single use can use the same value multiple times.
1499
1863
  # Use a dictionary to preserve insertion order so that the visiting order is deterministic
1500
- self._uses: dict[tuple[Node, int], None] = {}
1864
+ self._uses: dict[Usage, None] = {}
1501
1865
  self.doc_string = doc_string
1502
1866
 
1867
+ # The graph this value belongs to. It is set *only* when the value is added as
1868
+ # a graph input, output or initializer.
1869
+ # The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
1870
+ self._graph: Graph | None = None
1871
+ self._is_graph_input: bool = False
1872
+ self._is_graph_output: bool = False
1873
+ self._is_initializer: bool = False
1874
+
1503
1875
  def __repr__(self) -> str:
1504
1876
  value_name = self.name if self.name else "anonymous:" + str(id(self))
1877
+ type_text = f", type={self.type!r}" if self.type is not None else ""
1878
+ shape_text = f", shape={self.shape!r}" if self.shape is not None else ""
1505
1879
  producer = self.producer()
1506
1880
  if producer is None:
1507
- producer_text = "None"
1881
+ producer_text = ""
1508
1882
  elif producer.name is not None:
1509
- producer_text = producer.name
1883
+ producer_text = f", producer='{producer.name}'"
1510
1884
  else:
1511
- producer_text = f"anonymous_node:{id(producer)}"
1512
- return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})"
1885
+ producer_text = f", producer=anonymous_node:{id(producer)}"
1886
+ index_text = f", index={self.index()}" if self.index() is not None else ""
1887
+ const_value_text = self._constant_tensor_part()
1888
+ if const_value_text:
1889
+ const_value_text = f", const_value={const_value_text}"
1890
+ return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})"
1513
1891
 
1514
1892
  def __str__(self) -> str:
1515
1893
  value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
@@ -1518,41 +1896,85 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1518
1896
 
1519
1897
  # Quote the name because in reality the names can have invalid characters
1520
1898
  # that make them hard to read
1521
- return f"%{_quoted(value_name)}<{type_text},{shape_text}>"
1899
+ return (
1900
+ f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
1901
+ )
1902
+
1903
+ def _constant_tensor_part(self) -> str:
1904
+ """Display string for the constant tensor attached to str of Value."""
1905
+ if self.const_value is not None:
1906
+ # Only display when the const value is small
1907
+ if self.const_value.size <= 10:
1908
+ return f"{{{self.const_value}}}"
1909
+ else:
1910
+ return f"{{{self.const_value.__class__.__name__}(...)}}"
1911
+ return ""
1912
+
1913
+ @property
1914
+ def graph(self) -> Graph | None:
1915
+ """Return the graph that defines this value.
1916
+
1917
+ When the value is an input/output/initializer of a graph, the owning graph
1918
+ is that graph. When the value is an output of a node, the owning graph is the
1919
+ graph that the node belongs to. When the value is not owned by any graph,
1920
+ it returns ``None``.
1921
+ """
1922
+ if self._graph is not None:
1923
+ return self._graph
1924
+ if self._producer is not None:
1925
+ return self._producer.graph
1926
+ return None
1927
+
1928
+ def _owned_by_graph(self) -> bool:
1929
+ """Return True if the value is owned by a graph."""
1930
+ result = self._is_graph_input or self._is_graph_output or self._is_initializer
1931
+ if result:
1932
+ assert self._graph is not None
1933
+ return result
1522
1934
 
1523
1935
  def producer(self) -> Node | None:
1524
1936
  """The node that produces this value.
1525
1937
 
1526
1938
  When producer is ``None``, the value does not belong to a node, and is
1527
- typically a graph input or an initializer.
1939
+ typically a graph input or an initializer. You can use :meth:`graph``
1940
+ to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
1941
+ or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
1528
1942
  """
1529
1943
  return self._producer
1530
1944
 
1945
+ def consumers(self) -> Sequence[Node]:
1946
+ """Return the nodes (deduplicated) that consume this value."""
1947
+ return tuple({usage.node: None for usage in self._uses})
1948
+
1531
1949
  def index(self) -> int | None:
1532
1950
  """The index of the output of the defining node."""
1533
1951
  return self._index
1534
1952
 
1535
- def uses(self) -> Collection[tuple[Node, int]]:
1953
+ def uses(self) -> Collection[Usage]:
1536
1954
  """Return a set of uses of the value.
1537
1955
 
1538
1956
  The set contains tuples of ``(Node, index)`` where the index is the index of the input
1539
1957
  of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
1540
1958
  """
1541
- return self._uses.keys()
1959
+ # Create a tuple for the collection so that iteration on will will not
1960
+ # be affected when the usage changes during graph mutation.
1961
+ # This adds a small overhead but is better a user experience than
1962
+ # having users call tuple().
1963
+ return tuple(self._uses)
1542
1964
 
1543
1965
  def _add_usage(self, use: Node, index: int) -> None:
1544
1966
  """Add a usage of this value.
1545
1967
 
1546
1968
  This is an internal method. It should only be called by the Node class.
1547
1969
  """
1548
- self._uses[(use, index)] = None
1970
+ self._uses[Usage(use, index)] = None
1549
1971
 
1550
1972
  def _remove_usage(self, use: Node, index: int) -> None:
1551
1973
  """Remove a node from the uses of this value.
1552
1974
 
1553
1975
  This is an internal method. It should only be called by the Node class.
1554
1976
  """
1555
- self._uses.pop((use, index))
1977
+ self._uses.pop(Usage(use, index))
1556
1978
 
1557
1979
  @property
1558
1980
  def name(self) -> str | None:
@@ -1652,15 +2074,17 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1652
2074
  self._metadata_props = {}
1653
2075
  return self._metadata_props
1654
2076
 
2077
+ def is_graph_input(self) -> bool:
2078
+ """Whether the value is an input of a graph."""
2079
+ return self._is_graph_input
2080
+
1655
2081
  def is_graph_output(self) -> bool:
1656
2082
  """Whether the value is an output of a graph."""
1657
- if (producer := self.producer()) is None:
1658
- return False
1659
- if (graph := producer.graph) is None:
1660
- return False
1661
- # Cannot use `in` because __eq__ may be defined by subclasses, even though
1662
- # it is not recommended
1663
- return any(output is self for output in graph.outputs)
2083
+ return self._is_graph_output
2084
+
2085
+ def is_initializer(self) -> bool:
2086
+ """Whether the value is an initializer of a graph."""
2087
+ return self._is_initializer
1664
2088
 
1665
2089
 
1666
2090
  def Input(
@@ -1673,7 +2097,6 @@ def Input(
1673
2097
 
1674
2098
  This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
1675
2099
  """
1676
-
1677
2100
  # NOTE: The function name is capitalized to maintain API backward compatibility.
1678
2101
 
1679
2102
  return Value(name=name, shape=shape, type=type, doc_string=doc_string)
@@ -1770,19 +2193,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1770
2193
  self.name = name
1771
2194
 
1772
2195
  # Private fields that are not to be accessed by any other classes
1773
- self._inputs = list(inputs)
1774
- self._outputs = list(outputs)
1775
- self._initializers = {}
1776
- for initializer in initializers:
1777
- if isinstance(initializer, str):
1778
- raise TypeError(
1779
- "Initializer must be a Value, not a string. "
1780
- "If you are copying the initializers from another graph, "
1781
- "make sure you call graph.initializers.values() because it is a dictionary."
1782
- )
1783
- if initializer.name is None:
1784
- raise ValueError(f"Initializer must have a name: {initializer}")
1785
- self._initializers[initializer.name] = initializer
2196
+ self._inputs = _graph_containers.GraphInputs(self, inputs)
2197
+ self._outputs = _graph_containers.GraphOutputs(self, outputs)
2198
+ self._initializers = _graph_containers.GraphInitializers(
2199
+ self, {initializer.name: initializer for initializer in initializers}
2200
+ )
1786
2201
  self._doc_string = doc_string
1787
2202
  self._opset_imports = opset_imports or {}
1788
2203
  self._metadata: _metadata.MetadataStore | None = None
@@ -1791,21 +2206,67 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1791
2206
  # Be sure the initialize the name authority before extending the nodes
1792
2207
  # because it is used to name the nodes and their outputs
1793
2208
  self._name_authority = _name_authority.NameAuthority()
2209
+ # TODO(justinchuby): Trigger again if inputs or initializers are modified.
2210
+ self._set_input_and_initializer_value_names_into_name_authority()
1794
2211
  # Call self.extend not self._nodes.extend so the graph reference is added to the nodes
1795
2212
  self.extend(nodes)
1796
2213
 
1797
2214
  @property
1798
- def inputs(self) -> list[Value]:
2215
+ def inputs(self) -> MutableSequence[Value]:
1799
2216
  return self._inputs
1800
2217
 
1801
2218
  @property
1802
- def outputs(self) -> list[Value]:
2219
+ def outputs(self) -> MutableSequence[Value]:
1803
2220
  return self._outputs
1804
2221
 
1805
2222
  @property
1806
- def initializers(self) -> dict[str, Value]:
2223
+ def initializers(self) -> _graph_containers.GraphInitializers:
2224
+ """The initializers of the graph as a ``MutableMapping[str, Value]``.
2225
+
2226
+ The keys are the names of the initializers. The values are the :class:`Value` objects.
2227
+
2228
+ This property additionally supports the ``add`` method, which takes a :class:`Value`
2229
+ and adds it to the initializers if it is not already present.
2230
+
2231
+ .. note::
2232
+ When setting an initializer with ``graph.initializers[key] = value``,
2233
+ if the value does not have a name, it will be assigned ``key`` as its name.
2234
+
2235
+ """
1807
2236
  return self._initializers
1808
2237
 
2238
+ def register_initializer(self, value: Value) -> None:
2239
+ """Register an initializer to the graph.
2240
+
2241
+ This is a convenience method to register an initializer to the graph with
2242
+ checks.
2243
+
2244
+ Args:
2245
+ value: The :class:`Value` to register as an initializer of the graph.
2246
+ It must have its ``.const_value`` set.
2247
+
2248
+ Raises:
2249
+ ValueError: If a value of the same name that is not this value
2250
+ is already registered.
2251
+ ValueError: If the value does not have a name.
2252
+ ValueError: If the initializer is produced by a node.
2253
+ ValueError: If the value does not have its ``.const_value`` set.
2254
+ """
2255
+ if not value.name:
2256
+ raise ValueError(f"Initializer must have a name: {value!r}")
2257
+ if value.name in self._initializers:
2258
+ if self._initializers[value.name] is not value:
2259
+ raise ValueError(
2260
+ f"Initializer '{value.name}' is already registered, but"
2261
+ " it is not the same object: existing={self._initializers[value.name]!r},"
2262
+ f" new={value!r}"
2263
+ )
2264
+ if value.const_value is None:
2265
+ raise ValueError(
2266
+ f"Value '{value!r}' must have its const_value set to be an initializer."
2267
+ )
2268
+ self._initializers.add(value)
2269
+
1809
2270
  @property
1810
2271
  def doc_string(self) -> str | None:
1811
2272
  return self._doc_string
@@ -1818,7 +2279,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1818
2279
  def opset_imports(self) -> dict[str, int]:
1819
2280
  return self._opset_imports
1820
2281
 
1821
- def __getitem__(self, index: int) -> Node:
2282
+ @typing.overload
2283
+ def __getitem__(self, index: int) -> Node: ...
2284
+ @typing.overload
2285
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2286
+
2287
+ def __getitem__(self, index):
1822
2288
  return self._nodes[index]
1823
2289
 
1824
2290
  def __len__(self) -> int:
@@ -1830,6 +2296,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1830
2296
  def __reversed__(self) -> Iterator[Node]:
1831
2297
  return reversed(self._nodes)
1832
2298
 
2299
+ def _set_input_and_initializer_value_names_into_name_authority(self):
2300
+ for value in self.inputs:
2301
+ self._name_authority.register_or_name_value(value)
2302
+ for value in self.initializers.values():
2303
+ self._name_authority.register_or_name_value(value)
2304
+
1833
2305
  def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
1834
2306
  """Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
1835
2307
  if node.graph is not None and node.graph is not self:
@@ -1993,7 +2465,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1993
2465
 
1994
2466
  This sort is stable. It preserves the original order as much as possible.
1995
2467
 
1996
- Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
2468
+ Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
1997
2469
 
1998
2470
  Raises:
1999
2471
  ValueError: If the graph contains a cycle, making topological sorting impossible.
@@ -2004,7 +2476,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2004
2476
  sorted_nodes_by_graph: dict[Graph, list[Node]] = {
2005
2477
  graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
2006
2478
  }
2007
- # TODO: Explain why we need to store direct predecessors and children and why
2479
+ # TODO(justinchuby): Explain why we need to store direct predecessors and children and why
2008
2480
  # we only need to store the direct ones
2009
2481
 
2010
2482
  # The depth of a node is defined as the number of direct children it has
@@ -2024,7 +2496,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2024
2496
  node_depth[predecessor] += 1
2025
2497
 
2026
2498
  # 1. Build the direct predecessors of each node and the depth of each node
2027
- # for sorting topolocally using Kahn's algorithm.
2499
+ # for sorting topologically using Kahn's algorithm.
2028
2500
  # Note that when a node contains graph attributes (aka. has subgraphs),
2029
2501
  # we consider all nodes in the subgraphs *predecessors* of this node. This
2030
2502
  # way we ensure the implicit dependencies of the subgraphs are captured
@@ -2125,11 +2597,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
2125
2597
  )
2126
2598
  signature = f"""\
2127
2599
  graph(
2128
- name={graph.name or 'anonymous_graph:' + str(id(graph))},
2129
- inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2600
+ name={graph.name or "anonymous_graph:" + str(id(graph))},
2601
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
2130
2602
  ),
2131
- outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2132
- ),{textwrap.indent(initializers_text, ' ' * 4)}
2603
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
2604
+ ),{textwrap.indent(initializers_text, " " * 4)}
2133
2605
  )"""
2134
2606
  node_count = len(graph)
2135
2607
  number_width = len(str(node_count))
@@ -2163,11 +2635,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
2163
2635
  )
2164
2636
  return f"""\
2165
2637
  {graph.__class__.__name__}(
2166
- name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
2167
- inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2638
+ name={graph.name or "anonymous_graph:" + str(id(graph))!r},
2639
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
2168
2640
  ),
2169
- outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2170
- ),{textwrap.indent(initializers_text, ' ' * 4)}
2641
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
2642
+ ),{textwrap.indent(initializers_text, " " * 4)}
2171
2643
  len()={len(graph)}
2172
2644
  )"""
2173
2645
 
@@ -2223,7 +2695,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2223
2695
  outputs: Sequence[Value],
2224
2696
  *,
2225
2697
  nodes: Iterable[Node],
2226
- initializers: Sequence[_protocols.ValueProtocol] = (),
2698
+ initializers: Sequence[Value] = (),
2227
2699
  doc_string: str | None = None,
2228
2700
  opset_imports: dict[str, int] | None = None,
2229
2701
  name: str | None = None,
@@ -2232,17 +2704,19 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2232
2704
  self.name = name
2233
2705
  self.inputs = tuple(inputs)
2234
2706
  self.outputs = tuple(outputs)
2235
- for initializer in initializers:
2236
- if initializer.name is None:
2237
- raise ValueError(f"Initializer must have a name: {initializer}")
2238
- self.initializers = {tensor.name: tensor for tensor in initializers}
2707
+ self.initializers = {initializer.name: initializer for initializer in initializers}
2239
2708
  self.doc_string = doc_string
2240
2709
  self.opset_imports = opset_imports or {}
2241
2710
  self._metadata: _metadata.MetadataStore | None = None
2242
2711
  self._metadata_props: dict[str, str] | None = metadata_props
2243
2712
  self._nodes: tuple[Node, ...] = tuple(nodes)
2244
2713
 
2245
- def __getitem__(self, index: int) -> Node:
2714
+ @typing.overload
2715
+ def __getitem__(self, index: int) -> Node: ...
2716
+ @typing.overload
2717
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2718
+
2719
+ def __getitem__(self, index):
2246
2720
  return self._nodes[index]
2247
2721
 
2248
2722
  def __len__(self) -> int:
@@ -2318,7 +2792,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
2318
2792
  model_version: int | None = None,
2319
2793
  doc_string: str | None = None,
2320
2794
  functions: Sequence[Function] = (),
2321
- meta_data_props: dict[str, str] | None = None,
2795
+ metadata_props: dict[str, str] | None = None,
2322
2796
  ) -> None:
2323
2797
  self.graph: Graph = graph
2324
2798
  self.ir_version = ir_version
@@ -2329,7 +2803,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
2329
2803
  self.doc_string = doc_string
2330
2804
  self._functions = {func.identifier(): func for func in functions}
2331
2805
  self._metadata: _metadata.MetadataStore | None = None
2332
- self._metadata_props: dict[str, str] | None = meta_data_props
2806
+ self._metadata_props: dict[str, str] | None = metadata_props
2333
2807
 
2334
2808
  @property
2335
2809
  def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
@@ -2381,9 +2855,28 @@ Model(
2381
2855
  domain={self.domain!r},
2382
2856
  model_version={self.model_version!r},
2383
2857
  functions={self.functions!r},
2384
- graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
2858
+ graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
2385
2859
  )"""
2386
2860
 
2861
+ def graphs(self) -> Iterable[Graph]:
2862
+ """Get all graphs and subgraphs in the model.
2863
+
2864
+ This is a convenience method to traverse the model. Consider using
2865
+ `onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2866
+ traversals on nodes.
2867
+ """
2868
+ # NOTE(justinchuby): Given
2869
+ # (1) how useful the method is
2870
+ # (2) I couldn't find an appropriate name for it in `traversal.py`
2871
+ # (3) Users familiar with onnxruntime optimization tools expect this method
2872
+ # I created this method as a core method instead of an iterator in
2873
+ # `traversal.py`.
2874
+ seen_graphs: set[Graph] = set()
2875
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2876
+ if node.graph is not None and node.graph not in seen_graphs:
2877
+ seen_graphs.add(node.graph)
2878
+ yield node.graph
2879
+
2387
2880
 
2388
2881
  class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
2389
2882
  """IR functions.
@@ -2404,16 +2897,14 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2404
2897
  outputs: The output values of the function.
2405
2898
  opset_imports: Opsets imported by the function.
2406
2899
  doc_string: Documentation string.
2407
- metadata_props: Metadata that will be serialized to the ONNX file.
2408
2900
  meta: Metadata store for graph transform passes.
2901
+ metadata_props: Metadata that will be serialized to the ONNX file.
2409
2902
  """
2410
2903
 
2411
2904
  __slots__ = (
2412
2905
  "_attributes",
2413
2906
  "_domain",
2414
2907
  "_graph",
2415
- "_metadata",
2416
- "_metadata_props",
2417
2908
  "_name",
2418
2909
  "_overload",
2419
2910
  )
@@ -2427,16 +2918,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2427
2918
  # Ensure the inputs and outputs of the function belong to a graph
2428
2919
  # and not from an outer scope
2429
2920
  graph: Graph,
2430
- attributes: Sequence[Attr],
2431
- metadata_props: dict[str, str] | None = None,
2921
+ attributes: Iterable[Attr] | Mapping[str, Attr],
2432
2922
  ) -> None:
2433
2923
  self._domain = domain
2434
2924
  self._name = name
2435
2925
  self._overload = overload
2436
2926
  self._graph = graph
2437
- self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2438
- self._metadata: _metadata.MetadataStore | None = None
2439
- self._metadata_props: dict[str, str] | None = metadata_props
2927
+ if isinstance(attributes, Mapping):
2928
+ attributes = tuple(attributes.values())
2929
+ self._attributes = _graph_containers.Attributes(attributes)
2440
2930
 
2441
2931
  def identifier(self) -> _protocols.OperatorIdentifier:
2442
2932
  return self.domain, self.name, self.overload
@@ -2455,7 +2945,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2455
2945
 
2456
2946
  @domain.setter
2457
2947
  def domain(self, value: str) -> None:
2458
- self._domain = value
2948
+ self._domain = _normalize_domain(value)
2459
2949
 
2460
2950
  @property
2461
2951
  def overload(self) -> str:
@@ -2466,18 +2956,23 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2466
2956
  self._overload = value
2467
2957
 
2468
2958
  @property
2469
- def inputs(self) -> list[Value]:
2959
+ def inputs(self) -> MutableSequence[Value]:
2470
2960
  return self._graph.inputs
2471
2961
 
2472
2962
  @property
2473
- def outputs(self) -> list[Value]:
2963
+ def outputs(self) -> MutableSequence[Value]:
2474
2964
  return self._graph.outputs
2475
2965
 
2476
2966
  @property
2477
- def attributes(self) -> OrderedDict[str, Attr]:
2967
+ def attributes(self) -> _graph_containers.Attributes:
2478
2968
  return self._attributes
2479
2969
 
2480
- def __getitem__(self, index: int) -> Node:
2970
+ @typing.overload
2971
+ def __getitem__(self, index: int) -> Node: ...
2972
+ @typing.overload
2973
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2974
+
2975
+ def __getitem__(self, index):
2481
2976
  return self._graph.__getitem__(index)
2482
2977
 
2483
2978
  def __len__(self) -> int:
@@ -2508,15 +3003,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2508
3003
  Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2509
3004
  to the ONNX proto.
2510
3005
  """
2511
- if self._metadata is None:
2512
- self._metadata = _metadata.MetadataStore()
2513
- return self._metadata
3006
+ return self._graph.meta
2514
3007
 
2515
3008
  @property
2516
3009
  def metadata_props(self) -> dict[str, str]:
2517
- if self._metadata_props is None:
2518
- self._metadata_props = {}
2519
- return self._metadata_props
3010
+ return self._graph.metadata_props
2520
3011
 
2521
3012
  # Mutation methods
2522
3013
  def append(self, node: Node, /) -> None:
@@ -2549,11 +3040,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2549
3040
  """
2550
3041
  self._graph.remove(nodes, safe=safe)
2551
3042
 
2552
- def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
3043
+ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
2553
3044
  """Insert new nodes after the given node in O(#new_nodes) time."""
2554
3045
  self._graph.insert_after(node, new_nodes)
2555
3046
 
2556
- def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
3047
+ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
2557
3048
  """Insert new nodes before the given node in O(#new_nodes) time."""
2558
3049
  self._graph.insert_before(node, new_nodes)
2559
3050
 
@@ -2581,10 +3072,10 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2581
3072
  >
2582
3073
  def {full_name}(
2583
3074
  inputs=(
2584
- {textwrap.indent(inputs_text, ' ' * 8)}
2585
- ),{textwrap.indent(attributes_text, ' ' * 4)}
3075
+ {textwrap.indent(inputs_text, " " * 8)}
3076
+ ),{textwrap.indent(attributes_text, " " * 4)}
2586
3077
  outputs=(
2587
- {textwrap.indent(outputs_text, ' ' * 8)}
3078
+ {textwrap.indent(outputs_text, " " * 8)}
2588
3079
  ),
2589
3080
  )"""
2590
3081
  node_count = len(self)
@@ -2611,22 +3102,28 @@ def {full_name}(
2611
3102
  return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
2612
3103
 
2613
3104
 
2614
- class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
2615
- """Reference attribute."""
3105
+ class Attr(
3106
+ _protocols.AttributeProtocol,
3107
+ _protocols.ReferenceAttributeProtocol,
3108
+ _display.PrettyPrintable,
3109
+ ):
3110
+ """Base class for ONNX attributes or references."""
2616
3111
 
2617
- __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
3112
+ __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
2618
3113
 
2619
3114
  def __init__(
2620
3115
  self,
2621
3116
  name: str,
2622
- ref_attr_name: str,
2623
3117
  type: _enums.AttributeType,
3118
+ value: Any,
3119
+ ref_attr_name: str | None = None,
2624
3120
  *,
2625
3121
  doc_string: str | None = None,
2626
3122
  ) -> None:
2627
3123
  self._name = name
2628
- self._ref_attr_name = ref_attr_name
2629
3124
  self._type = type
3125
+ self._value = value
3126
+ self._ref_attr_name = ref_attr_name
2630
3127
  self.doc_string = doc_string
2631
3128
 
2632
3129
  @property
@@ -2637,43 +3134,21 @@ class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
2637
3134
  def name(self, value: str) -> None:
2638
3135
  self._name = value
2639
3136
 
2640
- @property
2641
- def ref_attr_name(self) -> str:
2642
- return self._ref_attr_name
2643
-
2644
- @ref_attr_name.setter
2645
- def ref_attr_name(self, value: str) -> None:
2646
- self._ref_attr_name = value
2647
-
2648
3137
  @property
2649
3138
  def type(self) -> _enums.AttributeType:
2650
3139
  return self._type
2651
3140
 
2652
- @type.setter
2653
- def type(self, value: _enums.AttributeType) -> None:
2654
- self._type = value
2655
-
2656
- def __repr__(self) -> str:
2657
- return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
2658
-
3141
+ @property
3142
+ def value(self) -> Any:
3143
+ return self._value
2659
3144
 
2660
- class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
2661
- """Base class for ONNX attributes."""
3145
+ @property
3146
+ def ref_attr_name(self) -> str | None:
3147
+ return self._ref_attr_name
2662
3148
 
2663
- __slots__ = ("doc_string", "name", "type", "value")
2664
-
2665
- def __init__(
2666
- self,
2667
- name: str,
2668
- type: _enums.AttributeType,
2669
- value: Any,
2670
- *,
2671
- doc_string: str | None = None,
2672
- ):
2673
- self.name = name
2674
- self.type = type
2675
- self.value = value
2676
- self.doc_string = doc_string
3149
+ def is_ref(self) -> bool:
3150
+ """Check if this attribute is a reference attribute."""
3151
+ return self.ref_attr_name is not None
2677
3152
 
2678
3153
  def __eq__(self, other: object) -> bool:
2679
3154
  if not isinstance(other, _protocols.AttributeProtocol):
@@ -2690,15 +3165,157 @@ class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
2690
3165
  return True
2691
3166
 
2692
3167
  def __str__(self) -> str:
3168
+ if self.is_ref():
3169
+ return f"@{self.ref_attr_name}"
2693
3170
  if self.type == _enums.AttributeType.GRAPH:
2694
3171
  return textwrap.indent("\n" + str(self.value), " " * 4)
2695
3172
  return str(self.value)
2696
3173
 
2697
3174
  def __repr__(self) -> str:
3175
+ if self.is_ref():
3176
+ return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
2698
3177
  return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
2699
3178
 
3179
+ # Well typed getters
3180
+ def as_float(self) -> float:
3181
+ """Get the attribute value as a float."""
3182
+ if self.type != _enums.AttributeType.FLOAT:
3183
+ raise TypeError(
3184
+ f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
3185
+ )
3186
+ # Do not use isinstance check because it may prevent np.float32 etc. from being used
3187
+ return float(self.value)
3188
+
3189
+ def as_int(self) -> int:
3190
+ """Get the attribute value as an int."""
3191
+ if self.type != _enums.AttributeType.INT:
3192
+ raise TypeError(
3193
+ f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
3194
+ )
3195
+ # Do not use isinstance check because it may prevent np.int32 etc. from being used
3196
+ return int(self.value)
3197
+
3198
+ def as_string(self) -> str:
3199
+ """Get the attribute value as a string."""
3200
+ if self.type != _enums.AttributeType.STRING:
3201
+ raise TypeError(
3202
+ f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
3203
+ )
3204
+ if not isinstance(self.value, str):
3205
+ raise TypeError(f"Value of attribute '{self!r}' is not a string.")
3206
+ return self.value
3207
+
3208
+ def as_tensor(self) -> _protocols.TensorProtocol:
3209
+ """Get the attribute value as a tensor."""
3210
+ if self.type != _enums.AttributeType.TENSOR:
3211
+ raise TypeError(
3212
+ f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
3213
+ )
3214
+ if not isinstance(self.value, _protocols.TensorProtocol):
3215
+ raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
3216
+ return self.value
3217
+
3218
+ def as_graph(self) -> Graph:
3219
+ """Get the attribute value as a graph."""
3220
+ if self.type != _enums.AttributeType.GRAPH:
3221
+ raise TypeError(
3222
+ f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
3223
+ )
3224
+ if not isinstance(self.value, Graph):
3225
+ raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
3226
+ return self.value
3227
+
3228
+ def as_floats(self) -> Sequence[float]:
3229
+ """Get the attribute value as a sequence of floats."""
3230
+ if self.type != _enums.AttributeType.FLOATS:
3231
+ raise TypeError(
3232
+ f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
3233
+ )
3234
+ if not isinstance(self.value, Sequence):
3235
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3236
+ # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3237
+ # Create a copy of the list to prevent mutation
3238
+ return [float(v) for v in self.value]
3239
+
3240
+ def as_ints(self) -> Sequence[int]:
3241
+ """Get the attribute value as a sequence of ints."""
3242
+ if self.type != _enums.AttributeType.INTS:
3243
+ raise TypeError(
3244
+ f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
3245
+ )
3246
+ if not isinstance(self.value, Sequence):
3247
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3248
+ # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3249
+ # Create a copy of the list to prevent mutation
3250
+ return list(self.value)
3251
+
3252
+ def as_strings(self) -> Sequence[str]:
3253
+ """Get the attribute value as a sequence of strings."""
3254
+ if self.type != _enums.AttributeType.STRINGS:
3255
+ raise TypeError(
3256
+ f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
3257
+ )
3258
+ if not isinstance(self.value, Sequence):
3259
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3260
+ if onnx_ir.DEBUG:
3261
+ if not all(isinstance(x, str) for x in self.value):
3262
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
3263
+ # Create a copy of the list to prevent mutation
3264
+ return list(self.value)
3265
+
3266
+ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3267
+ """Get the attribute value as a sequence of tensors."""
3268
+ if self.type != _enums.AttributeType.TENSORS:
3269
+ raise TypeError(
3270
+ f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
3271
+ )
3272
+ if not isinstance(self.value, Sequence):
3273
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3274
+ if onnx_ir.DEBUG:
3275
+ if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
3276
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
3277
+ # Create a copy of the list to prevent mutation
3278
+ return list(self.value)
3279
+
3280
+ def as_graphs(self) -> Sequence[Graph]:
3281
+ """Get the attribute value as a sequence of graphs."""
3282
+ if self.type != _enums.AttributeType.GRAPHS:
3283
+ raise TypeError(
3284
+ f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
3285
+ )
3286
+ if not isinstance(self.value, Sequence):
3287
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3288
+ if onnx_ir.DEBUG:
3289
+ if not all(isinstance(x, Graph) for x in self.value):
3290
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
3291
+ # Create a copy of the list to prevent mutation
3292
+ return list(self.value)
3293
+
2700
3294
 
2701
3295
  # NOTE: The following functions are just for convenience
3296
+
3297
+
3298
+ def RefAttr(
3299
+ name: str,
3300
+ ref_attr_name: str,
3301
+ type: _enums.AttributeType,
3302
+ doc_string: str | None = None,
3303
+ ) -> Attr:
3304
+ """Create a reference attribute.
3305
+
3306
+ Args:
3307
+ name: The name of the attribute.
3308
+ type: The type of the attribute.
3309
+ ref_attr_name: The name of the referenced attribute.
3310
+ doc_string: Documentation string.
3311
+
3312
+ Returns:
3313
+ A reference attribute.
3314
+ """
3315
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
3316
+ return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
3317
+
3318
+
2702
3319
  def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
2703
3320
  """Create a float attribute."""
2704
3321
  # NOTE: The function name is capitalized to maintain API backward compatibility.