onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,196 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"RemoveUnusedNodesPass",
|
|
7
|
+
"RemoveUnusedFunctionsPass",
|
|
8
|
+
"RemoveUnusedOpsetsPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _remove_unused_optional_outputs(
|
|
21
|
+
node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
|
|
22
|
+
) -> None:
|
|
23
|
+
try:
|
|
24
|
+
if node.domain not in {"", "onnx.ai"}:
|
|
25
|
+
return
|
|
26
|
+
op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
|
|
27
|
+
except Exception: # pylint: disable=broad-exception-caught
|
|
28
|
+
logger.info(
|
|
29
|
+
"Failed to get schema for %s, skipping optional output removal",
|
|
30
|
+
node,
|
|
31
|
+
stack_info=True,
|
|
32
|
+
)
|
|
33
|
+
return
|
|
34
|
+
|
|
35
|
+
if node.op_type == "BatchNormalization":
|
|
36
|
+
# BatchNormalization op has 3 outputs: Y, running_mean, running_var
|
|
37
|
+
# If running_mean and running_var are not used, remove them, and the training_mode attribute
|
|
38
|
+
def is_used_output(i: int) -> bool:
|
|
39
|
+
if i < len(node.outputs):
|
|
40
|
+
val = node.outputs[i]
|
|
41
|
+
return val in graph_outputs or bool(val.uses())
|
|
42
|
+
return False
|
|
43
|
+
|
|
44
|
+
if is_used_output(1) or is_used_output(2):
|
|
45
|
+
return
|
|
46
|
+
if len(node.outputs) > 1:
|
|
47
|
+
node.outputs[1].name = ""
|
|
48
|
+
if len(node.outputs) > 2:
|
|
49
|
+
node.outputs[2].name = ""
|
|
50
|
+
node.attributes.pop("training_mode", None)
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
optional_info = []
|
|
54
|
+
for o in op_schema.outputs:
|
|
55
|
+
# Current ops do not have optional outputs if they have variable number of outputs
|
|
56
|
+
if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
|
|
57
|
+
return
|
|
58
|
+
optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
|
|
59
|
+
# If no optional outputs in spec, skip delete operations
|
|
60
|
+
if len([o == 1 for o in optional_info]) == 0:
|
|
61
|
+
return
|
|
62
|
+
|
|
63
|
+
for i, out in enumerate(node.outputs):
|
|
64
|
+
if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
|
|
65
|
+
out.name = ""
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
|
|
69
|
+
graph_outputs = frozenset(function_or_graph.outputs)
|
|
70
|
+
onnx_opset_version = function_or_graph.opset_imports.get("", None)
|
|
71
|
+
count = 0
|
|
72
|
+
for node in reversed(function_or_graph):
|
|
73
|
+
removable = True
|
|
74
|
+
for output in node.outputs:
|
|
75
|
+
if output in graph_outputs or output.uses():
|
|
76
|
+
removable = False
|
|
77
|
+
break
|
|
78
|
+
if removable:
|
|
79
|
+
function_or_graph.remove(node, safe=True)
|
|
80
|
+
count += 1
|
|
81
|
+
else:
|
|
82
|
+
if onnx_opset_version is not None:
|
|
83
|
+
_remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
|
|
84
|
+
for attr in node.attributes.values():
|
|
85
|
+
if not isinstance(attr, ir.Attr):
|
|
86
|
+
continue
|
|
87
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
88
|
+
count += _remove_unused_nodes_in_graph_like(attr.as_graph())
|
|
89
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
90
|
+
for graph in attr.as_graphs():
|
|
91
|
+
count += _remove_unused_nodes_in_graph_like(graph)
|
|
92
|
+
return count
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class RemoveUnusedNodesPass(ir.passes.InPlacePass):
|
|
96
|
+
"""Pass for removing unused nodes and initializers (dead code elimination).
|
|
97
|
+
|
|
98
|
+
This pass does not modify the model signature (inputs and outputs). It ensures
|
|
99
|
+
that unused nodes and initializers are removed while preserving the original
|
|
100
|
+
contract of the model.
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
104
|
+
count = _remove_unused_nodes_in_graph_like(model.graph)
|
|
105
|
+
graph_outputs = frozenset(model.graph.outputs)
|
|
106
|
+
graph_inputs = frozenset(model.graph.inputs)
|
|
107
|
+
initializers = model.graph.initializers
|
|
108
|
+
for init in list(initializers.values()):
|
|
109
|
+
if not (init.uses() or init in graph_outputs or init in graph_inputs):
|
|
110
|
+
assert init.name is not None
|
|
111
|
+
del initializers[init.name]
|
|
112
|
+
count += 1
|
|
113
|
+
for function in model.functions.values():
|
|
114
|
+
count += _remove_unused_nodes_in_graph_like(function)
|
|
115
|
+
if count:
|
|
116
|
+
logger.info("Removed %s unused nodes", count)
|
|
117
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
|
|
121
|
+
def __init__(self):
|
|
122
|
+
super().__init__()
|
|
123
|
+
self._used: set[ir.OperatorIdentifier] | None = None
|
|
124
|
+
|
|
125
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
126
|
+
self._used = set()
|
|
127
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
128
|
+
self._call_node(model, node)
|
|
129
|
+
|
|
130
|
+
# Update the model to remove unused functions
|
|
131
|
+
unused = set(model.functions) - self._used
|
|
132
|
+
if not unused:
|
|
133
|
+
logger.info("No unused functions to remove")
|
|
134
|
+
return ir.passes.PassResult(model, modified=False)
|
|
135
|
+
|
|
136
|
+
for op_identifier in unused:
|
|
137
|
+
del model.functions[op_identifier]
|
|
138
|
+
|
|
139
|
+
logger.info("Removed %s unused functions", len(unused))
|
|
140
|
+
logger.debug("Functions left: %s", list(model.functions))
|
|
141
|
+
logger.debug("Functions removed: %s", unused)
|
|
142
|
+
|
|
143
|
+
self._used = None
|
|
144
|
+
return ir.passes.PassResult(model, modified=bool(unused))
|
|
145
|
+
|
|
146
|
+
def _call_function(self, model: ir.Model, function: ir.Function) -> None:
|
|
147
|
+
assert self._used is not None
|
|
148
|
+
if function.identifier() in self._used:
|
|
149
|
+
# The function and its nodes are already recorded as used
|
|
150
|
+
return
|
|
151
|
+
self._used.add(function.identifier())
|
|
152
|
+
for node in ir.traversal.RecursiveGraphIterator(function):
|
|
153
|
+
self._call_node(model, node)
|
|
154
|
+
|
|
155
|
+
def _call_node(self, model: ir.Model, node: ir.Node) -> None:
|
|
156
|
+
op_identifier = node.op_identifier()
|
|
157
|
+
if op_identifier not in model.functions:
|
|
158
|
+
return
|
|
159
|
+
self._call_function(model, model.functions[op_identifier])
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
|
|
163
|
+
"""Remove unused opset imports from the model and functions.
|
|
164
|
+
|
|
165
|
+
Attributes:
|
|
166
|
+
process_functions: Whether to process functions in the model. If True, the pass will
|
|
167
|
+
remove unused opset imports from functions as well. If False, only the main graph
|
|
168
|
+
will be processed.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, process_functions: bool = True):
|
|
172
|
+
super().__init__()
|
|
173
|
+
self.process_functions = process_functions
|
|
174
|
+
|
|
175
|
+
def _process_graph_like(
|
|
176
|
+
self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
|
|
177
|
+
) -> bool:
|
|
178
|
+
for node in ir.traversal.RecursiveGraphIterator(graph_like):
|
|
179
|
+
used_domains.add(node.domain)
|
|
180
|
+
unused = set(graph_like.opset_imports) - used_domains
|
|
181
|
+
for domain in unused:
|
|
182
|
+
del graph_like.opset_imports[domain]
|
|
183
|
+
return bool(unused)
|
|
184
|
+
|
|
185
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
186
|
+
# Record domains of all functions
|
|
187
|
+
used_domains = {""} # By default always retain the onnx (default) domain
|
|
188
|
+
for function in model.functions.values():
|
|
189
|
+
used_domains.add(function.domain)
|
|
190
|
+
modified = self._process_graph_like(model.graph, used_domains=used_domains)
|
|
191
|
+
|
|
192
|
+
if self.process_functions:
|
|
193
|
+
for function in model.functions.values():
|
|
194
|
+
modified |= self._process_graph_like(function, used_domains={""})
|
|
195
|
+
|
|
196
|
+
return ir.passes.PassResult(model, modified=modified)
|