onnx-ir 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +130 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +283 -39
- onnx_ir/_enums.py +37 -25
- onnx_ir/_graph_containers.py +2 -2
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/common_subexpression_elimination.py +104 -75
- onnx_ir/passes/common/initializer_deduplication.py +56 -0
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +176 -6
- onnx_ir/tensor_adapters.py +62 -7
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/METADATA +22 -4
- onnx_ir-0.1.3.dist-info/RECORD +43 -0
- onnx_ir-0.1.1.dist-info/RECORD +0 -42
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py
CHANGED
|
@@ -251,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
251
251
|
or corresponding dtypes from the ``ml_dtype`` package.
|
|
252
252
|
"""
|
|
253
253
|
if dtype in _NON_NUMPY_NATIVE_TYPES:
|
|
254
|
-
if dtype.
|
|
254
|
+
if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
|
|
255
255
|
raise TypeError(
|
|
256
256
|
f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
|
|
257
257
|
)
|
|
258
|
-
if dtype.
|
|
258
|
+
if dtype.bitwidth == 8 and array.dtype not in (
|
|
259
259
|
np.uint8,
|
|
260
260
|
ml_dtypes.float8_e4m3fnuz,
|
|
261
261
|
ml_dtypes.float8_e4m3fn,
|
|
@@ -385,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
385
385
|
|
|
386
386
|
Args:
|
|
387
387
|
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
388
|
-
When the dtype is not one of the numpy native dtypes, the value
|
|
389
|
-
|
|
390
|
-
when the value is a numpy array;
|
|
388
|
+
When the dtype is not one of the numpy native dtypes, the value can
|
|
389
|
+
be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
|
|
390
|
+
and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
|
|
391
|
+
``dtype`` must be specified in this case.
|
|
391
392
|
dtype: The data type of the tensor. It can be None only when value is a numpy array.
|
|
392
393
|
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
|
|
393
394
|
shape: The shape of the tensor. If None, the shape is obtained from the value.
|
|
@@ -416,12 +417,16 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
416
417
|
else:
|
|
417
418
|
self._shape = shape
|
|
418
419
|
self._shape.freeze()
|
|
420
|
+
if isinstance(value, np.generic):
|
|
421
|
+
# Turn numpy scalar into a numpy array
|
|
422
|
+
value = np.array(value) # type: ignore[assignment]
|
|
419
423
|
if dtype is None:
|
|
420
424
|
if isinstance(value, np.ndarray):
|
|
421
425
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
422
426
|
else:
|
|
423
427
|
raise ValueError(
|
|
424
|
-
"The dtype must be specified when the value is not a numpy array."
|
|
428
|
+
"The dtype must be specified when the value is not a numpy array. "
|
|
429
|
+
"Value type: {type(value)}"
|
|
425
430
|
)
|
|
426
431
|
else:
|
|
427
432
|
if isinstance(value, np.ndarray):
|
|
@@ -502,7 +507,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
502
507
|
_enums.DataType.FLOAT4E2M1,
|
|
503
508
|
}:
|
|
504
509
|
# Pack the array into int4
|
|
505
|
-
array = _type_casting.
|
|
510
|
+
array = _type_casting.pack_4bitx2(array)
|
|
506
511
|
else:
|
|
507
512
|
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
|
|
508
513
|
if not _IS_LITTLE_ENDIAN:
|
|
@@ -961,8 +966,154 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
961
966
|
return self._evaluate().tobytes()
|
|
962
967
|
|
|
963
968
|
|
|
969
|
+
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
970
|
+
"""A tensor that stores 4bit datatypes in packed format.
|
|
971
|
+
|
|
972
|
+
.. versionadded:: 0.1.2
|
|
973
|
+
"""
|
|
974
|
+
|
|
975
|
+
__slots__ = (
|
|
976
|
+
"_dtype",
|
|
977
|
+
"_raw",
|
|
978
|
+
"_shape",
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
def __init__(
|
|
982
|
+
self,
|
|
983
|
+
value: TArrayCompatible,
|
|
984
|
+
dtype: _enums.DataType,
|
|
985
|
+
*,
|
|
986
|
+
shape: Shape | Sequence[int],
|
|
987
|
+
name: str | None = None,
|
|
988
|
+
doc_string: str | None = None,
|
|
989
|
+
metadata_props: dict[str, str] | None = None,
|
|
990
|
+
) -> None:
|
|
991
|
+
"""Initialize a tensor.
|
|
992
|
+
|
|
993
|
+
Args:
|
|
994
|
+
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
995
|
+
The value MUST be packed in an integer dtype.
|
|
996
|
+
dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
|
|
997
|
+
shape: The shape of the tensor.
|
|
998
|
+
name: The name of the tensor.
|
|
999
|
+
doc_string: The documentation string.
|
|
1000
|
+
metadata_props: The metadata properties.
|
|
1001
|
+
|
|
1002
|
+
Raises:
|
|
1003
|
+
TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
|
|
1004
|
+
TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
|
|
1005
|
+
"""
|
|
1006
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
1007
|
+
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
|
|
1008
|
+
raise TypeError(f"Expected an array compatible object, got {type(value)}")
|
|
1009
|
+
self._shape = Shape(shape)
|
|
1010
|
+
self._shape.freeze()
|
|
1011
|
+
if dtype.bitwidth != 4:
|
|
1012
|
+
raise TypeError(
|
|
1013
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
|
|
1014
|
+
)
|
|
1015
|
+
self._dtype = dtype
|
|
1016
|
+
self._raw = value
|
|
1017
|
+
|
|
1018
|
+
if isinstance(value, np.ndarray):
|
|
1019
|
+
if (
|
|
1020
|
+
value.dtype == ml_dtypes.float4_e2m1fn
|
|
1021
|
+
or value.dtype == ml_dtypes.uint4
|
|
1022
|
+
or value.dtype == ml_dtypes.int4
|
|
1023
|
+
):
|
|
1024
|
+
raise TypeError(
|
|
1025
|
+
f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
|
|
1026
|
+
"Please pack the value or use `onnx_ir.Tensor`."
|
|
1027
|
+
)
|
|
1028
|
+
# Check after shape and dtype is set
|
|
1029
|
+
if value.size != self.nbytes:
|
|
1030
|
+
raise ValueError(
|
|
1031
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
|
|
1032
|
+
)
|
|
1033
|
+
|
|
1034
|
+
def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
|
|
1035
|
+
return self.numpy()
|
|
1036
|
+
|
|
1037
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
1038
|
+
if _compatible_with_dlpack(self._raw):
|
|
1039
|
+
return self._raw.__dlpack__(stream=stream)
|
|
1040
|
+
return self.__array__().__dlpack__(stream=stream)
|
|
1041
|
+
|
|
1042
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
1043
|
+
if _compatible_with_dlpack(self._raw):
|
|
1044
|
+
return self._raw.__dlpack_device__()
|
|
1045
|
+
return self.__array__().__dlpack_device__()
|
|
1046
|
+
|
|
1047
|
+
def __repr__(self) -> str:
|
|
1048
|
+
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
|
|
1049
|
+
|
|
1050
|
+
@property
|
|
1051
|
+
def dtype(self) -> _enums.DataType:
|
|
1052
|
+
"""The data type of the tensor. Immutable."""
|
|
1053
|
+
return self._dtype
|
|
1054
|
+
|
|
1055
|
+
@property
|
|
1056
|
+
def shape(self) -> Shape:
|
|
1057
|
+
"""The shape of the tensor. Immutable."""
|
|
1058
|
+
return self._shape
|
|
1059
|
+
|
|
1060
|
+
@property
|
|
1061
|
+
def raw(self) -> TArrayCompatible:
|
|
1062
|
+
"""Backing data of the tensor. Immutable."""
|
|
1063
|
+
return self._raw # type: ignore[return-value]
|
|
1064
|
+
|
|
1065
|
+
def numpy(self) -> np.ndarray:
|
|
1066
|
+
"""Return the tensor as a numpy array.
|
|
1067
|
+
|
|
1068
|
+
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
|
|
1069
|
+
package are used. The values can be reinterpreted as bit representations
|
|
1070
|
+
using the ``.view()`` method.
|
|
1071
|
+
"""
|
|
1072
|
+
array = self.numpy_packed()
|
|
1073
|
+
# ONNX IR returns the unpacked arrays
|
|
1074
|
+
if self.dtype == _enums.DataType.INT4:
|
|
1075
|
+
return _type_casting.unpack_int4(array, self.shape.numpy())
|
|
1076
|
+
if self.dtype == _enums.DataType.UINT4:
|
|
1077
|
+
return _type_casting.unpack_uint4(array, self.shape.numpy())
|
|
1078
|
+
if self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
1079
|
+
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
|
|
1080
|
+
raise TypeError(
|
|
1081
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
|
|
1082
|
+
)
|
|
1083
|
+
|
|
1084
|
+
def numpy_packed(self) -> npt.NDArray[np.uint8]:
|
|
1085
|
+
"""Return the tensor as a packed array."""
|
|
1086
|
+
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
|
|
1087
|
+
array = np.asarray(self._raw)
|
|
1088
|
+
else:
|
|
1089
|
+
assert _compatible_with_dlpack(self._raw), (
|
|
1090
|
+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
|
|
1091
|
+
)
|
|
1092
|
+
array = np.from_dlpack(self._raw)
|
|
1093
|
+
if array.nbytes != self.nbytes:
|
|
1094
|
+
raise ValueError(
|
|
1095
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
|
|
1096
|
+
)
|
|
1097
|
+
return array.view(np.uint8)
|
|
1098
|
+
|
|
1099
|
+
def tobytes(self) -> bytes:
|
|
1100
|
+
"""Returns the value as bytes encoded in little endian.
|
|
1101
|
+
|
|
1102
|
+
Override this method for more efficient serialization when the raw
|
|
1103
|
+
value is not a numpy array.
|
|
1104
|
+
"""
|
|
1105
|
+
array = self.numpy_packed()
|
|
1106
|
+
if not _IS_LITTLE_ENDIAN:
|
|
1107
|
+
array = array.view(array.dtype.newbyteorder("<"))
|
|
1108
|
+
return array.tobytes()
|
|
1109
|
+
|
|
1110
|
+
|
|
964
1111
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
965
|
-
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1112
|
+
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1113
|
+
|
|
1114
|
+
SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
|
|
1115
|
+
It is immutable and can be compared or hashed.
|
|
1116
|
+
"""
|
|
966
1117
|
|
|
967
1118
|
__slots__ = ("_value",)
|
|
968
1119
|
|
|
@@ -971,6 +1122,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
971
1122
|
|
|
972
1123
|
Args:
|
|
973
1124
|
value: The value of the dimension. It should not be an int.
|
|
1125
|
+
|
|
1126
|
+
Raises:
|
|
1127
|
+
TypeError: If value is an int.
|
|
974
1128
|
"""
|
|
975
1129
|
if isinstance(value, int):
|
|
976
1130
|
raise TypeError(
|
|
@@ -980,15 +1134,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
980
1134
|
self._value = value
|
|
981
1135
|
|
|
982
1136
|
def __eq__(self, other: object) -> bool:
|
|
1137
|
+
"""Check equality with another SymbolicDim or string/None."""
|
|
983
1138
|
if not isinstance(other, SymbolicDim):
|
|
984
1139
|
return self.value == other
|
|
985
1140
|
return self.value == other.value
|
|
986
1141
|
|
|
987
1142
|
def __hash__(self) -> int:
|
|
1143
|
+
"""Return the hash of the symbolic dimension value."""
|
|
988
1144
|
return hash(self.value)
|
|
989
1145
|
|
|
990
1146
|
@property
|
|
991
1147
|
def value(self) -> str | None:
|
|
1148
|
+
"""The value of the symbolic dimension (string or None)."""
|
|
992
1149
|
return self._value
|
|
993
1150
|
|
|
994
1151
|
def __str__(self) -> str:
|
|
@@ -999,7 +1156,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
999
1156
|
|
|
1000
1157
|
|
|
1001
1158
|
def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
1002
|
-
"""
|
|
1159
|
+
"""Check if the value is compatible with int (i.e., can be safely cast to int).
|
|
1160
|
+
|
|
1161
|
+
Args:
|
|
1162
|
+
value: The value to check.
|
|
1163
|
+
|
|
1164
|
+
Returns:
|
|
1165
|
+
True if the value is an int or has an __int__ method, False otherwise.
|
|
1166
|
+
"""
|
|
1003
1167
|
if isinstance(value, int):
|
|
1004
1168
|
return True
|
|
1005
1169
|
if hasattr(value, "__int__"):
|
|
@@ -1011,7 +1175,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
|
1011
1175
|
def _maybe_convert_to_symbolic_dim(
|
|
1012
1176
|
dim: int | SupportsInt | SymbolicDim | str | None,
|
|
1013
1177
|
) -> SymbolicDim | int:
|
|
1014
|
-
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1178
|
+
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1179
|
+
|
|
1180
|
+
Args:
|
|
1181
|
+
dim: The dimension value, which can be int, str, None, or SymbolicDim.
|
|
1182
|
+
|
|
1183
|
+
Returns:
|
|
1184
|
+
An int or SymbolicDim instance.
|
|
1185
|
+
|
|
1186
|
+
Raises:
|
|
1187
|
+
TypeError: If the value is not int, str, None, or SymbolicDim.
|
|
1188
|
+
"""
|
|
1015
1189
|
if dim is None or isinstance(dim, str):
|
|
1016
1190
|
return SymbolicDim(dim)
|
|
1017
1191
|
if _is_int_compatible(dim):
|
|
@@ -1024,21 +1198,20 @@ def _maybe_convert_to_symbolic_dim(
|
|
|
1024
1198
|
|
|
1025
1199
|
|
|
1026
1200
|
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
1027
|
-
"""
|
|
1028
|
-
|
|
1029
|
-
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1030
|
-
symbolic dimensions.
|
|
1201
|
+
"""Represents the shape of a tensor, including its dimensions and optional denotations.
|
|
1031
1202
|
|
|
1032
|
-
|
|
1203
|
+
The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1204
|
+
symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
|
|
1205
|
+
shapes to other shapes or plain Python lists.
|
|
1033
1206
|
|
|
1034
1207
|
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
|
|
1035
1208
|
unfrozen, making it suitable to be shared across tensors or values.
|
|
1036
|
-
Call :
|
|
1209
|
+
Call :meth:`freeze` to freeze the shape.
|
|
1037
1210
|
|
|
1038
|
-
To update the dimension of a frozen shape, call :
|
|
1211
|
+
To update the dimension of a frozen shape, call :meth:`copy` to create a
|
|
1039
1212
|
new shape with the same dimensions that can be modified.
|
|
1040
1213
|
|
|
1041
|
-
Use :
|
|
1214
|
+
Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
|
|
1042
1215
|
|
|
1043
1216
|
Example::
|
|
1044
1217
|
|
|
@@ -1066,7 +1239,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1066
1239
|
|
|
1067
1240
|
Attributes:
|
|
1068
1241
|
dims: A tuple of dimensions representing the shape.
|
|
1069
|
-
Each dimension can be an integer, None or a :class:`SymbolicDim`.
|
|
1242
|
+
Each dimension can be an integer, None, or a :class:`SymbolicDim`.
|
|
1070
1243
|
frozen: Indicates whether the shape is immutable. When frozen, the shape
|
|
1071
1244
|
cannot be modified or unfrozen.
|
|
1072
1245
|
"""
|
|
@@ -1121,7 +1294,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1121
1294
|
"""Whether the shape is frozen.
|
|
1122
1295
|
|
|
1123
1296
|
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1124
|
-
Call :
|
|
1297
|
+
Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
|
|
1125
1298
|
new shape with the same dimensions that can be modified.
|
|
1126
1299
|
"""
|
|
1127
1300
|
return self._frozen
|
|
@@ -1289,19 +1462,24 @@ def _normalize_domain(domain: str) -> str:
|
|
|
1289
1462
|
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
1290
1463
|
"""IR Node.
|
|
1291
1464
|
|
|
1292
|
-
|
|
1293
|
-
|
|
1465
|
+
.. tip::
|
|
1466
|
+
For a more convenient way (that supports Python objects
|
|
1467
|
+
as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
|
|
1468
|
+
|
|
1469
|
+
If ``graph`` is provided, the node will be added to the graph. Otherwise,
|
|
1470
|
+
the user is responsible for calling ``graph.append(node)`` (or other mutation methods
|
|
1294
1471
|
in :class:`Graph`) to add the node to the graph.
|
|
1295
1472
|
|
|
1296
|
-
After the node is initialized, it will add itself as a user of
|
|
1473
|
+
After the node is initialized, it will add itself as a user of its input values.
|
|
1297
1474
|
|
|
1298
1475
|
The output values of the node are created during node initialization and are immutable.
|
|
1299
|
-
To change the output values, create a new node and
|
|
1300
|
-
the
|
|
1301
|
-
|
|
1476
|
+
To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
|
|
1477
|
+
replace the input in the consuming node by calling :meth:`replace_input_with`.
|
|
1478
|
+
You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
|
|
1479
|
+
to replace all uses of the output values.
|
|
1302
1480
|
|
|
1303
|
-
.. note
|
|
1304
|
-
When the ``domain`` is
|
|
1481
|
+
.. note::
|
|
1482
|
+
When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1305
1483
|
"""
|
|
1306
1484
|
|
|
1307
1485
|
__slots__ = (
|
|
@@ -1339,7 +1517,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1339
1517
|
|
|
1340
1518
|
Args:
|
|
1341
1519
|
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
1342
|
-
When it is
|
|
1520
|
+
When it is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1343
1521
|
op_type: The name of the operator.
|
|
1344
1522
|
inputs: The input values. When an input is ``None``, it is an empty input.
|
|
1345
1523
|
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
@@ -1632,7 +1810,15 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1632
1810
|
|
|
1633
1811
|
@property
|
|
1634
1812
|
def attributes(self) -> _graph_containers.Attributes:
|
|
1635
|
-
"""The attributes of the node.
|
|
1813
|
+
"""The attributes of the node as ``dict[str, Attr]`` with additional access methods.
|
|
1814
|
+
|
|
1815
|
+
Use it as a dictionary with keys being the attribute names and values being the
|
|
1816
|
+
:class:`Attr` objects.
|
|
1817
|
+
|
|
1818
|
+
Use ``node.attributes.add(attr)`` to add an attribute to the node.
|
|
1819
|
+
Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
|
|
1820
|
+
Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
|
|
1821
|
+
"""
|
|
1636
1822
|
return self._attributes
|
|
1637
1823
|
|
|
1638
1824
|
@property
|
|
@@ -1799,12 +1985,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1799
1985
|
The index of the output of the node that produces the value can be accessed with
|
|
1800
1986
|
:meth:`index`.
|
|
1801
1987
|
|
|
1802
|
-
To find all the nodes that use this value as an input, call :meth:`uses`.
|
|
1988
|
+
To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
|
|
1989
|
+
nodes can be obtained with :meth:`consumers`.
|
|
1803
1990
|
|
|
1804
1991
|
To check if the value is an is an input, output or initializer of a graph,
|
|
1805
1992
|
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
|
|
1806
1993
|
|
|
1807
|
-
Use :
|
|
1994
|
+
Use :attr:`graph` to get the graph that owns the value.
|
|
1808
1995
|
"""
|
|
1809
1996
|
|
|
1810
1997
|
__slots__ = (
|
|
@@ -2154,6 +2341,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2154
2341
|
seen as a Sequence of nodes and should be used as such. For example, to obtain
|
|
2155
2342
|
all nodes as a list, call ``list(graph)``.
|
|
2156
2343
|
|
|
2344
|
+
.. versionchanged:: 0.1.1
|
|
2345
|
+
Values with non-none producers will be rejected as graph inputs or initializers.
|
|
2346
|
+
|
|
2347
|
+
.. versionadded:: 0.1.1
|
|
2348
|
+
Added ``add`` method to initializers and attributes.
|
|
2349
|
+
|
|
2157
2350
|
Attributes:
|
|
2158
2351
|
name: The name of the graph.
|
|
2159
2352
|
inputs: The input values of the graph.
|
|
@@ -2221,7 +2414,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2221
2414
|
|
|
2222
2415
|
@property
|
|
2223
2416
|
def initializers(self) -> _graph_containers.GraphInitializers:
|
|
2224
|
-
"""The initializers of the graph as a ``
|
|
2417
|
+
"""The initializers of the graph as a ``dict[str, Value]``.
|
|
2225
2418
|
|
|
2226
2419
|
The keys are the names of the initializers. The values are the :class:`Value` objects.
|
|
2227
2420
|
|
|
@@ -2357,6 +2550,33 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2357
2550
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2358
2551
|
return len(self)
|
|
2359
2552
|
|
|
2553
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
2554
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
2555
|
+
|
|
2556
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
2557
|
+
Consider using
|
|
2558
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2559
|
+
traversals on nodes.
|
|
2560
|
+
|
|
2561
|
+
.. versionadded:: 0.1.2
|
|
2562
|
+
"""
|
|
2563
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2564
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2565
|
+
|
|
2566
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
2567
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2568
|
+
|
|
2569
|
+
.. versionadded:: 0.1.2
|
|
2570
|
+
"""
|
|
2571
|
+
seen_graphs: set[Graph] = set()
|
|
2572
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2573
|
+
graph = node.graph
|
|
2574
|
+
if graph is self:
|
|
2575
|
+
continue
|
|
2576
|
+
if graph is not None and graph not in seen_graphs:
|
|
2577
|
+
seen_graphs.add(graph)
|
|
2578
|
+
yield graph
|
|
2579
|
+
|
|
2360
2580
|
# Mutation methods
|
|
2361
2581
|
def append(self, node: Node, /) -> None:
|
|
2362
2582
|
"""Append a node to the graph in O(1) time.
|
|
@@ -2862,7 +3082,7 @@ Model(
|
|
|
2862
3082
|
"""Get all graphs and subgraphs in the model.
|
|
2863
3083
|
|
|
2864
3084
|
This is a convenience method to traverse the model. Consider using
|
|
2865
|
-
|
|
3085
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2866
3086
|
traversals on nodes.
|
|
2867
3087
|
"""
|
|
2868
3088
|
# NOTE(justinchuby): Given
|
|
@@ -2871,11 +3091,8 @@ Model(
|
|
|
2871
3091
|
# (3) Users familiar with onnxruntime optimization tools expect this method
|
|
2872
3092
|
# I created this method as a core method instead of an iterator in
|
|
2873
3093
|
# `traversal.py`.
|
|
2874
|
-
|
|
2875
|
-
|
|
2876
|
-
if node.graph is not None and node.graph not in seen_graphs:
|
|
2877
|
-
seen_graphs.add(node.graph)
|
|
2878
|
-
yield node.graph
|
|
3094
|
+
yield self.graph
|
|
3095
|
+
yield from self.graph.subgraphs()
|
|
2879
3096
|
|
|
2880
3097
|
|
|
2881
3098
|
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
@@ -3009,6 +3226,33 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3009
3226
|
def metadata_props(self) -> dict[str, str]:
|
|
3010
3227
|
return self._graph.metadata_props
|
|
3011
3228
|
|
|
3229
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
3230
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
3231
|
+
|
|
3232
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
3233
|
+
Consider using
|
|
3234
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3235
|
+
traversals on nodes.
|
|
3236
|
+
|
|
3237
|
+
.. versionadded:: 0.1.2
|
|
3238
|
+
"""
|
|
3239
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3240
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3241
|
+
|
|
3242
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
3243
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3244
|
+
|
|
3245
|
+
.. versionadded:: 0.1.2
|
|
3246
|
+
"""
|
|
3247
|
+
seen_graphs: set[Graph] = set()
|
|
3248
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3249
|
+
graph = node.graph
|
|
3250
|
+
if graph is self._graph:
|
|
3251
|
+
continue
|
|
3252
|
+
if graph is not None and graph not in seen_graphs:
|
|
3253
|
+
seen_graphs.add(graph)
|
|
3254
|
+
yield graph
|
|
3255
|
+
|
|
3012
3256
|
# Mutation methods
|
|
3013
3257
|
def append(self, node: Node, /) -> None:
|
|
3014
3258
|
"""Append a node to the function in O(1) time."""
|
onnx_ir/_enums.py
CHANGED
|
@@ -114,7 +114,20 @@ class DataType(enum.IntEnum):
|
|
|
114
114
|
@property
|
|
115
115
|
def itemsize(self) -> float:
|
|
116
116
|
"""Returns the size of the data type in bytes."""
|
|
117
|
-
return
|
|
117
|
+
return self.bitwidth / 8
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def bitwidth(self) -> int:
|
|
121
|
+
"""Returns the bit width of the data type.
|
|
122
|
+
|
|
123
|
+
.. versionadded:: 0.1.2
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
TypeError: If the data type is not supported.
|
|
127
|
+
"""
|
|
128
|
+
if self not in _BITWIDTH_MAP:
|
|
129
|
+
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
|
|
130
|
+
return _BITWIDTH_MAP[self]
|
|
118
131
|
|
|
119
132
|
def numpy(self) -> np.dtype:
|
|
120
133
|
"""Returns the numpy dtype for the ONNX data type.
|
|
@@ -163,30 +176,29 @@ class DataType(enum.IntEnum):
|
|
|
163
176
|
return self.__repr__()
|
|
164
177
|
|
|
165
178
|
|
|
166
|
-
|
|
167
|
-
DataType.FLOAT:
|
|
168
|
-
DataType.UINT8:
|
|
169
|
-
DataType.INT8:
|
|
170
|
-
DataType.UINT16:
|
|
171
|
-
DataType.INT16:
|
|
172
|
-
DataType.INT32:
|
|
173
|
-
DataType.INT64:
|
|
174
|
-
DataType.
|
|
175
|
-
DataType.
|
|
176
|
-
DataType.
|
|
177
|
-
DataType.
|
|
178
|
-
DataType.
|
|
179
|
-
DataType.
|
|
180
|
-
DataType.
|
|
181
|
-
DataType.
|
|
182
|
-
DataType.
|
|
183
|
-
DataType.
|
|
184
|
-
DataType.
|
|
185
|
-
DataType.
|
|
186
|
-
DataType.
|
|
187
|
-
DataType.
|
|
188
|
-
DataType.
|
|
189
|
-
DataType.FLOAT4E2M1: 0.5,
|
|
179
|
+
_BITWIDTH_MAP = {
|
|
180
|
+
DataType.FLOAT: 32,
|
|
181
|
+
DataType.UINT8: 8,
|
|
182
|
+
DataType.INT8: 8,
|
|
183
|
+
DataType.UINT16: 16,
|
|
184
|
+
DataType.INT16: 16,
|
|
185
|
+
DataType.INT32: 32,
|
|
186
|
+
DataType.INT64: 64,
|
|
187
|
+
DataType.BOOL: 8,
|
|
188
|
+
DataType.FLOAT16: 16,
|
|
189
|
+
DataType.DOUBLE: 64,
|
|
190
|
+
DataType.UINT32: 32,
|
|
191
|
+
DataType.UINT64: 64,
|
|
192
|
+
DataType.COMPLEX64: 64, # 2 * 32
|
|
193
|
+
DataType.COMPLEX128: 128, # 2 * 64
|
|
194
|
+
DataType.BFLOAT16: 16,
|
|
195
|
+
DataType.FLOAT8E4M3FN: 8,
|
|
196
|
+
DataType.FLOAT8E4M3FNUZ: 8,
|
|
197
|
+
DataType.FLOAT8E5M2: 8,
|
|
198
|
+
DataType.FLOAT8E5M2FNUZ: 8,
|
|
199
|
+
DataType.UINT4: 4,
|
|
200
|
+
DataType.INT4: 4,
|
|
201
|
+
DataType.FLOAT4E2M1: 4,
|
|
190
202
|
}
|
|
191
203
|
|
|
192
204
|
|
onnx_ir/_graph_containers.py
CHANGED
|
@@ -216,7 +216,7 @@ class GraphOutputs(_GraphIO):
|
|
|
216
216
|
|
|
217
217
|
|
|
218
218
|
class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
219
|
-
"""The initializers of a Graph."""
|
|
219
|
+
"""The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
|
|
220
220
|
|
|
221
221
|
def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
|
|
222
222
|
# Perform checks first in _set_graph before modifying the data structure with super().__init__()
|
|
@@ -291,7 +291,7 @@ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
|
291
291
|
|
|
292
292
|
|
|
293
293
|
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
-
"""The attributes of a Node."""
|
|
294
|
+
"""The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
|
|
295
295
|
|
|
296
296
|
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
297
|
super().__init__({attr.name: attr for attr in attrs})
|
onnx_ir/_io.py
CHANGED
|
@@ -7,10 +7,11 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = ["load", "save"]
|
|
8
8
|
|
|
9
9
|
import os
|
|
10
|
+
from typing import Callable
|
|
10
11
|
|
|
11
|
-
import onnx
|
|
12
|
+
import onnx # noqa: TID251
|
|
12
13
|
|
|
13
|
-
from onnx_ir import _core, serde
|
|
14
|
+
from onnx_ir import _core, _protocols, serde
|
|
14
15
|
from onnx_ir import external_data as _external_data
|
|
15
16
|
from onnx_ir._polyfill import zip
|
|
16
17
|
|
|
@@ -43,6 +44,8 @@ def save(
|
|
|
43
44
|
format: str | None = None,
|
|
44
45
|
external_data: str | os.PathLike | None = None,
|
|
45
46
|
size_threshold_bytes: int = 256,
|
|
47
|
+
callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
|
|
48
|
+
| None = None,
|
|
46
49
|
) -> None:
|
|
47
50
|
"""Save an ONNX model to a file.
|
|
48
51
|
|
|
@@ -52,6 +55,30 @@ def save(
|
|
|
52
55
|
to load the newly saved model, or provide a different external data path that
|
|
53
56
|
is not currently referenced by any tensors in the model.
|
|
54
57
|
|
|
58
|
+
.. tip::
|
|
59
|
+
|
|
60
|
+
A simple progress bar can be implemented by passing a callback function as the following::
|
|
61
|
+
|
|
62
|
+
import onnx_ir as ir
|
|
63
|
+
import tqdm
|
|
64
|
+
|
|
65
|
+
with tqdm.tqdm() as pbar:
|
|
66
|
+
total_set = False
|
|
67
|
+
|
|
68
|
+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
|
|
69
|
+
nonlocal total_set
|
|
70
|
+
if not total_set:
|
|
71
|
+
pbar.total = metadata.total
|
|
72
|
+
total_set = True
|
|
73
|
+
|
|
74
|
+
pbar.update()
|
|
75
|
+
pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
|
|
76
|
+
|
|
77
|
+
ir.save(
|
|
78
|
+
...,
|
|
79
|
+
callback=callback,
|
|
80
|
+
)
|
|
81
|
+
|
|
55
82
|
Args:
|
|
56
83
|
model: The model to save.
|
|
57
84
|
path: The path to save the model to. E.g. "model.onnx".
|
|
@@ -65,6 +92,8 @@ def save(
|
|
|
65
92
|
it will be serialized in the ONNX Proto message.
|
|
66
93
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
67
94
|
Effective only when ``external_data`` is set.
|
|
95
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
96
|
+
for debugging or logging purposes.
|
|
68
97
|
|
|
69
98
|
Raises:
|
|
70
99
|
ValueError: If the external data path is an absolute path.
|
|
@@ -77,12 +106,19 @@ def save(
|
|
|
77
106
|
base_dir = os.path.dirname(path)
|
|
78
107
|
|
|
79
108
|
# Store the original initializer values so they can be restored if modify_model=False
|
|
80
|
-
initializer_values =
|
|
109
|
+
initializer_values: list[_core.Value] = []
|
|
110
|
+
for graph in model.graphs():
|
|
111
|
+
# Collect from all subgraphs as well
|
|
112
|
+
initializer_values.extend(graph.initializers.values())
|
|
81
113
|
tensors = [v.const_value for v in initializer_values]
|
|
82
114
|
|
|
83
115
|
try:
|
|
84
116
|
model = _external_data.unload_from_model(
|
|
85
|
-
model,
|
|
117
|
+
model,
|
|
118
|
+
base_dir,
|
|
119
|
+
external_data,
|
|
120
|
+
size_threshold_bytes=size_threshold_bytes,
|
|
121
|
+
callback=callback,
|
|
86
122
|
)
|
|
87
123
|
proto = serde.serialize_model(model)
|
|
88
124
|
onnx.save(proto, path, format=format)
|