onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/__init__.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""In-memory intermediate representation for ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
# Modules
|
|
7
|
+
"serde",
|
|
8
|
+
"traversal",
|
|
9
|
+
"convenience",
|
|
10
|
+
"external_data",
|
|
11
|
+
"tape",
|
|
12
|
+
# IR classes
|
|
13
|
+
"Tensor",
|
|
14
|
+
"ExternalTensor",
|
|
15
|
+
"StringTensor",
|
|
16
|
+
"LazyTensor",
|
|
17
|
+
"PackedTensor",
|
|
18
|
+
"SymbolicDim",
|
|
19
|
+
"Shape",
|
|
20
|
+
"TensorType",
|
|
21
|
+
"OptionalType",
|
|
22
|
+
"SequenceType",
|
|
23
|
+
"SparseTensorType",
|
|
24
|
+
"TypeAndShape",
|
|
25
|
+
"Value",
|
|
26
|
+
"Attr",
|
|
27
|
+
"RefAttr",
|
|
28
|
+
"Node",
|
|
29
|
+
"Function",
|
|
30
|
+
"Graph",
|
|
31
|
+
"GraphView",
|
|
32
|
+
"Model",
|
|
33
|
+
# Constructors
|
|
34
|
+
"AttrFloat32",
|
|
35
|
+
"AttrFloat32s",
|
|
36
|
+
"AttrGraph",
|
|
37
|
+
"AttrGraphs",
|
|
38
|
+
"AttrInt64",
|
|
39
|
+
"AttrInt64s",
|
|
40
|
+
"AttrSparseTensor",
|
|
41
|
+
"AttrSparseTensors",
|
|
42
|
+
"AttrString",
|
|
43
|
+
"AttrStrings",
|
|
44
|
+
"AttrTensor",
|
|
45
|
+
"AttrTensors",
|
|
46
|
+
"AttrTypeProto",
|
|
47
|
+
"AttrTypeProtos",
|
|
48
|
+
"Input",
|
|
49
|
+
# Protocols
|
|
50
|
+
"ArrayCompatible",
|
|
51
|
+
"DLPackCompatible",
|
|
52
|
+
"TensorProtocol",
|
|
53
|
+
"ValueProtocol",
|
|
54
|
+
"ModelProtocol",
|
|
55
|
+
"NodeProtocol",
|
|
56
|
+
"GraphProtocol",
|
|
57
|
+
"GraphViewProtocol",
|
|
58
|
+
"AttributeProtocol",
|
|
59
|
+
"ReferenceAttributeProtocol",
|
|
60
|
+
"SparseTensorProtocol",
|
|
61
|
+
"SymbolicDimProtocol",
|
|
62
|
+
"ShapeProtocol",
|
|
63
|
+
"TypeProtocol",
|
|
64
|
+
"MapTypeProtocol",
|
|
65
|
+
"FunctionProtocol",
|
|
66
|
+
# Enums
|
|
67
|
+
"AttributeType",
|
|
68
|
+
"DataType",
|
|
69
|
+
# Types
|
|
70
|
+
"OperatorIdentifier",
|
|
71
|
+
# Protobuf compatible types
|
|
72
|
+
"TensorProtoTensor",
|
|
73
|
+
# Conversion functions
|
|
74
|
+
"from_proto",
|
|
75
|
+
"from_onnx_text",
|
|
76
|
+
"to_proto",
|
|
77
|
+
"to_onnx_text",
|
|
78
|
+
# Convenience constructors
|
|
79
|
+
"tensor",
|
|
80
|
+
"node",
|
|
81
|
+
"val",
|
|
82
|
+
# Pass infrastructure
|
|
83
|
+
"passes",
|
|
84
|
+
# IO
|
|
85
|
+
"load",
|
|
86
|
+
"save",
|
|
87
|
+
"save_safetensors",
|
|
88
|
+
# Flags
|
|
89
|
+
"DEBUG",
|
|
90
|
+
# Others
|
|
91
|
+
"set_value_magic_handler",
|
|
92
|
+
]
|
|
93
|
+
|
|
94
|
+
import types
|
|
95
|
+
|
|
96
|
+
from onnx_ir import convenience, external_data, passes, serde, tape, traversal
|
|
97
|
+
from onnx_ir._convenience._constructors import node, tensor, val
|
|
98
|
+
from onnx_ir._core import (
|
|
99
|
+
Attr,
|
|
100
|
+
AttrFloat32,
|
|
101
|
+
AttrFloat32s,
|
|
102
|
+
AttrGraph,
|
|
103
|
+
AttrGraphs,
|
|
104
|
+
AttrInt64,
|
|
105
|
+
AttrInt64s,
|
|
106
|
+
AttrSparseTensor,
|
|
107
|
+
AttrSparseTensors,
|
|
108
|
+
AttrString,
|
|
109
|
+
AttrStrings,
|
|
110
|
+
AttrTensor,
|
|
111
|
+
AttrTensors,
|
|
112
|
+
AttrTypeProto,
|
|
113
|
+
AttrTypeProtos,
|
|
114
|
+
ExternalTensor,
|
|
115
|
+
Function,
|
|
116
|
+
Graph,
|
|
117
|
+
GraphView,
|
|
118
|
+
Input,
|
|
119
|
+
LazyTensor,
|
|
120
|
+
Model,
|
|
121
|
+
Node,
|
|
122
|
+
OptionalType,
|
|
123
|
+
PackedTensor,
|
|
124
|
+
RefAttr,
|
|
125
|
+
SequenceType,
|
|
126
|
+
Shape,
|
|
127
|
+
SparseTensorType,
|
|
128
|
+
StringTensor,
|
|
129
|
+
SymbolicDim,
|
|
130
|
+
Tensor,
|
|
131
|
+
TensorType,
|
|
132
|
+
TypeAndShape,
|
|
133
|
+
Value,
|
|
134
|
+
set_value_magic_handler,
|
|
135
|
+
)
|
|
136
|
+
from onnx_ir._enums import (
|
|
137
|
+
AttributeType,
|
|
138
|
+
DataType,
|
|
139
|
+
)
|
|
140
|
+
from onnx_ir._io import load, save
|
|
141
|
+
from onnx_ir._protocols import (
|
|
142
|
+
ArrayCompatible,
|
|
143
|
+
AttributeProtocol,
|
|
144
|
+
DLPackCompatible,
|
|
145
|
+
FunctionProtocol,
|
|
146
|
+
GraphProtocol,
|
|
147
|
+
GraphViewProtocol,
|
|
148
|
+
MapTypeProtocol,
|
|
149
|
+
ModelProtocol,
|
|
150
|
+
NodeProtocol,
|
|
151
|
+
OperatorIdentifier,
|
|
152
|
+
ReferenceAttributeProtocol,
|
|
153
|
+
ShapeProtocol,
|
|
154
|
+
SparseTensorProtocol,
|
|
155
|
+
SymbolicDimProtocol,
|
|
156
|
+
TensorProtocol,
|
|
157
|
+
TypeProtocol,
|
|
158
|
+
ValueProtocol,
|
|
159
|
+
)
|
|
160
|
+
from onnx_ir._safetensors import save_safetensors
|
|
161
|
+
from onnx_ir.serde import TensorProtoTensor, from_onnx_text, from_proto, to_onnx_text, to_proto
|
|
162
|
+
|
|
163
|
+
DEBUG = False
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def __set_module() -> None:
|
|
167
|
+
"""Set the module of all functions in this module to this public module."""
|
|
168
|
+
global_dict = globals()
|
|
169
|
+
for name in __all__:
|
|
170
|
+
obj = global_dict[name]
|
|
171
|
+
if hasattr(obj, "__module__") and not isinstance(obj, types.GenericAlias):
|
|
172
|
+
obj.__module__ = __name__
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
__set_module()
|
|
176
|
+
__version__ = "0.1.15"
|
onnx_ir/_cloner.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Logic for cloning graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import functools
|
|
8
|
+
import typing
|
|
9
|
+
from collections.abc import Callable, Mapping
|
|
10
|
+
from typing import TypeVar
|
|
11
|
+
|
|
12
|
+
from typing_extensions import Concatenate, ParamSpec
|
|
13
|
+
|
|
14
|
+
from onnx_ir import _core, _enums
|
|
15
|
+
|
|
16
|
+
P = ParamSpec("P")
|
|
17
|
+
R = TypeVar("R")
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _capture_error_context(
|
|
21
|
+
func: Callable[Concatenate[Cloner, P], R],
|
|
22
|
+
) -> Callable[Concatenate[Cloner, P], R]:
|
|
23
|
+
"""Decorator to capture error context during cloning."""
|
|
24
|
+
|
|
25
|
+
@functools.wraps(func)
|
|
26
|
+
def wrapper(self: Cloner, *args: P.args, **kwargs: P.kwargs) -> R:
|
|
27
|
+
try:
|
|
28
|
+
return func(self, *args, **kwargs)
|
|
29
|
+
except Exception as e:
|
|
30
|
+
raise RuntimeError(
|
|
31
|
+
f"In {func.__name__} with args {args!r} and kwargs {kwargs!r}"
|
|
32
|
+
) from e
|
|
33
|
+
|
|
34
|
+
return wrapper
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Cloner:
|
|
38
|
+
"""Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
attr_map: Mapping[str, _core.Attr],
|
|
44
|
+
value_map: dict[_core.Value, _core.Value | None],
|
|
45
|
+
metadata_props: dict[str, str],
|
|
46
|
+
post_process: Callable[[_core.Node], None] = lambda _: None,
|
|
47
|
+
resolve_ref_attrs: bool = False,
|
|
48
|
+
allow_outer_scope_values: bool = False,
|
|
49
|
+
) -> None:
|
|
50
|
+
"""Initializes the cloner.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
attr_map: A mapping from attribute names to attributes to substitute, used when
|
|
54
|
+
inlining functions.
|
|
55
|
+
value_map: A mapping from original values to cloned values. If a value is not in
|
|
56
|
+
this map, it is assumed to be a graph input and will be cloned as a new value.
|
|
57
|
+
metadata_props: Metadata properties to add to cloned nodes.
|
|
58
|
+
post_process: A callback invoked after cloning each node, allowing for additional
|
|
59
|
+
processing on the cloned node.
|
|
60
|
+
resolve_ref_attrs: Whether to resolve reference attributes using the attr_map.
|
|
61
|
+
Set to True when inlining functions.
|
|
62
|
+
allow_outer_scope_values: When True, values that are from outer scopes
|
|
63
|
+
(not defined in this graph) will not be cloned. Instead, the cloned
|
|
64
|
+
graph will reference the same outer scope values. This is useful
|
|
65
|
+
when cloning subgraphs that reference values from the outer graph.
|
|
66
|
+
When False (default), values from outer scopes will cause an error if they
|
|
67
|
+
are referenced in the cloned graph.
|
|
68
|
+
"""
|
|
69
|
+
self._value_map = value_map
|
|
70
|
+
self._attr_map = attr_map
|
|
71
|
+
self._metadata_props = metadata_props
|
|
72
|
+
self._post_process = post_process
|
|
73
|
+
self._resolve_ref_attrs = resolve_ref_attrs
|
|
74
|
+
self._allow_outer_scope_values = allow_outer_scope_values
|
|
75
|
+
|
|
76
|
+
@_capture_error_context
|
|
77
|
+
def _get_value(self, value: _core.Value) -> _core.Value | None:
|
|
78
|
+
return self._value_map[value]
|
|
79
|
+
|
|
80
|
+
@_capture_error_context
|
|
81
|
+
def _clone_or_get_value(self, value: _core.Value) -> _core.Value:
|
|
82
|
+
if value in self._value_map:
|
|
83
|
+
known_value = self._value_map[value]
|
|
84
|
+
assert known_value is not None, f"BUG: Value {value} mapped to None in value map"
|
|
85
|
+
return known_value
|
|
86
|
+
# If the value is not in the value map, it must be a graph input.
|
|
87
|
+
# Note: value.producer() may not be None when the value is an input of a GraphView
|
|
88
|
+
new_value = _core.Value(
|
|
89
|
+
name=value.name,
|
|
90
|
+
type=value.type,
|
|
91
|
+
shape=value.shape.copy() if value.shape is not None else None,
|
|
92
|
+
doc_string=value.doc_string,
|
|
93
|
+
const_value=value.const_value,
|
|
94
|
+
)
|
|
95
|
+
if value.metadata_props:
|
|
96
|
+
new_value.metadata_props.update(value.metadata_props)
|
|
97
|
+
if value.meta:
|
|
98
|
+
new_value.meta.update(value.meta)
|
|
99
|
+
self._value_map[value] = new_value
|
|
100
|
+
return new_value
|
|
101
|
+
|
|
102
|
+
@_capture_error_context
|
|
103
|
+
def clone_attr(self, key: str, attr: _core.Attr) -> _core.Attr | None:
|
|
104
|
+
if not attr.is_ref():
|
|
105
|
+
if attr.type == _enums.AttributeType.GRAPH:
|
|
106
|
+
graph = self.clone_graph(attr.as_graph())
|
|
107
|
+
return _core.Attr(
|
|
108
|
+
key, _enums.AttributeType.GRAPH, graph, doc_string=attr.doc_string
|
|
109
|
+
)
|
|
110
|
+
elif attr.type == _enums.AttributeType.GRAPHS:
|
|
111
|
+
graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
|
|
112
|
+
return _core.Attr(
|
|
113
|
+
key, _enums.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
|
|
114
|
+
)
|
|
115
|
+
return attr
|
|
116
|
+
|
|
117
|
+
assert attr.is_ref()
|
|
118
|
+
if not self._resolve_ref_attrs:
|
|
119
|
+
return attr
|
|
120
|
+
|
|
121
|
+
ref_attr_name = attr.ref_attr_name
|
|
122
|
+
if ref_attr_name is None:
|
|
123
|
+
raise ValueError("Reference attribute must have a name")
|
|
124
|
+
if ref_attr_name in self._attr_map:
|
|
125
|
+
ref_attr = self._attr_map[ref_attr_name]
|
|
126
|
+
if not ref_attr.is_ref():
|
|
127
|
+
return _core.Attr(
|
|
128
|
+
key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
# When inlining into a function, we resolve reference attributes to other reference
|
|
132
|
+
# attributes declared in the parent scope.
|
|
133
|
+
assert ref_attr.ref_attr_name is not None
|
|
134
|
+
return _core.RefAttr(
|
|
135
|
+
key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
|
|
136
|
+
)
|
|
137
|
+
# Note that if a function has an attribute-parameter X, and a call (node) to the function
|
|
138
|
+
# has no attribute X, all references to X in nodes inside the function body will be
|
|
139
|
+
# removed. This is just the ONNX representation of optional-attributes.
|
|
140
|
+
return None
|
|
141
|
+
|
|
142
|
+
@_capture_error_context
|
|
143
|
+
def clone_node(self, node: _core.Node) -> _core.Node:
|
|
144
|
+
new_inputs: list[_core.Value | None] = []
|
|
145
|
+
for input in node.inputs:
|
|
146
|
+
if input is None:
|
|
147
|
+
new_inputs.append(input)
|
|
148
|
+
elif input not in self._value_map:
|
|
149
|
+
# If the node input cannot be found in the value map, it must be an outer-scope
|
|
150
|
+
# value, given that the nodes are sorted topologically.
|
|
151
|
+
if not self._allow_outer_scope_values:
|
|
152
|
+
graph_name = (
|
|
153
|
+
input.graph.name or "<anonymous>" if input.graph else "<unknown>"
|
|
154
|
+
)
|
|
155
|
+
raise ValueError(
|
|
156
|
+
f"Value '{input}' used by node '{node}' is an outer-scope value (from graph '{graph_name}'), "
|
|
157
|
+
"but 'allow_outer_scope_values' is set to False. Consider creating a GraphView and add the value to its "
|
|
158
|
+
"inputs then clone, or setting 'allow_outer_scope_values' to True to allow referencing outer-scope values."
|
|
159
|
+
)
|
|
160
|
+
# When preserving outer-scope values, pass them through unchanged instead of cloning.
|
|
161
|
+
new_inputs.append(input)
|
|
162
|
+
else:
|
|
163
|
+
new_inputs.append(self._get_value(input))
|
|
164
|
+
new_attributes = [
|
|
165
|
+
new_value
|
|
166
|
+
for key, value in node.attributes.items()
|
|
167
|
+
if (new_value := self.clone_attr(key, value)) is not None
|
|
168
|
+
]
|
|
169
|
+
|
|
170
|
+
new_metadata = {**self._metadata_props, **node.metadata_props}
|
|
171
|
+
# TODO: For now, node metadata overrides callnode metadata if there is a conflict.
|
|
172
|
+
# Do we need to preserve both?
|
|
173
|
+
|
|
174
|
+
new_node = _core.Node(
|
|
175
|
+
node.domain,
|
|
176
|
+
node.op_type,
|
|
177
|
+
new_inputs,
|
|
178
|
+
new_attributes,
|
|
179
|
+
overload=node.overload,
|
|
180
|
+
num_outputs=len(node.outputs),
|
|
181
|
+
version=node.version,
|
|
182
|
+
name=node.name,
|
|
183
|
+
doc_string=node.doc_string,
|
|
184
|
+
metadata_props=new_metadata,
|
|
185
|
+
)
|
|
186
|
+
if node.meta:
|
|
187
|
+
new_node.meta.update(node.meta)
|
|
188
|
+
|
|
189
|
+
# Copy output properties
|
|
190
|
+
for output, new_output in zip(node.outputs, new_node.outputs):
|
|
191
|
+
self._value_map[output] = new_output
|
|
192
|
+
new_output.name = output.name
|
|
193
|
+
new_output.shape = output.shape.copy() if output.shape is not None else None
|
|
194
|
+
new_output.type = output.type
|
|
195
|
+
new_output.const_value = output.const_value
|
|
196
|
+
new_output.doc_string = output.doc_string
|
|
197
|
+
if output.metadata_props:
|
|
198
|
+
new_output.metadata_props.update(output.metadata_props)
|
|
199
|
+
if output.meta:
|
|
200
|
+
new_output.meta.update(output.meta)
|
|
201
|
+
|
|
202
|
+
self._post_process(new_node)
|
|
203
|
+
return new_node
|
|
204
|
+
|
|
205
|
+
@_capture_error_context
|
|
206
|
+
def clone_graph(self, graph: _core.Graph | _core.GraphView) -> _core.Graph:
|
|
207
|
+
"""Clones a graph with shared TensorProtocols."""
|
|
208
|
+
input_values = [self._clone_or_get_value(v) for v in graph.inputs]
|
|
209
|
+
initializers = [self._clone_or_get_value(v) for v in graph.initializers.values()]
|
|
210
|
+
nodes = [self.clone_node(node) for node in graph]
|
|
211
|
+
# Looks up already cloned values. Here we know graph outputs will not be None
|
|
212
|
+
output_values = typing.cast(
|
|
213
|
+
list["_core.Value"], [self._get_value(v) for v in graph.outputs]
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
new_graph = _core.Graph(
|
|
217
|
+
input_values,
|
|
218
|
+
output_values,
|
|
219
|
+
nodes=nodes,
|
|
220
|
+
initializers=initializers,
|
|
221
|
+
doc_string=graph.doc_string,
|
|
222
|
+
opset_imports=graph.opset_imports.copy(),
|
|
223
|
+
name=graph.name,
|
|
224
|
+
)
|
|
225
|
+
if graph.metadata_props:
|
|
226
|
+
new_graph.metadata_props.update(graph.metadata_props)
|
|
227
|
+
if graph.meta:
|
|
228
|
+
new_graph.meta.update(graph.meta)
|
|
229
|
+
return new_graph
|