onnx-ir 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_core.py CHANGED
@@ -251,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
251
251
  or corresponding dtypes from the ``ml_dtype`` package.
252
252
  """
253
253
  if dtype in _NON_NUMPY_NATIVE_TYPES:
254
- if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
254
+ if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
255
255
  raise TypeError(
256
256
  f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
257
257
  )
258
- if dtype.itemsize == 1 and array.dtype not in (
258
+ if dtype.bitwidth == 8 and array.dtype not in (
259
259
  np.uint8,
260
260
  ml_dtypes.float8_e4m3fnuz,
261
261
  ml_dtypes.float8_e4m3fn,
@@ -385,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
385
385
 
386
386
  Args:
387
387
  value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
388
- When the dtype is not one of the numpy native dtypes, the value needs
389
- to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
390
- when the value is a numpy array; ``dtype`` must be specified in this case.
388
+ When the dtype is not one of the numpy native dtypes, the value can
389
+ be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
390
+ and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
391
+ ``dtype`` must be specified in this case.
391
392
  dtype: The data type of the tensor. It can be None only when value is a numpy array.
392
393
  Users are responsible for making sure the dtype matches the value when value is not a numpy array.
393
394
  shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -416,12 +417,16 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
416
417
  else:
417
418
  self._shape = shape
418
419
  self._shape.freeze()
420
+ if isinstance(value, np.generic):
421
+ # Turn numpy scalar into a numpy array
422
+ value = np.array(value) # type: ignore[assignment]
419
423
  if dtype is None:
420
424
  if isinstance(value, np.ndarray):
421
425
  self._dtype = _enums.DataType.from_numpy(value.dtype)
422
426
  else:
423
427
  raise ValueError(
424
- "The dtype must be specified when the value is not a numpy array."
428
+ "The dtype must be specified when the value is not a numpy array. "
429
+ "Value type: {type(value)}"
425
430
  )
426
431
  else:
427
432
  if isinstance(value, np.ndarray):
@@ -502,7 +507,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
502
507
  _enums.DataType.FLOAT4E2M1,
503
508
  }:
504
509
  # Pack the array into int4
505
- array = _type_casting.pack_int4(array)
510
+ array = _type_casting.pack_4bitx2(array)
506
511
  else:
507
512
  assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
508
513
  if not _IS_LITTLE_ENDIAN:
@@ -961,8 +966,154 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
961
966
  return self._evaluate().tobytes()
962
967
 
963
968
 
969
+ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
970
+ """A tensor that stores 4bit datatypes in packed format.
971
+
972
+ .. versionadded:: 0.1.2
973
+ """
974
+
975
+ __slots__ = (
976
+ "_dtype",
977
+ "_raw",
978
+ "_shape",
979
+ )
980
+
981
+ def __init__(
982
+ self,
983
+ value: TArrayCompatible,
984
+ dtype: _enums.DataType,
985
+ *,
986
+ shape: Shape | Sequence[int],
987
+ name: str | None = None,
988
+ doc_string: str | None = None,
989
+ metadata_props: dict[str, str] | None = None,
990
+ ) -> None:
991
+ """Initialize a tensor.
992
+
993
+ Args:
994
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
995
+ The value MUST be packed in an integer dtype.
996
+ dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
997
+ shape: The shape of the tensor.
998
+ name: The name of the tensor.
999
+ doc_string: The documentation string.
1000
+ metadata_props: The metadata properties.
1001
+
1002
+ Raises:
1003
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
1004
+ TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
1005
+ """
1006
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
1007
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
1008
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
1009
+ self._shape = Shape(shape)
1010
+ self._shape.freeze()
1011
+ if dtype.bitwidth != 4:
1012
+ raise TypeError(
1013
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
1014
+ )
1015
+ self._dtype = dtype
1016
+ self._raw = value
1017
+
1018
+ if isinstance(value, np.ndarray):
1019
+ if (
1020
+ value.dtype == ml_dtypes.float4_e2m1fn
1021
+ or value.dtype == ml_dtypes.uint4
1022
+ or value.dtype == ml_dtypes.int4
1023
+ ):
1024
+ raise TypeError(
1025
+ f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
1026
+ "Please pack the value or use `onnx_ir.Tensor`."
1027
+ )
1028
+ # Check after shape and dtype is set
1029
+ if value.size != self.nbytes:
1030
+ raise ValueError(
1031
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
1032
+ )
1033
+
1034
+ def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
1035
+ return self.numpy()
1036
+
1037
+ def __dlpack__(self, *, stream: Any = None) -> Any:
1038
+ if _compatible_with_dlpack(self._raw):
1039
+ return self._raw.__dlpack__(stream=stream)
1040
+ return self.__array__().__dlpack__(stream=stream)
1041
+
1042
+ def __dlpack_device__(self) -> tuple[int, int]:
1043
+ if _compatible_with_dlpack(self._raw):
1044
+ return self._raw.__dlpack_device__()
1045
+ return self.__array__().__dlpack_device__()
1046
+
1047
+ def __repr__(self) -> str:
1048
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
1049
+
1050
+ @property
1051
+ def dtype(self) -> _enums.DataType:
1052
+ """The data type of the tensor. Immutable."""
1053
+ return self._dtype
1054
+
1055
+ @property
1056
+ def shape(self) -> Shape:
1057
+ """The shape of the tensor. Immutable."""
1058
+ return self._shape
1059
+
1060
+ @property
1061
+ def raw(self) -> TArrayCompatible:
1062
+ """Backing data of the tensor. Immutable."""
1063
+ return self._raw # type: ignore[return-value]
1064
+
1065
+ def numpy(self) -> np.ndarray:
1066
+ """Return the tensor as a numpy array.
1067
+
1068
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1069
+ package are used. The values can be reinterpreted as bit representations
1070
+ using the ``.view()`` method.
1071
+ """
1072
+ array = self.numpy_packed()
1073
+ # ONNX IR returns the unpacked arrays
1074
+ if self.dtype == _enums.DataType.INT4:
1075
+ return _type_casting.unpack_int4(array, self.shape.numpy())
1076
+ if self.dtype == _enums.DataType.UINT4:
1077
+ return _type_casting.unpack_uint4(array, self.shape.numpy())
1078
+ if self.dtype == _enums.DataType.FLOAT4E2M1:
1079
+ return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1080
+ raise TypeError(
1081
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1082
+ )
1083
+
1084
+ def numpy_packed(self) -> npt.NDArray[np.uint8]:
1085
+ """Return the tensor as a packed array."""
1086
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
1087
+ array = np.asarray(self._raw)
1088
+ else:
1089
+ assert _compatible_with_dlpack(self._raw), (
1090
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
1091
+ )
1092
+ array = np.from_dlpack(self._raw)
1093
+ if array.nbytes != self.nbytes:
1094
+ raise ValueError(
1095
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
1096
+ )
1097
+ return array.view(np.uint8)
1098
+
1099
+ def tobytes(self) -> bytes:
1100
+ """Returns the value as bytes encoded in little endian.
1101
+
1102
+ Override this method for more efficient serialization when the raw
1103
+ value is not a numpy array.
1104
+ """
1105
+ array = self.numpy_packed()
1106
+ if not _IS_LITTLE_ENDIAN:
1107
+ array = array.view(array.dtype.newbyteorder("<"))
1108
+ return array.tobytes()
1109
+
1110
+
964
1111
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
965
- """Immutable symbolic dimension that can be shared across multiple shapes."""
1112
+ """Immutable symbolic dimension that can be shared across multiple shapes.
1113
+
1114
+ SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
1115
+ It is immutable and can be compared or hashed.
1116
+ """
966
1117
 
967
1118
  __slots__ = ("_value",)
968
1119
 
@@ -971,6 +1122,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
971
1122
 
972
1123
  Args:
973
1124
  value: The value of the dimension. It should not be an int.
1125
+
1126
+ Raises:
1127
+ TypeError: If value is an int.
974
1128
  """
975
1129
  if isinstance(value, int):
976
1130
  raise TypeError(
@@ -980,15 +1134,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
980
1134
  self._value = value
981
1135
 
982
1136
  def __eq__(self, other: object) -> bool:
1137
+ """Check equality with another SymbolicDim or string/None."""
983
1138
  if not isinstance(other, SymbolicDim):
984
1139
  return self.value == other
985
1140
  return self.value == other.value
986
1141
 
987
1142
  def __hash__(self) -> int:
1143
+ """Return the hash of the symbolic dimension value."""
988
1144
  return hash(self.value)
989
1145
 
990
1146
  @property
991
1147
  def value(self) -> str | None:
1148
+ """The value of the symbolic dimension (string or None)."""
992
1149
  return self._value
993
1150
 
994
1151
  def __str__(self) -> str:
@@ -999,7 +1156,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
999
1156
 
1000
1157
 
1001
1158
  def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1002
- """Return True if the value is int compatible."""
1159
+ """Check if the value is compatible with int (i.e., can be safely cast to int).
1160
+
1161
+ Args:
1162
+ value: The value to check.
1163
+
1164
+ Returns:
1165
+ True if the value is an int or has an __int__ method, False otherwise.
1166
+ """
1003
1167
  if isinstance(value, int):
1004
1168
  return True
1005
1169
  if hasattr(value, "__int__"):
@@ -1011,7 +1175,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1011
1175
  def _maybe_convert_to_symbolic_dim(
1012
1176
  dim: int | SupportsInt | SymbolicDim | str | None,
1013
1177
  ) -> SymbolicDim | int:
1014
- """Convert the value to a SymbolicDim if it is not an int."""
1178
+ """Convert the value to a SymbolicDim if it is not an int.
1179
+
1180
+ Args:
1181
+ dim: The dimension value, which can be int, str, None, or SymbolicDim.
1182
+
1183
+ Returns:
1184
+ An int or SymbolicDim instance.
1185
+
1186
+ Raises:
1187
+ TypeError: If the value is not int, str, None, or SymbolicDim.
1188
+ """
1015
1189
  if dim is None or isinstance(dim, str):
1016
1190
  return SymbolicDim(dim)
1017
1191
  if _is_int_compatible(dim):
@@ -1024,21 +1198,20 @@ def _maybe_convert_to_symbolic_dim(
1024
1198
 
1025
1199
 
1026
1200
  class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1027
- """The shape of a tensor, including its dimensions and optional denotations.
1028
-
1029
- The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1030
- symbolic dimensions.
1201
+ """Represents the shape of a tensor, including its dimensions and optional denotations.
1031
1202
 
1032
- A shape can be compared to another shape or plain Python list.
1203
+ The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
1204
+ symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
1205
+ shapes to other shapes or plain Python lists.
1033
1206
 
1034
1207
  A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1035
1208
  unfrozen, making it suitable to be shared across tensors or values.
1036
- Call :method:`freeze` to freeze the shape.
1209
+ Call :meth:`freeze` to freeze the shape.
1037
1210
 
1038
- To update the dimension of a frozen shape, call :method:`copy` to create a
1211
+ To update the dimension of a frozen shape, call :meth:`copy` to create a
1039
1212
  new shape with the same dimensions that can be modified.
1040
1213
 
1041
- Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1214
+ Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
1042
1215
 
1043
1216
  Example::
1044
1217
 
@@ -1066,7 +1239,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1066
1239
 
1067
1240
  Attributes:
1068
1241
  dims: A tuple of dimensions representing the shape.
1069
- Each dimension can be an integer, None or a :class:`SymbolicDim`.
1242
+ Each dimension can be an integer, None, or a :class:`SymbolicDim`.
1070
1243
  frozen: Indicates whether the shape is immutable. When frozen, the shape
1071
1244
  cannot be modified or unfrozen.
1072
1245
  """
@@ -1121,7 +1294,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1121
1294
  """Whether the shape is frozen.
1122
1295
 
1123
1296
  When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1124
- Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1297
+ Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
1125
1298
  new shape with the same dimensions that can be modified.
1126
1299
  """
1127
1300
  return self._frozen
@@ -1289,19 +1462,24 @@ def _normalize_domain(domain: str) -> str:
1289
1462
  class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1290
1463
  """IR Node.
1291
1464
 
1292
- If the ``graph`` is provided, the node will be added to the graph. Otherwise,
1293
- user is responsible to call ``graph.append(node)`` (or other mutation methods
1465
+ .. tip::
1466
+ For a more convenient way (that supports Python objects
1467
+ as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
1468
+
1469
+ If ``graph`` is provided, the node will be added to the graph. Otherwise,
1470
+ the user is responsible for calling ``graph.append(node)`` (or other mutation methods
1294
1471
  in :class:`Graph`) to add the node to the graph.
1295
1472
 
1296
- After the node is initialized, it will add itself as a user of the input values.
1473
+ After the node is initialized, it will add itself as a user of its input values.
1297
1474
 
1298
1475
  The output values of the node are created during node initialization and are immutable.
1299
- To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1300
- the new output values by calling :meth:`replace_input_with` on the using nodes
1301
- of this node's outputs.
1476
+ To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
1477
+ replace the input in the consuming node by calling :meth:`replace_input_with`.
1478
+ You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
1479
+ to replace all uses of the output values.
1302
1480
 
1303
- .. note:
1304
- When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
1481
+ .. note::
1482
+ When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
1305
1483
  """
1306
1484
 
1307
1485
  __slots__ = (
@@ -1339,7 +1517,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1339
1517
 
1340
1518
  Args:
1341
1519
  domain: The domain of the operator. For onnx operators, this is an empty string.
1342
- When it is `"ai.onnx"`, it is normalized to `""`.
1520
+ When it is ``"ai.onnx"``, it is normalized to ``""``.
1343
1521
  op_type: The name of the operator.
1344
1522
  inputs: The input values. When an input is ``None``, it is an empty input.
1345
1523
  attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
@@ -1632,7 +1810,15 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1632
1810
 
1633
1811
  @property
1634
1812
  def attributes(self) -> _graph_containers.Attributes:
1635
- """The attributes of the node."""
1813
+ """The attributes of the node as ``dict[str, Attr]`` with additional access methods.
1814
+
1815
+ Use it as a dictionary with keys being the attribute names and values being the
1816
+ :class:`Attr` objects.
1817
+
1818
+ Use ``node.attributes.add(attr)`` to add an attribute to the node.
1819
+ Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
1820
+ Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
1821
+ """
1636
1822
  return self._attributes
1637
1823
 
1638
1824
  @property
@@ -1799,12 +1985,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1799
1985
  The index of the output of the node that produces the value can be accessed with
1800
1986
  :meth:`index`.
1801
1987
 
1802
- To find all the nodes that use this value as an input, call :meth:`uses`.
1988
+ To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
1989
+ nodes can be obtained with :meth:`consumers`.
1803
1990
 
1804
1991
  To check if the value is an is an input, output or initializer of a graph,
1805
1992
  use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
1806
1993
 
1807
- Use :meth:`graph` to get the graph that owns the value.
1994
+ Use :attr:`graph` to get the graph that owns the value.
1808
1995
  """
1809
1996
 
1810
1997
  __slots__ = (
@@ -2154,6 +2341,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2154
2341
  seen as a Sequence of nodes and should be used as such. For example, to obtain
2155
2342
  all nodes as a list, call ``list(graph)``.
2156
2343
 
2344
+ .. versionchanged:: 0.1.1
2345
+ Values with non-none producers will be rejected as graph inputs or initializers.
2346
+
2347
+ .. versionadded:: 0.1.1
2348
+ Added ``add`` method to initializers and attributes.
2349
+
2157
2350
  Attributes:
2158
2351
  name: The name of the graph.
2159
2352
  inputs: The input values of the graph.
@@ -2221,7 +2414,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2221
2414
 
2222
2415
  @property
2223
2416
  def initializers(self) -> _graph_containers.GraphInitializers:
2224
- """The initializers of the graph as a ``MutableMapping[str, Value]``.
2417
+ """The initializers of the graph as a ``dict[str, Value]``.
2225
2418
 
2226
2419
  The keys are the names of the initializers. The values are the :class:`Value` objects.
2227
2420
 
@@ -2357,6 +2550,33 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2357
2550
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2358
2551
  return len(self)
2359
2552
 
2553
+ def all_nodes(self) -> Iterator[Node]:
2554
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
2555
+
2556
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
2557
+ Consider using
2558
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2559
+ traversals on nodes.
2560
+
2561
+ .. versionadded:: 0.1.2
2562
+ """
2563
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2564
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
2565
+
2566
+ def subgraphs(self) -> Iterator[Graph]:
2567
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time.
2568
+
2569
+ .. versionadded:: 0.1.2
2570
+ """
2571
+ seen_graphs: set[Graph] = set()
2572
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2573
+ graph = node.graph
2574
+ if graph is self:
2575
+ continue
2576
+ if graph is not None and graph not in seen_graphs:
2577
+ seen_graphs.add(graph)
2578
+ yield graph
2579
+
2360
2580
  # Mutation methods
2361
2581
  def append(self, node: Node, /) -> None:
2362
2582
  """Append a node to the graph in O(1) time.
@@ -2862,7 +3082,7 @@ Model(
2862
3082
  """Get all graphs and subgraphs in the model.
2863
3083
 
2864
3084
  This is a convenience method to traverse the model. Consider using
2865
- `onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3085
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2866
3086
  traversals on nodes.
2867
3087
  """
2868
3088
  # NOTE(justinchuby): Given
@@ -2871,11 +3091,8 @@ Model(
2871
3091
  # (3) Users familiar with onnxruntime optimization tools expect this method
2872
3092
  # I created this method as a core method instead of an iterator in
2873
3093
  # `traversal.py`.
2874
- seen_graphs: set[Graph] = set()
2875
- for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2876
- if node.graph is not None and node.graph not in seen_graphs:
2877
- seen_graphs.add(node.graph)
2878
- yield node.graph
3094
+ yield self.graph
3095
+ yield from self.graph.subgraphs()
2879
3096
 
2880
3097
 
2881
3098
  class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
@@ -3009,6 +3226,33 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3009
3226
  def metadata_props(self) -> dict[str, str]:
3010
3227
  return self._graph.metadata_props
3011
3228
 
3229
+ def all_nodes(self) -> Iterator[Node]:
3230
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3231
+
3232
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3233
+ Consider using
3234
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3235
+ traversals on nodes.
3236
+
3237
+ .. versionadded:: 0.1.2
3238
+ """
3239
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3240
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
3241
+
3242
+ def subgraphs(self) -> Iterator[Graph]:
3243
+ """Get all subgraphs in the function in O(#nodes + #attributes) time.
3244
+
3245
+ .. versionadded:: 0.1.2
3246
+ """
3247
+ seen_graphs: set[Graph] = set()
3248
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3249
+ graph = node.graph
3250
+ if graph is self._graph:
3251
+ continue
3252
+ if graph is not None and graph not in seen_graphs:
3253
+ seen_graphs.add(graph)
3254
+ yield graph
3255
+
3012
3256
  # Mutation methods
3013
3257
  def append(self, node: Node, /) -> None:
3014
3258
  """Append a node to the function in O(1) time."""
onnx_ir/_enums.py CHANGED
@@ -114,7 +114,20 @@ class DataType(enum.IntEnum):
114
114
  @property
115
115
  def itemsize(self) -> float:
116
116
  """Returns the size of the data type in bytes."""
117
- return _ITEMSIZE_MAP[self]
117
+ return self.bitwidth / 8
118
+
119
+ @property
120
+ def bitwidth(self) -> int:
121
+ """Returns the bit width of the data type.
122
+
123
+ .. versionadded:: 0.1.2
124
+
125
+ Raises:
126
+ TypeError: If the data type is not supported.
127
+ """
128
+ if self not in _BITWIDTH_MAP:
129
+ raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
130
+ return _BITWIDTH_MAP[self]
118
131
 
119
132
  def numpy(self) -> np.dtype:
120
133
  """Returns the numpy dtype for the ONNX data type.
@@ -163,30 +176,29 @@ class DataType(enum.IntEnum):
163
176
  return self.__repr__()
164
177
 
165
178
 
166
- _ITEMSIZE_MAP = {
167
- DataType.FLOAT: 4,
168
- DataType.UINT8: 1,
169
- DataType.INT8: 1,
170
- DataType.UINT16: 2,
171
- DataType.INT16: 2,
172
- DataType.INT32: 4,
173
- DataType.INT64: 8,
174
- DataType.STRING: 1,
175
- DataType.BOOL: 1,
176
- DataType.FLOAT16: 2,
177
- DataType.DOUBLE: 8,
178
- DataType.UINT32: 4,
179
- DataType.UINT64: 8,
180
- DataType.COMPLEX64: 8,
181
- DataType.COMPLEX128: 16,
182
- DataType.BFLOAT16: 2,
183
- DataType.FLOAT8E4M3FN: 1,
184
- DataType.FLOAT8E4M3FNUZ: 1,
185
- DataType.FLOAT8E5M2: 1,
186
- DataType.FLOAT8E5M2FNUZ: 1,
187
- DataType.UINT4: 0.5,
188
- DataType.INT4: 0.5,
189
- DataType.FLOAT4E2M1: 0.5,
179
+ _BITWIDTH_MAP = {
180
+ DataType.FLOAT: 32,
181
+ DataType.UINT8: 8,
182
+ DataType.INT8: 8,
183
+ DataType.UINT16: 16,
184
+ DataType.INT16: 16,
185
+ DataType.INT32: 32,
186
+ DataType.INT64: 64,
187
+ DataType.BOOL: 8,
188
+ DataType.FLOAT16: 16,
189
+ DataType.DOUBLE: 64,
190
+ DataType.UINT32: 32,
191
+ DataType.UINT64: 64,
192
+ DataType.COMPLEX64: 64, # 2 * 32
193
+ DataType.COMPLEX128: 128, # 2 * 64
194
+ DataType.BFLOAT16: 16,
195
+ DataType.FLOAT8E4M3FN: 8,
196
+ DataType.FLOAT8E4M3FNUZ: 8,
197
+ DataType.FLOAT8E5M2: 8,
198
+ DataType.FLOAT8E5M2FNUZ: 8,
199
+ DataType.UINT4: 4,
200
+ DataType.INT4: 4,
201
+ DataType.FLOAT4E2M1: 4,
190
202
  }
191
203
 
192
204
 
@@ -216,7 +216,7 @@ class GraphOutputs(_GraphIO):
216
216
 
217
217
 
218
218
  class GraphInitializers(collections.UserDict[str, "_core.Value"]):
219
- """The initializers of a Graph."""
219
+ """The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
220
220
 
221
221
  def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
222
222
  # Perform checks first in _set_graph before modifying the data structure with super().__init__()
@@ -291,7 +291,7 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
291
291
 
292
292
 
293
293
  class Attributes(collections.UserDict[str, "_core.Attr"]):
294
- """The attributes of a Node."""
294
+ """The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
295
295
 
296
296
  def __init__(self, attrs: Iterable[_core.Attr]):
297
297
  super().__init__({attr.name: attr for attr in attrs})
onnx_ir/_io.py CHANGED
@@ -7,10 +7,11 @@ from __future__ import annotations
7
7
  __all__ = ["load", "save"]
8
8
 
9
9
  import os
10
+ from typing import Callable
10
11
 
11
- import onnx
12
+ import onnx # noqa: TID251
12
13
 
13
- from onnx_ir import _core, serde
14
+ from onnx_ir import _core, _protocols, serde
14
15
  from onnx_ir import external_data as _external_data
15
16
  from onnx_ir._polyfill import zip
16
17
 
@@ -43,6 +44,8 @@ def save(
43
44
  format: str | None = None,
44
45
  external_data: str | os.PathLike | None = None,
45
46
  size_threshold_bytes: int = 256,
47
+ callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
48
+ | None = None,
46
49
  ) -> None:
47
50
  """Save an ONNX model to a file.
48
51
 
@@ -52,6 +55,30 @@ def save(
52
55
  to load the newly saved model, or provide a different external data path that
53
56
  is not currently referenced by any tensors in the model.
54
57
 
58
+ .. tip::
59
+
60
+ A simple progress bar can be implemented by passing a callback function as the following::
61
+
62
+ import onnx_ir as ir
63
+ import tqdm
64
+
65
+ with tqdm.tqdm() as pbar:
66
+ total_set = False
67
+
68
+ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
69
+ nonlocal total_set
70
+ if not total_set:
71
+ pbar.total = metadata.total
72
+ total_set = True
73
+
74
+ pbar.update()
75
+ pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
76
+
77
+ ir.save(
78
+ ...,
79
+ callback=callback,
80
+ )
81
+
55
82
  Args:
56
83
  model: The model to save.
57
84
  path: The path to save the model to. E.g. "model.onnx".
@@ -65,6 +92,8 @@ def save(
65
92
  it will be serialized in the ONNX Proto message.
66
93
  size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
67
94
  Effective only when ``external_data`` is set.
95
+ callback: A callback function that is called for each tensor that is saved to external data
96
+ for debugging or logging purposes.
68
97
 
69
98
  Raises:
70
99
  ValueError: If the external data path is an absolute path.
@@ -77,12 +106,19 @@ def save(
77
106
  base_dir = os.path.dirname(path)
78
107
 
79
108
  # Store the original initializer values so they can be restored if modify_model=False
80
- initializer_values = tuple(model.graph.initializers.values())
109
+ initializer_values: list[_core.Value] = []
110
+ for graph in model.graphs():
111
+ # Collect from all subgraphs as well
112
+ initializer_values.extend(graph.initializers.values())
81
113
  tensors = [v.const_value for v in initializer_values]
82
114
 
83
115
  try:
84
116
  model = _external_data.unload_from_model(
85
- model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
117
+ model,
118
+ base_dir,
119
+ external_data,
120
+ size_threshold_bytes=size_threshold_bytes,
121
+ callback=callback,
86
122
  )
87
123
  proto = serde.serialize_model(model)
88
124
  onnx.save(proto, path, format=format)