onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """data structures for the intermediate representation."""
4
4
 
5
5
  # NOTES for developers:
@@ -22,26 +22,37 @@ import os
22
22
  import sys
23
23
  import textwrap
24
24
  import typing
25
- from typing import (
26
- AbstractSet,
27
- Any,
25
+ from collections import OrderedDict
26
+ from collections.abc import (
28
27
  Collection,
29
- Generic,
30
28
  Hashable,
31
29
  Iterable,
32
30
  Iterator,
33
- OrderedDict,
31
+ MutableMapping,
32
+ MutableSequence,
34
33
  Sequence,
34
+ )
35
+ from collections.abc import (
36
+ Set as AbstractSet,
37
+ )
38
+ from typing import (
39
+ Any,
40
+ Callable,
41
+ Generic,
42
+ NamedTuple,
43
+ SupportsInt,
35
44
  Union,
36
45
  )
37
46
 
38
47
  import ml_dtypes
39
48
  import numpy as np
49
+ from typing_extensions import TypeIs
40
50
 
41
51
  import onnx_ir
42
52
  from onnx_ir import (
43
53
  _display,
44
54
  _enums,
55
+ _graph_containers,
45
56
  _linked_list,
46
57
  _metadata,
47
58
  _name_authority,
@@ -70,6 +81,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
70
81
  _enums.DataType.FLOAT8E5M2FNUZ,
71
82
  _enums.DataType.INT4,
72
83
  _enums.DataType.UINT4,
84
+ _enums.DataType.FLOAT4E2M1,
73
85
  )
74
86
  )
75
87
 
@@ -93,7 +105,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
93
105
  class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
94
106
  """Convenience Shared methods for classes implementing TensorProtocol."""
95
107
 
96
- __slots__ = ()
108
+ __slots__ = (
109
+ "_doc_string",
110
+ "_metadata",
111
+ "_metadata_props",
112
+ "_name",
113
+ )
114
+
115
+ def __init__(
116
+ self,
117
+ name: str | None = None,
118
+ doc_string: str | None = None,
119
+ metadata_props: dict[str, str] | None = None,
120
+ ) -> None:
121
+ self._metadata: _metadata.MetadataStore | None = None
122
+ self._metadata_props: dict[str, str] | None = metadata_props
123
+ self._name: str | None = name
124
+ self._doc_string: str | None = doc_string
97
125
 
98
126
  def _printable_type_shape(self) -> str:
99
127
  """Return a string representation of the shape and data type."""
@@ -106,10 +134,28 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
106
134
  """
107
135
  return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
108
136
 
137
+ @property
138
+ def name(self) -> str | None:
139
+ """The name of the tensor."""
140
+ return self._name
141
+
142
+ @name.setter
143
+ def name(self, value: str | None) -> None:
144
+ self._name = value
145
+
146
+ @property
147
+ def doc_string(self) -> str | None:
148
+ """The documentation string."""
149
+ return self._doc_string
150
+
151
+ @doc_string.setter
152
+ def doc_string(self, value: str | None) -> None:
153
+ self._doc_string = value
154
+
109
155
  @property
110
156
  def size(self) -> int:
111
157
  """The number of elements in the tensor."""
112
- return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined]
158
+ return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
113
159
 
114
160
  @property
115
161
  def nbytes(self) -> int:
@@ -117,6 +163,23 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
117
163
  # Use math.ceil because when dtype is INT4, the itemsize is 0.5
118
164
  return math.ceil(self.dtype.itemsize * self.size)
119
165
 
166
+ @property
167
+ def metadata_props(self) -> dict[str, str]:
168
+ if self._metadata_props is None:
169
+ self._metadata_props = {}
170
+ return self._metadata_props
171
+
172
+ @property
173
+ def meta(self) -> _metadata.MetadataStore:
174
+ """The metadata store for intermediate analysis.
175
+
176
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
177
+ to the ONNX proto.
178
+ """
179
+ if self._metadata is None:
180
+ self._metadata = _metadata.MetadataStore()
181
+ return self._metadata
182
+
120
183
  def display(self, *, page: bool = False) -> None:
121
184
  rich = _display.require_rich()
122
185
 
@@ -182,7 +245,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
182
245
  When the dtype is not one of the numpy native dtypes, the value needs need to be:
183
246
 
184
247
  - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
185
- - ``uint8`` for uint4.
248
+ - ``uint8`` for uint4 or float4.
186
249
  - ``uint8`` for 8-bit data types.
187
250
  - ``uint16`` for bfloat16
188
251
 
@@ -195,7 +258,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
195
258
  )
196
259
  if dtype.itemsize == 1 and array.dtype not in (
197
260
  np.uint8,
198
- ml_dtypes.float8_e4m3b11fnuz,
261
+ ml_dtypes.float8_e4m3fnuz,
199
262
  ml_dtypes.float8_e4m3fn,
200
263
  ml_dtypes.float8_e5m2fnuz,
201
264
  ml_dtypes.float8_e5m2,
@@ -213,6 +276,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
213
276
  raise TypeError(
214
277
  f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
215
278
  )
279
+ if dtype == _enums.DataType.FLOAT4E2M1:
280
+ if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
281
+ raise TypeError(
282
+ f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
283
+ )
216
284
  return
217
285
 
218
286
  try:
@@ -256,6 +324,8 @@ def _maybe_view_np_array_with_ml_dtypes(
256
324
  return array.view(ml_dtypes.int4)
257
325
  if dtype == _enums.DataType.UINT4:
258
326
  return array.view(ml_dtypes.uint4)
327
+ if dtype == _enums.DataType.FLOAT4E2M1:
328
+ return array.view(ml_dtypes.float4_e2m1fn)
259
329
  return array
260
330
 
261
331
 
@@ -298,12 +368,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
298
368
 
299
369
  __slots__ = (
300
370
  "_dtype",
301
- "_metadata",
302
- "_metadata_props",
303
371
  "_raw",
304
372
  "_shape",
305
- "doc_string",
306
- "name",
307
373
  )
308
374
 
309
375
  def __init__(
@@ -322,7 +388,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
322
388
  value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
323
389
  When the dtype is not one of the numpy native dtypes, the value needs
324
390
  to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
325
- when the value is a numpy array; :param:`dtype` must be specified in this case.
391
+ when the value is a numpy array; ``dtype`` must be specified in this case.
326
392
  dtype: The data type of the tensor. It can be None only when value is a numpy array.
327
393
  Users are responsible for making sure the dtype matches the value when value is not a numpy array.
328
394
  shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -336,6 +402,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
336
402
  ValueError: If the shape is not specified and the value does not have a shape attribute.
337
403
  ValueError: If the dtype is not specified and the value is not a numpy array.
338
404
  """
405
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
339
406
  # NOTE: We should not do any copying here for performance reasons
340
407
  if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
341
408
  raise TypeError(f"Expected an array compatible object, got {type(value)}")
@@ -349,7 +416,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
349
416
  self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
350
417
  else:
351
418
  self._shape = shape
352
- self._shape._frozen = True
419
+ self._shape.freeze()
353
420
  if dtype is None:
354
421
  if isinstance(value, np.ndarray):
355
422
  self._dtype = _enums.DataType.from_numpy(value.dtype)
@@ -370,17 +437,13 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
370
437
  value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
371
438
 
372
439
  self._raw = value
373
- self.name = name
374
- self.doc_string = doc_string
375
- self._metadata: _metadata.MetadataStore | None = None
376
- self._metadata_props = metadata_props
377
440
 
378
441
  def __array__(self, dtype: Any = None) -> np.ndarray:
379
442
  if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
380
443
  return self._raw.__array__(dtype)
381
- assert _compatible_with_dlpack(
382
- self._raw
383
- ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
444
+ assert _compatible_with_dlpack(self._raw), (
445
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
446
+ )
384
447
  return np.from_dlpack(self._raw)
385
448
 
386
449
  def __dlpack__(self, *, stream: Any = None) -> Any:
@@ -394,7 +457,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
394
457
  return self.__array__().__dlpack_device__()
395
458
 
396
459
  def __repr__(self) -> str:
397
- return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
460
+ # Avoid multi-line repr
461
+ tensor_lines = repr(self._raw).split("\n")
462
+ tensor_text = " ".join(line.strip() for line in tensor_lines)
463
+ return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
398
464
 
399
465
  @property
400
466
  def dtype(self) -> _enums.DataType:
@@ -431,7 +497,11 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
431
497
  """
432
498
  # TODO(justinchuby): Support DLPack
433
499
  array = self.numpy()
434
- if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
500
+ if self.dtype in {
501
+ _enums.DataType.INT4,
502
+ _enums.DataType.UINT4,
503
+ _enums.DataType.FLOAT4E2M1,
504
+ }:
435
505
  # Pack the array into int4
436
506
  array = _type_casting.pack_int4(array)
437
507
  else:
@@ -440,23 +510,6 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
440
510
  array = array.view(array.dtype.newbyteorder("<"))
441
511
  return array.tobytes()
442
512
 
443
- @property
444
- def metadata_props(self) -> dict[str, str]:
445
- if self._metadata_props is None:
446
- self._metadata_props = {}
447
- return self._metadata_props
448
-
449
- @property
450
- def meta(self) -> _metadata.MetadataStore:
451
- """The metadata store for intermediate analysis.
452
-
453
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
454
- to the ONNX proto.
455
- """
456
- if self._metadata is None:
457
- self._metadata = _metadata.MetadataStore()
458
- return self._metadata
459
-
460
513
 
461
514
  class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
462
515
  """An immutable concrete tensor with its data store on disk.
@@ -497,12 +550,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
497
550
  "_dtype",
498
551
  "_length",
499
552
  "_location",
500
- "_metadata",
501
- "_metadata_props",
502
553
  "_offset",
503
554
  "_shape",
504
- "doc_string",
505
- "name",
555
+ "_valid",
506
556
  "raw",
507
557
  )
508
558
 
@@ -532,6 +582,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
532
582
  metadata_props: The metadata properties.
533
583
  base_dir: The base directory for the external data. It is used to resolve relative paths.
534
584
  """
585
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
535
586
  # NOTE: Do not verify the location by default. This is because the location field
536
587
  # in the tensor proto can be anything and we would like deserialization from
537
588
  # proto to IR to not fail.
@@ -547,12 +598,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
547
598
  self._dtype: _enums.DataType = dtype
548
599
  self.name: str = name # mutable
549
600
  self._shape: Shape = shape
550
- self._shape._frozen = True
601
+ self._shape.freeze()
551
602
  self.doc_string: str | None = doc_string # mutable
552
603
  self._array: np.ndarray | None = None
553
604
  self.raw: mmap.mmap | None = None
554
605
  self._metadata_props = metadata_props
555
606
  self._metadata: _metadata.MetadataStore | None = None
607
+ self._valid = True
556
608
 
557
609
  @property
558
610
  def base_dir(self) -> str | os.PathLike:
@@ -594,6 +646,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
594
646
  return self._shape
595
647
 
596
648
  def _load(self):
649
+ self._check_validity()
597
650
  assert self._array is None, "Bug: The array should be loaded only once."
598
651
  if self.size == 0:
599
652
  # When the size is 0, mmap is impossible and meaningless
@@ -609,7 +662,11 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
609
662
  )
610
663
  # Handle the byte order correctly by always using little endian
611
664
  dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
612
- if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
665
+ if self.dtype in {
666
+ _enums.DataType.INT4,
667
+ _enums.DataType.UINT4,
668
+ _enums.DataType.FLOAT4E2M1,
669
+ }:
613
670
  # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
614
671
  dt = np.dtype(np.uint8).newbyteorder("<")
615
672
  count = self.size // 2 + self.size % 2
@@ -622,10 +679,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
622
679
  self._array = _type_casting.unpack_int4(self._array, shape)
623
680
  elif self.dtype == _enums.DataType.UINT4:
624
681
  self._array = _type_casting.unpack_uint4(self._array, shape)
682
+ elif self.dtype == _enums.DataType.FLOAT4E2M1:
683
+ self._array = _type_casting.unpack_float4e2m1(self._array, shape)
625
684
  else:
626
685
  self._array = self._array.reshape(shape)
627
686
 
628
687
  def __array__(self, dtype: Any = None) -> np.ndarray:
688
+ self._check_validity()
629
689
  if self._array is None:
630
690
  self._load()
631
691
  assert self._array is not None
@@ -654,6 +714,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
654
714
 
655
715
  The data will be memory mapped into memory and will not taken up physical memory space.
656
716
  """
717
+ self._check_validity()
657
718
  if self._array is None:
658
719
  self._load()
659
720
  assert self._array is not None
@@ -664,6 +725,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
664
725
 
665
726
  This will load the tensor into memory.
666
727
  """
728
+ self._check_validity()
667
729
  if self.raw is None:
668
730
  self._load()
669
731
  assert self.raw is not None
@@ -671,6 +733,26 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
671
733
  length = self._length or self.nbytes
672
734
  return self.raw[offset : offset + length]
673
735
 
736
+ def valid(self) -> bool:
737
+ """Check if the tensor is valid.
738
+
739
+ The external tensor is valid if it has not been invalidated.
740
+ """
741
+ return self._valid
742
+
743
+ def _check_validity(self) -> None:
744
+ if not self.valid():
745
+ raise ValueError(
746
+ f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
747
+ )
748
+
749
+ def invalidate(self) -> None:
750
+ """Invalidate the tensor.
751
+
752
+ The external tensor is invalidated when the data is known to be corrupted or deleted.
753
+ """
754
+ self._valid = False
755
+
674
756
  def release(self) -> None:
675
757
  """Delete all references to the memory buffer and close the memory-mapped file."""
676
758
  self._array = None
@@ -678,34 +760,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
678
760
  self.raw.close()
679
761
  self.raw = None
680
762
 
681
- @property
682
- def metadata_props(self) -> dict[str, str]:
683
- if self._metadata_props is None:
684
- self._metadata_props = {}
685
- return self._metadata_props
686
-
687
- @property
688
- def meta(self) -> _metadata.MetadataStore:
689
- """The metadata store for intermediate analysis.
690
-
691
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
692
- to the ONNX proto.
693
- """
694
- if self._metadata is None:
695
- self._metadata = _metadata.MetadataStore()
696
- return self._metadata
697
-
698
763
 
699
764
  class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
700
765
  """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
701
766
 
702
767
  __slots__ = (
703
- "_metadata",
704
- "_metadata_props",
705
768
  "_raw",
706
769
  "_shape",
707
- "doc_string",
708
- "name",
709
770
  )
710
771
 
711
772
  def __init__(
@@ -726,6 +787,7 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
726
787
  doc_string: The documentation string.
727
788
  metadata_props: The metadata properties.
728
789
  """
790
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
729
791
  if shape is None:
730
792
  if not hasattr(value, "shape"):
731
793
  raise ValueError(
@@ -735,19 +797,15 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
735
797
  self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
736
798
  else:
737
799
  self._shape = shape
738
- self._shape._frozen = True
800
+ self._shape.freeze()
739
801
  self._raw = value
740
- self.name = name
741
- self.doc_string = doc_string
742
- self._metadata: _metadata.MetadataStore | None = None
743
- self._metadata_props = metadata_props
744
802
 
745
803
  def __array__(self, dtype: Any = None) -> np.ndarray:
746
804
  if isinstance(self._raw, np.ndarray):
747
805
  return self._raw
748
- assert isinstance(
749
- self._raw, Sequence
750
- ), f"Bug: Expected a sequence, got {type(self._raw)}"
806
+ assert isinstance(self._raw, Sequence), (
807
+ f"Bug: Expected a sequence, got {type(self._raw)}"
808
+ )
751
809
  return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
752
810
 
753
811
  def __dlpack__(self, *, stream: Any = None) -> Any:
@@ -788,25 +846,125 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
788
846
  return self._raw.flatten().tolist()
789
847
  return self._raw
790
848
 
849
+
850
+ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
851
+ """A tensor that lazily evaluates a function to get the actual tensor.
852
+
853
+ This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
854
+ The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
855
+
856
+ Example::
857
+
858
+ >>> import numpy as np
859
+ >>> import onnx_ir as ir
860
+ >>> weights = np.array([[1, 2, 3]])
861
+ >>> def create_tensor(): # Delay applying transformations to the weights
862
+ ... weights_t = weights.transpose()
863
+ ... return ir.tensor(weights_t)
864
+ >>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
865
+ >>> print(lazy_tensor.numpy())
866
+ [[1]
867
+ [2]
868
+ [3]]
869
+
870
+ Attributes:
871
+ func: The function that returns the actual tensor.
872
+ dtype: The data type of the tensor.
873
+ shape: The shape of the tensor.
874
+ cache: Whether to cache the result of the function. If False,
875
+ the function is called every time the tensor content is accessed.
876
+ If True, the function is called only once and the result is cached in memory.
877
+ Default is False.
878
+ name: The name of the tensor.
879
+ doc_string: The documentation string.
880
+ metadata_props: The metadata properties.
881
+ """
882
+
883
+ __slots__ = (
884
+ "_dtype",
885
+ "_func",
886
+ "_shape",
887
+ "_tensor",
888
+ "cache",
889
+ )
890
+
891
+ def __init__(
892
+ self,
893
+ func: Callable[[], _protocols.TensorProtocol],
894
+ dtype: _enums.DataType,
895
+ shape: Shape,
896
+ *,
897
+ cache: bool = False,
898
+ name: str | None = None,
899
+ doc_string: str | None = None,
900
+ metadata_props: dict[str, str] | None = None,
901
+ ) -> None:
902
+ """Initialize a lazy tensor.
903
+
904
+ Args:
905
+ func: The function that returns the actual tensor.
906
+ dtype: The data type of the tensor.
907
+ shape: The shape of the tensor.
908
+ cache: Whether to cache the result of the function.
909
+ name: The name of the tensor.
910
+ doc_string: The documentation string.
911
+ metadata_props: The metadata properties.
912
+ """
913
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
914
+ self._func = func
915
+ self._dtype = dtype
916
+ self._shape = shape
917
+ self._tensor: _protocols.TensorProtocol | None = None
918
+ self.cache = cache
919
+
920
+ def _evaluate(self) -> _protocols.TensorProtocol:
921
+ """Evaluate the function to get the actual tensor."""
922
+ if not self.cache:
923
+ return self._func()
924
+
925
+ # Cache the tensor
926
+ if self._tensor is None:
927
+ self._tensor = self._func()
928
+ return self._tensor
929
+
930
+ def __array__(self, dtype: Any = None) -> np.ndarray:
931
+ return self._evaluate().__array__(dtype)
932
+
933
+ def __dlpack__(self, *, stream: Any = None) -> Any:
934
+ return self._evaluate().__dlpack__(stream=stream)
935
+
936
+ def __dlpack_device__(self) -> tuple[int, int]:
937
+ return self._evaluate().__dlpack_device__()
938
+
939
+ def __repr__(self) -> str:
940
+ return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
941
+
791
942
  @property
792
- def metadata_props(self) -> dict[str, str]:
793
- if self._metadata_props is None:
794
- self._metadata_props = {}
795
- return self._metadata_props
943
+ def raw(self) -> Callable[[], _protocols.TensorProtocol]:
944
+ return self._func
796
945
 
797
946
  @property
798
- def meta(self) -> _metadata.MetadataStore:
799
- """The metadata store for intermediate analysis.
947
+ def dtype(self) -> _enums.DataType:
948
+ """The data type of the tensor. Immutable."""
949
+ return self._dtype
800
950
 
801
- Write to the :attr:`metadata_props` if you would like the metadata to be serialized
802
- to the ONNX proto.
803
- """
804
- if self._metadata is None:
805
- self._metadata = _metadata.MetadataStore()
806
- return self._metadata
951
+ @property
952
+ def shape(self) -> Shape:
953
+ """The shape of the tensor. Immutable."""
954
+ return self._shape
955
+
956
+ def numpy(self) -> np.ndarray:
957
+ """Return the tensor as a numpy array."""
958
+ return self._evaluate().numpy()
959
+
960
+ def tobytes(self) -> bytes:
961
+ """Return the bytes of the tensor."""
962
+ return self._evaluate().tobytes()
807
963
 
808
964
 
809
965
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
966
+ """Immutable symbolic dimension that can be shared across multiple shapes."""
967
+
810
968
  __slots__ = ("_value",)
811
969
 
812
970
  def __init__(self, value: str | None) -> None:
@@ -841,12 +999,84 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
841
999
  return f"{self.__class__.__name__}({self._value})"
842
1000
 
843
1001
 
1002
+ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1003
+ """Return True if the value is int compatible."""
1004
+ if isinstance(value, int):
1005
+ return True
1006
+ if hasattr(value, "__int__"):
1007
+ # For performance reasons, we do not use isinstance(value, SupportsInt)
1008
+ return True
1009
+ return False
1010
+
1011
+
1012
+ def _maybe_convert_to_symbolic_dim(
1013
+ dim: int | SupportsInt | SymbolicDim | str | None,
1014
+ ) -> SymbolicDim | int:
1015
+ """Convert the value to a SymbolicDim if it is not an int."""
1016
+ if dim is None or isinstance(dim, str):
1017
+ return SymbolicDim(dim)
1018
+ if _is_int_compatible(dim):
1019
+ return int(dim)
1020
+ if isinstance(dim, SymbolicDim):
1021
+ return dim
1022
+ raise TypeError(
1023
+ f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
1024
+ )
1025
+
1026
+
844
1027
  class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1028
+ """The shape of a tensor, including its dimensions and optional denotations.
1029
+
1030
+ The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1031
+ symbolic dimensions.
1032
+
1033
+ A shape can be compared to another shape or plain Python list.
1034
+
1035
+ A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1036
+ unfrozen, making it suitable to be shared across tensors or values.
1037
+ Call :method:`freeze` to freeze the shape.
1038
+
1039
+ To update the dimension of a frozen shape, call :method:`copy` to create a
1040
+ new shape with the same dimensions that can be modified.
1041
+
1042
+ Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1043
+
1044
+ Example::
1045
+
1046
+ >>> import onnx_ir as ir
1047
+ >>> shape = ir.Shape(["B", None, 3])
1048
+ >>> shape.rank()
1049
+ 3
1050
+ >>> shape.is_static()
1051
+ False
1052
+ >>> shape.is_dynamic()
1053
+ True
1054
+ >>> shape.is_static(dim=2)
1055
+ True
1056
+ >>> shape[0] = 1
1057
+ >>> shape[1] = 2
1058
+ >>> shape.dims
1059
+ (1, 2, 3)
1060
+ >>> shape == [1, 2, 3]
1061
+ True
1062
+ >>> shape.frozen
1063
+ False
1064
+ >>> shape.freeze()
1065
+ >>> shape.frozen
1066
+ True
1067
+
1068
+ Attributes:
1069
+ dims: A tuple of dimensions representing the shape.
1070
+ Each dimension can be an integer, None or a :class:`SymbolicDim`.
1071
+ frozen: Indicates whether the shape is immutable. When frozen, the shape
1072
+ cannot be modified or unfrozen.
1073
+ """
1074
+
845
1075
  __slots__ = ("_dims", "_frozen")
846
1076
 
847
1077
  def __init__(
848
1078
  self,
849
- dims: Iterable[int | SymbolicDim | str | None],
1079
+ dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
850
1080
  /,
851
1081
  denotations: Iterable[str | None] | None = None,
852
1082
  frozen: bool = False,
@@ -864,11 +1094,11 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
864
1094
  Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
865
1095
  for pre-defined dimension denotations.
866
1096
  frozen: If True, the shape is immutable and cannot be modified. This
867
- is useful when the shape is initialized by a Tensor.
1097
+ is useful when the shape is initialized by a Tensor or when the shape
1098
+ is shared across multiple tensors. The default is False.
868
1099
  """
869
1100
  self._dims: list[int | SymbolicDim] = [
870
- SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim
871
- for dim in dims
1101
+ _maybe_convert_to_symbolic_dim(dim) for dim in dims
872
1102
  ]
873
1103
  self._denotations: list[str | None] = (
874
1104
  list(denotations) if denotations is not None else [None] * len(self._dims)
@@ -879,10 +1109,6 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
879
1109
  )
880
1110
  self._frozen: bool = frozen
881
1111
 
882
- def copy(self):
883
- """Return a copy of the shape."""
884
- return Shape(self._dims, self._denotations, self._frozen)
885
-
886
1112
  @property
887
1113
  def dims(self) -> tuple[int | SymbolicDim, ...]:
888
1114
  """All dimensions in the shape.
@@ -891,8 +1117,29 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
891
1117
  """
892
1118
  return tuple(self._dims)
893
1119
 
1120
+ @property
1121
+ def frozen(self) -> bool:
1122
+ """Whether the shape is frozen.
1123
+
1124
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1125
+ Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1126
+ new shape with the same dimensions that can be modified.
1127
+ """
1128
+ return self._frozen
1129
+
1130
+ def freeze(self) -> None:
1131
+ """Freeze the shape.
1132
+
1133
+ When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1134
+ """
1135
+ self._frozen = True
1136
+
1137
+ def copy(self, frozen: bool = False):
1138
+ """Return a copy of the shape."""
1139
+ return Shape(self._dims, self._denotations, frozen=frozen)
1140
+
894
1141
  def rank(self) -> int:
895
- """The rank of the shape."""
1142
+ """The rank of the tensor this shape represents."""
896
1143
  return len(self._dims)
897
1144
 
898
1145
  def numpy(self) -> tuple[int, ...]:
@@ -928,12 +1175,8 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
928
1175
  """
929
1176
  if self._frozen:
930
1177
  raise TypeError("The shape is frozen and cannot be modified.")
931
- if isinstance(value, str) or value is None:
932
- value = SymbolicDim(value)
933
- if not isinstance(value, (int, SymbolicDim)):
934
- raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
935
1178
 
936
- self._dims[index] = value
1179
+ self._dims[index] = _maybe_convert_to_symbolic_dim(value)
937
1180
 
938
1181
  def get_denotation(self, index: int) -> str | None:
939
1182
  """Return the denotation of the dimension at the index.
@@ -968,7 +1211,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
968
1211
  def __eq__(self, other: object) -> bool:
969
1212
  """Return True if the shapes are equal.
970
1213
 
971
- Two shapes are eqaul if all their dimensions are equal.
1214
+ Two shapes are equal if all their dimensions are equal.
972
1215
  """
973
1216
  if isinstance(other, Shape):
974
1217
  return self._dims == other._dims
@@ -979,6 +1222,33 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
979
1222
  def __ne__(self, other: object) -> bool:
980
1223
  return not self.__eq__(other)
981
1224
 
1225
+ @typing.overload
1226
+ def is_static(self, dim: int) -> bool: # noqa: D418
1227
+ """Return True if the dimension is static."""
1228
+
1229
+ @typing.overload
1230
+ def is_static(self) -> bool: # noqa: D418
1231
+ """Return True if all dimensions are static."""
1232
+
1233
+ def is_static(self, dim=None) -> bool:
1234
+ """Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
1235
+ if dim is None:
1236
+ return all(isinstance(dim, int) for dim in self._dims)
1237
+ return isinstance(self[dim], int)
1238
+
1239
+ @typing.overload
1240
+ def is_dynamic(self, dim: int) -> bool: # noqa: D418
1241
+ """Return True if the dimension is dynamic."""
1242
+
1243
+ @typing.overload
1244
+ def is_dynamic(self) -> bool: # noqa: D418
1245
+ """Return True if any dimension is dynamic."""
1246
+
1247
+ def is_dynamic(self, dim=None) -> bool:
1248
+ if dim is None:
1249
+ return not self.is_static()
1250
+ return not self.is_static(dim)
1251
+
982
1252
 
983
1253
  def _quoted(string: str) -> str:
984
1254
  """Return a quoted string.
@@ -988,6 +1258,35 @@ def _quoted(string: str) -> str:
988
1258
  return f'"{string}"'
989
1259
 
990
1260
 
1261
+ class Usage(NamedTuple):
1262
+ """A usage of a value in a node.
1263
+
1264
+ Attributes:
1265
+ node: The node that uses the value.
1266
+ idx: The input index of the value in the node.
1267
+ """
1268
+
1269
+ node: Node
1270
+ idx: int
1271
+
1272
+
1273
+ def _short_tensor_str_for_node(x: Value) -> str:
1274
+ if x.const_value is None:
1275
+ return ""
1276
+ if x.const_value.size <= 10:
1277
+ try:
1278
+ data = x.const_value.numpy().tolist()
1279
+ except Exception: # pylint: disable=broad-except
1280
+ return "{...}"
1281
+ return f"{{{data}}}"
1282
+ return "{...}"
1283
+
1284
+
1285
+ def _normalize_domain(domain: str) -> str:
1286
+ """Normalize 'ai.onnx' to ''."""
1287
+ return "" if domain == "ai.onnx" else domain
1288
+
1289
+
991
1290
  class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
992
1291
  """IR Node.
993
1292
 
@@ -1001,6 +1300,9 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1001
1300
  To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1002
1301
  the new output values by calling :meth:`replace_input_with` on the using nodes
1003
1302
  of this node's outputs.
1303
+
1304
+ .. note:
1305
+ When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
1004
1306
  """
1005
1307
 
1006
1308
  __slots__ = (
@@ -1023,13 +1325,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1023
1325
  domain: str,
1024
1326
  op_type: str,
1025
1327
  inputs: Iterable[Value | None],
1026
- attributes: Iterable[Attr | RefAttr] = (),
1328
+ attributes: Iterable[Attr] = (),
1027
1329
  *,
1028
1330
  overload: str = "",
1029
1331
  num_outputs: int | None = None,
1030
1332
  outputs: Sequence[Value] | None = None,
1031
1333
  version: int | None = None,
1032
- graph: Graph | None = None,
1334
+ graph: Graph | Function | None = None,
1033
1335
  name: str | None = None,
1034
1336
  doc_string: str | None = None,
1035
1337
  metadata_props: dict[str, str] | None = None,
@@ -1038,27 +1340,30 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1038
1340
 
1039
1341
  Args:
1040
1342
  domain: The domain of the operator. For onnx operators, this is an empty string.
1343
+ When it is `"ai.onnx"`, it is normalized to `""`.
1041
1344
  op_type: The name of the operator.
1042
- inputs: The input values. When an input is None, it is an empty input.
1345
+ inputs: The input values. When an input is ``None``, it is an empty input.
1043
1346
  attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
1044
1347
  overload: The overload name when the node is invoking a function.
1045
1348
  num_outputs: The number of outputs of the node. If not specified, the number is 1.
1046
- outputs: The output values. If None, the outputs are created during initialization.
1047
- version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
1048
- graph: The graph that the node belongs to. If None, the node is not added to any graph.
1049
- A `Node` must belong to zero or one graph.
1050
- name: The name of the node. If None, the node is anonymous.
1349
+ outputs: The output values. If ``None``, the outputs are created during initialization.
1350
+ version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
1351
+ graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
1352
+ A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
1353
+ of the function is assigned to the node.
1354
+ name: The name of the node. If ``None``, the node is anonymous. The name may be
1355
+ set by a :class:`Graph` if ``graph`` is specified.
1051
1356
  doc_string: The documentation string.
1052
1357
  metadata_props: The metadata properties.
1053
1358
 
1054
1359
  Raises:
1055
- TypeError: If the attributes are not Attr or RefAttr.
1056
- ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1057
- ValueError: If an output value is None, when outputs is specified.
1360
+ TypeError: If the attributes are not :class:`Attr`.
1361
+ ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
1362
+ ValueError: If an output value is ``None``, when outputs is specified.
1058
1363
  ValueError: If an output value has a producer set already, when outputs is specified.
1059
1364
  """
1060
1365
  self._name = name
1061
- self._domain: str = domain
1366
+ self._domain: str = _normalize_domain(domain)
1062
1367
  self._op_type: str = op_type
1063
1368
  # NOTE: Make inputs immutable with the assumption that they are not mutated
1064
1369
  # very often. This way all mutations can be tracked.
@@ -1067,13 +1372,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1067
1372
  # Values belong to their defining nodes. The values list is immutable
1068
1373
  self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1069
1374
  attributes = tuple(attributes)
1070
- if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
1375
+ if attributes and not isinstance(attributes[0], Attr):
1071
1376
  raise TypeError(
1072
- f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. "
1377
+ f"Expected the attributes to be Attr, got {type(attributes[0])}. "
1073
1378
  "If you are copying the attributes from another node, make sure you call "
1074
1379
  "node.attributes.values() because it is a dictionary."
1075
1380
  )
1076
- self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
1381
+ self._attributes: OrderedDict[str, Attr] = OrderedDict(
1077
1382
  (attr.name, attr) for attr in attributes
1078
1383
  )
1079
1384
  self._overload: str = overload
@@ -1081,7 +1386,11 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1081
1386
  self._version: int | None = version
1082
1387
  self._metadata: _metadata.MetadataStore | None = None
1083
1388
  self._metadata_props: dict[str, str] | None = metadata_props
1084
- self._graph: Graph | None = graph
1389
+ # _graph is set by graph.append
1390
+ self._graph: Graph | None = None
1391
+ # Add the node to the graph if graph is specified
1392
+ if graph is not None:
1393
+ graph.append(self)
1085
1394
  self.doc_string = doc_string
1086
1395
 
1087
1396
  # Add the node as a use of the inputs
@@ -1089,10 +1398,6 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1089
1398
  if input_value is not None:
1090
1399
  input_value._add_usage(self, i) # pylint: disable=protected-access
1091
1400
 
1092
- # Add the node to the graph if graph is specified
1093
- if self._graph is not None:
1094
- self._graph.append(self)
1095
-
1096
1401
  def _create_outputs(
1097
1402
  self, num_outputs: int | None, outputs: Sequence[Value] | None
1098
1403
  ) -> tuple[Value, ...]:
@@ -1150,7 +1455,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1150
1455
  + ", ".join(
1151
1456
  [
1152
1457
  (
1153
- f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
1458
+ f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
1154
1459
  if x is not None
1155
1460
  else "None"
1156
1461
  )
@@ -1178,6 +1483,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1178
1483
 
1179
1484
  @property
1180
1485
  def name(self) -> str | None:
1486
+ """Optional name of the node."""
1181
1487
  return self._name
1182
1488
 
1183
1489
  @name.setter
@@ -1186,14 +1492,26 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1186
1492
 
1187
1493
  @property
1188
1494
  def domain(self) -> str:
1495
+ """The domain of the operator. For onnx operators, this is an empty string.
1496
+
1497
+ .. note:
1498
+ When domain is `"ai.onnx"`, it is normalized to `""`.
1499
+ """
1189
1500
  return self._domain
1190
1501
 
1191
1502
  @domain.setter
1192
1503
  def domain(self, value: str) -> None:
1193
- self._domain = value
1504
+ self._domain = _normalize_domain(value)
1194
1505
 
1195
1506
  @property
1196
1507
  def version(self) -> int | None:
1508
+ """Opset version of the operator called.
1509
+
1510
+ If ``None``, the version is unspecified and will follow that of the graph.
1511
+ This property is special to ONNX IR to allow mixed opset usage in a graph
1512
+ for supporting more flexible graph transformations. It does not exist in the ONNX
1513
+ serialization (protobuf) spec.
1514
+ """
1197
1515
  return self._version
1198
1516
 
1199
1517
  @version.setter
@@ -1202,6 +1520,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1202
1520
 
1203
1521
  @property
1204
1522
  def op_type(self) -> str:
1523
+ """The name of the operator called."""
1205
1524
  return self._op_type
1206
1525
 
1207
1526
  @op_type.setter
@@ -1210,6 +1529,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1210
1529
 
1211
1530
  @property
1212
1531
  def overload(self) -> str:
1532
+ """The overload name when the node is invoking a function."""
1213
1533
  return self._overload
1214
1534
 
1215
1535
  @overload.setter
@@ -1218,6 +1538,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1218
1538
 
1219
1539
  @property
1220
1540
  def inputs(self) -> Sequence[Value | None]:
1541
+ """The input values of the node.
1542
+
1543
+ The inputs are immutable. To change the inputs, create a new node and
1544
+ replace the inputs of the using nodes of this node's outputs by calling
1545
+ :meth:`replace_input_with` on the using nodes of this node's outputs.
1546
+ """
1221
1547
  return self._inputs
1222
1548
 
1223
1549
  @inputs.setter
@@ -1226,6 +1552,25 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1226
1552
  "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
1227
1553
  )
1228
1554
 
1555
+ def predecessors(self) -> Sequence[Node]:
1556
+ """Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
1557
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1558
+ predecessors: dict[Node, None] = {}
1559
+ for value in self.inputs:
1560
+ if value is not None and (producer := value.producer()) is not None:
1561
+ predecessors[producer] = None
1562
+ return tuple(predecessors)
1563
+
1564
+ def successors(self) -> Sequence[Node]:
1565
+ """Return the successor nodes of the node, deduplicated, in a deterministic order."""
1566
+ # Use the ordered nature of a dictionary to deduplicate the nodes
1567
+ successors: dict[Node, None] = {}
1568
+ for value in self.outputs:
1569
+ assert value is not None, "Bug: Output values are not expected to be None"
1570
+ for usage in value.uses():
1571
+ successors[usage.node] = None
1572
+ return tuple(successors)
1573
+
1229
1574
  def replace_input_with(self, index: int, value: Value | None) -> None:
1230
1575
  """Replace an input with a new value."""
1231
1576
  if index < 0 or index >= len(self.inputs):
@@ -1279,6 +1624,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1279
1624
 
1280
1625
  @property
1281
1626
  def outputs(self) -> Sequence[Value]:
1627
+ """The output values of the node.
1628
+
1629
+ The outputs are immutable. To change the outputs, create a new node and
1630
+ replace the inputs of the using nodes of this node's outputs by calling
1631
+ :meth:`replace_input_with` on the using nodes of this node's outputs.
1632
+ """
1282
1633
  return self._outputs
1283
1634
 
1284
1635
  @outputs.setter
@@ -1286,7 +1637,8 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1286
1637
  raise AttributeError("outputs is immutable. Please create a new node instead.")
1287
1638
 
1288
1639
  @property
1289
- def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
1640
+ def attributes(self) -> OrderedDict[str, Attr]:
1641
+ """The attributes of the node."""
1290
1642
  return self._attributes
1291
1643
 
1292
1644
  @property
@@ -1302,12 +1654,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1302
1654
 
1303
1655
  @property
1304
1656
  def metadata_props(self) -> dict[str, str]:
1657
+ """The metadata properties of the node.
1658
+
1659
+ The metadata properties are used to store additional information about the node.
1660
+ Unlike ``meta``, this property is serialized to the ONNX proto.
1661
+ """
1305
1662
  if self._metadata_props is None:
1306
1663
  self._metadata_props = {}
1307
1664
  return self._metadata_props
1308
1665
 
1309
1666
  @property
1310
1667
  def graph(self) -> Graph | None:
1668
+ """The graph that the node belongs to.
1669
+
1670
+ If the node is not added to any graph, this property is None.
1671
+ """
1311
1672
  return self._graph
1312
1673
 
1313
1674
  @graph.setter
@@ -1315,9 +1676,17 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1315
1676
  self._graph = value
1316
1677
 
1317
1678
  def op_identifier(self) -> _protocols.OperatorIdentifier:
1679
+ """Return the operator identifier of the node.
1680
+
1681
+ The operator identifier is a tuple of the domain, op_type and overload.
1682
+ """
1318
1683
  return self.domain, self.op_type, self.overload
1319
1684
 
1320
1685
  def display(self, *, page: bool = False) -> None:
1686
+ """Pretty print the node.
1687
+
1688
+ This method is used for debugging and visualization purposes.
1689
+ """
1321
1690
  # Add the node's name to the displayed text
1322
1691
  print(f"Node: {self.name!r}")
1323
1692
  if self.doc_string:
@@ -1344,7 +1713,7 @@ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashabl
1344
1713
 
1345
1714
  @property
1346
1715
  def elem_type(self) -> _enums.DataType:
1347
- """Return the element type of the tensor type"""
1716
+ """Return the element type of the tensor type."""
1348
1717
  return self.dtype
1349
1718
 
1350
1719
  def __hash__(self) -> int:
@@ -1438,18 +1807,19 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1438
1807
 
1439
1808
  To find all the nodes that use this value as an input, call :meth:`uses`.
1440
1809
 
1441
- To check if the value is an output of a graph, call :meth:`is_graph_output`.
1810
+ To check if the value is an is an input, output or initializer of a graph,
1811
+ use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
1442
1812
 
1443
- Attributes:
1444
- name: The name of the value. A value is always named when it is part of a graph.
1445
- shape: The shape of the value.
1446
- type: The type of the value.
1447
- metadata_props: Metadata.
1813
+ Use :meth:`graph` to get the graph that owns the value.
1448
1814
  """
1449
1815
 
1450
1816
  __slots__ = (
1451
1817
  "_const_value",
1818
+ "_graph",
1452
1819
  "_index",
1820
+ "_is_graph_input",
1821
+ "_is_graph_output",
1822
+ "_is_initializer",
1453
1823
  "_metadata",
1454
1824
  "_metadata_props",
1455
1825
  "_name",
@@ -1497,19 +1867,33 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1497
1867
  # Use a collection of (Node, int) to store uses. This is needed
1498
1868
  # because a single use can use the same value multiple times.
1499
1869
  # Use a dictionary to preserve insertion order so that the visiting order is deterministic
1500
- self._uses: dict[tuple[Node, int], None] = {}
1870
+ self._uses: dict[Usage, None] = {}
1501
1871
  self.doc_string = doc_string
1502
1872
 
1873
+ # The graph this value belongs to. It is set *only* when the value is added as
1874
+ # a graph input, output or initializer.
1875
+ # The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
1876
+ self._graph: Graph | None = None
1877
+ self._is_graph_input: bool = False
1878
+ self._is_graph_output: bool = False
1879
+ self._is_initializer: bool = False
1880
+
1503
1881
  def __repr__(self) -> str:
1504
1882
  value_name = self.name if self.name else "anonymous:" + str(id(self))
1883
+ type_text = f", type={self.type!r}" if self.type is not None else ""
1884
+ shape_text = f", shape={self.shape!r}" if self.shape is not None else ""
1505
1885
  producer = self.producer()
1506
1886
  if producer is None:
1507
- producer_text = "None"
1887
+ producer_text = ""
1508
1888
  elif producer.name is not None:
1509
- producer_text = producer.name
1889
+ producer_text = f", producer='{producer.name}'"
1510
1890
  else:
1511
- producer_text = f"anonymous_node:{id(producer)}"
1512
- return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})"
1891
+ producer_text = f", producer=anonymous_node:{id(producer)}"
1892
+ index_text = f", index={self.index()}" if self.index() is not None else ""
1893
+ const_value_text = self._constant_tensor_part()
1894
+ if const_value_text:
1895
+ const_value_text = f", const_value={const_value_text}"
1896
+ return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})"
1513
1897
 
1514
1898
  def __str__(self) -> str:
1515
1899
  value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
@@ -1518,41 +1902,85 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1518
1902
 
1519
1903
  # Quote the name because in reality the names can have invalid characters
1520
1904
  # that make them hard to read
1521
- return f"%{_quoted(value_name)}<{type_text},{shape_text}>"
1905
+ return (
1906
+ f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
1907
+ )
1908
+
1909
+ def _constant_tensor_part(self) -> str:
1910
+ """Display string for the constant tensor attached to str of Value."""
1911
+ if self.const_value is not None:
1912
+ # Only display when the const value is small
1913
+ if self.const_value.size <= 10:
1914
+ return f"{{{self.const_value}}}"
1915
+ else:
1916
+ return f"{{{self.const_value.__class__.__name__}(...)}}"
1917
+ return ""
1918
+
1919
+ @property
1920
+ def graph(self) -> Graph | None:
1921
+ """Return the graph that defines this value.
1922
+
1923
+ When the value is an input/output/initializer of a graph, the owning graph
1924
+ is that graph. When the value is an output of a node, the owning graph is the
1925
+ graph that the node belongs to. When the value is not owned by any graph,
1926
+ it returns ``None``.
1927
+ """
1928
+ if self._graph is not None:
1929
+ return self._graph
1930
+ if self._producer is not None:
1931
+ return self._producer.graph
1932
+ return None
1933
+
1934
+ def _owned_by_graph(self) -> bool:
1935
+ """Return True if the value is owned by a graph."""
1936
+ result = self._is_graph_input or self._is_graph_output or self._is_initializer
1937
+ if result:
1938
+ assert self._graph is not None
1939
+ return result
1522
1940
 
1523
1941
  def producer(self) -> Node | None:
1524
1942
  """The node that produces this value.
1525
1943
 
1526
1944
  When producer is ``None``, the value does not belong to a node, and is
1527
- typically a graph input or an initializer.
1945
+ typically a graph input or an initializer. You can use :meth:`graph``
1946
+ to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
1947
+ or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
1528
1948
  """
1529
1949
  return self._producer
1530
1950
 
1951
+ def consumers(self) -> Sequence[Node]:
1952
+ """Return the nodes (deduplicated) that consume this value."""
1953
+ return tuple({usage.node: None for usage in self._uses})
1954
+
1531
1955
  def index(self) -> int | None:
1532
1956
  """The index of the output of the defining node."""
1533
1957
  return self._index
1534
1958
 
1535
- def uses(self) -> Collection[tuple[Node, int]]:
1959
+ def uses(self) -> Collection[Usage]:
1536
1960
  """Return a set of uses of the value.
1537
1961
 
1538
1962
  The set contains tuples of ``(Node, index)`` where the index is the index of the input
1539
1963
  of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
1540
1964
  """
1541
- return self._uses.keys()
1965
+ # Create a tuple for the collection so that iteration on will will not
1966
+ # be affected when the usage changes during graph mutation.
1967
+ # This adds a small overhead but is better a user experience than
1968
+ # having users call tuple().
1969
+ return tuple(self._uses)
1542
1970
 
1543
1971
  def _add_usage(self, use: Node, index: int) -> None:
1544
1972
  """Add a usage of this value.
1545
1973
 
1546
1974
  This is an internal method. It should only be called by the Node class.
1547
1975
  """
1548
- self._uses[(use, index)] = None
1976
+ self._uses[Usage(use, index)] = None
1549
1977
 
1550
1978
  def _remove_usage(self, use: Node, index: int) -> None:
1551
1979
  """Remove a node from the uses of this value.
1552
1980
 
1553
1981
  This is an internal method. It should only be called by the Node class.
1554
1982
  """
1555
- self._uses.pop((use, index))
1983
+ self._uses.pop(Usage(use, index))
1556
1984
 
1557
1985
  @property
1558
1986
  def name(self) -> str | None:
@@ -1652,15 +2080,17 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1652
2080
  self._metadata_props = {}
1653
2081
  return self._metadata_props
1654
2082
 
2083
+ def is_graph_input(self) -> bool:
2084
+ """Whether the value is an input of a graph."""
2085
+ return self._is_graph_input
2086
+
1655
2087
  def is_graph_output(self) -> bool:
1656
2088
  """Whether the value is an output of a graph."""
1657
- if (producer := self.producer()) is None:
1658
- return False
1659
- if (graph := producer.graph) is None:
1660
- return False
1661
- # Cannot use `in` because __eq__ may be defined by subclasses, even though
1662
- # it is not recommended
1663
- return any(output is self for output in graph.outputs)
2089
+ return self._is_graph_output
2090
+
2091
+ def is_initializer(self) -> bool:
2092
+ """Whether the value is an initializer of a graph."""
2093
+ return self._is_initializer
1664
2094
 
1665
2095
 
1666
2096
  def Input(
@@ -1673,7 +2103,6 @@ def Input(
1673
2103
 
1674
2104
  This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
1675
2105
  """
1676
-
1677
2106
  # NOTE: The function name is capitalized to maintain API backward compatibility.
1678
2107
 
1679
2108
  return Value(name=name, shape=shape, type=type, doc_string=doc_string)
@@ -1770,9 +2199,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1770
2199
  self.name = name
1771
2200
 
1772
2201
  # Private fields that are not to be accessed by any other classes
1773
- self._inputs = list(inputs)
1774
- self._outputs = list(outputs)
1775
- self._initializers = {}
2202
+ self._inputs = _graph_containers.GraphInputs(self, inputs)
2203
+ self._outputs = _graph_containers.GraphOutputs(self, outputs)
2204
+ self._initializers = _graph_containers.GraphInitializers(self)
1776
2205
  for initializer in initializers:
1777
2206
  if isinstance(initializer, str):
1778
2207
  raise TypeError(
@@ -1791,21 +2220,59 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1791
2220
  # Be sure the initialize the name authority before extending the nodes
1792
2221
  # because it is used to name the nodes and their outputs
1793
2222
  self._name_authority = _name_authority.NameAuthority()
2223
+ # TODO(justinchuby): Trigger again if inputs or initializers are modified.
2224
+ self._set_input_and_initializer_value_names_into_name_authority()
1794
2225
  # Call self.extend not self._nodes.extend so the graph reference is added to the nodes
1795
2226
  self.extend(nodes)
1796
2227
 
1797
2228
  @property
1798
- def inputs(self) -> list[Value]:
2229
+ def inputs(self) -> MutableSequence[Value]:
1799
2230
  return self._inputs
1800
2231
 
1801
2232
  @property
1802
- def outputs(self) -> list[Value]:
2233
+ def outputs(self) -> MutableSequence[Value]:
1803
2234
  return self._outputs
1804
2235
 
1805
2236
  @property
1806
- def initializers(self) -> dict[str, Value]:
2237
+ def initializers(self) -> MutableMapping[str, Value]:
1807
2238
  return self._initializers
1808
2239
 
2240
+ def register_initializer(self, value: Value) -> None:
2241
+ """Register an initializer to the graph.
2242
+
2243
+ This is a convenience method to register an initializer to the graph with
2244
+ checks.
2245
+
2246
+ Args:
2247
+ value: The :class:`Value` to register as an initializer of the graph.
2248
+ It must have its ``.const_value`` set.
2249
+
2250
+ Raises:
2251
+ ValueError: If a value of the same name that is not this value
2252
+ is already registered.
2253
+ ValueError: If the value does not have a name.
2254
+ ValueError: If the initializer is produced by a node.
2255
+ ValueError: If the value does not have its ``.const_value`` set.
2256
+ """
2257
+ if not value.name:
2258
+ raise ValueError(f"Initializer must have a name: {value!r}")
2259
+ if value.name in self._initializers:
2260
+ if self._initializers[value.name] is not value:
2261
+ raise ValueError(
2262
+ f"Initializer '{value.name}' is already registered, but"
2263
+ " it is not the same object: existing={self._initializers[value.name]!r},"
2264
+ f" new={value!r}"
2265
+ )
2266
+ if value.producer() is not None:
2267
+ raise ValueError(
2268
+ f"Value '{value!r}' is produced by a node and cannot be an initializer."
2269
+ )
2270
+ if value.const_value is None:
2271
+ raise ValueError(
2272
+ f"Value '{value!r}' must have its const_value set to be an initializer."
2273
+ )
2274
+ self._initializers[value.name] = value
2275
+
1809
2276
  @property
1810
2277
  def doc_string(self) -> str | None:
1811
2278
  return self._doc_string
@@ -1818,7 +2285,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1818
2285
  def opset_imports(self) -> dict[str, int]:
1819
2286
  return self._opset_imports
1820
2287
 
1821
- def __getitem__(self, index: int) -> Node:
2288
+ @typing.overload
2289
+ def __getitem__(self, index: int) -> Node: ...
2290
+ @typing.overload
2291
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2292
+
2293
+ def __getitem__(self, index):
1822
2294
  return self._nodes[index]
1823
2295
 
1824
2296
  def __len__(self) -> int:
@@ -1830,6 +2302,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1830
2302
  def __reversed__(self) -> Iterator[Node]:
1831
2303
  return reversed(self._nodes)
1832
2304
 
2305
+ def _set_input_and_initializer_value_names_into_name_authority(self):
2306
+ for value in self.inputs:
2307
+ self._name_authority.register_or_name_value(value)
2308
+ for value in self.initializers.values():
2309
+ self._name_authority.register_or_name_value(value)
2310
+
1833
2311
  def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
1834
2312
  """Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
1835
2313
  if node.graph is not None and node.graph is not self:
@@ -1993,7 +2471,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1993
2471
 
1994
2472
  This sort is stable. It preserves the original order as much as possible.
1995
2473
 
1996
- Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
2474
+ Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
1997
2475
 
1998
2476
  Raises:
1999
2477
  ValueError: If the graph contains a cycle, making topological sorting impossible.
@@ -2004,7 +2482,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2004
2482
  sorted_nodes_by_graph: dict[Graph, list[Node]] = {
2005
2483
  graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
2006
2484
  }
2007
- # TODO: Explain why we need to store direct predecessors and children and why
2485
+ # TODO(justinchuby): Explain why we need to store direct predecessors and children and why
2008
2486
  # we only need to store the direct ones
2009
2487
 
2010
2488
  # The depth of a node is defined as the number of direct children it has
@@ -2024,7 +2502,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2024
2502
  node_depth[predecessor] += 1
2025
2503
 
2026
2504
  # 1. Build the direct predecessors of each node and the depth of each node
2027
- # for sorting topolocally using Kahn's algorithm.
2505
+ # for sorting topologically using Kahn's algorithm.
2028
2506
  # Note that when a node contains graph attributes (aka. has subgraphs),
2029
2507
  # we consider all nodes in the subgraphs *predecessors* of this node. This
2030
2508
  # way we ensure the implicit dependencies of the subgraphs are captured
@@ -2125,11 +2603,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
2125
2603
  )
2126
2604
  signature = f"""\
2127
2605
  graph(
2128
- name={graph.name or 'anonymous_graph:' + str(id(graph))},
2129
- inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2606
+ name={graph.name or "anonymous_graph:" + str(id(graph))},
2607
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
2130
2608
  ),
2131
- outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2132
- ),{textwrap.indent(initializers_text, ' ' * 4)}
2609
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
2610
+ ),{textwrap.indent(initializers_text, " " * 4)}
2133
2611
  )"""
2134
2612
  node_count = len(graph)
2135
2613
  number_width = len(str(node_count))
@@ -2163,11 +2641,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
2163
2641
  )
2164
2642
  return f"""\
2165
2643
  {graph.__class__.__name__}(
2166
- name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
2167
- inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2644
+ name={graph.name or "anonymous_graph:" + str(id(graph))!r},
2645
+ inputs=({textwrap.indent(inputs_text, " " * 8)}
2168
2646
  ),
2169
- outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2170
- ),{textwrap.indent(initializers_text, ' ' * 4)}
2647
+ outputs=({textwrap.indent(outputs_text, " " * 8)}
2648
+ ),{textwrap.indent(initializers_text, " " * 4)}
2171
2649
  len()={len(graph)}
2172
2650
  )"""
2173
2651
 
@@ -2242,7 +2720,12 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2242
2720
  self._metadata_props: dict[str, str] | None = metadata_props
2243
2721
  self._nodes: tuple[Node, ...] = tuple(nodes)
2244
2722
 
2245
- def __getitem__(self, index: int) -> Node:
2723
+ @typing.overload
2724
+ def __getitem__(self, index: int) -> Node: ...
2725
+ @typing.overload
2726
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2727
+
2728
+ def __getitem__(self, index):
2246
2729
  return self._nodes[index]
2247
2730
 
2248
2731
  def __len__(self) -> int:
@@ -2318,7 +2801,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
2318
2801
  model_version: int | None = None,
2319
2802
  doc_string: str | None = None,
2320
2803
  functions: Sequence[Function] = (),
2321
- meta_data_props: dict[str, str] | None = None,
2804
+ metadata_props: dict[str, str] | None = None,
2322
2805
  ) -> None:
2323
2806
  self.graph: Graph = graph
2324
2807
  self.ir_version = ir_version
@@ -2329,7 +2812,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
2329
2812
  self.doc_string = doc_string
2330
2813
  self._functions = {func.identifier(): func for func in functions}
2331
2814
  self._metadata: _metadata.MetadataStore | None = None
2332
- self._metadata_props: dict[str, str] | None = meta_data_props
2815
+ self._metadata_props: dict[str, str] | None = metadata_props
2333
2816
 
2334
2817
  @property
2335
2818
  def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
@@ -2381,9 +2864,28 @@ Model(
2381
2864
  domain={self.domain!r},
2382
2865
  model_version={self.model_version!r},
2383
2866
  functions={self.functions!r},
2384
- graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
2867
+ graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
2385
2868
  )"""
2386
2869
 
2870
+ def graphs(self) -> Iterable[Graph]:
2871
+ """Get all graphs and subgraphs in the model.
2872
+
2873
+ This is a convenience method to traverse the model. Consider using
2874
+ `onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2875
+ traversals on nodes.
2876
+ """
2877
+ # NOTE(justinchuby): Given
2878
+ # (1) how useful the method is
2879
+ # (2) I couldn't find an appropriate name for it in `traversal.py`
2880
+ # (3) Users familiar with onnxruntime optimization tools expect this method
2881
+ # I created this method as a core method instead of an iterator in
2882
+ # `traversal.py`.
2883
+ seen_graphs: set[Graph] = set()
2884
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2885
+ if node.graph is not None and node.graph not in seen_graphs:
2886
+ seen_graphs.add(node.graph)
2887
+ yield node.graph
2888
+
2387
2889
 
2388
2890
  class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
2389
2891
  """IR functions.
@@ -2404,16 +2906,14 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2404
2906
  outputs: The output values of the function.
2405
2907
  opset_imports: Opsets imported by the function.
2406
2908
  doc_string: Documentation string.
2407
- metadata_props: Metadata that will be serialized to the ONNX file.
2408
2909
  meta: Metadata store for graph transform passes.
2910
+ metadata_props: Metadata that will be serialized to the ONNX file.
2409
2911
  """
2410
2912
 
2411
2913
  __slots__ = (
2412
2914
  "_attributes",
2413
2915
  "_domain",
2414
2916
  "_graph",
2415
- "_metadata",
2416
- "_metadata_props",
2417
2917
  "_name",
2418
2918
  "_overload",
2419
2919
  )
@@ -2428,15 +2928,12 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2428
2928
  # and not from an outer scope
2429
2929
  graph: Graph,
2430
2930
  attributes: Sequence[Attr],
2431
- metadata_props: dict[str, str] | None = None,
2432
2931
  ) -> None:
2433
2932
  self._domain = domain
2434
2933
  self._name = name
2435
2934
  self._overload = overload
2436
2935
  self._graph = graph
2437
2936
  self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2438
- self._metadata: _metadata.MetadataStore | None = None
2439
- self._metadata_props: dict[str, str] | None = metadata_props
2440
2937
 
2441
2938
  def identifier(self) -> _protocols.OperatorIdentifier:
2442
2939
  return self.domain, self.name, self.overload
@@ -2455,7 +2952,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2455
2952
 
2456
2953
  @domain.setter
2457
2954
  def domain(self, value: str) -> None:
2458
- self._domain = value
2955
+ self._domain = _normalize_domain(value)
2459
2956
 
2460
2957
  @property
2461
2958
  def overload(self) -> str:
@@ -2466,18 +2963,23 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2466
2963
  self._overload = value
2467
2964
 
2468
2965
  @property
2469
- def inputs(self) -> list[Value]:
2966
+ def inputs(self) -> MutableSequence[Value]:
2470
2967
  return self._graph.inputs
2471
2968
 
2472
2969
  @property
2473
- def outputs(self) -> list[Value]:
2970
+ def outputs(self) -> MutableSequence[Value]:
2474
2971
  return self._graph.outputs
2475
2972
 
2476
2973
  @property
2477
2974
  def attributes(self) -> OrderedDict[str, Attr]:
2478
2975
  return self._attributes
2479
2976
 
2480
- def __getitem__(self, index: int) -> Node:
2977
+ @typing.overload
2978
+ def __getitem__(self, index: int) -> Node: ...
2979
+ @typing.overload
2980
+ def __getitem__(self, index: slice) -> Sequence[Node]: ...
2981
+
2982
+ def __getitem__(self, index):
2481
2983
  return self._graph.__getitem__(index)
2482
2984
 
2483
2985
  def __len__(self) -> int:
@@ -2508,15 +3010,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2508
3010
  Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2509
3011
  to the ONNX proto.
2510
3012
  """
2511
- if self._metadata is None:
2512
- self._metadata = _metadata.MetadataStore()
2513
- return self._metadata
3013
+ return self._graph.meta
2514
3014
 
2515
3015
  @property
2516
3016
  def metadata_props(self) -> dict[str, str]:
2517
- if self._metadata_props is None:
2518
- self._metadata_props = {}
2519
- return self._metadata_props
3017
+ return self._graph.metadata_props
2520
3018
 
2521
3019
  # Mutation methods
2522
3020
  def append(self, node: Node, /) -> None:
@@ -2549,11 +3047,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2549
3047
  """
2550
3048
  self._graph.remove(nodes, safe=safe)
2551
3049
 
2552
- def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
3050
+ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
2553
3051
  """Insert new nodes after the given node in O(#new_nodes) time."""
2554
3052
  self._graph.insert_after(node, new_nodes)
2555
3053
 
2556
- def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
3054
+ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
2557
3055
  """Insert new nodes before the given node in O(#new_nodes) time."""
2558
3056
  self._graph.insert_before(node, new_nodes)
2559
3057
 
@@ -2581,10 +3079,10 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2581
3079
  >
2582
3080
  def {full_name}(
2583
3081
  inputs=(
2584
- {textwrap.indent(inputs_text, ' ' * 8)}
2585
- ),{textwrap.indent(attributes_text, ' ' * 4)}
3082
+ {textwrap.indent(inputs_text, " " * 8)}
3083
+ ),{textwrap.indent(attributes_text, " " * 4)}
2586
3084
  outputs=(
2587
- {textwrap.indent(outputs_text, ' ' * 8)}
3085
+ {textwrap.indent(outputs_text, " " * 8)}
2588
3086
  ),
2589
3087
  )"""
2590
3088
  node_count = len(self)
@@ -2611,22 +3109,28 @@ def {full_name}(
2611
3109
  return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
2612
3110
 
2613
3111
 
2614
- class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
2615
- """Reference attribute."""
3112
+ class Attr(
3113
+ _protocols.AttributeProtocol,
3114
+ _protocols.ReferenceAttributeProtocol,
3115
+ _display.PrettyPrintable,
3116
+ ):
3117
+ """Base class for ONNX attributes or references."""
2616
3118
 
2617
- __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
3119
+ __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
2618
3120
 
2619
3121
  def __init__(
2620
3122
  self,
2621
3123
  name: str,
2622
- ref_attr_name: str,
2623
3124
  type: _enums.AttributeType,
3125
+ value: Any,
3126
+ ref_attr_name: str | None = None,
2624
3127
  *,
2625
3128
  doc_string: str | None = None,
2626
3129
  ) -> None:
2627
3130
  self._name = name
2628
- self._ref_attr_name = ref_attr_name
2629
3131
  self._type = type
3132
+ self._value = value
3133
+ self._ref_attr_name = ref_attr_name
2630
3134
  self.doc_string = doc_string
2631
3135
 
2632
3136
  @property
@@ -2637,43 +3141,21 @@ class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
2637
3141
  def name(self, value: str) -> None:
2638
3142
  self._name = value
2639
3143
 
2640
- @property
2641
- def ref_attr_name(self) -> str:
2642
- return self._ref_attr_name
2643
-
2644
- @ref_attr_name.setter
2645
- def ref_attr_name(self, value: str) -> None:
2646
- self._ref_attr_name = value
2647
-
2648
3144
  @property
2649
3145
  def type(self) -> _enums.AttributeType:
2650
3146
  return self._type
2651
3147
 
2652
- @type.setter
2653
- def type(self, value: _enums.AttributeType) -> None:
2654
- self._type = value
2655
-
2656
- def __repr__(self) -> str:
2657
- return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
2658
-
2659
-
2660
- class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
2661
- """Base class for ONNX attributes."""
3148
+ @property
3149
+ def value(self) -> Any:
3150
+ return self._value
2662
3151
 
2663
- __slots__ = ("doc_string", "name", "type", "value")
3152
+ @property
3153
+ def ref_attr_name(self) -> str | None:
3154
+ return self._ref_attr_name
2664
3155
 
2665
- def __init__(
2666
- self,
2667
- name: str,
2668
- type: _enums.AttributeType,
2669
- value: Any,
2670
- *,
2671
- doc_string: str | None = None,
2672
- ):
2673
- self.name = name
2674
- self.type = type
2675
- self.value = value
2676
- self.doc_string = doc_string
3156
+ def is_ref(self) -> bool:
3157
+ """Check if this attribute is a reference attribute."""
3158
+ return self.ref_attr_name is not None
2677
3159
 
2678
3160
  def __eq__(self, other: object) -> bool:
2679
3161
  if not isinstance(other, _protocols.AttributeProtocol):
@@ -2690,15 +3172,157 @@ class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
2690
3172
  return True
2691
3173
 
2692
3174
  def __str__(self) -> str:
3175
+ if self.is_ref():
3176
+ return f"@{self.ref_attr_name}"
2693
3177
  if self.type == _enums.AttributeType.GRAPH:
2694
3178
  return textwrap.indent("\n" + str(self.value), " " * 4)
2695
3179
  return str(self.value)
2696
3180
 
2697
3181
  def __repr__(self) -> str:
3182
+ if self.is_ref():
3183
+ return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
2698
3184
  return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
2699
3185
 
3186
+ # Well typed getters
3187
+ def as_float(self) -> float:
3188
+ """Get the attribute value as a float."""
3189
+ if self.type != _enums.AttributeType.FLOAT:
3190
+ raise TypeError(
3191
+ f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
3192
+ )
3193
+ # Do not use isinstance check because it may prevent np.float32 etc. from being used
3194
+ return float(self.value)
3195
+
3196
+ def as_int(self) -> int:
3197
+ """Get the attribute value as an int."""
3198
+ if self.type != _enums.AttributeType.INT:
3199
+ raise TypeError(
3200
+ f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
3201
+ )
3202
+ # Do not use isinstance check because it may prevent np.int32 etc. from being used
3203
+ return int(self.value)
3204
+
3205
+ def as_string(self) -> str:
3206
+ """Get the attribute value as a string."""
3207
+ if self.type != _enums.AttributeType.STRING:
3208
+ raise TypeError(
3209
+ f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
3210
+ )
3211
+ if not isinstance(self.value, str):
3212
+ raise TypeError(f"Value of attribute '{self!r}' is not a string.")
3213
+ return self.value
3214
+
3215
+ def as_tensor(self) -> _protocols.TensorProtocol:
3216
+ """Get the attribute value as a tensor."""
3217
+ if self.type != _enums.AttributeType.TENSOR:
3218
+ raise TypeError(
3219
+ f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
3220
+ )
3221
+ if not isinstance(self.value, _protocols.TensorProtocol):
3222
+ raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
3223
+ return self.value
3224
+
3225
+ def as_graph(self) -> Graph:
3226
+ """Get the attribute value as a graph."""
3227
+ if self.type != _enums.AttributeType.GRAPH:
3228
+ raise TypeError(
3229
+ f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
3230
+ )
3231
+ if not isinstance(self.value, Graph):
3232
+ raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
3233
+ return self.value
3234
+
3235
+ def as_floats(self) -> Sequence[float]:
3236
+ """Get the attribute value as a sequence of floats."""
3237
+ if self.type != _enums.AttributeType.FLOATS:
3238
+ raise TypeError(
3239
+ f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
3240
+ )
3241
+ if not isinstance(self.value, Sequence):
3242
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3243
+ # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3244
+ # Create a copy of the list to prevent mutation
3245
+ return [float(v) for v in self.value]
3246
+
3247
+ def as_ints(self) -> Sequence[int]:
3248
+ """Get the attribute value as a sequence of ints."""
3249
+ if self.type != _enums.AttributeType.INTS:
3250
+ raise TypeError(
3251
+ f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
3252
+ )
3253
+ if not isinstance(self.value, Sequence):
3254
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3255
+ # Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
3256
+ # Create a copy of the list to prevent mutation
3257
+ return list(self.value)
3258
+
3259
+ def as_strings(self) -> Sequence[str]:
3260
+ """Get the attribute value as a sequence of strings."""
3261
+ if self.type != _enums.AttributeType.STRINGS:
3262
+ raise TypeError(
3263
+ f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
3264
+ )
3265
+ if not isinstance(self.value, Sequence):
3266
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3267
+ if onnx_ir.DEBUG:
3268
+ if not all(isinstance(x, str) for x in self.value):
3269
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
3270
+ # Create a copy of the list to prevent mutation
3271
+ return list(self.value)
3272
+
3273
+ def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
3274
+ """Get the attribute value as a sequence of tensors."""
3275
+ if self.type != _enums.AttributeType.TENSORS:
3276
+ raise TypeError(
3277
+ f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
3278
+ )
3279
+ if not isinstance(self.value, Sequence):
3280
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3281
+ if onnx_ir.DEBUG:
3282
+ if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
3283
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
3284
+ # Create a copy of the list to prevent mutation
3285
+ return list(self.value)
3286
+
3287
+ def as_graphs(self) -> Sequence[Graph]:
3288
+ """Get the attribute value as a sequence of graphs."""
3289
+ if self.type != _enums.AttributeType.GRAPHS:
3290
+ raise TypeError(
3291
+ f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
3292
+ )
3293
+ if not isinstance(self.value, Sequence):
3294
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
3295
+ if onnx_ir.DEBUG:
3296
+ if not all(isinstance(x, Graph) for x in self.value):
3297
+ raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
3298
+ # Create a copy of the list to prevent mutation
3299
+ return list(self.value)
3300
+
2700
3301
 
2701
3302
  # NOTE: The following functions are just for convenience
3303
+
3304
+
3305
+ def RefAttr(
3306
+ name: str,
3307
+ ref_attr_name: str,
3308
+ type: _enums.AttributeType,
3309
+ doc_string: str | None = None,
3310
+ ) -> Attr:
3311
+ """Create a reference attribute.
3312
+
3313
+ Args:
3314
+ name: The name of the attribute.
3315
+ type: The type of the attribute.
3316
+ ref_attr_name: The name of the referenced attribute.
3317
+ doc_string: Documentation string.
3318
+
3319
+ Returns:
3320
+ A reference attribute.
3321
+ """
3322
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
3323
+ return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
3324
+
3325
+
2702
3326
  def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
2703
3327
  """Create a float attribute."""
2704
3328
  # NOTE: The function name is capitalized to maintain API backward compatibility.