onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,179 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
8
+
9
+
10
+ import hashlib
11
+ import logging
12
+
13
+ import numpy as np
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
21
+ """Check if the initializer should be skipped for deduplication."""
22
+ if initializer.is_graph_input() or initializer.is_graph_output():
23
+ # Skip graph inputs and outputs
24
+ logger.warning(
25
+ "Skipped deduplication of initializer '%s' as it is a graph input or output",
26
+ initializer.name,
27
+ )
28
+ return True
29
+
30
+ const_val = initializer.const_value
31
+ if const_val is None:
32
+ # Skip if initializer has no constant value
33
+ logger.warning(
34
+ "Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
35
+ initializer.name,
36
+ )
37
+ return True
38
+
39
+ if const_val.size > size_limit:
40
+ # Skip if the initializer is larger than the size limit
41
+ logger.debug(
42
+ "Skipped initializer '%s' as it exceeds the size limit of %d elements",
43
+ initializer.name,
44
+ size_limit,
45
+ )
46
+ return True
47
+ return False
48
+
49
+
50
+ def _tobytes(val):
51
+ """StringTensor does not support tobytes. Use 'string_data' instead.
52
+
53
+ However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
54
+ cannot be used to index into a dict. To generate keys for identifying
55
+ tensors in initializer deduplication the following converts the list of
56
+ bytes to an array of fixed-length strings which can be flattened into a
57
+ bytes-string. This, together with the tensor shape, is sufficient for
58
+ identifying tensors for deduplication, but it differs from the
59
+ representation used for serializing tensors (that is string_data) by adding
60
+ padding bytes so that each string occupies the same number of consecutive
61
+ bytes in the flattened .tobytes representation.
62
+ """
63
+ if val.dtype.is_string():
64
+ return np.array(val.string_data()).tobytes()
65
+ return val.tobytes()
66
+
67
+
68
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
69
+ """Remove duplicated initializer tensors from the main graph and all subgraphs.
70
+
71
+ This pass detects initializers with identical shape, dtype, and content,
72
+ and replaces all duplicate references with a canonical one.
73
+
74
+ Initializers are deduplicated within each graph. To deduplicate initializers
75
+ in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
76
+ to lift the initializers to the main graph first before running pass.
77
+
78
+ .. versionadded:: 0.1.3
79
+ .. versionchanged:: 0.1.7
80
+ This pass now deduplicates initializers in subgraphs as well.
81
+ """
82
+
83
+ def __init__(self, size_limit: int = 1024):
84
+ super().__init__()
85
+ self.size_limit = size_limit
86
+
87
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
88
+ modified = False
89
+
90
+ for graph in model.graphs():
91
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
92
+ for initializer in tuple(graph.initializers.values()):
93
+ if _should_skip_initializer(initializer, self.size_limit):
94
+ continue
95
+
96
+ const_val = initializer.const_value
97
+ assert const_val is not None
98
+
99
+ key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
100
+ if key in initializers:
101
+ modified = True
102
+ initializer_to_keep = initializers[key] # type: ignore[index]
103
+ initializer.replace_all_uses_with(initializer_to_keep)
104
+ assert initializer.name is not None
105
+ graph.initializers.pop(initializer.name)
106
+ logger.info(
107
+ "Replaced initializer '%s' with existing initializer '%s'",
108
+ initializer.name,
109
+ initializer_to_keep.name,
110
+ )
111
+ else:
112
+ initializers[key] = initializer # type: ignore[index]
113
+
114
+ return ir.passes.PassResult(model=model, modified=modified)
115
+
116
+
117
+ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
118
+ """Remove duplicated initializer tensors (using a hashed method) from the graph.
119
+
120
+ This pass detects initializers with identical shape, dtype, and hashed content,
121
+ and replaces all duplicate references with a canonical one.
122
+
123
+ This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
124
+ as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
125
+
126
+ .. versionadded:: 0.1.7
127
+ """
128
+
129
+ def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
130
+ super().__init__()
131
+ # 4 GB default size limit for deduplication
132
+ self.size_limit = size_limit
133
+
134
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
135
+ modified = False
136
+
137
+ for graph in model.graphs():
138
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
139
+
140
+ for initializer in tuple(graph.initializers.values()):
141
+ if _should_skip_initializer(initializer, self.size_limit):
142
+ continue
143
+
144
+ const_val = initializer.const_value
145
+ assert const_val is not None
146
+
147
+ # Hash tensor data to avoid storing large amounts of data in memory
148
+ hashed = hashlib.sha512()
149
+ tensor_data = const_val.numpy()
150
+ hashed.update(tensor_data)
151
+ tensor_digest = hashed.hexdigest()
152
+
153
+ tensor_dims = tuple(const_val.shape.numpy())
154
+
155
+ key = (const_val.dtype, tensor_dims, tensor_digest)
156
+
157
+ if key in initializers:
158
+ if _tobytes(initializers[key].const_value) != _tobytes(const_val):
159
+ logger.warning(
160
+ "Initializer deduplication failed: "
161
+ "hashes match but values differ with values %s and %s",
162
+ initializers[key],
163
+ initializer,
164
+ )
165
+ continue
166
+ modified = True
167
+ initializer_to_keep = initializers[key] # type: ignore[index]
168
+ initializer.replace_all_uses_with(initializer_to_keep)
169
+ assert initializer.name is not None
170
+ graph.initializers.pop(initializer.name)
171
+ logger.info(
172
+ "Replaced initializer '%s' with existing initializer '%s'",
173
+ initializer.name,
174
+ initializer_to_keep.name,
175
+ )
176
+ else:
177
+ initializers[key] = initializer # type: ignore[index]
178
+
179
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -0,0 +1,223 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Implementation of an inliner for onnx_ir."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+
9
+ __all__ = ["InlinePass", "InlinePassResult"]
10
+
11
+ from collections import defaultdict
12
+ from collections.abc import Iterable, Mapping, Sequence
13
+
14
+ import onnx_ir as ir
15
+ import onnx_ir.convenience as _ir_convenience
16
+ from onnx_ir import _cloner
17
+
18
+ # A replacement for a node specifies a list of nodes that replaces the original node,
19
+ # and a list of values that replaces the original node's outputs.
20
+
21
+ NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
22
+
23
+ # A call stack is a list of identifiers of call sites, where the first element is the
24
+ # outermost call site, and the last element is the innermost call site. This is used
25
+ # primarily for generating unique names for values in the inlined functions.
26
+ CallSiteId = str
27
+ CallStack = list[CallSiteId]
28
+
29
+
30
+ def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
31
+ """Generate a unique name from a name, calling-context, and set of used names.
32
+
33
+ If there is a name clash, we add a numeric suffix to the name to make
34
+ it unique. We use the same strategy to make node names unique.
35
+
36
+ TODO: We can use the callstack in generating a name for a value X in a function
37
+ that is inlined into a graph. This is not yet implemented. Using the full callstack
38
+ leads to very long and hard to read names. Some investigation is needed to find
39
+ a good naming strategy that will produce useful names for debugging.
40
+ """
41
+ candidate = name
42
+ i = 1
43
+ while candidate in used_names:
44
+ i += 1
45
+ candidate = f"{name}_{i}"
46
+ used_names.add(candidate)
47
+ return candidate
48
+
49
+
50
+ def _abbreviate(
51
+ function_ids: Iterable[ir.OperatorIdentifier],
52
+ ) -> dict[ir.OperatorIdentifier, str]:
53
+ """Create a short unambiguous abbreviation for all function ids."""
54
+
55
+ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
56
+ """Create a short unambiguous abbreviation for a function id."""
57
+ domain, name, overload = id
58
+ # Omit the domain, if it remains unambiguous after omitting it.
59
+ if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
60
+ short_domain = domain + "_"
61
+ else:
62
+ short_domain = ""
63
+ if overload != "":
64
+ return short_domain + name + "_" + overload
65
+ return short_domain + name
66
+
67
+ return {id: id_abbreviation(id) for id in function_ids}
68
+
69
+
70
+ @dataclasses.dataclass
71
+ class InlinePassResult(ir.passes.PassResult):
72
+ id_count: dict[ir.OperatorIdentifier, int]
73
+
74
+
75
+ class InlinePass(ir.passes.InPlacePass):
76
+ """Inline model local functions to the main graph and clear function definitions."""
77
+
78
+ def __init__(self) -> None:
79
+ super().__init__()
80
+ self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
81
+ self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
82
+ self._opset_imports: dict[str, int] = {}
83
+ self.used_value_names: set[str] = set()
84
+ self.used_node_names: set[str] = set()
85
+ self.node_context: dict[ir.Node, CallStack] = {}
86
+
87
+ def _reset(self, model: ir.Model) -> None:
88
+ self._functions = model.functions
89
+ self._function_id_abbreviations = _abbreviate(self._functions.keys())
90
+ self._opset_imports = model.opset_imports
91
+ self.used_value_names = set()
92
+ self.used_node_names = set()
93
+ self.node_context = {}
94
+
95
+ def call(self, model: ir.Model) -> InlinePassResult:
96
+ self._reset(model)
97
+ id_count = self._inline_calls_in(model.graph)
98
+ model.functions.clear()
99
+ return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
100
+
101
+ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
102
+ id = node.op_identifier()
103
+ function = self._functions[id]
104
+
105
+ # check opset compatibility and update the opset imports
106
+ for key, value in function.opset_imports.items():
107
+ if key not in self._opset_imports:
108
+ self._opset_imports[key] = value
109
+ elif self._opset_imports[key] != value:
110
+ raise ValueError(
111
+ f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
112
+ )
113
+
114
+ # Identify substitutions for both inputs and attributes of the function:
115
+ attributes: Mapping[str, ir.Attr] = node.attributes
116
+ default_attr_values = {
117
+ attr.name: attr
118
+ for attr in function.attributes.values()
119
+ if attr.name not in attributes and attr.value is not None
120
+ }
121
+ if default_attr_values:
122
+ attributes = {**attributes, **default_attr_values}
123
+ if any(
124
+ attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
125
+ for attr in attributes.values()
126
+ ):
127
+ raise ValueError(
128
+ "Inliner does not support graph attribute parameters to functions"
129
+ )
130
+
131
+ if len(node.inputs) > len(function.inputs):
132
+ raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
133
+ value_map = {}
134
+ for i, input in enumerate(node.inputs):
135
+ value_map[function.inputs[i]] = input
136
+ for i in range(len(node.inputs), len(function.inputs)):
137
+ value_map[function.inputs[i]] = None
138
+
139
+ # Identify call-stack for node, used to generate unique names.
140
+ call_stack = self.node_context.get(node, [])
141
+ new_call_stack = [*call_stack, call_site_id]
142
+
143
+ def rename(node: ir.Node) -> None:
144
+ """Rename node/values in inlined node to ensure uniqueness in the inlined context."""
145
+ node_name = node.name or "node"
146
+ node.name = _make_unique_name(node_name, new_call_stack, self.used_node_names)
147
+ for output in node.outputs:
148
+ if output is not None:
149
+ output_name = output.name or "val"
150
+ output.name = _make_unique_name(
151
+ output_name, new_call_stack, self.used_value_names
152
+ )
153
+ # Update context in case the new node is itself a call node that will be inlined.
154
+ self.node_context[node] = new_call_stack
155
+
156
+ cloner = _cloner.Cloner(
157
+ attr_map=attributes,
158
+ value_map=value_map,
159
+ metadata_props=node.metadata_props,
160
+ post_process=rename,
161
+ resolve_ref_attrs=True,
162
+ )
163
+
164
+ # iterate over the nodes in the function, creating a copy of each node
165
+ # and replacing inputs with the corresponding values in the value map.
166
+ # Update the value map with the new values.
167
+
168
+ nodes = [cloner.clone_node(node) for node in function]
169
+ output_values = [value_map[output] for output in function.outputs]
170
+ return nodes, output_values # type: ignore
171
+
172
+ def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
173
+ for input in graph.inputs:
174
+ if input.name is not None:
175
+ self.used_value_names.add(input.name)
176
+ for initializer in graph.initializers:
177
+ self.used_value_names.add(initializer)
178
+
179
+ # Pre-processing:
180
+ # * Count the number of times each function is called in the graph.
181
+ # This is used for disambiguating names of values in the inlined functions.
182
+ # * And identify names of values that are used in the graph.
183
+ id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
184
+ for node in graph:
185
+ if node.name:
186
+ self.used_node_names.add(node.name)
187
+ id = node.op_identifier()
188
+ if id in self._functions:
189
+ id_count[id] += 1
190
+ for output in node.outputs:
191
+ if output.name is not None:
192
+ self.used_value_names.add(output.name)
193
+ next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
194
+ for node in graph:
195
+ id = node.op_identifier()
196
+ if id in self._functions:
197
+ # If there are multiple calls to same function, we use a prefix to disambiguate
198
+ # the different call-sites:
199
+ if id_count[id] > 1:
200
+ call_site_prefix = f"_{next_id[id]}"
201
+ next_id[id] += 1
202
+ else:
203
+ call_site_prefix = ""
204
+ call_site = node.name or (
205
+ self._function_id_abbreviations[id] + call_site_prefix
206
+ )
207
+ nodes, values = self._instantiate_call(node, call_site)
208
+ _ir_convenience.replace_nodes_and_values(
209
+ graph,
210
+ insertion_point=node,
211
+ old_nodes=[node],
212
+ new_nodes=nodes,
213
+ old_values=node.outputs,
214
+ new_values=values,
215
+ )
216
+ else:
217
+ for attr in node.attributes.values():
218
+ if attr.type == ir.AttributeType.GRAPH:
219
+ self._inline_calls_in(attr.as_graph())
220
+ elif attr.type == ir.AttributeType.GRAPHS:
221
+ for g in attr.as_graphs():
222
+ self._inline_calls_in(g)
223
+ return id_count
@@ -0,0 +1,280 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Name fix pass for ensuring unique names for all values and nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "NameFixPass",
9
+ "NameGenerator",
10
+ "SimpleNameGenerator",
11
+ ]
12
+
13
+ import collections
14
+ import logging
15
+ from typing import Protocol
16
+
17
+ import onnx_ir as ir
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class NameGenerator(Protocol):
23
+ def generate_node_name(self, node: ir.Node) -> str:
24
+ """Generate a preferred name for a node."""
25
+ ...
26
+
27
+ def generate_value_name(self, value: ir.Value) -> str:
28
+ """Generate a preferred name for a value."""
29
+ ...
30
+
31
+
32
+ class SimpleNameGenerator(NameGenerator):
33
+ """Base class for name generation functions."""
34
+
35
+ def generate_node_name(self, node: ir.Node) -> str:
36
+ """Generate a preferred name for a node."""
37
+ return node.name or "node"
38
+
39
+ def generate_value_name(self, value: ir.Value) -> str:
40
+ """Generate a preferred name for a value."""
41
+ return value.name or "v"
42
+
43
+
44
+ class NameFixPass(ir.passes.InPlacePass):
45
+ """Pass for fixing names to ensure all values and nodes have unique names.
46
+
47
+ This pass ensures that:
48
+ 1. Graph inputs and outputs have unique names (take precedence)
49
+ 2. All intermediate values have unique names (assign names to unnamed values)
50
+ 3. All values in subgraphs have unique names within their graph and parent graphs
51
+ 4. All nodes have unique names within their graph
52
+
53
+ The pass maintains global uniqueness across the entire model.
54
+
55
+ You can customize the name generation functions for nodes and values by passing
56
+ a subclass of :class:`NameGenerator`.
57
+
58
+ For example, you can use a custom naming scheme like this::
59
+
60
+ class CustomNameGenerator:
61
+ def custom_node_name(node: ir.Node) -> str:
62
+ return f"custom_node_{node.op_type}"
63
+
64
+ def custom_value_name(value: ir.Value) -> str:
65
+ return f"custom_value_{value.type}"
66
+
67
+ name_fix_pass = NameFixPass(name_generator=CustomNameGenerator())
68
+
69
+ .. versionadded:: 0.1.6
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ name_generator: NameGenerator | None = None,
75
+ ) -> None:
76
+ """Initialize the NameFixPass with custom name generation functions.
77
+
78
+ Args:
79
+ name_generator (NameGenerator, optional): An instance of a subclass of
80
+ :class:`NameGenerator` to customize name generation for nodes and values.
81
+ If not provided, defaults to a basic implementation that uses
82
+ the node's or value's existing name or a generic name like "node" or "v".
83
+ """
84
+ super().__init__()
85
+ self._name_generator = name_generator or SimpleNameGenerator()
86
+
87
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
88
+ # Process the main graph
89
+ modified = self._fix_graph_names(model.graph)
90
+
91
+ # Process functions
92
+ for function in model.functions.values():
93
+ modified = self._fix_graph_names(function) or modified
94
+
95
+ return ir.passes.PassResult(model, modified=modified)
96
+
97
+ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
98
+ """Fix names in a graph and return whether modifications were made."""
99
+ modified = False
100
+
101
+ # Set to track which values have been assigned names
102
+ seen_values: set[ir.Value] = set()
103
+
104
+ # The first set is a dummy placeholder so that there is always a [-1] scope for access
105
+ # (even though we don't write to it)
106
+ scoped_used_value_names: list[set[str]] = [set()]
107
+ scoped_used_node_names: list[set[str]] = [set()]
108
+
109
+ # Counters for generating unique names (using list to pass by reference)
110
+ value_counter: collections.Counter[str] = collections.Counter()
111
+ node_counter: collections.Counter[str] = collections.Counter()
112
+
113
+ def enter_graph(graph_like) -> None:
114
+ """Callback for entering a subgraph."""
115
+ # Initialize new scopes with all names from the parent scope
116
+ scoped_used_value_names.append(set(scoped_used_value_names[-1]))
117
+ scoped_used_node_names.append(set())
118
+
119
+ nonlocal modified
120
+
121
+ # Step 1: Fix graph input names first (they have precedence)
122
+ for input_value in graph_like.inputs:
123
+ if self._process_value(
124
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
125
+ ):
126
+ modified = True
127
+
128
+ # Step 2: Fix graph output names (they have precedence)
129
+ for output_value in graph_like.outputs:
130
+ if self._process_value(
131
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
132
+ ):
133
+ modified = True
134
+
135
+ if isinstance(graph_like, ir.Graph):
136
+ # For graphs, also fix initializers
137
+ for initializer in graph_like.initializers.values():
138
+ if self._process_value(
139
+ initializer, scoped_used_value_names[-1], seen_values, value_counter
140
+ ):
141
+ modified = True
142
+
143
+ def exit_graph(_) -> None:
144
+ """Callback for exiting a subgraph."""
145
+ # Pop the current scope
146
+ scoped_used_value_names.pop()
147
+ scoped_used_node_names.pop()
148
+
149
+ # Step 3: Process all nodes and their values
150
+ for node in ir.traversal.RecursiveGraphIterator(
151
+ graph_like, enter_graph=enter_graph, exit_graph=exit_graph
152
+ ):
153
+ # Fix node name
154
+ if not node.name:
155
+ if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
156
+ modified = True
157
+ else:
158
+ if self._fix_duplicate_node_name(
159
+ node, scoped_used_node_names[-1], node_counter
160
+ ):
161
+ modified = True
162
+
163
+ # Fix input value names (only if not already processed)
164
+ for input_value in node.inputs:
165
+ if input_value is not None:
166
+ if self._process_value(
167
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
168
+ ):
169
+ modified = True
170
+
171
+ # Fix output value names (only if not already processed)
172
+ for output_value in node.outputs:
173
+ if self._process_value(
174
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
175
+ ):
176
+ modified = True
177
+
178
+ return modified
179
+
180
+ def _process_value(
181
+ self,
182
+ value: ir.Value,
183
+ used_value_names: set[str],
184
+ seen_values: set[ir.Value],
185
+ value_counter: collections.Counter[str],
186
+ ) -> bool:
187
+ """Process a value only if it hasn't been processed before."""
188
+ if value in seen_values:
189
+ return False
190
+
191
+ modified = False
192
+
193
+ if not value.name:
194
+ modified = self._assign_value_name(value, used_value_names, value_counter)
195
+ else:
196
+ modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
197
+ # initializers dictionary is updated automatically when the Value is renamed
198
+
199
+ # Record the final name for this value
200
+ assert value.name is not None
201
+ seen_values.add(value)
202
+ return modified
203
+
204
+ def _assign_value_name(
205
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter[str]
206
+ ) -> bool:
207
+ """Assign a name to an unnamed value. Returns True if modified."""
208
+ assert not value.name, (
209
+ "value should not have a name already if function is called correctly"
210
+ )
211
+
212
+ preferred_name = self._name_generator.generate_value_name(value)
213
+ value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
214
+ logger.debug("Assigned name %s to unnamed value", value.name)
215
+ return True
216
+
217
+ def _assign_node_name(
218
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter[str]
219
+ ) -> bool:
220
+ """Assign a name to an unnamed node. Returns True if modified."""
221
+ assert not node.name, (
222
+ "node should not have a name already if function is called correctly"
223
+ )
224
+
225
+ preferred_name = self._name_generator.generate_node_name(node)
226
+ node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
227
+ logger.debug("Assigned name %s to unnamed node", node.name)
228
+ return True
229
+
230
+ def _fix_duplicate_value_name(
231
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter[str]
232
+ ) -> bool:
233
+ """Fix a value's name if it conflicts with existing names. Returns True if modified."""
234
+ original_name = value.name
235
+
236
+ assert original_name, (
237
+ "value should have a name already if function is called correctly"
238
+ )
239
+
240
+ if original_name not in used_names:
241
+ # Name is unique, just record it
242
+ used_names.add(original_name)
243
+ return False
244
+
245
+ # If name is already used, make it unique
246
+ base_name = self._name_generator.generate_value_name(value)
247
+ value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
248
+ logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
249
+ return True
250
+
251
+ def _fix_duplicate_node_name(
252
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter[str]
253
+ ) -> bool:
254
+ """Fix a node's name if it conflicts with existing names. Returns True if modified."""
255
+ original_name = node.name
256
+
257
+ assert original_name, "node should have a name already if function is called correctly"
258
+
259
+ if original_name not in used_names:
260
+ # Name is unique, just record it
261
+ used_names.add(original_name)
262
+ return False
263
+
264
+ # If name is already used, make it unique
265
+ base_name = self._name_generator.generate_node_name(node)
266
+ node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
267
+ logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
268
+ return True
269
+
270
+
271
+ def _find_and_record_next_unique_name(
272
+ preferred_name: str, used_names: set[str], counter: collections.Counter[str]
273
+ ) -> str:
274
+ """Generate a unique name based on the preferred name and current counter."""
275
+ new_name = preferred_name
276
+ while new_name in used_names:
277
+ counter[preferred_name] += 1
278
+ new_name = f"{preferred_name}_{counter[preferred_name]}"
279
+ used_names.add(new_name)
280
+ return new_name