onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,373 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Tracked containers for graph."""
4
+
5
+ # pylint: disable=protected-access
6
+
7
+ from __future__ import annotations
8
+
9
+ __all__ = [
10
+ "GraphInputs",
11
+ "GraphOutputs",
12
+ ]
13
+
14
+ import collections
15
+ import logging
16
+ from collections.abc import Iterable, Sequence
17
+ from typing import SupportsIndex, TypeVar
18
+
19
+ import onnx_ir
20
+ from onnx_ir import _core, _protocols
21
+
22
+ T = TypeVar("T")
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class _GraphIO(collections.UserList["_core.Value"]):
28
+ """The inputs and outputs of a Graph."""
29
+
30
+ def __init__(self, graph: _core.Graph, initlist=None):
31
+ self._graph = graph
32
+ # Use a ref counter to track the number of references to each value
33
+ # in the input/output list. This is used to determine when to unset the graph
34
+ # reference in the value.
35
+ # Even though a duplicated value is invalid in inputs and not recommended in outputs,
36
+ # it is still possible to have duplicated inputs/outputs in an ONNX graph so we
37
+ # need to properly handle this case and maintain the graph reference properly.
38
+ self._ref_counter: collections.Counter[_core.Value] = collections.Counter()
39
+ if initlist is not None:
40
+ initlist = tuple(initlist) # Create a copy in case initlist is a generator
41
+ for value in initlist:
42
+ self._set_graph(value)
43
+ super().__init__(initlist)
44
+ self._check_invariance()
45
+
46
+ def _check_invariance(self) -> None:
47
+ """Check the invariance of the graph."""
48
+ raise NotImplementedError
49
+
50
+ def _set_graph(self, value: _core.Value) -> None:
51
+ """Set the graph for the value."""
52
+ raise NotImplementedError
53
+
54
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
55
+ """Unset the graph for the value."""
56
+ raise NotImplementedError
57
+
58
+ def append(self, item: _core.Value) -> None:
59
+ """Add a new input to the graph."""
60
+ # Perform checks first in _set_graph before modifying the data structure
61
+ self._set_graph(item)
62
+ super().append(item)
63
+ self._check_invariance()
64
+
65
+ def extend(self, other) -> None:
66
+ """Extend the list of inputs or outputs."""
67
+ other = tuple(other)
68
+ for item in other:
69
+ self._set_graph(item)
70
+ super().extend(other)
71
+
72
+ def insert(self, i: int, item: _core.Value) -> None:
73
+ """Insert an input/output to the graph."""
74
+ super().insert(i, item)
75
+ self._set_graph(item)
76
+ self._check_invariance()
77
+
78
+ def pop(self, i: int = -1) -> _core.Value:
79
+ """Remove an input/output from the graph."""
80
+ value = super().pop(i)
81
+ self._maybe_unset_graph(value)
82
+ self._check_invariance()
83
+ return value
84
+
85
+ def remove(self, item: _core.Value) -> None:
86
+ """Remove an input/output from the graph."""
87
+ super().remove(item)
88
+ self._maybe_unset_graph(item)
89
+ self._check_invariance()
90
+
91
+ def clear(self) -> None:
92
+ """Clear the list."""
93
+ for value in self.data:
94
+ self._maybe_unset_graph(value)
95
+ super().clear()
96
+
97
+ def copy(self) -> list[_core.Value]:
98
+ """Return a shallow copy of the list."""
99
+ # This is a shallow copy, so the values are not copied, just the references
100
+ return self.data.copy()
101
+
102
+ def __setitem__(self, i, item) -> None:
103
+ """Replace an input/output to the node."""
104
+ if isinstance(item, Iterable) and isinstance(i, slice):
105
+ # Modify a slice of the list
106
+ for value in self.data[i]:
107
+ self._maybe_unset_graph(value)
108
+ for value in item:
109
+ self._set_graph(value)
110
+ super().__setitem__(i, item)
111
+ self._check_invariance()
112
+ return
113
+ elif isinstance(i, SupportsIndex):
114
+ # Replace a single item
115
+ self._maybe_unset_graph(self.data[i])
116
+ self._set_graph(item)
117
+ super().__setitem__(i, item)
118
+ self._check_invariance()
119
+ return
120
+
121
+ raise TypeError(f"Invalid types for __setitem__: {type(i)} and {type(item)}")
122
+
123
+ def __getitem__(self, i):
124
+ """Get an input/output from the graph."""
125
+ return self.data[i]
126
+
127
+ def _unimplemented(self, *_args, **_kwargs):
128
+ """Unimplemented method."""
129
+ raise RuntimeError("Method is not supported")
130
+
131
+ __add__ = _unimplemented
132
+ __radd__ = _unimplemented
133
+ __iadd__ = _unimplemented
134
+ __mul__ = _unimplemented
135
+ __rmul__ = _unimplemented
136
+
137
+
138
+ class GraphInputs(_GraphIO):
139
+ """The inputs of a Graph."""
140
+
141
+ def _check_invariance(self) -> None:
142
+ """Check the invariance of the graph."""
143
+ if not onnx_ir.DEBUG:
144
+ return
145
+ for value in self.data:
146
+ if value._graph is self._graph:
147
+ continue
148
+ raise ValueError(
149
+ f"Invariance error: Value '{value}' is not an input of the graph: {self._graph!r}"
150
+ )
151
+
152
+ def _set_graph(self, value: _core.Value) -> None:
153
+ """Set the graph for the value."""
154
+ if value._graph is not None and value._graph is not self._graph:
155
+ raise ValueError(
156
+ f"Value '{value}' is already owned by a different graph. Please remove the value from the previous graph first"
157
+ )
158
+ if value.producer() is not None:
159
+ raise ValueError(
160
+ f"Value '{value}' is produced by a node and cannot be an input to the graph. Please create new Values for graph inputs"
161
+ )
162
+ self._ref_counter[value] += 1
163
+ value._is_graph_input = True
164
+ value._graph = self._graph
165
+
166
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
167
+ """Unset the graph for the value."""
168
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
169
+ self._ref_counter[value] -= 1
170
+ if self._ref_counter[value] > 0:
171
+ # The value is still used by another graph input
172
+ return
173
+ value._is_graph_input = False
174
+ if value._owned_by_graph():
175
+ # Keep the graph reference if the value is still an input or an initializer
176
+ return
177
+ value._graph = None
178
+
179
+
180
+ class GraphOutputs(_GraphIO):
181
+ """The outputs of a Graph."""
182
+
183
+ def _check_invariance(self) -> None:
184
+ """Check the invariance of the graph."""
185
+ if not onnx_ir.DEBUG:
186
+ return
187
+ for value in self.data:
188
+ if value._graph is self._graph:
189
+ continue
190
+ raise ValueError(
191
+ f"Invariance error: Value '{value}' is not an output of the graph: {self._graph!r}"
192
+ )
193
+
194
+ def _set_graph(self, value: _core.Value) -> None:
195
+ """Set the graph for the value."""
196
+ if value._graph is not None and value._graph is not self._graph:
197
+ raise ValueError(
198
+ f"Value '{value}' is already an output of a different graph. Please remove the value from the previous graph first"
199
+ )
200
+ self._ref_counter[value] += 1
201
+ value._is_graph_output = True
202
+ value._graph = self._graph
203
+
204
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
205
+ """Unset the graph for the value."""
206
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
207
+ self._ref_counter[value] -= 1
208
+ if self._ref_counter[value] > 0:
209
+ # The value is still used by another graph input
210
+ return
211
+ value._is_graph_output = False
212
+ if value._owned_by_graph():
213
+ # Keep the graph reference if the value is still an input or an initializer
214
+ return
215
+ value._graph = None
216
+
217
+
218
+ class GraphInitializers(collections.UserDict[str, "_core.Value"]):
219
+ """The initializers of a Graph as ``dict[str, Value]`` with additional mutation methods."""
220
+
221
+ def __init__(self, graph: _core.Graph, dict=None, /, **kwargs):
222
+ # Perform checks first in _set_graph before modifying the data structure with super().__init__()
223
+ data = {}
224
+ if dict is not None:
225
+ data.update(dict)
226
+ if kwargs:
227
+ data.update(kwargs)
228
+ self._graph = graph
229
+ for value in data.values():
230
+ self._set_graph(value)
231
+
232
+ super().__init__(data)
233
+
234
+ def _set_graph(self, value: _core.Value) -> None:
235
+ """Set the graph for the value."""
236
+ if value._graph is not None and value._graph is not self._graph:
237
+ raise ValueError(
238
+ f"Value '{value}' is already an initializer of a different graph. Please remove the value from the previous graph first"
239
+ )
240
+ value._is_initializer = True
241
+ value._graph = self._graph
242
+
243
+ def _maybe_unset_graph(self, value: _core.Value) -> None:
244
+ """Unset the graph for the value."""
245
+ assert value._graph is self._graph, "Bug: value does not belong to the graph"
246
+ value._is_initializer = False
247
+ if value._owned_by_graph():
248
+ # Keep the graph reference if the value is still an input or an initializer
249
+ return
250
+ value._graph = None
251
+
252
+ def __setitem__(self, key: str, value: _core.Value) -> None:
253
+ """Set an initializer for the graph."""
254
+ if not isinstance(value, _core.Value):
255
+ raise TypeError(f"value must be a Value object, not {type(value)}")
256
+ if not isinstance(key, str):
257
+ raise TypeError(f"Value name must be a string, not {type(key)}")
258
+ if key == "":
259
+ raise ValueError("Value name cannot be an empty string")
260
+ if not value.name:
261
+ logger.info("Value %s does not have a name, setting it to '%s'", value, key)
262
+ value.name = key
263
+ elif key != value.name:
264
+ raise ValueError(
265
+ f"Key '{key}' does not match the name of the value '{value.name}'. Please use the value.name as the key."
266
+ )
267
+ if value.producer() is not None:
268
+ raise ValueError(
269
+ f"Value '{value}' is produced by a node and cannot be a graph initializer"
270
+ )
271
+ if key in self.data:
272
+ # If the key already exists, unset the old value
273
+ old_value = self.data[key]
274
+ self._maybe_unset_graph(old_value)
275
+ # Must call _set_graph before super().__setitem__ so that when there is an error,
276
+ # the dictionary is not modified
277
+ self._set_graph(value)
278
+ super().__setitem__(key, value)
279
+
280
+ def __delitem__(self, key: str) -> None:
281
+ """Delete an initializer from the graph."""
282
+ value = self.data[key]
283
+ # Must call _maybe_unset_graph before super().__delitem__ so that when there is an error,
284
+ # the dictionary is not modified
285
+ self._maybe_unset_graph(value)
286
+ super().__delitem__(key)
287
+
288
+ def add(self, value: _core.Value) -> None:
289
+ """Add an initializer to the graph."""
290
+ self[value.name] = value # type: ignore[index]
291
+
292
+
293
+ class Attributes(collections.UserDict[str, "_core.Attr"]):
294
+ """The attributes of a Node as ``dict[str, Attr]`` with additional access methods."""
295
+
296
+ def __init__(self, attrs: Iterable[_core.Attr]):
297
+ super().__init__({attr.name: attr for attr in attrs})
298
+
299
+ def __setitem__(self, key: str, value: _core.Attr) -> None:
300
+ """Set an attribute for the node."""
301
+ if type(key) is not str:
302
+ raise TypeError(f"Key must be a string, not {type(key)}")
303
+ if not isinstance(value, _core.Attr):
304
+ raise TypeError(f"Value must be an Attr, not {type(value)}")
305
+ super().__setitem__(key, value)
306
+
307
+ def add(self, value: _core.Attr) -> None:
308
+ """Add an attribute to the node."""
309
+ self[value.name] = value
310
+
311
+ def get_int(self, key: str, default: T = None) -> int | T: # type: ignore[assignment]
312
+ """Get the integer value of the attribute."""
313
+ if key in self:
314
+ return self[key].as_int()
315
+ return default
316
+
317
+ def get_float(self, key: str, default: T = None) -> float | T: # type: ignore[assignment]
318
+ """Get the float value of the attribute."""
319
+ if key in self:
320
+ return self[key].as_float()
321
+ return default
322
+
323
+ def get_string(self, key: str, default: T = None) -> str | T: # type: ignore[assignment]
324
+ """Get the string value of the attribute."""
325
+ if key in self:
326
+ return self[key].as_string()
327
+ return default
328
+
329
+ def get_tensor(self, key: str, default: T = None) -> _protocols.TensorProtocol | T: # type: ignore[assignment]
330
+ """Get the tensor value of the attribute."""
331
+ if key in self:
332
+ return self[key].as_tensor()
333
+ return default
334
+
335
+ def get_graph(self, key: str, default: T = None) -> _core.Graph | T: # type: ignore[assignment]
336
+ """Get the graph value of the attribute."""
337
+ if key in self:
338
+ return self[key].as_graph()
339
+ return default
340
+
341
+ def get_ints(self, key: str, default: T = None) -> Sequence[int] | T: # type: ignore[assignment]
342
+ """Get the Sequence of integers from the attribute."""
343
+ if key in self:
344
+ return self[key].as_ints()
345
+ return default
346
+
347
+ def get_floats(self, key: str, default: T = None) -> Sequence[float] | T: # type: ignore[assignment]
348
+ """Get the Sequence of floats from the attribute."""
349
+ if key in self:
350
+ return self[key].as_floats()
351
+ return default
352
+
353
+ def get_strings(self, key: str, default: T = None) -> Sequence[str] | T: # type: ignore[assignment]
354
+ """Get the Sequence of strings from the attribute."""
355
+ if key in self:
356
+ return self[key].as_strings()
357
+ return default
358
+
359
+ def get_tensors(
360
+ self,
361
+ key: str,
362
+ default: T = None, # type: ignore[assignment]
363
+ ) -> Sequence[_protocols.TensorProtocol] | T:
364
+ """Get the Sequence of tensors from the attribute."""
365
+ if key in self:
366
+ return self[key].as_tensors()
367
+ return default
368
+
369
+ def get_graphs(self, key: str, default: T = None) -> Sequence[_core.Graph] | T: # type: ignore[assignment]
370
+ """Get the Sequence of graphs from the attribute."""
371
+ if key in self:
372
+ return self[key].as_graphs()
373
+ return default
onnx_ir/_io.py ADDED
@@ -0,0 +1,133 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Load and save ONNX models."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["load", "save"]
8
+
9
+ import os
10
+ from typing import Callable
11
+
12
+ import onnx # noqa: TID251
13
+
14
+ from onnx_ir import _core, _protocols, serde
15
+ from onnx_ir import external_data as _external_data
16
+ from onnx_ir._polyfill import zip
17
+
18
+
19
+ def load(path: str | os.PathLike, format: str | None = None) -> _core.Model:
20
+ """Load an ONNX model from a file.
21
+
22
+ Args:
23
+ path: The path to the ONNX file.
24
+ format: The format of the file (e.g. protobuf, textproto, json, etc.).
25
+ If None, the format is inferred from the file extension.
26
+
27
+ Returns:
28
+ The loaded model.
29
+ """
30
+ # Do not use ONNX to load external data because the IR handles external data
31
+ # by doing memory mapping directly.
32
+ proto = onnx.load(path, format=format, load_external_data=False)
33
+ model = serde.deserialize_model(proto)
34
+ base_dir = os.path.dirname(path)
35
+ # Set the base directory for external data to the directory of the ONNX file
36
+ # so that relative paths are resolved correctly.
37
+ _external_data.set_base_dir(model.graph, base_dir)
38
+ return model
39
+
40
+
41
+ def save(
42
+ model: _core.Model,
43
+ path: str | os.PathLike,
44
+ format: str | None = None,
45
+ external_data: str | os.PathLike | None = None,
46
+ size_threshold_bytes: int = 256,
47
+ callback: Callable[[_protocols.TensorProtocol, _external_data.CallbackInfo], None]
48
+ | None = None,
49
+ ) -> None:
50
+ """Save an ONNX model to a file.
51
+
52
+ The model remains unchanged after the call. If any existing external tensor
53
+ references the provided ``external_data`` path, it will be invalidated
54
+ after the external data is overwritten. To obtain a valid model, use :func:`load`
55
+ to load the newly saved model, or provide a different external data path that
56
+ is not currently referenced by any tensors in the model.
57
+
58
+ .. tip::
59
+
60
+ A simple progress bar can be implemented by passing a callback function as the following::
61
+
62
+ import onnx_ir as ir
63
+ import tqdm
64
+
65
+ with tqdm.tqdm() as pbar:
66
+ total_set = False
67
+
68
+ def callback(tensor: ir.TensorProtocol, metadata: ir.external_data.CallbackInfo) -> None:
69
+ nonlocal total_set
70
+ if not total_set:
71
+ pbar.total = metadata.total
72
+ total_set = True
73
+
74
+ pbar.update()
75
+ pbar.set_description(f"Saving {tensor.name} ({tensor.dtype}, {tensor.shape}) at offset {metadata.offset}")
76
+
77
+ ir.save(
78
+ ...,
79
+ callback=callback,
80
+ )
81
+
82
+ Args:
83
+ model: The model to save.
84
+ path: The path to save the model to. E.g. "model.onnx".
85
+ format: The format of the file (e.g. ``protobuf``, ``textproto``, ``json``, etc.).
86
+ If None, the format is inferred from the file extension.
87
+ external_data: The relative path to save external data to. When specified,
88
+ all initializers in the model will be converted to external data and
89
+ saved to the specified directory. If None, all tensors will be saved unmodified.
90
+ That is, if a tensor in the model is already external, it will be saved
91
+ with the same external information; if the tensor is not external,
92
+ it will be serialized in the ONNX Proto message.
93
+ size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
94
+ Effective only when ``external_data`` is set.
95
+ callback: A callback function that is called for each tensor that is saved to external data
96
+ for debugging or logging purposes.
97
+
98
+ Raises:
99
+ ValueError: If the external data path is an absolute path.
100
+ """
101
+ if external_data is not None:
102
+ if os.path.isabs(external_data):
103
+ raise ValueError(
104
+ f"The external data path must be relative to the ONNX file path, not '{external_data}'."
105
+ )
106
+ base_dir = os.path.dirname(path)
107
+
108
+ # Store the original initializer values so they can be restored if modify_model=False
109
+ initialized_values: list[_core.Value] = []
110
+ for graph in model.graphs():
111
+ # Collect from all subgraphs as well
112
+ initialized_values.extend(graph.initializers.values())
113
+ tensors = [v.const_value for v in initialized_values]
114
+
115
+ try:
116
+ model = _external_data.unload_from_model(
117
+ model,
118
+ base_dir,
119
+ external_data,
120
+ size_threshold_bytes=size_threshold_bytes,
121
+ callback=callback,
122
+ )
123
+ proto = serde.serialize_model(model)
124
+ onnx.save(proto, path, format=format)
125
+
126
+ finally:
127
+ # Restore the original initializer values so the model is unchanged
128
+ for initializer, tensor in zip(initialized_values, tensors, strict=True):
129
+ initializer.const_value = tensor
130
+
131
+ else:
132
+ proto = serde.serialize_model(model)
133
+ onnx.save(proto, path, format=format)