onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/_display.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Internal utilities for displaying the intermediate representation of a model.
4
4
 
5
5
  NOTE: All third-party imports should be scoped and imported only when used to avoid
onnx_ir/_enums.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """ONNX IR enums that matches the ONNX spec."""
4
4
 
5
5
  from __future__ import annotations
@@ -64,6 +64,7 @@ class DataType(enum.IntEnum):
64
64
  FLOAT8E5M2FNUZ = 20
65
65
  UINT4 = 21
66
66
  INT4 = 22
67
+ FLOAT4E2M1 = 23
67
68
 
68
69
  @classmethod
69
70
  def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -72,9 +73,43 @@ class DataType(enum.IntEnum):
72
73
  Raises:
73
74
  TypeError: If the data type is not supported by ONNX.
74
75
  """
75
- if dtype not in _NP_TYPE_TO_DATA_TYPE:
76
- raise TypeError(f"Unsupported numpy data type: {dtype}")
77
- return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
76
+ if dtype in _NP_TYPE_TO_DATA_TYPE:
77
+ return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
78
+
79
+ if np.issubdtype(dtype, np.str_):
80
+ return DataType.STRING
81
+
82
+ # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83
+ # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
84
+ if hasattr(dtype, "names"):
85
+ if dtype.names == ("bfloat16",):
86
+ return DataType.BFLOAT16
87
+ if dtype.names == ("e4m3fn",):
88
+ return DataType.FLOAT8E4M3FN
89
+ if dtype.names == ("e4m3fnuz",):
90
+ return DataType.FLOAT8E4M3FNUZ
91
+ if dtype.names == ("e5m2",):
92
+ return DataType.FLOAT8E5M2
93
+ if dtype.names == ("e5m2fnuz",):
94
+ return DataType.FLOAT8E5M2FNUZ
95
+ if dtype.names == ("uint4",):
96
+ return DataType.UINT4
97
+ if dtype.names == ("int4",):
98
+ return DataType.INT4
99
+ if dtype.names == ("float4e2m1",):
100
+ return DataType.FLOAT4E2M1
101
+ raise TypeError(f"Unsupported numpy data type: {dtype}")
102
+
103
+ @classmethod
104
+ def from_short_name(cls, short_name: str) -> DataType:
105
+ """Returns the ONNX data type for the short name.
106
+
107
+ Raises:
108
+ TypeError: If the short name is not available for the data type.
109
+ """
110
+ if short_name not in _SHORT_NAME_TO_DATA_TYPE:
111
+ raise TypeError(f"Unknown short name: {short_name}")
112
+ return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
78
113
 
79
114
  @property
80
115
  def itemsize(self) -> float:
@@ -91,6 +126,36 @@ class DataType(enum.IntEnum):
91
126
  raise TypeError(f"Numpy does not support ONNX data type: {self}")
92
127
  return _DATA_TYPE_TO_NP_TYPE[self]
93
128
 
129
+ def short_name(self) -> str:
130
+ """Returns the short name of the data type.
131
+
132
+ The short name is a string that is used to represent the data type in a more
133
+ compact form. For example, the short name for `DataType.FLOAT` is "f32".
134
+ To get the corresponding data type back, call ``from_short_name`` on a string.
135
+
136
+ Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
137
+
138
+ Raises:
139
+ TypeError: If the short name is not available for the data type.
140
+ """
141
+ if self not in _DATA_TYPE_TO_SHORT_NAME:
142
+ raise TypeError(f"Short name not available for ONNX data type: {self}")
143
+ return _DATA_TYPE_TO_SHORT_NAME[self]
144
+
145
+ def is_floating_point(self) -> bool:
146
+ """Returns True if the data type is a floating point type."""
147
+ return self in {
148
+ DataType.FLOAT,
149
+ DataType.FLOAT16,
150
+ DataType.DOUBLE,
151
+ DataType.BFLOAT16,
152
+ DataType.FLOAT8E4M3FN,
153
+ DataType.FLOAT8E4M3FNUZ,
154
+ DataType.FLOAT8E5M2,
155
+ DataType.FLOAT8E5M2FNUZ,
156
+ DataType.FLOAT4E2M1,
157
+ }
158
+
94
159
  def __repr__(self) -> str:
95
160
  return self.name
96
161
 
@@ -121,6 +186,7 @@ _ITEMSIZE_MAP = {
121
186
  DataType.FLOAT8E5M2FNUZ: 1,
122
187
  DataType.UINT4: 0.5,
123
188
  DataType.INT4: 0.5,
189
+ DataType.FLOAT4E2M1: 0.5,
124
190
  }
125
191
 
126
192
 
@@ -150,5 +216,41 @@ _NP_TYPE_TO_DATA_TYPE = {
150
216
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
151
217
  }
152
218
 
219
+ # TODO(after min req for ml_dtypes>=0.5): Move this inside _NP_TYPE_TO_DATA_TYPE
220
+ _NP_TYPE_TO_DATA_TYPE.update(
221
+ {np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1}
222
+ if hasattr(ml_dtypes, "float4_e2m1fn")
223
+ else {}
224
+ )
225
+
153
226
  # ONNX DataType to Numpy dtype.
154
227
  _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
228
+
229
+ _DATA_TYPE_TO_SHORT_NAME = {
230
+ DataType.UNDEFINED: "undefined",
231
+ DataType.BFLOAT16: "bf16",
232
+ DataType.DOUBLE: "f64",
233
+ DataType.FLOAT: "f32",
234
+ DataType.FLOAT16: "f16",
235
+ DataType.FLOAT8E4M3FN: "f8e4m3fn",
236
+ DataType.FLOAT8E5M2: "f8e5m2",
237
+ DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
238
+ DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
239
+ DataType.FLOAT4E2M1: "f4e2m1",
240
+ DataType.COMPLEX64: "c64",
241
+ DataType.COMPLEX128: "c128",
242
+ DataType.INT4: "i4",
243
+ DataType.INT8: "i8",
244
+ DataType.INT16: "i16",
245
+ DataType.INT32: "i32",
246
+ DataType.INT64: "i64",
247
+ DataType.BOOL: "b8",
248
+ DataType.UINT4: "u4",
249
+ DataType.UINT8: "u8",
250
+ DataType.UINT16: "u16",
251
+ DataType.UINT32: "u32",
252
+ DataType.UINT64: "u64",
253
+ DataType.STRING: "s",
254
+ }
255
+
256
+ _SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Utilities for comparing IR graphs."""
4
4
 
5
5
  from __future__ import annotations
@@ -0,0 +1,373 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Tracked containers for graph."""
4
+
5
+ # pylint: disable=protected-access
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__ = [
10
+ "GraphInputs",
11
+ "GraphOutputs",
12
+ ]
13
+
14
+ import collections
15
+ import logging
16
+ from collections.abc import Iterable, Sequence
17
+ from typing import SupportsIndex, TypeVar
18
+
19
+ import onnx_ir
20
+ from onnx_ir import _core, _protocols
21
+
22
+ T = TypeVar("T")
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class _GraphIO(collections.UserList["_core.Value"]):
28
+ """The inputs and outputs of a Graph."""
29
+
30
+ def __init__(self, graph: _core.Graph, initlist=None):
31
+ self._graph = graph
32
+ # Use a ref counter to track the number of references to each value
33
+ # in the input/output list. This is used to determine when to unset the graph
34
+ # reference in the value.
35
+ # Even though a duplicated value is invalid in inputs and not recommended in outputs,
36
+ # it is still possible to have duplicated inputs/outputs in an ONNX graph so we
37
+ # need to properly handle this case and maintain the graph reference properly.
38
+ self._ref_counter: collections.Counter[_core.Value] = collections.Counter()
39
+ if initlist is not None:
40
+ initlist = tuple(initlist) # Create a copy in case initlist is a generator
41
+ for value in initlist:
42
+ self._set_graph(value)
43
+ super().__init__(initlist)
44
+ self._check_invariance()
45
+
46
+ def _check_invariance(self) -> None:
47
+ """Check the invariance of the graph."""
48
+ raise NotImplementedError
49
+
50
+ def _set_graph(self, value: _core.Value) -> None:
51
+ """Set the graph for the value."""
52
+ raise NotImplementedError
53
+
54
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
55
+ """Unset the graph for the value."""
56
+ raise NotImplementedError
57
+
58
+ def append(self, item: _core.Value) -> None:
59
+ """Add a new input to the graph."""
60
+ # Perform checks first in _set_graph before modifying the data structure
61
+ self._set_graph(item)
62
+ super().append(item)
63
+ self._check_invariance()
64
+
65
+ def extend(self, other) -> None:
66
+ """Extend the list of inputs or outputs."""
67
+ other = tuple(other)
68
+ for item in other:
69
+ self._set_graph(item)
70
+ super().extend(other)
71
+
72
+ def insert(self, i: int, item: _core.Value) -> None:
73
+ """Insert an input/output to the graph."""
74
+ super().insert(i, item)
75
+ self._set_graph(item)
76
+ self._check_invariance()
77
+
78
+ def pop(self, i: int = -1) -> _core.Value:
79
+ """Remove an input/output from the graph."""
80
+ value = super().pop(i)
81
+ self._maybe_unset_graph(value)
82
+ self._check_invariance()
83
+ return value
84
+
85
+ def remove(self, item: _core.Value) -> None:
86
+ """Remove an input/output from the graph."""
87
+ super().remove(item)
88
+ self._maybe_unset_graph(item)
89
+ self._check_invariance()
90
+
91
+ def clear(self) -> None:
92
+ """Clear the list."""
93
+ for value in self.data:
94
+ self._maybe_unset_graph(value)
95
+ super().clear()
96
+
97
+ def copy(self) -> list[_core.Value]:
98
+ """Return a shallow copy of the list."""
99
+ # This is a shallow copy, so the values are not copied, just the references
100
+ return self.data.copy()
101
+
102
+ def __setitem__(self, i, item) -> None:
103
+ """Replace an input/output to the node."""
104
+ if isinstance(item, Iterable) and isinstance(i, slice):
105
+ # Modify a slice of the list
106
+ for value in self.data[i]:
107
+ self._maybe_unset_graph(value)
108
+ for value in item:
109
+ self._set_graph(value)
110
+ super().__setitem__(i, item)
111
+ self._check_invariance()
112
+ return
113
+ elif isinstance(i, SupportsIndex):
114
+ # Replace a single item
115
+ self._maybe_unset_graph(self.data[i])
116
+ self._set_graph(item)
117
+ super().__setitem__(i, item)
118
+ self._check_invariance()
119
+ return
120
+
121
+ raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
122
+
123
+ def __getitem__(self, i):
124
+ """Get an input/output from the graph."""
125
+ return self.data[i]
126
+
127
+ def _unimplemented(self, *_args, **_kwargs):
128
+ """Unimplemented method."""
129
+ raise RuntimeError("Method is not supported")
130
+
131
+ __add__ = _unimplemented
132
+ __radd__ = _unimplemented
133
+ __iadd__ = _unimplemented
134
+ __mul__ = _unimplemented
135
+ __rmul__ = _unimplemented
136
+
137
+
138
+ class GraphInputs(_GraphIO):
139
+ """The inputs of a Graph."""
140
+
141
+ def _check_invariance(self) -> None:
142
+ """Check the invariance of the graph."""
143
+ if not onnx_ir.DEBUG:
144
+ return
145
+ for value in self.data:
146
+ if value._graph is self._graph:
147
+ continue
148
+ raise ValueError(
149
+ f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
150
+ )
151
+
152
+ def _set_graph(self, value: _core.Value) -> None:
153
+ """Set the graph for the value."""
154
+ if value._graph is not None and value._graph is not self._graph:
155
+ raise ValueError(
156
+ f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
157
+ )
158
+ if value.producer() is not None:
159
+ raise ValueError(
160
+ f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
161
+ )
162
+ self._ref_counter[value] += 1
163
+ value._is_graph_input = True
164
+ value._graph = self._graph
165
+
166
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
167
+ """Unset the graph for the value."""
168
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
169
+ self._ref_counter[value] -= 1
170
+ if self._ref_counter[value] > 0:
171
+ # The value is still used by another graph input
172
+ return
173
+ value._is_graph_input = False
174
+ if value._owned_by_graph():
175
+ # Keep the graph reference if the value is still an input or an initializer
176
+ return
177
+ value._graph = None
178
+
179
+
180
+ class GraphOutputs(_GraphIO):
181
+ """The outputs of a Graph."""
182
+
183
+ def _check_invariance(self) -> None:
184
+ """Check the invariance of the graph."""
185
+ if not onnx_ir.DEBUG:
186
+ return
187
+ for value in self.data:
188
+ if value._graph is self._graph:
189
+ continue
190
+ raise ValueError(
191
+ f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
192
+ )
193
+
194
+ def _set_graph(self, value: _core.Value) -> None:
195
+ """Set the graph for the value."""
196
+ if value._graph is not None and value._graph is not self._graph:
197
+ raise ValueError(
198
+ f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first"
199
+ )
200
+ self._ref_counter[value] += 1
201
+ value._is_graph_output = True
202
+ value._graph = self._graph
203
+
204
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
205
+ """Unset the graph for the value."""
206
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
207
+ self._ref_counter[value] -= 1
208
+ if self._ref_counter[value] > 0:
209
+ # The value is still used by another graph input
210
+ return
211
+ value._is_graph_output = False
212
+ if value._owned_by_graph():
213
+ # Keep the graph reference if the value is still an input or an initializer
214
+ return
215
+ value._graph = None
216
+
217
+
218
+ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
219
+ """The initializers of a Graph."""
220
+
221
+ def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
222
+ # Perform checks first in _set_graph before modifying the data structure with super().__init__()
223
+ data = {}
224
+ if dict is not None:
225
+ data.update(dict)
226
+ if kwargs:
227
+ data.update(kwargs)
228
+ self._graph = graph
229
+ for value in data.values():
230
+ self._set_graph(value)
231
+
232
+ super().__init__(data)
233
+
234
+ def _set_graph(self, value: _core.Value) -> None:
235
+ """Set the graph for the value."""
236
+ if value._graph is not None and value._graph is not self._graph:
237
+ raise ValueError(
238
+ f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first"
239
+ )
240
+ value._is_initializer = True
241
+ value._graph = self._graph
242
+
243
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
244
+ """Unset the graph for the value."""
245
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
246
+ value._is_initializer = False
247
+ if value._owned_by_graph():
248
+ # Keep the graph reference if the value is still an input or an initializer
249
+ return
250
+ value._graph = None
251
+
252
+ def __setitem__(self, key: str, value: _core.Value) -> None:
253
+ """Set an initializer for the graph."""
254
+ if not isinstance(value, _core.Value):
255
+ raise TypeError(f"value must be a Value object, not {type(value)}")
256
+ if not isinstance(key, str):
257
+ raise TypeError(f"Value name must be a string, not {type(key)}")
258
+ if key == "":
259
+ raise ValueError("Value name cannot be an empty string")
260
+ if not value.name:
261
+ logger.info("Value %s does not have a name, setting it to '%s'", value, key)
262
+ value.name = key
263
+ elif key != value.name:
264
+ raise ValueError(
265
+ f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
266
+ )
267
+ if value.producer() is not None:
268
+ raise ValueError(
269
+ f"Value '{value}' is produced by a node and cannot be a graph initializer"
270
+ )
271
+ if key in self.data:
272
+ # If the key already exists, unset the old value
273
+ old_value = self.data[key]
274
+ self._maybe_unset_graph(old_value)
275
+ # Must call _set_graph before super().__setitem__ so that when there is an error,
276
+ # the dictionary is not modified
277
+ self._set_graph(value)
278
+ super().__setitem__(key, value)
279
+
280
+ def __delitem__(self, key: str) -> None:
281
+ """Delete an initializer from the graph."""
282
+ value = self.data[key]
283
+ # Must call _maybe_unset_graph before super().__delitem__ so that when there is an error,
284
+ # the dictionary is not modified
285
+ self._maybe_unset_graph(value)
286
+ super().__delitem__(key)
287
+
288
+ def add(self, value: _core.Value) -> None:
289
+ """Add an initializer to the graph."""
290
+ self[value.name] = value # type: ignore[index]
291
+
292
+
293
+ class Attributes(collections.UserDict[str, "_core.Attr"]):
294
+ """The attributes of a Node."""
295
+
296
+ def __init__(self, attrs: Iterable[_core.Attr]):
297
+ super().__init__({attr.name: attr for attr in attrs})
298
+
299
+ def __setitem__(self, key: str, value: _core.Attr) -> None:
300
+ """Set an attribute for the node."""
301
+ if type(key) is not str:
302
+ raise TypeError(f"Key must be a string, not {type(key)}")
303
+ if not isinstance(value, _core.Attr):
304
+ raise TypeError(f"Value must be an Attr, not {type(value)}")
305
+ super().__setitem__(key, value)
306
+
307
+ def add(self, value: _core.Attr) -> None:
308
+ """Add an attribute to the node."""
309
+ self[value.name] = value
310
+
311
+ def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
312
+ """Get the integer value of the attribute."""
313
+ if key in self:
314
+ return self[key].as_int()
315
+ return default
316
+
317
+ def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
318
+ """Get the float value of the attribute."""
319
+ if key in self:
320
+ return self[key].as_float()
321
+ return default
322
+
323
+ def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
324
+ """Get the string value of the attribute."""
325
+ if key in self:
326
+ return self[key].as_string()
327
+ return default
328
+
329
+ def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
330
+ """Get the tensor value of the attribute."""
331
+ if key in self:
332
+ return self[key].as_tensor()
333
+ return default
334
+
335
+ def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
336
+ """Get the graph value of the attribute."""
337
+ if key in self:
338
+ return self[key].as_graph()
339
+ return default
340
+
341
+ def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
342
+ """Get the Sequence of integers from the attribute."""
343
+ if key in self:
344
+ return self[key].as_ints()
345
+ return default
346
+
347
+ def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
348
+ """Get the Sequence of floats from the attribute."""
349
+ if key in self:
350
+ return self[key].as_floats()
351
+ return default
352
+
353
+ def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
354
+ """Get the Sequence of strings from the attribute."""
355
+ if key in self:
356
+ return self[key].as_strings()
357
+ return default
358
+
359
+ def get_tensors(
360
+ self,
361
+ key: str,
362
+ default: T = None, # type: ignore[assignment]
363
+ ) -> Sequence[_protocols.TensorProtocol] | T:
364
+ """Get the Sequence of tensors from the attribute."""
365
+ if key in self:
366
+ return self[key].as_tensors()
367
+ return default
368
+
369
+ def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
370
+ """Get the Sequence of graphs from the attribute."""
371
+ if key in self:
372
+ return self[key].as_graphs()
373
+ return default
onnx_ir/_io.py CHANGED
@@ -1,5 +1,5 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Load and save ONNX models."""
4
4
 
5
5
  from __future__ import annotations
@@ -10,7 +10,9 @@ import os
10
10
 
11
11
  import onnx
12
12
 
13
- from onnx_ir import _core, _external_data, serde
13
+ from onnx_ir import _core, serde
14
+ from onnx_ir import external_data as _external_data
15
+ from onnx_ir._polyfill import zip
14
16
 
15
17
 
16
18
  def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
@@ -35,16 +37,61 @@ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
35
37
  return model
36
38
 
37
39
 
38
- def save(model: _core.Model, path: str | os.PathLike, format: str | None = None) -> None:
40
+ def save(
41
+ model: _core.Model,
42
+ path: str | os.PathLike,
43
+ format: str | None = None,
44
+ external_data: str | os.PathLike | None = None,
45
+ size_threshold_bytes: int = 256,
46
+ ) -> None:
39
47
  """Save an ONNX model to a file.
40
48
 
49
+ The model remains unchanged after the call. If any existing external tensor
50
+ references the provided ``external_data`` path, it will be invalidated
51
+ after the external data is overwritten. To obtain a valid model, use :func:`load`
52
+ to load the newly saved model, or provide a different external data path that
53
+ is not currently referenced by any tensors in the model.
54
+
41
55
  Args:
42
56
  model: The model to save.
43
- path: The path to save the model to.
44
- format: The format of the file (e.g. protobuf, textproto, json, etc.).
57
+ path: The path to save the model to. E.g. "model.onnx".
58
+ format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
45
59
  If None, the format is inferred from the file extension.
60
+ external_data: The relative path to save external data to. When specified,
61
+ all initializers in the model will be converted to external data and
62
+ saved to the specified directory. If None, all tensors will be saved unmodified.
63
+ That is, if a tensor in the model is already external, it will be saved
64
+ with the same external information; if the tensor is not external,
65
+ it will be serialized in the ONNX Proto message.
66
+ size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
67
+ Effective only when ``external_data`` is set.
68
+
69
+ Raises:
70
+ ValueError: If the external data path is an absolute path.
46
71
  """
47
- proto = serde.serialize_model(model)
48
- onnx.save(proto, path, format=format)
49
- # TODO(justinchuby): Handle external data when the relative path has changed
50
- # TODO(justinchuby): Handle off loading external data to disk when saving
72
+ if external_data is not None:
73
+ if os.path.isabs(external_data):
74
+ raise ValueError(
75
+ f"The external data path must be relative to the ONNX file path, not '{external_data}'."
76
+ )
77
+ base_dir = os.path.dirname(path)
78
+
79
+ # Store the original initializer values so they can be restored if modify_model=False
80
+ initializer_values = tuple(model.graph.initializers.values())
81
+ tensors = [v.const_value for v in initializer_values]
82
+
83
+ try:
84
+ model = _external_data.unload_from_model(
85
+ model, base_dir, external_data, size_threshold_bytes=size_threshold_bytes
86
+ )
87
+ proto = serde.serialize_model(model)
88
+ onnx.save(proto, path, format=format)
89
+
90
+ finally:
91
+ # Restore the original initializer values so the model is unchanged
92
+ for initializer, tensor in zip(initializer_values, tensors, strict=True):
93
+ initializer.const_value = tensor
94
+
95
+ else:
96
+ proto = serde.serialize_model(model)
97
+ onnx.save(proto, path, format=format)
onnx_ir/_linked_list.py CHANGED
@@ -1,10 +1,11 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Mutable list for nodes in a graph with safe mutation properties."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
- from typing import Generic, Iterable, Iterator, Sequence, TypeVar
7
+ from collections.abc import Iterable, Iterator, Sequence
8
+ from typing import Generic, TypeVar, overload
8
9
 
9
10
  T = TypeVar("T")
10
11
 
@@ -131,16 +132,23 @@ class DoublyLinkedSet(Sequence[T], Generic[T]):
131
132
  box = box.prev
132
133
 
133
134
  def __len__(self) -> int:
134
- assert self._length == len(
135
- self._value_ids_to_boxes
136
- ), "Bug in the implementation: length mismatch"
135
+ assert self._length == len(self._value_ids_to_boxes), (
136
+ "Bug in the implementation: length mismatch"
137
+ )
137
138
  return self._length
138
139
 
139
- def __getitem__(self, index: int) -> T:
140
+ @overload
141
+ def __getitem__(self, index: int) -> T: ...
142
+ @overload
143
+ def __getitem__(self, index: slice) -> Sequence[T]: ...
144
+
145
+ def __getitem__(self, index):
140
146
  """Get the node at the given index.
141
147
 
142
148
  Complexity is O(n).
143
149
  """
150
+ if isinstance(index, slice):
151
+ return tuple(self)[index]
144
152
  if index >= self._length or index < -self._length:
145
153
  raise IndexError(
146
154
  f"Index out of range: {index} not in range [-{self._length}, {self._length})"
onnx_ir/_metadata.py CHANGED
@@ -1,11 +1,12 @@
1
- # Copyright (c) Microsoft Corporation.
2
- # Licensed under the MIT License.
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
3
  """Class for storing metadata about the IR objects."""
4
4
 
5
5
  from __future__ import annotations
6
6
 
7
7
  import collections
8
- from typing import Any, Mapping
8
+ from collections.abc import Mapping
9
+ from typing import Any
9
10
 
10
11
 
11
12
  class MetadataStore(collections.UserDict):