onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,57 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Passes for debugging purposes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CheckerPass",
9
+ ]
10
+
11
+ from typing import Literal
12
+
13
+ import onnx # noqa: TID251
14
+
15
+ import onnx_ir as ir
16
+ from onnx_ir.passes.common import _c_api_utils
17
+
18
+
19
+ class CheckerPass(ir.passes.PassBase):
20
+ """Run onnx checker on the model."""
21
+
22
+ @property
23
+ def in_place(self) -> Literal[True]:
24
+ """This pass does not create a new model."""
25
+ return True
26
+
27
+ @property
28
+ def changes_input(self) -> Literal[False]:
29
+ """This pass does not change the input model."""
30
+ return False
31
+
32
+ def __init__(
33
+ self,
34
+ full_check: bool = False,
35
+ skip_opset_compatibility_check: bool = False,
36
+ check_custom_domain: bool = False,
37
+ ):
38
+ super().__init__()
39
+ self.full_check = full_check
40
+ self.skip_opset_compatibility_check = skip_opset_compatibility_check
41
+ self.check_custom_domain = check_custom_domain
42
+
43
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
44
+ """Run the onnx checker on the model."""
45
+
46
+ def _partial_check_model(proto: onnx.ModelProto) -> None:
47
+ """Partial function to check the model."""
48
+ onnx.checker.check_model(
49
+ proto,
50
+ full_check=self.full_check,
51
+ skip_opset_compatibility_check=self.skip_opset_compatibility_check,
52
+ check_custom_domain=self.check_custom_domain,
53
+ )
54
+
55
+ _c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
56
+ # The model is not modified
57
+ return ir.passes.PassResult(model, False)
@@ -0,0 +1,141 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Output fix pass for adding Identity nodes.
4
+
5
+ - Graph inputs are directly used as outputs (without any intermediate nodes).
6
+ - A value is used multiple times as a graph output (ensuring each output is unique).
7
+
8
+ This ensures compliance with the ONNX specification for valid output configurations.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ __all__ = [
14
+ "OutputFixPass",
15
+ ]
16
+
17
+ import logging
18
+
19
+ import onnx_ir as ir
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class OutputFixPass(ir.passes.InPlacePass):
25
+ """Pass for adding Identity nodes to fix invalid output configurations.
26
+
27
+ This pass adds Identity nodes according to the following rules:
28
+
29
+ - If a graph input is directly used as a graph output (without any intermediate nodes),
30
+ insert an Identity node between them. The ONNX specification does not allow a graph
31
+ input to be directly used as a graph output without any processing nodes in between.
32
+ - If a value is used multiple times as graph outputs, insert Identity nodes for each
33
+ duplicate usage (keeping the first usage unchanged). This ensures each output value
34
+ is unique, as required by the ONNX specification.
35
+
36
+ This pass processes both the main graph and all subgraphs (e.g., in control flow operators).
37
+
38
+ Example transformations:
39
+ Direct input-to-output:
40
+ Before: input -> (direct connection) -> output
41
+ After: input -> Identity -> output
42
+
43
+ Duplicate outputs:
44
+ Before: value -> [output1, output2]
45
+ After: value -> output1, value -> Identity -> output2
46
+ """
47
+
48
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
49
+ """Main entry point for the output fix pass."""
50
+ modified = False
51
+
52
+ # Process the main graph
53
+ if _alias_multi_used_outputs(model.graph):
54
+ modified = True
55
+ if _alias_direct_outputs(model.graph):
56
+ modified = True
57
+
58
+ # Process functions
59
+ for function in model.functions.values():
60
+ if _alias_multi_used_outputs(function):
61
+ modified = True
62
+ if _alias_direct_outputs(function):
63
+ modified = True
64
+
65
+ return ir.passes.PassResult(model, modified=modified)
66
+
67
+
68
+ def _alias_multi_used_outputs(graph_like: ir.Graph | ir.Function) -> bool:
69
+ """Insert Identity nodes for values that appear in the graph output list multiple times."""
70
+ modified = False
71
+
72
+ for graph in (graph_like, *graph_like.subgraphs()):
73
+ # Count usage of each output
74
+ seen: set[ir.Value] = set()
75
+
76
+ # Add Identity nodes for outputs used multiple times
77
+ for i, output in enumerate(graph.outputs):
78
+ if output not in seen:
79
+ # Skip the first occurrence
80
+ seen.add(output)
81
+ continue
82
+
83
+ # Create an Identity node
84
+ identity_node = ir.node("Identity", inputs=[output])
85
+ identity_output = identity_node.outputs[0]
86
+
87
+ # Copy metadata from the original output
88
+ # TODO: Use a better unique naming strategy if needed
89
+ identity_output.name = f"{output.name}_alias_{i}"
90
+ identity_output.shape = output.shape
91
+ identity_output.type = output.type
92
+ identity_output.metadata_props.update(output.metadata_props)
93
+ identity_output.doc_string = output.doc_string
94
+
95
+ # Add the node to the graph
96
+ graph.append(identity_node)
97
+ graph.outputs[i] = identity_output
98
+ logger.debug(
99
+ "Added Identity node for graph output '%s' used multiple times", output
100
+ )
101
+ modified = True
102
+ return modified
103
+
104
+
105
+ def _alias_direct_outputs(graph_like: ir.Graph | ir.Function) -> bool:
106
+ """Insert Identity nodes for graph inputs used directly as outputs."""
107
+ modified = False
108
+
109
+ for graph in (graph_like, *graph_like.subgraphs()):
110
+ # Check each output to see if it's directly a graph input
111
+ outputs_to_fix: list[tuple[ir.Value, int]] = []
112
+ for i, output in enumerate(graph.outputs):
113
+ if output.is_graph_input():
114
+ outputs_to_fix.append((output, i))
115
+
116
+ # Add Identity nodes for each output that needs fixing
117
+ for output, index in outputs_to_fix:
118
+ # Create an Identity node
119
+ identity_node = ir.node("Identity", inputs=[output])
120
+ identity_output = identity_node.outputs[0]
121
+
122
+ # Copy metadata from the original output
123
+ # Preserve the original output name
124
+ identity_output.name = output.name
125
+ identity_output.shape = output.shape
126
+ identity_output.type = output.type
127
+ identity_output.metadata_props.update(output.metadata_props)
128
+ identity_output.doc_string = output.doc_string
129
+
130
+ # Create a new name for the old output
131
+ # TODO: Use a better unique naming strategy if needed
132
+ output.name = f"{output.name}_orig"
133
+
134
+ # Add the node to the graph
135
+ graph.append(identity_node)
136
+ graph.outputs[index] = identity_output
137
+
138
+ logger.debug("Added Identity node for graph input '%s' used as output", output)
139
+ modified = True
140
+
141
+ return modified
@@ -0,0 +1,112 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shape inference pass using onnx.shape_inference."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "ShapeInferencePass",
9
+ "infer_shapes",
10
+ ]
11
+
12
+ import logging
13
+
14
+ import onnx # noqa: TID251
15
+
16
+ import onnx_ir as ir
17
+ from onnx_ir.passes.common import _c_api_utils
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
23
+ """Merge the shape inferred model with the original model.
24
+
25
+ Args:
26
+ model: The original IR model.
27
+ inferred_proto: The ONNX model with shapes and types inferred.
28
+
29
+ Returns:
30
+ A tuple containing the modified model and a boolean indicating whether the model was modified.
31
+ """
32
+ inferred_model = ir.serde.deserialize_model(inferred_proto)
33
+ modified = False
34
+ for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
35
+ original_values = ir.convenience.create_value_mapping(original_graph)
36
+ inferred_values = ir.convenience.create_value_mapping(inferred_graph)
37
+ for name, value in original_values.items():
38
+ if name in inferred_values:
39
+ inferred_value = inferred_values[name]
40
+ if value.shape != inferred_value.shape and inferred_value.shape is not None:
41
+ value.shape = inferred_value.shape
42
+ modified = True
43
+ if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
44
+ value.dtype = inferred_value.dtype
45
+ modified = True
46
+ else:
47
+ logger.warning(
48
+ "Value %s not found in inferred graph %s", name, inferred_graph.name
49
+ )
50
+ return modified
51
+
52
+
53
+ class ShapeInferencePass(ir.passes.InPlacePass):
54
+ """This pass performs shape inference on the graph."""
55
+
56
+ def __init__(
57
+ self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
58
+ ) -> None:
59
+ """Initialize the shape inference pass.
60
+
61
+ If inference fails, the model is left unchanged.
62
+
63
+ Args:
64
+ check_type: If True, check the types of the inputs and outputs.
65
+ strict_mode: If True, use strict mode for shape inference.
66
+ data_prop: If True, use data propagation for shape inference.
67
+ """
68
+ super().__init__()
69
+ self.check_type = check_type
70
+ self.strict_mode = strict_mode
71
+ self.data_prop = data_prop
72
+
73
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
74
+ def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75
+ return onnx.shape_inference.infer_shapes(
76
+ proto,
77
+ check_type=self.check_type,
78
+ strict_mode=self.strict_mode,
79
+ data_prop=self.data_prop,
80
+ )
81
+
82
+ try:
83
+ inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
84
+ except Exception as e: # pylint: disable=broad-exception-caught
85
+ logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
86
+ return ir.passes.PassResult(model, False)
87
+
88
+ modified = _merge_func(model, inferred_model_proto)
89
+ return ir.passes.PassResult(model, modified=modified)
90
+
91
+
92
+ def infer_shapes(
93
+ model: ir.Model,
94
+ *,
95
+ check_type: bool = True,
96
+ strict_mode: bool = True,
97
+ data_prop: bool = True,
98
+ ) -> ir.Model:
99
+ """Perform shape inference on the model.
100
+
101
+ Args:
102
+ model: The model to perform shape inference on.
103
+ check_type: If True, check the types of the inputs and outputs.
104
+ strict_mode: If True, use strict mode for shape inference.
105
+ data_prop: If True, use data propagation for shape inference.
106
+
107
+ Returns:
108
+ The model with shape inference applied.
109
+ """
110
+ return ShapeInferencePass(
111
+ check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
112
+ )(model).model
@@ -0,0 +1,37 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for topologically sorting the graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "TopologicalSortPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class TopologicalSortPass(ir.passes.InPlacePass):
16
+ """Topologically sort graphs and functions in a model.
17
+
18
+ The sort is stable, preserving the relative order of nodes that are not
19
+ dependent on each other. Read more at :meth:`onnx_ir.Graph.sort`.
20
+ """
21
+
22
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
23
+ original_nodes = list(model.graph)
24
+ model.graph.sort()
25
+ sorted_nodes = list(model.graph)
26
+ for function in model.functions.values():
27
+ original_nodes.extend(function)
28
+ function.sort()
29
+ sorted_nodes.extend(function)
30
+
31
+ # Compare node orders to determine if any changes were made
32
+ modified = False
33
+ for node, new_node in zip(original_nodes, sorted_nodes):
34
+ if node is not new_node:
35
+ modified = True
36
+ break
37
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -0,0 +1,215 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
4
+
5
+ __all__ = [
6
+ "RemoveUnusedNodesPass",
7
+ "RemoveUnusedFunctionsPass",
8
+ "RemoveUnusedOpsetsPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx # noqa: TID251
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _remove_unused_optional_outputs(
21
+ node: ir.Node, graph_outputs: frozenset[ir.Value], onnx_opset_version: int
22
+ ) -> None:
23
+ try:
24
+ if node.domain not in {"", "onnx.ai"}:
25
+ return
26
+ op_schema = onnx.defs.get_schema(node.op_type, onnx_opset_version, domain=node.domain)
27
+ except Exception: # pylint: disable=broad-exception-caught
28
+ logger.info(
29
+ "Failed to get schema for %s, skipping optional output removal",
30
+ node,
31
+ stack_info=True,
32
+ )
33
+ return
34
+
35
+ if node.op_type == "BatchNormalization":
36
+ # BatchNormalization op has 3 outputs: Y, running_mean, running_var
37
+ # If running_mean and running_var are not used, remove them, and the training_mode attribute
38
+ def is_used_output(i: int) -> bool:
39
+ if i < len(node.outputs):
40
+ val = node.outputs[i]
41
+ return val in graph_outputs or bool(val.uses())
42
+ return False
43
+
44
+ if is_used_output(1) or is_used_output(2):
45
+ return
46
+ if len(node.outputs) > 1:
47
+ node.outputs[1].name = ""
48
+ if len(node.outputs) > 2:
49
+ node.outputs[2].name = ""
50
+ node.attributes.pop("training_mode", None)
51
+ return
52
+
53
+ optional_info = []
54
+ for o in op_schema.outputs:
55
+ # Current ops do not have optional outputs if they have variable number of outputs
56
+ if o.option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
57
+ return
58
+ optional_info.append(o.option == onnx.defs.OpSchema.FormalParameterOption.Optional)
59
+ # If no optional outputs in spec, skip delete operations
60
+ if len([o == 1 for o in optional_info]) == 0:
61
+ return
62
+
63
+ for i, out in enumerate(node.outputs):
64
+ if out not in graph_outputs and (not out.uses()) and optional_info[i] is True:
65
+ out.name = ""
66
+
67
+ # Remove trailing outputs with empty names by counting backwards
68
+ new_output_count = len(node.outputs)
69
+ for i in reversed(range(len(node.outputs))):
70
+ if not node.outputs[i].name:
71
+ new_output_count -= 1
72
+ else:
73
+ break
74
+ node.resize_outputs(new_output_count)
75
+
76
+
77
+ def _remove_trailing_empty_inputs(node: ir.Node) -> None:
78
+ # Remove trailing None inputs
79
+ new_input_count = len(node.inputs)
80
+ for i in reversed(range(len(node.inputs))):
81
+ if node.inputs[i] is None:
82
+ new_input_count -= 1
83
+ else:
84
+ break
85
+ node.resize_inputs(new_input_count)
86
+
87
+
88
+ def _remove_unused_nodes_in_graph_like(function_or_graph: ir.Function | ir.Graph) -> int:
89
+ graph_outputs = frozenset(function_or_graph.outputs)
90
+ onnx_opset_version = function_or_graph.opset_imports.get("", None)
91
+ count = 0
92
+ for node in reversed(function_or_graph):
93
+ removable = True
94
+ for output in node.outputs:
95
+ if output in graph_outputs or output.uses():
96
+ removable = False
97
+ break
98
+ if removable:
99
+ function_or_graph.remove(node, safe=True)
100
+ count += 1
101
+ else:
102
+ _remove_trailing_empty_inputs(node)
103
+ if onnx_opset_version is not None:
104
+ _remove_unused_optional_outputs(node, graph_outputs, onnx_opset_version)
105
+ for attr in node.attributes.values():
106
+ if attr.type == ir.AttributeType.GRAPH:
107
+ count += _remove_unused_nodes_in_graph_like(attr.as_graph())
108
+ elif attr.type == ir.AttributeType.GRAPHS:
109
+ for graph in attr.as_graphs():
110
+ count += _remove_unused_nodes_in_graph_like(graph)
111
+ return count
112
+
113
+
114
+ class RemoveUnusedNodesPass(ir.passes.InPlacePass):
115
+ """Pass for removing unused nodes and initializers (dead code elimination).
116
+
117
+ This pass does not modify the model signature (inputs and outputs). It ensures
118
+ that unused nodes and initializers are removed while preserving the original
119
+ contract of the model.
120
+ """
121
+
122
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
123
+ count = _remove_unused_nodes_in_graph_like(model.graph)
124
+ graph_outputs = frozenset(model.graph.outputs)
125
+ graph_inputs = frozenset(model.graph.inputs)
126
+ initializers = model.graph.initializers
127
+ for init in list(initializers.values()):
128
+ if not (init.uses() or init in graph_outputs or init in graph_inputs):
129
+ assert init.name is not None
130
+ del initializers[init.name]
131
+ count += 1
132
+ for function in model.functions.values():
133
+ count += _remove_unused_nodes_in_graph_like(function)
134
+ if count:
135
+ logger.info("Removed %s unused nodes", count)
136
+ return ir.passes.PassResult(model, modified=bool(count))
137
+
138
+
139
+ class RemoveUnusedFunctionsPass(ir.passes.InPlacePass):
140
+ def __init__(self):
141
+ super().__init__()
142
+ self._used: set[ir.OperatorIdentifier] | None = None
143
+
144
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
145
+ self._used = set()
146
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
147
+ self._call_node(model, node)
148
+
149
+ # Update the model to remove unused functions
150
+ unused = set(model.functions) - self._used
151
+ if not unused:
152
+ logger.info("No unused functions to remove")
153
+ return ir.passes.PassResult(model, modified=False)
154
+
155
+ for op_identifier in unused:
156
+ del model.functions[op_identifier]
157
+
158
+ logger.info("Removed %s unused functions", len(unused))
159
+ logger.debug("Functions left: %s", list(model.functions))
160
+ logger.debug("Functions removed: %s", unused)
161
+
162
+ self._used = None
163
+ return ir.passes.PassResult(model, modified=bool(unused))
164
+
165
+ def _call_function(self, model: ir.Model, function: ir.Function) -> None:
166
+ assert self._used is not None
167
+ if function.identifier() in self._used:
168
+ # The function and its nodes are already recorded as used
169
+ return
170
+ self._used.add(function.identifier())
171
+ for node in ir.traversal.RecursiveGraphIterator(function):
172
+ self._call_node(model, node)
173
+
174
+ def _call_node(self, model: ir.Model, node: ir.Node) -> None:
175
+ op_identifier = node.op_identifier()
176
+ if op_identifier not in model.functions:
177
+ return
178
+ self._call_function(model, model.functions[op_identifier])
179
+
180
+
181
+ class RemoveUnusedOpsetsPass(ir.passes.InPlacePass):
182
+ """Remove unused opset imports from the model and functions.
183
+
184
+ Attributes:
185
+ process_functions: Whether to process functions in the model. If True, the pass will
186
+ remove unused opset imports from functions as well. If False, only the main graph
187
+ will be processed.
188
+ """
189
+
190
+ def __init__(self, process_functions: bool = True):
191
+ super().__init__()
192
+ self.process_functions = process_functions
193
+
194
+ def _process_graph_like(
195
+ self, graph_like: ir.Graph | ir.Function, used_domains: set[str]
196
+ ) -> bool:
197
+ for node in ir.traversal.RecursiveGraphIterator(graph_like):
198
+ used_domains.add(node.domain)
199
+ unused = set(graph_like.opset_imports) - used_domains
200
+ for domain in unused:
201
+ del graph_like.opset_imports[domain]
202
+ return bool(unused)
203
+
204
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
205
+ # Record domains of all functions
206
+ used_domains = {""} # By default always retain the onnx (default) domain
207
+ for function in model.functions.values():
208
+ used_domains.add(function.domain)
209
+ modified = self._process_graph_like(model.graph, used_domains=used_domains)
210
+
211
+ if self.process_functions:
212
+ for function in model.functions.values():
213
+ modified |= self._process_graph_like(function, used_domains={""})
214
+
215
+ return ir.passes.PassResult(model, modified=modified)
onnx_ir/py.typed ADDED
@@ -0,0 +1 @@
1
+