onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +874 -257
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +40 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +217 -0
- onnx_ir/passes/common/inliner.py +332 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.1.dist-info/METADATA +53 -0
- onnx_ir-0.1.1.dist-info/RECORD +42 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Shape inference pass using onnx.shape_inference."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ShapeInferencePass",
|
|
9
|
+
"infer_shapes",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
import onnx
|
|
15
|
+
|
|
16
|
+
import onnx_ir as ir
|
|
17
|
+
from onnx_ir.passes.common import _c_api_utils
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
|
|
23
|
+
"""Merge the shape inferred model with the original model.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
model: The original IR model.
|
|
27
|
+
inferred_proto: The ONNX model with shapes and types inferred.
|
|
28
|
+
|
|
29
|
+
Returns:
|
|
30
|
+
A tuple containing the modified model and a boolean indicating whether the model was modified.
|
|
31
|
+
"""
|
|
32
|
+
inferred_model = ir.serde.deserialize_model(inferred_proto)
|
|
33
|
+
modified = False
|
|
34
|
+
for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
|
|
35
|
+
original_values = ir.convenience.create_value_mapping(original_graph)
|
|
36
|
+
inferred_values = ir.convenience.create_value_mapping(inferred_graph)
|
|
37
|
+
for name, value in original_values.items():
|
|
38
|
+
if name in inferred_values:
|
|
39
|
+
inferred_value = inferred_values[name]
|
|
40
|
+
if value.shape != inferred_value.shape and inferred_value.shape is not None:
|
|
41
|
+
value.shape = inferred_value.shape
|
|
42
|
+
modified = True
|
|
43
|
+
if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
|
|
44
|
+
value.dtype = inferred_value.dtype
|
|
45
|
+
modified = True
|
|
46
|
+
else:
|
|
47
|
+
logger.warning(
|
|
48
|
+
"Value %s not found in inferred graph %s", name, inferred_graph.name
|
|
49
|
+
)
|
|
50
|
+
return modified
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class ShapeInferencePass(ir.passes.InPlacePass):
|
|
54
|
+
"""This pass performs shape inference on the graph."""
|
|
55
|
+
|
|
56
|
+
def __init__(
|
|
57
|
+
self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
|
|
58
|
+
) -> None:
|
|
59
|
+
"""Initialize the shape inference pass.
|
|
60
|
+
|
|
61
|
+
If inference fails, the model is left unchanged.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
check_type: If True, check the types of the inputs and outputs.
|
|
65
|
+
strict_mode: If True, use strict mode for shape inference.
|
|
66
|
+
data_prop: If True, use data propagation for shape inference.
|
|
67
|
+
"""
|
|
68
|
+
super().__init__()
|
|
69
|
+
self.check_type = check_type
|
|
70
|
+
self.strict_mode = strict_mode
|
|
71
|
+
self.data_prop = data_prop
|
|
72
|
+
|
|
73
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
74
|
+
def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
|
|
75
|
+
return onnx.shape_inference.infer_shapes(
|
|
76
|
+
proto,
|
|
77
|
+
check_type=self.check_type,
|
|
78
|
+
strict_mode=self.strict_mode,
|
|
79
|
+
data_prop=self.data_prop,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
try:
|
|
83
|
+
inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
|
|
84
|
+
except Exception as e: # pylint: disable=broad-exception-caught
|
|
85
|
+
logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
|
|
86
|
+
return ir.passes.PassResult(model, False)
|
|
87
|
+
|
|
88
|
+
modified = _merge_func(model, inferred_model_proto)
|
|
89
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def infer_shapes(
|
|
93
|
+
model: ir.Model,
|
|
94
|
+
*,
|
|
95
|
+
check_type: bool = True,
|
|
96
|
+
strict_mode: bool = True,
|
|
97
|
+
data_prop: bool = True,
|
|
98
|
+
) -> ir.Model:
|
|
99
|
+
"""Perform shape inference on the model.
|
|
100
|
+
|
|
101
|
+
Args:
|
|
102
|
+
model: The model to perform shape inference on.
|
|
103
|
+
check_type: If True, check the types of the inputs and outputs.
|
|
104
|
+
strict_mode: If True, use strict mode for shape inference.
|
|
105
|
+
data_prop: If True, use data propagation for shape inference.
|
|
106
|
+
|
|
107
|
+
Returns:
|
|
108
|
+
The model with shape inference applied.
|
|
109
|
+
"""
|
|
110
|
+
return ShapeInferencePass(
|
|
111
|
+
check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
|
|
112
|
+
)(model).model
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for topologically sorting the graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"TopologicalSortPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class TopologicalSortPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Topologically sort graphs and functions in a model."""
|
|
17
|
+
|
|
18
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
19
|
+
original_nodes = list(model.graph)
|
|
20
|
+
model.graph.sort()
|
|
21
|
+
sorted_nodes = list(model.graph)
|
|
22
|
+
for function in model.functions.values():
|
|
23
|
+
original_nodes.extend(function)
|
|
24
|
+
function.sort()
|
|
25
|
+
sorted_nodes.extend(function)
|
|
26
|
+
|
|
27
|
+
# Compare node orders to determine if any changes were made
|
|
28
|
+
modified = False
|
|
29
|
+
for node, new_node in zip(original_nodes, sorted_nodes):
|
|
30
|
+
if node is not new_node:
|
|
31
|
+
modified = True
|
|
32
|
+
break
|
|
33
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"RemoveUnusedNodesPass",
|
|
7
|
+
"RemoveUnusedFunctionsPass",
|
|
8
|
+
"RemoveUnusedOpsetsPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _remove_unused_optional_outputs(
|
|
21
|
+
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
|
|
22
|
+
) -> None:
|
|
23
|
+
try:
|
|
24
|
+
if node.domain not in {"", "onnx.ai"}:
|
|
25
|
+
return
|
|
26
|
+
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
|
|
27
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
28
|
+
logger.info(
|
|
29
|
+
"Failed to get schema for %s, skipping optional output removal",
|
|
30
|
+
node,
|
|
31
|
+
stack_info=True,
|
|
32
|
+
)
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
if node.op_type == "BatchNormalization":
|
|
36
|
+
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
|
|
37
|
+
# If running_mean and running_var are not used, remove them, and the training_mode attribute
|
|
38
|
+
def is_used_output(i: int) -> bool:
|
|
39
|
+
if i < len(node.outputs):
|
|
40
|
+
val = node.outputs[i]
|
|
41
|
+
return val in graph_outputs or bool(val.uses())
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
if is_used_output(1) or is_used_output(2):
|
|
45
|
+
return
|
|
46
|
+
if len(node.outputs) > 1:
|
|
47
|
+
node.outputs[1].name = ""
|
|
48
|
+
if len(node.outputs) > 2:
|
|
49
|
+
node.outputs[2].name = ""
|
|
50
|
+
node.attributes.pop("training_mode", None)
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
optional_info = []
|
|
54
|
+
for o in op_schema.outputs:
|
|
55
|
+
# Current ops do not have optional outputs if they have variable number of outputs
|
|
56
|
+
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
|
|
57
|
+
return
|
|
58
|
+
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
|
|
59
|
+
# If no optional outputs in spec, skip delete operations
|
|
60
|
+
if len([o == 1 for o in optional_info]) == 0:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
for i, out in enumerate(node.outputs):
|
|
64
|
+
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
|
|
65
|
+
out.name = ""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
|
|
69
|
+
graph_outputs = frozenset(function_or_graph.outputs)
|
|
70
|
+
onnx_opset_version = function_or_graph.opset_imports.get("", None)
|
|
71
|
+
count = 0
|
|
72
|
+
for node in reversed(function_or_graph):
|
|
73
|
+
removable = True
|
|
74
|
+
for output in node.outputs:
|
|
75
|
+
if output in graph_outputs or output.uses():
|
|
76
|
+
removable = False
|
|
77
|
+
break
|
|
78
|
+
if removable:
|
|
79
|
+
function_or_graph.remove(node, safe=True)
|
|
80
|
+
count += 1
|
|
81
|
+
else:
|
|
82
|
+
if onnx_opset_version is not None:
|
|
83
|
+
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
|
|
84
|
+
for attr in node.attributes.values():
|
|
85
|
+
if not isinstance(attr, ir.Attr):
|
|
86
|
+
continue
|
|
87
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
88
|
+
count += _remove_unused_nodes_in_graph_like(attr.as_graph())
|
|
89
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
90
|
+
for graph in attr.as_graphs():
|
|
91
|
+
count += _remove_unused_nodes_in_graph_like(graph)
|
|
92
|
+
return count
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
|
|
96
|
+
"""Pass for removing unused nodes and initializers (dead code elimination).
|
|
97
|
+
|
|
98
|
+
This pass does not modify the model signature (inputs and outputs). It ensures
|
|
99
|
+
that unused nodes and initializers are removed while preserving the original
|
|
100
|
+
contract of the model.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
104
|
+
count = _remove_unused_nodes_in_graph_like(model.graph)
|
|
105
|
+
graph_outputs = frozenset(model.graph.outputs)
|
|
106
|
+
graph_inputs = frozenset(model.graph.inputs)
|
|
107
|
+
initializers = model.graph.initializers
|
|
108
|
+
for init in list(initializers.values()):
|
|
109
|
+
if not (init.uses() or init in graph_outputs or init in graph_inputs):
|
|
110
|
+
assert init.name is not None
|
|
111
|
+
del initializers[init.name]
|
|
112
|
+
count += 1
|
|
113
|
+
for function in model.functions.values():
|
|
114
|
+
count += _remove_unused_nodes_in_graph_like(function)
|
|
115
|
+
if count:
|
|
116
|
+
logger.info("Removed %s unused nodes", count)
|
|
117
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
|
|
121
|
+
def __init__(self):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self._used: set[ir.OperatorIdentifier] | None = None
|
|
124
|
+
|
|
125
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
126
|
+
self._used = set()
|
|
127
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
128
|
+
self._call_node(model, node)
|
|
129
|
+
|
|
130
|
+
# Update the model to remove unused functions
|
|
131
|
+
unused = set(model.functions) - self._used
|
|
132
|
+
if not unused:
|
|
133
|
+
logger.info("No unused functions to remove")
|
|
134
|
+
return ir.passes.PassResult(model, modified=False)
|
|
135
|
+
|
|
136
|
+
for op_identifier in unused:
|
|
137
|
+
del model.functions[op_identifier]
|
|
138
|
+
|
|
139
|
+
logger.info("Removed %s unused functions", len(unused))
|
|
140
|
+
logger.debug("Functions left: %s", list(model.functions))
|
|
141
|
+
logger.debug("Functions removed: %s", unused)
|
|
142
|
+
|
|
143
|
+
self._used = None
|
|
144
|
+
return ir.passes.PassResult(model, modified=bool(unused))
|
|
145
|
+
|
|
146
|
+
def _call_function(self, model: ir.Model, function: ir.Function) -> None:
|
|
147
|
+
assert self._used is not None
|
|
148
|
+
if function.identifier() in self._used:
|
|
149
|
+
# The function and its nodes are already recorded as used
|
|
150
|
+
return
|
|
151
|
+
self._used.add(function.identifier())
|
|
152
|
+
for node in ir.traversal.RecursiveGraphIterator(function):
|
|
153
|
+
self._call_node(model, node)
|
|
154
|
+
|
|
155
|
+
def _call_node(self, model: ir.Model, node: ir.Node) -> None:
|
|
156
|
+
op_identifier = node.op_identifier()
|
|
157
|
+
if op_identifier not in model.functions:
|
|
158
|
+
return
|
|
159
|
+
self._call_function(model, model.functions[op_identifier])
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
|
|
163
|
+
"""Remove unused opset imports from the model and functions.
|
|
164
|
+
|
|
165
|
+
Attributes:
|
|
166
|
+
process_functions: Whether to process functions in the model. If True, the pass will
|
|
167
|
+
remove unused opset imports from functions as well. If False, only the main graph
|
|
168
|
+
will be processed.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, process_functions: bool = True):
|
|
172
|
+
super().__init__()
|
|
173
|
+
self.process_functions = process_functions
|
|
174
|
+
|
|
175
|
+
def _process_graph_like(
|
|
176
|
+
self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
|
|
177
|
+
) -> bool:
|
|
178
|
+
for node in ir.traversal.RecursiveGraphIterator(graph_like):
|
|
179
|
+
used_domains.add(node.domain)
|
|
180
|
+
unused = set(graph_like.opset_imports) - used_domains
|
|
181
|
+
for domain in unused:
|
|
182
|
+
del graph_like.opset_imports[domain]
|
|
183
|
+
return bool(unused)
|
|
184
|
+
|
|
185
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
186
|
+
# Record domains of all functions
|
|
187
|
+
used_domains = {""} # By default always retain the onnx (default) domain
|
|
188
|
+
for function in model.functions.values():
|
|
189
|
+
used_domains.add(function.domain)
|
|
190
|
+
modified = self._process_graph_like(model.graph, used_domains=used_domains)
|
|
191
|
+
|
|
192
|
+
if self.process_functions:
|
|
193
|
+
for function in model.functions.values():
|
|
194
|
+
modified |= self._process_graph_like(function, used_domains={""})
|
|
195
|
+
|
|
196
|
+
return ir.passes.PassResult(model, modified=modified)
|