onnx-ir 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_core.py CHANGED
@@ -251,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
251
251
  or corresponding dtypes from the ``ml_dtype`` package.
252
252
  """
253
253
  if dtype in _NON_NUMPY_NATIVE_TYPES:
254
- if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
254
+ if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
255
255
  raise TypeError(
256
256
  f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
257
257
  )
258
- if dtype.itemsize == 1 and array.dtype not in (
258
+ if dtype.bitwidth == 8 and array.dtype not in (
259
259
  np.uint8,
260
260
  ml_dtypes.float8_e4m3fnuz,
261
261
  ml_dtypes.float8_e4m3fn,
@@ -385,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
385
385
 
386
386
  Args:
387
387
  value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
388
- When the dtype is not one of the numpy native dtypes, the value needs
389
- to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
390
- when the value is a numpy array; ``dtype`` must be specified in this case.
388
+ When the dtype is not one of the numpy native dtypes, the value can
389
+ be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
390
+ and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
391
+ ``dtype`` must be specified in this case.
391
392
  dtype: The data type of the tensor. It can be None only when value is a numpy array.
392
393
  Users are responsible for making sure the dtype matches the value when value is not a numpy array.
393
394
  shape: The shape of the tensor. If None, the shape is obtained from the value.
@@ -421,7 +422,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
421
422
  self._dtype = _enums.DataType.from_numpy(value.dtype)
422
423
  else:
423
424
  raise ValueError(
424
- "The dtype must be specified when the value is not a numpy array."
425
+ "The dtype must be specified when the value is not a numpy array. "
426
+ "Value type: {type(value)}"
425
427
  )
426
428
  else:
427
429
  if isinstance(value, np.ndarray):
@@ -502,7 +504,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
502
504
  _enums.DataType.FLOAT4E2M1,
503
505
  }:
504
506
  # Pack the array into int4
505
- array = _type_casting.pack_int4(array)
507
+ array = _type_casting.pack_4bitx2(array)
506
508
  else:
507
509
  assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
508
510
  if not _IS_LITTLE_ENDIAN:
@@ -961,8 +963,151 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
961
963
  return self._evaluate().tobytes()
962
964
 
963
965
 
966
+ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
967
+ """A tensor that stores 4bit datatypes in packed format."""
968
+
969
+ __slots__ = (
970
+ "_dtype",
971
+ "_raw",
972
+ "_shape",
973
+ )
974
+
975
+ def __init__(
976
+ self,
977
+ value: TArrayCompatible,
978
+ dtype: _enums.DataType,
979
+ *,
980
+ shape: Shape | Sequence[int],
981
+ name: str | None = None,
982
+ doc_string: str | None = None,
983
+ metadata_props: dict[str, str] | None = None,
984
+ ) -> None:
985
+ """Initialize a tensor.
986
+
987
+ Args:
988
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
989
+ The value MUST be packed in an integer dtype.
990
+ dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
991
+ shape: The shape of the tensor.
992
+ name: The name of the tensor.
993
+ doc_string: The documentation string.
994
+ metadata_props: The metadata properties.
995
+
996
+ Raises:
997
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
998
+ TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
999
+ """
1000
+ super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
1001
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
1002
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
1003
+ self._shape = Shape(shape)
1004
+ self._shape.freeze()
1005
+ if dtype.bitwidth != 4:
1006
+ raise TypeError(
1007
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
1008
+ )
1009
+ self._dtype = dtype
1010
+ self._raw = value
1011
+
1012
+ if isinstance(value, np.ndarray):
1013
+ if (
1014
+ value.dtype == ml_dtypes.float4_e2m1fn
1015
+ or value.dtype == ml_dtypes.uint4
1016
+ or value.dtype == ml_dtypes.int4
1017
+ ):
1018
+ raise TypeError(
1019
+ f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
1020
+ "Please pack the value or use `onnx_ir.Tensor`."
1021
+ )
1022
+ # Check after shape and dtype is set
1023
+ if value.size != self.nbytes:
1024
+ raise ValueError(
1025
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
1026
+ )
1027
+
1028
+ def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
1029
+ return self.numpy()
1030
+
1031
+ def __dlpack__(self, *, stream: Any = None) -> Any:
1032
+ if _compatible_with_dlpack(self._raw):
1033
+ return self._raw.__dlpack__(stream=stream)
1034
+ return self.__array__().__dlpack__(stream=stream)
1035
+
1036
+ def __dlpack_device__(self) -> tuple[int, int]:
1037
+ if _compatible_with_dlpack(self._raw):
1038
+ return self._raw.__dlpack_device__()
1039
+ return self.__array__().__dlpack_device__()
1040
+
1041
+ def __repr__(self) -> str:
1042
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
1043
+
1044
+ @property
1045
+ def dtype(self) -> _enums.DataType:
1046
+ """The data type of the tensor. Immutable."""
1047
+ return self._dtype
1048
+
1049
+ @property
1050
+ def shape(self) -> Shape:
1051
+ """The shape of the tensor. Immutable."""
1052
+ return self._shape
1053
+
1054
+ @property
1055
+ def raw(self) -> TArrayCompatible:
1056
+ """Backing data of the tensor. Immutable."""
1057
+ return self._raw # type: ignore[return-value]
1058
+
1059
+ def numpy(self) -> np.ndarray:
1060
+ """Return the tensor as a numpy array.
1061
+
1062
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
1063
+ package are used. The values can be reinterpreted as bit representations
1064
+ using the ``.view()`` method.
1065
+ """
1066
+ array = self.numpy_packed()
1067
+ # ONNX IR returns the unpacked arrays
1068
+ if self.dtype == _enums.DataType.INT4:
1069
+ return _type_casting.unpack_int4(array, self.shape.numpy())
1070
+ if self.dtype == _enums.DataType.UINT4:
1071
+ return _type_casting.unpack_uint4(array, self.shape.numpy())
1072
+ if self.dtype == _enums.DataType.FLOAT4E2M1:
1073
+ return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1074
+ raise TypeError(
1075
+ f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1076
+ )
1077
+
1078
+ def numpy_packed(self) -> npt.NDArray[np.uint8]:
1079
+ """Return the tensor as a packed array."""
1080
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
1081
+ array = np.asarray(self._raw)
1082
+ else:
1083
+ assert _compatible_with_dlpack(self._raw), (
1084
+ f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
1085
+ )
1086
+ array = np.from_dlpack(self._raw)
1087
+ if array.nbytes != self.nbytes:
1088
+ raise ValueError(
1089
+ f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
1090
+ )
1091
+ return array.view(np.uint8)
1092
+
1093
+ def tobytes(self) -> bytes:
1094
+ """Returns the value as bytes encoded in little endian.
1095
+
1096
+ Override this method for more efficient serialization when the raw
1097
+ value is not a numpy array.
1098
+ """
1099
+ array = self.numpy_packed()
1100
+ if not _IS_LITTLE_ENDIAN:
1101
+ array = array.view(array.dtype.newbyteorder("<"))
1102
+ return array.tobytes()
1103
+
1104
+
964
1105
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
965
- """Immutable symbolic dimension that can be shared across multiple shapes."""
1106
+ """Immutable symbolic dimension that can be shared across multiple shapes.
1107
+
1108
+ SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
1109
+ It is immutable and can be compared or hashed.
1110
+ """
966
1111
 
967
1112
  __slots__ = ("_value",)
968
1113
 
@@ -971,6 +1116,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
971
1116
 
972
1117
  Args:
973
1118
  value: The value of the dimension. It should not be an int.
1119
+
1120
+ Raises:
1121
+ TypeError: If value is an int.
974
1122
  """
975
1123
  if isinstance(value, int):
976
1124
  raise TypeError(
@@ -980,15 +1128,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
980
1128
  self._value = value
981
1129
 
982
1130
  def __eq__(self, other: object) -> bool:
1131
+ """Check equality with another SymbolicDim or string/None."""
983
1132
  if not isinstance(other, SymbolicDim):
984
1133
  return self.value == other
985
1134
  return self.value == other.value
986
1135
 
987
1136
  def __hash__(self) -> int:
1137
+ """Return the hash of the symbolic dimension value."""
988
1138
  return hash(self.value)
989
1139
 
990
1140
  @property
991
1141
  def value(self) -> str | None:
1142
+ """The value of the symbolic dimension (string or None)."""
992
1143
  return self._value
993
1144
 
994
1145
  def __str__(self) -> str:
@@ -999,7 +1150,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
999
1150
 
1000
1151
 
1001
1152
  def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1002
- """Return True if the value is int compatible."""
1153
+ """Check if the value is compatible with int (i.e., can be safely cast to int).
1154
+
1155
+ Args:
1156
+ value: The value to check.
1157
+
1158
+ Returns:
1159
+ True if the value is an int or has an __int__ method, False otherwise.
1160
+ """
1003
1161
  if isinstance(value, int):
1004
1162
  return True
1005
1163
  if hasattr(value, "__int__"):
@@ -1011,7 +1169,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
1011
1169
  def _maybe_convert_to_symbolic_dim(
1012
1170
  dim: int | SupportsInt | SymbolicDim | str | None,
1013
1171
  ) -> SymbolicDim | int:
1014
- """Convert the value to a SymbolicDim if it is not an int."""
1172
+ """Convert the value to a SymbolicDim if it is not an int.
1173
+
1174
+ Args:
1175
+ dim: The dimension value, which can be int, str, None, or SymbolicDim.
1176
+
1177
+ Returns:
1178
+ An int or SymbolicDim instance.
1179
+
1180
+ Raises:
1181
+ TypeError: If the value is not int, str, None, or SymbolicDim.
1182
+ """
1015
1183
  if dim is None or isinstance(dim, str):
1016
1184
  return SymbolicDim(dim)
1017
1185
  if _is_int_compatible(dim):
@@ -1024,21 +1192,20 @@ def _maybe_convert_to_symbolic_dim(
1024
1192
 
1025
1193
 
1026
1194
  class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1027
- """The shape of a tensor, including its dimensions and optional denotations.
1028
-
1029
- The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
1030
- symbolic dimensions.
1195
+ """Represents the shape of a tensor, including its dimensions and optional denotations.
1031
1196
 
1032
- A shape can be compared to another shape or plain Python list.
1197
+ The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
1198
+ symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
1199
+ shapes to other shapes or plain Python lists.
1033
1200
 
1034
1201
  A shape can be frozen (made immutable). When the shape is frozen, it cannot be
1035
1202
  unfrozen, making it suitable to be shared across tensors or values.
1036
- Call :method:`freeze` to freeze the shape.
1203
+ Call :meth:`freeze` to freeze the shape.
1037
1204
 
1038
- To update the dimension of a frozen shape, call :method:`copy` to create a
1205
+ To update the dimension of a frozen shape, call :meth:`copy` to create a
1039
1206
  new shape with the same dimensions that can be modified.
1040
1207
 
1041
- Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
1208
+ Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
1042
1209
 
1043
1210
  Example::
1044
1211
 
@@ -1066,7 +1233,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1066
1233
 
1067
1234
  Attributes:
1068
1235
  dims: A tuple of dimensions representing the shape.
1069
- Each dimension can be an integer, None or a :class:`SymbolicDim`.
1236
+ Each dimension can be an integer, None, or a :class:`SymbolicDim`.
1070
1237
  frozen: Indicates whether the shape is immutable. When frozen, the shape
1071
1238
  cannot be modified or unfrozen.
1072
1239
  """
@@ -1121,7 +1288,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1121
1288
  """Whether the shape is frozen.
1122
1289
 
1123
1290
  When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
1124
- Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
1291
+ Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
1125
1292
  new shape with the same dimensions that can be modified.
1126
1293
  """
1127
1294
  return self._frozen
@@ -1289,19 +1456,24 @@ def _normalize_domain(domain: str) -> str:
1289
1456
  class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1290
1457
  """IR Node.
1291
1458
 
1292
- If the ``graph`` is provided, the node will be added to the graph. Otherwise,
1293
- user is responsible to call ``graph.append(node)`` (or other mutation methods
1459
+ .. tip::
1460
+ For a more convenient way (that supports Python objects
1461
+ as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
1462
+
1463
+ If ``graph`` is provided, the node will be added to the graph. Otherwise,
1464
+ the user is responsible for calling ``graph.append(node)`` (or other mutation methods
1294
1465
  in :class:`Graph`) to add the node to the graph.
1295
1466
 
1296
- After the node is initialized, it will add itself as a user of the input values.
1467
+ After the node is initialized, it will add itself as a user of its input values.
1297
1468
 
1298
1469
  The output values of the node are created during node initialization and are immutable.
1299
- To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1300
- the new output values by calling :meth:`replace_input_with` on the using nodes
1301
- of this node's outputs.
1470
+ To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
1471
+ replace the input in the consuming node by calling :meth:`replace_input_with`.
1472
+ You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
1473
+ to replace all uses of the output values.
1302
1474
 
1303
- .. note:
1304
- When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
1475
+ .. note::
1476
+ When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
1305
1477
  """
1306
1478
 
1307
1479
  __slots__ = (
@@ -1339,7 +1511,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1339
1511
 
1340
1512
  Args:
1341
1513
  domain: The domain of the operator. For onnx operators, this is an empty string.
1342
- When it is `"ai.onnx"`, it is normalized to `""`.
1514
+ When it is ``"ai.onnx"``, it is normalized to ``""``.
1343
1515
  op_type: The name of the operator.
1344
1516
  inputs: The input values. When an input is ``None``, it is an empty input.
1345
1517
  attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
@@ -1632,7 +1804,15 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
1632
1804
 
1633
1805
  @property
1634
1806
  def attributes(self) -> _graph_containers.Attributes:
1635
- """The attributes of the node."""
1807
+ """The attributes of the node as ``dict[str, Attr]`` with additional access methods.
1808
+
1809
+ Use it as a dictionary with keys being the attribute names and values being the
1810
+ :class:`Attr` objects.
1811
+
1812
+ Use ``node.attributes.add(attr)`` to add an attribute to the node.
1813
+ Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
1814
+ Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
1815
+ """
1636
1816
  return self._attributes
1637
1817
 
1638
1818
  @property
@@ -1799,12 +1979,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1799
1979
  The index of the output of the node that produces the value can be accessed with
1800
1980
  :meth:`index`.
1801
1981
 
1802
- To find all the nodes that use this value as an input, call :meth:`uses`.
1982
+ To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
1983
+ nodes can be obtained with :meth:`consumers`.
1803
1984
 
1804
1985
  To check if the value is an is an input, output or initializer of a graph,
1805
1986
  use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
1806
1987
 
1807
- Use :meth:`graph` to get the graph that owns the value.
1988
+ Use :attr:`graph` to get the graph that owns the value.
1808
1989
  """
1809
1990
 
1810
1991
  __slots__ = (
@@ -2221,7 +2402,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2221
2402
 
2222
2403
  @property
2223
2404
  def initializers(self) -> _graph_containers.GraphInitializers:
2224
- """The initializers of the graph as a ``MutableMapping[str, Value]``.
2405
+ """The initializers of the graph as a ``dict[str, Value]``.
2225
2406
 
2226
2407
  The keys are the names of the initializers. The values are the :class:`Value` objects.
2227
2408
 
@@ -2357,6 +2538,28 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2357
2538
  # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2358
2539
  return len(self)
2359
2540
 
2541
+ def all_nodes(self) -> Iterator[Node]:
2542
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
2543
+
2544
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
2545
+ Consider using
2546
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2547
+ traversals on nodes.
2548
+ """
2549
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
2550
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
2551
+
2552
+ def subgraphs(self) -> Iterator[Graph]:
2553
+ """Get all subgraphs in the graph in O(#nodes + #attributes) time."""
2554
+ seen_graphs: set[Graph] = set()
2555
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2556
+ graph = node.graph
2557
+ if graph is self:
2558
+ continue
2559
+ if graph is not None and graph not in seen_graphs:
2560
+ seen_graphs.add(graph)
2561
+ yield graph
2562
+
2360
2563
  # Mutation methods
2361
2564
  def append(self, node: Node, /) -> None:
2362
2565
  """Append a node to the graph in O(1) time.
@@ -2862,7 +3065,7 @@ Model(
2862
3065
  """Get all graphs and subgraphs in the model.
2863
3066
 
2864
3067
  This is a convenience method to traverse the model. Consider using
2865
- `onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3068
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
2866
3069
  traversals on nodes.
2867
3070
  """
2868
3071
  # NOTE(justinchuby): Given
@@ -2871,11 +3074,8 @@ Model(
2871
3074
  # (3) Users familiar with onnxruntime optimization tools expect this method
2872
3075
  # I created this method as a core method instead of an iterator in
2873
3076
  # `traversal.py`.
2874
- seen_graphs: set[Graph] = set()
2875
- for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
2876
- if node.graph is not None and node.graph not in seen_graphs:
2877
- seen_graphs.add(node.graph)
2878
- yield node.graph
3077
+ yield self.graph
3078
+ yield from self.graph.subgraphs()
2879
3079
 
2880
3080
 
2881
3081
  class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
@@ -3009,6 +3209,28 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3009
3209
  def metadata_props(self) -> dict[str, str]:
3010
3210
  return self._graph.metadata_props
3011
3211
 
3212
+ def all_nodes(self) -> Iterator[Node]:
3213
+ """Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
3214
+
3215
+ This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
3216
+ Consider using
3217
+ :class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
3218
+ traversals on nodes.
3219
+ """
3220
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
3221
+ return onnx_ir.traversal.RecursiveGraphIterator(self)
3222
+
3223
+ def subgraphs(self) -> Iterator[Graph]:
3224
+ """Get all subgraphs in the function in O(#nodes + #attributes) time."""
3225
+ seen_graphs: set[Graph] = set()
3226
+ for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3227
+ graph = node.graph
3228
+ if graph is self._graph:
3229
+ continue
3230
+ if graph is not None and graph not in seen_graphs:
3231
+ seen_graphs.add(graph)
3232
+ yield graph
3233
+
3012
3234
  # Mutation methods
3013
3235
  def append(self, node: Node, /) -> None:
3014
3236
  """Append a node to the function in O(1) time."""
onnx_ir/_enums.py CHANGED
@@ -114,7 +114,18 @@ class DataType(enum.IntEnum):
114
114
  @property
115
115
  def itemsize(self) -> float:
116
116
  """Returns the size of the data type in bytes."""
117
- return _ITEMSIZE_MAP[self]
117
+ return self.bitwidth / 8
118
+
119
+ @property
120
+ def bitwidth(self) -> int:
121
+ """Returns the bit width of the data type.
122
+
123
+ Raises:
124
+ TypeError: If the data type is not supported.
125
+ """
126
+ if self not in _BITWIDTH_MAP:
127
+ raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
128
+ return _BITWIDTH_MAP[self]
118
129
 
119
130
  def numpy(self) -> np.dtype:
120
131
  """Returns the numpy dtype for the ONNX data type.
@@ -163,30 +174,29 @@ class DataType(enum.IntEnum):
163
174
  return self.__repr__()
164
175
 
165
176
 
166
- _ITEMSIZE_MAP = {
167
- DataType.FLOAT: 4,
168
- DataType.UINT8: 1,
169
- DataType.INT8: 1,
170
- DataType.UINT16: 2,
171
- DataType.INT16: 2,
172
- DataType.INT32: 4,
173
- DataType.INT64: 8,
174
- DataType.STRING: 1,
175
- DataType.BOOL: 1,
176
- DataType.FLOAT16: 2,
177
- DataType.DOUBLE: 8,
178
- DataType.UINT32: 4,
179
- DataType.UINT64: 8,
180
- DataType.COMPLEX64: 8,
181
- DataType.COMPLEX128: 16,
182
- DataType.BFLOAT16: 2,
183
- DataType.FLOAT8E4M3FN: 1,
184
- DataType.FLOAT8E4M3FNUZ: 1,
185
- DataType.FLOAT8E5M2: 1,
186
- DataType.FLOAT8E5M2FNUZ: 1,
187
- DataType.UINT4: 0.5,
188
- DataType.INT4: 0.5,
189
- DataType.FLOAT4E2M1: 0.5,
177
+ _BITWIDTH_MAP = {
178
+ DataType.FLOAT: 32,
179
+ DataType.UINT8: 8,
180
+ DataType.INT8: 8,
181
+ DataType.UINT16: 16,
182
+ DataType.INT16: 16,
183
+ DataType.INT32: 32,
184
+ DataType.INT64: 64,
185
+ DataType.BOOL: 8,
186
+ DataType.FLOAT16: 16,
187
+ DataType.DOUBLE: 64,
188
+ DataType.UINT32: 32,
189
+ DataType.UINT64: 64,
190
+ DataType.COMPLEX64: 64, # 2 * 32
191
+ DataType.COMPLEX128: 128, # 2 * 64
192
+ DataType.BFLOAT16: 16,
193
+ DataType.FLOAT8E4M3FN: 8,
194
+ DataType.FLOAT8E4M3FNUZ: 8,
195
+ DataType.FLOAT8E5M2: 8,
196
+ DataType.FLOAT8E5M2FNUZ: 8,
197
+ DataType.UINT4: 4,
198
+ DataType.INT4: 4,
199
+ DataType.FLOAT4E2M1: 4,
190
200
  }
191
201
 
192
202
 
@@ -216,7 +216,7 @@ class GraphOutputs(_GraphIO):
216
216
 
217
217
 
218
218
  class GraphInitializers(collections.UserDict[str, "_core.Value"]):
219
- """The initializers of a Graph."""
219
+ """The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
220
220
 
221
221
  def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
222
222
  # Perform checks first in _set_graph before modifying the data structure with super().__init__()
@@ -291,7 +291,7 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
291
291
 
292
292
 
293
293
  class Attributes(collections.UserDict[str, "_core.Attr"]):
294
- """The attributes of a Node."""
294
+ """The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
295
295
 
296
296
  def __init__(self, attrs: Iterable[_core.Attr]):
297
297
  super().__init__({attr.name: attr for attr in attrs})
onnx_ir/_io.py CHANGED
@@ -7,10 +7,11 @@ from __future__ import annotations
7
7
  __all__ = ["load", "save"]
8
8
 
9
9
  import os
10
+ from typing import Callable
10
11
 
11
- import onnx
12
+ import onnx # noqa: TID251
12
13
 
13
- from onnx_ir import _core, serde
14
+ from onnx_ir import _core, _protocols, serde
14
15
  from onnx_ir import external_data as _external_data
15
16
  from onnx_ir._polyfill import zip
16
17
 
@@ -43,6 +44,8 @@ def save(
43
44
  format: str | None = None,
44
45
  external_data: str | os.PathLike | None = None,
45
46
  size_threshold_bytes: int = 256,
47
+ callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
48
+ | None = None,
46
49
  ) -> None:
47
50
  """Save an ONNX model to a file.
48
51
 
@@ -52,6 +55,30 @@ def save(
52
55
  to load the newly saved model, or provide a different external data path that
53
56
  is not currently referenced by any tensors in the model.
54
57
 
58
+ .. tip::
59
+
60
+ A simple progress bar can be implemented by passing a callback function as the following::
61
+
62
+ import onnx_ir as ir
63
+ import tqdm
64
+
65
+ with tqdm.tqdm() as pbar:
66
+ total_set = False
67
+
68
+ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
69
+ nonlocal total_set
70
+ if not total_set:
71
+ pbar.total = metadata.total
72
+ total_set = True
73
+
74
+ pbar.update()
75
+ pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
76
+
77
+ ir.save(
78
+ ...,
79
+ callback=callback,
80
+ )
81
+
55
82
  Args:
56
83
  model: The model to save.
57
84
  path: The path to save the model to. E.g. "model.onnx".
@@ -65,6 +92,8 @@ def save(
65
92
  it will be serialized in the ONNX Proto message.
66
93
  size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
67
94
  Effective only when ``external_data`` is set.
95
+ callback: A callback function that is called for each tensor that is saved to external data
96
+ for debugging or logging purposes.
68
97
 
69
98
  Raises:
70
99
  ValueError: If the external data path is an absolute path.
@@ -77,12 +106,19 @@ def save(
77
106
  base_dir = os.path.dirname(path)
78
107
 
79
108
  # Store the original initializer values so they can be restored if modify_model=False
80
- initializer_values = tuple(model.graph.initializers.values())
109
+ initializer_values: list[_core.Value] = []
110
+ for graph in model.graphs():
111
+ # Collect from all subgraphs as well
112
+ initializer_values.extend(graph.initializers.values())
81
113
  tensors = [v.const_value for v in initializer_values]
82
114
 
83
115
  try:
84
116
  model = _external_data.unload_from_model(
85
- model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
117
+ model,
118
+ base_dir,
119
+ external_data,
120
+ size_threshold_bytes=size_threshold_bytes,
121
+ callback=callback,
86
122
  )
87
123
  proto = serde.serialize_model(model)
88
124
  onnx.save(proto, path, format=format)
onnx_ir/_type_casting.py CHANGED
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
15
15
  import numpy.typing as npt
16
16
 
17
17
 
18
- def pack_int4(array: np.ndarray) -> npt.NDArray[np.uint8]:
18
+ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
19
19
  """Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
20
20
  # Create a 1D copy
21
21
  array_flat = array.ravel().view(np.uint8).copy()
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
40
40
  Returns:
41
41
  A numpy array of int8/uint8 reshaped to dims.
42
42
  """
43
+ assert data.dtype == np.uint8, "Input data must be of type uint8"
43
44
  result = np.empty([data.size * 2], dtype=data.dtype)
44
45
  array_low = data & np.uint8(0x0F)
45
46
  array_high = data & np.uint8(0xF0)