onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (46) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +874 -257
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +373 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +40 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
  27. onnx_ir/passes/common/constant_manipulation.py +217 -0
  28. onnx_ir/passes/common/inliner.py +332 -0
  29. onnx_ir/passes/common/onnx_checker.py +57 -0
  30. onnx_ir/passes/common/shape_inference.py +112 -0
  31. onnx_ir/passes/common/topological_sort.py +33 -0
  32. onnx_ir/passes/common/unused_removal.py +196 -0
  33. onnx_ir/serde.py +288 -124
  34. onnx_ir/tape.py +15 -0
  35. onnx_ir/tensor_adapters.py +122 -0
  36. onnx_ir/testing.py +197 -0
  37. onnx_ir/traversal.py +4 -3
  38. onnx_ir-0.1.1.dist-info/METADATA +53 -0
  39. onnx_ir-0.1.1.dist-info/RECORD +42 -0
  40. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
  41. onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
  42. onnx_ir/_external_data.py +0 -323
  43. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  44. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  45. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  46. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,217 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Lift constants to initializers."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "AddInitializersToInputsPass",
9
+ "LiftConstantsToInitializersPass",
10
+ "LiftSubgraphInitializersToMainGraphPass",
11
+ "RemoveInitializersFromInputsPass",
12
+ ]
13
+
14
+ import logging
15
+
16
+ import numpy as np
17
+
18
+ import onnx_ir as ir
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
24
+ """Lift constants to initializers.
25
+
26
+ Attributes:
27
+ lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
28
+ Default to False, where only Constants with the ``value`` attribute are lifted.
29
+ size_limit: The minimum size of the tensor to be lifted. If the tensor contains
30
+ number of elements less than size_limit, it will not be lifted. Default is 16.
31
+ """
32
+
33
+ def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
34
+ super().__init__()
35
+ self.lift_all_constants = lift_all_constants
36
+ self.size_limit = size_limit
37
+
38
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
39
+ count = 0
40
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
41
+ assert node.graph is not None
42
+ if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
43
+ continue
44
+ if node.outputs[0].is_graph_output():
45
+ logger.debug(
46
+ "Constant node '%s' is used as output, so it can't be lifted.", node.name
47
+ )
48
+ continue
49
+ constant_node_attribute = set(node.attributes.keys())
50
+ if len(constant_node_attribute) != 1:
51
+ logger.debug(
52
+ "Invalid constant node '%s' has more than one attribute", node.name
53
+ )
54
+ continue
55
+
56
+ attr_name, attr_value = next(iter(node.attributes.items()))
57
+ initializer_name = node.outputs[0].name
58
+ assert initializer_name is not None
59
+ assert isinstance(attr_value, ir.Attr)
60
+ tensor = self._constant_node_attribute_to_tensor(
61
+ node, attr_name, attr_value, initializer_name
62
+ )
63
+ if tensor is None:
64
+ # The reason of None is logged in _constant_node_attribute_to_tensor
65
+ continue
66
+ # Register an initializer with the tensor value
67
+ initializer = ir.Value(
68
+ name=initializer_name,
69
+ shape=tensor.shape, # type: ignore[arg-type]
70
+ type=ir.TensorType(tensor.dtype),
71
+ const_value=tensor,
72
+ )
73
+ assert node.graph is not None
74
+ node.graph.register_initializer(initializer)
75
+ # Replace the constant node with the initializer
76
+ ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
77
+ node.graph.remove(node, safe=True)
78
+ count += 1
79
+ logger.debug(
80
+ "Converted constant node '%s' to initializer '%s'", node.name, initializer_name
81
+ )
82
+ if count:
83
+ logger.debug("Lifted %s constants to initializers", count)
84
+ return ir.passes.PassResult(model, modified=bool(count))
85
+
86
+ def _constant_node_attribute_to_tensor(
87
+ self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
88
+ ) -> ir.TensorProtocol | None:
89
+ """Convert constant node attribute to tensor."""
90
+ if not self.lift_all_constants and attr_name != "value":
91
+ logger.debug(
92
+ "Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
93
+ )
94
+ return None
95
+
96
+ tensor: ir.TensorProtocol
97
+ if attr_name == "value":
98
+ tensor = attr_value.as_tensor()
99
+ elif attr_name == "value_int":
100
+ tensor = ir.tensor(
101
+ attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
102
+ )
103
+ elif attr_name == "value_ints":
104
+ tensor = ir.tensor(
105
+ attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
106
+ )
107
+ elif attr_name == "value_float":
108
+ tensor = ir.tensor(
109
+ attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
110
+ )
111
+ elif attr_name == "value_floats":
112
+ tensor = ir.tensor(
113
+ attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
114
+ )
115
+ elif attr_name in ("value_string", "value_strings"):
116
+ tensor = ir.StringTensor(
117
+ np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
118
+ )
119
+ else:
120
+ raise ValueError(
121
+ f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
122
+ )
123
+
124
+ if tensor.size < self.size_limit:
125
+ logger.debug(
126
+ "Tensor from node '%s' has less than %s elements",
127
+ node.name,
128
+ self.size_limit,
129
+ )
130
+ return None
131
+ return tensor
132
+
133
+
134
+ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
135
+ """Lift subgraph initializers to main graph.
136
+
137
+ This pass lifts the initializers of a subgraph to the main graph.
138
+ It is used to ensure that the initializers are available in the main graph
139
+ for further processing or optimization.
140
+
141
+ Initializers that are also graph inputs will not be lifted.
142
+ """
143
+
144
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
145
+ count = 0
146
+ registered_initializer_names: dict[str, int] = {}
147
+ for graph in model.graphs():
148
+ if graph is model.graph:
149
+ continue
150
+ for name in tuple(graph.initializers):
151
+ initializer = graph.initializers[name]
152
+ if initializer.is_graph_input():
153
+ # Skip the ones that are also graph inputs
154
+ logger.debug(
155
+ "Initializer '%s' is also a graph input, so it can't be lifted",
156
+ initializer.name,
157
+ )
158
+ continue
159
+ # Remove the initializer from the subgraph
160
+ graph.initializers.pop(name)
161
+ # To avoid name conflicts, we need to rename the initializer
162
+ # to a unique name in the main graph
163
+ if name in registered_initializer_names:
164
+ name_count = registered_initializer_names[name]
165
+ initializer.name = f"{name}_{name_count}"
166
+ registered_initializer_names[name] = name_count + 1
167
+ else:
168
+ assert initializer.name is not None
169
+ registered_initializer_names[initializer.name] = 1
170
+ model.graph.register_initializer(initializer)
171
+ count += 1
172
+ logger.debug(
173
+ "Lifted initializer '%s' from subgraph '%s' to main graph",
174
+ initializer.name,
175
+ graph.name,
176
+ )
177
+ return ir.passes.PassResult(model, modified=bool(count))
178
+
179
+
180
+ class RemoveInitializersFromInputsPass(ir.passes.InPlacePass):
181
+ """Remove initializers from inputs.
182
+
183
+ This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list.
184
+ """
185
+
186
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
187
+ count = 0
188
+ for graph in model.graphs():
189
+ initializers = set(graph.initializers.values())
190
+ new_inputs = []
191
+ for input_value in graph.inputs:
192
+ if input_value in initializers:
193
+ count += 1
194
+ else:
195
+ new_inputs.append(input_value)
196
+ graph.inputs.clear()
197
+ graph.inputs.extend(new_inputs)
198
+ logger.info("Removed %s initializers from graph inputs", count)
199
+ return ir.passes.PassResult(model, modified=bool(count))
200
+
201
+
202
+ class AddInitializersToInputsPass(ir.passes.InPlacePass):
203
+ """Add initializers to inputs.
204
+
205
+ This pass finds all initializers and adds them to the graph.inputs list if they are not already present.
206
+ """
207
+
208
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
209
+ count = 0
210
+ for graph in model.graphs():
211
+ inputs_set = set(graph.inputs)
212
+ for initializer in graph.initializers.values():
213
+ if initializer not in inputs_set:
214
+ graph.inputs.append(initializer)
215
+ count += 1
216
+ logger.info("Added %s initializers to graph inputs", count)
217
+ return ir.passes.PassResult(model, modified=bool(count))
@@ -0,0 +1,332 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Implementation of an inliner for onnx_ir."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+
9
+ __all__ = ["InlinePass", "InlinePassResult"]
10
+
11
+ from collections import defaultdict
12
+ from collections.abc import Iterable, Mapping, Sequence
13
+
14
+ import onnx_ir as ir
15
+ import onnx_ir.convenience as _ir_convenience
16
+
17
+ # A replacement for a node specifies a list of nodes that replaces the original node,
18
+ # and a list of values that replaces the original node's outputs.
19
+
20
+ NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
21
+
22
+ # A call stack is a list of identifiers of call sites, where the first element is the
23
+ # outermost call site, and the last element is the innermost call site. This is used
24
+ # primarily for generating unique names for values in the inlined functions.
25
+ CallSiteId = str
26
+ CallStack = list[CallSiteId]
27
+
28
+
29
+ def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
30
+ """Generate a unique name from a name, calling-context, and set of used names.
31
+
32
+ If there is a name clash, we add a numeric suffix to the name to make
33
+ it unique. We use the same strategy to make node names unique.
34
+
35
+ TODO: We can use the callstack in generating a name for a value X in a function
36
+ that is inlined into a graph. This is not yet implemented. Using the full callstack
37
+ leads to very long and hard to read names. Some investigation is needed to find
38
+ a good naming strategy that will produce useful names for debugging.
39
+ """
40
+ candidate = name
41
+ i = 1
42
+ while candidate in used_names:
43
+ i += 1
44
+ candidate = f"{name}_{i}"
45
+ used_names.add(candidate)
46
+ return candidate
47
+
48
+
49
+ class _CopyReplace:
50
+ """Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
51
+
52
+ def __init__(
53
+ self,
54
+ inliner: InlinePass,
55
+ attr_map: Mapping[str, ir.Attr],
56
+ value_map: dict[ir.Value, ir.Value | None],
57
+ metadata_props: dict[str, str],
58
+ call_stack: CallStack,
59
+ ) -> None:
60
+ self._inliner = inliner
61
+ self._value_map = value_map
62
+ self._attr_map = attr_map
63
+ self._metadata_props = metadata_props
64
+ self._call_stack = call_stack
65
+
66
+ def clone_value(self, value: ir.Value) -> ir.Value | None:
67
+ if value in self._value_map:
68
+ return self._value_map[value]
69
+ # If the value is not in the value map, it must be a graph input.
70
+ assert value.producer() is None, f"Value {value} has no entry in the value map"
71
+ new_value = ir.Value(
72
+ name=value.name,
73
+ type=value.type,
74
+ shape=value.shape,
75
+ doc_string=value.doc_string,
76
+ const_value=value.const_value,
77
+ )
78
+ self._value_map[value] = new_value
79
+ return new_value
80
+
81
+ def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
82
+ if value is None:
83
+ return None
84
+ return self.clone_value(value)
85
+
86
+ def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None:
87
+ if not attr.is_ref():
88
+ if attr.type == ir.AttributeType.GRAPH:
89
+ graph = self.clone_graph(attr.as_graph())
90
+ return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
91
+ elif attr.type == ir.AttributeType.GRAPHS:
92
+ graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
93
+ return ir.Attr(
94
+ key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
95
+ )
96
+ return attr
97
+ assert attr.is_ref()
98
+ ref_attr_name = attr.ref_attr_name
99
+ assert ref_attr_name is not None, "Reference attribute must have a name"
100
+ if ref_attr_name in self._attr_map:
101
+ ref_attr = self._attr_map[ref_attr_name]
102
+ if not ref_attr.is_ref():
103
+ return ir.Attr(
104
+ key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
105
+ )
106
+ assert ref_attr.ref_attr_name is not None
107
+ return ir.RefAttr(
108
+ key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
109
+ )
110
+ # Note that if a function has an attribute-parameter X, and a call (node) to the function
111
+ # has no attribute X, all references to X in nodes inside the function body will be
112
+ # removed. This is just the ONNX representation of optional-attributes.
113
+ return None
114
+
115
+ def clone_node(self, node: ir.Node) -> ir.Node:
116
+ new_inputs = [self.clone_optional_value(input) for input in node.inputs]
117
+ new_attributes = [
118
+ new_value
119
+ for key, value in node.attributes.items()
120
+ if (new_value := self.clone_attr(key, value)) is not None
121
+ ]
122
+ new_name = node.name
123
+ if new_name is not None:
124
+ new_name = _make_unique_name(
125
+ new_name, self._call_stack, self._inliner.used_node_names
126
+ )
127
+
128
+ new_metadata = {**self._metadata_props, **node.metadata_props}
129
+ # TODO: For now, node metadata overrides callnode metadata if there is a conflict.
130
+ # Do we need to preserve both?
131
+
132
+ new_node = ir.Node(
133
+ node.domain,
134
+ node.op_type,
135
+ new_inputs,
136
+ new_attributes,
137
+ overload=node.overload,
138
+ num_outputs=len(node.outputs),
139
+ graph=None,
140
+ name=new_name,
141
+ doc_string=node.doc_string, # type: ignore
142
+ metadata_props=new_metadata,
143
+ )
144
+ new_outputs = new_node.outputs
145
+ for i, output in enumerate(node.outputs):
146
+ self._value_map[output] = new_outputs[i]
147
+ old_name = output.name if output.name is not None else f"output_{i}"
148
+ new_outputs[i].name = _make_unique_name(
149
+ old_name, self._call_stack, self._inliner.used_value_names
150
+ )
151
+
152
+ self._inliner.node_context[new_node] = self._call_stack
153
+
154
+ return new_node
155
+
156
+ def clone_graph(self, graph: ir.Graph) -> ir.Graph:
157
+ input_values = [self.clone_value(v) for v in graph.inputs]
158
+ nodes = [self.clone_node(node) for node in graph]
159
+ initializers = [self.clone_value(init) for init in graph.initializers.values()]
160
+ output_values = [
161
+ self.clone_value(v) for v in graph.outputs
162
+ ] # Looks up already cloned values
163
+
164
+ return ir.Graph(
165
+ input_values, # type: ignore
166
+ output_values, # type: ignore
167
+ nodes=nodes,
168
+ initializers=initializers, # type: ignore
169
+ doc_string=graph.doc_string,
170
+ opset_imports=graph.opset_imports,
171
+ name=graph.name,
172
+ metadata_props=graph.metadata_props,
173
+ )
174
+
175
+
176
+ def _abbreviate(
177
+ function_ids: Iterable[ir.OperatorIdentifier],
178
+ ) -> dict[ir.OperatorIdentifier, str]:
179
+ """Create a short unambiguous abbreviation for all function ids."""
180
+
181
+ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
182
+ """Create a short unambiguous abbreviation for a function id."""
183
+ domain, name, overload = id
184
+ # Omit the domain, if it remains unambiguous after omitting it.
185
+ if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
186
+ short_domain = domain + "_"
187
+ else:
188
+ short_domain = ""
189
+ if overload != "":
190
+ return short_domain + name + "_" + overload
191
+ return short_domain + name
192
+
193
+ return {id: id_abbreviation(id) for id in function_ids}
194
+
195
+
196
+ @dataclasses.dataclass
197
+ class InlinePassResult(ir.passes.PassResult):
198
+ id_count: dict[ir.OperatorIdentifier, int]
199
+
200
+
201
+ class InlinePass(ir.passes.InPlacePass):
202
+ """Inline model local functions to the main graph and clear function definitions."""
203
+
204
+ def __init__(self) -> None:
205
+ super().__init__()
206
+ self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
207
+ self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
208
+ self._opset_imports: dict[str, int] = {}
209
+ self.used_value_names: set[str] = set()
210
+ self.used_node_names: set[str] = set()
211
+ self.node_context: dict[ir.Node, CallStack] = {}
212
+
213
+ def _reset(self, model: ir.Model) -> None:
214
+ self._functions = model.functions
215
+ self._function_id_abbreviations = _abbreviate(self._functions.keys())
216
+ self._opset_imports = model.opset_imports
217
+ self.used_value_names = set()
218
+ self.used_node_names = set()
219
+ self.node_context = {}
220
+
221
+ def call(self, model: ir.Model) -> InlinePassResult:
222
+ self._reset(model)
223
+ id_count = self._inline_calls_in(model.graph)
224
+ model.functions.clear()
225
+ return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
226
+
227
+ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
228
+ id = node.op_identifier()
229
+ function = self._functions[id]
230
+
231
+ # check opset compatibility and update the opset imports
232
+ for key, value in function.opset_imports.items():
233
+ if key not in self._opset_imports:
234
+ self._opset_imports[key] = value
235
+ elif self._opset_imports[key] != value:
236
+ raise ValueError(
237
+ f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
238
+ )
239
+
240
+ # Identify substitutions for both inputs and attributes of the function:
241
+ attributes: Mapping[str, ir.Attr] = node.attributes
242
+ default_attr_values = {
243
+ attr.name: attr
244
+ for attr in function.attributes.values()
245
+ if attr.name not in attributes and attr.value is not None
246
+ }
247
+ if default_attr_values:
248
+ attributes = {**attributes, **default_attr_values}
249
+ if any(
250
+ attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
251
+ for attr in attributes.values()
252
+ ):
253
+ raise ValueError(
254
+ "Inliner does not support graph attribute parameters to functions"
255
+ )
256
+
257
+ if len(node.inputs) > len(function.inputs):
258
+ raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
259
+ value_map = {}
260
+ for i, input in enumerate(node.inputs):
261
+ value_map[function.inputs[i]] = input
262
+ for i in range(len(node.inputs), len(function.inputs)):
263
+ value_map[function.inputs[i]] = None
264
+
265
+ # Identify call-stack for node, used to generate unique names.
266
+ call_stack = self.node_context.get(node, [])
267
+ new_call_stack = [*call_stack, call_site_id]
268
+
269
+ cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)
270
+
271
+ # iterate over the nodes in the function, creating a copy of each node
272
+ # and replacing inputs with the corresponding values in the value map.
273
+ # Update the value map with the new values.
274
+
275
+ nodes = [cloner.clone_node(node) for node in function]
276
+ output_values = [value_map[output] for output in function.outputs]
277
+ return nodes, output_values # type: ignore
278
+
279
+ def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
280
+ for input in graph.inputs:
281
+ if input.name is not None:
282
+ self.used_value_names.add(input.name)
283
+ for initializer in graph.initializers:
284
+ self.used_value_names.add(initializer)
285
+
286
+ # Pre-processing:
287
+ # * Count the number of times each function is called in the graph.
288
+ # This is used for disambiguating names of values in the inlined functions.
289
+ # * And identify names of values that are used in the graph.
290
+ id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
291
+ for node in graph:
292
+ if node.name:
293
+ self.used_node_names.add(node.name)
294
+ id = node.op_identifier()
295
+ if id in self._functions:
296
+ id_count[id] += 1
297
+ for output in node.outputs:
298
+ if output.name is not None:
299
+ self.used_value_names.add(output.name)
300
+ next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
301
+ for node in graph:
302
+ id = node.op_identifier()
303
+ if id in self._functions:
304
+ # If there are multiple calls to same function, we use a prefix to disambiguate
305
+ # the different call-sites:
306
+ if id_count[id] > 1:
307
+ call_site_prefix = f"_{next_id[id]}"
308
+ next_id[id] += 1
309
+ else:
310
+ call_site_prefix = ""
311
+ call_site = node.name or (
312
+ self._function_id_abbreviations[id] + call_site_prefix
313
+ )
314
+ nodes, values = self._instantiate_call(node, call_site)
315
+ _ir_convenience.replace_nodes_and_values(
316
+ graph,
317
+ insertion_point=node,
318
+ old_nodes=[node],
319
+ new_nodes=nodes,
320
+ old_values=node.outputs,
321
+ new_values=values,
322
+ )
323
+ else:
324
+ for attr in node.attributes.values():
325
+ if not isinstance(attr, ir.Attr):
326
+ continue
327
+ if attr.type == ir.AttributeType.GRAPH:
328
+ self._inline_calls_in(attr.as_graph())
329
+ elif attr.type == ir.AttributeType.GRAPHS:
330
+ for g in attr.as_graphs():
331
+ self._inline_calls_in(g)
332
+ return id_count
@@ -0,0 +1,57 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Passes for debugging purposes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CheckerPass",
9
+ ]
10
+
11
+ from typing import Literal
12
+
13
+ import onnx
14
+
15
+ import onnx_ir as ir
16
+ from onnx_ir.passes.common import _c_api_utils
17
+
18
+
19
+ class CheckerPass(ir.passes.PassBase):
20
+ """Run onnx checker on the model."""
21
+
22
+ @property
23
+ def in_place(self) -> Literal[True]:
24
+ """This pass does not create a new model."""
25
+ return True
26
+
27
+ @property
28
+ def changes_input(self) -> Literal[False]:
29
+ """This pass does not change the input model."""
30
+ return False
31
+
32
+ def __init__(
33
+ self,
34
+ full_check: bool = False,
35
+ skip_opset_compatibility_check: bool = False,
36
+ check_custom_domain: bool = False,
37
+ ):
38
+ super().__init__()
39
+ self.full_check = full_check
40
+ self.skip_opset_compatibility_check = skip_opset_compatibility_check
41
+ self.check_custom_domain = check_custom_domain
42
+
43
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
44
+ """Run the onnx checker on the model."""
45
+
46
+ def _partial_check_model(proto: onnx.ModelProto) -> None:
47
+ """Partial function to check the model."""
48
+ onnx.checker.check_model(
49
+ proto,
50
+ full_check=self.full_check,
51
+ skip_opset_compatibility_check=self.skip_opset_compatibility_check,
52
+ check_custom_domain=self.check_custom_domain,
53
+ )
54
+
55
+ _c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
56
+ # The model is not modified
57
+ return ir.passes.PassResult(model, False)