onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Tracked containers for graph."""
|
|
4
|
+
|
|
5
|
+
# pylint: disable=protected-access
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"GraphInputs",
|
|
11
|
+
"GraphOutputs",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import collections
|
|
15
|
+
import logging
|
|
16
|
+
from collections.abc import Iterable, Sequence
|
|
17
|
+
from typing import SupportsIndex, TypeVar
|
|
18
|
+
|
|
19
|
+
import onnx_ir
|
|
20
|
+
from onnx_ir import _core, _protocols
|
|
21
|
+
|
|
22
|
+
T = TypeVar("T")
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _GraphIO(collections.UserList["_core.Value"]):
|
|
28
|
+
"""The inputs and outputs of a Graph."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, graph: _core.Graph, initlist=None):
|
|
31
|
+
self._graph = graph
|
|
32
|
+
# Use a ref counter to track the number of references to each value
|
|
33
|
+
# in the input/output list. This is used to determine when to unset the graph
|
|
34
|
+
# reference in the value.
|
|
35
|
+
# Even though a duplicated value is invalid in inputs and not recommended in outputs,
|
|
36
|
+
# it is still possible to have duplicated inputs/outputs in an ONNX graph so we
|
|
37
|
+
# need to properly handle this case and maintain the graph reference properly.
|
|
38
|
+
self._ref_counter: collections.Counter[_core.Value] = collections.Counter()
|
|
39
|
+
if initlist is not None:
|
|
40
|
+
initlist = tuple(initlist) # Create a copy in case initlist is a generator
|
|
41
|
+
for value in initlist:
|
|
42
|
+
self._set_graph(value)
|
|
43
|
+
super().__init__(initlist)
|
|
44
|
+
self._check_invariance()
|
|
45
|
+
|
|
46
|
+
def _check_invariance(self) -> None:
|
|
47
|
+
"""Check the invariance of the graph."""
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
51
|
+
"""Set the graph for the value."""
|
|
52
|
+
raise NotImplementedError
|
|
53
|
+
|
|
54
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
55
|
+
"""Unset the graph for the value."""
|
|
56
|
+
raise NotImplementedError
|
|
57
|
+
|
|
58
|
+
def append(self, item: _core.Value) -> None:
|
|
59
|
+
"""Add a new input to the graph."""
|
|
60
|
+
# Perform checks first in _set_graph before modifying the data structure
|
|
61
|
+
self._set_graph(item)
|
|
62
|
+
super().append(item)
|
|
63
|
+
self._check_invariance()
|
|
64
|
+
|
|
65
|
+
def extend(self, other) -> None:
|
|
66
|
+
"""Extend the list of inputs or outputs."""
|
|
67
|
+
other = tuple(other)
|
|
68
|
+
for item in other:
|
|
69
|
+
self._set_graph(item)
|
|
70
|
+
super().extend(other)
|
|
71
|
+
|
|
72
|
+
def insert(self, i: int, item: _core.Value) -> None:
|
|
73
|
+
"""Insert an input/output to the graph."""
|
|
74
|
+
super().insert(i, item)
|
|
75
|
+
self._set_graph(item)
|
|
76
|
+
self._check_invariance()
|
|
77
|
+
|
|
78
|
+
def pop(self, i: int = -1) -> _core.Value:
|
|
79
|
+
"""Remove an input/output from the graph."""
|
|
80
|
+
value = super().pop(i)
|
|
81
|
+
self._maybe_unset_graph(value)
|
|
82
|
+
self._check_invariance()
|
|
83
|
+
return value
|
|
84
|
+
|
|
85
|
+
def remove(self, item: _core.Value) -> None:
|
|
86
|
+
"""Remove an input/output from the graph."""
|
|
87
|
+
super().remove(item)
|
|
88
|
+
self._maybe_unset_graph(item)
|
|
89
|
+
self._check_invariance()
|
|
90
|
+
|
|
91
|
+
def clear(self) -> None:
|
|
92
|
+
"""Clear the list."""
|
|
93
|
+
for value in self.data:
|
|
94
|
+
self._maybe_unset_graph(value)
|
|
95
|
+
super().clear()
|
|
96
|
+
|
|
97
|
+
def copy(self) -> list[_core.Value]:
|
|
98
|
+
"""Return a shallow copy of the list."""
|
|
99
|
+
# This is a shallow copy, so the values are not copied, just the references
|
|
100
|
+
return self.data.copy()
|
|
101
|
+
|
|
102
|
+
def __setitem__(self, i, item) -> None:
|
|
103
|
+
"""Replace an input/output to the node."""
|
|
104
|
+
if isinstance(item, Iterable) and isinstance(i, slice):
|
|
105
|
+
# Modify a slice of the list
|
|
106
|
+
for value in self.data[i]:
|
|
107
|
+
self._maybe_unset_graph(value)
|
|
108
|
+
for value in item:
|
|
109
|
+
self._set_graph(value)
|
|
110
|
+
super().__setitem__(i, item)
|
|
111
|
+
self._check_invariance()
|
|
112
|
+
return
|
|
113
|
+
elif isinstance(i, SupportsIndex):
|
|
114
|
+
# Replace a single item
|
|
115
|
+
self._maybe_unset_graph(self.data[i])
|
|
116
|
+
self._set_graph(item)
|
|
117
|
+
super().__setitem__(i, item)
|
|
118
|
+
self._check_invariance()
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
|
|
122
|
+
|
|
123
|
+
def __getitem__(self, i):
|
|
124
|
+
"""Get an input/output from the graph."""
|
|
125
|
+
return self.data[i]
|
|
126
|
+
|
|
127
|
+
def _unimplemented(self, *_args, **_kwargs):
|
|
128
|
+
"""Unimplemented method."""
|
|
129
|
+
raise RuntimeError("Method is not supported")
|
|
130
|
+
|
|
131
|
+
__add__ = _unimplemented
|
|
132
|
+
__radd__ = _unimplemented
|
|
133
|
+
__iadd__ = _unimplemented
|
|
134
|
+
__mul__ = _unimplemented
|
|
135
|
+
__rmul__ = _unimplemented
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class GraphInputs(_GraphIO):
|
|
139
|
+
"""The inputs of a Graph."""
|
|
140
|
+
|
|
141
|
+
def _check_invariance(self) -> None:
|
|
142
|
+
"""Check the invariance of the graph."""
|
|
143
|
+
if not onnx_ir.DEBUG:
|
|
144
|
+
return
|
|
145
|
+
for value in self.data:
|
|
146
|
+
if value._graph is self._graph:
|
|
147
|
+
continue
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
153
|
+
"""Set the graph for the value."""
|
|
154
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
|
|
157
|
+
)
|
|
158
|
+
if value.producer() is not None:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
|
|
161
|
+
)
|
|
162
|
+
self._ref_counter[value] += 1
|
|
163
|
+
value._is_graph_input = True
|
|
164
|
+
value._graph = self._graph
|
|
165
|
+
|
|
166
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
167
|
+
"""Unset the graph for the value."""
|
|
168
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
169
|
+
self._ref_counter[value] -= 1
|
|
170
|
+
if self._ref_counter[value] > 0:
|
|
171
|
+
# The value is still used by another graph input
|
|
172
|
+
return
|
|
173
|
+
value._is_graph_input = False
|
|
174
|
+
if value._owned_by_graph():
|
|
175
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
176
|
+
return
|
|
177
|
+
value._graph = None
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class GraphOutputs(_GraphIO):
|
|
181
|
+
"""The outputs of a Graph."""
|
|
182
|
+
|
|
183
|
+
def _check_invariance(self) -> None:
|
|
184
|
+
"""Check the invariance of the graph."""
|
|
185
|
+
if not onnx_ir.DEBUG:
|
|
186
|
+
return
|
|
187
|
+
for value in self.data:
|
|
188
|
+
if value._graph is self._graph:
|
|
189
|
+
continue
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
195
|
+
"""Set the graph for the value."""
|
|
196
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first"
|
|
199
|
+
)
|
|
200
|
+
self._ref_counter[value] += 1
|
|
201
|
+
value._is_graph_output = True
|
|
202
|
+
value._graph = self._graph
|
|
203
|
+
|
|
204
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
205
|
+
"""Unset the graph for the value."""
|
|
206
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
207
|
+
self._ref_counter[value] -= 1
|
|
208
|
+
if self._ref_counter[value] > 0:
|
|
209
|
+
# The value is still used by another graph input
|
|
210
|
+
return
|
|
211
|
+
value._is_graph_output = False
|
|
212
|
+
if value._owned_by_graph():
|
|
213
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
214
|
+
return
|
|
215
|
+
value._graph = None
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class GraphInitializers(collections.UserDict[str, "_core.Value"]):
|
|
219
|
+
"""The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
|
|
220
|
+
|
|
221
|
+
def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
|
|
222
|
+
# Perform checks first in _set_graph before modifying the data structure with super().__init__()
|
|
223
|
+
data = {}
|
|
224
|
+
if dict is not None:
|
|
225
|
+
data.update(dict)
|
|
226
|
+
if kwargs:
|
|
227
|
+
data.update(kwargs)
|
|
228
|
+
self._graph = graph
|
|
229
|
+
for value in data.values():
|
|
230
|
+
self._set_graph(value)
|
|
231
|
+
|
|
232
|
+
super().__init__(data)
|
|
233
|
+
|
|
234
|
+
def _set_graph(self, value: _core.Value) -> None:
|
|
235
|
+
"""Set the graph for the value."""
|
|
236
|
+
if value._graph is not None and value._graph is not self._graph:
|
|
237
|
+
raise ValueError(
|
|
238
|
+
f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first"
|
|
239
|
+
)
|
|
240
|
+
value._is_initializer = True
|
|
241
|
+
value._graph = self._graph
|
|
242
|
+
|
|
243
|
+
def _maybe_unset_graph(self, value: _core.Value) -> None:
|
|
244
|
+
"""Unset the graph for the value."""
|
|
245
|
+
assert value._graph is self._graph, "Bug: value does not belong to the graph"
|
|
246
|
+
value._is_initializer = False
|
|
247
|
+
if value._owned_by_graph():
|
|
248
|
+
# Keep the graph reference if the value is still an input or an initializer
|
|
249
|
+
return
|
|
250
|
+
value._graph = None
|
|
251
|
+
|
|
252
|
+
def __setitem__(self, key: str, value: _core.Value) -> None:
|
|
253
|
+
"""Set an initializer for the graph."""
|
|
254
|
+
if not isinstance(value, _core.Value):
|
|
255
|
+
raise TypeError(f"value must be a Value object, not {type(value)}")
|
|
256
|
+
if not isinstance(key, str):
|
|
257
|
+
raise TypeError(f"Value name must be a string, not {type(key)}")
|
|
258
|
+
if key == "":
|
|
259
|
+
raise ValueError("Value name cannot be an empty string")
|
|
260
|
+
if not value.name:
|
|
261
|
+
logger.info("Value %s does not have a name, setting it to '%s'", value, key)
|
|
262
|
+
value.name = key
|
|
263
|
+
elif key != value.name:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
|
|
266
|
+
)
|
|
267
|
+
if value.producer() is not None:
|
|
268
|
+
raise ValueError(
|
|
269
|
+
f"Value '{value}' is produced by a node and cannot be a graph initializer"
|
|
270
|
+
)
|
|
271
|
+
if key in self.data:
|
|
272
|
+
# If the key already exists, unset the old value
|
|
273
|
+
old_value = self.data[key]
|
|
274
|
+
self._maybe_unset_graph(old_value)
|
|
275
|
+
# Must call _set_graph before super().__setitem__ so that when there is an error,
|
|
276
|
+
# the dictionary is not modified
|
|
277
|
+
self._set_graph(value)
|
|
278
|
+
super().__setitem__(key, value)
|
|
279
|
+
|
|
280
|
+
def __delitem__(self, key: str) -> None:
|
|
281
|
+
"""Delete an initializer from the graph."""
|
|
282
|
+
value = self.data[key]
|
|
283
|
+
# Must call _maybe_unset_graph before super().__delitem__ so that when there is an error,
|
|
284
|
+
# the dictionary is not modified
|
|
285
|
+
self._maybe_unset_graph(value)
|
|
286
|
+
super().__delitem__(key)
|
|
287
|
+
|
|
288
|
+
def add(self, value: _core.Value) -> None:
|
|
289
|
+
"""Add an initializer to the graph."""
|
|
290
|
+
self[value.name] = value # type: ignore[index]
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class Attributes(collections.UserDict[str, "_core.Attr"]):
|
|
294
|
+
"""The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
|
|
295
|
+
|
|
296
|
+
def __init__(self, attrs: Iterable[_core.Attr]):
|
|
297
|
+
super().__init__({attr.name: attr for attr in attrs})
|
|
298
|
+
|
|
299
|
+
def __setitem__(self, key: str, value: _core.Attr) -> None:
|
|
300
|
+
"""Set an attribute for the node."""
|
|
301
|
+
if type(key) is not str:
|
|
302
|
+
raise TypeError(f"Key must be a string, not {type(key)}")
|
|
303
|
+
if not isinstance(value, _core.Attr):
|
|
304
|
+
raise TypeError(f"Value must be an Attr, not {type(value)}")
|
|
305
|
+
super().__setitem__(key, value)
|
|
306
|
+
|
|
307
|
+
def add(self, value: _core.Attr) -> None:
|
|
308
|
+
"""Add an attribute to the node."""
|
|
309
|
+
self[value.name] = value
|
|
310
|
+
|
|
311
|
+
def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
|
|
312
|
+
"""Get the integer value of the attribute."""
|
|
313
|
+
if key in self:
|
|
314
|
+
return self[key].as_int()
|
|
315
|
+
return default
|
|
316
|
+
|
|
317
|
+
def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
|
|
318
|
+
"""Get the float value of the attribute."""
|
|
319
|
+
if key in self:
|
|
320
|
+
return self[key].as_float()
|
|
321
|
+
return default
|
|
322
|
+
|
|
323
|
+
def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
|
|
324
|
+
"""Get the string value of the attribute."""
|
|
325
|
+
if key in self:
|
|
326
|
+
return self[key].as_string()
|
|
327
|
+
return default
|
|
328
|
+
|
|
329
|
+
def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
|
|
330
|
+
"""Get the tensor value of the attribute."""
|
|
331
|
+
if key in self:
|
|
332
|
+
return self[key].as_tensor()
|
|
333
|
+
return default
|
|
334
|
+
|
|
335
|
+
def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
|
|
336
|
+
"""Get the graph value of the attribute."""
|
|
337
|
+
if key in self:
|
|
338
|
+
return self[key].as_graph()
|
|
339
|
+
return default
|
|
340
|
+
|
|
341
|
+
def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
|
|
342
|
+
"""Get the Sequence of integers from the attribute."""
|
|
343
|
+
if key in self:
|
|
344
|
+
return self[key].as_ints()
|
|
345
|
+
return default
|
|
346
|
+
|
|
347
|
+
def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
|
|
348
|
+
"""Get the Sequence of floats from the attribute."""
|
|
349
|
+
if key in self:
|
|
350
|
+
return self[key].as_floats()
|
|
351
|
+
return default
|
|
352
|
+
|
|
353
|
+
def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
|
|
354
|
+
"""Get the Sequence of strings from the attribute."""
|
|
355
|
+
if key in self:
|
|
356
|
+
return self[key].as_strings()
|
|
357
|
+
return default
|
|
358
|
+
|
|
359
|
+
def get_tensors(
|
|
360
|
+
self,
|
|
361
|
+
key: str,
|
|
362
|
+
default: T = None, # type: ignore[assignment]
|
|
363
|
+
) -> Sequence[_protocols.TensorProtocol] | T:
|
|
364
|
+
"""Get the Sequence of tensors from the attribute."""
|
|
365
|
+
if key in self:
|
|
366
|
+
return self[key].as_tensors()
|
|
367
|
+
return default
|
|
368
|
+
|
|
369
|
+
def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
|
|
370
|
+
"""Get the Sequence of graphs from the attribute."""
|
|
371
|
+
if key in self:
|
|
372
|
+
return self[key].as_graphs()
|
|
373
|
+
return default
|
onnx_ir/_io.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Load and save ONNX models."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["load", "save"]
|
|
8
|
+
|
|
9
|
+
import os
|
|
10
|
+
from typing import Callable
|
|
11
|
+
|
|
12
|
+
import onnx # noqa: TID251
|
|
13
|
+
|
|
14
|
+
from onnx_ir import _core, _protocols, serde
|
|
15
|
+
from onnx_ir import external_data as _external_data
|
|
16
|
+
from onnx_ir._polyfill import zip
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
|
|
20
|
+
"""Load an ONNX model from a file.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
path: The path to the ONNX file.
|
|
24
|
+
format: The format of the file (e.g. protobuf, textproto, json, etc.).
|
|
25
|
+
If None, the format is inferred from the file extension.
|
|
26
|
+
|
|
27
|
+
Returns:
|
|
28
|
+
The loaded model.
|
|
29
|
+
"""
|
|
30
|
+
# Do not use ONNX to load external data because the IR handles external data
|
|
31
|
+
# by doing memory mapping directly.
|
|
32
|
+
proto = onnx.load(path, format=format, load_external_data=False)
|
|
33
|
+
model = serde.deserialize_model(proto)
|
|
34
|
+
base_dir = os.path.dirname(path)
|
|
35
|
+
# Set the base directory for external data to the directory of the ONNX file
|
|
36
|
+
# so that relative paths are resolved correctly.
|
|
37
|
+
_external_data.set_base_dir(model.graph, base_dir)
|
|
38
|
+
return model
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def save(
|
|
42
|
+
model: _core.Model,
|
|
43
|
+
path: str | os.PathLike,
|
|
44
|
+
format: str | None = None,
|
|
45
|
+
external_data: str | os.PathLike | None = None,
|
|
46
|
+
size_threshold_bytes: int = 256,
|
|
47
|
+
callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
|
|
48
|
+
| None = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Save an ONNX model to a file.
|
|
51
|
+
|
|
52
|
+
The model remains unchanged after the call. If any existing external tensor
|
|
53
|
+
references the provided ``external_data`` path, it will be invalidated
|
|
54
|
+
after the external data is overwritten. To obtain a valid model, use :func:`load`
|
|
55
|
+
to load the newly saved model, or provide a different external data path that
|
|
56
|
+
is not currently referenced by any tensors in the model.
|
|
57
|
+
|
|
58
|
+
.. tip::
|
|
59
|
+
|
|
60
|
+
A simple progress bar can be implemented by passing a callback function as the following::
|
|
61
|
+
|
|
62
|
+
import onnx_ir as ir
|
|
63
|
+
import tqdm
|
|
64
|
+
|
|
65
|
+
with tqdm.tqdm() as pbar:
|
|
66
|
+
total_set = False
|
|
67
|
+
|
|
68
|
+
def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
|
|
69
|
+
nonlocal total_set
|
|
70
|
+
if not total_set:
|
|
71
|
+
pbar.total = metadata.total
|
|
72
|
+
total_set = True
|
|
73
|
+
|
|
74
|
+
pbar.update()
|
|
75
|
+
pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
|
|
76
|
+
|
|
77
|
+
ir.save(
|
|
78
|
+
...,
|
|
79
|
+
callback=callback,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
Args:
|
|
83
|
+
model: The model to save.
|
|
84
|
+
path: The path to save the model to. E.g. "model.onnx".
|
|
85
|
+
format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
|
|
86
|
+
If None, the format is inferred from the file extension.
|
|
87
|
+
external_data: The relative path to save external data to. When specified,
|
|
88
|
+
all initializers in the model will be converted to external data and
|
|
89
|
+
saved to the specified directory. If None, all tensors will be saved unmodified.
|
|
90
|
+
That is, if a tensor in the model is already external, it will be saved
|
|
91
|
+
with the same external information; if the tensor is not external,
|
|
92
|
+
it will be serialized in the ONNX Proto message.
|
|
93
|
+
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
94
|
+
Effective only when ``external_data`` is set.
|
|
95
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
96
|
+
for debugging or logging purposes.
|
|
97
|
+
|
|
98
|
+
Raises:
|
|
99
|
+
ValueError: If the external data path is an absolute path.
|
|
100
|
+
"""
|
|
101
|
+
if external_data is not None:
|
|
102
|
+
if os.path.isabs(external_data):
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"The external data path must be relative to the ONNX file path, not '{external_data}'."
|
|
105
|
+
)
|
|
106
|
+
base_dir = os.path.dirname(path)
|
|
107
|
+
|
|
108
|
+
# Store the original initializer values so they can be restored if modify_model=False
|
|
109
|
+
initialized_values: list[_core.Value] = []
|
|
110
|
+
for graph in model.graphs():
|
|
111
|
+
# Collect from all subgraphs as well
|
|
112
|
+
initialized_values.extend(graph.initializers.values())
|
|
113
|
+
tensors = [v.const_value for v in initialized_values]
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
model = _external_data.unload_from_model(
|
|
117
|
+
model,
|
|
118
|
+
base_dir,
|
|
119
|
+
external_data,
|
|
120
|
+
size_threshold_bytes=size_threshold_bytes,
|
|
121
|
+
callback=callback,
|
|
122
|
+
)
|
|
123
|
+
proto = serde.serialize_model(model)
|
|
124
|
+
onnx.save(proto, path, format=format)
|
|
125
|
+
|
|
126
|
+
finally:
|
|
127
|
+
# Restore the original initializer values so the model is unchanged
|
|
128
|
+
for initializer, tensor in zip(initialized_values, tensors, strict=True):
|
|
129
|
+
initializer.const_value = tensor
|
|
130
|
+
|
|
131
|
+
else:
|
|
132
|
+
proto = serde.serialize_model(model)
|
|
133
|
+
onnx.save(proto, path, format=format)
|