onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,331 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Implementation of an inliner for onnx_ir."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import dataclasses
|
|
8
|
+
|
|
9
|
+
__all__ = ["InlinePass", "InlinePassResult"]
|
|
10
|
+
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from collections.abc import Iterable, Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
import onnx_ir.convenience as _ir_convenience
|
|
16
|
+
|
|
17
|
+
# A replacement for a node specifies a list of nodes that replaces the original node,
|
|
18
|
+
# and a list of values that replaces the original node's outputs.
|
|
19
|
+
|
|
20
|
+
NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
|
|
21
|
+
|
|
22
|
+
# A call stack is a list of identifiers of call sites, where the first element is the
|
|
23
|
+
# outermost call site, and the last element is the innermost call site. This is used
|
|
24
|
+
# primarily for generating unique names for values in the inlined functions.
|
|
25
|
+
CallSiteId = str
|
|
26
|
+
CallStack = list[CallSiteId]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
|
|
30
|
+
"""Generate a unique name from a name, calling-context, and set of used names.
|
|
31
|
+
|
|
32
|
+
If there is a name clash, we add a numeric suffix to the name to make
|
|
33
|
+
it unique. We use the same strategy to make node names unique.
|
|
34
|
+
|
|
35
|
+
TODO: We can use the callstack in generating a name for a value X in a function
|
|
36
|
+
that is inlined into a graph. This is not yet implemented. Using the full callstack
|
|
37
|
+
leads to very long and hard to read names. Some investigation is needed to find
|
|
38
|
+
a good naming strategy that will produce useful names for debugging.
|
|
39
|
+
"""
|
|
40
|
+
candidate = name
|
|
41
|
+
i = 1
|
|
42
|
+
while candidate in used_names:
|
|
43
|
+
i += 1
|
|
44
|
+
candidate = f"{name}_{i}"
|
|
45
|
+
used_names.add(candidate)
|
|
46
|
+
return candidate
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class _CopyReplace:
|
|
50
|
+
"""Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
inliner: InlinePass,
|
|
55
|
+
attr_map: dict[str, ir.Attr],
|
|
56
|
+
value_map: dict[ir.Value, ir.Value | None],
|
|
57
|
+
metadata_props: dict[str, str],
|
|
58
|
+
call_stack: CallStack,
|
|
59
|
+
) -> None:
|
|
60
|
+
self._inliner = inliner
|
|
61
|
+
self._value_map = value_map
|
|
62
|
+
self._attr_map = attr_map
|
|
63
|
+
self._metadata_props = metadata_props
|
|
64
|
+
self._call_stack = call_stack
|
|
65
|
+
|
|
66
|
+
def clone_value(self, value: ir.Value) -> ir.Value | None:
|
|
67
|
+
if value in self._value_map:
|
|
68
|
+
return self._value_map[value]
|
|
69
|
+
# If the value is not in the value map, it must be a graph input.
|
|
70
|
+
assert value.producer() is None, f"Value {value} has no entry in the value map"
|
|
71
|
+
new_value = ir.Value(
|
|
72
|
+
name=value.name,
|
|
73
|
+
type=value.type,
|
|
74
|
+
shape=value.shape,
|
|
75
|
+
doc_string=value.doc_string,
|
|
76
|
+
const_value=value.const_value,
|
|
77
|
+
)
|
|
78
|
+
self._value_map[value] = new_value
|
|
79
|
+
return new_value
|
|
80
|
+
|
|
81
|
+
def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
|
|
82
|
+
if value is None:
|
|
83
|
+
return None
|
|
84
|
+
return self.clone_value(value)
|
|
85
|
+
|
|
86
|
+
def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None:
|
|
87
|
+
if not attr.is_ref():
|
|
88
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
89
|
+
graph = self.clone_graph(attr.as_graph())
|
|
90
|
+
return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
|
|
91
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
92
|
+
graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
|
|
93
|
+
return ir.Attr(
|
|
94
|
+
key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
|
|
95
|
+
)
|
|
96
|
+
return attr
|
|
97
|
+
assert attr.is_ref()
|
|
98
|
+
ref_attr_name = attr.ref_attr_name
|
|
99
|
+
if ref_attr_name in self._attr_map:
|
|
100
|
+
ref_attr = self._attr_map[ref_attr_name]
|
|
101
|
+
if not ref_attr.is_ref():
|
|
102
|
+
return ir.Attr(
|
|
103
|
+
key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
|
|
104
|
+
)
|
|
105
|
+
assert ref_attr.ref_attr_name is not None
|
|
106
|
+
return ir.RefAttr(
|
|
107
|
+
key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
|
|
108
|
+
)
|
|
109
|
+
# Note that if a function has an attribute-parameter X, and a call (node) to the function
|
|
110
|
+
# has no attribute X, all references to X in nodes inside the function body will be
|
|
111
|
+
# removed. This is just the ONNX representation of optional-attributes.
|
|
112
|
+
return None
|
|
113
|
+
|
|
114
|
+
def clone_node(self, node: ir.Node) -> ir.Node:
|
|
115
|
+
new_inputs = [self.clone_optional_value(input) for input in node.inputs]
|
|
116
|
+
new_attributes = [
|
|
117
|
+
new_value
|
|
118
|
+
for key, value in node.attributes.items()
|
|
119
|
+
if (new_value := self.clone_attr(key, value)) is not None
|
|
120
|
+
]
|
|
121
|
+
new_name = node.name
|
|
122
|
+
if new_name is not None:
|
|
123
|
+
new_name = _make_unique_name(
|
|
124
|
+
new_name, self._call_stack, self._inliner.used_node_names
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
new_metadata = {**self._metadata_props, **node.metadata_props}
|
|
128
|
+
# TODO: For now, node metadata overrides callnode metadata if there is a conflict.
|
|
129
|
+
# Do we need to preserve both?
|
|
130
|
+
|
|
131
|
+
new_node = ir.Node(
|
|
132
|
+
node.domain,
|
|
133
|
+
node.op_type,
|
|
134
|
+
new_inputs,
|
|
135
|
+
new_attributes,
|
|
136
|
+
overload=node.overload,
|
|
137
|
+
num_outputs=len(node.outputs),
|
|
138
|
+
graph=None,
|
|
139
|
+
name=new_name,
|
|
140
|
+
doc_string=node.doc_string, # type: ignore
|
|
141
|
+
metadata_props=new_metadata,
|
|
142
|
+
)
|
|
143
|
+
new_outputs = new_node.outputs
|
|
144
|
+
for i, output in enumerate(node.outputs):
|
|
145
|
+
self._value_map[output] = new_outputs[i]
|
|
146
|
+
old_name = output.name if output.name is not None else f"output_{i}"
|
|
147
|
+
new_outputs[i].name = _make_unique_name(
|
|
148
|
+
old_name, self._call_stack, self._inliner.used_value_names
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
self._inliner.node_context[new_node] = self._call_stack
|
|
152
|
+
|
|
153
|
+
return new_node
|
|
154
|
+
|
|
155
|
+
def clone_graph(self, graph: ir.Graph) -> ir.Graph:
|
|
156
|
+
input_values = [self.clone_value(v) for v in graph.inputs]
|
|
157
|
+
nodes = [self.clone_node(node) for node in graph]
|
|
158
|
+
initializers = [self.clone_value(init) for init in graph.initializers.values()]
|
|
159
|
+
output_values = [
|
|
160
|
+
self.clone_value(v) for v in graph.outputs
|
|
161
|
+
] # Looks up already cloned values
|
|
162
|
+
|
|
163
|
+
return ir.Graph(
|
|
164
|
+
input_values, # type: ignore
|
|
165
|
+
output_values, # type: ignore
|
|
166
|
+
nodes=nodes,
|
|
167
|
+
initializers=initializers, # type: ignore
|
|
168
|
+
doc_string=graph.doc_string,
|
|
169
|
+
opset_imports=graph.opset_imports,
|
|
170
|
+
name=graph.name,
|
|
171
|
+
metadata_props=graph.metadata_props,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def _abbreviate(
|
|
176
|
+
function_ids: Iterable[ir.OperatorIdentifier],
|
|
177
|
+
) -> dict[ir.OperatorIdentifier, str]:
|
|
178
|
+
"""Create a short unambiguous abbreviation for all function ids."""
|
|
179
|
+
|
|
180
|
+
def id_abbreviation(id: ir.OperatorIdentifier) -> str:
|
|
181
|
+
"""Create a short unambiguous abbreviation for a function id."""
|
|
182
|
+
domain, name, overload = id
|
|
183
|
+
# Omit the domain, if it remains unambiguous after omitting it.
|
|
184
|
+
if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
|
|
185
|
+
short_domain = domain + "_"
|
|
186
|
+
else:
|
|
187
|
+
short_domain = ""
|
|
188
|
+
if overload != "":
|
|
189
|
+
return short_domain + name + "_" + overload
|
|
190
|
+
return short_domain + name
|
|
191
|
+
|
|
192
|
+
return {id: id_abbreviation(id) for id in function_ids}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
@dataclasses.dataclass
|
|
196
|
+
class InlinePassResult(ir.passes.PassResult):
|
|
197
|
+
id_count: dict[ir.OperatorIdentifier, int]
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
class InlinePass(ir.passes.InPlacePass):
|
|
201
|
+
"""Inline model local functions to the main graph and clear function definitions."""
|
|
202
|
+
|
|
203
|
+
def __init__(self) -> None:
|
|
204
|
+
super().__init__()
|
|
205
|
+
self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
|
|
206
|
+
self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
|
|
207
|
+
self._opset_imports: dict[str, int] = {}
|
|
208
|
+
self.used_value_names: set[str] = set()
|
|
209
|
+
self.used_node_names: set[str] = set()
|
|
210
|
+
self.node_context: dict[ir.Node, CallStack] = {}
|
|
211
|
+
|
|
212
|
+
def _reset(self, model: ir.Model) -> None:
|
|
213
|
+
self._functions = model.functions
|
|
214
|
+
self._function_id_abbreviations = _abbreviate(self._functions.keys())
|
|
215
|
+
self._opset_imports = model.opset_imports
|
|
216
|
+
self.used_value_names = set()
|
|
217
|
+
self.used_node_names = set()
|
|
218
|
+
self.node_context = {}
|
|
219
|
+
|
|
220
|
+
def call(self, model: ir.Model) -> InlinePassResult:
|
|
221
|
+
self._reset(model)
|
|
222
|
+
id_count = self._inline_calls_in(model.graph)
|
|
223
|
+
model.functions.clear()
|
|
224
|
+
return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
|
|
225
|
+
|
|
226
|
+
def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
|
|
227
|
+
id = node.op_identifier()
|
|
228
|
+
function = self._functions[id]
|
|
229
|
+
|
|
230
|
+
# check opset compatibility and update the opset imports
|
|
231
|
+
for key, value in function.opset_imports.items():
|
|
232
|
+
if key not in self._opset_imports:
|
|
233
|
+
self._opset_imports[key] = value
|
|
234
|
+
elif self._opset_imports[key] != value:
|
|
235
|
+
raise ValueError(
|
|
236
|
+
f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Identify substitutions for both inputs and attributes of the function:
|
|
240
|
+
attributes: dict[str, ir.Attr] = node.attributes
|
|
241
|
+
default_attr_values = {
|
|
242
|
+
attr.name: attr
|
|
243
|
+
for attr in function.attributes.values()
|
|
244
|
+
if attr.name not in attributes and attr.value is not None
|
|
245
|
+
}
|
|
246
|
+
if default_attr_values:
|
|
247
|
+
attributes = {**attributes, **default_attr_values}
|
|
248
|
+
if any(
|
|
249
|
+
attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
|
|
250
|
+
for attr in attributes.values()
|
|
251
|
+
):
|
|
252
|
+
raise ValueError(
|
|
253
|
+
"Inliner does not support graph attribute parameters to functions"
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
if len(node.inputs) > len(function.inputs):
|
|
257
|
+
raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
|
|
258
|
+
value_map = {}
|
|
259
|
+
for i, input in enumerate(node.inputs):
|
|
260
|
+
value_map[function.inputs[i]] = input
|
|
261
|
+
for i in range(len(node.inputs), len(function.inputs)):
|
|
262
|
+
value_map[function.inputs[i]] = None
|
|
263
|
+
|
|
264
|
+
# Identify call-stack for node, used to generate unique names.
|
|
265
|
+
call_stack = self.node_context.get(node, [])
|
|
266
|
+
new_call_stack = [*call_stack, call_site_id]
|
|
267
|
+
|
|
268
|
+
cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)
|
|
269
|
+
|
|
270
|
+
# iterate over the nodes in the function, creating a copy of each node
|
|
271
|
+
# and replacing inputs with the corresponding values in the value map.
|
|
272
|
+
# Update the value map with the new values.
|
|
273
|
+
|
|
274
|
+
nodes = [cloner.clone_node(node) for node in function]
|
|
275
|
+
output_values = [value_map[output] for output in function.outputs]
|
|
276
|
+
return nodes, output_values # type: ignore
|
|
277
|
+
|
|
278
|
+
def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
|
|
279
|
+
for input in graph.inputs:
|
|
280
|
+
if input.name is not None:
|
|
281
|
+
self.used_value_names.add(input.name)
|
|
282
|
+
for initializer in graph.initializers:
|
|
283
|
+
self.used_value_names.add(initializer)
|
|
284
|
+
|
|
285
|
+
# Pre-processing:
|
|
286
|
+
# * Count the number of times each function is called in the graph.
|
|
287
|
+
# This is used for disambiguating names of values in the inlined functions.
|
|
288
|
+
# * And identify names of values that are used in the graph.
|
|
289
|
+
id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
290
|
+
for node in graph:
|
|
291
|
+
if node.name:
|
|
292
|
+
self.used_node_names.add(node.name)
|
|
293
|
+
id = node.op_identifier()
|
|
294
|
+
if id in self._functions:
|
|
295
|
+
id_count[id] += 1
|
|
296
|
+
for output in node.outputs:
|
|
297
|
+
if output.name is not None:
|
|
298
|
+
self.used_value_names.add(output.name)
|
|
299
|
+
next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
|
|
300
|
+
for node in graph:
|
|
301
|
+
id = node.op_identifier()
|
|
302
|
+
if id in self._functions:
|
|
303
|
+
# If there are multiple calls to same function, we use a prefix to disambiguate
|
|
304
|
+
# the different call-sites:
|
|
305
|
+
if id_count[id] > 1:
|
|
306
|
+
call_site_prefix = f"_{next_id[id]}"
|
|
307
|
+
next_id[id] += 1
|
|
308
|
+
else:
|
|
309
|
+
call_site_prefix = ""
|
|
310
|
+
call_site = node.name or (
|
|
311
|
+
self._function_id_abbreviations[id] + call_site_prefix
|
|
312
|
+
)
|
|
313
|
+
nodes, values = self._instantiate_call(node, call_site)
|
|
314
|
+
_ir_convenience.replace_nodes_and_values(
|
|
315
|
+
graph,
|
|
316
|
+
insertion_point=node,
|
|
317
|
+
old_nodes=[node],
|
|
318
|
+
new_nodes=nodes,
|
|
319
|
+
old_values=node.outputs,
|
|
320
|
+
new_values=values,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
for attr in node.attributes.values():
|
|
324
|
+
if not isinstance(attr, ir.Attr):
|
|
325
|
+
continue
|
|
326
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
327
|
+
self._inline_calls_in(attr.as_graph())
|
|
328
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
329
|
+
for g in attr.as_graphs():
|
|
330
|
+
self._inline_calls_in(g)
|
|
331
|
+
return id_count
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Passes for debugging purposes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CheckerPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
from typing import Literal
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
from onnx_ir.passes.common import _c_api_utils
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CheckerPass(ir.passes.PassBase):
|
|
20
|
+
"""Run onnx checker on the model."""
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def in_place(self) -> Literal[True]:
|
|
24
|
+
"""This pass does not create a new model."""
|
|
25
|
+
return True
|
|
26
|
+
|
|
27
|
+
@property
|
|
28
|
+
def changes_input(self) -> Literal[False]:
|
|
29
|
+
"""This pass does not change the input model."""
|
|
30
|
+
return False
|
|
31
|
+
|
|
32
|
+
def __init__(
|
|
33
|
+
self,
|
|
34
|
+
full_check: bool = False,
|
|
35
|
+
skip_opset_compatibility_check: bool = False,
|
|
36
|
+
check_custom_domain: bool = False,
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.full_check = full_check
|
|
40
|
+
self.skip_opset_compatibility_check = skip_opset_compatibility_check
|
|
41
|
+
self.check_custom_domain = check_custom_domain
|
|
42
|
+
|
|
43
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
44
|
+
"""Run the onnx checker on the model."""
|
|
45
|
+
|
|
46
|
+
def _partial_check_model(proto: onnx.ModelProto) -> None:
|
|
47
|
+
"""Partial function to check the model."""
|
|
48
|
+
onnx.checker.check_model(
|
|
49
|
+
proto,
|
|
50
|
+
full_check=self.full_check,
|
|
51
|
+
skip_opset_compatibility_check=self.skip_opset_compatibility_check,
|
|
52
|
+
check_custom_domain=self.check_custom_domain,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
_c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
|
|
56
|
+
# The model is not modified
|
|
57
|
+
return ir.passes.PassResult(model, False)
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Shape inference pass using onnx.shape_inference."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ShapeInferencePass",
|
|
9
|
+
"infer_shapes",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
import onnx
|
|
15
|
+
|
|
16
|
+
import onnx_ir as ir
|
|
17
|
+
from onnx_ir.passes.common import _c_api_utils
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
|
|
23
|
+
"""Merge the shape inferred model with the original model.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: The original IR model.
|
|
27
|
+
inferred_proto: The ONNX model with shapes and types inferred.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A tuple containing the modified model and a boolean indicating whether the model was modified.
|
|
31
|
+
"""
|
|
32
|
+
inferred_model = ir.serde.deserialize_model(inferred_proto)
|
|
33
|
+
modified = False
|
|
34
|
+
for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
|
|
35
|
+
original_values = ir.convenience.create_value_mapping(original_graph)
|
|
36
|
+
inferred_values = ir.convenience.create_value_mapping(inferred_graph)
|
|
37
|
+
for name, value in original_values.items():
|
|
38
|
+
if name in inferred_values:
|
|
39
|
+
inferred_value = inferred_values[name]
|
|
40
|
+
if value.shape != inferred_value.shape and inferred_value.shape is not None:
|
|
41
|
+
value.shape = inferred_value.shape
|
|
42
|
+
modified = True
|
|
43
|
+
if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
|
|
44
|
+
value.dtype = inferred_value.dtype
|
|
45
|
+
modified = True
|
|
46
|
+
else:
|
|
47
|
+
logger.warning(
|
|
48
|
+
"Value %s not found in inferred graph %s", name, inferred_graph.name
|
|
49
|
+
)
|
|
50
|
+
return modified
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ShapeInferencePass(ir.passes.InPlacePass):
|
|
54
|
+
"""This pass performs shape inference on the graph."""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Initialize the shape inference pass.
|
|
60
|
+
|
|
61
|
+
If inference fails, the model is left unchanged.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
check_type: If True, check the types of the inputs and outputs.
|
|
65
|
+
strict_mode: If True, use strict mode for shape inference.
|
|
66
|
+
data_prop: If True, use data propagation for shape inference.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.check_type = check_type
|
|
70
|
+
self.strict_mode = strict_mode
|
|
71
|
+
self.data_prop = data_prop
|
|
72
|
+
|
|
73
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
74
|
+
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
|
|
75
|
+
return onnx.shape_inference.infer_shapes(
|
|
76
|
+
proto,
|
|
77
|
+
check_type=self.check_type,
|
|
78
|
+
strict_mode=self.strict_mode,
|
|
79
|
+
data_prop=self.data_prop,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
|
|
84
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
85
|
+
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
|
|
86
|
+
return ir.passes.PassResult(model, False)
|
|
87
|
+
|
|
88
|
+
modified = _merge_func(model, inferred_model_proto)
|
|
89
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def infer_shapes(
|
|
93
|
+
model: ir.Model,
|
|
94
|
+
*,
|
|
95
|
+
check_type: bool = True,
|
|
96
|
+
strict_mode: bool = True,
|
|
97
|
+
data_prop: bool = True,
|
|
98
|
+
) -> ir.Model:
|
|
99
|
+
"""Perform shape inference on the model.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model: The model to perform shape inference on.
|
|
103
|
+
check_type: If True, check the types of the inputs and outputs.
|
|
104
|
+
strict_mode: If True, use strict mode for shape inference.
|
|
105
|
+
data_prop: If True, use data propagation for shape inference.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The model with shape inference applied.
|
|
109
|
+
"""
|
|
110
|
+
return ShapeInferencePass(
|
|
111
|
+
check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
|
|
112
|
+
)(model).model
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for topologically sorting the graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"TopologicalSortPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TopologicalSortPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Topologically sort graphs and functions in a model."""
|
|
17
|
+
|
|
18
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
19
|
+
original_nodes = list(model.graph)
|
|
20
|
+
model.graph.sort()
|
|
21
|
+
sorted_nodes = list(model.graph)
|
|
22
|
+
for function in model.functions.values():
|
|
23
|
+
original_nodes.extend(function)
|
|
24
|
+
function.sort()
|
|
25
|
+
sorted_nodes.extend(function)
|
|
26
|
+
|
|
27
|
+
# Compare node orders to determine if any changes were made
|
|
28
|
+
modified = False
|
|
29
|
+
for node, new_node in zip(original_nodes, sorted_nodes):
|
|
30
|
+
if node is not new_node:
|
|
31
|
+
modified = True
|
|
32
|
+
break
|
|
33
|
+
return ir.passes.PassResult(model=model, modified=modified)
|