onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +125 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +291 -76
- onnx_ir/_enums.py +35 -25
- onnx_ir/_graph_containers.py +114 -9
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +10 -25
- onnx_ir/passes/common/inliner.py +4 -3
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +171 -6
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/METADATA +22 -4
- onnx_ir-0.1.2.dist-info/RECORD +42 -0
- onnx_ir-0.1.0.dist-info/RECORD +0 -41
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py
CHANGED
|
@@ -22,13 +22,12 @@ import os
|
|
|
22
22
|
import sys
|
|
23
23
|
import textwrap
|
|
24
24
|
import typing
|
|
25
|
-
from collections import OrderedDict
|
|
26
25
|
from collections.abc import (
|
|
27
26
|
Collection,
|
|
28
27
|
Hashable,
|
|
29
28
|
Iterable,
|
|
30
29
|
Iterator,
|
|
31
|
-
|
|
30
|
+
Mapping,
|
|
32
31
|
MutableSequence,
|
|
33
32
|
Sequence,
|
|
34
33
|
)
|
|
@@ -252,11 +251,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
252
251
|
or corresponding dtypes from the ``ml_dtype`` package.
|
|
253
252
|
"""
|
|
254
253
|
if dtype in _NON_NUMPY_NATIVE_TYPES:
|
|
255
|
-
if dtype.
|
|
254
|
+
if dtype.bitwidth == 16 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
|
|
256
255
|
raise TypeError(
|
|
257
256
|
f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
|
|
258
257
|
)
|
|
259
|
-
if dtype.
|
|
258
|
+
if dtype.bitwidth == 8 and array.dtype not in (
|
|
260
259
|
np.uint8,
|
|
261
260
|
ml_dtypes.float8_e4m3fnuz,
|
|
262
261
|
ml_dtypes.float8_e4m3fn,
|
|
@@ -386,9 +385,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
386
385
|
|
|
387
386
|
Args:
|
|
388
387
|
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
389
|
-
When the dtype is not one of the numpy native dtypes, the value
|
|
390
|
-
|
|
391
|
-
when the value is a numpy array;
|
|
388
|
+
When the dtype is not one of the numpy native dtypes, the value can
|
|
389
|
+
be ``uint8`` (unpacked) or ml_dtypes types for 4-bit and 8-bit data types,
|
|
390
|
+
and ``uint16`` or ml_dtype.bfloat16 for bfloat16 when the value is a numpy array;
|
|
391
|
+
``dtype`` must be specified in this case.
|
|
392
392
|
dtype: The data type of the tensor. It can be None only when value is a numpy array.
|
|
393
393
|
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
|
|
394
394
|
shape: The shape of the tensor. If None, the shape is obtained from the value.
|
|
@@ -422,7 +422,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
422
422
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
423
423
|
else:
|
|
424
424
|
raise ValueError(
|
|
425
|
-
"The dtype must be specified when the value is not a numpy array."
|
|
425
|
+
"The dtype must be specified when the value is not a numpy array. "
|
|
426
|
+
"Value type: {type(value)}"
|
|
426
427
|
)
|
|
427
428
|
else:
|
|
428
429
|
if isinstance(value, np.ndarray):
|
|
@@ -503,7 +504,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
503
504
|
_enums.DataType.FLOAT4E2M1,
|
|
504
505
|
}:
|
|
505
506
|
# Pack the array into int4
|
|
506
|
-
array = _type_casting.
|
|
507
|
+
array = _type_casting.pack_4bitx2(array)
|
|
507
508
|
else:
|
|
508
509
|
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
|
|
509
510
|
if not _IS_LITTLE_ENDIAN:
|
|
@@ -962,8 +963,151 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
962
963
|
return self._evaluate().tobytes()
|
|
963
964
|
|
|
964
965
|
|
|
966
|
+
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
967
|
+
"""A tensor that stores 4bit datatypes in packed format."""
|
|
968
|
+
|
|
969
|
+
__slots__ = (
|
|
970
|
+
"_dtype",
|
|
971
|
+
"_raw",
|
|
972
|
+
"_shape",
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
def __init__(
|
|
976
|
+
self,
|
|
977
|
+
value: TArrayCompatible,
|
|
978
|
+
dtype: _enums.DataType,
|
|
979
|
+
*,
|
|
980
|
+
shape: Shape | Sequence[int],
|
|
981
|
+
name: str | None = None,
|
|
982
|
+
doc_string: str | None = None,
|
|
983
|
+
metadata_props: dict[str, str] | None = None,
|
|
984
|
+
) -> None:
|
|
985
|
+
"""Initialize a tensor.
|
|
986
|
+
|
|
987
|
+
Args:
|
|
988
|
+
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
989
|
+
The value MUST be packed in an integer dtype.
|
|
990
|
+
dtype: The data type of the tensor. Must be one of INT4, UINT4, FLOAT4E2M1.
|
|
991
|
+
shape: The shape of the tensor.
|
|
992
|
+
name: The name of the tensor.
|
|
993
|
+
doc_string: The documentation string.
|
|
994
|
+
metadata_props: The metadata properties.
|
|
995
|
+
|
|
996
|
+
Raises:
|
|
997
|
+
TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
|
|
998
|
+
TypeError: If the value is a numpy array and the dtype is not uint8 or one of the ml_dtypes dtypes.
|
|
999
|
+
"""
|
|
1000
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
1001
|
+
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
|
|
1002
|
+
raise TypeError(f"Expected an array compatible object, got {type(value)}")
|
|
1003
|
+
self._shape = Shape(shape)
|
|
1004
|
+
self._shape.freeze()
|
|
1005
|
+
if dtype.bitwidth != 4:
|
|
1006
|
+
raise TypeError(
|
|
1007
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {dtype}"
|
|
1008
|
+
)
|
|
1009
|
+
self._dtype = dtype
|
|
1010
|
+
self._raw = value
|
|
1011
|
+
|
|
1012
|
+
if isinstance(value, np.ndarray):
|
|
1013
|
+
if (
|
|
1014
|
+
value.dtype == ml_dtypes.float4_e2m1fn
|
|
1015
|
+
or value.dtype == ml_dtypes.uint4
|
|
1016
|
+
or value.dtype == ml_dtypes.int4
|
|
1017
|
+
):
|
|
1018
|
+
raise TypeError(
|
|
1019
|
+
f"PackedTensor expects the value to be packed, but got {value.dtype} which is not packed. "
|
|
1020
|
+
"Please pack the value or use `onnx_ir.Tensor`."
|
|
1021
|
+
)
|
|
1022
|
+
# Check after shape and dtype is set
|
|
1023
|
+
if value.size != self.nbytes:
|
|
1024
|
+
raise ValueError(
|
|
1025
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {value.nbytes} bytes"
|
|
1026
|
+
)
|
|
1027
|
+
|
|
1028
|
+
def __array__(self, dtype: Any = None, copy: bool = False) -> np.ndarray:
|
|
1029
|
+
return self.numpy()
|
|
1030
|
+
|
|
1031
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
1032
|
+
if _compatible_with_dlpack(self._raw):
|
|
1033
|
+
return self._raw.__dlpack__(stream=stream)
|
|
1034
|
+
return self.__array__().__dlpack__(stream=stream)
|
|
1035
|
+
|
|
1036
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
1037
|
+
if _compatible_with_dlpack(self._raw):
|
|
1038
|
+
return self._raw.__dlpack_device__()
|
|
1039
|
+
return self.__array__().__dlpack_device__()
|
|
1040
|
+
|
|
1041
|
+
def __repr__(self) -> str:
|
|
1042
|
+
return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
|
|
1043
|
+
|
|
1044
|
+
@property
|
|
1045
|
+
def dtype(self) -> _enums.DataType:
|
|
1046
|
+
"""The data type of the tensor. Immutable."""
|
|
1047
|
+
return self._dtype
|
|
1048
|
+
|
|
1049
|
+
@property
|
|
1050
|
+
def shape(self) -> Shape:
|
|
1051
|
+
"""The shape of the tensor. Immutable."""
|
|
1052
|
+
return self._shape
|
|
1053
|
+
|
|
1054
|
+
@property
|
|
1055
|
+
def raw(self) -> TArrayCompatible:
|
|
1056
|
+
"""Backing data of the tensor. Immutable."""
|
|
1057
|
+
return self._raw # type: ignore[return-value]
|
|
1058
|
+
|
|
1059
|
+
def numpy(self) -> np.ndarray:
|
|
1060
|
+
"""Return the tensor as a numpy array.
|
|
1061
|
+
|
|
1062
|
+
When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
|
|
1063
|
+
package are used. The values can be reinterpreted as bit representations
|
|
1064
|
+
using the ``.view()`` method.
|
|
1065
|
+
"""
|
|
1066
|
+
array = self.numpy_packed()
|
|
1067
|
+
# ONNX IR returns the unpacked arrays
|
|
1068
|
+
if self.dtype == _enums.DataType.INT4:
|
|
1069
|
+
return _type_casting.unpack_int4(array, self.shape.numpy())
|
|
1070
|
+
if self.dtype == _enums.DataType.UINT4:
|
|
1071
|
+
return _type_casting.unpack_uint4(array, self.shape.numpy())
|
|
1072
|
+
if self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
1073
|
+
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
|
|
1074
|
+
raise TypeError(
|
|
1075
|
+
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
|
|
1076
|
+
)
|
|
1077
|
+
|
|
1078
|
+
def numpy_packed(self) -> npt.NDArray[np.uint8]:
|
|
1079
|
+
"""Return the tensor as a packed array."""
|
|
1080
|
+
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
|
|
1081
|
+
array = np.asarray(self._raw)
|
|
1082
|
+
else:
|
|
1083
|
+
assert _compatible_with_dlpack(self._raw), (
|
|
1084
|
+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
|
|
1085
|
+
)
|
|
1086
|
+
array = np.from_dlpack(self._raw)
|
|
1087
|
+
if array.nbytes != self.nbytes:
|
|
1088
|
+
raise ValueError(
|
|
1089
|
+
f"Expected the packed array to be {self.nbytes} bytes (from shape {self.shape}), but got {array.nbytes} bytes"
|
|
1090
|
+
)
|
|
1091
|
+
return array.view(np.uint8)
|
|
1092
|
+
|
|
1093
|
+
def tobytes(self) -> bytes:
|
|
1094
|
+
"""Returns the value as bytes encoded in little endian.
|
|
1095
|
+
|
|
1096
|
+
Override this method for more efficient serialization when the raw
|
|
1097
|
+
value is not a numpy array.
|
|
1098
|
+
"""
|
|
1099
|
+
array = self.numpy_packed()
|
|
1100
|
+
if not _IS_LITTLE_ENDIAN:
|
|
1101
|
+
array = array.view(array.dtype.newbyteorder("<"))
|
|
1102
|
+
return array.tobytes()
|
|
1103
|
+
|
|
1104
|
+
|
|
965
1105
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
966
|
-
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1106
|
+
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
1107
|
+
|
|
1108
|
+
SymbolicDim is used to represent a symbolic (non-integer) dimension in a tensor shape.
|
|
1109
|
+
It is immutable and can be compared or hashed.
|
|
1110
|
+
"""
|
|
967
1111
|
|
|
968
1112
|
__slots__ = ("_value",)
|
|
969
1113
|
|
|
@@ -972,6 +1116,9 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
972
1116
|
|
|
973
1117
|
Args:
|
|
974
1118
|
value: The value of the dimension. It should not be an int.
|
|
1119
|
+
|
|
1120
|
+
Raises:
|
|
1121
|
+
TypeError: If value is an int.
|
|
975
1122
|
"""
|
|
976
1123
|
if isinstance(value, int):
|
|
977
1124
|
raise TypeError(
|
|
@@ -981,15 +1128,18 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
981
1128
|
self._value = value
|
|
982
1129
|
|
|
983
1130
|
def __eq__(self, other: object) -> bool:
|
|
1131
|
+
"""Check equality with another SymbolicDim or string/None."""
|
|
984
1132
|
if not isinstance(other, SymbolicDim):
|
|
985
1133
|
return self.value == other
|
|
986
1134
|
return self.value == other.value
|
|
987
1135
|
|
|
988
1136
|
def __hash__(self) -> int:
|
|
1137
|
+
"""Return the hash of the symbolic dimension value."""
|
|
989
1138
|
return hash(self.value)
|
|
990
1139
|
|
|
991
1140
|
@property
|
|
992
1141
|
def value(self) -> str | None:
|
|
1142
|
+
"""The value of the symbolic dimension (string or None)."""
|
|
993
1143
|
return self._value
|
|
994
1144
|
|
|
995
1145
|
def __str__(self) -> str:
|
|
@@ -1000,7 +1150,14 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
1000
1150
|
|
|
1001
1151
|
|
|
1002
1152
|
def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
1003
|
-
"""
|
|
1153
|
+
"""Check if the value is compatible with int (i.e., can be safely cast to int).
|
|
1154
|
+
|
|
1155
|
+
Args:
|
|
1156
|
+
value: The value to check.
|
|
1157
|
+
|
|
1158
|
+
Returns:
|
|
1159
|
+
True if the value is an int or has an __int__ method, False otherwise.
|
|
1160
|
+
"""
|
|
1004
1161
|
if isinstance(value, int):
|
|
1005
1162
|
return True
|
|
1006
1163
|
if hasattr(value, "__int__"):
|
|
@@ -1012,7 +1169,17 @@ def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
|
1012
1169
|
def _maybe_convert_to_symbolic_dim(
|
|
1013
1170
|
dim: int | SupportsInt | SymbolicDim | str | None,
|
|
1014
1171
|
) -> SymbolicDim | int:
|
|
1015
|
-
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1172
|
+
"""Convert the value to a SymbolicDim if it is not an int.
|
|
1173
|
+
|
|
1174
|
+
Args:
|
|
1175
|
+
dim: The dimension value, which can be int, str, None, or SymbolicDim.
|
|
1176
|
+
|
|
1177
|
+
Returns:
|
|
1178
|
+
An int or SymbolicDim instance.
|
|
1179
|
+
|
|
1180
|
+
Raises:
|
|
1181
|
+
TypeError: If the value is not int, str, None, or SymbolicDim.
|
|
1182
|
+
"""
|
|
1016
1183
|
if dim is None or isinstance(dim, str):
|
|
1017
1184
|
return SymbolicDim(dim)
|
|
1018
1185
|
if _is_int_compatible(dim):
|
|
@@ -1025,21 +1192,20 @@ def _maybe_convert_to_symbolic_dim(
|
|
|
1025
1192
|
|
|
1026
1193
|
|
|
1027
1194
|
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
1028
|
-
"""
|
|
1195
|
+
"""Represents the shape of a tensor, including its dimensions and optional denotations.
|
|
1029
1196
|
|
|
1030
|
-
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1031
|
-
symbolic dimensions.
|
|
1032
|
-
|
|
1033
|
-
A shape can be compared to another shape or plain Python list.
|
|
1197
|
+
The :class:`Shape` class stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1198
|
+
symbolic dimensions. It provides methods for querying and manipulating the shape, as well as for comparing
|
|
1199
|
+
shapes to other shapes or plain Python lists.
|
|
1034
1200
|
|
|
1035
1201
|
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
|
|
1036
1202
|
unfrozen, making it suitable to be shared across tensors or values.
|
|
1037
|
-
Call :
|
|
1203
|
+
Call :meth:`freeze` to freeze the shape.
|
|
1038
1204
|
|
|
1039
|
-
To update the dimension of a frozen shape, call :
|
|
1205
|
+
To update the dimension of a frozen shape, call :meth:`copy` to create a
|
|
1040
1206
|
new shape with the same dimensions that can be modified.
|
|
1041
1207
|
|
|
1042
|
-
Use :
|
|
1208
|
+
Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
|
|
1043
1209
|
|
|
1044
1210
|
Example::
|
|
1045
1211
|
|
|
@@ -1067,7 +1233,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1067
1233
|
|
|
1068
1234
|
Attributes:
|
|
1069
1235
|
dims: A tuple of dimensions representing the shape.
|
|
1070
|
-
Each dimension can be an integer, None or a :class:`SymbolicDim`.
|
|
1236
|
+
Each dimension can be an integer, None, or a :class:`SymbolicDim`.
|
|
1071
1237
|
frozen: Indicates whether the shape is immutable. When frozen, the shape
|
|
1072
1238
|
cannot be modified or unfrozen.
|
|
1073
1239
|
"""
|
|
@@ -1122,7 +1288,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
1122
1288
|
"""Whether the shape is frozen.
|
|
1123
1289
|
|
|
1124
1290
|
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1125
|
-
Call :
|
|
1291
|
+
Call :meth:`freeze` to freeze the shape. Call :meth:`copy` to create a
|
|
1126
1292
|
new shape with the same dimensions that can be modified.
|
|
1127
1293
|
"""
|
|
1128
1294
|
return self._frozen
|
|
@@ -1290,19 +1456,24 @@ def _normalize_domain(domain: str) -> str:
|
|
|
1290
1456
|
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
1291
1457
|
"""IR Node.
|
|
1292
1458
|
|
|
1293
|
-
|
|
1294
|
-
|
|
1459
|
+
.. tip::
|
|
1460
|
+
For a more convenient way (that supports Python objects
|
|
1461
|
+
as attributes) to create a node, use the :func:`onnx_ir.node` constructor.
|
|
1462
|
+
|
|
1463
|
+
If ``graph`` is provided, the node will be added to the graph. Otherwise,
|
|
1464
|
+
the user is responsible for calling ``graph.append(node)`` (or other mutation methods
|
|
1295
1465
|
in :class:`Graph`) to add the node to the graph.
|
|
1296
1466
|
|
|
1297
|
-
After the node is initialized, it will add itself as a user of
|
|
1467
|
+
After the node is initialized, it will add itself as a user of its input values.
|
|
1298
1468
|
|
|
1299
1469
|
The output values of the node are created during node initialization and are immutable.
|
|
1300
|
-
To change the output values, create a new node and
|
|
1301
|
-
the
|
|
1302
|
-
|
|
1470
|
+
To change the output values, create a new node and, for each use of the old outputs (``output.uses()``),
|
|
1471
|
+
replace the input in the consuming node by calling :meth:`replace_input_with`.
|
|
1472
|
+
You can also use the :func:`~onnx_ir.convenience.replace_all_uses_with` method
|
|
1473
|
+
to replace all uses of the output values.
|
|
1303
1474
|
|
|
1304
|
-
.. note
|
|
1305
|
-
When the ``domain`` is
|
|
1475
|
+
.. note::
|
|
1476
|
+
When the ``domain`` is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1306
1477
|
"""
|
|
1307
1478
|
|
|
1308
1479
|
__slots__ = (
|
|
@@ -1325,7 +1496,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1325
1496
|
domain: str,
|
|
1326
1497
|
op_type: str,
|
|
1327
1498
|
inputs: Iterable[Value | None],
|
|
1328
|
-
attributes: Iterable[Attr] = (),
|
|
1499
|
+
attributes: Iterable[Attr] | Mapping[str, Attr] = (),
|
|
1329
1500
|
*,
|
|
1330
1501
|
overload: str = "",
|
|
1331
1502
|
num_outputs: int | None = None,
|
|
@@ -1340,7 +1511,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1340
1511
|
|
|
1341
1512
|
Args:
|
|
1342
1513
|
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
1343
|
-
When it is
|
|
1514
|
+
When it is ``"ai.onnx"``, it is normalized to ``""``.
|
|
1344
1515
|
op_type: The name of the operator.
|
|
1345
1516
|
inputs: The input values. When an input is ``None``, it is an empty input.
|
|
1346
1517
|
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
@@ -1371,15 +1542,10 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1371
1542
|
self._inputs: tuple[Value | None, ...] = tuple(inputs)
|
|
1372
1543
|
# Values belong to their defining nodes. The values list is immutable
|
|
1373
1544
|
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
|
|
1374
|
-
|
|
1375
|
-
|
|
1376
|
-
|
|
1377
|
-
|
|
1378
|
-
"If you are copying the attributes from another node, make sure you call "
|
|
1379
|
-
"node.attributes.values() because it is a dictionary."
|
|
1380
|
-
)
|
|
1381
|
-
self._attributes: OrderedDict[str, Attr] = OrderedDict(
|
|
1382
|
-
(attr.name, attr) for attr in attributes
|
|
1545
|
+
if isinstance(attributes, Mapping):
|
|
1546
|
+
attributes = tuple(attributes.values())
|
|
1547
|
+
self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
|
|
1548
|
+
attributes
|
|
1383
1549
|
)
|
|
1384
1550
|
self._overload: str = overload
|
|
1385
1551
|
# TODO(justinchuby): Potentially support a version range
|
|
@@ -1637,8 +1803,16 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1637
1803
|
raise AttributeError("outputs is immutable. Please create a new node instead.")
|
|
1638
1804
|
|
|
1639
1805
|
@property
|
|
1640
|
-
def attributes(self) ->
|
|
1641
|
-
"""The attributes of the node.
|
|
1806
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
1807
|
+
"""The attributes of the node as ``dict[str, Attr]`` with additional access methods.
|
|
1808
|
+
|
|
1809
|
+
Use it as a dictionary with keys being the attribute names and values being the
|
|
1810
|
+
:class:`Attr` objects.
|
|
1811
|
+
|
|
1812
|
+
Use ``node.attributes.add(attr)`` to add an attribute to the node.
|
|
1813
|
+
Use ``node.attributes.get_int(name, default)`` to get an integer attribute value.
|
|
1814
|
+
Refer to the :class:`~onnx_ir._graph_containers.Attributes` for more methods.
|
|
1815
|
+
"""
|
|
1642
1816
|
return self._attributes
|
|
1643
1817
|
|
|
1644
1818
|
@property
|
|
@@ -1805,12 +1979,13 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1805
1979
|
The index of the output of the node that produces the value can be accessed with
|
|
1806
1980
|
:meth:`index`.
|
|
1807
1981
|
|
|
1808
|
-
To find all the nodes that use this value as an input, call :meth:`uses`.
|
|
1982
|
+
To find all the nodes that use this value as an input, call :meth:`uses`. Consuming
|
|
1983
|
+
nodes can be obtained with :meth:`consumers`.
|
|
1809
1984
|
|
|
1810
1985
|
To check if the value is an is an input, output or initializer of a graph,
|
|
1811
1986
|
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
|
|
1812
1987
|
|
|
1813
|
-
Use :
|
|
1988
|
+
Use :attr:`graph` to get the graph that owns the value.
|
|
1814
1989
|
"""
|
|
1815
1990
|
|
|
1816
1991
|
__slots__ = (
|
|
@@ -2201,17 +2376,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2201
2376
|
# Private fields that are not to be accessed by any other classes
|
|
2202
2377
|
self._inputs = _graph_containers.GraphInputs(self, inputs)
|
|
2203
2378
|
self._outputs = _graph_containers.GraphOutputs(self, outputs)
|
|
2204
|
-
self._initializers = _graph_containers.GraphInitializers(
|
|
2205
|
-
|
|
2206
|
-
|
|
2207
|
-
raise TypeError(
|
|
2208
|
-
"Initializer must be a Value, not a string. "
|
|
2209
|
-
"If you are copying the initializers from another graph, "
|
|
2210
|
-
"make sure you call graph.initializers.values() because it is a dictionary."
|
|
2211
|
-
)
|
|
2212
|
-
if initializer.name is None:
|
|
2213
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
2214
|
-
self._initializers[initializer.name] = initializer
|
|
2379
|
+
self._initializers = _graph_containers.GraphInitializers(
|
|
2380
|
+
self, {initializer.name: initializer for initializer in initializers}
|
|
2381
|
+
)
|
|
2215
2382
|
self._doc_string = doc_string
|
|
2216
2383
|
self._opset_imports = opset_imports or {}
|
|
2217
2384
|
self._metadata: _metadata.MetadataStore | None = None
|
|
@@ -2234,7 +2401,19 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2234
2401
|
return self._outputs
|
|
2235
2402
|
|
|
2236
2403
|
@property
|
|
2237
|
-
def initializers(self) ->
|
|
2404
|
+
def initializers(self) -> _graph_containers.GraphInitializers:
|
|
2405
|
+
"""The initializers of the graph as a ``dict[str, Value]``.
|
|
2406
|
+
|
|
2407
|
+
The keys are the names of the initializers. The values are the :class:`Value` objects.
|
|
2408
|
+
|
|
2409
|
+
This property additionally supports the ``add`` method, which takes a :class:`Value`
|
|
2410
|
+
and adds it to the initializers if it is not already present.
|
|
2411
|
+
|
|
2412
|
+
.. note::
|
|
2413
|
+
When setting an initializer with ``graph.initializers[key] = value``,
|
|
2414
|
+
if the value does not have a name, it will be assigned ``key`` as its name.
|
|
2415
|
+
|
|
2416
|
+
"""
|
|
2238
2417
|
return self._initializers
|
|
2239
2418
|
|
|
2240
2419
|
def register_initializer(self, value: Value) -> None:
|
|
@@ -2263,15 +2442,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2263
2442
|
" it is not the same object: existing={self._initializers[value.name]!r},"
|
|
2264
2443
|
f" new={value!r}"
|
|
2265
2444
|
)
|
|
2266
|
-
if value.producer() is not None:
|
|
2267
|
-
raise ValueError(
|
|
2268
|
-
f"Value '{value!r}' is produced by a node and cannot be an initializer."
|
|
2269
|
-
)
|
|
2270
2445
|
if value.const_value is None:
|
|
2271
2446
|
raise ValueError(
|
|
2272
2447
|
f"Value '{value!r}' must have its const_value set to be an initializer."
|
|
2273
2448
|
)
|
|
2274
|
-
self._initializers
|
|
2449
|
+
self._initializers.add(value)
|
|
2275
2450
|
|
|
2276
2451
|
@property
|
|
2277
2452
|
def doc_string(self) -> str | None:
|
|
@@ -2363,6 +2538,28 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2363
2538
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2364
2539
|
return len(self)
|
|
2365
2540
|
|
|
2541
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
2542
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
2543
|
+
|
|
2544
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
2545
|
+
Consider using
|
|
2546
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2547
|
+
traversals on nodes.
|
|
2548
|
+
"""
|
|
2549
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2550
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2551
|
+
|
|
2552
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
2553
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time."""
|
|
2554
|
+
seen_graphs: set[Graph] = set()
|
|
2555
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2556
|
+
graph = node.graph
|
|
2557
|
+
if graph is self:
|
|
2558
|
+
continue
|
|
2559
|
+
if graph is not None and graph not in seen_graphs:
|
|
2560
|
+
seen_graphs.add(graph)
|
|
2561
|
+
yield graph
|
|
2562
|
+
|
|
2366
2563
|
# Mutation methods
|
|
2367
2564
|
def append(self, node: Node, /) -> None:
|
|
2368
2565
|
"""Append a node to the graph in O(1) time.
|
|
@@ -2701,7 +2898,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2701
2898
|
outputs: Sequence[Value],
|
|
2702
2899
|
*,
|
|
2703
2900
|
nodes: Iterable[Node],
|
|
2704
|
-
initializers: Sequence[
|
|
2901
|
+
initializers: Sequence[Value] = (),
|
|
2705
2902
|
doc_string: str | None = None,
|
|
2706
2903
|
opset_imports: dict[str, int] | None = None,
|
|
2707
2904
|
name: str | None = None,
|
|
@@ -2710,10 +2907,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2710
2907
|
self.name = name
|
|
2711
2908
|
self.inputs = tuple(inputs)
|
|
2712
2909
|
self.outputs = tuple(outputs)
|
|
2713
|
-
for initializer in initializers
|
|
2714
|
-
if initializer.name is None:
|
|
2715
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
2716
|
-
self.initializers = {tensor.name: tensor for tensor in initializers}
|
|
2910
|
+
self.initializers = {initializer.name: initializer for initializer in initializers}
|
|
2717
2911
|
self.doc_string = doc_string
|
|
2718
2912
|
self.opset_imports = opset_imports or {}
|
|
2719
2913
|
self._metadata: _metadata.MetadataStore | None = None
|
|
@@ -2871,7 +3065,7 @@ Model(
|
|
|
2871
3065
|
"""Get all graphs and subgraphs in the model.
|
|
2872
3066
|
|
|
2873
3067
|
This is a convenience method to traverse the model. Consider using
|
|
2874
|
-
|
|
3068
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2875
3069
|
traversals on nodes.
|
|
2876
3070
|
"""
|
|
2877
3071
|
# NOTE(justinchuby): Given
|
|
@@ -2880,11 +3074,8 @@ Model(
|
|
|
2880
3074
|
# (3) Users familiar with onnxruntime optimization tools expect this method
|
|
2881
3075
|
# I created this method as a core method instead of an iterator in
|
|
2882
3076
|
# `traversal.py`.
|
|
2883
|
-
|
|
2884
|
-
|
|
2885
|
-
if node.graph is not None and node.graph not in seen_graphs:
|
|
2886
|
-
seen_graphs.add(node.graph)
|
|
2887
|
-
yield node.graph
|
|
3077
|
+
yield self.graph
|
|
3078
|
+
yield from self.graph.subgraphs()
|
|
2888
3079
|
|
|
2889
3080
|
|
|
2890
3081
|
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
@@ -2927,13 +3118,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2927
3118
|
# Ensure the inputs and outputs of the function belong to a graph
|
|
2928
3119
|
# and not from an outer scope
|
|
2929
3120
|
graph: Graph,
|
|
2930
|
-
attributes:
|
|
3121
|
+
attributes: Iterable[Attr] | Mapping[str, Attr],
|
|
2931
3122
|
) -> None:
|
|
2932
3123
|
self._domain = domain
|
|
2933
3124
|
self._name = name
|
|
2934
3125
|
self._overload = overload
|
|
2935
3126
|
self._graph = graph
|
|
2936
|
-
|
|
3127
|
+
if isinstance(attributes, Mapping):
|
|
3128
|
+
attributes = tuple(attributes.values())
|
|
3129
|
+
self._attributes = _graph_containers.Attributes(attributes)
|
|
2937
3130
|
|
|
2938
3131
|
def identifier(self) -> _protocols.OperatorIdentifier:
|
|
2939
3132
|
return self.domain, self.name, self.overload
|
|
@@ -2971,7 +3164,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2971
3164
|
return self._graph.outputs
|
|
2972
3165
|
|
|
2973
3166
|
@property
|
|
2974
|
-
def attributes(self) ->
|
|
3167
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
2975
3168
|
return self._attributes
|
|
2976
3169
|
|
|
2977
3170
|
@typing.overload
|
|
@@ -3016,6 +3209,28 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3016
3209
|
def metadata_props(self) -> dict[str, str]:
|
|
3017
3210
|
return self._graph.metadata_props
|
|
3018
3211
|
|
|
3212
|
+
def all_nodes(self) -> Iterator[Node]:
|
|
3213
|
+
"""Get all nodes in the graph and its subgraphs in O(#nodes + #attributes) time.
|
|
3214
|
+
|
|
3215
|
+
This is an alias for ``onnx_ir.traversal.RecursiveGraphIterator(graph)``.
|
|
3216
|
+
Consider using
|
|
3217
|
+
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3218
|
+
traversals on nodes.
|
|
3219
|
+
"""
|
|
3220
|
+
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3221
|
+
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3222
|
+
|
|
3223
|
+
def subgraphs(self) -> Iterator[Graph]:
|
|
3224
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time."""
|
|
3225
|
+
seen_graphs: set[Graph] = set()
|
|
3226
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3227
|
+
graph = node.graph
|
|
3228
|
+
if graph is self._graph:
|
|
3229
|
+
continue
|
|
3230
|
+
if graph is not None and graph not in seen_graphs:
|
|
3231
|
+
seen_graphs.add(graph)
|
|
3232
|
+
yield graph
|
|
3233
|
+
|
|
3019
3234
|
# Mutation methods
|
|
3020
3235
|
def append(self, node: Node, /) -> None:
|
|
3021
3236
|
"""Append a node to the function in O(1) time."""
|