onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/__init__.py ADDED
@@ -0,0 +1,176 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """In-memory intermediate representation for ONNX graphs."""
4
+
5
+ __all__ = [
6
+ # Modules
7
+ "serde",
8
+ "traversal",
9
+ "convenience",
10
+ "external_data",
11
+ "tape",
12
+ # IR classes
13
+ "Tensor",
14
+ "ExternalTensor",
15
+ "StringTensor",
16
+ "LazyTensor",
17
+ "PackedTensor",
18
+ "SymbolicDim",
19
+ "Shape",
20
+ "TensorType",
21
+ "OptionalType",
22
+ "SequenceType",
23
+ "SparseTensorType",
24
+ "TypeAndShape",
25
+ "Value",
26
+ "Attr",
27
+ "RefAttr",
28
+ "Node",
29
+ "Function",
30
+ "Graph",
31
+ "GraphView",
32
+ "Model",
33
+ # Constructors
34
+ "AttrFloat32",
35
+ "AttrFloat32s",
36
+ "AttrGraph",
37
+ "AttrGraphs",
38
+ "AttrInt64",
39
+ "AttrInt64s",
40
+ "AttrSparseTensor",
41
+ "AttrSparseTensors",
42
+ "AttrString",
43
+ "AttrStrings",
44
+ "AttrTensor",
45
+ "AttrTensors",
46
+ "AttrTypeProto",
47
+ "AttrTypeProtos",
48
+ "Input",
49
+ # Protocols
50
+ "ArrayCompatible",
51
+ "DLPackCompatible",
52
+ "TensorProtocol",
53
+ "ValueProtocol",
54
+ "ModelProtocol",
55
+ "NodeProtocol",
56
+ "GraphProtocol",
57
+ "GraphViewProtocol",
58
+ "AttributeProtocol",
59
+ "ReferenceAttributeProtocol",
60
+ "SparseTensorProtocol",
61
+ "SymbolicDimProtocol",
62
+ "ShapeProtocol",
63
+ "TypeProtocol",
64
+ "MapTypeProtocol",
65
+ "FunctionProtocol",
66
+ # Enums
67
+ "AttributeType",
68
+ "DataType",
69
+ # Types
70
+ "OperatorIdentifier",
71
+ # Protobuf compatible types
72
+ "TensorProtoTensor",
73
+ # Conversion functions
74
+ "from_proto",
75
+ "from_onnx_text",
76
+ "to_proto",
77
+ "to_onnx_text",
78
+ # Convenience constructors
79
+ "tensor",
80
+ "node",
81
+ "val",
82
+ # Pass infrastructure
83
+ "passes",
84
+ # IO
85
+ "load",
86
+ "save",
87
+ "save_safetensors",
88
+ # Flags
89
+ "DEBUG",
90
+ # Others
91
+ "set_value_magic_handler",
92
+ ]
93
+
94
+ import types
95
+
96
+ from onnx_ir import convenience, external_data, passes, serde, tape, traversal
97
+ from onnx_ir._convenience._constructors import node, tensor, val
98
+ from onnx_ir._core import (
99
+ Attr,
100
+ AttrFloat32,
101
+ AttrFloat32s,
102
+ AttrGraph,
103
+ AttrGraphs,
104
+ AttrInt64,
105
+ AttrInt64s,
106
+ AttrSparseTensor,
107
+ AttrSparseTensors,
108
+ AttrString,
109
+ AttrStrings,
110
+ AttrTensor,
111
+ AttrTensors,
112
+ AttrTypeProto,
113
+ AttrTypeProtos,
114
+ ExternalTensor,
115
+ Function,
116
+ Graph,
117
+ GraphView,
118
+ Input,
119
+ LazyTensor,
120
+ Model,
121
+ Node,
122
+ OptionalType,
123
+ PackedTensor,
124
+ RefAttr,
125
+ SequenceType,
126
+ Shape,
127
+ SparseTensorType,
128
+ StringTensor,
129
+ SymbolicDim,
130
+ Tensor,
131
+ TensorType,
132
+ TypeAndShape,
133
+ Value,
134
+ set_value_magic_handler,
135
+ )
136
+ from onnx_ir._enums import (
137
+ AttributeType,
138
+ DataType,
139
+ )
140
+ from onnx_ir._io import load, save
141
+ from onnx_ir._protocols import (
142
+ ArrayCompatible,
143
+ AttributeProtocol,
144
+ DLPackCompatible,
145
+ FunctionProtocol,
146
+ GraphProtocol,
147
+ GraphViewProtocol,
148
+ MapTypeProtocol,
149
+ ModelProtocol,
150
+ NodeProtocol,
151
+ OperatorIdentifier,
152
+ ReferenceAttributeProtocol,
153
+ ShapeProtocol,
154
+ SparseTensorProtocol,
155
+ SymbolicDimProtocol,
156
+ TensorProtocol,
157
+ TypeProtocol,
158
+ ValueProtocol,
159
+ )
160
+ from onnx_ir._safetensors import save_safetensors
161
+ from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
162
+
163
+ DEBUG = False
164
+
165
+
166
+ def __set_module() -> None:
167
+ """Set the module of all functions in this module to this public module."""
168
+ global_dict = globals()
169
+ for name in __all__:
170
+ obj = global_dict[name]
171
+ if hasattr(obj, "__module__") and not isinstance(obj, types.GenericAlias):
172
+ obj.__module__ = __name__
173
+
174
+
175
+ __set_module()
176
+ __version__ = "0.1.15"
onnx_ir/_cloner.py ADDED
@@ -0,0 +1,229 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Logic for cloning graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import functools
8
+ import typing
9
+ from collections.abc import Callable, Mapping
10
+ from typing import TypeVar
11
+
12
+ from typing_extensions import Concatenate, ParamSpec
13
+
14
+ from onnx_ir import _core, _enums
15
+
16
+ P = ParamSpec("P")
17
+ R = TypeVar("R")
18
+
19
+
20
+ def _capture_error_context(
21
+ func: Callable[Concatenate[Cloner, P], R],
22
+ ) -> Callable[Concatenate[Cloner, P], R]:
23
+ """Decorator to capture error context during cloning."""
24
+
25
+ @functools.wraps(func)
26
+ def wrapper(self: Cloner, *args: P.args, **kwargs: P.kwargs) -> R:
27
+ try:
28
+ return func(self, *args, **kwargs)
29
+ except Exception as e:
30
+ raise RuntimeError(
31
+ f"In {func.__name__} with args {args!r} and kwargs {kwargs!r}"
32
+ ) from e
33
+
34
+ return wrapper
35
+
36
+
37
+ class Cloner:
38
+ """Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ attr_map: Mapping[str, _core.Attr],
44
+ value_map: dict[_core.Value, _core.Value | None],
45
+ metadata_props: dict[str, str],
46
+ post_process: Callable[[_core.Node], None] = lambda _: None,
47
+ resolve_ref_attrs: bool = False,
48
+ allow_outer_scope_values: bool = False,
49
+ ) -> None:
50
+ """Initializes the cloner.
51
+
52
+ Args:
53
+ attr_map: A mapping from attribute names to attributes to substitute, used when
54
+ inlining functions.
55
+ value_map: A mapping from original values to cloned values. If a value is not in
56
+ this map, it is assumed to be a graph input and will be cloned as a new value.
57
+ metadata_props: Metadata properties to add to cloned nodes.
58
+ post_process: A callback invoked after cloning each node, allowing for additional
59
+ processing on the cloned node.
60
+ resolve_ref_attrs: Whether to resolve reference attributes using the attr_map.
61
+ Set to True when inlining functions.
62
+ allow_outer_scope_values: When True, values that are from outer scopes
63
+ (not defined in this graph) will not be cloned. Instead, the cloned
64
+ graph will reference the same outer scope values. This is useful
65
+ when cloning subgraphs that reference values from the outer graph.
66
+ When False (default), values from outer scopes will cause an error if they
67
+ are referenced in the cloned graph.
68
+ """
69
+ self._value_map = value_map
70
+ self._attr_map = attr_map
71
+ self._metadata_props = metadata_props
72
+ self._post_process = post_process
73
+ self._resolve_ref_attrs = resolve_ref_attrs
74
+ self._allow_outer_scope_values = allow_outer_scope_values
75
+
76
+ @_capture_error_context
77
+ def _get_value(self, value: _core.Value) -> _core.Value | None:
78
+ return self._value_map[value]
79
+
80
+ @_capture_error_context
81
+ def _clone_or_get_value(self, value: _core.Value) -> _core.Value:
82
+ if value in self._value_map:
83
+ known_value = self._value_map[value]
84
+ assert known_value is not None, f"BUG: Value {value} mapped to None in value map"
85
+ return known_value
86
+ # If the value is not in the value map, it must be a graph input.
87
+ # Note: value.producer() may not be None when the value is an input of a GraphView
88
+ new_value = _core.Value(
89
+ name=value.name,
90
+ type=value.type,
91
+ shape=value.shape.copy() if value.shape is not None else None,
92
+ doc_string=value.doc_string,
93
+ const_value=value.const_value,
94
+ )
95
+ if value.metadata_props:
96
+ new_value.metadata_props.update(value.metadata_props)
97
+ if value.meta:
98
+ new_value.meta.update(value.meta)
99
+ self._value_map[value] = new_value
100
+ return new_value
101
+
102
+ @_capture_error_context
103
+ def clone_attr(self, key: str, attr: _core.Attr) -> _core.Attr | None:
104
+ if not attr.is_ref():
105
+ if attr.type == _enums.AttributeType.GRAPH:
106
+ graph = self.clone_graph(attr.as_graph())
107
+ return _core.Attr(
108
+ key, _enums.AttributeType.GRAPH, graph, doc_string=attr.doc_string
109
+ )
110
+ elif attr.type == _enums.AttributeType.GRAPHS:
111
+ graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
112
+ return _core.Attr(
113
+ key, _enums.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
114
+ )
115
+ return attr
116
+
117
+ assert attr.is_ref()
118
+ if not self._resolve_ref_attrs:
119
+ return attr
120
+
121
+ ref_attr_name = attr.ref_attr_name
122
+ if ref_attr_name is None:
123
+ raise ValueError("Reference attribute must have a name")
124
+ if ref_attr_name in self._attr_map:
125
+ ref_attr = self._attr_map[ref_attr_name]
126
+ if not ref_attr.is_ref():
127
+ return _core.Attr(
128
+ key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
129
+ )
130
+
131
+ # When inlining into a function, we resolve reference attributes to other reference
132
+ # attributes declared in the parent scope.
133
+ assert ref_attr.ref_attr_name is not None
134
+ return _core.RefAttr(
135
+ key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
136
+ )
137
+ # Note that if a function has an attribute-parameter X, and a call (node) to the function
138
+ # has no attribute X, all references to X in nodes inside the function body will be
139
+ # removed. This is just the ONNX representation of optional-attributes.
140
+ return None
141
+
142
+ @_capture_error_context
143
+ def clone_node(self, node: _core.Node) -> _core.Node:
144
+ new_inputs: list[_core.Value | None] = []
145
+ for input in node.inputs:
146
+ if input is None:
147
+ new_inputs.append(input)
148
+ elif input not in self._value_map:
149
+ # If the node input cannot be found in the value map, it must be an outer-scope
150
+ # value, given that the nodes are sorted topologically.
151
+ if not self._allow_outer_scope_values:
152
+ graph_name = (
153
+ input.graph.name or "<anonymous>" if input.graph else "<unknown>"
154
+ )
155
+ raise ValueError(
156
+ f"Value '{input}' used by node '{node}' is an outer-scope value (from graph '{graph_name}'), "
157
+ "but 'allow_outer_scope_values' is set to False. Consider creating a GraphView and add the value to its "
158
+ "inputs then clone, or setting 'allow_outer_scope_values' to True to allow referencing outer-scope values."
159
+ )
160
+ # When preserving outer-scope values, pass them through unchanged instead of cloning.
161
+ new_inputs.append(input)
162
+ else:
163
+ new_inputs.append(self._get_value(input))
164
+ new_attributes = [
165
+ new_value
166
+ for key, value in node.attributes.items()
167
+ if (new_value := self.clone_attr(key, value)) is not None
168
+ ]
169
+
170
+ new_metadata = {**self._metadata_props, **node.metadata_props}
171
+ # TODO: For now, node metadata overrides callnode metadata if there is a conflict.
172
+ # Do we need to preserve both?
173
+
174
+ new_node = _core.Node(
175
+ node.domain,
176
+ node.op_type,
177
+ new_inputs,
178
+ new_attributes,
179
+ overload=node.overload,
180
+ num_outputs=len(node.outputs),
181
+ version=node.version,
182
+ name=node.name,
183
+ doc_string=node.doc_string,
184
+ metadata_props=new_metadata,
185
+ )
186
+ if node.meta:
187
+ new_node.meta.update(node.meta)
188
+
189
+ # Copy output properties
190
+ for output, new_output in zip(node.outputs, new_node.outputs):
191
+ self._value_map[output] = new_output
192
+ new_output.name = output.name
193
+ new_output.shape = output.shape.copy() if output.shape is not None else None
194
+ new_output.type = output.type
195
+ new_output.const_value = output.const_value
196
+ new_output.doc_string = output.doc_string
197
+ if output.metadata_props:
198
+ new_output.metadata_props.update(output.metadata_props)
199
+ if output.meta:
200
+ new_output.meta.update(output.meta)
201
+
202
+ self._post_process(new_node)
203
+ return new_node
204
+
205
+ @_capture_error_context
206
+ def clone_graph(self, graph: _core.Graph | _core.GraphView) -> _core.Graph:
207
+ """Clones a graph with shared TensorProtocols."""
208
+ input_values = [self._clone_or_get_value(v) for v in graph.inputs]
209
+ initializers = [self._clone_or_get_value(v) for v in graph.initializers.values()]
210
+ nodes = [self.clone_node(node) for node in graph]
211
+ # Looks up already cloned values. Here we know graph outputs will not be None
212
+ output_values = typing.cast(
213
+ list["_core.Value"], [self._get_value(v) for v in graph.outputs]
214
+ )
215
+
216
+ new_graph = _core.Graph(
217
+ input_values,
218
+ output_values,
219
+ nodes=nodes,
220
+ initializers=initializers,
221
+ doc_string=graph.doc_string,
222
+ opset_imports=graph.opset_imports.copy(),
223
+ name=graph.name,
224
+ )
225
+ if graph.metadata_props:
226
+ new_graph.metadata_props.update(graph.metadata_props)
227
+ if graph.meta:
228
+ new_graph.meta.update(graph.meta)
229
+ return new_graph