onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,207 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Eliminate common subexpression in ONNX graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CommonSubexpressionEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+ from collections.abc import Sequence
13
+
14
+ import onnx_ir as ir
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
20
+ """Eliminate common subexpression in ONNX graphs.
21
+
22
+ .. versionadded:: 0.1.1
23
+
24
+ .. versionchanged:: 0.1.3
25
+ Constant nodes with values smaller than ``size_limit`` will be CSE'd.
26
+
27
+ Attributes:
28
+ size_limit: The maximum size of the tensor to be csed. If the tensor contains
29
+ number of elements larger than size_limit, it will not be cse'd. Default is 10.
30
+
31
+ """
32
+
33
+ def __init__(self, size_limit: int = 10):
34
+ """Initialize the CommonSubexpressionEliminationPass."""
35
+ super().__init__()
36
+ self.size_limit = size_limit
37
+
38
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
39
+ """Return the same ir.Model but with CSE applied to the graph."""
40
+ graph = model.graph
41
+ modified = self._eliminate_common_subexpression(graph)
42
+
43
+ return ir.passes.PassResult(
44
+ model,
45
+ modified=modified,
46
+ )
47
+
48
+ def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
49
+ """Eliminate common subexpression in ONNX graphs."""
50
+ modified: bool = False
51
+ # node to node identifier, length of outputs, inputs, and attributes
52
+ existing_node_info_to_the_node: dict[
53
+ tuple[
54
+ ir.OperatorIdentifier,
55
+ int, # len(outputs)
56
+ tuple[int, ...], # input ids
57
+ tuple[tuple[str, object], ...], # attributes
58
+ ],
59
+ ir.Node,
60
+ ] = {}
61
+
62
+ for node in graph:
63
+ # Skip control flow ops like Loop and If.
64
+ control_flow_op: bool = False
65
+ # Skip large tensors to avoid cse weights and bias.
66
+ large_tensor: bool = False
67
+ # Use equality to check if the node is a common subexpression.
68
+ attributes = {}
69
+ for k, v in node.attributes.items():
70
+ # TODO(exporter team): CSE subgraphs.
71
+ # NOTE: control flow ops like Loop and If won't be CSEd
72
+ # because attribute: graph won't match.
73
+ if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
74
+ control_flow_op = True
75
+ break
76
+ # The attribute value could be directly taken from the original
77
+ # protobuf, so we need to make a copy of it.
78
+ value = v.value
79
+ if v.type in (
80
+ ir.AttributeType.INTS,
81
+ ir.AttributeType.FLOATS,
82
+ ir.AttributeType.STRINGS,
83
+ ):
84
+ # For INT, FLOAT and STRING attributes, we convert them to tuples
85
+ # to ensure they are hashable.
86
+ value = tuple(value)
87
+ elif v.type is ir.AttributeType.TENSOR:
88
+ if value.size > self.size_limit:
89
+ # If the tensor is larger than the size limit, we skip it.
90
+ large_tensor = True
91
+ break
92
+ np_value = value.numpy()
93
+
94
+ value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
95
+ attributes[k] = value
96
+
97
+ if control_flow_op:
98
+ # If the node is a control flow op, we skip it.
99
+ logger.debug("Skipping control flow op %s", node)
100
+ continue
101
+
102
+ if large_tensor:
103
+ # If the node has a large tensor, we skip it.
104
+ logger.debug("Skipping large tensor in node %s", node)
105
+ continue
106
+
107
+ if _is_non_deterministic_op(node):
108
+ # If the node is a non-deterministic op, we skip it.
109
+ logger.debug("Skipping non-deterministic op %s", node)
110
+ continue
111
+
112
+ node_info = (
113
+ node.op_identifier(),
114
+ len(node.outputs),
115
+ tuple(id(input) for input in node.inputs),
116
+ tuple(sorted(attributes.items())),
117
+ )
118
+ # Check if the node is a common subexpression.
119
+ if node_info in existing_node_info_to_the_node:
120
+ # If it is, this node has an existing node with the same
121
+ # operator, number of outputs, inputs, and attributes.
122
+ # We replace the node with the existing node.
123
+ modified = True
124
+ existing_node = existing_node_info_to_the_node[node_info]
125
+ _remove_node_and_replace_values(
126
+ graph,
127
+ remove_node=node,
128
+ remove_values=node.outputs,
129
+ new_values=existing_node.outputs,
130
+ )
131
+ logger.debug("Reusing node %s", existing_node)
132
+ else:
133
+ # If it is not, add to the mapping.
134
+ existing_node_info_to_the_node[node_info] = node
135
+ return modified
136
+
137
+
138
+ def _remove_node_and_replace_values(
139
+ graph: ir.Graph,
140
+ /,
141
+ remove_node: ir.Node,
142
+ remove_values: Sequence[ir.Value],
143
+ new_values: Sequence[ir.Value],
144
+ ) -> None:
145
+ """Replaces nodes and values in the graph or function.
146
+
147
+ Args:
148
+ graph: The graph to replace nodes and values in.
149
+ remove_node: The node to remove.
150
+ remove_values: The values to replace.
151
+ new_values: The values to replace with.
152
+ """
153
+ # Update graph/function outputs if the node generates output
154
+ if any(remove_value.is_graph_output() for remove_value in remove_values):
155
+ replacement_mapping = dict(zip(remove_values, new_values))
156
+ for idx, graph_output in enumerate(graph.outputs):
157
+ if graph_output in replacement_mapping:
158
+ new_value = replacement_mapping[graph_output]
159
+ if new_value.is_graph_output() or new_value.is_graph_input():
160
+ # If the new value is also a graph input/output, we need to
161
+ # create a Identity node to preserve the remove_value and
162
+ # prevent from changing new_value name.
163
+ identity_node = ir.node(
164
+ "Identity",
165
+ inputs=[new_value],
166
+ outputs=[
167
+ ir.Value(
168
+ name=graph_output.name,
169
+ type=graph_output.type,
170
+ shape=graph_output.shape,
171
+ )
172
+ ],
173
+ )
174
+ # reuse the name of the graph output
175
+ graph.outputs[idx] = identity_node.outputs[0]
176
+ graph.insert_before(
177
+ remove_node,
178
+ identity_node,
179
+ )
180
+ else:
181
+ # if new_value is not graph output, we just
182
+ # update it to use old_value name.
183
+ new_value.name = graph_output.name
184
+ graph.outputs[idx] = new_value
185
+
186
+ # Reconnect the users of the deleted values to use the new values
187
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
188
+
189
+ graph.remove(remove_node, safe=True)
190
+
191
+
192
+ def _is_non_deterministic_op(node: ir.Node) -> bool:
193
+ non_deterministic_ops = frozenset(
194
+ {
195
+ "RandomUniform",
196
+ "RandomNormal",
197
+ "RandomUniformLike",
198
+ "RandomNormalLike",
199
+ "Multinomial",
200
+ }
201
+ )
202
+ return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
203
+
204
+
205
+ def _is_onnx_domain(d: str) -> bool:
206
+ """Check if the domain is the ONNX domain."""
207
+ return d == ""
@@ -0,0 +1,230 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Lift constants to initializers."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "AddInitializersToInputsPass",
9
+ "LiftConstantsToInitializersPass",
10
+ "LiftSubgraphInitializersToMainGraphPass",
11
+ "RemoveInitializersFromInputsPass",
12
+ ]
13
+
14
+ import logging
15
+
16
+ import numpy as np
17
+
18
+ import onnx_ir as ir
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
24
+ """Lift constants to initializers.
25
+
26
+ Attributes:
27
+ lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
28
+ Default to False, where only Constants with the ``value`` attribute are lifted.
29
+ size_limit: The minimum size of the tensor to be lifted. If the tensor contains
30
+ number of elements less than size_limit, it will not be lifted. Default is 16.
31
+ """
32
+
33
+ def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
34
+ super().__init__()
35
+ self.lift_all_constants = lift_all_constants
36
+ self.size_limit = size_limit
37
+
38
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
39
+ count = 0
40
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
41
+ assert node.graph is not None
42
+ if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
43
+ continue
44
+ if node.outputs[0].is_graph_output():
45
+ logger.debug(
46
+ "Constant node '%s' is used as output, so it can't be lifted.", node.name
47
+ )
48
+ continue
49
+ constant_node_attribute = set(node.attributes.keys())
50
+ if len(constant_node_attribute) != 1:
51
+ logger.debug(
52
+ "Invalid constant node '%s' has more than one attribute", node.name
53
+ )
54
+ continue
55
+
56
+ attr_name, attr_value = next(iter(node.attributes.items()))
57
+ initializer_name = node.outputs[0].name
58
+ assert initializer_name is not None
59
+ assert isinstance(attr_value, ir.Attr)
60
+ tensor = self._constant_node_attribute_to_tensor(
61
+ node, attr_name, attr_value, initializer_name
62
+ )
63
+ if tensor is None:
64
+ # The reason of None is logged in _constant_node_attribute_to_tensor
65
+ continue
66
+ # Register an initializer with the tensor value
67
+ initializer = ir.Value(
68
+ name=initializer_name,
69
+ shape=tensor.shape, # type: ignore[arg-type]
70
+ type=ir.TensorType(tensor.dtype),
71
+ const_value=tensor,
72
+ # Preserve metadata from Constant value into the onnx model
73
+ metadata_props=node.outputs[0].metadata_props.copy(),
74
+ )
75
+ # Preserve value meta from the Constant output for intermediate analysis
76
+ initializer.meta.update(node.outputs[0].meta)
77
+
78
+ assert node.graph is not None
79
+ node.graph.register_initializer(initializer)
80
+ # Replace the constant node with the initializer
81
+ node.outputs[0].replace_all_uses_with(initializer)
82
+ node.graph.remove(node, safe=True)
83
+ count += 1
84
+ logger.debug(
85
+ "Converted constant node '%s' to initializer '%s'", node.name, initializer_name
86
+ )
87
+ if count:
88
+ logger.debug("Lifted %s constants to initializers", count)
89
+ return ir.passes.PassResult(model, modified=bool(count))
90
+
91
+ def _constant_node_attribute_to_tensor(
92
+ self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
93
+ ) -> ir.TensorProtocol | None:
94
+ """Convert constant node attribute to tensor."""
95
+ if not self.lift_all_constants and attr_name != "value":
96
+ logger.debug(
97
+ "Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
98
+ )
99
+ return None
100
+
101
+ tensor: ir.TensorProtocol
102
+ if attr_name == "value":
103
+ tensor = attr_value.as_tensor()
104
+ elif attr_name == "value_int":
105
+ tensor = ir.tensor(
106
+ attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
107
+ )
108
+ elif attr_name == "value_ints":
109
+ tensor = ir.tensor(
110
+ attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
111
+ )
112
+ elif attr_name == "value_float":
113
+ tensor = ir.tensor(
114
+ attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
115
+ )
116
+ elif attr_name == "value_floats":
117
+ tensor = ir.tensor(
118
+ attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
119
+ )
120
+ elif attr_name in ("value_string", "value_strings"):
121
+ tensor = ir.StringTensor(
122
+ np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
123
+ )
124
+ else:
125
+ raise ValueError(
126
+ f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
127
+ )
128
+
129
+ if tensor.size < self.size_limit:
130
+ logger.debug(
131
+ "Tensor from node '%s' has less than %s elements",
132
+ node.name,
133
+ self.size_limit,
134
+ )
135
+ return None
136
+ return tensor
137
+
138
+
139
+ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
140
+ """Lift subgraph initializers to main graph.
141
+
142
+ This pass lifts the initializers of a subgraph to the main graph.
143
+ It is used to ensure that the initializers are available in the main graph
144
+ for further processing or optimization.
145
+
146
+ Initializers that are also graph inputs will not be lifted.
147
+ """
148
+
149
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
150
+ count = 0
151
+ registered_initializer_names: dict[str, int] = {}
152
+ for graph in model.graphs():
153
+ if graph is model.graph:
154
+ continue
155
+ for name in tuple(graph.initializers):
156
+ assert name is not None
157
+ initializer = graph.initializers[name]
158
+ if initializer.is_graph_input():
159
+ # Skip the ones that are also graph inputs
160
+ logger.debug(
161
+ "Initializer '%s' is also a graph input, so it can't be lifted",
162
+ initializer.name,
163
+ )
164
+ continue
165
+ if initializer.is_graph_output():
166
+ logger.debug(
167
+ "Initializer '%s' is used as output, so it can't be lifted",
168
+ initializer.name,
169
+ )
170
+ continue
171
+ # Remove the initializer from the subgraph
172
+ graph.initializers.pop(name)
173
+ # To avoid name conflicts, we need to rename the initializer
174
+ # to a unique name in the main graph
175
+ new_name = name
176
+ while new_name in model.graph.initializers:
177
+ if name in registered_initializer_names:
178
+ registered_initializer_names[name] += 1
179
+ else:
180
+ registered_initializer_names[name] = 1
181
+ new_name = f"{name}_{registered_initializer_names[name]}"
182
+ initializer.name = new_name
183
+ model.graph.register_initializer(initializer)
184
+ count += 1
185
+ logger.debug(
186
+ "Lifted initializer '%s' from subgraph '%s' to main graph",
187
+ initializer.name,
188
+ graph.name,
189
+ )
190
+ return ir.passes.PassResult(model, modified=bool(count))
191
+
192
+
193
+ class RemoveInitializersFromInputsPass(ir.passes.InPlacePass):
194
+ """Remove initializers from inputs.
195
+
196
+ This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list.
197
+ """
198
+
199
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
200
+ count = 0
201
+ for graph in model.graphs():
202
+ initializers = set(graph.initializers.values())
203
+ new_inputs = []
204
+ for input_value in graph.inputs:
205
+ if input_value in initializers:
206
+ count += 1
207
+ else:
208
+ new_inputs.append(input_value)
209
+ graph.inputs.clear()
210
+ graph.inputs.extend(new_inputs)
211
+ logger.info("Removed %s initializers from graph inputs", count)
212
+ return ir.passes.PassResult(model, modified=bool(count))
213
+
214
+
215
+ class AddInitializersToInputsPass(ir.passes.InPlacePass):
216
+ """Add initializers to inputs.
217
+
218
+ This pass finds all initializers and adds them to the graph.inputs list if they are not already present.
219
+ """
220
+
221
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
222
+ count = 0
223
+ for graph in model.graphs():
224
+ inputs_set = set(graph.inputs)
225
+ for initializer in graph.initializers.values():
226
+ if initializer not in inputs_set:
227
+ graph.inputs.append(initializer)
228
+ count += 1
229
+ logger.info("Added %s initializers to graph inputs", count)
230
+ return ir.passes.PassResult(model, modified=bool(count))
@@ -0,0 +1,99 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Add default attributes to nodes that are missing optional attributes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "AddDefaultAttributesPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx # noqa: TID251
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _has_valid_default(attr_def: onnx.defs.OpSchema.Attribute) -> bool:
21
+ """Check if an attribute definition has a valid default value."""
22
+ return bool(
23
+ attr_def.default_value and attr_def.default_value.type != onnx.AttributeProto.UNDEFINED
24
+ )
25
+
26
+
27
+ class AddDefaultAttributesPass(ir.passes.InPlacePass):
28
+ """Add default values for optional attributes that are not present in nodes.
29
+
30
+ This pass iterates through all nodes in the model and for each node:
31
+ 1. Gets the ONNX schema for the operator
32
+ 2. For each optional attribute with a default value in the schema
33
+ 3. If the attribute is not present in the node, adds it with the default value
34
+ """
35
+
36
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
37
+ """Main entry point for the add default attributes pass."""
38
+ modified = False
39
+
40
+ # Process all nodes in the model graph and subgraphs
41
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
42
+ if _add_default_attributes_to_node(node, model.graph.opset_imports):
43
+ modified = True
44
+
45
+ # Process nodes in functions
46
+ for function in model.functions.values():
47
+ for node in ir.traversal.RecursiveGraphIterator(function):
48
+ if _add_default_attributes_to_node(node, model.graph.opset_imports):
49
+ modified = True
50
+
51
+ if modified:
52
+ logger.info("AddDefaultAttributes pass modified the model")
53
+
54
+ return ir.passes.PassResult(model, modified=modified)
55
+
56
+
57
+ def _add_default_attributes_to_node(node: ir.Node, opset_imports: dict[str, int]) -> bool:
58
+ """Add default attributes to a single node. Returns True if modified."""
59
+ # Get the operator schema
60
+ if node.version is not None:
61
+ opset_version = node.version
62
+ elif node.domain in opset_imports:
63
+ opset_version = opset_imports[node.domain]
64
+ else:
65
+ logger.warning(
66
+ "OpSet version for domain '%s' not found. Skipping node %s",
67
+ node.domain,
68
+ node,
69
+ )
70
+ return False
71
+
72
+ try:
73
+ op_schema = onnx.defs.get_schema(node.op_type, opset_version, domain=node.domain)
74
+ except onnx.defs.SchemaError:
75
+ logger.debug(
76
+ "Schema not found for %s, skipping default attribute addition",
77
+ node,
78
+ )
79
+ return False
80
+
81
+ modified = False
82
+ # Iterate through all attributes in the schema
83
+ for attr_name, attr_def in op_schema.attributes.items():
84
+ # Skip if attribute is required or already present in the node
85
+ if attr_def.required or attr_name in node.attributes:
86
+ continue
87
+
88
+ # Skip if attribute doesn't have a default value
89
+ if not _has_valid_default(attr_def):
90
+ continue
91
+
92
+ # Create an IR Attr from the ONNX AttributeProto default value
93
+ default_attr_proto = attr_def.default_value
94
+ default_attr = ir.serde.deserialize_attribute(default_attr_proto)
95
+ node.attributes[attr_name] = default_attr
96
+ logger.debug("Added default attribute '%s' to node %s", attr_name, node)
97
+ modified = True
98
+
99
+ return modified
@@ -0,0 +1,120 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Identity elimination pass for removing redundant Identity nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "IdentityEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19
+ def merge_dims(dim1, dim2):
20
+ if dim1 == dim2:
21
+ return dim1
22
+ if not isinstance(dim1, ir.SymbolicDim):
23
+ return dim1 # Prefer int value over symbolic dim
24
+ if not isinstance(dim2, ir.SymbolicDim):
25
+ return dim2
26
+ if dim1.value is None:
27
+ return dim2
28
+ return dim1
29
+
30
+ if shape1 is None:
31
+ return shape2
32
+ if shape2 is None:
33
+ return shape1
34
+ if len(shape1) != len(shape2):
35
+ raise ValueError(
36
+ f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
37
+ )
38
+ return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
39
+
40
+
41
+ class IdentityEliminationPass(ir.passes.InPlacePass):
42
+ """Pass for eliminating redundant Identity nodes.
43
+
44
+ This pass removes Identity nodes according to the following rules:
45
+
46
+ 1. For any node of the form `y = Identity(x)`, where `y` is not an output
47
+ of any graph, replace all uses of `y` with a use of `x`, and remove the node.
48
+ 2. If `y` is an output of a graph, and `x` is not an input of any graph,
49
+ we can still do the elimination, but the value `x` should be renamed to be `y`.
50
+ 3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
51
+ the node. It should be retained.
52
+ """
53
+
54
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
55
+ """Main entry point for the identity elimination pass."""
56
+ modified = False
57
+
58
+ # Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
59
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
60
+ if self._try_eliminate_identity_node(node):
61
+ modified = True
62
+
63
+ # Process nodes in functions
64
+ for function in model.functions.values():
65
+ for node in ir.traversal.RecursiveGraphIterator(function):
66
+ if self._try_eliminate_identity_node(node):
67
+ modified = True
68
+
69
+ if modified:
70
+ logger.info("Identity elimination pass modified the model")
71
+
72
+ return ir.passes.PassResult(model, modified=modified)
73
+
74
+ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
75
+ """Try to eliminate a single identity node. Returns True if modified."""
76
+ if node.op_type != "Identity" or node.domain != "":
77
+ return False
78
+
79
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
80
+ # Invalid Identity node, skip
81
+ return False
82
+
83
+ input_value = node.inputs[0]
84
+ output_value = node.outputs[0]
85
+
86
+ if input_value is None:
87
+ # Cannot eliminate if input is None
88
+ return False
89
+
90
+ # Get the graph that contains this node
91
+ graph_like = node.graph
92
+ assert graph_like is not None, "Node must be in a graph"
93
+
94
+ output_is_graph_output = output_value.is_graph_output()
95
+ input_is_graph_input = input_value.is_graph_input()
96
+
97
+ # Case 3: Both output is graph output and input is graph input - keep the node
98
+ if output_is_graph_output and input_is_graph_input:
99
+ return False
100
+
101
+ # Copy over shape/type if the output has more complete information
102
+ input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
103
+ if input_value.type is None:
104
+ input_value.type = output_value.type
105
+
106
+ # Case 1 & 2 (merged): Eliminate the identity node
107
+ # Replace all uses of output with input
108
+ ir.convenience.replace_all_uses_with(
109
+ output_value, input_value, replace_graph_outputs=True
110
+ )
111
+
112
+ # If output is a graph output, we need to rename input and update graph outputs
113
+ if output_is_graph_output:
114
+ # Update the input value to have the output's name
115
+ input_value.name = output_value.name
116
+
117
+ # Remove the identity node
118
+ graph_like.remove(node, safe=True)
119
+ logger.debug("Eliminated identity node: %s", node)
120
+ return True