onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,112 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shape inference pass using onnx.shape_inference."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "ShapeInferencePass",
9
+ "infer_shapes",
10
+ ]
11
+
12
+ import logging
13
+
14
+ import onnx
15
+
16
+ import onnx_ir as ir
17
+ from onnx_ir.passes.common import _c_api_utils
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
23
+ """Merge the shape inferred model with the original model.
24
+
25
+ Args:
26
+ model: The original IR model.
27
+ inferred_proto: The ONNX model with shapes and types inferred.
28
+
29
+ Returns:
30
+ A tuple containing the modified model and a boolean indicating whether the model was modified.
31
+ """
32
+ inferred_model = ir.serde.deserialize_model(inferred_proto)
33
+ modified = False
34
+ for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
35
+ original_values = ir.convenience.create_value_mapping(original_graph)
36
+ inferred_values = ir.convenience.create_value_mapping(inferred_graph)
37
+ for name, value in original_values.items():
38
+ if name in inferred_values:
39
+ inferred_value = inferred_values[name]
40
+ if value.shape != inferred_value.shape and inferred_value.shape is not None:
41
+ value.shape = inferred_value.shape
42
+ modified = True
43
+ if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
44
+ value.dtype = inferred_value.dtype
45
+ modified = True
46
+ else:
47
+ logger.warning(
48
+ "Value %s not found in inferred graph %s", name, inferred_graph.name
49
+ )
50
+ return modified
51
+
52
+
53
+ class ShapeInferencePass(ir.passes.InPlacePass):
54
+ """This pass performs shape inference on the graph."""
55
+
56
+ def __init__(
57
+ self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
58
+ ) -> None:
59
+ """Initialize the shape inference pass.
60
+
61
+ If inference fails, the model is left unchanged.
62
+
63
+ Args:
64
+ check_type: If True, check the types of the inputs and outputs.
65
+ strict_mode: If True, use strict mode for shape inference.
66
+ data_prop: If True, use data propagation for shape inference.
67
+ """
68
+ super().__init__()
69
+ self.check_type = check_type
70
+ self.strict_mode = strict_mode
71
+ self.data_prop = data_prop
72
+
73
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
74
+ def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75
+ return onnx.shape_inference.infer_shapes(
76
+ proto,
77
+ check_type=self.check_type,
78
+ strict_mode=self.strict_mode,
79
+ data_prop=self.data_prop,
80
+ )
81
+
82
+ try:
83
+ inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
84
+ except Exception as e: # pylint: disable=broad-exception-caught
85
+ logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
86
+ return ir.passes.PassResult(model, False)
87
+
88
+ modified = _merge_func(model, inferred_model_proto)
89
+ return ir.passes.PassResult(model, modified=modified)
90
+
91
+
92
+ def infer_shapes(
93
+ model: ir.Model,
94
+ *,
95
+ check_type: bool = True,
96
+ strict_mode: bool = True,
97
+ data_prop: bool = True,
98
+ ) -> ir.Model:
99
+ """Perform shape inference on the model.
100
+
101
+ Args:
102
+ model: The model to perform shape inference on.
103
+ check_type: If True, check the types of the inputs and outputs.
104
+ strict_mode: If True, use strict mode for shape inference.
105
+ data_prop: If True, use data propagation for shape inference.
106
+
107
+ Returns:
108
+ The model with shape inference applied.
109
+ """
110
+ return ShapeInferencePass(
111
+ check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
112
+ )(model).model
@@ -0,0 +1,33 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for topologically sorting the graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "TopologicalSortPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class TopologicalSortPass(ir.passes.InPlacePass):
16
+ """Topologically sort graphs and functions in a model."""
17
+
18
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
19
+ original_nodes = list(model.graph)
20
+ model.graph.sort()
21
+ sorted_nodes = list(model.graph)
22
+ for function in model.functions.values():
23
+ original_nodes.extend(function)
24
+ function.sort()
25
+ sorted_nodes.extend(function)
26
+
27
+ # Compare node orders to determine if any changes were made
28
+ modified = False
29
+ for node, new_node in zip(original_nodes, sorted_nodes):
30
+ if node is not new_node:
31
+ modified = True
32
+ break
33
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -0,0 +1,196 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "RemoveUnusedNodesPass",
7
+ "RemoveUnusedFunctionsPass",
8
+ "RemoveUnusedOpsetsPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _remove_unused_optional_outputs(
21
+ node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
22
+ ) -> None:
23
+ try:
24
+ if node.domain not in {"", "onnx.ai"}:
25
+ return
26
+ op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
27
+ except Exception: # pylint: disable=broad-exception-caught
28
+ logger.info(
29
+ "Failed to get schema for %s, skipping optional output removal",
30
+ node,
31
+ stack_info=True,
32
+ )
33
+ return
34
+
35
+ if node.op_type == "BatchNormalization":
36
+ # BatchNormalization op has 3 outputs: Y, running_mean, running_var
37
+ # If running_mean and running_var are not used, remove them, and the training_mode attribute
38
+ def is_used_output(i: int) -> bool:
39
+ if i < len(node.outputs):
40
+ val = node.outputs[i]
41
+ return val in graph_outputs or bool(val.uses())
42
+ return False
43
+
44
+ if is_used_output(1) or is_used_output(2):
45
+ return
46
+ if len(node.outputs) > 1:
47
+ node.outputs[1].name = ""
48
+ if len(node.outputs) > 2:
49
+ node.outputs[2].name = ""
50
+ node.attributes.pop("training_mode", None)
51
+ return
52
+
53
+ optional_info = []
54
+ for o in op_schema.outputs:
55
+ # Current ops do not have optional outputs if they have variable number of outputs
56
+ if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
57
+ return
58
+ optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
59
+ # If no optional outputs in spec, skip delete operations
60
+ if len([o == 1 for o in optional_info]) == 0:
61
+ return
62
+
63
+ for i, out in enumerate(node.outputs):
64
+ if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
65
+ out.name = ""
66
+
67
+
68
+ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
69
+ graph_outputs = frozenset(function_or_graph.outputs)
70
+ onnx_opset_version = function_or_graph.opset_imports.get("", None)
71
+ count = 0
72
+ for node in reversed(function_or_graph):
73
+ removable = True
74
+ for output in node.outputs:
75
+ if output in graph_outputs or output.uses():
76
+ removable = False
77
+ break
78
+ if removable:
79
+ function_or_graph.remove(node, safe=True)
80
+ count += 1
81
+ else:
82
+ if onnx_opset_version is not None:
83
+ _remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
84
+ for attr in node.attributes.values():
85
+ if not isinstance(attr, ir.Attr):
86
+ continue
87
+ if attr.type == ir.AttributeType.GRAPH:
88
+ count += _remove_unused_nodes_in_graph_like(attr.as_graph())
89
+ elif attr.type == ir.AttributeType.GRAPHS:
90
+ for graph in attr.as_graphs():
91
+ count += _remove_unused_nodes_in_graph_like(graph)
92
+ return count
93
+
94
+
95
+ class RemoveUnusedNodesPass(ir.passes.InPlacePass):
96
+ """Pass for removing unused nodes and initializers (dead code elimination).
97
+
98
+ This pass does not modify the model signature (inputs and outputs). It ensures
99
+ that unused nodes and initializers are removed while preserving the original
100
+ contract of the model.
101
+ """
102
+
103
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
104
+ count = _remove_unused_nodes_in_graph_like(model.graph)
105
+ graph_outputs = frozenset(model.graph.outputs)
106
+ graph_inputs = frozenset(model.graph.inputs)
107
+ initializers = model.graph.initializers
108
+ for init in list(initializers.values()):
109
+ if not (init.uses() or init in graph_outputs or init in graph_inputs):
110
+ assert init.name is not None
111
+ del initializers[init.name]
112
+ count += 1
113
+ for function in model.functions.values():
114
+ count += _remove_unused_nodes_in_graph_like(function)
115
+ if count:
116
+ logger.info("Removed %s unused nodes", count)
117
+ return ir.passes.PassResult(model, modified=bool(count))
118
+
119
+
120
+ class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
121
+ def __init__(self):
122
+ super().__init__()
123
+ self._used: set[ir.OperatorIdentifier] | None = None
124
+
125
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
126
+ self._used = set()
127
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
128
+ self._call_node(model, node)
129
+
130
+ # Update the model to remove unused functions
131
+ unused = set(model.functions) - self._used
132
+ if not unused:
133
+ logger.info("No unused functions to remove")
134
+ return ir.passes.PassResult(model, modified=False)
135
+
136
+ for op_identifier in unused:
137
+ del model.functions[op_identifier]
138
+
139
+ logger.info("Removed %s unused functions", len(unused))
140
+ logger.debug("Functions left: %s", list(model.functions))
141
+ logger.debug("Functions removed: %s", unused)
142
+
143
+ self._used = None
144
+ return ir.passes.PassResult(model, modified=bool(unused))
145
+
146
+ def _call_function(self, model: ir.Model, function: ir.Function) -> None:
147
+ assert self._used is not None
148
+ if function.identifier() in self._used:
149
+ # The function and its nodes are already recorded as used
150
+ return
151
+ self._used.add(function.identifier())
152
+ for node in ir.traversal.RecursiveGraphIterator(function):
153
+ self._call_node(model, node)
154
+
155
+ def _call_node(self, model: ir.Model, node: ir.Node) -> None:
156
+ op_identifier = node.op_identifier()
157
+ if op_identifier not in model.functions:
158
+ return
159
+ self._call_function(model, model.functions[op_identifier])
160
+
161
+
162
+ class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
163
+ """Remove unused opset imports from the model and functions.
164
+
165
+ Attributes:
166
+ process_functions: Whether to process functions in the model. If True, the pass will
167
+ remove unused opset imports from functions as well. If False, only the main graph
168
+ will be processed.
169
+ """
170
+
171
+ def __init__(self, process_functions: bool = True):
172
+ super().__init__()
173
+ self.process_functions = process_functions
174
+
175
+ def _process_graph_like(
176
+ self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
177
+ ) -> bool:
178
+ for node in ir.traversal.RecursiveGraphIterator(graph_like):
179
+ used_domains.add(node.domain)
180
+ unused = set(graph_like.opset_imports) - used_domains
181
+ for domain in unused:
182
+ del graph_like.opset_imports[domain]
183
+ return bool(unused)
184
+
185
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
186
+ # Record domains of all functions
187
+ used_domains = {""} # By default always retain the onnx (default) domain
188
+ for function in model.functions.values():
189
+ used_domains.add(function.domain)
190
+ modified = self._process_graph_like(model.graph, used_domains=used_domains)
191
+
192
+ if self.process_functions:
193
+ for function in model.functions.values():
194
+ modified |= self._process_graph_like(function, used_domains={""})
195
+
196
+ return ir.passes.PassResult(model, modified=modified)