onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_core.py CHANGED
@@ -22,13 +22,12 @@ import os
22
22
  import sys
23
23
  import textwrap
24
24
  import typing
25
- from collections import OrderedDict
26
25
  from collections.abc import (
27
26
  Collection,
28
27
  Hashable,
29
28
  Iterable,
30
29
  Iterator,
31
- MutableMapping,
30
+ Mapping,
32
31
  MutableSequence,
33
32
  Sequence,
34
33
  )
@@ -252,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
252
251
  or corresponding dtypes from the ``ml_dtype`` package.
253
252
  """
254
253
  if dtype in _NON_NUMPY_NATIVE_TYPES:
255
- if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
254
+ if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
256
255
  raise TypeError(
257
256
  f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
258
257
  )
259
- if dtype.itemsize == 1 and array.dtype not in (
258
+ if dtype.bitwidth == 8 and array.dtype not in (
260
259
  np.uint8,
261
260
  ml_dtypes.float8_e4m3fnuz,
262
261
  ml_dtypes.float8_e4m3fn,
@@ -386,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
386
385
 
387
386
  Args:
388
387
  value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
389
- When the dtype is not one of the numpy native dtypes, the value needs
390
- to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
391
- when the value is a numpy array; ``dtype`` must be specified in this case.
388
+ When the dtype is not one of the numpy native dtypes, the value can
389
+ be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
390
+ and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
391
+ ``dtype`` must be specified in this case.
392
392
  dtype: The data type of the tensor. It can be None only when value is a numpy array.
393
393
  Users are responsible for making sure the dtype matches the value when value is not a numpy array.
394
394
  shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -422,7 +422,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
422
422
  self._dtype = _enums.DataType.from_numpy(value.dtype)
423
423
  else:
424
424
  raise ValueError(
425
- "The dtype must be specified when the value is not a numpy array."
425
+ "The dtype must be specified when the value is not a numpy array. "
426
+ "Value type: {type(value)}"
426
427
  )
427
428
  else:
428
429
  if isinstance(value, np.ndarray):
@@ -503,7 +504,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
503
504
  _enums.DataType.FLOAT4E2M1,
504
505
  }:
505
506
  # Pack the array into int4
506
- array = _type_casting.pack_int4(array)
507
+ array = _type_casting.pack_4bitx2(array)
507
508
  else:
508
509
  assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
509
510
  if not _IS_LITTLE_ENDIAN:
@@ -962,8 +963,151 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
962
963
  return self._evaluate().tobytes()
963
964
 
964
965
 
966
+ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
967
+ """A tensor that stores 4bit datatypes in packed format."""
968
+
969
+ __slots__ = (
970
+ "_dtype",
971
+ "_raw",
972
+ "_shape",
973
+ )
974
+
975
+ def __init__(
976
+ self,
977
+ value: TArrayCompatible,
978
+ dtype: _enums.DataType,
979
+ *,
980
+ shape: Shape | Sequence[int],
981
+ name: str | None = None,
982
+ doc_string: str | None = None,
983
+ metadata_props: dict[str, str] | None = None,
984
+ ) -> None:
985
+ """Initialize a tensor.
986
+
987
+ Args:
988
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
989
+ The value MUST be packed in an integer dtype.
990
+ dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
991
+ shape: The shape of the tensor.
992
+ name: The name of the tensor.
993
+ doc_string: The documentation string.
994
+ metadata_props: The metadata properties.
995
+
996
+ Raises:
997
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
998
+ TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
999
+ """
1000
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
1001
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
1002
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
1003
+ self._shape = Shape(shape)
1004
+ self._shape.freeze()
1005
+ if dtype.bitwidth != 4:
1006
+ raise TypeError(
1007
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
1008
+ )
1009
+ self._dtype = dtype
1010
+ self._raw = value
1011
+
1012
+ if isinstance(value, np.ndarray):
1013
+ if (
1014
+ value.dtype == ml_dtypes.float4_e2m1fn
1015
+ or value.dtype == ml_dtypes.uint4
1016
+ or value.dtype == ml_dtypes.int4
1017
+ ):
1018
+ raise TypeError(
1019
+ f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
1020
+ "Please pack the value or use `onnx_ir.Tensor`."
1021
+ )
1022
+ # Check after shape and dtype is set
1023
+ if value.size != self.nbytes:
1024
+ raise ValueError(
1025
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
1026
+ )
1027
+
1028
+ def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
1029
+ return self.numpy()
1030
+
1031
+ def __dlpack__(self, *, stream: Any = None) -> Any:
1032
+ if _compatible_with_dlpack(self._raw):
1033
+ return self._raw.__dlpack__(stream=stream)
1034
+ return self.__array__().__dlpack__(stream=stream)
1035
+
1036
+ def __dlpack_device__(self) -> tuple[int, int]:
1037
+ if _compatible_with_dlpack(self._raw):
1038
+ return self._raw.__dlpack_device__()
1039
+ return self.__array__().__dlpack_device__()
1040
+
1041
+ def __repr__(self) -> str:
1042
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
1043
+
1044
+ @property
1045
+ def dtype(self) -> _enums.DataType:
1046
+ """The data type of the tensor. Immutable."""
1047
+ return self._dtype
1048
+
1049
+ @property
1050
+ def shape(self) -> Shape:
1051
+ """The shape of the tensor. Immutable."""
1052
+ return self._shape
1053
+
1054
+ @property
1055
+ def raw(self) -> TArrayCompatible:
1056
+ """Backing data of the tensor. Immutable."""
1057
+ return self._raw # type: ignore[return-value]
1058
+
1059
+ def numpy(self) -> np.ndarray:
1060
+ """Return the tensor as a numpy array.
1061
+
1062
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1063
+ package are used. The values can be reinterpreted as bit representations
1064
+ using the ``.view()`` method.
1065
+ """
1066
+ array = self.numpy_packed()
1067
+ # ONNX IR returns the unpacked arrays
1068
+ if self.dtype == _enums.DataType.INT4:
1069
+ return _type_casting.unpack_int4(array, self.shape.numpy())
1070
+ if self.dtype == _enums.DataType.UINT4:
1071
+ return _type_casting.unpack_uint4(array, self.shape.numpy())
1072
+ if self.dtype == _enums.DataType.FLOAT4E2M1:
1073
+ return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1074
+ raise TypeError(
1075
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1076
+ )
1077
+
1078
+ def numpy_packed(self) -> npt.NDArray[np.uint8]:
1079
+ """Return the tensor as a packed array."""
1080
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
1081
+ array = np.asarray(self._raw)
1082
+ else:
1083
+ assert _compatible_with_dlpack(self._raw), (
1084
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
1085
+ )
1086
+ array = np.from_dlpack(self._raw)
1087
+ if array.nbytes != self.nbytes:
1088
+ raise ValueError(
1089
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
1090
+ )
1091
+ return array.view(np.uint8)
1092
+
1093
+ def tobytes(self) -> bytes:
1094
+ """Returns the value as bytes encoded in little endian.
1095
+
1096
+ Override this method for more efficient serialization when the raw
1097
+ value is not a numpy array.
1098
+ """
1099
+ array = self.numpy_packed()
1100
+ if not _IS_LITTLE_ENDIAN:
1101
+ array = array.view(array.dtype.newbyteorder("<"))
1102
+ return array.tobytes()
1103
+
1104
+
965
1105
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
966
- """Immutable symbolic dimension that can be shared across multiple shapes."""
1106
+ """Immutable symbolic dimension that can be shared across multiple shapes.
1107
+
1108
+ SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
1109
+ It is immutable and can be compared or hashed.
1110
+ """
967
1111
 
968
1112
  __slots__ = ("_value",)
969
1113
 
@@ -972,6 +1116,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
972
1116
 
973
1117
  Args:
974
1118
  value: The value of the dimension. It should not be an int.
1119
+
1120
+ Raises:
1121
+ TypeError: If value is an int.
975
1122
  """
976
1123
  if isinstance(value, int):
977
1124
  raise TypeError(
@@ -981,15 +1128,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
981
1128
  self._value = value
982
1129
 
983
1130
  def __eq__(self, other: object) -> bool:
1131
+ """Check equality with another SymbolicDim or string/None."""
984
1132
  if not isinstance(other, SymbolicDim):
985
1133
  return self.value == other
986
1134
  return self.value == other.value
987
1135
 
988
1136
  def __hash__(self) -> int:
1137
+ """Return the hash of the symbolic dimension value."""
989
1138
  return hash(self.value)
990
1139
 
991
1140
  @property
992
1141
  def value(self) -> str | None:
1142
+ """The value of the symbolic dimension (string or None)."""
993
1143
  return self._value
994
1144
 
995
1145
  def __str__(self) -> str:
@@ -1000,7 +1150,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1000
1150
 
1001
1151
 
1002
1152
  def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1003
- """Return True if the value is int compatible."""
1153
+ """Check if the value is compatible with int (i.e., can be safely cast to int).
1154
+
1155
+ Args:
1156
+ value: The value to check.
1157
+
1158
+ Returns:
1159
+ True if the value is an int or has an __int__ method, False otherwise.
1160
+ """
1004
1161
  if isinstance(value, int):
1005
1162
  return True
1006
1163
  if hasattr(value, "__int__"):
@@ -1012,7 +1169,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1012
1169
  def _maybe_convert_to_symbolic_dim(
1013
1170
  dim: int | SupportsInt | SymbolicDim | str | None,
1014
1171
  ) -> SymbolicDim | int:
1015
- """Convert the value to a SymbolicDim if it is not an int."""
1172
+ """Convert the value to a SymbolicDim if it is not an int.
1173
+
1174
+ Args:
1175
+ dim: The dimension value, which can be int, str, None, or SymbolicDim.
1176
+
1177
+ Returns:
1178
+ An int or SymbolicDim instance.
1179
+
1180
+ Raises:
1181
+ TypeError: If the value is not int, str, None, or SymbolicDim.
1182
+ """
1016
1183
  if dim is None or isinstance(dim, str):
1017
1184
  return SymbolicDim(dim)
1018
1185
  if _is_int_compatible(dim):
@@ -1025,21 +1192,20 @@ def _maybe_convert_to_symbolic_dim(
1025
1192
 
1026
1193
 
1027
1194
  class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1028
- """The shape of a tensor, including its dimensions and optional denotations.
1195
+ """Represents the shape of a tensor, including its dimensions and optional denotations.
1029
1196
 
1030
- The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1031
- symbolic dimensions.
1032
-
1033
- A shape can be compared to another shape or plain Python list.
1197
+ The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
1198
+ symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
1199
+ shapes to other shapes or plain Python lists.
1034
1200
 
1035
1201
  A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1036
1202
  unfrozen, making it suitable to be shared across tensors or values.
1037
- Call :method:`freeze` to freeze the shape.
1203
+ Call :meth:`freeze` to freeze the shape.
1038
1204
 
1039
- To update the dimension of a frozen shape, call :method:`copy` to create a
1205
+ To update the dimension of a frozen shape, call :meth:`copy` to create a
1040
1206
  new shape with the same dimensions that can be modified.
1041
1207
 
1042
- Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1208
+ Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
1043
1209
 
1044
1210
  Example::
1045
1211
 
@@ -1067,7 +1233,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1067
1233
 
1068
1234
  Attributes:
1069
1235
  dims: A tuple of dimensions representing the shape.
1070
- Each dimension can be an integer, None or a :class:`SymbolicDim`.
1236
+ Each dimension can be an integer, None, or a :class:`SymbolicDim`.
1071
1237
  frozen: Indicates whether the shape is immutable. When frozen, the shape
1072
1238
  cannot be modified or unfrozen.
1073
1239
  """
@@ -1122,7 +1288,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1122
1288
  """Whether the shape is frozen.
1123
1289
 
1124
1290
  When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1125
- Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1291
+ Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
1126
1292
  new shape with the same dimensions that can be modified.
1127
1293
  """
1128
1294
  return self._frozen
@@ -1290,19 +1456,24 @@ def _normalize_domain(domain: str) -> str:
1290
1456
  class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1291
1457
  """IR Node.
1292
1458
 
1293
- If the ``graph`` is provided, the node will be added to the graph. Otherwise,
1294
- user is responsible to call ``graph.append(node)`` (or other mutation methods
1459
+ .. tip::
1460
+ For a more convenient way (that supports Python objects
1461
+ as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
1462
+
1463
+ If ``graph`` is provided, the node will be added to the graph. Otherwise,
1464
+ the user is responsible for calling ``graph.append(node)`` (or other mutation methods
1295
1465
  in :class:`Graph`) to add the node to the graph.
1296
1466
 
1297
- After the node is initialized, it will add itself as a user of the input values.
1467
+ After the node is initialized, it will add itself as a user of its input values.
1298
1468
 
1299
1469
  The output values of the node are created during node initialization and are immutable.
1300
- To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1301
- the new output values by calling :meth:`replace_input_with` on the using nodes
1302
- of this node's outputs.
1470
+ To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
1471
+ replace the input in the consuming node by calling :meth:`replace_input_with`.
1472
+ You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
1473
+ to replace all uses of the output values.
1303
1474
 
1304
- .. note:
1305
- When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
1475
+ .. note::
1476
+ When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
1306
1477
  """
1307
1478
 
1308
1479
  __slots__ = (
@@ -1325,7 +1496,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1325
1496
  domain: str,
1326
1497
  op_type: str,
1327
1498
  inputs: Iterable[Value | None],
1328
- attributes: Iterable[Attr] = (),
1499
+ attributes: Iterable[Attr] | Mapping[str, Attr] = (),
1329
1500
  *,
1330
1501
  overload: str = "",
1331
1502
  num_outputs: int | None = None,
@@ -1340,7 +1511,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1340
1511
 
1341
1512
  Args:
1342
1513
  domain: The domain of the operator. For onnx operators, this is an empty string.
1343
- When it is `"ai.onnx"`, it is normalized to `""`.
1514
+ When it is ``"ai.onnx"``, it is normalized to ``""``.
1344
1515
  op_type: The name of the operator.
1345
1516
  inputs: The input values. When an input is ``None``, it is an empty input.
1346
1517
  attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
@@ -1371,15 +1542,10 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1371
1542
  self._inputs: tuple[Value | None, ...] = tuple(inputs)
1372
1543
  # Values belong to their defining nodes. The values list is immutable
1373
1544
  self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1374
- attributes = tuple(attributes)
1375
- if attributes and not isinstance(attributes[0], Attr):
1376
- raise TypeError(
1377
- f"Expected the attributes to be Attr, got {type(attributes[0])}. "
1378
- "If you are copying the attributes from another node, make sure you call "
1379
- "node.attributes.values() because it is a dictionary."
1380
- )
1381
- self._attributes: OrderedDict[str, Attr] = OrderedDict(
1382
- (attr.name, attr) for attr in attributes
1545
+ if isinstance(attributes, Mapping):
1546
+ attributes = tuple(attributes.values())
1547
+ self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
1548
+ attributes
1383
1549
  )
1384
1550
  self._overload: str = overload
1385
1551
  # TODO(justinchuby): Potentially support a version range
@@ -1637,8 +1803,16 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1637
1803
  raise AttributeError("outputs is immutable. Please create a new node instead.")
1638
1804
 
1639
1805
  @property
1640
- def attributes(self) -> OrderedDict[str, Attr]:
1641
- """The attributes of the node."""
1806
+ def attributes(self) -> _graph_containers.Attributes:
1807
+ """The attributes of the node as ``dict[str, Attr]`` with additional access methods.
1808
+
1809
+ Use it as a dictionary with keys being the attribute names and values being the
1810
+ :class:`Attr` objects.
1811
+
1812
+ Use ``node.attributes.add(attr)`` to add an attribute to the node.
1813
+ Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
1814
+ Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
1815
+ """
1642
1816
  return self._attributes
1643
1817
 
1644
1818
  @property
@@ -1805,12 +1979,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1805
1979
  The index of the output of the node that produces the value can be accessed with
1806
1980
  :meth:`index`.
1807
1981
 
1808
- To find all the nodes that use this value as an input, call :meth:`uses`.
1982
+ To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
1983
+ nodes can be obtained with :meth:`consumers`.
1809
1984
 
1810
1985
  To check if the value is an is an input, output or initializer of a graph,
1811
1986
  use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
1812
1987
 
1813
- Use :meth:`graph` to get the graph that owns the value.
1988
+ Use :attr:`graph` to get the graph that owns the value.
1814
1989
  """
1815
1990
 
1816
1991
  __slots__ = (
@@ -2201,17 +2376,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2201
2376
  # Private fields that are not to be accessed by any other classes
2202
2377
  self._inputs = _graph_containers.GraphInputs(self, inputs)
2203
2378
  self._outputs = _graph_containers.GraphOutputs(self, outputs)
2204
- self._initializers = _graph_containers.GraphInitializers(self)
2205
- for initializer in initializers:
2206
- if isinstance(initializer, str):
2207
- raise TypeError(
2208
- "Initializer must be a Value, not a string. "
2209
- "If you are copying the initializers from another graph, "
2210
- "make sure you call graph.initializers.values() because it is a dictionary."
2211
- )
2212
- if initializer.name is None:
2213
- raise ValueError(f"Initializer must have a name: {initializer}")
2214
- self._initializers[initializer.name] = initializer
2379
+ self._initializers = _graph_containers.GraphInitializers(
2380
+ self, {initializer.name: initializer for initializer in initializers}
2381
+ )
2215
2382
  self._doc_string = doc_string
2216
2383
  self._opset_imports = opset_imports or {}
2217
2384
  self._metadata: _metadata.MetadataStore | None = None
@@ -2234,7 +2401,19 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2234
2401
  return self._outputs
2235
2402
 
2236
2403
  @property
2237
- def initializers(self) -> MutableMapping[str, Value]:
2404
+ def initializers(self) -> _graph_containers.GraphInitializers:
2405
+ """The initializers of the graph as a ``dict[str, Value]``.
2406
+
2407
+ The keys are the names of the initializers. The values are the :class:`Value` objects.
2408
+
2409
+ This property additionally supports the ``add`` method, which takes a :class:`Value`
2410
+ and adds it to the initializers if it is not already present.
2411
+
2412
+ .. note::
2413
+ When setting an initializer with ``graph.initializers[key] = value``,
2414
+ if the value does not have a name, it will be assigned ``key`` as its name.
2415
+
2416
+ """
2238
2417
  return self._initializers
2239
2418
 
2240
2419
  def register_initializer(self, value: Value) -> None:
@@ -2263,15 +2442,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2263
2442
  " it is not the same object: existing={self._initializers[value.name]!r},"
2264
2443
  f" new={value!r}"
2265
2444
  )
2266
- if value.producer() is not None:
2267
- raise ValueError(
2268
- f"Value '{value!r}' is produced by a node and cannot be an initializer."
2269
- )
2270
2445
  if value.const_value is None:
2271
2446
  raise ValueError(
2272
2447
  f"Value '{value!r}' must have its const_value set to be an initializer."
2273
2448
  )
2274
- self._initializers[value.name] = value
2449
+ self._initializers.add(value)
2275
2450
 
2276
2451
  @property
2277
2452
  def doc_string(self) -> str | None:
@@ -2363,6 +2538,28 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2363
2538
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2364
2539
  return len(self)
2365
2540
 
2541
+ def all_nodes(self) -> Iterator[Node]:
2542
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
2543
+
2544
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
2545
+ Consider using
2546
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2547
+ traversals on nodes.
2548
+ """
2549
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2550
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
2551
+
2552
+ def subgraphs(self) -> Iterator[Graph]:
2553
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time."""
2554
+ seen_graphs: set[Graph] = set()
2555
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2556
+ graph = node.graph
2557
+ if graph is self:
2558
+ continue
2559
+ if graph is not None and graph not in seen_graphs:
2560
+ seen_graphs.add(graph)
2561
+ yield graph
2562
+
2366
2563
  # Mutation methods
2367
2564
  def append(self, node: Node, /) -> None:
2368
2565
  """Append a node to the graph in O(1) time.
@@ -2701,7 +2898,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2701
2898
  outputs: Sequence[Value],
2702
2899
  *,
2703
2900
  nodes: Iterable[Node],
2704
- initializers: Sequence[_protocols.ValueProtocol] = (),
2901
+ initializers: Sequence[Value] = (),
2705
2902
  doc_string: str | None = None,
2706
2903
  opset_imports: dict[str, int] | None = None,
2707
2904
  name: str | None = None,
@@ -2710,10 +2907,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
2710
2907
  self.name = name
2711
2908
  self.inputs = tuple(inputs)
2712
2909
  self.outputs = tuple(outputs)
2713
- for initializer in initializers:
2714
- if initializer.name is None:
2715
- raise ValueError(f"Initializer must have a name: {initializer}")
2716
- self.initializers = {tensor.name: tensor for tensor in initializers}
2910
+ self.initializers = {initializer.name: initializer for initializer in initializers}
2717
2911
  self.doc_string = doc_string
2718
2912
  self.opset_imports = opset_imports or {}
2719
2913
  self._metadata: _metadata.MetadataStore | None = None
@@ -2871,7 +3065,7 @@ Model(
2871
3065
  """Get all graphs and subgraphs in the model.
2872
3066
 
2873
3067
  This is a convenience method to traverse the model. Consider using
2874
- `onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3068
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2875
3069
  traversals on nodes.
2876
3070
  """
2877
3071
  # NOTE(justinchuby): Given
@@ -2880,11 +3074,8 @@ Model(
2880
3074
  # (3) Users familiar with onnxruntime optimization tools expect this method
2881
3075
  # I created this method as a core method instead of an iterator in
2882
3076
  # `traversal.py`.
2883
- seen_graphs: set[Graph] = set()
2884
- for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2885
- if node.graph is not None and node.graph not in seen_graphs:
2886
- seen_graphs.add(node.graph)
2887
- yield node.graph
3077
+ yield self.graph
3078
+ yield from self.graph.subgraphs()
2888
3079
 
2889
3080
 
2890
3081
  class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
@@ -2927,13 +3118,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2927
3118
  # Ensure the inputs and outputs of the function belong to a graph
2928
3119
  # and not from an outer scope
2929
3120
  graph: Graph,
2930
- attributes: Sequence[Attr],
3121
+ attributes: Iterable[Attr] | Mapping[str, Attr],
2931
3122
  ) -> None:
2932
3123
  self._domain = domain
2933
3124
  self._name = name
2934
3125
  self._overload = overload
2935
3126
  self._graph = graph
2936
- self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
3127
+ if isinstance(attributes, Mapping):
3128
+ attributes = tuple(attributes.values())
3129
+ self._attributes = _graph_containers.Attributes(attributes)
2937
3130
 
2938
3131
  def identifier(self) -> _protocols.OperatorIdentifier:
2939
3132
  return self.domain, self.name, self.overload
@@ -2971,7 +3164,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
2971
3164
  return self._graph.outputs
2972
3165
 
2973
3166
  @property
2974
- def attributes(self) -> OrderedDict[str, Attr]:
3167
+ def attributes(self) -> _graph_containers.Attributes:
2975
3168
  return self._attributes
2976
3169
 
2977
3170
  @typing.overload
@@ -3016,6 +3209,28 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3016
3209
  def metadata_props(self) -> dict[str, str]:
3017
3210
  return self._graph.metadata_props
3018
3211
 
3212
+ def all_nodes(self) -> Iterator[Node]:
3213
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3214
+
3215
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3216
+ Consider using
3217
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3218
+ traversals on nodes.
3219
+ """
3220
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3221
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
3222
+
3223
+ def subgraphs(self) -> Iterator[Graph]:
3224
+ """Get all subgraphs in the function in O(#nodes + #attributes) time."""
3225
+ seen_graphs: set[Graph] = set()
3226
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3227
+ graph = node.graph
3228
+ if graph is self._graph:
3229
+ continue
3230
+ if graph is not None and graph not in seen_graphs:
3231
+ seen_graphs.add(graph)
3232
+ yield graph
3233
+
3019
3234
  # Mutation methods
3020
3235
  def append(self, node: Node, /) -> None:
3021
3236
  """Append a node to the function in O(1) time."""