onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "RemoveUnusedNodesPass",
7
+ "RemoveUnusedFunctionsPass",
8
+ "RemoveUnusedOpsetsPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _remove_unused_optional_outputs(
21
+ node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
22
+ ) -> None:
23
+ try:
24
+ if node.domain not in {"", "onnx.ai"}:
25
+ return
26
+ op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
27
+ except Exception: # pylint: disable=broad-exception-caught
28
+ logger.info(
29
+ "Failed to get schema for %s, skipping optional output removal",
30
+ node,
31
+ stack_info=True,
32
+ )
33
+ return
34
+
35
+ if node.op_type == "BatchNormalization":
36
+ # BatchNormalization op has 3 outputs: Y, running_mean, running_var
37
+ # If running_mean and running_var are not used, remove them, and the training_mode attribute
38
+ def is_used_output(i: int) -> bool:
39
+ if i < len(node.outputs):
40
+ val = node.outputs[i]
41
+ return val in graph_outputs or bool(val.uses())
42
+ return False
43
+
44
+ if is_used_output(1) or is_used_output(2):
45
+ return
46
+ if len(node.outputs) > 1:
47
+ node.outputs[1].name = ""
48
+ if len(node.outputs) > 2:
49
+ node.outputs[2].name = ""
50
+ node.attributes.pop("training_mode", None)
51
+ return
52
+
53
+ optional_info = []
54
+ for o in op_schema.outputs:
55
+ # Current ops do not have optional outputs if they have variable number of outputs
56
+ if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
57
+ return
58
+ optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
59
+ # If no optional outputs in spec, skip delete operations
60
+ if len([o == 1 for o in optional_info]) == 0:
61
+ return
62
+
63
+ for i, out in enumerate(node.outputs):
64
+ if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
65
+ out.name = ""
66
+
67
+
68
+ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
69
+ graph_outputs = frozenset(function_or_graph.outputs)
70
+ onnx_opset_version = function_or_graph.opset_imports.get("", None)
71
+ count = 0
72
+ for node in reversed(function_or_graph):
73
+ removable = True
74
+ for output in node.outputs:
75
+ if output in graph_outputs or output.uses():
76
+ removable = False
77
+ break
78
+ if removable:
79
+ function_or_graph.remove(node, safe=True)
80
+ count += 1
81
+ else:
82
+ if onnx_opset_version is not None:
83
+ _remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
84
+ for attr in node.attributes.values():
85
+ if not isinstance(attr, ir.Attr):
86
+ continue
87
+ if attr.type == ir.AttributeType.GRAPH:
88
+ count += _remove_unused_nodes_in_graph_like(attr.as_graph())
89
+ elif attr.type == ir.AttributeType.GRAPHS:
90
+ for graph in attr.as_graphs():
91
+ count += _remove_unused_nodes_in_graph_like(graph)
92
+ return count
93
+
94
+
95
+ class RemoveUnusedNodesPass(ir.passes.InPlacePass):
96
+ """Pass for removing unused nodes and initializers (dead code elimination).
97
+
98
+ This pass does not modify the model signature (inputs and outputs). It ensures
99
+ that unused nodes and initializers are removed while preserving the original
100
+ contract of the model.
101
+ """
102
+
103
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
104
+ count = _remove_unused_nodes_in_graph_like(model.graph)
105
+ graph_outputs = frozenset(model.graph.outputs)
106
+ graph_inputs = frozenset(model.graph.inputs)
107
+ initializers = model.graph.initializers
108
+ for init in list(initializers.values()):
109
+ if not (init.uses() or init in graph_outputs or init in graph_inputs):
110
+ assert init.name is not None
111
+ del initializers[init.name]
112
+ count += 1
113
+ for function in model.functions.values():
114
+ count += _remove_unused_nodes_in_graph_like(function)
115
+ if count:
116
+ logger.info("Removed %s unused nodes", count)
117
+ return ir.passes.PassResult(model, modified=bool(count))
118
+
119
+
120
+ class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self._used: set[ir.OperatorIdentifier] | None = None
124
+
125
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
126
+ self._used = set()
127
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
128
+ self._call_node(model, node)
129
+
130
+ # Update the model to remove unused functions
131
+ unused = set(model.functions) - self._used
132
+ if not unused:
133
+ logger.info("No unused functions to remove")
134
+ return ir.passes.PassResult(model, modified=False)
135
+
136
+ for op_identifier in unused:
137
+ del model.functions[op_identifier]
138
+
139
+ logger.info("Removed %s unused functions", len(unused))
140
+ logger.debug("Functions left: %s", list(model.functions))
141
+ logger.debug("Functions removed: %s", unused)
142
+
143
+ self._used = None
144
+ return ir.passes.PassResult(model, modified=bool(unused))
145
+
146
+ def _call_function(self, model: ir.Model, function: ir.Function) -> None:
147
+ assert self._used is not None
148
+ if function.identifier() in self._used:
149
+ # The function and its nodes are already recorded as used
150
+ return
151
+ self._used.add(function.identifier())
152
+ for node in ir.traversal.RecursiveGraphIterator(function):
153
+ self._call_node(model, node)
154
+
155
+ def _call_node(self, model: ir.Model, node: ir.Node) -> None:
156
+ op_identifier = node.op_identifier()
157
+ if op_identifier not in model.functions:
158
+ return
159
+ self._call_function(model, model.functions[op_identifier])
160
+
161
+
162
+ class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
163
+ """Remove unused opset imports from the model and functions.
164
+
165
+ Attributes:
166
+ process_functions: Whether to process functions in the model. If True, the pass will
167
+ remove unused opset imports from functions as well. If False, only the main graph
168
+ will be processed.
169
+ """
170
+
171
+ def __init__(self, process_functions: bool = True):
172
+ super().__init__()
173
+ self.process_functions = process_functions
174
+
175
+ def _process_graph_like(
176
+ self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
177
+ ) -> bool:
178
+ for node in ir.traversal.RecursiveGraphIterator(graph_like):
179
+ used_domains.add(node.domain)
180
+ unused = set(graph_like.opset_imports) - used_domains
181
+ for domain in unused:
182
+ del graph_like.opset_imports[domain]
183
+ return bool(unused)
184
+
185
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
186
+ # Record domains of all functions
187
+ used_domains = {""} # By default always retain the onnx (default) domain
188
+ for function in model.functions.values():
189
+ used_domains.add(function.domain)
190
+ modified = self._process_graph_like(model.graph, used_domains=used_domains)
191
+
192
+ if self.process_functions:
193
+ for function in model.functions.values():
194
+ modified |= self._process_graph_like(function, used_domains={""})
195
+
196
+ return ir.passes.PassResult(model, modified=modified)