onnx-ir 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +125 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +261 -39
- onnx_ir/_enums.py +35 -25
- onnx_ir/_graph_containers.py +2 -2
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +171 -6
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/METADATA +22 -4
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/RECORD +22 -22
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.2.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py
CHANGED
|
@@ -251,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
251
251
|
or corresponding dtypes from the ``ml_dtype`` package.
|
|
252
252
|
"""
|
|
253
253
|
if dtype in _NON_NUMPY_NATIVE_TYPES:
|
|
254
|
-
if dtype.
|
|
254
|
+
if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
|
|
255
255
|
raise TypeError(
|
|
256
256
|
f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
|
|
257
257
|
)
|
|
258
|
-
if dtype.
|
|
258
|
+
if dtype.bitwidth == 8 and array.dtype not in (
|
|
259
259
|
np.uint8,
|
|
260
260
|
ml_dtypes.float8_e4m3fnuz,
|
|
261
261
|
ml_dtypes.float8_e4m3fn,
|
|
@@ -385,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
385
385
|
|
|
386
386
|
Args:
|
|
387
387
|
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
388
|
-
When the dtype is not one of the numpy native dtypes, the value
|
|
389
|
-
|
|
390
|
-
when the value is a numpy array;
|
|
388
|
+
When the dtype is not one of the numpy native dtypes, the value can
|
|
389
|
+
be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
|
|
390
|
+
and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
|
|
391
|
+
``dtype`` must be specified in this case.
|
|
391
392
|
dtype: The data type of the tensor. It can be None only when value is a numpy array.
|
|
392
393
|
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
|
|
393
394
|
shape: The shape of the tensor. If None, the shape is obtained from the value.
|
|
@@ -421,7 +422,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
421
422
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
422
423
|
else:
|
|
423
424
|
raise ValueError(
|
|
424
|
-
"The dtype must be specified when the value is not a numpy array."
|
|
425
|
+
"The dtype must be specified when the value is not a numpy array. "
|
|
426
|
+
"Value type: {type(value)}"
|
|
425
427
|
)
|
|
426
428
|
else:
|
|
427
429
|
if isinstance(value, np.ndarray):
|
|
@@ -502,7 +504,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
502
504
|
_enums.DataType.FLOAT4E2M1,
|
|
503
505
|
}:
|
|
504
506
|
# Pack the array into int4
|
|
505
|
-
array = _type_casting.
|
|
507
|
+
array = _type_casting.pack_4bitx2(array)
|
|
506
508
|
else:
|
|
507
509
|
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
|
|
508
510
|
if not _IS_LITTLE_ENDIAN:
|
|
@@ -961,8 +963,151 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
961
963
|
return self._evaluate().tobytes()
|
|
962
964
|
|
|
963
965
|
|
|
966
|
+
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
967
|
+
"""A tensor that stores 4bit datatypes in packed format."""
|
|
968
|
+
|
|
969
|
+
__slots__ = (
|
|
970
|
+
"_dtype",
|
|
971
|
+
"_raw",
|
|
972
|
+
"_shape",
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
def __init__(
|
|
976
|
+
self,
|
|
977
|
+
value: TArrayCompatible,
|
|
978
|
+
dtype: _enums.DataType,
|
|
979
|
+
*,
|
|
980
|
+
shape: Shape | Sequence[int],
|
|
981
|
+
name: str | None = None,
|
|
982
|
+
doc_string: str | None = None,
|
|
983
|
+
metadata_props: dict[str, str] | None = None,
|
|
984
|
+
) -> None:
|
|
985
|
+
"""Initialize a tensor.
|
|
986
|
+
|
|
987
|
+
Args:
|
|
988
|
+
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
989
|
+
The value MUST be packed in an integer dtype.
|
|
990
|
+
dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
|
|
991
|
+
shape: The shape of the tensor.
|
|
992
|
+
name: The name of the tensor.
|
|
993
|
+
doc_string: The documentation string.
|
|
994
|
+
metadata_props: The metadata properties.
|
|
995
|
+
|
|
996
|
+
Raises:
|
|
997
|
+
TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
|
|
998
|
+
TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
|
|
999
|
+
"""
|
|
1000
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
1001
|
+
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
|
|
1002
|
+
raise TypeError(f"Expected an array compatible object, got {type(value)}")
|
|
1003
|
+
self._shape = Shape(shape)
|
|
1004
|
+
self._shape.freeze()
|
|
1005
|
+
if dtype.bitwidth != 4:
|
|
1006
|
+
raise TypeError(
|
|
1007
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
|
|
1008
|
+
)
|
|
1009
|
+
self._dtype = dtype
|
|
1010
|
+
self._raw = value
|
|
1011
|
+
|
|
1012
|
+
if isinstance(value, np.ndarray):
|
|
1013
|
+
if (
|
|
1014
|
+
value.dtype == ml_dtypes.float4_e2m1fn
|
|
1015
|
+
or value.dtype == ml_dtypes.uint4
|
|
1016
|
+
or value.dtype == ml_dtypes.int4
|
|
1017
|
+
):
|
|
1018
|
+
raise TypeError(
|
|
1019
|
+
f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
|
|
1020
|
+
"Please pack the value or use `onnx_ir.Tensor`."
|
|
1021
|
+
)
|
|
1022
|
+
# Check after shape and dtype is set
|
|
1023
|
+
if value.size != self.nbytes:
|
|
1024
|
+
raise ValueError(
|
|
1025
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
|
|
1029
|
+
return self.numpy()
|
|
1030
|
+
|
|
1031
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
1032
|
+
if _compatible_with_dlpack(self._raw):
|
|
1033
|
+
return self._raw.__dlpack__(stream=stream)
|
|
1034
|
+
return self.__array__().__dlpack__(stream=stream)
|
|
1035
|
+
|
|
1036
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
1037
|
+
if _compatible_with_dlpack(self._raw):
|
|
1038
|
+
return self._raw.__dlpack_device__()
|
|
1039
|
+
return self.__array__().__dlpack_device__()
|
|
1040
|
+
|
|
1041
|
+
def __repr__(self) -> str:
|
|
1042
|
+
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
|
|
1043
|
+
|
|
1044
|
+
@property
|
|
1045
|
+
def dtype(self) -> _enums.DataType:
|
|
1046
|
+
"""The data type of the tensor. Immutable."""
|
|
1047
|
+
return self._dtype
|
|
1048
|
+
|
|
1049
|
+
@property
|
|
1050
|
+
def shape(self) -> Shape:
|
|
1051
|
+
"""The shape of the tensor. Immutable."""
|
|
1052
|
+
return self._shape
|
|
1053
|
+
|
|
1054
|
+
@property
|
|
1055
|
+
def raw(self) -> TArrayCompatible:
|
|
1056
|
+
"""Backing data of the tensor. Immutable."""
|
|
1057
|
+
return self._raw # type: ignore[return-value]
|
|
1058
|
+
|
|
1059
|
+
def numpy(self) -> np.ndarray:
|
|
1060
|
+
"""Return the tensor as a numpy array.
|
|
1061
|
+
|
|
1062
|
+
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
|
|
1063
|
+
package are used. The values can be reinterpreted as bit representations
|
|
1064
|
+
using the ``.view()`` method.
|
|
1065
|
+
"""
|
|
1066
|
+
array = self.numpy_packed()
|
|
1067
|
+
# ONNX IR returns the unpacked arrays
|
|
1068
|
+
if self.dtype == _enums.DataType.INT4:
|
|
1069
|
+
return _type_casting.unpack_int4(array, self.shape.numpy())
|
|
1070
|
+
if self.dtype == _enums.DataType.UINT4:
|
|
1071
|
+
return _type_casting.unpack_uint4(array, self.shape.numpy())
|
|
1072
|
+
if self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
1073
|
+
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
|
|
1074
|
+
raise TypeError(
|
|
1075
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
def numpy_packed(self) -> npt.NDArray[np.uint8]:
|
|
1079
|
+
"""Return the tensor as a packed array."""
|
|
1080
|
+
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
|
|
1081
|
+
array = np.asarray(self._raw)
|
|
1082
|
+
else:
|
|
1083
|
+
assert _compatible_with_dlpack(self._raw), (
|
|
1084
|
+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
|
|
1085
|
+
)
|
|
1086
|
+
array = np.from_dlpack(self._raw)
|
|
1087
|
+
if array.nbytes != self.nbytes:
|
|
1088
|
+
raise ValueError(
|
|
1089
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
|
|
1090
|
+
)
|
|
1091
|
+
return array.view(np.uint8)
|
|
1092
|
+
|
|
1093
|
+
def tobytes(self) -> bytes:
|
|
1094
|
+
"""Returns the value as bytes encoded in little endian.
|
|
1095
|
+
|
|
1096
|
+
Override this method for more efficient serialization when the raw
|
|
1097
|
+
value is not a numpy array.
|
|
1098
|
+
"""
|
|
1099
|
+
array = self.numpy_packed()
|
|
1100
|
+
if not _IS_LITTLE_ENDIAN:
|
|
1101
|
+
array = array.view(array.dtype.newbyteorder("<"))
|
|
1102
|
+
return array.tobytes()
|
|
1103
|
+
|
|
1104
|
+
|
|
964
1105
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
965
|
-
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1106
|
+
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1107
|
+
|
|
1108
|
+
SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
|
|
1109
|
+
It is immutable and can be compared or hashed.
|
|
1110
|
+
"""
|
|
966
1111
|
|
|
967
1112
|
__slots__ = ("_value",)
|
|
968
1113
|
|
|
@@ -971,6 +1116,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
971
1116
|
|
|
972
1117
|
Args:
|
|
973
1118
|
value: The value of the dimension. It should not be an int.
|
|
1119
|
+
|
|
1120
|
+
Raises:
|
|
1121
|
+
TypeError: If value is an int.
|
|
974
1122
|
"""
|
|
975
1123
|
if isinstance(value, int):
|
|
976
1124
|
raise TypeError(
|
|
@@ -980,15 +1128,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
980
1128
|
self._value = value
|
|
981
1129
|
|
|
982
1130
|
def __eq__(self, other: object) -> bool:
|
|
1131
|
+
"""Check equality with another SymbolicDim or string/None."""
|
|
983
1132
|
if not isinstance(other, SymbolicDim):
|
|
984
1133
|
return self.value == other
|
|
985
1134
|
return self.value == other.value
|
|
986
1135
|
|
|
987
1136
|
def __hash__(self) -> int:
|
|
1137
|
+
"""Return the hash of the symbolic dimension value."""
|
|
988
1138
|
return hash(self.value)
|
|
989
1139
|
|
|
990
1140
|
@property
|
|
991
1141
|
def value(self) -> str | None:
|
|
1142
|
+
"""The value of the symbolic dimension (string or None)."""
|
|
992
1143
|
return self._value
|
|
993
1144
|
|
|
994
1145
|
def __str__(self) -> str:
|
|
@@ -999,7 +1150,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
999
1150
|
|
|
1000
1151
|
|
|
1001
1152
|
def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
1002
|
-
"""
|
|
1153
|
+
"""Check if the value is compatible with int (i.e., can be safely cast to int).
|
|
1154
|
+
|
|
1155
|
+
Args:
|
|
1156
|
+
value: The value to check.
|
|
1157
|
+
|
|
1158
|
+
Returns:
|
|
1159
|
+
True if the value is an int or has an __int__ method, False otherwise.
|
|
1160
|
+
"""
|
|
1003
1161
|
if isinstance(value, int):
|
|
1004
1162
|
return True
|
|
1005
1163
|
if hasattr(value, "__int__"):
|
|
@@ -1011,7 +1169,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
|
1011
1169
|
def _maybe_convert_to_symbolic_dim(
|
|
1012
1170
|
dim: int | SupportsInt | SymbolicDim | str | None,
|
|
1013
1171
|
) -> SymbolicDim | int:
|
|
1014
|
-
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1172
|
+
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1173
|
+
|
|
1174
|
+
Args:
|
|
1175
|
+
dim: The dimension value, which can be int, str, None, or SymbolicDim.
|
|
1176
|
+
|
|
1177
|
+
Returns:
|
|
1178
|
+
An int or SymbolicDim instance.
|
|
1179
|
+
|
|
1180
|
+
Raises:
|
|
1181
|
+
TypeError: If the value is not int, str, None, or SymbolicDim.
|
|
1182
|
+
"""
|
|
1015
1183
|
if dim is None or isinstance(dim, str):
|
|
1016
1184
|
return SymbolicDim(dim)
|
|
1017
1185
|
if _is_int_compatible(dim):
|
|
@@ -1024,21 +1192,20 @@ def _maybe_convert_to_symbolic_dim(
|
|
|
1024
1192
|
|
|
1025
1193
|
|
|
1026
1194
|
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
1027
|
-
"""
|
|
1028
|
-
|
|
1029
|
-
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1030
|
-
symbolic dimensions.
|
|
1195
|
+
"""Represents the shape of a tensor, including its dimensions and optional denotations.
|
|
1031
1196
|
|
|
1032
|
-
|
|
1197
|
+
The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1198
|
+
symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
|
|
1199
|
+
shapes to other shapes or plain Python lists.
|
|
1033
1200
|
|
|
1034
1201
|
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
|
|
1035
1202
|
unfrozen, making it suitable to be shared across tensors or values.
|
|
1036
|
-
Call :
|
|
1203
|
+
Call :meth:`freeze` to freeze the shape.
|
|
1037
1204
|
|
|
1038
|
-
To update the dimension of a frozen shape, call :
|
|
1205
|
+
To update the dimension of a frozen shape, call :meth:`copy` to create a
|
|
1039
1206
|
new shape with the same dimensions that can be modified.
|
|
1040
1207
|
|
|
1041
|
-
Use :
|
|
1208
|
+
Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
|
|
1042
1209
|
|
|
1043
1210
|
Example::
|
|
1044
1211
|
|
|
@@ -1066,7 +1233,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1066
1233
|
|
|
1067
1234
|
Attributes:
|
|
1068
1235
|
dims: A tuple of dimensions representing the shape.
|
|
1069
|
-
Each dimension can be an integer, None or a :class:`SymbolicDim`.
|
|
1236
|
+
Each dimension can be an integer, None, or a :class:`SymbolicDim`.
|
|
1070
1237
|
frozen: Indicates whether the shape is immutable. When frozen, the shape
|
|
1071
1238
|
cannot be modified or unfrozen.
|
|
1072
1239
|
"""
|
|
@@ -1121,7 +1288,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1121
1288
|
"""Whether the shape is frozen.
|
|
1122
1289
|
|
|
1123
1290
|
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1124
|
-
Call :
|
|
1291
|
+
Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
|
|
1125
1292
|
new shape with the same dimensions that can be modified.
|
|
1126
1293
|
"""
|
|
1127
1294
|
return self._frozen
|
|
@@ -1289,19 +1456,24 @@ def _normalize_domain(domain: str) -> str:
|
|
|
1289
1456
|
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
1290
1457
|
"""IR Node.
|
|
1291
1458
|
|
|
1292
|
-
|
|
1293
|
-
|
|
1459
|
+
.. tip::
|
|
1460
|
+
For a more convenient way (that supports Python objects
|
|
1461
|
+
as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
|
|
1462
|
+
|
|
1463
|
+
If ``graph`` is provided, the node will be added to the graph. Otherwise,
|
|
1464
|
+
the user is responsible for calling ``graph.append(node)`` (or other mutation methods
|
|
1294
1465
|
in :class:`Graph`) to add the node to the graph.
|
|
1295
1466
|
|
|
1296
|
-
After the node is initialized, it will add itself as a user of
|
|
1467
|
+
After the node is initialized, it will add itself as a user of its input values.
|
|
1297
1468
|
|
|
1298
1469
|
The output values of the node are created during node initialization and are immutable.
|
|
1299
|
-
To change the output values, create a new node and
|
|
1300
|
-
the
|
|
1301
|
-
|
|
1470
|
+
To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
|
|
1471
|
+
replace the input in the consuming node by calling :meth:`replace_input_with`.
|
|
1472
|
+
You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
|
|
1473
|
+
to replace all uses of the output values.
|
|
1302
1474
|
|
|
1303
|
-
.. note
|
|
1304
|
-
When the ``domain`` is
|
|
1475
|
+
.. note::
|
|
1476
|
+
When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1305
1477
|
"""
|
|
1306
1478
|
|
|
1307
1479
|
__slots__ = (
|
|
@@ -1339,7 +1511,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1339
1511
|
|
|
1340
1512
|
Args:
|
|
1341
1513
|
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
1342
|
-
When it is
|
|
1514
|
+
When it is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1343
1515
|
op_type: The name of the operator.
|
|
1344
1516
|
inputs: The input values. When an input is ``None``, it is an empty input.
|
|
1345
1517
|
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
@@ -1632,7 +1804,15 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1632
1804
|
|
|
1633
1805
|
@property
|
|
1634
1806
|
def attributes(self) -> _graph_containers.Attributes:
|
|
1635
|
-
"""The attributes of the node.
|
|
1807
|
+
"""The attributes of the node as ``dict[str, Attr]`` with additional access methods.
|
|
1808
|
+
|
|
1809
|
+
Use it as a dictionary with keys being the attribute names and values being the
|
|
1810
|
+
:class:`Attr` objects.
|
|
1811
|
+
|
|
1812
|
+
Use ``node.attributes.add(attr)`` to add an attribute to the node.
|
|
1813
|
+
Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
|
|
1814
|
+
Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
|
|
1815
|
+
"""
|
|
1636
1816
|
return self._attributes
|
|
1637
1817
|
|
|
1638
1818
|
@property
|
|
@@ -1799,12 +1979,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1799
1979
|
The index of the output of the node that produces the value can be accessed with
|
|
1800
1980
|
:meth:`index`.
|
|
1801
1981
|
|
|
1802
|
-
To find all the nodes that use this value as an input, call :meth:`uses`.
|
|
1982
|
+
To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
|
|
1983
|
+
nodes can be obtained with :meth:`consumers`.
|
|
1803
1984
|
|
|
1804
1985
|
To check if the value is an is an input, output or initializer of a graph,
|
|
1805
1986
|
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
|
|
1806
1987
|
|
|
1807
|
-
Use :
|
|
1988
|
+
Use :attr:`graph` to get the graph that owns the value.
|
|
1808
1989
|
"""
|
|
1809
1990
|
|
|
1810
1991
|
__slots__ = (
|
|
@@ -2221,7 +2402,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2221
2402
|
|
|
2222
2403
|
@property
|
|
2223
2404
|
def initializers(self) -> _graph_containers.GraphInitializers:
|
|
2224
|
-
"""The initializers of the graph as a ``
|
|
2405
|
+
"""The initializers of the graph as a ``dict[str, Value]``.
|
|
2225
2406
|
|
|
2226
2407
|
The keys are the names of the initializers. The values are the :class:`Value` objects.
|
|
2227
2408
|
|
|
@@ -2357,6 +2538,28 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2357
2538
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2358
2539
|
return len(self)
|
|
2359
2540
|
|
|
2541
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
2542
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
2543
|
+
|
|
2544
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
2545
|
+
Consider using
|
|
2546
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2547
|
+
traversals on nodes.
|
|
2548
|
+
"""
|
|
2549
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2550
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2551
|
+
|
|
2552
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
2553
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time."""
|
|
2554
|
+
seen_graphs: set[Graph] = set()
|
|
2555
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2556
|
+
graph = node.graph
|
|
2557
|
+
if graph is self:
|
|
2558
|
+
continue
|
|
2559
|
+
if graph is not None and graph not in seen_graphs:
|
|
2560
|
+
seen_graphs.add(graph)
|
|
2561
|
+
yield graph
|
|
2562
|
+
|
|
2360
2563
|
# Mutation methods
|
|
2361
2564
|
def append(self, node: Node, /) -> None:
|
|
2362
2565
|
"""Append a node to the graph in O(1) time.
|
|
@@ -2862,7 +3065,7 @@ Model(
|
|
|
2862
3065
|
"""Get all graphs and subgraphs in the model.
|
|
2863
3066
|
|
|
2864
3067
|
This is a convenience method to traverse the model. Consider using
|
|
2865
|
-
|
|
3068
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2866
3069
|
traversals on nodes.
|
|
2867
3070
|
"""
|
|
2868
3071
|
# NOTE(justinchuby): Given
|
|
@@ -2871,11 +3074,8 @@ Model(
|
|
|
2871
3074
|
# (3) Users familiar with onnxruntime optimization tools expect this method
|
|
2872
3075
|
# I created this method as a core method instead of an iterator in
|
|
2873
3076
|
# `traversal.py`.
|
|
2874
|
-
|
|
2875
|
-
|
|
2876
|
-
if node.graph is not None and node.graph not in seen_graphs:
|
|
2877
|
-
seen_graphs.add(node.graph)
|
|
2878
|
-
yield node.graph
|
|
3077
|
+
yield self.graph
|
|
3078
|
+
yield from self.graph.subgraphs()
|
|
2879
3079
|
|
|
2880
3080
|
|
|
2881
3081
|
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
@@ -3009,6 +3209,28 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3009
3209
|
def metadata_props(self) -> dict[str, str]:
|
|
3010
3210
|
return self._graph.metadata_props
|
|
3011
3211
|
|
|
3212
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
3213
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
3214
|
+
|
|
3215
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
3216
|
+
Consider using
|
|
3217
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3218
|
+
traversals on nodes.
|
|
3219
|
+
"""
|
|
3220
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3221
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3222
|
+
|
|
3223
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
3224
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time."""
|
|
3225
|
+
seen_graphs: set[Graph] = set()
|
|
3226
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3227
|
+
graph = node.graph
|
|
3228
|
+
if graph is self._graph:
|
|
3229
|
+
continue
|
|
3230
|
+
if graph is not None and graph not in seen_graphs:
|
|
3231
|
+
seen_graphs.add(graph)
|
|
3232
|
+
yield graph
|
|
3233
|
+
|
|
3012
3234
|
# Mutation methods
|
|
3013
3235
|
def append(self, node: Node, /) -> None:
|
|
3014
3236
|
"""Append a node to the function in O(1) time."""
|
onnx_ir/_enums.py
CHANGED
|
@@ -114,7 +114,18 @@ class DataType(enum.IntEnum):
|
|
|
114
114
|
@property
|
|
115
115
|
def itemsize(self) -> float:
|
|
116
116
|
"""Returns the size of the data type in bytes."""
|
|
117
|
-
return
|
|
117
|
+
return self.bitwidth / 8
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def bitwidth(self) -> int:
|
|
121
|
+
"""Returns the bit width of the data type.
|
|
122
|
+
|
|
123
|
+
Raises:
|
|
124
|
+
TypeError: If the data type is not supported.
|
|
125
|
+
"""
|
|
126
|
+
if self not in _BITWIDTH_MAP:
|
|
127
|
+
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
|
|
128
|
+
return _BITWIDTH_MAP[self]
|
|
118
129
|
|
|
119
130
|
def numpy(self) -> np.dtype:
|
|
120
131
|
"""Returns the numpy dtype for the ONNX data type.
|
|
@@ -163,30 +174,29 @@ class DataType(enum.IntEnum):
|
|
|
163
174
|
return self.__repr__()
|
|
164
175
|
|
|
165
176
|
|
|
166
|
-
|
|
167
|
-
DataType.FLOAT:
|
|
168
|
-
DataType.UINT8:
|
|
169
|
-
DataType.INT8:
|
|
170
|
-
DataType.UINT16:
|
|
171
|
-
DataType.INT16:
|
|
172
|
-
DataType.INT32:
|
|
173
|
-
DataType.INT64:
|
|
174
|
-
DataType.
|
|
175
|
-
DataType.
|
|
176
|
-
DataType.
|
|
177
|
-
DataType.
|
|
178
|
-
DataType.
|
|
179
|
-
DataType.
|
|
180
|
-
DataType.
|
|
181
|
-
DataType.
|
|
182
|
-
DataType.
|
|
183
|
-
DataType.
|
|
184
|
-
DataType.
|
|
185
|
-
DataType.
|
|
186
|
-
DataType.
|
|
187
|
-
DataType.
|
|
188
|
-
DataType.
|
|
189
|
-
DataType.FLOAT4E2M1: 0.5,
|
|
177
|
+
_BITWIDTH_MAP = {
|
|
178
|
+
DataType.FLOAT: 32,
|
|
179
|
+
DataType.UINT8: 8,
|
|
180
|
+
DataType.INT8: 8,
|
|
181
|
+
DataType.UINT16: 16,
|
|
182
|
+
DataType.INT16: 16,
|
|
183
|
+
DataType.INT32: 32,
|
|
184
|
+
DataType.INT64: 64,
|
|
185
|
+
DataType.BOOL: 8,
|
|
186
|
+
DataType.FLOAT16: 16,
|
|
187
|
+
DataType.DOUBLE: 64,
|
|
188
|
+
DataType.UINT32: 32,
|
|
189
|
+
DataType.UINT64: 64,
|
|
190
|
+
DataType.COMPLEX64: 64, # 2 * 32
|
|
191
|
+
DataType.COMPLEX128: 128, # 2 * 64
|
|
192
|
+
DataType.BFLOAT16: 16,
|
|
193
|
+
DataType.FLOAT8E4M3FN: 8,
|
|
194
|
+
DataType.FLOAT8E4M3FNUZ: 8,
|
|
195
|
+
DataType.FLOAT8E5M2: 8,
|
|
196
|
+
DataType.FLOAT8E5M2FNUZ: 8,
|
|
197
|
+
DataType.UINT4: 4,
|
|
198
|
+
DataType.INT4: 4,
|
|
199
|
+
DataType.FLOAT4E2M1: 4,
|
|
190
200
|
}
|
|
191
201
|
|
|
192
202
|
|
onnx_ir/_graph_containers.py
CHANGED
|
@@ -216,7 +216,7 @@ class GraphOutputs(_GraphIO):
|
|
|
216
216
|
|
|
217
217
|
|
|
218
218
|
class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
219
|
-
"""The initializers of a Graph."""
|
|
219
|
+
"""The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
|
|
220
220
|
|
|
221
221
|
def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
|
|
222
222
|
# Perform checks first in _set_graph before modifying the data structure with super().__init__()
|
|
@@ -291,7 +291,7 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
291
291
|
|
|
292
292
|
|
|
293
293
|
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
-
"""The attributes of a Node."""
|
|
294
|
+
"""The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
|
|
295
295
|
|
|
296
296
|
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
297
|
super().__init__({attr.name: attr for attr in attrs})
|
onnx_ir/_io.py
CHANGED
|
@@ -7,10 +7,11 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = ["load", "save"]
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
+
from typing import Callable
|
|
10
11
|
|
|
11
|
-
import onnx
|
|
12
|
+
import onnx # noqa: TID251
|
|
12
13
|
|
|
13
|
-
from onnx_ir import _core, serde
|
|
14
|
+
from onnx_ir import _core, _protocols, serde
|
|
14
15
|
from onnx_ir import external_data as _external_data
|
|
15
16
|
from onnx_ir._polyfill import zip
|
|
16
17
|
|
|
@@ -43,6 +44,8 @@ def save(
|
|
|
43
44
|
format: str | None = None,
|
|
44
45
|
external_data: str | os.PathLike | None = None,
|
|
45
46
|
size_threshold_bytes: int = 256,
|
|
47
|
+
callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
|
|
48
|
+
| None = None,
|
|
46
49
|
) -> None:
|
|
47
50
|
"""Save an ONNX model to a file.
|
|
48
51
|
|
|
@@ -52,6 +55,30 @@ def save(
|
|
|
52
55
|
to load the newly saved model, or provide a different external data path that
|
|
53
56
|
is not currently referenced by any tensors in the model.
|
|
54
57
|
|
|
58
|
+
.. tip::
|
|
59
|
+
|
|
60
|
+
A simple progress bar can be implemented by passing a callback function as the following::
|
|
61
|
+
|
|
62
|
+
import onnx_ir as ir
|
|
63
|
+
import tqdm
|
|
64
|
+
|
|
65
|
+
with tqdm.tqdm() as pbar:
|
|
66
|
+
total_set = False
|
|
67
|
+
|
|
68
|
+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
|
|
69
|
+
nonlocal total_set
|
|
70
|
+
if not total_set:
|
|
71
|
+
pbar.total = metadata.total
|
|
72
|
+
total_set = True
|
|
73
|
+
|
|
74
|
+
pbar.update()
|
|
75
|
+
pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
|
|
76
|
+
|
|
77
|
+
ir.save(
|
|
78
|
+
...,
|
|
79
|
+
callback=callback,
|
|
80
|
+
)
|
|
81
|
+
|
|
55
82
|
Args:
|
|
56
83
|
model: The model to save.
|
|
57
84
|
path: The path to save the model to. E.g. "model.onnx".
|
|
@@ -65,6 +92,8 @@ def save(
|
|
|
65
92
|
it will be serialized in the ONNX Proto message.
|
|
66
93
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
67
94
|
Effective only when ``external_data`` is set.
|
|
95
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
96
|
+
for debugging or logging purposes.
|
|
68
97
|
|
|
69
98
|
Raises:
|
|
70
99
|
ValueError: If the external data path is an absolute path.
|
|
@@ -77,12 +106,19 @@ def save(
|
|
|
77
106
|
base_dir = os.path.dirname(path)
|
|
78
107
|
|
|
79
108
|
# Store the original initializer values so they can be restored if modify_model=False
|
|
80
|
-
initializer_values =
|
|
109
|
+
initializer_values: list[_core.Value] = []
|
|
110
|
+
for graph in model.graphs():
|
|
111
|
+
# Collect from all subgraphs as well
|
|
112
|
+
initializer_values.extend(graph.initializers.values())
|
|
81
113
|
tensors = [v.const_value for v in initializer_values]
|
|
82
114
|
|
|
83
115
|
try:
|
|
84
116
|
model = _external_data.unload_from_model(
|
|
85
|
-
model,
|
|
117
|
+
model,
|
|
118
|
+
base_dir,
|
|
119
|
+
external_data,
|
|
120
|
+
size_threshold_bytes=size_threshold_bytes,
|
|
121
|
+
callback=callback,
|
|
86
122
|
)
|
|
87
123
|
proto = serde.serialize_model(model)
|
|
88
124
|
onnx.save(proto, path, format=format)
|
onnx_ir/_type_casting.py
CHANGED
|
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
|
|
|
15
15
|
import numpy.typing as npt
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def
|
|
18
|
+
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
19
19
|
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
20
20
|
# Create a 1D copy
|
|
21
21
|
array_flat = array.ravel().view(np.uint8).copy()
|
|
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
|
|
|
40
40
|
Returns:
|
|
41
41
|
A numpy array of int8/uint8 reshaped to dims.
|
|
42
42
|
"""
|
|
43
|
+
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
43
44
|
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
44
45
|
array_low = data & np.uint8(0x0F)
|
|
45
46
|
array_high = data & np.uint8(0xF0)
|