onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +874 -257
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +40 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +217 -0
- onnx_ir/passes/common/inliner.py +332 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.1.dist-info/METADATA +53 -0
- onnx_ir-0.1.1.dist-info/RECORD +42 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/_display.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Internal utilities for displaying the intermediate representation of a model.
|
|
4
4
|
|
|
5
5
|
NOTE: All third-party imports should be scoped and imported only when used to avoid
|
onnx_ir/_enums.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""ONNX IR enums that matches the ONNX spec."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
@@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
|
|
|
64
64
|
FLOAT8E5M2FNUZ = 20
|
|
65
65
|
UINT4 = 21
|
|
66
66
|
INT4 = 22
|
|
67
|
+
FLOAT4E2M1 = 23
|
|
67
68
|
|
|
68
69
|
@classmethod
|
|
69
70
|
def from_numpy(cls, dtype: np.dtype) -> DataType:
|
|
@@ -72,9 +73,43 @@ class DataType(enum.IntEnum):
|
|
|
72
73
|
Raises:
|
|
73
74
|
TypeError: If the data type is not supported by ONNX.
|
|
74
75
|
"""
|
|
75
|
-
if dtype
|
|
76
|
-
|
|
77
|
-
|
|
76
|
+
if dtype in _NP_TYPE_TO_DATA_TYPE:
|
|
77
|
+
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
|
|
78
|
+
|
|
79
|
+
if np.issubdtype(dtype, np.str_):
|
|
80
|
+
return DataType.STRING
|
|
81
|
+
|
|
82
|
+
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
|
|
83
|
+
# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
|
|
84
|
+
if hasattr(dtype, "names"):
|
|
85
|
+
if dtype.names == ("bfloat16",):
|
|
86
|
+
return DataType.BFLOAT16
|
|
87
|
+
if dtype.names == ("e4m3fn",):
|
|
88
|
+
return DataType.FLOAT8E4M3FN
|
|
89
|
+
if dtype.names == ("e4m3fnuz",):
|
|
90
|
+
return DataType.FLOAT8E4M3FNUZ
|
|
91
|
+
if dtype.names == ("e5m2",):
|
|
92
|
+
return DataType.FLOAT8E5M2
|
|
93
|
+
if dtype.names == ("e5m2fnuz",):
|
|
94
|
+
return DataType.FLOAT8E5M2FNUZ
|
|
95
|
+
if dtype.names == ("uint4",):
|
|
96
|
+
return DataType.UINT4
|
|
97
|
+
if dtype.names == ("int4",):
|
|
98
|
+
return DataType.INT4
|
|
99
|
+
if dtype.names == ("float4e2m1",):
|
|
100
|
+
return DataType.FLOAT4E2M1
|
|
101
|
+
raise TypeError(f"Unsupported numpy data type: {dtype}")
|
|
102
|
+
|
|
103
|
+
@classmethod
|
|
104
|
+
def from_short_name(cls, short_name: str) -> DataType:
|
|
105
|
+
"""Returns the ONNX data type for the short name.
|
|
106
|
+
|
|
107
|
+
Raises:
|
|
108
|
+
TypeError: If the short name is not available for the data type.
|
|
109
|
+
"""
|
|
110
|
+
if short_name not in _SHORT_NAME_TO_DATA_TYPE:
|
|
111
|
+
raise TypeError(f"Unknown short name: {short_name}")
|
|
112
|
+
return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
|
|
78
113
|
|
|
79
114
|
@property
|
|
80
115
|
def itemsize(self) -> float:
|
|
@@ -91,6 +126,36 @@ class DataType(enum.IntEnum):
|
|
|
91
126
|
raise TypeError(f"Numpy does not support ONNX data type: {self}")
|
|
92
127
|
return _DATA_TYPE_TO_NP_TYPE[self]
|
|
93
128
|
|
|
129
|
+
def short_name(self) -> str:
|
|
130
|
+
"""Returns the short name of the data type.
|
|
131
|
+
|
|
132
|
+
The short name is a string that is used to represent the data type in a more
|
|
133
|
+
compact form. For example, the short name for `DataType.FLOAT` is "f32".
|
|
134
|
+
To get the corresponding data type back, call ``from_short_name`` on a string.
|
|
135
|
+
|
|
136
|
+
Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
|
|
137
|
+
|
|
138
|
+
Raises:
|
|
139
|
+
TypeError: If the short name is not available for the data type.
|
|
140
|
+
"""
|
|
141
|
+
if self not in _DATA_TYPE_TO_SHORT_NAME:
|
|
142
|
+
raise TypeError(f"Short name not available for ONNX data type: {self}")
|
|
143
|
+
return _DATA_TYPE_TO_SHORT_NAME[self]
|
|
144
|
+
|
|
145
|
+
def is_floating_point(self) -> bool:
|
|
146
|
+
"""Returns True if the data type is a floating point type."""
|
|
147
|
+
return self in {
|
|
148
|
+
DataType.FLOAT,
|
|
149
|
+
DataType.FLOAT16,
|
|
150
|
+
DataType.DOUBLE,
|
|
151
|
+
DataType.BFLOAT16,
|
|
152
|
+
DataType.FLOAT8E4M3FN,
|
|
153
|
+
DataType.FLOAT8E4M3FNUZ,
|
|
154
|
+
DataType.FLOAT8E5M2,
|
|
155
|
+
DataType.FLOAT8E5M2FNUZ,
|
|
156
|
+
DataType.FLOAT4E2M1,
|
|
157
|
+
}
|
|
158
|
+
|
|
94
159
|
def __repr__(self) -> str:
|
|
95
160
|
return self.name
|
|
96
161
|
|
|
@@ -121,6 +186,7 @@ _ITEMSIZE_MAP = {
|
|
|
121
186
|
DataType.FLOAT8E5M2FNUZ: 1,
|
|
122
187
|
DataType.UINT4: 0.5,
|
|
123
188
|
DataType.INT4: 0.5,
|
|
189
|
+
DataType.FLOAT4E2M1: 0.5,
|
|
124
190
|
}
|
|
125
191
|
|
|
126
192
|
|
|
@@ -150,5 +216,41 @@ _NP_TYPE_TO_DATA_TYPE = {
|
|
|
150
216
|
np.dtype(ml_dtypes.uint4): DataType.UINT4,
|
|
151
217
|
}
|
|
152
218
|
|
|
219
|
+
# TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
|
|
220
|
+
_NP_TYPE_TO_DATA_TYPE.update(
|
|
221
|
+
{np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
|
|
222
|
+
if hasattr(ml_dtypes, "float4_e2m1fn")
|
|
223
|
+
else {}
|
|
224
|
+
)
|
|
225
|
+
|
|
153
226
|
# ONNX DataType to Numpy dtype.
|
|
154
227
|
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
|
|
228
|
+
|
|
229
|
+
_DATA_TYPE_TO_SHORT_NAME = {
|
|
230
|
+
DataType.UNDEFINED: "undefined",
|
|
231
|
+
DataType.BFLOAT16: "bf16",
|
|
232
|
+
DataType.DOUBLE: "f64",
|
|
233
|
+
DataType.FLOAT: "f32",
|
|
234
|
+
DataType.FLOAT16: "f16",
|
|
235
|
+
DataType.FLOAT8E4M3FN: "f8e4m3fn",
|
|
236
|
+
DataType.FLOAT8E5M2: "f8e5m2",
|
|
237
|
+
DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
|
|
238
|
+
DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
|
|
239
|
+
DataType.FLOAT4E2M1: "f4e2m1",
|
|
240
|
+
DataType.COMPLEX64: "c64",
|
|
241
|
+
DataType.COMPLEX128: "c128",
|
|
242
|
+
DataType.INT4: "i4",
|
|
243
|
+
DataType.INT8: "i8",
|
|
244
|
+
DataType.INT16: "i16",
|
|
245
|
+
DataType.INT32: "i32",
|
|
246
|
+
DataType.INT64: "i64",
|
|
247
|
+
DataType.BOOL: "b8",
|
|
248
|
+
DataType.UINT4: "u4",
|
|
249
|
+
DataType.UINT8: "u8",
|
|
250
|
+
DataType.UINT16: "u16",
|
|
251
|
+
DataType.UINT32: "u32",
|
|
252
|
+
DataType.UINT64: "u64",
|
|
253
|
+
DataType.STRING: "s",
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
_SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}
|
onnx_ir/_graph_comparison.py
CHANGED
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Tracked containers for graph."""
|
|
4
|
+
|
|
5
|
+
# pylint: disable=protected-access
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"GraphInputs",
|
|
11
|
+
"GraphOutputs",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import collections
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Iterable, Sequence
|
|
17
|
+
from typing import SupportsIndex, TypeVar
|
|
18
|
+
|
|
19
|
+
import onnx_ir
|
|
20
|
+
from onnx_ir import _core, _protocols
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _GraphIO(collections.UserList["_core.Value"]):
|
|
28
|
+
"""The inputs and outputs of a Graph."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, graph: _core.Graph, initlist=None):
|
|
31
|
+
self._graph = graph
|
|
32
|
+
# Use a ref counter to track the number of references to each value
|
|
33
|
+
# in the input/output list. This is used to determine when to unset the graph
|
|
34
|
+
# reference in the value.
|
|
35
|
+
# Even though a duplicated value is invalid in inputs and not recommended in outputs,
|
|
36
|
+
# it is still possible to have duplicated inputs/outputs in an ONNX graph so we
|
|
37
|
+
# need to properly handle this case and maintain the graph reference properly.
|
|
38
|
+
self._ref_counter: collections.Counter[_core.Value] = collections.Counter()
|
|
39
|
+
if initlist is not None:
|
|
40
|
+
initlist = tuple(initlist) # Create a copy in case initlist is a generator
|
|
41
|
+
for value in initlist:
|
|
42
|
+
self._set_graph(value)
|
|
43
|
+
super().__init__(initlist)
|
|
44
|
+
self._check_invariance()
|
|
45
|
+
|
|
46
|
+
def _check_invariance(self) -> None:
|
|
47
|
+
"""Check the invariance of the graph."""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
51
|
+
"""Set the graph for the value."""
|
|
52
|
+
raise NotImplementedError
|
|
53
|
+
|
|
54
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
55
|
+
"""Unset the graph for the value."""
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
def append(self, item: _core.Value) -> None:
|
|
59
|
+
"""Add a new input to the graph."""
|
|
60
|
+
# Perform checks first in _set_graph before modifying the data structure
|
|
61
|
+
self._set_graph(item)
|
|
62
|
+
super().append(item)
|
|
63
|
+
self._check_invariance()
|
|
64
|
+
|
|
65
|
+
def extend(self, other) -> None:
|
|
66
|
+
"""Extend the list of inputs or outputs."""
|
|
67
|
+
other = tuple(other)
|
|
68
|
+
for item in other:
|
|
69
|
+
self._set_graph(item)
|
|
70
|
+
super().extend(other)
|
|
71
|
+
|
|
72
|
+
def insert(self, i: int, item: _core.Value) -> None:
|
|
73
|
+
"""Insert an input/output to the graph."""
|
|
74
|
+
super().insert(i, item)
|
|
75
|
+
self._set_graph(item)
|
|
76
|
+
self._check_invariance()
|
|
77
|
+
|
|
78
|
+
def pop(self, i: int = -1) -> _core.Value:
|
|
79
|
+
"""Remove an input/output from the graph."""
|
|
80
|
+
value = super().pop(i)
|
|
81
|
+
self._maybe_unset_graph(value)
|
|
82
|
+
self._check_invariance()
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
def remove(self, item: _core.Value) -> None:
|
|
86
|
+
"""Remove an input/output from the graph."""
|
|
87
|
+
super().remove(item)
|
|
88
|
+
self._maybe_unset_graph(item)
|
|
89
|
+
self._check_invariance()
|
|
90
|
+
|
|
91
|
+
def clear(self) -> None:
|
|
92
|
+
"""Clear the list."""
|
|
93
|
+
for value in self.data:
|
|
94
|
+
self._maybe_unset_graph(value)
|
|
95
|
+
super().clear()
|
|
96
|
+
|
|
97
|
+
def copy(self) -> list[_core.Value]:
|
|
98
|
+
"""Return a shallow copy of the list."""
|
|
99
|
+
# This is a shallow copy, so the values are not copied, just the references
|
|
100
|
+
return self.data.copy()
|
|
101
|
+
|
|
102
|
+
def __setitem__(self, i, item) -> None:
|
|
103
|
+
"""Replace an input/output to the node."""
|
|
104
|
+
if isinstance(item, Iterable) and isinstance(i, slice):
|
|
105
|
+
# Modify a slice of the list
|
|
106
|
+
for value in self.data[i]:
|
|
107
|
+
self._maybe_unset_graph(value)
|
|
108
|
+
for value in item:
|
|
109
|
+
self._set_graph(value)
|
|
110
|
+
super().__setitem__(i, item)
|
|
111
|
+
self._check_invariance()
|
|
112
|
+
return
|
|
113
|
+
elif isinstance(i, SupportsIndex):
|
|
114
|
+
# Replace a single item
|
|
115
|
+
self._maybe_unset_graph(self.data[i])
|
|
116
|
+
self._set_graph(item)
|
|
117
|
+
super().__setitem__(i, item)
|
|
118
|
+
self._check_invariance()
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
|
|
122
|
+
|
|
123
|
+
def __getitem__(self, i):
|
|
124
|
+
"""Get an input/output from the graph."""
|
|
125
|
+
return self.data[i]
|
|
126
|
+
|
|
127
|
+
def _unimplemented(self, *_args, **_kwargs):
|
|
128
|
+
"""Unimplemented method."""
|
|
129
|
+
raise RuntimeError("Method is not supported")
|
|
130
|
+
|
|
131
|
+
__add__ = _unimplemented
|
|
132
|
+
__radd__ = _unimplemented
|
|
133
|
+
__iadd__ = _unimplemented
|
|
134
|
+
__mul__ = _unimplemented
|
|
135
|
+
__rmul__ = _unimplemented
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GraphInputs(_GraphIO):
|
|
139
|
+
"""The inputs of a Graph."""
|
|
140
|
+
|
|
141
|
+
def _check_invariance(self) -> None:
|
|
142
|
+
"""Check the invariance of the graph."""
|
|
143
|
+
if not onnx_ir.DEBUG:
|
|
144
|
+
return
|
|
145
|
+
for value in self.data:
|
|
146
|
+
if value._graph is self._graph:
|
|
147
|
+
continue
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
153
|
+
"""Set the graph for the value."""
|
|
154
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
|
|
157
|
+
)
|
|
158
|
+
if value.producer() is not None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
|
|
161
|
+
)
|
|
162
|
+
self._ref_counter[value] += 1
|
|
163
|
+
value._is_graph_input = True
|
|
164
|
+
value._graph = self._graph
|
|
165
|
+
|
|
166
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
167
|
+
"""Unset the graph for the value."""
|
|
168
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
169
|
+
self._ref_counter[value] -= 1
|
|
170
|
+
if self._ref_counter[value] > 0:
|
|
171
|
+
# The value is still used by another graph input
|
|
172
|
+
return
|
|
173
|
+
value._is_graph_input = False
|
|
174
|
+
if value._owned_by_graph():
|
|
175
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
176
|
+
return
|
|
177
|
+
value._graph = None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class GraphOutputs(_GraphIO):
|
|
181
|
+
"""The outputs of a Graph."""
|
|
182
|
+
|
|
183
|
+
def _check_invariance(self) -> None:
|
|
184
|
+
"""Check the invariance of the graph."""
|
|
185
|
+
if not onnx_ir.DEBUG:
|
|
186
|
+
return
|
|
187
|
+
for value in self.data:
|
|
188
|
+
if value._graph is self._graph:
|
|
189
|
+
continue
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
195
|
+
"""Set the graph for the value."""
|
|
196
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first"
|
|
199
|
+
)
|
|
200
|
+
self._ref_counter[value] += 1
|
|
201
|
+
value._is_graph_output = True
|
|
202
|
+
value._graph = self._graph
|
|
203
|
+
|
|
204
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
205
|
+
"""Unset the graph for the value."""
|
|
206
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
207
|
+
self._ref_counter[value] -= 1
|
|
208
|
+
if self._ref_counter[value] > 0:
|
|
209
|
+
# The value is still used by another graph input
|
|
210
|
+
return
|
|
211
|
+
value._is_graph_output = False
|
|
212
|
+
if value._owned_by_graph():
|
|
213
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
214
|
+
return
|
|
215
|
+
value._graph = None
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
219
|
+
"""The initializers of a Graph."""
|
|
220
|
+
|
|
221
|
+
def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
|
|
222
|
+
# Perform checks first in _set_graph before modifying the data structure with super().__init__()
|
|
223
|
+
data = {}
|
|
224
|
+
if dict is not None:
|
|
225
|
+
data.update(dict)
|
|
226
|
+
if kwargs:
|
|
227
|
+
data.update(kwargs)
|
|
228
|
+
self._graph = graph
|
|
229
|
+
for value in data.values():
|
|
230
|
+
self._set_graph(value)
|
|
231
|
+
|
|
232
|
+
super().__init__(data)
|
|
233
|
+
|
|
234
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
235
|
+
"""Set the graph for the value."""
|
|
236
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first"
|
|
239
|
+
)
|
|
240
|
+
value._is_initializer = True
|
|
241
|
+
value._graph = self._graph
|
|
242
|
+
|
|
243
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
244
|
+
"""Unset the graph for the value."""
|
|
245
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
246
|
+
value._is_initializer = False
|
|
247
|
+
if value._owned_by_graph():
|
|
248
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
249
|
+
return
|
|
250
|
+
value._graph = None
|
|
251
|
+
|
|
252
|
+
def __setitem__(self, key: str, value: _core.Value) -> None:
|
|
253
|
+
"""Set an initializer for the graph."""
|
|
254
|
+
if not isinstance(value, _core.Value):
|
|
255
|
+
raise TypeError(f"value must be a Value object, not {type(value)}")
|
|
256
|
+
if not isinstance(key, str):
|
|
257
|
+
raise TypeError(f"Value name must be a string, not {type(key)}")
|
|
258
|
+
if key == "":
|
|
259
|
+
raise ValueError("Value name cannot be an empty string")
|
|
260
|
+
if not value.name:
|
|
261
|
+
logger.info("Value %s does not have a name, setting it to '%s'", value, key)
|
|
262
|
+
value.name = key
|
|
263
|
+
elif key != value.name:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
|
|
266
|
+
)
|
|
267
|
+
if value.producer() is not None:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Value '{value}' is produced by a node and cannot be a graph initializer"
|
|
270
|
+
)
|
|
271
|
+
if key in self.data:
|
|
272
|
+
# If the key already exists, unset the old value
|
|
273
|
+
old_value = self.data[key]
|
|
274
|
+
self._maybe_unset_graph(old_value)
|
|
275
|
+
# Must call _set_graph before super().__setitem__ so that when there is an error,
|
|
276
|
+
# the dictionary is not modified
|
|
277
|
+
self._set_graph(value)
|
|
278
|
+
super().__setitem__(key, value)
|
|
279
|
+
|
|
280
|
+
def __delitem__(self, key: str) -> None:
|
|
281
|
+
"""Delete an initializer from the graph."""
|
|
282
|
+
value = self.data[key]
|
|
283
|
+
# Must call _maybe_unset_graph before super().__delitem__ so that when there is an error,
|
|
284
|
+
# the dictionary is not modified
|
|
285
|
+
self._maybe_unset_graph(value)
|
|
286
|
+
super().__delitem__(key)
|
|
287
|
+
|
|
288
|
+
def add(self, value: _core.Value) -> None:
|
|
289
|
+
"""Add an initializer to the graph."""
|
|
290
|
+
self[value.name] = value # type: ignore[index]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
+
"""The attributes of a Node."""
|
|
295
|
+
|
|
296
|
+
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
|
+
super().__init__({attr.name: attr for attr in attrs})
|
|
298
|
+
|
|
299
|
+
def __setitem__(self, key: str, value: _core.Attr) -> None:
|
|
300
|
+
"""Set an attribute for the node."""
|
|
301
|
+
if type(key) is not str:
|
|
302
|
+
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
303
|
+
if not isinstance(value, _core.Attr):
|
|
304
|
+
raise TypeError(f"Value must be an Attr, not {type(value)}")
|
|
305
|
+
super().__setitem__(key, value)
|
|
306
|
+
|
|
307
|
+
def add(self, value: _core.Attr) -> None:
|
|
308
|
+
"""Add an attribute to the node."""
|
|
309
|
+
self[value.name] = value
|
|
310
|
+
|
|
311
|
+
def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
|
|
312
|
+
"""Get the integer value of the attribute."""
|
|
313
|
+
if key in self:
|
|
314
|
+
return self[key].as_int()
|
|
315
|
+
return default
|
|
316
|
+
|
|
317
|
+
def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
|
|
318
|
+
"""Get the float value of the attribute."""
|
|
319
|
+
if key in self:
|
|
320
|
+
return self[key].as_float()
|
|
321
|
+
return default
|
|
322
|
+
|
|
323
|
+
def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
|
|
324
|
+
"""Get the string value of the attribute."""
|
|
325
|
+
if key in self:
|
|
326
|
+
return self[key].as_string()
|
|
327
|
+
return default
|
|
328
|
+
|
|
329
|
+
def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
|
|
330
|
+
"""Get the tensor value of the attribute."""
|
|
331
|
+
if key in self:
|
|
332
|
+
return self[key].as_tensor()
|
|
333
|
+
return default
|
|
334
|
+
|
|
335
|
+
def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
|
|
336
|
+
"""Get the graph value of the attribute."""
|
|
337
|
+
if key in self:
|
|
338
|
+
return self[key].as_graph()
|
|
339
|
+
return default
|
|
340
|
+
|
|
341
|
+
def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
|
|
342
|
+
"""Get the Sequence of integers from the attribute."""
|
|
343
|
+
if key in self:
|
|
344
|
+
return self[key].as_ints()
|
|
345
|
+
return default
|
|
346
|
+
|
|
347
|
+
def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
|
|
348
|
+
"""Get the Sequence of floats from the attribute."""
|
|
349
|
+
if key in self:
|
|
350
|
+
return self[key].as_floats()
|
|
351
|
+
return default
|
|
352
|
+
|
|
353
|
+
def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
|
|
354
|
+
"""Get the Sequence of strings from the attribute."""
|
|
355
|
+
if key in self:
|
|
356
|
+
return self[key].as_strings()
|
|
357
|
+
return default
|
|
358
|
+
|
|
359
|
+
def get_tensors(
|
|
360
|
+
self,
|
|
361
|
+
key: str,
|
|
362
|
+
default: T = None, # type: ignore[assignment]
|
|
363
|
+
) -> Sequence[_protocols.TensorProtocol] | T:
|
|
364
|
+
"""Get the Sequence of tensors from the attribute."""
|
|
365
|
+
if key in self:
|
|
366
|
+
return self[key].as_tensors()
|
|
367
|
+
return default
|
|
368
|
+
|
|
369
|
+
def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
|
|
370
|
+
"""Get the Sequence of graphs from the attribute."""
|
|
371
|
+
if key in self:
|
|
372
|
+
return self[key].as_graphs()
|
|
373
|
+
return default
|
onnx_ir/_io.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Load and save ONNX models."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
@@ -10,7 +10,9 @@ import os
|
|
|
10
10
|
|
|
11
11
|
import onnx
|
|
12
12
|
|
|
13
|
-
from onnx_ir import _core,
|
|
13
|
+
from onnx_ir import _core, serde
|
|
14
|
+
from onnx_ir import external_data as _external_data
|
|
15
|
+
from onnx_ir._polyfill import zip
|
|
14
16
|
|
|
15
17
|
|
|
16
18
|
def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
|
|
@@ -35,16 +37,61 @@ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
|
|
|
35
37
|
return model
|
|
36
38
|
|
|
37
39
|
|
|
38
|
-
def save(
|
|
40
|
+
def save(
|
|
41
|
+
model: _core.Model,
|
|
42
|
+
path: str | os.PathLike,
|
|
43
|
+
format: str | None = None,
|
|
44
|
+
external_data: str | os.PathLike | None = None,
|
|
45
|
+
size_threshold_bytes: int = 256,
|
|
46
|
+
) -> None:
|
|
39
47
|
"""Save an ONNX model to a file.
|
|
40
48
|
|
|
49
|
+
The model remains unchanged after the call. If any existing external tensor
|
|
50
|
+
references the provided ``external_data`` path, it will be invalidated
|
|
51
|
+
after the external data is overwritten. To obtain a valid model, use :func:`load`
|
|
52
|
+
to load the newly saved model, or provide a different external data path that
|
|
53
|
+
is not currently referenced by any tensors in the model.
|
|
54
|
+
|
|
41
55
|
Args:
|
|
42
56
|
model: The model to save.
|
|
43
|
-
path: The path to save the model to.
|
|
44
|
-
format: The format of the file (e.g. protobuf
|
|
57
|
+
path: The path to save the model to. E.g. "model.onnx".
|
|
58
|
+
format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
|
|
45
59
|
If None, the format is inferred from the file extension.
|
|
60
|
+
external_data: The relative path to save external data to. When specified,
|
|
61
|
+
all initializers in the model will be converted to external data and
|
|
62
|
+
saved to the specified directory. If None, all tensors will be saved unmodified.
|
|
63
|
+
That is, if a tensor in the model is already external, it will be saved
|
|
64
|
+
with the same external information; if the tensor is not external,
|
|
65
|
+
it will be serialized in the ONNX Proto message.
|
|
66
|
+
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
67
|
+
Effective only when ``external_data`` is set.
|
|
68
|
+
|
|
69
|
+
Raises:
|
|
70
|
+
ValueError: If the external data path is an absolute path.
|
|
46
71
|
"""
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
72
|
+
if external_data is not None:
|
|
73
|
+
if os.path.isabs(external_data):
|
|
74
|
+
raise ValueError(
|
|
75
|
+
f"The external data path must be relative to the ONNX file path, not '{external_data}'."
|
|
76
|
+
)
|
|
77
|
+
base_dir = os.path.dirname(path)
|
|
78
|
+
|
|
79
|
+
# Store the original initializer values so they can be restored if modify_model=False
|
|
80
|
+
initializer_values = tuple(model.graph.initializers.values())
|
|
81
|
+
tensors = [v.const_value for v in initializer_values]
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
model = _external_data.unload_from_model(
|
|
85
|
+
model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
|
|
86
|
+
)
|
|
87
|
+
proto = serde.serialize_model(model)
|
|
88
|
+
onnx.save(proto, path, format=format)
|
|
89
|
+
|
|
90
|
+
finally:
|
|
91
|
+
# Restore the original initializer values so the model is unchanged
|
|
92
|
+
for initializer, tensor in zip(initializer_values, tensors, strict=True):
|
|
93
|
+
initializer.const_value = tensor
|
|
94
|
+
|
|
95
|
+
else:
|
|
96
|
+
proto = serde.serialize_model(model)
|
|
97
|
+
onnx.save(proto, path, format=format)
|
onnx_ir/_linked_list.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Mutable list for nodes in a graph with safe mutation properties."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
-
from
|
|
7
|
+
from collections.abc import Iterable, Iterator, Sequence
|
|
8
|
+
from typing import Generic, TypeVar, overload
|
|
8
9
|
|
|
9
10
|
T = TypeVar("T")
|
|
10
11
|
|
|
@@ -131,16 +132,23 @@ class DoublyLinkedSet(Sequence[T], Generic[T]):
|
|
|
131
132
|
box = box.prev
|
|
132
133
|
|
|
133
134
|
def __len__(self) -> int:
|
|
134
|
-
assert self._length == len(
|
|
135
|
-
|
|
136
|
-
)
|
|
135
|
+
assert self._length == len(self._value_ids_to_boxes), (
|
|
136
|
+
"Bug in the implementation: length mismatch"
|
|
137
|
+
)
|
|
137
138
|
return self._length
|
|
138
139
|
|
|
139
|
-
|
|
140
|
+
@overload
|
|
141
|
+
def __getitem__(self, index: int) -> T: ...
|
|
142
|
+
@overload
|
|
143
|
+
def __getitem__(self, index: slice) -> Sequence[T]: ...
|
|
144
|
+
|
|
145
|
+
def __getitem__(self, index):
|
|
140
146
|
"""Get the node at the given index.
|
|
141
147
|
|
|
142
148
|
Complexity is O(n).
|
|
143
149
|
"""
|
|
150
|
+
if isinstance(index, slice):
|
|
151
|
+
return tuple(self)[index]
|
|
144
152
|
if index >= self._length or index < -self._length:
|
|
145
153
|
raise IndexError(
|
|
146
154
|
f"Index out of range: {index} not in range [-{self._length}, {self._length})"
|
onnx_ir/_metadata.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Class for storing metadata about the IR objects."""
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import collections
|
|
8
|
-
from
|
|
8
|
+
from collections.abc import Mapping
|
|
9
|
+
from typing import Any
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
class MetadataStore(collections.UserDict):
|