onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +874 -257
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +40 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +217 -0
- onnx_ir/passes/common/inliner.py +332 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.1.dist-info/METADATA +53 -0
- onnx_ir-0.1.1.dist-info/RECORD +42 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Lift constants to initializers."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AddInitializersToInputsPass",
|
|
9
|
+
"LiftConstantsToInitializersPass",
|
|
10
|
+
"LiftSubgraphInitializersToMainGraphPass",
|
|
11
|
+
"RemoveInitializersFromInputsPass",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
import onnx_ir as ir
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
|
|
24
|
+
"""Lift constants to initializers.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
|
|
28
|
+
Default to False, where only Constants with the ``value`` attribute are lifted.
|
|
29
|
+
size_limit: The minimum size of the tensor to be lifted. If the tensor contains
|
|
30
|
+
number of elements less than size_limit, it will not be lifted. Default is 16.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.lift_all_constants = lift_all_constants
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
count = 0
|
|
40
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
41
|
+
assert node.graph is not None
|
|
42
|
+
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
|
|
43
|
+
continue
|
|
44
|
+
if node.outputs[0].is_graph_output():
|
|
45
|
+
logger.debug(
|
|
46
|
+
"Constant node '%s' is used as output, so it can't be lifted.", node.name
|
|
47
|
+
)
|
|
48
|
+
continue
|
|
49
|
+
constant_node_attribute = set(node.attributes.keys())
|
|
50
|
+
if len(constant_node_attribute) != 1:
|
|
51
|
+
logger.debug(
|
|
52
|
+
"Invalid constant node '%s' has more than one attribute", node.name
|
|
53
|
+
)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
attr_name, attr_value = next(iter(node.attributes.items()))
|
|
57
|
+
initializer_name = node.outputs[0].name
|
|
58
|
+
assert initializer_name is not None
|
|
59
|
+
assert isinstance(attr_value, ir.Attr)
|
|
60
|
+
tensor = self._constant_node_attribute_to_tensor(
|
|
61
|
+
node, attr_name, attr_value, initializer_name
|
|
62
|
+
)
|
|
63
|
+
if tensor is None:
|
|
64
|
+
# The reason of None is logged in _constant_node_attribute_to_tensor
|
|
65
|
+
continue
|
|
66
|
+
# Register an initializer with the tensor value
|
|
67
|
+
initializer = ir.Value(
|
|
68
|
+
name=initializer_name,
|
|
69
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
70
|
+
type=ir.TensorType(tensor.dtype),
|
|
71
|
+
const_value=tensor,
|
|
72
|
+
)
|
|
73
|
+
assert node.graph is not None
|
|
74
|
+
node.graph.register_initializer(initializer)
|
|
75
|
+
# Replace the constant node with the initializer
|
|
76
|
+
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
|
|
77
|
+
node.graph.remove(node, safe=True)
|
|
78
|
+
count += 1
|
|
79
|
+
logger.debug(
|
|
80
|
+
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
|
|
81
|
+
)
|
|
82
|
+
if count:
|
|
83
|
+
logger.debug("Lifted %s constants to initializers", count)
|
|
84
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
85
|
+
|
|
86
|
+
def _constant_node_attribute_to_tensor(
|
|
87
|
+
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
|
|
88
|
+
) -> ir.TensorProtocol | None:
|
|
89
|
+
"""Convert constant node attribute to tensor."""
|
|
90
|
+
if not self.lift_all_constants and attr_name != "value":
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
|
|
93
|
+
)
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
tensor: ir.TensorProtocol
|
|
97
|
+
if attr_name == "value":
|
|
98
|
+
tensor = attr_value.as_tensor()
|
|
99
|
+
elif attr_name == "value_int":
|
|
100
|
+
tensor = ir.tensor(
|
|
101
|
+
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
|
|
102
|
+
)
|
|
103
|
+
elif attr_name == "value_ints":
|
|
104
|
+
tensor = ir.tensor(
|
|
105
|
+
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
|
|
106
|
+
)
|
|
107
|
+
elif attr_name == "value_float":
|
|
108
|
+
tensor = ir.tensor(
|
|
109
|
+
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
110
|
+
)
|
|
111
|
+
elif attr_name == "value_floats":
|
|
112
|
+
tensor = ir.tensor(
|
|
113
|
+
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
114
|
+
)
|
|
115
|
+
elif attr_name in ("value_string", "value_strings"):
|
|
116
|
+
tensor = ir.StringTensor(
|
|
117
|
+
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if tensor.size < self.size_limit:
|
|
125
|
+
logger.debug(
|
|
126
|
+
"Tensor from node '%s' has less than %s elements",
|
|
127
|
+
node.name,
|
|
128
|
+
self.size_limit,
|
|
129
|
+
)
|
|
130
|
+
return None
|
|
131
|
+
return tensor
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
135
|
+
"""Lift subgraph initializers to main graph.
|
|
136
|
+
|
|
137
|
+
This pass lifts the initializers of a subgraph to the main graph.
|
|
138
|
+
It is used to ensure that the initializers are available in the main graph
|
|
139
|
+
for further processing or optimization.
|
|
140
|
+
|
|
141
|
+
Initializers that are also graph inputs will not be lifted.
|
|
142
|
+
"""
|
|
143
|
+
|
|
144
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
145
|
+
count = 0
|
|
146
|
+
registered_initializer_names: dict[str, int] = {}
|
|
147
|
+
for graph in model.graphs():
|
|
148
|
+
if graph is model.graph:
|
|
149
|
+
continue
|
|
150
|
+
for name in tuple(graph.initializers):
|
|
151
|
+
initializer = graph.initializers[name]
|
|
152
|
+
if initializer.is_graph_input():
|
|
153
|
+
# Skip the ones that are also graph inputs
|
|
154
|
+
logger.debug(
|
|
155
|
+
"Initializer '%s' is also a graph input, so it can't be lifted",
|
|
156
|
+
initializer.name,
|
|
157
|
+
)
|
|
158
|
+
continue
|
|
159
|
+
# Remove the initializer from the subgraph
|
|
160
|
+
graph.initializers.pop(name)
|
|
161
|
+
# To avoid name conflicts, we need to rename the initializer
|
|
162
|
+
# to a unique name in the main graph
|
|
163
|
+
if name in registered_initializer_names:
|
|
164
|
+
name_count = registered_initializer_names[name]
|
|
165
|
+
initializer.name = f"{name}_{name_count}"
|
|
166
|
+
registered_initializer_names[name] = name_count + 1
|
|
167
|
+
else:
|
|
168
|
+
assert initializer.name is not None
|
|
169
|
+
registered_initializer_names[initializer.name] = 1
|
|
170
|
+
model.graph.register_initializer(initializer)
|
|
171
|
+
count += 1
|
|
172
|
+
logger.debug(
|
|
173
|
+
"Lifted initializer '%s' from subgraph '%s' to main graph",
|
|
174
|
+
initializer.name,
|
|
175
|
+
graph.name,
|
|
176
|
+
)
|
|
177
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
class RemoveInitializersFromInputsPass(ir.passes.InPlacePass):
|
|
181
|
+
"""Remove initializers from inputs.
|
|
182
|
+
|
|
183
|
+
This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
187
|
+
count = 0
|
|
188
|
+
for graph in model.graphs():
|
|
189
|
+
initializers = set(graph.initializers.values())
|
|
190
|
+
new_inputs = []
|
|
191
|
+
for input_value in graph.inputs:
|
|
192
|
+
if input_value in initializers:
|
|
193
|
+
count += 1
|
|
194
|
+
else:
|
|
195
|
+
new_inputs.append(input_value)
|
|
196
|
+
graph.inputs.clear()
|
|
197
|
+
graph.inputs.extend(new_inputs)
|
|
198
|
+
logger.info("Removed %s initializers from graph inputs", count)
|
|
199
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
class AddInitializersToInputsPass(ir.passes.InPlacePass):
|
|
203
|
+
"""Add initializers to inputs.
|
|
204
|
+
|
|
205
|
+
This pass finds all initializers and adds them to the graph.inputs list if they are not already present.
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
209
|
+
count = 0
|
|
210
|
+
for graph in model.graphs():
|
|
211
|
+
inputs_set = set(graph.inputs)
|
|
212
|
+
for initializer in graph.initializers.values():
|
|
213
|
+
if initializer not in inputs_set:
|
|
214
|
+
graph.inputs.append(initializer)
|
|
215
|
+
count += 1
|
|
216
|
+
logger.info("Added %s initializers to graph inputs", count)
|
|
217
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Implementation of an inliner for onnx_ir."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
|
|
9
|
+
__all__ = ["InlinePass", "InlinePassResult"]
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
import onnx_ir.convenience as _ir_convenience
|
|
16
|
+
|
|
17
|
+
# A replacement for a node specifies a list of nodes that replaces the original node,
|
|
18
|
+
# and a list of values that replaces the original node's outputs.
|
|
19
|
+
|
|
20
|
+
NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
|
|
21
|
+
|
|
22
|
+
# A call stack is a list of identifiers of call sites, where the first element is the
|
|
23
|
+
# outermost call site, and the last element is the innermost call site. This is used
|
|
24
|
+
# primarily for generating unique names for values in the inlined functions.
|
|
25
|
+
CallSiteId = str
|
|
26
|
+
CallStack = list[CallSiteId]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
|
|
30
|
+
"""Generate a unique name from a name, calling-context, and set of used names.
|
|
31
|
+
|
|
32
|
+
If there is a name clash, we add a numeric suffix to the name to make
|
|
33
|
+
it unique. We use the same strategy to make node names unique.
|
|
34
|
+
|
|
35
|
+
TODO: We can use the callstack in generating a name for a value X in a function
|
|
36
|
+
that is inlined into a graph. This is not yet implemented. Using the full callstack
|
|
37
|
+
leads to very long and hard to read names. Some investigation is needed to find
|
|
38
|
+
a good naming strategy that will produce useful names for debugging.
|
|
39
|
+
"""
|
|
40
|
+
candidate = name
|
|
41
|
+
i = 1
|
|
42
|
+
while candidate in used_names:
|
|
43
|
+
i += 1
|
|
44
|
+
candidate = f"{name}_{i}"
|
|
45
|
+
used_names.add(candidate)
|
|
46
|
+
return candidate
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _CopyReplace:
|
|
50
|
+
"""Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
inliner: InlinePass,
|
|
55
|
+
attr_map: Mapping[str, ir.Attr],
|
|
56
|
+
value_map: dict[ir.Value, ir.Value | None],
|
|
57
|
+
metadata_props: dict[str, str],
|
|
58
|
+
call_stack: CallStack,
|
|
59
|
+
) -> None:
|
|
60
|
+
self._inliner = inliner
|
|
61
|
+
self._value_map = value_map
|
|
62
|
+
self._attr_map = attr_map
|
|
63
|
+
self._metadata_props = metadata_props
|
|
64
|
+
self._call_stack = call_stack
|
|
65
|
+
|
|
66
|
+
def clone_value(self, value: ir.Value) -> ir.Value | None:
|
|
67
|
+
if value in self._value_map:
|
|
68
|
+
return self._value_map[value]
|
|
69
|
+
# If the value is not in the value map, it must be a graph input.
|
|
70
|
+
assert value.producer() is None, f"Value {value} has no entry in the value map"
|
|
71
|
+
new_value = ir.Value(
|
|
72
|
+
name=value.name,
|
|
73
|
+
type=value.type,
|
|
74
|
+
shape=value.shape,
|
|
75
|
+
doc_string=value.doc_string,
|
|
76
|
+
const_value=value.const_value,
|
|
77
|
+
)
|
|
78
|
+
self._value_map[value] = new_value
|
|
79
|
+
return new_value
|
|
80
|
+
|
|
81
|
+
def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
|
|
82
|
+
if value is None:
|
|
83
|
+
return None
|
|
84
|
+
return self.clone_value(value)
|
|
85
|
+
|
|
86
|
+
def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None:
|
|
87
|
+
if not attr.is_ref():
|
|
88
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
89
|
+
graph = self.clone_graph(attr.as_graph())
|
|
90
|
+
return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
|
|
91
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
92
|
+
graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
|
|
93
|
+
return ir.Attr(
|
|
94
|
+
key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
|
|
95
|
+
)
|
|
96
|
+
return attr
|
|
97
|
+
assert attr.is_ref()
|
|
98
|
+
ref_attr_name = attr.ref_attr_name
|
|
99
|
+
assert ref_attr_name is not None, "Reference attribute must have a name"
|
|
100
|
+
if ref_attr_name in self._attr_map:
|
|
101
|
+
ref_attr = self._attr_map[ref_attr_name]
|
|
102
|
+
if not ref_attr.is_ref():
|
|
103
|
+
return ir.Attr(
|
|
104
|
+
key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
|
|
105
|
+
)
|
|
106
|
+
assert ref_attr.ref_attr_name is not None
|
|
107
|
+
return ir.RefAttr(
|
|
108
|
+
key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
|
|
109
|
+
)
|
|
110
|
+
# Note that if a function has an attribute-parameter X, and a call (node) to the function
|
|
111
|
+
# has no attribute X, all references to X in nodes inside the function body will be
|
|
112
|
+
# removed. This is just the ONNX representation of optional-attributes.
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
def clone_node(self, node: ir.Node) -> ir.Node:
|
|
116
|
+
new_inputs = [self.clone_optional_value(input) for input in node.inputs]
|
|
117
|
+
new_attributes = [
|
|
118
|
+
new_value
|
|
119
|
+
for key, value in node.attributes.items()
|
|
120
|
+
if (new_value := self.clone_attr(key, value)) is not None
|
|
121
|
+
]
|
|
122
|
+
new_name = node.name
|
|
123
|
+
if new_name is not None:
|
|
124
|
+
new_name = _make_unique_name(
|
|
125
|
+
new_name, self._call_stack, self._inliner.used_node_names
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
new_metadata = {**self._metadata_props, **node.metadata_props}
|
|
129
|
+
# TODO: For now, node metadata overrides callnode metadata if there is a conflict.
|
|
130
|
+
# Do we need to preserve both?
|
|
131
|
+
|
|
132
|
+
new_node = ir.Node(
|
|
133
|
+
node.domain,
|
|
134
|
+
node.op_type,
|
|
135
|
+
new_inputs,
|
|
136
|
+
new_attributes,
|
|
137
|
+
overload=node.overload,
|
|
138
|
+
num_outputs=len(node.outputs),
|
|
139
|
+
graph=None,
|
|
140
|
+
name=new_name,
|
|
141
|
+
doc_string=node.doc_string, # type: ignore
|
|
142
|
+
metadata_props=new_metadata,
|
|
143
|
+
)
|
|
144
|
+
new_outputs = new_node.outputs
|
|
145
|
+
for i, output in enumerate(node.outputs):
|
|
146
|
+
self._value_map[output] = new_outputs[i]
|
|
147
|
+
old_name = output.name if output.name is not None else f"output_{i}"
|
|
148
|
+
new_outputs[i].name = _make_unique_name(
|
|
149
|
+
old_name, self._call_stack, self._inliner.used_value_names
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
self._inliner.node_context[new_node] = self._call_stack
|
|
153
|
+
|
|
154
|
+
return new_node
|
|
155
|
+
|
|
156
|
+
def clone_graph(self, graph: ir.Graph) -> ir.Graph:
|
|
157
|
+
input_values = [self.clone_value(v) for v in graph.inputs]
|
|
158
|
+
nodes = [self.clone_node(node) for node in graph]
|
|
159
|
+
initializers = [self.clone_value(init) for init in graph.initializers.values()]
|
|
160
|
+
output_values = [
|
|
161
|
+
self.clone_value(v) for v in graph.outputs
|
|
162
|
+
] # Looks up already cloned values
|
|
163
|
+
|
|
164
|
+
return ir.Graph(
|
|
165
|
+
input_values, # type: ignore
|
|
166
|
+
output_values, # type: ignore
|
|
167
|
+
nodes=nodes,
|
|
168
|
+
initializers=initializers, # type: ignore
|
|
169
|
+
doc_string=graph.doc_string,
|
|
170
|
+
opset_imports=graph.opset_imports,
|
|
171
|
+
name=graph.name,
|
|
172
|
+
metadata_props=graph.metadata_props,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def _abbreviate(
|
|
177
|
+
function_ids: Iterable[ir.OperatorIdentifier],
|
|
178
|
+
) -> dict[ir.OperatorIdentifier, str]:
|
|
179
|
+
"""Create a short unambiguous abbreviation for all function ids."""
|
|
180
|
+
|
|
181
|
+
def id_abbreviation(id: ir.OperatorIdentifier) -> str:
|
|
182
|
+
"""Create a short unambiguous abbreviation for a function id."""
|
|
183
|
+
domain, name, overload = id
|
|
184
|
+
# Omit the domain, if it remains unambiguous after omitting it.
|
|
185
|
+
if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
|
|
186
|
+
short_domain = domain + "_"
|
|
187
|
+
else:
|
|
188
|
+
short_domain = ""
|
|
189
|
+
if overload != "":
|
|
190
|
+
return short_domain + name + "_" + overload
|
|
191
|
+
return short_domain + name
|
|
192
|
+
|
|
193
|
+
return {id: id_abbreviation(id) for id in function_ids}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@dataclasses.dataclass
|
|
197
|
+
class InlinePassResult(ir.passes.PassResult):
|
|
198
|
+
id_count: dict[ir.OperatorIdentifier, int]
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class InlinePass(ir.passes.InPlacePass):
|
|
202
|
+
"""Inline model local functions to the main graph and clear function definitions."""
|
|
203
|
+
|
|
204
|
+
def __init__(self) -> None:
|
|
205
|
+
super().__init__()
|
|
206
|
+
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
|
|
207
|
+
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
|
|
208
|
+
self._opset_imports: dict[str, int] = {}
|
|
209
|
+
self.used_value_names: set[str] = set()
|
|
210
|
+
self.used_node_names: set[str] = set()
|
|
211
|
+
self.node_context: dict[ir.Node, CallStack] = {}
|
|
212
|
+
|
|
213
|
+
def _reset(self, model: ir.Model) -> None:
|
|
214
|
+
self._functions = model.functions
|
|
215
|
+
self._function_id_abbreviations = _abbreviate(self._functions.keys())
|
|
216
|
+
self._opset_imports = model.opset_imports
|
|
217
|
+
self.used_value_names = set()
|
|
218
|
+
self.used_node_names = set()
|
|
219
|
+
self.node_context = {}
|
|
220
|
+
|
|
221
|
+
def call(self, model: ir.Model) -> InlinePassResult:
|
|
222
|
+
self._reset(model)
|
|
223
|
+
id_count = self._inline_calls_in(model.graph)
|
|
224
|
+
model.functions.clear()
|
|
225
|
+
return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
|
|
226
|
+
|
|
227
|
+
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
|
|
228
|
+
id = node.op_identifier()
|
|
229
|
+
function = self._functions[id]
|
|
230
|
+
|
|
231
|
+
# check opset compatibility and update the opset imports
|
|
232
|
+
for key, value in function.opset_imports.items():
|
|
233
|
+
if key not in self._opset_imports:
|
|
234
|
+
self._opset_imports[key] = value
|
|
235
|
+
elif self._opset_imports[key] != value:
|
|
236
|
+
raise ValueError(
|
|
237
|
+
f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
# Identify substitutions for both inputs and attributes of the function:
|
|
241
|
+
attributes: Mapping[str, ir.Attr] = node.attributes
|
|
242
|
+
default_attr_values = {
|
|
243
|
+
attr.name: attr
|
|
244
|
+
for attr in function.attributes.values()
|
|
245
|
+
if attr.name not in attributes and attr.value is not None
|
|
246
|
+
}
|
|
247
|
+
if default_attr_values:
|
|
248
|
+
attributes = {**attributes, **default_attr_values}
|
|
249
|
+
if any(
|
|
250
|
+
attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
|
|
251
|
+
for attr in attributes.values()
|
|
252
|
+
):
|
|
253
|
+
raise ValueError(
|
|
254
|
+
"Inliner does not support graph attribute parameters to functions"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
if len(node.inputs) > len(function.inputs):
|
|
258
|
+
raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
|
|
259
|
+
value_map = {}
|
|
260
|
+
for i, input in enumerate(node.inputs):
|
|
261
|
+
value_map[function.inputs[i]] = input
|
|
262
|
+
for i in range(len(node.inputs), len(function.inputs)):
|
|
263
|
+
value_map[function.inputs[i]] = None
|
|
264
|
+
|
|
265
|
+
# Identify call-stack for node, used to generate unique names.
|
|
266
|
+
call_stack = self.node_context.get(node, [])
|
|
267
|
+
new_call_stack = [*call_stack, call_site_id]
|
|
268
|
+
|
|
269
|
+
cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)
|
|
270
|
+
|
|
271
|
+
# iterate over the nodes in the function, creating a copy of each node
|
|
272
|
+
# and replacing inputs with the corresponding values in the value map.
|
|
273
|
+
# Update the value map with the new values.
|
|
274
|
+
|
|
275
|
+
nodes = [cloner.clone_node(node) for node in function]
|
|
276
|
+
output_values = [value_map[output] for output in function.outputs]
|
|
277
|
+
return nodes, output_values # type: ignore
|
|
278
|
+
|
|
279
|
+
def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
|
|
280
|
+
for input in graph.inputs:
|
|
281
|
+
if input.name is not None:
|
|
282
|
+
self.used_value_names.add(input.name)
|
|
283
|
+
for initializer in graph.initializers:
|
|
284
|
+
self.used_value_names.add(initializer)
|
|
285
|
+
|
|
286
|
+
# Pre-processing:
|
|
287
|
+
# * Count the number of times each function is called in the graph.
|
|
288
|
+
# This is used for disambiguating names of values in the inlined functions.
|
|
289
|
+
# * And identify names of values that are used in the graph.
|
|
290
|
+
id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
291
|
+
for node in graph:
|
|
292
|
+
if node.name:
|
|
293
|
+
self.used_node_names.add(node.name)
|
|
294
|
+
id = node.op_identifier()
|
|
295
|
+
if id in self._functions:
|
|
296
|
+
id_count[id] += 1
|
|
297
|
+
for output in node.outputs:
|
|
298
|
+
if output.name is not None:
|
|
299
|
+
self.used_value_names.add(output.name)
|
|
300
|
+
next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
301
|
+
for node in graph:
|
|
302
|
+
id = node.op_identifier()
|
|
303
|
+
if id in self._functions:
|
|
304
|
+
# If there are multiple calls to same function, we use a prefix to disambiguate
|
|
305
|
+
# the different call-sites:
|
|
306
|
+
if id_count[id] > 1:
|
|
307
|
+
call_site_prefix = f"_{next_id[id]}"
|
|
308
|
+
next_id[id] += 1
|
|
309
|
+
else:
|
|
310
|
+
call_site_prefix = ""
|
|
311
|
+
call_site = node.name or (
|
|
312
|
+
self._function_id_abbreviations[id] + call_site_prefix
|
|
313
|
+
)
|
|
314
|
+
nodes, values = self._instantiate_call(node, call_site)
|
|
315
|
+
_ir_convenience.replace_nodes_and_values(
|
|
316
|
+
graph,
|
|
317
|
+
insertion_point=node,
|
|
318
|
+
old_nodes=[node],
|
|
319
|
+
new_nodes=nodes,
|
|
320
|
+
old_values=node.outputs,
|
|
321
|
+
new_values=values,
|
|
322
|
+
)
|
|
323
|
+
else:
|
|
324
|
+
for attr in node.attributes.values():
|
|
325
|
+
if not isinstance(attr, ir.Attr):
|
|
326
|
+
continue
|
|
327
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
328
|
+
self._inline_calls_in(attr.as_graph())
|
|
329
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
330
|
+
for g in attr.as_graphs():
|
|
331
|
+
self._inline_calls_in(g)
|
|
332
|
+
return id_count
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Passes for debugging purposes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CheckerPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
from onnx_ir.passes.common import _c_api_utils
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CheckerPass(ir.passes.PassBase):
|
|
20
|
+
"""Run onnx checker on the model."""
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def in_place(self) -> Literal[True]:
|
|
24
|
+
"""This pass does not create a new model."""
|
|
25
|
+
return True
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def changes_input(self) -> Literal[False]:
|
|
29
|
+
"""This pass does not change the input model."""
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
full_check: bool = False,
|
|
35
|
+
skip_opset_compatibility_check: bool = False,
|
|
36
|
+
check_custom_domain: bool = False,
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.full_check = full_check
|
|
40
|
+
self.skip_opset_compatibility_check = skip_opset_compatibility_check
|
|
41
|
+
self.check_custom_domain = check_custom_domain
|
|
42
|
+
|
|
43
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
44
|
+
"""Run the onnx checker on the model."""
|
|
45
|
+
|
|
46
|
+
def _partial_check_model(proto: onnx.ModelProto) -> None:
|
|
47
|
+
"""Partial function to check the model."""
|
|
48
|
+
onnx.checker.check_model(
|
|
49
|
+
proto,
|
|
50
|
+
full_check=self.full_check,
|
|
51
|
+
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
|
|
52
|
+
check_custom_domain=self.check_custom_domain,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
|
|
56
|
+
# The model is not modified
|
|
57
|
+
return ir.passes.PassResult(model, False)
|