onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,284 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Mutable list for nodes in a graph with safe mutation properties."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Iterable, Iterator, Sequence
8
+ from typing import Generic, TypeVar, overload
9
+
10
+ T = TypeVar("T")
11
+
12
+
13
+ class _LinkBox(Generic[T]):
14
+ """A link in a doubly linked list that has a reference to the actual object in the link.
15
+
16
+ The :class:`_LinkBox` is a container for the actual object in the list. It is used to
17
+ maintain the links between the elements in the linked list. The actual object is stored in the
18
+ :attr:`value` attribute.
19
+
20
+ By using a separate container for the actual object, we can safely remove the object from the
21
+ list without losing the links. This allows us to remove the object from the list during
22
+ iteration and place the object into a different list without breaking any chains.
23
+
24
+ This is an internal class and should only be initialized by the :class:`DoublyLinkedSet`.
25
+
26
+ Attributes:
27
+ prev: The previous box in the list.
28
+ next: The next box in the list.
29
+ erased: A flag to indicate if the box has been removed from the list.
30
+ owning_list: The :class:`DoublyLinkedSet` to which the box belongs.
31
+ value: The actual object in the list.
32
+ """
33
+
34
+ __slots__ = ("next", "owning_list", "prev", "value")
35
+
36
+ def __init__(self, owner: DoublyLinkedSet[T], value: T | None) -> None:
37
+ """Create a new link box.
38
+
39
+ Args:
40
+ owner: The linked list to which this box belongs.
41
+ value: The value to be stored in the link box. When the value is None,
42
+ the link box is considered erased (default). The root box of the list
43
+ should be created with a None value.
44
+ """
45
+ self.prev: _LinkBox[T] = self
46
+ self.next: _LinkBox[T] = self
47
+ self.value: T | None = value
48
+ self.owning_list: DoublyLinkedSet[T] = owner
49
+
50
+ @property
51
+ def erased(self) -> bool:
52
+ return self.value is None
53
+
54
+ def erase(self) -> None:
55
+ """Remove the link from the list and detach the value from the box."""
56
+ if self.value is None:
57
+ raise ValueError("_LinkBox is already erased")
58
+ # Update the links
59
+ prev, next_ = self.prev, self.next
60
+ prev.next, next_.prev = next_, prev
61
+ # Detach the value
62
+ self.value = None
63
+
64
+ def __repr__(self) -> str:
65
+ return f"_LinkBox({self.value!r}, erased={self.erased}, prev={self.prev.value!r}, next={self.next.value!r})"
66
+
67
+
68
+ class DoublyLinkedSet(Sequence[T], Generic[T]):
69
+ """A doubly linked ordered set of nodes.
70
+
71
+ The container can be viewed as a set as it does not allow duplicate values. The order of the
72
+ elements is maintained. One can typically treat it as a doubly linked list with list-like
73
+ methods implemented.
74
+
75
+ Adding and removing elements from the set during iteration is safe. Moving elements
76
+ from one set to another is also safe.
77
+
78
+ During the iteration:
79
+ - If new elements are inserted after the current node, the iterator will
80
+ iterate over them as well.
81
+ - If new elements are inserted before the current node, they will
82
+ not be iterated over in this iteration.
83
+ - If the current node is lifted and inserted in a different location,
84
+ iteration will start from the "next" node at the _original_ location.
85
+
86
+ Time complexity:
87
+ Inserting and removing nodes from the set is O(1). Accessing nodes by index is O(n),
88
+ although accessing nodes at either end of the set is O(1). I.e.
89
+ ``linked_set[0]`` and ``linked_set[-1]`` are O(1).
90
+
91
+ Values need to be hashable. ``None`` is not a valid value in the set.
92
+ """
93
+
94
+ __slots__ = ("_length", "_root", "_value_ids_to_boxes")
95
+
96
+ def __init__(self, values: Iterable[T] | None = None) -> None:
97
+ # Using the root node simplifies the mutation implementation a lot
98
+ # The list is circular. The root node is the only node that is not a part of the list values
99
+ root_ = _LinkBox(self, None)
100
+ self._root: _LinkBox = root_
101
+ self._length = 0
102
+ self._value_ids_to_boxes: dict[int, _LinkBox] = {}
103
+ if values is not None:
104
+ self.extend(values)
105
+
106
+ def __iter__(self) -> Iterator[T]:
107
+ """Iterate over the elements in the list.
108
+
109
+ - If new elements are inserted after the current node, the iterator will
110
+ iterate over them as well.
111
+ - If new elements are inserted before the current node, they will
112
+ not be iterated over in this iteration.
113
+ - If the current node is lifted and inserted in a different location,
114
+ iteration will start from the "next" node at the _original_ location.
115
+ """
116
+ box = self._root.next
117
+ while box is not self._root:
118
+ if box.owning_list is not self:
119
+ raise RuntimeError(f"Element {box!r} is not in the list")
120
+ if not box.erased:
121
+ assert box.value is not None
122
+ yield box.value
123
+ box = box.next
124
+
125
+ def __reversed__(self) -> Iterator[T]:
126
+ """Iterate over the elements in the list in reverse order."""
127
+ box = self._root.prev
128
+ while box is not self._root:
129
+ if not box.erased:
130
+ assert box.value is not None
131
+ yield box.value
132
+ box = box.prev
133
+
134
+ def __len__(self) -> int:
135
+ assert self._length == len(self._value_ids_to_boxes), (
136
+ "Bug in the implementation: length mismatch"
137
+ )
138
+ return self._length
139
+
140
+ @overload
141
+ def __getitem__(self, index: int) -> T: ...
142
+ @overload
143
+ def __getitem__(self, index: slice) -> Sequence[T]: ...
144
+
145
+ def __getitem__(self, index):
146
+ """Get the node at the given index.
147
+
148
+ Complexity is O(n).
149
+ """
150
+ if isinstance(index, slice):
151
+ return tuple(self)[index]
152
+ if index >= self._length or index < -self._length:
153
+ raise IndexError(
154
+ f"Index out of range: {index} not in range [-{self._length}, {self._length})"
155
+ )
156
+ if index < 0:
157
+ # Look up from the end of the list
158
+ iterator = reversed(self)
159
+ item = next(iterator)
160
+ for _ in range(-index - 1):
161
+ item = next(iterator)
162
+ else:
163
+ iterator = iter(self) # type: ignore[assignment]
164
+ item = next(iterator)
165
+ for _ in range(index):
166
+ item = next(iterator)
167
+ return item
168
+
169
+ def _insert_one_after(
170
+ self,
171
+ box: _LinkBox[T],
172
+ new_value: T,
173
+ ) -> _LinkBox[T]:
174
+ """Insert a new value after the given box.
175
+
176
+ All insertion methods should call this method to ensure that the list is updated correctly.
177
+
178
+ Example::
179
+ Before: A <-> B <-> C
180
+ ^v0 ^v1 ^v2
181
+ Call: _insert_one_after(B, v3)
182
+ After: A <-> B <-> new_box <-> C
183
+ ^v0 ^v1 ^v3 ^v2
184
+
185
+ Args:
186
+ box: The box which the new value is to be inserted.
187
+ new_value: The new value to be inserted.
188
+ """
189
+ if new_value is None:
190
+ raise TypeError(f"{self.__class__.__name__} does not support None values")
191
+ if box.value is new_value:
192
+ # Do nothing if the new value is the same as the old value
193
+ return box
194
+ if box.owning_list is not self:
195
+ raise ValueError(f"Value {box.value!r} is not in the list")
196
+
197
+ if (new_value_id := id(new_value)) in self._value_ids_to_boxes:
198
+ # If the value is already in the list, remove it first
199
+ self.remove(new_value)
200
+
201
+ # Create a new _LinkBox for the new value
202
+ new_box = _LinkBox(self, new_value)
203
+ # original_box <=> original_next
204
+ # becomes
205
+ # original_box <=> new_box <=> original_next
206
+ original_next = box.next
207
+ box.next = new_box
208
+ new_box.prev = box
209
+ new_box.next = original_next
210
+ original_next.prev = new_box
211
+
212
+ # Be sure to update the length and mapping
213
+ self._length += 1
214
+ self._value_ids_to_boxes[new_value_id] = new_box
215
+
216
+ return new_box
217
+
218
+ def _insert_many_after(
219
+ self,
220
+ box: _LinkBox[T],
221
+ new_values: Iterable[T],
222
+ ):
223
+ """Insert multiple new values after the given box."""
224
+ insertion_point = box
225
+ for new_value in new_values:
226
+ insertion_point = self._insert_one_after(insertion_point, new_value)
227
+
228
+ def remove(self, value: T) -> None:
229
+ """Remove a node from the list."""
230
+ if (value_id := id(value)) not in self._value_ids_to_boxes:
231
+ raise ValueError(f"Value {value!r} is not in the list")
232
+ box = self._value_ids_to_boxes[value_id]
233
+ # Remove the link box and detach the value from the box
234
+ box.erase()
235
+
236
+ # Be sure to update the length and mapping
237
+ self._length -= 1
238
+ del self._value_ids_to_boxes[value_id]
239
+
240
+ def append(self, value: T) -> None:
241
+ """Append a node to the list."""
242
+ _ = self._insert_one_after(self._root.prev, value)
243
+
244
+ def extend(
245
+ self,
246
+ values: Iterable[T],
247
+ ) -> None:
248
+ for value in values:
249
+ self.append(value)
250
+
251
+ def insert_after(
252
+ self,
253
+ value: T,
254
+ new_values: Iterable[T],
255
+ ) -> None:
256
+ """Insert new nodes after the given node.
257
+
258
+ Args:
259
+ value: The value after which the new values are to be inserted.
260
+ new_values: The new values to be inserted.
261
+ """
262
+ if (value_id := id(value)) not in self._value_ids_to_boxes:
263
+ raise ValueError(f"Value {value!r} is not in the list")
264
+ insertion_point = self._value_ids_to_boxes[value_id]
265
+ return self._insert_many_after(insertion_point, new_values)
266
+
267
+ def insert_before(
268
+ self,
269
+ value: T,
270
+ new_values: Iterable[T],
271
+ ) -> None:
272
+ """Insert new nodes before the given node.
273
+
274
+ Args:
275
+ value: The value before which the new values are to be inserted.
276
+ new_values: The new values to be inserted.
277
+ """
278
+ if (value_id := id(value)) not in self._value_ids_to_boxes:
279
+ raise ValueError(f"Value {value!r} is not in the list")
280
+ insertion_point = self._value_ids_to_boxes[value_id].prev
281
+ return self._insert_many_after(insertion_point, new_values)
282
+
283
+ def __repr__(self) -> str:
284
+ return f"DoublyLinkedSet({list(self)})"
onnx_ir/_metadata.py ADDED
@@ -0,0 +1,45 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Class for storing metadata about the IR objects."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import collections
8
+ from collections.abc import Mapping
9
+ from typing import Any
10
+
11
+
12
+ class MetadataStore(collections.UserDict):
13
+ """Class for storing metadata about the IR objects.
14
+
15
+ Metadata is stored as key-value pairs. The keys are strings and the values
16
+ can be any Python object.
17
+
18
+ The metadata store also supports marking keys as invalid. This is useful
19
+ when a pass wants to mark a key that needs to be recomputed.
20
+ """
21
+
22
+ def __init__(self, data: Mapping[str, Any] | None = None, /) -> None:
23
+ super().__init__(data)
24
+ self._invalid_keys: set[str] = set()
25
+
26
+ def __setitem__(self, key: str, item: Any) -> None:
27
+ self.data[key] = item
28
+ self._invalid_keys.discard(key)
29
+
30
+ def invalidate(self, key: str) -> None:
31
+ self._invalid_keys.add(key)
32
+
33
+ def is_valid(self, key: str) -> bool:
34
+ """Returns whether the value is valid.
35
+
36
+ Note that default values (None) are not necessarily invalid. For example,
37
+ a shape that is unknown (None) may be still valid if shape inference has
38
+ determined that the shape is unknown.
39
+
40
+ Whether a value is valid is solely determined by the user that sets the value.
41
+ """
42
+ return key not in self._invalid_keys
43
+
44
+ def __repr__(self) -> str:
45
+ return f"{self.__class__.__name__}({self.data!r}, invalid_keys={self._invalid_keys!r})"
@@ -0,0 +1,72 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Auxiliary class for managing names in the IR."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from onnx_ir import _core
8
+
9
+
10
+ class NameAuthority:
11
+ """Class for giving names to values and nodes in the IR.
12
+
13
+ The names are generated in the format ``val_{value_counter}`` for values and
14
+ ``node_{op_type}_{node_counter}`` for nodes. The counter is incremented each time
15
+ a new value or node is named.
16
+
17
+ This class keeps tracks of the names it has generated and existing names
18
+ in the graph to prevent producing duplicated names.
19
+
20
+ .. note::
21
+ Once a name is tracked, it will not be made available even if the node/value
22
+ is removed from the graph. It is possible to improve this behavior by keeping
23
+ track of the names that are no longer used, but it is not implemented yet.
24
+
25
+ However, if a value/node is already named when added to the graph,
26
+ the name authority will not change its name.
27
+ It is the responsibility of the user to ensure that the names are unique
28
+ (typically by running a name-fixing pass on the graph).
29
+
30
+ TODO(justichuby): Describe the pass when we have a reference implementation.
31
+ """
32
+
33
+ def __init__(self):
34
+ self._value_counter = 0
35
+ self._node_counter = 0
36
+ self._value_names: set[str] = set()
37
+ self._node_names: set[str] = set()
38
+
39
+ def _unique_value_name(self) -> str:
40
+ """Generate a unique name for a value."""
41
+ while True:
42
+ name = f"val_{self._value_counter}"
43
+ self._value_counter += 1
44
+ if name not in self._value_names:
45
+ return name
46
+
47
+ def _unique_node_name(self, op_type: str) -> str:
48
+ """Generate a unique name for a node."""
49
+ while True:
50
+ name = f"node_{op_type}_{self._node_counter}"
51
+ self._node_counter += 1
52
+ if name not in self._node_names:
53
+ return name
54
+
55
+ def register_or_name_value(self, value: _core.Value) -> None:
56
+ # TODO(justinchuby): Record names of the initializers and graph inputs
57
+ if value.name is None:
58
+ value.name = self._unique_value_name()
59
+ # If the name is already specified, we do not change it because keeping
60
+ # track of the used names can be costly when nodes can be removed from the graph:
61
+ # How do we know if a name is no longer used? We cannot reserve unused names
62
+ # because users may want to use them.
63
+ self._value_names.add(value.name)
64
+
65
+ def register_or_name_node(self, node: _core.Node) -> None:
66
+ if node.name is None:
67
+ node.name = self._unique_node_name(node.op_type)
68
+ # If the name is already specified, we do not change it because keeping
69
+ # track of the used names can be costly when nodes can be removed from the graph:
70
+ # How do we know if a name is no longer used? We cannot reserve unused names
71
+ # because users may want to use them.
72
+ self._node_names.add(node.name)
onnx_ir/_polyfill.py ADDED
@@ -0,0 +1,26 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Polyfill for Python builtin functions."""
4
+
5
+ import sys
6
+ from collections.abc import Sequence
7
+ from typing import Any
8
+
9
+ if sys.version_info >= (3, 10):
10
+ zip = zip # pylint: disable=self-assigning-variable
11
+ else:
12
+ # zip(..., strict=True) was added in Python 3.10
13
+ # TODO: Remove this polyfill when we drop support for Python 3.9
14
+ _python_zip = zip
15
+
16
+ def zip(a: Sequence[Any], b: Sequence[Any], strict: bool = False):
17
+ """Polyfill for Python's zip function.
18
+
19
+ This is a special version which only supports two Sequence inputs.
20
+
21
+ Raises:
22
+ ValueError: If the iterables have different lengths and strict is True.
23
+ """
24
+ if len(a) != len(b) and strict:
25
+ raise ValueError("zip() argument lengths must be equal")
26
+ return _python_zip(a, b)