onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
|
|
21
|
+
"""Check if the initializer should be skipped for deduplication."""
|
|
22
|
+
if initializer.is_graph_input() or initializer.is_graph_output():
|
|
23
|
+
# Skip graph inputs and outputs
|
|
24
|
+
logger.warning(
|
|
25
|
+
"Skipped deduplication of initializer '%s' as it is a graph input or output",
|
|
26
|
+
initializer.name,
|
|
27
|
+
)
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
const_val = initializer.const_value
|
|
31
|
+
if const_val is None:
|
|
32
|
+
# Skip if initializer has no constant value
|
|
33
|
+
logger.warning(
|
|
34
|
+
"Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
|
|
35
|
+
initializer.name,
|
|
36
|
+
)
|
|
37
|
+
return True
|
|
38
|
+
|
|
39
|
+
if const_val.size > size_limit:
|
|
40
|
+
# Skip if the initializer is larger than the size limit
|
|
41
|
+
logger.debug(
|
|
42
|
+
"Skipped initializer '%s' as it exceeds the size limit of %d elements",
|
|
43
|
+
initializer.name,
|
|
44
|
+
size_limit,
|
|
45
|
+
)
|
|
46
|
+
return True
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _tobytes(val):
|
|
51
|
+
"""StringTensor does not support tobytes. Use 'string_data' instead.
|
|
52
|
+
|
|
53
|
+
However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
|
|
54
|
+
cannot be used to index into a dict. To generate keys for identifying
|
|
55
|
+
tensors in initializer deduplication the following converts the list of
|
|
56
|
+
bytes to an array of fixed-length strings which can be flattened into a
|
|
57
|
+
bytes-string. This, together with the tensor shape, is sufficient for
|
|
58
|
+
identifying tensors for deduplication, but it differs from the
|
|
59
|
+
representation used for serializing tensors (that is string_data) by adding
|
|
60
|
+
padding bytes so that each string occupies the same number of consecutive
|
|
61
|
+
bytes in the flattened .tobytes representation.
|
|
62
|
+
"""
|
|
63
|
+
if val.dtype.is_string():
|
|
64
|
+
return np.array(val.string_data()).tobytes()
|
|
65
|
+
return val.tobytes()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
69
|
+
"""Remove duplicated initializer tensors from the main graph and all subgraphs.
|
|
70
|
+
|
|
71
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
72
|
+
and replaces all duplicate references with a canonical one.
|
|
73
|
+
|
|
74
|
+
Initializers are deduplicated within each graph. To deduplicate initializers
|
|
75
|
+
in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
76
|
+
to lift the initializers to the main graph first before running pass.
|
|
77
|
+
|
|
78
|
+
.. versionadded:: 0.1.3
|
|
79
|
+
.. versionchanged:: 0.1.7
|
|
80
|
+
This pass now deduplicates initializers in subgraphs as well.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, size_limit: int = 1024):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.size_limit = size_limit
|
|
86
|
+
|
|
87
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
88
|
+
modified = False
|
|
89
|
+
|
|
90
|
+
for graph in model.graphs():
|
|
91
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
92
|
+
for initializer in tuple(graph.initializers.values()):
|
|
93
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
const_val = initializer.const_value
|
|
97
|
+
assert const_val is not None
|
|
98
|
+
|
|
99
|
+
key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
|
|
100
|
+
if key in initializers:
|
|
101
|
+
modified = True
|
|
102
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
103
|
+
initializer.replace_all_uses_with(initializer_to_keep)
|
|
104
|
+
assert initializer.name is not None
|
|
105
|
+
graph.initializers.pop(initializer.name)
|
|
106
|
+
logger.info(
|
|
107
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
108
|
+
initializer.name,
|
|
109
|
+
initializer_to_keep.name,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
initializers[key] = initializer # type: ignore[index]
|
|
113
|
+
|
|
114
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
|
|
118
|
+
"""Remove duplicated initializer tensors (using a hashed method) from the graph.
|
|
119
|
+
|
|
120
|
+
This pass detects initializers with identical shape, dtype, and hashed content,
|
|
121
|
+
and replaces all duplicate references with a canonical one.
|
|
122
|
+
|
|
123
|
+
This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
|
|
124
|
+
as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
|
|
125
|
+
|
|
126
|
+
.. versionadded:: 0.1.7
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
|
|
130
|
+
super().__init__()
|
|
131
|
+
# 4 GB default size limit for deduplication
|
|
132
|
+
self.size_limit = size_limit
|
|
133
|
+
|
|
134
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
135
|
+
modified = False
|
|
136
|
+
|
|
137
|
+
for graph in model.graphs():
|
|
138
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
|
|
139
|
+
|
|
140
|
+
for initializer in tuple(graph.initializers.values()):
|
|
141
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
const_val = initializer.const_value
|
|
145
|
+
assert const_val is not None
|
|
146
|
+
|
|
147
|
+
# Hash tensor data to avoid storing large amounts of data in memory
|
|
148
|
+
hashed = hashlib.sha512()
|
|
149
|
+
tensor_data = const_val.numpy()
|
|
150
|
+
hashed.update(tensor_data)
|
|
151
|
+
tensor_digest = hashed.hexdigest()
|
|
152
|
+
|
|
153
|
+
tensor_dims = tuple(const_val.shape.numpy())
|
|
154
|
+
|
|
155
|
+
key = (const_val.dtype, tensor_dims, tensor_digest)
|
|
156
|
+
|
|
157
|
+
if key in initializers:
|
|
158
|
+
if _tobytes(initializers[key].const_value) != _tobytes(const_val):
|
|
159
|
+
logger.warning(
|
|
160
|
+
"Initializer deduplication failed: "
|
|
161
|
+
"hashes match but values differ with values %s and %s",
|
|
162
|
+
initializers[key],
|
|
163
|
+
initializer,
|
|
164
|
+
)
|
|
165
|
+
continue
|
|
166
|
+
modified = True
|
|
167
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
168
|
+
initializer.replace_all_uses_with(initializer_to_keep)
|
|
169
|
+
assert initializer.name is not None
|
|
170
|
+
graph.initializers.pop(initializer.name)
|
|
171
|
+
logger.info(
|
|
172
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
173
|
+
initializer.name,
|
|
174
|
+
initializer_to_keep.name,
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
initializers[key] = initializer # type: ignore[index]
|
|
178
|
+
|
|
179
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Implementation of an inliner for onnx_ir."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
|
|
9
|
+
__all__ = ["InlinePass", "InlinePassResult"]
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
import onnx_ir.convenience as _ir_convenience
|
|
16
|
+
from onnx_ir import _cloner
|
|
17
|
+
|
|
18
|
+
# A replacement for a node specifies a list of nodes that replaces the original node,
|
|
19
|
+
# and a list of values that replaces the original node's outputs.
|
|
20
|
+
|
|
21
|
+
NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
|
|
22
|
+
|
|
23
|
+
# A call stack is a list of identifiers of call sites, where the first element is the
|
|
24
|
+
# outermost call site, and the last element is the innermost call site. This is used
|
|
25
|
+
# primarily for generating unique names for values in the inlined functions.
|
|
26
|
+
CallSiteId = str
|
|
27
|
+
CallStack = list[CallSiteId]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
|
|
31
|
+
"""Generate a unique name from a name, calling-context, and set of used names.
|
|
32
|
+
|
|
33
|
+
If there is a name clash, we add a numeric suffix to the name to make
|
|
34
|
+
it unique. We use the same strategy to make node names unique.
|
|
35
|
+
|
|
36
|
+
TODO: We can use the callstack in generating a name for a value X in a function
|
|
37
|
+
that is inlined into a graph. This is not yet implemented. Using the full callstack
|
|
38
|
+
leads to very long and hard to read names. Some investigation is needed to find
|
|
39
|
+
a good naming strategy that will produce useful names for debugging.
|
|
40
|
+
"""
|
|
41
|
+
candidate = name
|
|
42
|
+
i = 1
|
|
43
|
+
while candidate in used_names:
|
|
44
|
+
i += 1
|
|
45
|
+
candidate = f"{name}_{i}"
|
|
46
|
+
used_names.add(candidate)
|
|
47
|
+
return candidate
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _abbreviate(
|
|
51
|
+
function_ids: Iterable[ir.OperatorIdentifier],
|
|
52
|
+
) -> dict[ir.OperatorIdentifier, str]:
|
|
53
|
+
"""Create a short unambiguous abbreviation for all function ids."""
|
|
54
|
+
|
|
55
|
+
def id_abbreviation(id: ir.OperatorIdentifier) -> str:
|
|
56
|
+
"""Create a short unambiguous abbreviation for a function id."""
|
|
57
|
+
domain, name, overload = id
|
|
58
|
+
# Omit the domain, if it remains unambiguous after omitting it.
|
|
59
|
+
if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
|
|
60
|
+
short_domain = domain + "_"
|
|
61
|
+
else:
|
|
62
|
+
short_domain = ""
|
|
63
|
+
if overload != "":
|
|
64
|
+
return short_domain + name + "_" + overload
|
|
65
|
+
return short_domain + name
|
|
66
|
+
|
|
67
|
+
return {id: id_abbreviation(id) for id in function_ids}
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclasses.dataclass
|
|
71
|
+
class InlinePassResult(ir.passes.PassResult):
|
|
72
|
+
id_count: dict[ir.OperatorIdentifier, int]
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class InlinePass(ir.passes.InPlacePass):
|
|
76
|
+
"""Inline model local functions to the main graph and clear function definitions."""
|
|
77
|
+
|
|
78
|
+
def __init__(self) -> None:
|
|
79
|
+
super().__init__()
|
|
80
|
+
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
|
|
81
|
+
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
|
|
82
|
+
self._opset_imports: dict[str, int] = {}
|
|
83
|
+
self.used_value_names: set[str] = set()
|
|
84
|
+
self.used_node_names: set[str] = set()
|
|
85
|
+
self.node_context: dict[ir.Node, CallStack] = {}
|
|
86
|
+
|
|
87
|
+
def _reset(self, model: ir.Model) -> None:
|
|
88
|
+
self._functions = model.functions
|
|
89
|
+
self._function_id_abbreviations = _abbreviate(self._functions.keys())
|
|
90
|
+
self._opset_imports = model.opset_imports
|
|
91
|
+
self.used_value_names = set()
|
|
92
|
+
self.used_node_names = set()
|
|
93
|
+
self.node_context = {}
|
|
94
|
+
|
|
95
|
+
def call(self, model: ir.Model) -> InlinePassResult:
|
|
96
|
+
self._reset(model)
|
|
97
|
+
id_count = self._inline_calls_in(model.graph)
|
|
98
|
+
model.functions.clear()
|
|
99
|
+
return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
|
|
100
|
+
|
|
101
|
+
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
|
|
102
|
+
id = node.op_identifier()
|
|
103
|
+
function = self._functions[id]
|
|
104
|
+
|
|
105
|
+
# check opset compatibility and update the opset imports
|
|
106
|
+
for key, value in function.opset_imports.items():
|
|
107
|
+
if key not in self._opset_imports:
|
|
108
|
+
self._opset_imports[key] = value
|
|
109
|
+
elif self._opset_imports[key] != value:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
# Identify substitutions for both inputs and attributes of the function:
|
|
115
|
+
attributes: Mapping[str, ir.Attr] = node.attributes
|
|
116
|
+
default_attr_values = {
|
|
117
|
+
attr.name: attr
|
|
118
|
+
for attr in function.attributes.values()
|
|
119
|
+
if attr.name not in attributes and attr.value is not None
|
|
120
|
+
}
|
|
121
|
+
if default_attr_values:
|
|
122
|
+
attributes = {**attributes, **default_attr_values}
|
|
123
|
+
if any(
|
|
124
|
+
attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
|
|
125
|
+
for attr in attributes.values()
|
|
126
|
+
):
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"Inliner does not support graph attribute parameters to functions"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
if len(node.inputs) > len(function.inputs):
|
|
132
|
+
raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
|
|
133
|
+
value_map = {}
|
|
134
|
+
for i, input in enumerate(node.inputs):
|
|
135
|
+
value_map[function.inputs[i]] = input
|
|
136
|
+
for i in range(len(node.inputs), len(function.inputs)):
|
|
137
|
+
value_map[function.inputs[i]] = None
|
|
138
|
+
|
|
139
|
+
# Identify call-stack for node, used to generate unique names.
|
|
140
|
+
call_stack = self.node_context.get(node, [])
|
|
141
|
+
new_call_stack = [*call_stack, call_site_id]
|
|
142
|
+
|
|
143
|
+
def rename(node: ir.Node) -> None:
|
|
144
|
+
"""Rename node/values in inlined node to ensure uniqueness in the inlined context."""
|
|
145
|
+
node_name = node.name or "node"
|
|
146
|
+
node.name = _make_unique_name(node_name, new_call_stack, self.used_node_names)
|
|
147
|
+
for output in node.outputs:
|
|
148
|
+
if output is not None:
|
|
149
|
+
output_name = output.name or "val"
|
|
150
|
+
output.name = _make_unique_name(
|
|
151
|
+
output_name, new_call_stack, self.used_value_names
|
|
152
|
+
)
|
|
153
|
+
# Update context in case the new node is itself a call node that will be inlined.
|
|
154
|
+
self.node_context[node] = new_call_stack
|
|
155
|
+
|
|
156
|
+
cloner = _cloner.Cloner(
|
|
157
|
+
attr_map=attributes,
|
|
158
|
+
value_map=value_map,
|
|
159
|
+
metadata_props=node.metadata_props,
|
|
160
|
+
post_process=rename,
|
|
161
|
+
resolve_ref_attrs=True,
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# iterate over the nodes in the function, creating a copy of each node
|
|
165
|
+
# and replacing inputs with the corresponding values in the value map.
|
|
166
|
+
# Update the value map with the new values.
|
|
167
|
+
|
|
168
|
+
nodes = [cloner.clone_node(node) for node in function]
|
|
169
|
+
output_values = [value_map[output] for output in function.outputs]
|
|
170
|
+
return nodes, output_values # type: ignore
|
|
171
|
+
|
|
172
|
+
def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
|
|
173
|
+
for input in graph.inputs:
|
|
174
|
+
if input.name is not None:
|
|
175
|
+
self.used_value_names.add(input.name)
|
|
176
|
+
for initializer in graph.initializers:
|
|
177
|
+
self.used_value_names.add(initializer)
|
|
178
|
+
|
|
179
|
+
# Pre-processing:
|
|
180
|
+
# * Count the number of times each function is called in the graph.
|
|
181
|
+
# This is used for disambiguating names of values in the inlined functions.
|
|
182
|
+
# * And identify names of values that are used in the graph.
|
|
183
|
+
id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
184
|
+
for node in graph:
|
|
185
|
+
if node.name:
|
|
186
|
+
self.used_node_names.add(node.name)
|
|
187
|
+
id = node.op_identifier()
|
|
188
|
+
if id in self._functions:
|
|
189
|
+
id_count[id] += 1
|
|
190
|
+
for output in node.outputs:
|
|
191
|
+
if output.name is not None:
|
|
192
|
+
self.used_value_names.add(output.name)
|
|
193
|
+
next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
194
|
+
for node in graph:
|
|
195
|
+
id = node.op_identifier()
|
|
196
|
+
if id in self._functions:
|
|
197
|
+
# If there are multiple calls to same function, we use a prefix to disambiguate
|
|
198
|
+
# the different call-sites:
|
|
199
|
+
if id_count[id] > 1:
|
|
200
|
+
call_site_prefix = f"_{next_id[id]}"
|
|
201
|
+
next_id[id] += 1
|
|
202
|
+
else:
|
|
203
|
+
call_site_prefix = ""
|
|
204
|
+
call_site = node.name or (
|
|
205
|
+
self._function_id_abbreviations[id] + call_site_prefix
|
|
206
|
+
)
|
|
207
|
+
nodes, values = self._instantiate_call(node, call_site)
|
|
208
|
+
_ir_convenience.replace_nodes_and_values(
|
|
209
|
+
graph,
|
|
210
|
+
insertion_point=node,
|
|
211
|
+
old_nodes=[node],
|
|
212
|
+
new_nodes=nodes,
|
|
213
|
+
old_values=node.outputs,
|
|
214
|
+
new_values=values,
|
|
215
|
+
)
|
|
216
|
+
else:
|
|
217
|
+
for attr in node.attributes.values():
|
|
218
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
219
|
+
self._inline_calls_in(attr.as_graph())
|
|
220
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
221
|
+
for g in attr.as_graphs():
|
|
222
|
+
self._inline_calls_in(g)
|
|
223
|
+
return id_count
|
|
@@ -0,0 +1,280 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Name fix pass for ensuring unique names for all values and nodes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"NameFixPass",
|
|
9
|
+
"NameGenerator",
|
|
10
|
+
"SimpleNameGenerator",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
import collections
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Protocol
|
|
16
|
+
|
|
17
|
+
import onnx_ir as ir
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NameGenerator(Protocol):
|
|
23
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
24
|
+
"""Generate a preferred name for a node."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
28
|
+
"""Generate a preferred name for a value."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SimpleNameGenerator(NameGenerator):
|
|
33
|
+
"""Base class for name generation functions."""
|
|
34
|
+
|
|
35
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
36
|
+
"""Generate a preferred name for a node."""
|
|
37
|
+
return node.name or "node"
|
|
38
|
+
|
|
39
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
40
|
+
"""Generate a preferred name for a value."""
|
|
41
|
+
return value.name or "v"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class NameFixPass(ir.passes.InPlacePass):
|
|
45
|
+
"""Pass for fixing names to ensure all values and nodes have unique names.
|
|
46
|
+
|
|
47
|
+
This pass ensures that:
|
|
48
|
+
1. Graph inputs and outputs have unique names (take precedence)
|
|
49
|
+
2. All intermediate values have unique names (assign names to unnamed values)
|
|
50
|
+
3. All values in subgraphs have unique names within their graph and parent graphs
|
|
51
|
+
4. All nodes have unique names within their graph
|
|
52
|
+
|
|
53
|
+
The pass maintains global uniqueness across the entire model.
|
|
54
|
+
|
|
55
|
+
You can customize the name generation functions for nodes and values by passing
|
|
56
|
+
a subclass of :class:`NameGenerator`.
|
|
57
|
+
|
|
58
|
+
For example, you can use a custom naming scheme like this::
|
|
59
|
+
|
|
60
|
+
class CustomNameGenerator:
|
|
61
|
+
def custom_node_name(node: ir.Node) -> str:
|
|
62
|
+
return f"custom_node_{node.op_type}"
|
|
63
|
+
|
|
64
|
+
def custom_value_name(value: ir.Value) -> str:
|
|
65
|
+
return f"custom_value_{value.type}"
|
|
66
|
+
|
|
67
|
+
name_fix_pass = NameFixPass(name_generator=CustomNameGenerator())
|
|
68
|
+
|
|
69
|
+
.. versionadded:: 0.1.6
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
name_generator: NameGenerator | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initialize the NameFixPass with custom name generation functions.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name_generator (NameGenerator, optional): An instance of a subclass of
|
|
80
|
+
:class:`NameGenerator` to customize name generation for nodes and values.
|
|
81
|
+
If not provided, defaults to a basic implementation that uses
|
|
82
|
+
the node's or value's existing name or a generic name like "node" or "v".
|
|
83
|
+
"""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._name_generator = name_generator or SimpleNameGenerator()
|
|
86
|
+
|
|
87
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
88
|
+
# Process the main graph
|
|
89
|
+
modified = self._fix_graph_names(model.graph)
|
|
90
|
+
|
|
91
|
+
# Process functions
|
|
92
|
+
for function in model.functions.values():
|
|
93
|
+
modified = self._fix_graph_names(function) or modified
|
|
94
|
+
|
|
95
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
96
|
+
|
|
97
|
+
def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
|
|
98
|
+
"""Fix names in a graph and return whether modifications were made."""
|
|
99
|
+
modified = False
|
|
100
|
+
|
|
101
|
+
# Set to track which values have been assigned names
|
|
102
|
+
seen_values: set[ir.Value] = set()
|
|
103
|
+
|
|
104
|
+
# The first set is a dummy placeholder so that there is always a [-1] scope for access
|
|
105
|
+
# (even though we don't write to it)
|
|
106
|
+
scoped_used_value_names: list[set[str]] = [set()]
|
|
107
|
+
scoped_used_node_names: list[set[str]] = [set()]
|
|
108
|
+
|
|
109
|
+
# Counters for generating unique names (using list to pass by reference)
|
|
110
|
+
value_counter: collections.Counter[str] = collections.Counter()
|
|
111
|
+
node_counter: collections.Counter[str] = collections.Counter()
|
|
112
|
+
|
|
113
|
+
def enter_graph(graph_like) -> None:
|
|
114
|
+
"""Callback for entering a subgraph."""
|
|
115
|
+
# Initialize new scopes with all names from the parent scope
|
|
116
|
+
scoped_used_value_names.append(set(scoped_used_value_names[-1]))
|
|
117
|
+
scoped_used_node_names.append(set())
|
|
118
|
+
|
|
119
|
+
nonlocal modified
|
|
120
|
+
|
|
121
|
+
# Step 1: Fix graph input names first (they have precedence)
|
|
122
|
+
for input_value in graph_like.inputs:
|
|
123
|
+
if self._process_value(
|
|
124
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
125
|
+
):
|
|
126
|
+
modified = True
|
|
127
|
+
|
|
128
|
+
# Step 2: Fix graph output names (they have precedence)
|
|
129
|
+
for output_value in graph_like.outputs:
|
|
130
|
+
if self._process_value(
|
|
131
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
132
|
+
):
|
|
133
|
+
modified = True
|
|
134
|
+
|
|
135
|
+
if isinstance(graph_like, ir.Graph):
|
|
136
|
+
# For graphs, also fix initializers
|
|
137
|
+
for initializer in graph_like.initializers.values():
|
|
138
|
+
if self._process_value(
|
|
139
|
+
initializer, scoped_used_value_names[-1], seen_values, value_counter
|
|
140
|
+
):
|
|
141
|
+
modified = True
|
|
142
|
+
|
|
143
|
+
def exit_graph(_) -> None:
|
|
144
|
+
"""Callback for exiting a subgraph."""
|
|
145
|
+
# Pop the current scope
|
|
146
|
+
scoped_used_value_names.pop()
|
|
147
|
+
scoped_used_node_names.pop()
|
|
148
|
+
|
|
149
|
+
# Step 3: Process all nodes and their values
|
|
150
|
+
for node in ir.traversal.RecursiveGraphIterator(
|
|
151
|
+
graph_like, enter_graph=enter_graph, exit_graph=exit_graph
|
|
152
|
+
):
|
|
153
|
+
# Fix node name
|
|
154
|
+
if not node.name:
|
|
155
|
+
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
|
|
156
|
+
modified = True
|
|
157
|
+
else:
|
|
158
|
+
if self._fix_duplicate_node_name(
|
|
159
|
+
node, scoped_used_node_names[-1], node_counter
|
|
160
|
+
):
|
|
161
|
+
modified = True
|
|
162
|
+
|
|
163
|
+
# Fix input value names (only if not already processed)
|
|
164
|
+
for input_value in node.inputs:
|
|
165
|
+
if input_value is not None:
|
|
166
|
+
if self._process_value(
|
|
167
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
168
|
+
):
|
|
169
|
+
modified = True
|
|
170
|
+
|
|
171
|
+
# Fix output value names (only if not already processed)
|
|
172
|
+
for output_value in node.outputs:
|
|
173
|
+
if self._process_value(
|
|
174
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
175
|
+
):
|
|
176
|
+
modified = True
|
|
177
|
+
|
|
178
|
+
return modified
|
|
179
|
+
|
|
180
|
+
def _process_value(
|
|
181
|
+
self,
|
|
182
|
+
value: ir.Value,
|
|
183
|
+
used_value_names: set[str],
|
|
184
|
+
seen_values: set[ir.Value],
|
|
185
|
+
value_counter: collections.Counter[str],
|
|
186
|
+
) -> bool:
|
|
187
|
+
"""Process a value only if it hasn't been processed before."""
|
|
188
|
+
if value in seen_values:
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
modified = False
|
|
192
|
+
|
|
193
|
+
if not value.name:
|
|
194
|
+
modified = self._assign_value_name(value, used_value_names, value_counter)
|
|
195
|
+
else:
|
|
196
|
+
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
|
|
197
|
+
# initializers dictionary is updated automatically when the Value is renamed
|
|
198
|
+
|
|
199
|
+
# Record the final name for this value
|
|
200
|
+
assert value.name is not None
|
|
201
|
+
seen_values.add(value)
|
|
202
|
+
return modified
|
|
203
|
+
|
|
204
|
+
def _assign_value_name(
|
|
205
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter[str]
|
|
206
|
+
) -> bool:
|
|
207
|
+
"""Assign a name to an unnamed value. Returns True if modified."""
|
|
208
|
+
assert not value.name, (
|
|
209
|
+
"value should not have a name already if function is called correctly"
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
preferred_name = self._name_generator.generate_value_name(value)
|
|
213
|
+
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
214
|
+
logger.debug("Assigned name %s to unnamed value", value.name)
|
|
215
|
+
return True
|
|
216
|
+
|
|
217
|
+
def _assign_node_name(
|
|
218
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter[str]
|
|
219
|
+
) -> bool:
|
|
220
|
+
"""Assign a name to an unnamed node. Returns True if modified."""
|
|
221
|
+
assert not node.name, (
|
|
222
|
+
"node should not have a name already if function is called correctly"
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
preferred_name = self._name_generator.generate_node_name(node)
|
|
226
|
+
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
227
|
+
logger.debug("Assigned name %s to unnamed node", node.name)
|
|
228
|
+
return True
|
|
229
|
+
|
|
230
|
+
def _fix_duplicate_value_name(
|
|
231
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter[str]
|
|
232
|
+
) -> bool:
|
|
233
|
+
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
|
|
234
|
+
original_name = value.name
|
|
235
|
+
|
|
236
|
+
assert original_name, (
|
|
237
|
+
"value should have a name already if function is called correctly"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
if original_name not in used_names:
|
|
241
|
+
# Name is unique, just record it
|
|
242
|
+
used_names.add(original_name)
|
|
243
|
+
return False
|
|
244
|
+
|
|
245
|
+
# If name is already used, make it unique
|
|
246
|
+
base_name = self._name_generator.generate_value_name(value)
|
|
247
|
+
value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
248
|
+
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
|
|
249
|
+
return True
|
|
250
|
+
|
|
251
|
+
def _fix_duplicate_node_name(
|
|
252
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter[str]
|
|
253
|
+
) -> bool:
|
|
254
|
+
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
|
|
255
|
+
original_name = node.name
|
|
256
|
+
|
|
257
|
+
assert original_name, "node should have a name already if function is called correctly"
|
|
258
|
+
|
|
259
|
+
if original_name not in used_names:
|
|
260
|
+
# Name is unique, just record it
|
|
261
|
+
used_names.add(original_name)
|
|
262
|
+
return False
|
|
263
|
+
|
|
264
|
+
# If name is already used, make it unique
|
|
265
|
+
base_name = self._name_generator.generate_node_name(node)
|
|
266
|
+
node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
267
|
+
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
|
|
268
|
+
return True
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def _find_and_record_next_unique_name(
|
|
272
|
+
preferred_name: str, used_names: set[str], counter: collections.Counter[str]
|
|
273
|
+
) -> str:
|
|
274
|
+
"""Generate a unique name based on the preferred name and current counter."""
|
|
275
|
+
new_name = preferred_name
|
|
276
|
+
while new_name in used_names:
|
|
277
|
+
counter[preferred_name] += 1
|
|
278
|
+
new_name = f"{preferred_name}_{counter[preferred_name]}"
|
|
279
|
+
used_names.add(new_name)
|
|
280
|
+
return new_name
|