onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.1.1
|
|
23
|
+
|
|
24
|
+
.. versionchanged:: 0.1.3
|
|
25
|
+
Constant nodes with values smaller than ``size_limit`` will be CSE'd.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
|
|
29
|
+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, size_limit: int = 10):
|
|
34
|
+
"""Initialize the CommonSubexpressionEliminationPass."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
40
|
+
graph = model.graph
|
|
41
|
+
modified = self._eliminate_common_subexpression(graph)
|
|
42
|
+
|
|
43
|
+
return ir.passes.PassResult(
|
|
44
|
+
model,
|
|
45
|
+
modified=modified,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
|
|
49
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
50
|
+
modified: bool = False
|
|
51
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
52
|
+
existing_node_info_to_the_node: dict[
|
|
53
|
+
tuple[
|
|
54
|
+
ir.OperatorIdentifier,
|
|
55
|
+
int, # len(outputs)
|
|
56
|
+
tuple[int, ...], # input ids
|
|
57
|
+
tuple[tuple[str, object], ...], # attributes
|
|
58
|
+
],
|
|
59
|
+
ir.Node,
|
|
60
|
+
] = {}
|
|
61
|
+
|
|
62
|
+
for node in graph:
|
|
63
|
+
# Skip control flow ops like Loop and If.
|
|
64
|
+
control_flow_op: bool = False
|
|
65
|
+
# Skip large tensors to avoid cse weights and bias.
|
|
66
|
+
large_tensor: bool = False
|
|
67
|
+
# Use equality to check if the node is a common subexpression.
|
|
68
|
+
attributes = {}
|
|
69
|
+
for k, v in node.attributes.items():
|
|
70
|
+
# TODO(exporter team): CSE subgraphs.
|
|
71
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
72
|
+
# because attribute: graph won't match.
|
|
73
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
74
|
+
control_flow_op = True
|
|
75
|
+
break
|
|
76
|
+
# The attribute value could be directly taken from the original
|
|
77
|
+
# protobuf, so we need to make a copy of it.
|
|
78
|
+
value = v.value
|
|
79
|
+
if v.type in (
|
|
80
|
+
ir.AttributeType.INTS,
|
|
81
|
+
ir.AttributeType.FLOATS,
|
|
82
|
+
ir.AttributeType.STRINGS,
|
|
83
|
+
):
|
|
84
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
85
|
+
# to ensure they are hashable.
|
|
86
|
+
value = tuple(value)
|
|
87
|
+
elif v.type is ir.AttributeType.TENSOR:
|
|
88
|
+
if value.size > self.size_limit:
|
|
89
|
+
# If the tensor is larger than the size limit, we skip it.
|
|
90
|
+
large_tensor = True
|
|
91
|
+
break
|
|
92
|
+
np_value = value.numpy()
|
|
93
|
+
|
|
94
|
+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
|
|
95
|
+
attributes[k] = value
|
|
96
|
+
|
|
97
|
+
if control_flow_op:
|
|
98
|
+
# If the node is a control flow op, we skip it.
|
|
99
|
+
logger.debug("Skipping control flow op %s", node)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if large_tensor:
|
|
103
|
+
# If the node has a large tensor, we skip it.
|
|
104
|
+
logger.debug("Skipping large tensor in node %s", node)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if _is_non_deterministic_op(node):
|
|
108
|
+
# If the node is a non-deterministic op, we skip it.
|
|
109
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
node_info = (
|
|
113
|
+
node.op_identifier(),
|
|
114
|
+
len(node.outputs),
|
|
115
|
+
tuple(id(input) for input in node.inputs),
|
|
116
|
+
tuple(sorted(attributes.items())),
|
|
117
|
+
)
|
|
118
|
+
# Check if the node is a common subexpression.
|
|
119
|
+
if node_info in existing_node_info_to_the_node:
|
|
120
|
+
# If it is, this node has an existing node with the same
|
|
121
|
+
# operator, number of outputs, inputs, and attributes.
|
|
122
|
+
# We replace the node with the existing node.
|
|
123
|
+
modified = True
|
|
124
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
125
|
+
_remove_node_and_replace_values(
|
|
126
|
+
graph,
|
|
127
|
+
remove_node=node,
|
|
128
|
+
remove_values=node.outputs,
|
|
129
|
+
new_values=existing_node.outputs,
|
|
130
|
+
)
|
|
131
|
+
logger.debug("Reusing node %s", existing_node)
|
|
132
|
+
else:
|
|
133
|
+
# If it is not, add to the mapping.
|
|
134
|
+
existing_node_info_to_the_node[node_info] = node
|
|
135
|
+
return modified
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _remove_node_and_replace_values(
|
|
139
|
+
graph: ir.Graph,
|
|
140
|
+
/,
|
|
141
|
+
remove_node: ir.Node,
|
|
142
|
+
remove_values: Sequence[ir.Value],
|
|
143
|
+
new_values: Sequence[ir.Value],
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Replaces nodes and values in the graph or function.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
graph: The graph to replace nodes and values in.
|
|
149
|
+
remove_node: The node to remove.
|
|
150
|
+
remove_values: The values to replace.
|
|
151
|
+
new_values: The values to replace with.
|
|
152
|
+
"""
|
|
153
|
+
# Update graph/function outputs if the node generates output
|
|
154
|
+
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
155
|
+
replacement_mapping = dict(zip(remove_values, new_values))
|
|
156
|
+
for idx, graph_output in enumerate(graph.outputs):
|
|
157
|
+
if graph_output in replacement_mapping:
|
|
158
|
+
new_value = replacement_mapping[graph_output]
|
|
159
|
+
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
160
|
+
# If the new value is also a graph input/output, we need to
|
|
161
|
+
# create a Identity node to preserve the remove_value and
|
|
162
|
+
# prevent from changing new_value name.
|
|
163
|
+
identity_node = ir.node(
|
|
164
|
+
"Identity",
|
|
165
|
+
inputs=[new_value],
|
|
166
|
+
outputs=[
|
|
167
|
+
ir.Value(
|
|
168
|
+
name=graph_output.name,
|
|
169
|
+
type=graph_output.type,
|
|
170
|
+
shape=graph_output.shape,
|
|
171
|
+
)
|
|
172
|
+
],
|
|
173
|
+
)
|
|
174
|
+
# reuse the name of the graph output
|
|
175
|
+
graph.outputs[idx] = identity_node.outputs[0]
|
|
176
|
+
graph.insert_before(
|
|
177
|
+
remove_node,
|
|
178
|
+
identity_node,
|
|
179
|
+
)
|
|
180
|
+
else:
|
|
181
|
+
# if new_value is not graph output, we just
|
|
182
|
+
# update it to use old_value name.
|
|
183
|
+
new_value.name = graph_output.name
|
|
184
|
+
graph.outputs[idx] = new_value
|
|
185
|
+
|
|
186
|
+
# Reconnect the users of the deleted values to use the new values
|
|
187
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
188
|
+
|
|
189
|
+
graph.remove(remove_node, safe=True)
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
193
|
+
non_deterministic_ops = frozenset(
|
|
194
|
+
{
|
|
195
|
+
"RandomUniform",
|
|
196
|
+
"RandomNormal",
|
|
197
|
+
"RandomUniformLike",
|
|
198
|
+
"RandomNormalLike",
|
|
199
|
+
"Multinomial",
|
|
200
|
+
}
|
|
201
|
+
)
|
|
202
|
+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def _is_onnx_domain(d: str) -> bool:
|
|
206
|
+
"""Check if the domain is the ONNX domain."""
|
|
207
|
+
return d == ""
|
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Lift constants to initializers."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AddInitializersToInputsPass",
|
|
9
|
+
"LiftConstantsToInitializersPass",
|
|
10
|
+
"LiftSubgraphInitializersToMainGraphPass",
|
|
11
|
+
"RemoveInitializersFromInputsPass",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
import onnx_ir as ir
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
|
|
24
|
+
"""Lift constants to initializers.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
|
|
28
|
+
Default to False, where only Constants with the ``value`` attribute are lifted.
|
|
29
|
+
size_limit: The minimum size of the tensor to be lifted. If the tensor contains
|
|
30
|
+
number of elements less than size_limit, it will not be lifted. Default is 16.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.lift_all_constants = lift_all_constants
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
count = 0
|
|
40
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
41
|
+
assert node.graph is not None
|
|
42
|
+
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
|
|
43
|
+
continue
|
|
44
|
+
if node.outputs[0].is_graph_output():
|
|
45
|
+
logger.debug(
|
|
46
|
+
"Constant node '%s' is used as output, so it can't be lifted.", node.name
|
|
47
|
+
)
|
|
48
|
+
continue
|
|
49
|
+
constant_node_attribute = set(node.attributes.keys())
|
|
50
|
+
if len(constant_node_attribute) != 1:
|
|
51
|
+
logger.debug(
|
|
52
|
+
"Invalid constant node '%s' has more than one attribute", node.name
|
|
53
|
+
)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
attr_name, attr_value = next(iter(node.attributes.items()))
|
|
57
|
+
initializer_name = node.outputs[0].name
|
|
58
|
+
assert initializer_name is not None
|
|
59
|
+
assert isinstance(attr_value, ir.Attr)
|
|
60
|
+
tensor = self._constant_node_attribute_to_tensor(
|
|
61
|
+
node, attr_name, attr_value, initializer_name
|
|
62
|
+
)
|
|
63
|
+
if tensor is None:
|
|
64
|
+
# The reason of None is logged in _constant_node_attribute_to_tensor
|
|
65
|
+
continue
|
|
66
|
+
# Register an initializer with the tensor value
|
|
67
|
+
initializer = ir.Value(
|
|
68
|
+
name=initializer_name,
|
|
69
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
70
|
+
type=ir.TensorType(tensor.dtype),
|
|
71
|
+
const_value=tensor,
|
|
72
|
+
# Preserve metadata from Constant value into the onnx model
|
|
73
|
+
metadata_props=node.outputs[0].metadata_props.copy(),
|
|
74
|
+
)
|
|
75
|
+
# Preserve value meta from the Constant output for intermediate analysis
|
|
76
|
+
initializer.meta.update(node.outputs[0].meta)
|
|
77
|
+
|
|
78
|
+
assert node.graph is not None
|
|
79
|
+
node.graph.register_initializer(initializer)
|
|
80
|
+
# Replace the constant node with the initializer
|
|
81
|
+
node.outputs[0].replace_all_uses_with(initializer)
|
|
82
|
+
node.graph.remove(node, safe=True)
|
|
83
|
+
count += 1
|
|
84
|
+
logger.debug(
|
|
85
|
+
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
|
|
86
|
+
)
|
|
87
|
+
if count:
|
|
88
|
+
logger.debug("Lifted %s constants to initializers", count)
|
|
89
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
90
|
+
|
|
91
|
+
def _constant_node_attribute_to_tensor(
|
|
92
|
+
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
|
|
93
|
+
) -> ir.TensorProtocol | None:
|
|
94
|
+
"""Convert constant node attribute to tensor."""
|
|
95
|
+
if not self.lift_all_constants and attr_name != "value":
|
|
96
|
+
logger.debug(
|
|
97
|
+
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
|
|
98
|
+
)
|
|
99
|
+
return None
|
|
100
|
+
|
|
101
|
+
tensor: ir.TensorProtocol
|
|
102
|
+
if attr_name == "value":
|
|
103
|
+
tensor = attr_value.as_tensor()
|
|
104
|
+
elif attr_name == "value_int":
|
|
105
|
+
tensor = ir.tensor(
|
|
106
|
+
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
|
|
107
|
+
)
|
|
108
|
+
elif attr_name == "value_ints":
|
|
109
|
+
tensor = ir.tensor(
|
|
110
|
+
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
|
|
111
|
+
)
|
|
112
|
+
elif attr_name == "value_float":
|
|
113
|
+
tensor = ir.tensor(
|
|
114
|
+
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
115
|
+
)
|
|
116
|
+
elif attr_name == "value_floats":
|
|
117
|
+
tensor = ir.tensor(
|
|
118
|
+
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
119
|
+
)
|
|
120
|
+
elif attr_name in ("value_string", "value_strings"):
|
|
121
|
+
tensor = ir.StringTensor(
|
|
122
|
+
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
|
|
123
|
+
)
|
|
124
|
+
else:
|
|
125
|
+
raise ValueError(
|
|
126
|
+
f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
if tensor.size < self.size_limit:
|
|
130
|
+
logger.debug(
|
|
131
|
+
"Tensor from node '%s' has less than %s elements",
|
|
132
|
+
node.name,
|
|
133
|
+
self.size_limit,
|
|
134
|
+
)
|
|
135
|
+
return None
|
|
136
|
+
return tensor
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
140
|
+
"""Lift subgraph initializers to main graph.
|
|
141
|
+
|
|
142
|
+
This pass lifts the initializers of a subgraph to the main graph.
|
|
143
|
+
It is used to ensure that the initializers are available in the main graph
|
|
144
|
+
for further processing or optimization.
|
|
145
|
+
|
|
146
|
+
Initializers that are also graph inputs will not be lifted.
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
150
|
+
count = 0
|
|
151
|
+
registered_initializer_names: dict[str, int] = {}
|
|
152
|
+
for graph in model.graphs():
|
|
153
|
+
if graph is model.graph:
|
|
154
|
+
continue
|
|
155
|
+
for name in tuple(graph.initializers):
|
|
156
|
+
assert name is not None
|
|
157
|
+
initializer = graph.initializers[name]
|
|
158
|
+
if initializer.is_graph_input():
|
|
159
|
+
# Skip the ones that are also graph inputs
|
|
160
|
+
logger.debug(
|
|
161
|
+
"Initializer '%s' is also a graph input, so it can't be lifted",
|
|
162
|
+
initializer.name,
|
|
163
|
+
)
|
|
164
|
+
continue
|
|
165
|
+
if initializer.is_graph_output():
|
|
166
|
+
logger.debug(
|
|
167
|
+
"Initializer '%s' is used as output, so it can't be lifted",
|
|
168
|
+
initializer.name,
|
|
169
|
+
)
|
|
170
|
+
continue
|
|
171
|
+
# Remove the initializer from the subgraph
|
|
172
|
+
graph.initializers.pop(name)
|
|
173
|
+
# To avoid name conflicts, we need to rename the initializer
|
|
174
|
+
# to a unique name in the main graph
|
|
175
|
+
new_name = name
|
|
176
|
+
while new_name in model.graph.initializers:
|
|
177
|
+
if name in registered_initializer_names:
|
|
178
|
+
registered_initializer_names[name] += 1
|
|
179
|
+
else:
|
|
180
|
+
registered_initializer_names[name] = 1
|
|
181
|
+
new_name = f"{name}_{registered_initializer_names[name]}"
|
|
182
|
+
initializer.name = new_name
|
|
183
|
+
model.graph.register_initializer(initializer)
|
|
184
|
+
count += 1
|
|
185
|
+
logger.debug(
|
|
186
|
+
"Lifted initializer '%s' from subgraph '%s' to main graph",
|
|
187
|
+
initializer.name,
|
|
188
|
+
graph.name,
|
|
189
|
+
)
|
|
190
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class RemoveInitializersFromInputsPass(ir.passes.InPlacePass):
|
|
194
|
+
"""Remove initializers from inputs.
|
|
195
|
+
|
|
196
|
+
This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list.
|
|
197
|
+
"""
|
|
198
|
+
|
|
199
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
200
|
+
count = 0
|
|
201
|
+
for graph in model.graphs():
|
|
202
|
+
initializers = set(graph.initializers.values())
|
|
203
|
+
new_inputs = []
|
|
204
|
+
for input_value in graph.inputs:
|
|
205
|
+
if input_value in initializers:
|
|
206
|
+
count += 1
|
|
207
|
+
else:
|
|
208
|
+
new_inputs.append(input_value)
|
|
209
|
+
graph.inputs.clear()
|
|
210
|
+
graph.inputs.extend(new_inputs)
|
|
211
|
+
logger.info("Removed %s initializers from graph inputs", count)
|
|
212
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class AddInitializersToInputsPass(ir.passes.InPlacePass):
|
|
216
|
+
"""Add initializers to inputs.
|
|
217
|
+
|
|
218
|
+
This pass finds all initializers and adds them to the graph.inputs list if they are not already present.
|
|
219
|
+
"""
|
|
220
|
+
|
|
221
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
222
|
+
count = 0
|
|
223
|
+
for graph in model.graphs():
|
|
224
|
+
inputs_set = set(graph.inputs)
|
|
225
|
+
for initializer in graph.initializers.values():
|
|
226
|
+
if initializer not in inputs_set:
|
|
227
|
+
graph.inputs.append(initializer)
|
|
228
|
+
count += 1
|
|
229
|
+
logger.info("Added %s initializers to graph inputs", count)
|
|
230
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Add default attributes to nodes that are missing optional attributes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AddDefaultAttributesPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx # noqa: TID251
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _has_valid_default(attr_def: onnx.defs.OpSchema.Attribute) -> bool:
|
|
21
|
+
"""Check if an attribute definition has a valid default value."""
|
|
22
|
+
return bool(
|
|
23
|
+
attr_def.default_value and attr_def.default_value.type != onnx.AttributeProto.UNDEFINED
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class AddDefaultAttributesPass(ir.passes.InPlacePass):
|
|
28
|
+
"""Add default values for optional attributes that are not present in nodes.
|
|
29
|
+
|
|
30
|
+
This pass iterates through all nodes in the model and for each node:
|
|
31
|
+
1. Gets the ONNX schema for the operator
|
|
32
|
+
2. For each optional attribute with a default value in the schema
|
|
33
|
+
3. If the attribute is not present in the node, adds it with the default value
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
37
|
+
"""Main entry point for the add default attributes pass."""
|
|
38
|
+
modified = False
|
|
39
|
+
|
|
40
|
+
# Process all nodes in the model graph and subgraphs
|
|
41
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
42
|
+
if _add_default_attributes_to_node(node, model.graph.opset_imports):
|
|
43
|
+
modified = True
|
|
44
|
+
|
|
45
|
+
# Process nodes in functions
|
|
46
|
+
for function in model.functions.values():
|
|
47
|
+
for node in ir.traversal.RecursiveGraphIterator(function):
|
|
48
|
+
if _add_default_attributes_to_node(node, model.graph.opset_imports):
|
|
49
|
+
modified = True
|
|
50
|
+
|
|
51
|
+
if modified:
|
|
52
|
+
logger.info("AddDefaultAttributes pass modified the model")
|
|
53
|
+
|
|
54
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _add_default_attributes_to_node(node: ir.Node, opset_imports: dict[str, int]) -> bool:
|
|
58
|
+
"""Add default attributes to a single node. Returns True if modified."""
|
|
59
|
+
# Get the operator schema
|
|
60
|
+
if node.version is not None:
|
|
61
|
+
opset_version = node.version
|
|
62
|
+
elif node.domain in opset_imports:
|
|
63
|
+
opset_version = opset_imports[node.domain]
|
|
64
|
+
else:
|
|
65
|
+
logger.warning(
|
|
66
|
+
"OpSet version for domain '%s' not found. Skipping node %s",
|
|
67
|
+
node.domain,
|
|
68
|
+
node,
|
|
69
|
+
)
|
|
70
|
+
return False
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
op_schema = onnx.defs.get_schema(node.op_type, opset_version, domain=node.domain)
|
|
74
|
+
except onnx.defs.SchemaError:
|
|
75
|
+
logger.debug(
|
|
76
|
+
"Schema not found for %s, skipping default attribute addition",
|
|
77
|
+
node,
|
|
78
|
+
)
|
|
79
|
+
return False
|
|
80
|
+
|
|
81
|
+
modified = False
|
|
82
|
+
# Iterate through all attributes in the schema
|
|
83
|
+
for attr_name, attr_def in op_schema.attributes.items():
|
|
84
|
+
# Skip if attribute is required or already present in the node
|
|
85
|
+
if attr_def.required or attr_name in node.attributes:
|
|
86
|
+
continue
|
|
87
|
+
|
|
88
|
+
# Skip if attribute doesn't have a default value
|
|
89
|
+
if not _has_valid_default(attr_def):
|
|
90
|
+
continue
|
|
91
|
+
|
|
92
|
+
# Create an IR Attr from the ONNX AttributeProto default value
|
|
93
|
+
default_attr_proto = attr_def.default_value
|
|
94
|
+
default_attr = ir.serde.deserialize_attribute(default_attr_proto)
|
|
95
|
+
node.attributes[attr_name] = default_attr
|
|
96
|
+
logger.debug("Added default attribute '%s' to node %s", attr_name, node)
|
|
97
|
+
modified = True
|
|
98
|
+
|
|
99
|
+
return modified
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Identity elimination pass for removing redundant Identity nodes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"IdentityEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx_ir as ir
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
|
|
19
|
+
def merge_dims(dim1, dim2):
|
|
20
|
+
if dim1 == dim2:
|
|
21
|
+
return dim1
|
|
22
|
+
if not isinstance(dim1, ir.SymbolicDim):
|
|
23
|
+
return dim1 # Prefer int value over symbolic dim
|
|
24
|
+
if not isinstance(dim2, ir.SymbolicDim):
|
|
25
|
+
return dim2
|
|
26
|
+
if dim1.value is None:
|
|
27
|
+
return dim2
|
|
28
|
+
return dim1
|
|
29
|
+
|
|
30
|
+
if shape1 is None:
|
|
31
|
+
return shape2
|
|
32
|
+
if shape2 is None:
|
|
33
|
+
return shape1
|
|
34
|
+
if len(shape1) != len(shape2):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
|
|
37
|
+
)
|
|
38
|
+
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
42
|
+
"""Pass for eliminating redundant Identity nodes.
|
|
43
|
+
|
|
44
|
+
This pass removes Identity nodes according to the following rules:
|
|
45
|
+
|
|
46
|
+
1. For any node of the form `y = Identity(x)`, where `y` is not an output
|
|
47
|
+
of any graph, replace all uses of `y` with a use of `x`, and remove the node.
|
|
48
|
+
2. If `y` is an output of a graph, and `x` is not an input of any graph,
|
|
49
|
+
we can still do the elimination, but the value `x` should be renamed to be `y`.
|
|
50
|
+
3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
|
|
51
|
+
the node. It should be retained.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
55
|
+
"""Main entry point for the identity elimination pass."""
|
|
56
|
+
modified = False
|
|
57
|
+
|
|
58
|
+
# Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
|
|
59
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
60
|
+
if self._try_eliminate_identity_node(node):
|
|
61
|
+
modified = True
|
|
62
|
+
|
|
63
|
+
# Process nodes in functions
|
|
64
|
+
for function in model.functions.values():
|
|
65
|
+
for node in ir.traversal.RecursiveGraphIterator(function):
|
|
66
|
+
if self._try_eliminate_identity_node(node):
|
|
67
|
+
modified = True
|
|
68
|
+
|
|
69
|
+
if modified:
|
|
70
|
+
logger.info("Identity elimination pass modified the model")
|
|
71
|
+
|
|
72
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
73
|
+
|
|
74
|
+
def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
|
|
75
|
+
"""Try to eliminate a single identity node. Returns True if modified."""
|
|
76
|
+
if node.op_type != "Identity" or node.domain != "":
|
|
77
|
+
return False
|
|
78
|
+
|
|
79
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
80
|
+
# Invalid Identity node, skip
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
input_value = node.inputs[0]
|
|
84
|
+
output_value = node.outputs[0]
|
|
85
|
+
|
|
86
|
+
if input_value is None:
|
|
87
|
+
# Cannot eliminate if input is None
|
|
88
|
+
return False
|
|
89
|
+
|
|
90
|
+
# Get the graph that contains this node
|
|
91
|
+
graph_like = node.graph
|
|
92
|
+
assert graph_like is not None, "Node must be in a graph"
|
|
93
|
+
|
|
94
|
+
output_is_graph_output = output_value.is_graph_output()
|
|
95
|
+
input_is_graph_input = input_value.is_graph_input()
|
|
96
|
+
|
|
97
|
+
# Case 3: Both output is graph output and input is graph input - keep the node
|
|
98
|
+
if output_is_graph_output and input_is_graph_input:
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
# Copy over shape/type if the output has more complete information
|
|
102
|
+
input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
|
|
103
|
+
if input_value.type is None:
|
|
104
|
+
input_value.type = output_value.type
|
|
105
|
+
|
|
106
|
+
# Case 1 & 2 (merged): Eliminate the identity node
|
|
107
|
+
# Replace all uses of output with input
|
|
108
|
+
ir.convenience.replace_all_uses_with(
|
|
109
|
+
output_value, input_value, replace_graph_outputs=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# If output is a graph output, we need to rename input and update graph outputs
|
|
113
|
+
if output_is_graph_output:
|
|
114
|
+
# Update the input value to have the output's name
|
|
115
|
+
input_value.name = output_value.name
|
|
116
|
+
|
|
117
|
+
# Remove the identity node
|
|
118
|
+
graph_like.remove(node, safe=True)
|
|
119
|
+
logger.debug("Eliminated identity node: %s", node)
|
|
120
|
+
return True
|