onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/_display.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Internal utilities for displaying the intermediate representation of a model.
4
4
 
5
5
  NOTE: All third-party imports should be scoped and imported only when used to avoid
onnx_ir/_enums.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """ONNX IR enums that matches the ONNX spec."""
4
4
 
5
5
  from __future__ import annotations
@@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
64
64
  FLOAT8E5M2FNUZ = 20
65
65
  UINT4 = 21
66
66
  INT4 = 22
67
+ FLOAT4E2M1 = 23
67
68
 
68
69
  @classmethod
69
70
  def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -72,9 +73,43 @@ class DataType(enum.IntEnum):
72
73
  Raises:
73
74
  TypeError: If the data type is not supported by ONNX.
74
75
  """
75
- if dtype not in _NP_TYPE_TO_DATA_TYPE:
76
- raise TypeError(f"Unsupported numpy data type: {dtype}")
77
- return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
76
+ if dtype in _NP_TYPE_TO_DATA_TYPE:
77
+ return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
78
+
79
+ if np.issubdtype(dtype, np.str_):
80
+ return DataType.STRING
81
+
82
+ # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83
+ # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
84
+ if hasattr(dtype, "names"):
85
+ if dtype.names == ("bfloat16",):
86
+ return DataType.BFLOAT16
87
+ if dtype.names == ("e4m3fn",):
88
+ return DataType.FLOAT8E4M3FN
89
+ if dtype.names == ("e4m3fnuz",):
90
+ return DataType.FLOAT8E4M3FNUZ
91
+ if dtype.names == ("e5m2",):
92
+ return DataType.FLOAT8E5M2
93
+ if dtype.names == ("e5m2fnuz",):
94
+ return DataType.FLOAT8E5M2FNUZ
95
+ if dtype.names == ("uint4",):
96
+ return DataType.UINT4
97
+ if dtype.names == ("int4",):
98
+ return DataType.INT4
99
+ if dtype.names == ("float4e2m1",):
100
+ return DataType.FLOAT4E2M1
101
+ raise TypeError(f"Unsupported numpy data type: {dtype}")
102
+
103
+ @classmethod
104
+ def from_short_name(cls, short_name: str) -> DataType:
105
+ """Returns the ONNX data type for the short name.
106
+
107
+ Raises:
108
+ TypeError: If the short name is not available for the data type.
109
+ """
110
+ if short_name not in _SHORT_NAME_TO_DATA_TYPE:
111
+ raise TypeError(f"Unknown short name: {short_name}")
112
+ return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
78
113
 
79
114
  @property
80
115
  def itemsize(self) -> float:
@@ -91,6 +126,36 @@ class DataType(enum.IntEnum):
91
126
  raise TypeError(f"Numpy does not support ONNX data type: {self}")
92
127
  return _DATA_TYPE_TO_NP_TYPE[self]
93
128
 
129
+ def short_name(self) -> str:
130
+ """Returns the short name of the data type.
131
+
132
+ The short name is a string that is used to represent the data type in a more
133
+ compact form. For example, the short name for `DataType.FLOAT` is "f32".
134
+ To get the corresponding data type back, call ``from_short_name`` on a string.
135
+
136
+ Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
137
+
138
+ Raises:
139
+ TypeError: If the short name is not available for the data type.
140
+ """
141
+ if self not in _DATA_TYPE_TO_SHORT_NAME:
142
+ raise TypeError(f"Short name not available for ONNX data type: {self}")
143
+ return _DATA_TYPE_TO_SHORT_NAME[self]
144
+
145
+ def is_floating_point(self) -> bool:
146
+ """Returns True if the data type is a floating point type."""
147
+ return self in {
148
+ DataType.FLOAT,
149
+ DataType.FLOAT16,
150
+ DataType.DOUBLE,
151
+ DataType.BFLOAT16,
152
+ DataType.FLOAT8E4M3FN,
153
+ DataType.FLOAT8E4M3FNUZ,
154
+ DataType.FLOAT8E5M2,
155
+ DataType.FLOAT8E5M2FNUZ,
156
+ DataType.FLOAT4E2M1,
157
+ }
158
+
94
159
  def __repr__(self) -> str:
95
160
  return self.name
96
161
 
@@ -121,6 +186,7 @@ _ITEMSIZE_MAP = {
121
186
  DataType.FLOAT8E5M2FNUZ: 1,
122
187
  DataType.UINT4: 0.5,
123
188
  DataType.INT4: 0.5,
189
+ DataType.FLOAT4E2M1: 0.5,
124
190
  }
125
191
 
126
192
 
@@ -150,5 +216,41 @@ _NP_TYPE_TO_DATA_TYPE = {
150
216
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
151
217
  }
152
218
 
219
+ # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
220
+ _NP_TYPE_TO_DATA_TYPE.update(
221
+ {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
222
+ if hasattr(ml_dtypes, "float4_e2m1fn")
223
+ else {}
224
+ )
225
+
153
226
  # ONNX DataType to Numpy dtype.
154
227
  _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
228
+
229
+ _DATA_TYPE_TO_SHORT_NAME = {
230
+ DataType.UNDEFINED: "undefined",
231
+ DataType.BFLOAT16: "bf16",
232
+ DataType.DOUBLE: "f64",
233
+ DataType.FLOAT: "f32",
234
+ DataType.FLOAT16: "f16",
235
+ DataType.FLOAT8E4M3FN: "f8e4m3fn",
236
+ DataType.FLOAT8E5M2: "f8e5m2",
237
+ DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
238
+ DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
239
+ DataType.FLOAT4E2M1: "f4e2m1",
240
+ DataType.COMPLEX64: "c64",
241
+ DataType.COMPLEX128: "c128",
242
+ DataType.INT4: "i4",
243
+ DataType.INT8: "i8",
244
+ DataType.INT16: "i16",
245
+ DataType.INT32: "i32",
246
+ DataType.INT64: "i64",
247
+ DataType.BOOL: "b8",
248
+ DataType.UINT4: "u4",
249
+ DataType.UINT8: "u8",
250
+ DataType.UINT16: "u16",
251
+ DataType.UINT32: "u32",
252
+ DataType.UINT64: "u64",
253
+ DataType.STRING: "s",
254
+ }
255
+
256
+ _SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Utilities for comparing IR graphs."""
4
4
 
5
5
  from __future__ import annotations
@@ -0,0 +1,268 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Tracked containers for graph."""
4
+
5
+ # pylint: disable=protected-access
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__ = [
10
+ "GraphInputs",
11
+ "GraphOutputs",
12
+ ]
13
+
14
+ import collections
15
+ from collections.abc import Iterable
16
+ from typing import TYPE_CHECKING, SupportsIndex
17
+
18
+ import onnx_ir
19
+
20
+ if TYPE_CHECKING:
21
+ from onnx_ir import _core
22
+
23
+
24
+ class _GraphIO(collections.UserList["_core.Value"]):
25
+ """The inputs and outputs of a Graph."""
26
+
27
+ def __init__(self, graph: _core.Graph, initlist=None):
28
+ self._graph = graph
29
+ # Use a ref counter to track the number of references to each value
30
+ # in the input/output list. This is used to determine when to unset the graph
31
+ # reference in the value.
32
+ # Even though a duplicated value is invalid in inputs and not recommended in outputs,
33
+ # it is still possible to have duplicated inputs/outputs in an ONNX graph so we
34
+ # need to properly handle this case and maintain the graph reference properly.
35
+ self._ref_counter: collections.Counter[_core.Value] = collections.Counter()
36
+ if initlist is not None:
37
+ initlist = tuple(initlist) # Create a copy in case initlist is a generator
38
+ for value in initlist:
39
+ self._set_graph(value)
40
+ super().__init__(initlist)
41
+ self._check_invariance()
42
+
43
+ def _check_invariance(self) -> None:
44
+ """Check the invariance of the graph."""
45
+ raise NotImplementedError
46
+
47
+ def _set_graph(self, value: _core.Value) -> None:
48
+ """Set the graph for the value."""
49
+ raise NotImplementedError
50
+
51
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
52
+ """Unset the graph for the value."""
53
+ raise NotImplementedError
54
+
55
+ def append(self, item: _core.Value) -> None:
56
+ """Add a new input to the graph."""
57
+ # Perform checks first in _set_graph before modifying the data structure
58
+ self._set_graph(item)
59
+ super().append(item)
60
+ self._check_invariance()
61
+
62
+ def extend(self, other) -> None:
63
+ """Extend the list of inputs or outputs."""
64
+ other = tuple(other)
65
+ for item in other:
66
+ self._set_graph(item)
67
+ super().extend(other)
68
+
69
+ def insert(self, i: int, item: _core.Value) -> None:
70
+ """Insert an input/output to the graph."""
71
+ super().insert(i, item)
72
+ self._set_graph(item)
73
+ self._check_invariance()
74
+
75
+ def pop(self, i: int = -1) -> _core.Value:
76
+ """Remove an input/output from the graph."""
77
+ value = super().pop(i)
78
+ self._maybe_unset_graph(value)
79
+ self._check_invariance()
80
+ return value
81
+
82
+ def remove(self, item: _core.Value) -> None:
83
+ """Remove an input/output from the graph."""
84
+ super().remove(item)
85
+ self._maybe_unset_graph(item)
86
+ self._check_invariance()
87
+
88
+ def clear(self) -> None:
89
+ """Clear the list."""
90
+ for value in self.data:
91
+ self._maybe_unset_graph(value)
92
+ super().clear()
93
+
94
+ def copy(self) -> list[_core.Value]:
95
+ """Return a shallow copy of the list."""
96
+ # This is a shallow copy, so the values are not copied, just the references
97
+ return self.data.copy()
98
+
99
+ def __setitem__(self, i, item) -> None:
100
+ """Replace an input/output to the node."""
101
+ if isinstance(item, Iterable) and isinstance(i, slice):
102
+ # Modify a slice of the list
103
+ for value in self.data[i]:
104
+ self._maybe_unset_graph(value)
105
+ for value in item:
106
+ self._set_graph(value)
107
+ super().__setitem__(i, item)
108
+ self._check_invariance()
109
+ return
110
+ elif isinstance(i, SupportsIndex):
111
+ # Replace a single item
112
+ self._maybe_unset_graph(self.data[i])
113
+ self._set_graph(item)
114
+ super().__setitem__(i, item)
115
+ self._check_invariance()
116
+ return
117
+
118
+ raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
119
+
120
+ def __getitem__(self, i):
121
+ """Get an input/output from the graph."""
122
+ return self.data[i]
123
+
124
+ def _unimplemented(self, *_args, **_kwargs):
125
+ """Unimplemented method."""
126
+ raise RuntimeError("Method is not supported")
127
+
128
+ __add__ = _unimplemented
129
+ __radd__ = _unimplemented
130
+ __iadd__ = _unimplemented
131
+ __mul__ = _unimplemented
132
+ __rmul__ = _unimplemented
133
+
134
+
135
+ class GraphInputs(_GraphIO):
136
+ """The inputs of a Graph."""
137
+
138
+ def _check_invariance(self) -> None:
139
+ """Check the invariance of the graph."""
140
+ if not onnx_ir.DEBUG:
141
+ return
142
+ for value in self.data:
143
+ if value._graph is self._graph:
144
+ continue
145
+ raise ValueError(
146
+ f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
147
+ )
148
+
149
+ def _set_graph(self, value: _core.Value) -> None:
150
+ """Set the graph for the value."""
151
+ if value._graph is not None and value._graph is not self._graph:
152
+ raise ValueError(
153
+ f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
154
+ )
155
+ self._ref_counter[value] += 1
156
+ value._is_graph_input = True
157
+ value._graph = self._graph
158
+
159
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
160
+ """Unset the graph for the value."""
161
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
162
+ self._ref_counter[value] -= 1
163
+ if self._ref_counter[value] > 0:
164
+ # The value is still used by another graph input
165
+ return
166
+ value._is_graph_input = False
167
+ if value._owned_by_graph():
168
+ # Keep the graph reference if the value is still an input or an initializer
169
+ return
170
+ value._graph = None
171
+
172
+
173
+ class GraphOutputs(_GraphIO):
174
+ """The outputs of a Graph."""
175
+
176
+ def _check_invariance(self) -> None:
177
+ """Check the invariance of the graph."""
178
+ if not onnx_ir.DEBUG:
179
+ return
180
+ for value in self.data:
181
+ if value._graph is self._graph:
182
+ continue
183
+ raise ValueError(
184
+ f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
185
+ )
186
+
187
+ def _set_graph(self, value: _core.Value) -> None:
188
+ """Set the graph for the value."""
189
+ if value._graph is not None and value._graph is not self._graph:
190
+ raise ValueError(
191
+ f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first"
192
+ )
193
+ self._ref_counter[value] += 1
194
+ value._is_graph_output = True
195
+ value._graph = self._graph
196
+
197
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
198
+ """Unset the graph for the value."""
199
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
200
+ self._ref_counter[value] -= 1
201
+ if self._ref_counter[value] > 0:
202
+ # The value is still used by another graph input
203
+ return
204
+ value._is_graph_output = False
205
+ if value._owned_by_graph():
206
+ # Keep the graph reference if the value is still an input or an initializer
207
+ return
208
+ value._graph = None
209
+
210
+
211
+ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
212
+ """The initializers of a Graph."""
213
+
214
+ def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
215
+ # Perform checks first in _set_graph before modifying the data structure with super().__init__()
216
+ data = {}
217
+ if dict is not None:
218
+ data.update(dict)
219
+ if kwargs:
220
+ data.update(kwargs)
221
+ self._graph = graph
222
+ for value in data.values():
223
+ self._set_graph(value)
224
+
225
+ super().__init__(data)
226
+
227
+ def _set_graph(self, value: _core.Value) -> None:
228
+ """Set the graph for the value."""
229
+ if value._graph is not None and value._graph is not self._graph:
230
+ raise ValueError(
231
+ f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first"
232
+ )
233
+ value._is_initializer = True
234
+ value._graph = self._graph
235
+
236
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
237
+ """Unset the graph for the value."""
238
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
239
+ value._is_initializer = False
240
+ if value._owned_by_graph():
241
+ # Keep the graph reference if the value is still an input or an initializer
242
+ return
243
+ value._graph = None
244
+
245
+ def __setitem__(self, key: str, value: _core.Value) -> None:
246
+ """Set an initializer for the graph."""
247
+ if key != value.name:
248
+ raise ValueError(
249
+ f"Key '{key}' does not match the name of the value '{value.name}'"
250
+ )
251
+ if not isinstance(key, str):
252
+ raise TypeError(f"Key must be a string, not {type(key)}")
253
+ if key in self.data:
254
+ # If the key already exists, unset the old value
255
+ old_value = self.data[key]
256
+ self._maybe_unset_graph(old_value)
257
+ # Must call _set_graph before super().__setitem__ so that when there is an error,
258
+ # the dictionary is not modified
259
+ self._set_graph(value)
260
+ super().__setitem__(key, value)
261
+
262
+ def __delitem__(self, key: str) -> None:
263
+ """Delete an initializer from the graph."""
264
+ value = self.data[key]
265
+ # Must call _maybe_unset_graph before super().__delitem__ so that when there is an error,
266
+ # the dictionary is not modified
267
+ self._maybe_unset_graph(value)
268
+ super().__delitem__(key)
onnx_ir/_io.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Load and save ONNX models."""
4
4
 
5
5
  from __future__ import annotations
@@ -10,7 +10,9 @@ import os
10
10
 
11
11
  import onnx
12
12
 
13
- from onnx_ir import _core, _external_data, serde
13
+ from onnx_ir import _core, serde
14
+ from onnx_ir import external_data as _external_data
15
+ from onnx_ir._polyfill import zip
14
16
 
15
17
 
16
18
  def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
@@ -35,16 +37,61 @@ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
35
37
  return model
36
38
 
37
39
 
38
- def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
40
+ def save(
41
+ model: _core.Model,
42
+ path: str | os.PathLike,
43
+ format: str | None = None,
44
+ external_data: str | os.PathLike | None = None,
45
+ size_threshold_bytes: int = 256,
46
+ ) -> None:
39
47
  """Save an ONNX model to a file.
40
48
 
49
+ The model remains unchanged after the call. If any existing external tensor
50
+ references the provided ``external_data`` path, it will be invalidated
51
+ after the external data is overwritten. To obtain a valid model, use :func:`load`
52
+ to load the newly saved model, or provide a different external data path that
53
+ is not currently referenced by any tensors in the model.
54
+
41
55
  Args:
42
56
  model: The model to save.
43
- path: The path to save the model to.
44
- format: The format of the file (e.g. protobuf, textproto, json, etc.).
57
+ path: The path to save the model to. E.g. "model.onnx".
58
+ format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
45
59
  If None, the format is inferred from the file extension.
60
+ external_data: The relative path to save external data to. When specified,
61
+ all initializers in the model will be converted to external data and
62
+ saved to the specified directory. If None, all tensors will be saved unmodified.
63
+ That is, if a tensor in the model is already external, it will be saved
64
+ with the same external information; if the tensor is not external,
65
+ it will be serialized in the ONNX Proto message.
66
+ size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
67
+ Effective only when ``external_data`` is set.
68
+
69
+ Raises:
70
+ ValueError: If the external data path is an absolute path.
46
71
  """
47
- proto = serde.serialize_model(model)
48
- onnx.save(proto, path, format=format)
49
- # TODO(justinchuby): Handle external data when the relative path has changed
50
- # TODO(justinchuby): Handle off loading external data to disk when saving
72
+ if external_data is not None:
73
+ if os.path.isabs(external_data):
74
+ raise ValueError(
75
+ f"The external data path must be relative to the ONNX file path, not '{external_data}'."
76
+ )
77
+ base_dir = os.path.dirname(path)
78
+
79
+ # Store the original initializer values so they can be restored if modify_model=False
80
+ initializer_values = tuple(model.graph.initializers.values())
81
+ tensors = [v.const_value for v in initializer_values]
82
+
83
+ try:
84
+ model = _external_data.unload_from_model(
85
+ model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
86
+ )
87
+ proto = serde.serialize_model(model)
88
+ onnx.save(proto, path, format=format)
89
+
90
+ finally:
91
+ # Restore the original initializer values so the model is unchanged
92
+ for initializer, tensor in zip(initializer_values, tensors, strict=True):
93
+ initializer.const_value = tensor
94
+
95
+ else:
96
+ proto = serde.serialize_model(model)
97
+ onnx.save(proto, path, format=format)
onnx_ir/_linked_list.py CHANGED
@@ -1,10 +1,11 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Mutable list for nodes in a graph with safe mutation properties."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Generic, Iterable, Iterator, Sequence, TypeVar
7
+ from collections.abc import Iterable, Iterator, Sequence
8
+ from typing import Generic, TypeVar, overload
8
9
 
9
10
  T = TypeVar("T")
10
11
 
@@ -131,16 +132,23 @@ class DoublyLinkedSet(Sequence[T], Generic[T]):
131
132
  box = box.prev
132
133
 
133
134
  def __len__(self) -> int:
134
- assert self._length == len(
135
- self._value_ids_to_boxes
136
- ), "Bug in the implementation: length mismatch"
135
+ assert self._length == len(self._value_ids_to_boxes), (
136
+ "Bug in the implementation: length mismatch"
137
+ )
137
138
  return self._length
138
139
 
139
- def __getitem__(self, index: int) -> T:
140
+ @overload
141
+ def __getitem__(self, index: int) -> T: ...
142
+ @overload
143
+ def __getitem__(self, index: slice) -> Sequence[T]: ...
144
+
145
+ def __getitem__(self, index):
140
146
  """Get the node at the given index.
141
147
 
142
148
  Complexity is O(n).
143
149
  """
150
+ if isinstance(index, slice):
151
+ return tuple(self)[index]
144
152
  if index >= self._length or index < -self._length:
145
153
  raise IndexError(
146
154
  f"Index out of range: {index} not in range [-{self._length}, {self._length})"
onnx_ir/_metadata.py CHANGED
@@ -1,11 +1,12 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Class for storing metadata about the IR objects."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
7
  import collections
8
- from typing import Any, Mapping
8
+ from collections.abc import Mapping
9
+ from typing import Any
9
10
 
10
11
 
11
12
  class MetadataStore(collections.UserDict):
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Auxiliary class for managing names in the IR."""
4
4
 
5
5
  from __future__ import annotations
onnx_ir/_polyfill.py ADDED
@@ -0,0 +1,26 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Polyfill for Python builtin functions."""
4
+
5
+ import sys
6
+ from collections.abc import Sequence
7
+ from typing import Any
8
+
9
+ if sys.version_info >= (3, 10):
10
+ zip = zip # pylint: disable=self-assigning-variable
11
+ else:
12
+ # zip(..., strict=True) was added in Python 3.10
13
+ # TODO: Remove this polyfill when we drop support for Python 3.9
14
+ _python_zip = zip
15
+
16
+ def zip(a: Sequence[Any], b: Sequence[Any], strict: bool = False):
17
+ """Polyfill for Python's zip function.
18
+
19
+ This is a special version which only supports two Sequence inputs.
20
+
21
+ Raises:
22
+ ValueError: If the iterables have different lengths and strict is True.
23
+ """
24
+ if len(a) != len(b) and strict:
25
+ raise ValueError("zip() argument lengths must be equal")
26
+ return _python_zip(a, b)