onnx-ir 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +125 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +291 -76
- onnx_ir/_enums.py +35 -25
- onnx_ir/_graph_containers.py +114 -9
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +10 -25
- onnx_ir/passes/common/inliner.py +4 -3
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +171 -6
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/METADATA +22 -4
- onnx_ir-0.1.2.dist-info/RECORD +42 -0
- onnx_ir-0.1.0.dist-info/RECORD +0 -41
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.0.dist-info → onnx_ir-0.1.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
21
|
+
|
|
22
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
|
+
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
+
modified = False
|
|
25
|
+
graph = model.graph
|
|
26
|
+
|
|
27
|
+
modified = _eliminate_common_subexpression(graph, modified)
|
|
28
|
+
|
|
29
|
+
return ir.passes.PassResult(
|
|
30
|
+
model,
|
|
31
|
+
modified=modified,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
|
|
36
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
37
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
38
|
+
existing_node_info_to_the_node: dict[
|
|
39
|
+
tuple[
|
|
40
|
+
ir.OperatorIdentifier,
|
|
41
|
+
int, # len(outputs)
|
|
42
|
+
tuple[int, ...], # input ids
|
|
43
|
+
tuple[tuple[str, object], ...], # attributes
|
|
44
|
+
],
|
|
45
|
+
ir.Node,
|
|
46
|
+
] = {}
|
|
47
|
+
|
|
48
|
+
for node in graph:
|
|
49
|
+
# Skip control flow ops like Loop and If.
|
|
50
|
+
control_flow_op: bool = False
|
|
51
|
+
# Use equality to check if the node is a common subexpression.
|
|
52
|
+
attributes = {}
|
|
53
|
+
for k, v in node.attributes.items():
|
|
54
|
+
# TODO(exporter team): CSE subgraphs.
|
|
55
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
56
|
+
# because attribute: graph won't match.
|
|
57
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
58
|
+
control_flow_op = True
|
|
59
|
+
logger.debug("Skipping control flow op %s", node)
|
|
60
|
+
# The attribute value could be directly taken from the original
|
|
61
|
+
# protobuf, so we need to make a copy of it.
|
|
62
|
+
value = v.value
|
|
63
|
+
if v.type in (
|
|
64
|
+
ir.AttributeType.INTS,
|
|
65
|
+
ir.AttributeType.FLOATS,
|
|
66
|
+
ir.AttributeType.STRINGS,
|
|
67
|
+
):
|
|
68
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
69
|
+
# to ensure they are hashable.
|
|
70
|
+
value = tuple(value)
|
|
71
|
+
attributes[k] = value
|
|
72
|
+
|
|
73
|
+
if control_flow_op:
|
|
74
|
+
# If the node is a control flow op, we skip it.
|
|
75
|
+
logger.debug("Skipping control flow op %s", node)
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
if _is_non_deterministic_op(node):
|
|
79
|
+
# If the node is a non-deterministic op, we skip it.
|
|
80
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
node_info = (
|
|
84
|
+
node.op_identifier(),
|
|
85
|
+
len(node.outputs),
|
|
86
|
+
tuple(id(input) for input in node.inputs),
|
|
87
|
+
tuple(sorted(attributes.items())),
|
|
88
|
+
)
|
|
89
|
+
# Check if the node is a common subexpression.
|
|
90
|
+
if node_info in existing_node_info_to_the_node:
|
|
91
|
+
# If it is, this node has an existing node with the same
|
|
92
|
+
# operator, number of outputs, inputs, and attributes.
|
|
93
|
+
# We replace the node with the existing node.
|
|
94
|
+
modified = True
|
|
95
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
+
_remove_node_and_replace_values(
|
|
97
|
+
graph,
|
|
98
|
+
remove_node=node,
|
|
99
|
+
remove_values=node.outputs,
|
|
100
|
+
new_values=existing_node.outputs,
|
|
101
|
+
)
|
|
102
|
+
logger.debug("Reusing node %s", existing_node)
|
|
103
|
+
else:
|
|
104
|
+
# If it is not, add to the mapping.
|
|
105
|
+
existing_node_info_to_the_node[node_info] = node
|
|
106
|
+
return modified
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def _remove_node_and_replace_values(
|
|
110
|
+
graph: ir.Graph,
|
|
111
|
+
/,
|
|
112
|
+
remove_node: ir.Node,
|
|
113
|
+
remove_values: Sequence[ir.Value],
|
|
114
|
+
new_values: Sequence[ir.Value],
|
|
115
|
+
) -> None:
|
|
116
|
+
"""Replaces nodes and values in the graph or function.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
graph: The graph to replace nodes and values in.
|
|
120
|
+
remove_node: The node to remove.
|
|
121
|
+
remove_values: The values to replace.
|
|
122
|
+
new_values: The values to replace with.
|
|
123
|
+
"""
|
|
124
|
+
# Reconnect the users of the deleted values to use the new values
|
|
125
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
126
|
+
# Update graph/function outputs if the node generates output
|
|
127
|
+
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
128
|
+
replacement_mapping = dict(zip(remove_values, new_values))
|
|
129
|
+
for idx, graph_output in enumerate(graph.outputs):
|
|
130
|
+
if graph_output in replacement_mapping:
|
|
131
|
+
new_value = replacement_mapping[graph_output]
|
|
132
|
+
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
133
|
+
# If the new value is also a graph input/output, we need to
|
|
134
|
+
# create a Identity node to preserve the remove_value and
|
|
135
|
+
# prevent from changing new_value name.
|
|
136
|
+
identity_node = ir.node(
|
|
137
|
+
"Identity",
|
|
138
|
+
inputs=[new_value],
|
|
139
|
+
outputs=[
|
|
140
|
+
ir.Value(
|
|
141
|
+
name=graph_output.name,
|
|
142
|
+
type=graph_output.type,
|
|
143
|
+
shape=graph_output.shape,
|
|
144
|
+
)
|
|
145
|
+
],
|
|
146
|
+
)
|
|
147
|
+
# reuse the name of the graph output
|
|
148
|
+
graph.outputs[idx] = identity_node.outputs[0]
|
|
149
|
+
graph.insert_before(
|
|
150
|
+
remove_node,
|
|
151
|
+
identity_node,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
# if new_value is not graph output, we just
|
|
155
|
+
# update it to use old_value name.
|
|
156
|
+
new_value.name = graph_output.name
|
|
157
|
+
graph.outputs[idx] = new_value
|
|
158
|
+
|
|
159
|
+
graph.remove(remove_node, safe=True)
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
163
|
+
non_deterministic_ops = frozenset(
|
|
164
|
+
{
|
|
165
|
+
"RandomUniform",
|
|
166
|
+
"RandomNormal",
|
|
167
|
+
"RandomUniformLike",
|
|
168
|
+
"RandomNormalLike",
|
|
169
|
+
"Multinomial",
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _is_onnx_domain(d: str) -> bool:
|
|
176
|
+
"""Check if the domain is the ONNX domain."""
|
|
177
|
+
return d == ""
|
|
@@ -139,35 +139,11 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
139
139
|
for further processing or optimization.
|
|
140
140
|
|
|
141
141
|
Initializers that are also graph inputs will not be lifted.
|
|
142
|
-
|
|
143
|
-
Preconditions:
|
|
144
|
-
- All initializers in the model must have unique names across the main graph and subgraphs.
|
|
145
142
|
"""
|
|
146
143
|
|
|
147
|
-
def requires(self, model: ir.Model) -> None:
|
|
148
|
-
"""Ensure all initializer names are unique."""
|
|
149
|
-
registered_initializer_names: set[str] = set()
|
|
150
|
-
duplicated_initializers: list[ir.Value] = []
|
|
151
|
-
for graph in model.graphs():
|
|
152
|
-
for initializer in graph.initializers.values():
|
|
153
|
-
if initializer.name is None:
|
|
154
|
-
raise ir.passes.PreconditionError(
|
|
155
|
-
f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}"
|
|
156
|
-
)
|
|
157
|
-
if initializer.name in registered_initializer_names:
|
|
158
|
-
duplicated_initializers.append(initializer)
|
|
159
|
-
else:
|
|
160
|
-
registered_initializer_names.add(initializer.name)
|
|
161
|
-
if duplicated_initializers:
|
|
162
|
-
raise ir.passes.PreconditionError(
|
|
163
|
-
"Found duplicated initializers in the model. "
|
|
164
|
-
"Initializer name must be unique across the main graph and subgraphs. "
|
|
165
|
-
"Please ensure all initializers have unique names. Duplicated: "
|
|
166
|
-
f"{duplicated_initializers!r}"
|
|
167
|
-
)
|
|
168
|
-
|
|
169
144
|
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
170
145
|
count = 0
|
|
146
|
+
registered_initializer_names: dict[str, int] = {}
|
|
171
147
|
for graph in model.graphs():
|
|
172
148
|
if graph is model.graph:
|
|
173
149
|
continue
|
|
@@ -182,6 +158,15 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
182
158
|
continue
|
|
183
159
|
# Remove the initializer from the subgraph
|
|
184
160
|
graph.initializers.pop(name)
|
|
161
|
+
# To avoid name conflicts, we need to rename the initializer
|
|
162
|
+
# to a unique name in the main graph
|
|
163
|
+
if name in registered_initializer_names:
|
|
164
|
+
name_count = registered_initializer_names[name]
|
|
165
|
+
initializer.name = f"{name}_{name_count}"
|
|
166
|
+
registered_initializer_names[name] = name_count + 1
|
|
167
|
+
else:
|
|
168
|
+
assert initializer.name is not None
|
|
169
|
+
registered_initializer_names[initializer.name] = 1
|
|
185
170
|
model.graph.register_initializer(initializer)
|
|
186
171
|
count += 1
|
|
187
172
|
logger.debug(
|
onnx_ir/passes/common/inliner.py
CHANGED
|
@@ -9,7 +9,7 @@ import dataclasses
|
|
|
9
9
|
__all__ = ["InlinePass", "InlinePassResult"]
|
|
10
10
|
|
|
11
11
|
from collections import defaultdict
|
|
12
|
-
from collections.abc import Iterable, Sequence
|
|
12
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
13
13
|
|
|
14
14
|
import onnx_ir as ir
|
|
15
15
|
import onnx_ir.convenience as _ir_convenience
|
|
@@ -52,7 +52,7 @@ class _CopyReplace:
|
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
54
|
inliner: InlinePass,
|
|
55
|
-
attr_map:
|
|
55
|
+
attr_map: Mapping[str, ir.Attr],
|
|
56
56
|
value_map: dict[ir.Value, ir.Value | None],
|
|
57
57
|
metadata_props: dict[str, str],
|
|
58
58
|
call_stack: CallStack,
|
|
@@ -96,6 +96,7 @@ class _CopyReplace:
|
|
|
96
96
|
return attr
|
|
97
97
|
assert attr.is_ref()
|
|
98
98
|
ref_attr_name = attr.ref_attr_name
|
|
99
|
+
assert ref_attr_name is not None, "Reference attribute must have a name"
|
|
99
100
|
if ref_attr_name in self._attr_map:
|
|
100
101
|
ref_attr = self._attr_map[ref_attr_name]
|
|
101
102
|
if not ref_attr.is_ref():
|
|
@@ -237,7 +238,7 @@ class InlinePass(ir.passes.InPlacePass):
|
|
|
237
238
|
)
|
|
238
239
|
|
|
239
240
|
# Identify substitutions for both inputs and attributes of the function:
|
|
240
|
-
attributes:
|
|
241
|
+
attributes: Mapping[str, ir.Attr] = node.attributes
|
|
241
242
|
default_attr_values = {
|
|
242
243
|
attr.name: attr
|
|
243
244
|
for attr in function.attributes.values()
|
onnx_ir/serde.py
CHANGED
|
@@ -37,6 +37,7 @@ __all__ = [
|
|
|
37
37
|
"deserialize_value_info_proto",
|
|
38
38
|
# Serialization
|
|
39
39
|
"to_proto",
|
|
40
|
+
"to_onnx_text",
|
|
40
41
|
"serialize_attribute_into",
|
|
41
42
|
"serialize_attribute",
|
|
42
43
|
"serialize_dimension_into",
|
|
@@ -62,14 +63,14 @@ __all__ = [
|
|
|
62
63
|
import collections
|
|
63
64
|
import logging
|
|
64
65
|
import os
|
|
65
|
-
from collections.abc import Mapping, Sequence
|
|
66
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
66
67
|
from typing import Any, Callable
|
|
67
68
|
|
|
68
69
|
import numpy as np
|
|
69
|
-
import onnx
|
|
70
|
-
import onnx.external_data_helper
|
|
70
|
+
import onnx # noqa: TID251
|
|
71
|
+
import onnx.external_data_helper # noqa: TID251
|
|
71
72
|
|
|
72
|
-
from onnx_ir import _core, _enums, _protocols, _type_casting
|
|
73
|
+
from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
|
|
73
74
|
|
|
74
75
|
if typing.TYPE_CHECKING:
|
|
75
76
|
import google.protobuf.internal.containers as proto_containers
|
|
@@ -190,13 +191,64 @@ def from_proto(proto: object) -> object:
|
|
|
190
191
|
)
|
|
191
192
|
|
|
192
193
|
|
|
193
|
-
def from_onnx_text(
|
|
194
|
+
def from_onnx_text(
|
|
195
|
+
model_text: str,
|
|
196
|
+
/,
|
|
197
|
+
initializers: Iterable[_protocols.TensorProtocol] | None = None,
|
|
198
|
+
) -> _core.Model:
|
|
194
199
|
"""Convert the ONNX textual representation to an IR model.
|
|
195
200
|
|
|
196
201
|
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
model_text: The ONNX textual representation of the model.
|
|
205
|
+
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
206
|
+
will be added to the model as initializers. If a name does not exist in the model,
|
|
207
|
+
a ValueError will be raised.
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
The IR model corresponding to the ONNX textual representation.
|
|
211
|
+
|
|
212
|
+
Raises:
|
|
213
|
+
ValueError: If a tensor name in `initializers` does not match any value in the model.
|
|
197
214
|
"""
|
|
198
215
|
proto = onnx.parser.parse_model(model_text)
|
|
199
|
-
|
|
216
|
+
model = deserialize_model(proto)
|
|
217
|
+
values = _convenience.create_value_mapping(model.graph)
|
|
218
|
+
if initializers:
|
|
219
|
+
# Add initializers to the model
|
|
220
|
+
for tensor in initializers:
|
|
221
|
+
name = tensor.name
|
|
222
|
+
if not name:
|
|
223
|
+
raise ValueError(
|
|
224
|
+
"Initializer tensor must have a name. "
|
|
225
|
+
f"Please provide a name for the initializer: {tensor}"
|
|
226
|
+
)
|
|
227
|
+
if name not in values:
|
|
228
|
+
raise ValueError(f"Value '{name}' does not exist in model.")
|
|
229
|
+
initializer = values[name]
|
|
230
|
+
initializer.const_value = tensor
|
|
231
|
+
model.graph.register_initializer(initializer)
|
|
232
|
+
return model
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def to_onnx_text(
|
|
236
|
+
model: _protocols.ModelProtocol, /, exclude_initializers: bool = False
|
|
237
|
+
) -> str:
|
|
238
|
+
"""Convert the IR model to the ONNX textual representation.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
model: The IR model to convert.
|
|
242
|
+
exclude_initializers: If True, the initializers will not be included in the output.
|
|
243
|
+
|
|
244
|
+
Returns:
|
|
245
|
+
The ONNX textual representation of the model.
|
|
246
|
+
"""
|
|
247
|
+
proto = serialize_model(model)
|
|
248
|
+
if exclude_initializers:
|
|
249
|
+
del proto.graph.initializer[:]
|
|
250
|
+
text = onnx.printer.to_text(proto)
|
|
251
|
+
return text
|
|
200
252
|
|
|
201
253
|
|
|
202
254
|
@typing.overload
|
|
@@ -462,6 +514,14 @@ def _get_field(proto: Any, field: str) -> Any:
|
|
|
462
514
|
def deserialize_opset_import(
|
|
463
515
|
protos: Sequence[onnx.OperatorSetIdProto],
|
|
464
516
|
) -> dict[str, int]:
|
|
517
|
+
"""Deserialize a sequence of OperatorSetIdProto to opset imports mapping.
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
protos: The sequence of ONNX OperatorSetIdProto objects.
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
A dictionary mapping domain strings to version integers.
|
|
524
|
+
"""
|
|
465
525
|
return {opset.domain: opset.version for opset in protos}
|
|
466
526
|
|
|
467
527
|
|
|
@@ -495,6 +555,14 @@ def _parse_experimental_function_value_info_name(
|
|
|
495
555
|
|
|
496
556
|
|
|
497
557
|
def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
|
|
558
|
+
"""Deserialize an ONNX ModelProto into an IR Model.
|
|
559
|
+
|
|
560
|
+
Args:
|
|
561
|
+
proto: The ONNX ModelProto to deserialize.
|
|
562
|
+
|
|
563
|
+
Returns:
|
|
564
|
+
An IR Model object representing the ONNX model.
|
|
565
|
+
"""
|
|
498
566
|
graph = _deserialize_graph(proto.graph, [])
|
|
499
567
|
graph.opset_imports.update(deserialize_opset_import(proto.opset_import))
|
|
500
568
|
|
|
@@ -699,6 +767,14 @@ def _deserialize_graph(
|
|
|
699
767
|
|
|
700
768
|
@_capture_errors(lambda proto: proto.name)
|
|
701
769
|
def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
770
|
+
"""Deserialize an ONNX FunctionProto into an IR Function.
|
|
771
|
+
|
|
772
|
+
Args:
|
|
773
|
+
proto: The ONNX FunctionProto to deserialize.
|
|
774
|
+
|
|
775
|
+
Returns:
|
|
776
|
+
An IR Function object representing the ONNX function.
|
|
777
|
+
"""
|
|
702
778
|
inputs = [_core.Input(name) for name in proto.input]
|
|
703
779
|
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
704
780
|
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
|
|
@@ -741,6 +817,15 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
741
817
|
def deserialize_value_info_proto(
|
|
742
818
|
proto: onnx.ValueInfoProto, value: _core.Value | None
|
|
743
819
|
) -> _core.Value:
|
|
820
|
+
"""Deserialize an ONNX ValueInfoProto into an IR Value.
|
|
821
|
+
|
|
822
|
+
Args:
|
|
823
|
+
proto: The ONNX ValueInfoProto to deserialize.
|
|
824
|
+
value: An existing Value to update, or None to create a new one.
|
|
825
|
+
|
|
826
|
+
Returns:
|
|
827
|
+
An IR Value object with type and shape information populated from the proto.
|
|
828
|
+
"""
|
|
744
829
|
if value is None:
|
|
745
830
|
value = _core.Value(name=proto.name)
|
|
746
831
|
value.shape = deserialize_type_proto_for_shape(proto.type)
|
|
@@ -767,6 +852,14 @@ def _deserialize_quantization_annotation(
|
|
|
767
852
|
|
|
768
853
|
@_capture_errors(str)
|
|
769
854
|
def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
855
|
+
"""Deserialize an ONNX TensorShapeProto into an IR Shape.
|
|
856
|
+
|
|
857
|
+
Args:
|
|
858
|
+
proto: The ONNX TensorShapeProto to deserialize.
|
|
859
|
+
|
|
860
|
+
Returns:
|
|
861
|
+
An IR Shape object representing the tensor shape.
|
|
862
|
+
"""
|
|
770
863
|
# This logic handles when the shape is [] as well
|
|
771
864
|
dim_protos = proto.dim
|
|
772
865
|
deserialized_dim_denotations = [
|
|
@@ -779,6 +872,14 @@ def deserialize_tensor_shape(proto: onnx.TensorShapeProto) -> _core.Shape:
|
|
|
779
872
|
|
|
780
873
|
@_capture_errors(str)
|
|
781
874
|
def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | None:
|
|
875
|
+
"""Extract and deserialize shape information from an ONNX TypeProto.
|
|
876
|
+
|
|
877
|
+
Args:
|
|
878
|
+
proto: The ONNX TypeProto to extract shape from.
|
|
879
|
+
|
|
880
|
+
Returns:
|
|
881
|
+
An IR Shape object if shape information is present, None otherwise.
|
|
882
|
+
"""
|
|
782
883
|
if proto.HasField("tensor_type"):
|
|
783
884
|
if (shape_proto := _get_field(proto.tensor_type, "shape")) is None:
|
|
784
885
|
return None
|
|
@@ -806,6 +907,14 @@ def deserialize_type_proto_for_shape(proto: onnx.TypeProto) -> _core.Shape | Non
|
|
|
806
907
|
def deserialize_type_proto_for_type(
|
|
807
908
|
proto: onnx.TypeProto,
|
|
808
909
|
) -> _protocols.TypeProtocol | None:
|
|
910
|
+
"""Extract and deserialize type information from an ONNX TypeProto.
|
|
911
|
+
|
|
912
|
+
Args:
|
|
913
|
+
proto: The ONNX TypeProto to extract type from.
|
|
914
|
+
|
|
915
|
+
Returns:
|
|
916
|
+
An IR type object (TensorType, SequenceType, etc.) if type information is present, None otherwise.
|
|
917
|
+
"""
|
|
809
918
|
denotation = _get_field(proto, "denotation")
|
|
810
919
|
if proto.HasField("tensor_type"):
|
|
811
920
|
if (elem_type := _get_field(proto.tensor_type, "elem_type")) is None:
|
|
@@ -906,6 +1015,14 @@ _deserialize_string_string_maps = deserialize_metadata_props
|
|
|
906
1015
|
|
|
907
1016
|
|
|
908
1017
|
def deserialize_attribute(proto: onnx.AttributeProto) -> _core.Attr:
|
|
1018
|
+
"""Deserialize an ONNX AttributeProto into an IR Attribute.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
proto: The ONNX AttributeProto to deserialize.
|
|
1022
|
+
|
|
1023
|
+
Returns:
|
|
1024
|
+
An IR Attribute object representing the ONNX attribute.
|
|
1025
|
+
"""
|
|
909
1026
|
return _deserialize_attribute(proto, [])
|
|
910
1027
|
|
|
911
1028
|
|
|
@@ -979,6 +1096,14 @@ def _deserialize_attribute(
|
|
|
979
1096
|
|
|
980
1097
|
|
|
981
1098
|
def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
|
|
1099
|
+
"""Deserialize an ONNX NodeProto into an IR Node.
|
|
1100
|
+
|
|
1101
|
+
Args:
|
|
1102
|
+
proto: The ONNX NodeProto to deserialize.
|
|
1103
|
+
|
|
1104
|
+
Returns:
|
|
1105
|
+
An IR Node object representing the ONNX node.
|
|
1106
|
+
"""
|
|
982
1107
|
return _deserialize_node(
|
|
983
1108
|
proto, scoped_values=[{}], value_info={}, quantization_annotations={}
|
|
984
1109
|
)
|
|
@@ -1097,6 +1222,14 @@ def _deserialize_node(
|
|
|
1097
1222
|
|
|
1098
1223
|
|
|
1099
1224
|
def serialize_model(model: _protocols.ModelProtocol) -> onnx.ModelProto:
|
|
1225
|
+
"""Serialize an IR Model to an ONNX ModelProto.
|
|
1226
|
+
|
|
1227
|
+
Args:
|
|
1228
|
+
model: The IR Model to serialize.
|
|
1229
|
+
|
|
1230
|
+
Returns:
|
|
1231
|
+
The serialized ONNX ModelProto object.
|
|
1232
|
+
"""
|
|
1100
1233
|
return serialize_model_into(onnx.ModelProto(), from_=model)
|
|
1101
1234
|
|
|
1102
1235
|
|
|
@@ -1418,6 +1551,14 @@ def serialize_function_into(
|
|
|
1418
1551
|
|
|
1419
1552
|
|
|
1420
1553
|
def serialize_node(node: _protocols.NodeProtocol) -> onnx.NodeProto:
|
|
1554
|
+
"""Serialize an IR Node to an ONNX NodeProto.
|
|
1555
|
+
|
|
1556
|
+
Args:
|
|
1557
|
+
node: The IR Node to serialize.
|
|
1558
|
+
|
|
1559
|
+
Returns:
|
|
1560
|
+
The serialized ONNX NodeProto object.
|
|
1561
|
+
"""
|
|
1421
1562
|
node_proto = onnx.NodeProto()
|
|
1422
1563
|
serialize_node_into(node_proto, from_=node)
|
|
1423
1564
|
return node_proto
|
|
@@ -1472,6 +1613,14 @@ def serialize_node_into(node_proto: onnx.NodeProto, from_: _protocols.NodeProtoc
|
|
|
1472
1613
|
|
|
1473
1614
|
|
|
1474
1615
|
def serialize_tensor(tensor: _protocols.TensorProtocol) -> onnx.TensorProto:
|
|
1616
|
+
"""Serialize an IR Tensor to an ONNX TensorProto.
|
|
1617
|
+
|
|
1618
|
+
Args:
|
|
1619
|
+
tensor: The IR Tensor to serialize.
|
|
1620
|
+
|
|
1621
|
+
Returns:
|
|
1622
|
+
The serialized ONNX TensorProto object.
|
|
1623
|
+
"""
|
|
1475
1624
|
tensor_proto = onnx.TensorProto()
|
|
1476
1625
|
serialize_tensor_into(tensor_proto, from_=tensor)
|
|
1477
1626
|
return tensor_proto
|
|
@@ -1514,6 +1663,14 @@ def serialize_tensor_into(
|
|
|
1514
1663
|
|
|
1515
1664
|
|
|
1516
1665
|
def serialize_attribute(attribute: _protocols.AttributeProtocol) -> onnx.AttributeProto:
|
|
1666
|
+
"""Serialize an IR Attribute to an ONNX AttributeProto.
|
|
1667
|
+
|
|
1668
|
+
Args:
|
|
1669
|
+
attribute: The IR Attribute to serialize.
|
|
1670
|
+
|
|
1671
|
+
Returns:
|
|
1672
|
+
The serialized ONNX AttributeProto object.
|
|
1673
|
+
"""
|
|
1517
1674
|
attribute_proto = onnx.AttributeProto()
|
|
1518
1675
|
serialize_attribute_into(attribute_proto, from_=attribute)
|
|
1519
1676
|
return attribute_proto
|
|
@@ -1678,6 +1835,14 @@ def serialize_type_into(type_proto: onnx.TypeProto, from_: _protocols.TypeProtoc
|
|
|
1678
1835
|
|
|
1679
1836
|
|
|
1680
1837
|
def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
|
|
1838
|
+
"""Serialize an IR Type to an ONNX TypeProto.
|
|
1839
|
+
|
|
1840
|
+
Args:
|
|
1841
|
+
type_protocol: The IR Type to serialize.
|
|
1842
|
+
|
|
1843
|
+
Returns:
|
|
1844
|
+
The serialized ONNX TypeProto object.
|
|
1845
|
+
"""
|
|
1681
1846
|
type_proto = onnx.TypeProto()
|
|
1682
1847
|
serialize_type_into(type_proto, from_=type_protocol)
|
|
1683
1848
|
return type_proto
|
|
@@ -1,12 +1,12 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.2
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License: Apache License v2.0
|
|
7
|
-
Project-URL: Homepage, https://onnx.ai/
|
|
8
|
-
Project-URL: Issues, https://github.com/onnx/
|
|
9
|
-
Project-URL: Repository, https://github.com/onnx/
|
|
7
|
+
Project-URL: Homepage, https://onnx.ai/ir-py
|
|
8
|
+
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
|
+
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.9
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.10
|
|
@@ -33,6 +33,24 @@ Dynamic: license-file
|
|
|
33
33
|
|
|
34
34
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
35
35
|
|
|
36
|
+
## Getting Started
|
|
37
|
+
|
|
38
|
+
[onnx-ir documentation](https://onnx.ai/ir-py/)
|
|
39
|
+
|
|
40
|
+
### Installation
|
|
41
|
+
|
|
42
|
+
Via pip:
|
|
43
|
+
|
|
44
|
+
```
|
|
45
|
+
pip install onnx-ir
|
|
46
|
+
```
|
|
47
|
+
|
|
48
|
+
Or from source:
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
pip install git+https://github.com/onnx/ir-py.git
|
|
52
|
+
```
|
|
53
|
+
|
|
36
54
|
## Features ✨
|
|
37
55
|
|
|
38
56
|
- Full ONNX spec support: all valid models representable by ONNX protobuf, and a subset of invalid models (so you can load and fix them).
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
onnx_ir/__init__.py,sha256=aeEp01Z1OnAVSQyJ_ejLbVnvI2BJE2mTbnRgIvkTDG8,3424
|
|
2
|
+
onnx_ir/_core.py,sha256=F1mQwdSSV8Y6T_yE01nQLnlBtoMs8bnchCNB3ykSaLM,137413
|
|
3
|
+
onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
|
|
4
|
+
onnx_ir/_enums.py,sha256=JX2-uUNtPQgPM6F23BslHllvR-2uZAHvKpzy2iPGg7o,8219
|
|
5
|
+
onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
|
|
6
|
+
onnx_ir/_graph_containers.py,sha256=PRKrshRZ5rzWCgRs1TefzJq9n8wyo7OqeKy3XxMhyys,14265
|
|
7
|
+
onnx_ir/_io.py,sha256=GWwA4XOZ-ZX1cgibgaYD0K0O5d9LX21ZwcBN02Wrh04,5205
|
|
8
|
+
onnx_ir/_linked_list.py,sha256=PXVcbHLMXHLZ6DxZnElnJLWfhBPvYcXUxM8Y3K4J7lM,10592
|
|
9
|
+
onnx_ir/_metadata.py,sha256=lzmCaYy4kAJrPW-PSGKF4a78LisxF0hiofySNQ3Mwhg,1544
|
|
10
|
+
onnx_ir/_name_authority.py,sha256=PnoV9TRgMLussZNufWavJXosDWx5avPfldVjMWEEz18,3036
|
|
11
|
+
onnx_ir/_polyfill.py,sha256=LzAGBKQbVDlURC0tgQgaxgkYU4rESgCYnqVs-u-Vsx8,887
|
|
12
|
+
onnx_ir/_protocols.py,sha256=M29sIOAvtdlis3QtBvCQPH4pnvSwhJCQNCvs3IrN9FY,21276
|
|
13
|
+
onnx_ir/_tape.py,sha256=nEGY6VZVKuB8FDyXeYr0MTq8j7E4HKOE2yN8qpz4ia0,7007
|
|
14
|
+
onnx_ir/_type_casting.py,sha256=8iZDVrNAx_FwRVt48G4tkzIOFu3I6AsETpH3fdxcyEI,3387
|
|
15
|
+
onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,1333
|
|
16
|
+
onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
|
|
17
|
+
onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
|
|
18
|
+
onnx_ir/serde.py,sha256=M2w-D2boYEHi96tA_4eQUGRaV4CYhxoHi-5FPmVEjRk,74968
|
|
19
|
+
onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
|
|
20
|
+
onnx_ir/tensor_adapters.py,sha256=J2z0gxkxwZqBrob1pYT53lgz1XQ1r7kCxhoSZa5NHaQ,4469
|
|
21
|
+
onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
|
|
22
|
+
onnx_ir/traversal.py,sha256=Z69wzYBNljn1S7PhVTYgwMftrfsdEBLoa0JYteOhLL0,2863
|
|
23
|
+
onnx_ir/_convenience/__init__.py,sha256=5g5IlMfozyQcGBMsJW9HtYWzEGe5uAQVrWrDQsnMk1Q,19059
|
|
24
|
+
onnx_ir/_convenience/_constructors.py,sha256=5GhlYy_xCE2ng7l_4cNx06WQsNDyvS-0U1HgOpPKJEk,8347
|
|
25
|
+
onnx_ir/_thirdparty/asciichartpy.py,sha256=afQ0fsqko2uYRPAR4TZBrQxvCb4eN8lxZ2yDFbVQq_s,10533
|
|
26
|
+
onnx_ir/passes/__init__.py,sha256=M_Tcl_-qGSNPluFIvOoeDyh0qAwNayaYyXDS5UJUJPQ,764
|
|
27
|
+
onnx_ir/passes/_pass_infra.py,sha256=xIOw_zZIuOqD4Z_wZ4OvsqXfh2IZMoMlDp1xQ_MPQlc,9567
|
|
28
|
+
onnx_ir/passes/common/__init__.py,sha256=aHjx2y7L7LJChixmKsSUCdiaTP1u-zSmcmEISduqeG4,1335
|
|
29
|
+
onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRkwB8X56cg,2907
|
|
30
|
+
onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
|
|
31
|
+
onnx_ir/passes/common/common_subexpression_elimination.py,sha256=WMsTAI-12A3iVqptmWw0tiBmGwVKsls5VNxZEbjvp2A,6527
|
|
32
|
+
onnx_ir/passes/common/constant_manipulation.py,sha256=_fGDwn0Axl2Q8APfc2m_mLMH28T-Mc9kIlpzBXoe3q4,8779
|
|
33
|
+
onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
|
|
34
|
+
onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6svE0cCyDew,1691
|
|
35
|
+
onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
|
|
36
|
+
onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
|
|
37
|
+
onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
|
|
38
|
+
onnx_ir-0.1.2.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
39
|
+
onnx_ir-0.1.2.dist-info/METADATA,sha256=p5ZfMwXTfm96hlrRDgvdBGzARfeLVpAWhZ_nNSq3m78,4782
|
|
40
|
+
onnx_ir-0.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
41
|
+
onnx_ir-0.1.2.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
|
|
42
|
+
onnx_ir-0.1.2.dist-info/RECORD,,
|