onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (45) hide show
  1. onnx_ir/__init__.py +23 -10
  2. onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
  3. onnx_ir/_convenience/_constructors.py +213 -0
  4. onnx_ir/_core.py +857 -233
  5. onnx_ir/_display.py +2 -2
  6. onnx_ir/_enums.py +107 -5
  7. onnx_ir/_graph_comparison.py +2 -2
  8. onnx_ir/_graph_containers.py +268 -0
  9. onnx_ir/_io.py +57 -10
  10. onnx_ir/_linked_list.py +15 -7
  11. onnx_ir/_metadata.py +4 -3
  12. onnx_ir/_name_authority.py +2 -2
  13. onnx_ir/_polyfill.py +26 -0
  14. onnx_ir/_protocols.py +31 -13
  15. onnx_ir/_tape.py +139 -32
  16. onnx_ir/_thirdparty/asciichartpy.py +1 -4
  17. onnx_ir/_type_casting.py +18 -3
  18. onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
  19. onnx_ir/convenience.py +4 -2
  20. onnx_ir/external_data.py +401 -0
  21. onnx_ir/passes/__init__.py +8 -2
  22. onnx_ir/passes/_pass_infra.py +173 -56
  23. onnx_ir/passes/common/__init__.py +36 -0
  24. onnx_ir/passes/common/_c_api_utils.py +76 -0
  25. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  26. onnx_ir/passes/common/constant_manipulation.py +232 -0
  27. onnx_ir/passes/common/inliner.py +331 -0
  28. onnx_ir/passes/common/onnx_checker.py +57 -0
  29. onnx_ir/passes/common/shape_inference.py +112 -0
  30. onnx_ir/passes/common/topological_sort.py +33 -0
  31. onnx_ir/passes/common/unused_removal.py +196 -0
  32. onnx_ir/serde.py +288 -124
  33. onnx_ir/tape.py +15 -0
  34. onnx_ir/tensor_adapters.py +122 -0
  35. onnx_ir/testing.py +197 -0
  36. onnx_ir/traversal.py +4 -3
  37. onnx_ir-0.1.0.dist-info/METADATA +53 -0
  38. onnx_ir-0.1.0.dist-info/RECORD +41 -0
  39. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
  40. onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
  41. onnx_ir/_external_data.py +0 -323
  42. onnx_ir-0.0.1.dist-info/LICENSE +0 -22
  43. onnx_ir-0.0.1.dist-info/METADATA +0 -73
  44. onnx_ir-0.0.1.dist-info/RECORD +0 -26
  45. {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,331 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Implementation of an inliner for onnx_ir."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import dataclasses
8
+
9
+ __all__ = ["InlinePass", "InlinePassResult"]
10
+
11
+ from collections import defaultdict
12
+ from collections.abc import Iterable, Sequence
13
+
14
+ import onnx_ir as ir
15
+ import onnx_ir.convenience as _ir_convenience
16
+
17
+ # A replacement for a node specifies a list of nodes that replaces the original node,
18
+ # and a list of values that replaces the original node's outputs.
19
+
20
+ NodeReplacement = tuple[Sequence[ir.Node], Sequence[ir.Value]]
21
+
22
+ # A call stack is a list of identifiers of call sites, where the first element is the
23
+ # outermost call site, and the last element is the innermost call site. This is used
24
+ # primarily for generating unique names for values in the inlined functions.
25
+ CallSiteId = str
26
+ CallStack = list[CallSiteId]
27
+
28
+
29
+ def _make_unique_name(name: str, callstack: CallStack, used_names: set[str]) -> str: # pylint: disable=unused-argument
30
+ """Generate a unique name from a name, calling-context, and set of used names.
31
+
32
+ If there is a name clash, we add a numeric suffix to the name to make
33
+ it unique. We use the same strategy to make node names unique.
34
+
35
+ TODO: We can use the callstack in generating a name for a value X in a function
36
+ that is inlined into a graph. This is not yet implemented. Using the full callstack
37
+ leads to very long and hard to read names. Some investigation is needed to find
38
+ a good naming strategy that will produce useful names for debugging.
39
+ """
40
+ candidate = name
41
+ i = 1
42
+ while candidate in used_names:
43
+ i += 1
44
+ candidate = f"{name}_{i}"
45
+ used_names.add(candidate)
46
+ return candidate
47
+
48
+
49
+ class _CopyReplace:
50
+ """Utilities for creating a copy of IR objects with substitutions for attributes/input values."""
51
+
52
+ def __init__(
53
+ self,
54
+ inliner: InlinePass,
55
+ attr_map: dict[str, ir.Attr],
56
+ value_map: dict[ir.Value, ir.Value | None],
57
+ metadata_props: dict[str, str],
58
+ call_stack: CallStack,
59
+ ) -> None:
60
+ self._inliner = inliner
61
+ self._value_map = value_map
62
+ self._attr_map = attr_map
63
+ self._metadata_props = metadata_props
64
+ self._call_stack = call_stack
65
+
66
+ def clone_value(self, value: ir.Value) -> ir.Value | None:
67
+ if value in self._value_map:
68
+ return self._value_map[value]
69
+ # If the value is not in the value map, it must be a graph input.
70
+ assert value.producer() is None, f"Value {value} has no entry in the value map"
71
+ new_value = ir.Value(
72
+ name=value.name,
73
+ type=value.type,
74
+ shape=value.shape,
75
+ doc_string=value.doc_string,
76
+ const_value=value.const_value,
77
+ )
78
+ self._value_map[value] = new_value
79
+ return new_value
80
+
81
+ def clone_optional_value(self, value: ir.Value | None) -> ir.Value | None:
82
+ if value is None:
83
+ return None
84
+ return self.clone_value(value)
85
+
86
+ def clone_attr(self, key: str, attr: ir.Attr) -> ir.Attr | None:
87
+ if not attr.is_ref():
88
+ if attr.type == ir.AttributeType.GRAPH:
89
+ graph = self.clone_graph(attr.as_graph())
90
+ return ir.Attr(key, ir.AttributeType.GRAPH, graph, doc_string=attr.doc_string)
91
+ elif attr.type == ir.AttributeType.GRAPHS:
92
+ graphs = [self.clone_graph(graph) for graph in attr.as_graphs()]
93
+ return ir.Attr(
94
+ key, ir.AttributeType.GRAPHS, graphs, doc_string=attr.doc_string
95
+ )
96
+ return attr
97
+ assert attr.is_ref()
98
+ ref_attr_name = attr.ref_attr_name
99
+ if ref_attr_name in self._attr_map:
100
+ ref_attr = self._attr_map[ref_attr_name]
101
+ if not ref_attr.is_ref():
102
+ return ir.Attr(
103
+ key, ref_attr.type, ref_attr.value, doc_string=ref_attr.doc_string
104
+ )
105
+ assert ref_attr.ref_attr_name is not None
106
+ return ir.RefAttr(
107
+ key, ref_attr.ref_attr_name, ref_attr.type, doc_string=ref_attr.doc_string
108
+ )
109
+ # Note that if a function has an attribute-parameter X, and a call (node) to the function
110
+ # has no attribute X, all references to X in nodes inside the function body will be
111
+ # removed. This is just the ONNX representation of optional-attributes.
112
+ return None
113
+
114
+ def clone_node(self, node: ir.Node) -> ir.Node:
115
+ new_inputs = [self.clone_optional_value(input) for input in node.inputs]
116
+ new_attributes = [
117
+ new_value
118
+ for key, value in node.attributes.items()
119
+ if (new_value := self.clone_attr(key, value)) is not None
120
+ ]
121
+ new_name = node.name
122
+ if new_name is not None:
123
+ new_name = _make_unique_name(
124
+ new_name, self._call_stack, self._inliner.used_node_names
125
+ )
126
+
127
+ new_metadata = {**self._metadata_props, **node.metadata_props}
128
+ # TODO: For now, node metadata overrides callnode metadata if there is a conflict.
129
+ # Do we need to preserve both?
130
+
131
+ new_node = ir.Node(
132
+ node.domain,
133
+ node.op_type,
134
+ new_inputs,
135
+ new_attributes,
136
+ overload=node.overload,
137
+ num_outputs=len(node.outputs),
138
+ graph=None,
139
+ name=new_name,
140
+ doc_string=node.doc_string, # type: ignore
141
+ metadata_props=new_metadata,
142
+ )
143
+ new_outputs = new_node.outputs
144
+ for i, output in enumerate(node.outputs):
145
+ self._value_map[output] = new_outputs[i]
146
+ old_name = output.name if output.name is not None else f"output_{i}"
147
+ new_outputs[i].name = _make_unique_name(
148
+ old_name, self._call_stack, self._inliner.used_value_names
149
+ )
150
+
151
+ self._inliner.node_context[new_node] = self._call_stack
152
+
153
+ return new_node
154
+
155
+ def clone_graph(self, graph: ir.Graph) -> ir.Graph:
156
+ input_values = [self.clone_value(v) for v in graph.inputs]
157
+ nodes = [self.clone_node(node) for node in graph]
158
+ initializers = [self.clone_value(init) for init in graph.initializers.values()]
159
+ output_values = [
160
+ self.clone_value(v) for v in graph.outputs
161
+ ] # Looks up already cloned values
162
+
163
+ return ir.Graph(
164
+ input_values, # type: ignore
165
+ output_values, # type: ignore
166
+ nodes=nodes,
167
+ initializers=initializers, # type: ignore
168
+ doc_string=graph.doc_string,
169
+ opset_imports=graph.opset_imports,
170
+ name=graph.name,
171
+ metadata_props=graph.metadata_props,
172
+ )
173
+
174
+
175
+ def _abbreviate(
176
+ function_ids: Iterable[ir.OperatorIdentifier],
177
+ ) -> dict[ir.OperatorIdentifier, str]:
178
+ """Create a short unambiguous abbreviation for all function ids."""
179
+
180
+ def id_abbreviation(id: ir.OperatorIdentifier) -> str:
181
+ """Create a short unambiguous abbreviation for a function id."""
182
+ domain, name, overload = id
183
+ # Omit the domain, if it remains unambiguous after omitting it.
184
+ if any(x[0] != domain and x[1] == name and x[2] == overload for x in function_ids):
185
+ short_domain = domain + "_"
186
+ else:
187
+ short_domain = ""
188
+ if overload != "":
189
+ return short_domain + name + "_" + overload
190
+ return short_domain + name
191
+
192
+ return {id: id_abbreviation(id) for id in function_ids}
193
+
194
+
195
+ @dataclasses.dataclass
196
+ class InlinePassResult(ir.passes.PassResult):
197
+ id_count: dict[ir.OperatorIdentifier, int]
198
+
199
+
200
+ class InlinePass(ir.passes.InPlacePass):
201
+ """Inline model local functions to the main graph and clear function definitions."""
202
+
203
+ def __init__(self) -> None:
204
+ super().__init__()
205
+ self._functions: dict[ir.OperatorIdentifier, ir.Function] = {}
206
+ self._function_id_abbreviations: dict[ir.OperatorIdentifier, str] = {}
207
+ self._opset_imports: dict[str, int] = {}
208
+ self.used_value_names: set[str] = set()
209
+ self.used_node_names: set[str] = set()
210
+ self.node_context: dict[ir.Node, CallStack] = {}
211
+
212
+ def _reset(self, model: ir.Model) -> None:
213
+ self._functions = model.functions
214
+ self._function_id_abbreviations = _abbreviate(self._functions.keys())
215
+ self._opset_imports = model.opset_imports
216
+ self.used_value_names = set()
217
+ self.used_node_names = set()
218
+ self.node_context = {}
219
+
220
+ def call(self, model: ir.Model) -> InlinePassResult:
221
+ self._reset(model)
222
+ id_count = self._inline_calls_in(model.graph)
223
+ model.functions.clear()
224
+ return InlinePassResult(model, modified=bool(id_count), id_count=id_count)
225
+
226
+ def _instantiate_call(self, node: ir.Node, call_site_id: CallSiteId) -> NodeReplacement:
227
+ id = node.op_identifier()
228
+ function = self._functions[id]
229
+
230
+ # check opset compatibility and update the opset imports
231
+ for key, value in function.opset_imports.items():
232
+ if key not in self._opset_imports:
233
+ self._opset_imports[key] = value
234
+ elif self._opset_imports[key] != value:
235
+ raise ValueError(
236
+ f"Opset mismatch: {key} {self._opset_imports[key]} != {value}"
237
+ )
238
+
239
+ # Identify substitutions for both inputs and attributes of the function:
240
+ attributes: dict[str, ir.Attr] = node.attributes
241
+ default_attr_values = {
242
+ attr.name: attr
243
+ for attr in function.attributes.values()
244
+ if attr.name not in attributes and attr.value is not None
245
+ }
246
+ if default_attr_values:
247
+ attributes = {**attributes, **default_attr_values}
248
+ if any(
249
+ attr.type in {ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS}
250
+ for attr in attributes.values()
251
+ ):
252
+ raise ValueError(
253
+ "Inliner does not support graph attribute parameters to functions"
254
+ )
255
+
256
+ if len(node.inputs) > len(function.inputs):
257
+ raise ValueError(f"Input mismatch: {len(node.inputs)} > {len(function.inputs)}")
258
+ value_map = {}
259
+ for i, input in enumerate(node.inputs):
260
+ value_map[function.inputs[i]] = input
261
+ for i in range(len(node.inputs), len(function.inputs)):
262
+ value_map[function.inputs[i]] = None
263
+
264
+ # Identify call-stack for node, used to generate unique names.
265
+ call_stack = self.node_context.get(node, [])
266
+ new_call_stack = [*call_stack, call_site_id]
267
+
268
+ cloner = _CopyReplace(self, attributes, value_map, node.metadata_props, new_call_stack)
269
+
270
+ # iterate over the nodes in the function, creating a copy of each node
271
+ # and replacing inputs with the corresponding values in the value map.
272
+ # Update the value map with the new values.
273
+
274
+ nodes = [cloner.clone_node(node) for node in function]
275
+ output_values = [value_map[output] for output in function.outputs]
276
+ return nodes, output_values # type: ignore
277
+
278
+ def _inline_calls_in(self, graph: ir.Graph) -> dict[ir.OperatorIdentifier, int]:
279
+ for input in graph.inputs:
280
+ if input.name is not None:
281
+ self.used_value_names.add(input.name)
282
+ for initializer in graph.initializers:
283
+ self.used_value_names.add(initializer)
284
+
285
+ # Pre-processing:
286
+ # * Count the number of times each function is called in the graph.
287
+ # This is used for disambiguating names of values in the inlined functions.
288
+ # * And identify names of values that are used in the graph.
289
+ id_count: dict[ir.OperatorIdentifier, int] = defaultdict(int)
290
+ for node in graph:
291
+ if node.name:
292
+ self.used_node_names.add(node.name)
293
+ id = node.op_identifier()
294
+ if id in self._functions:
295
+ id_count[id] += 1
296
+ for output in node.outputs:
297
+ if output.name is not None:
298
+ self.used_value_names.add(output.name)
299
+ next_id: dict[ir.OperatorIdentifier, int] = defaultdict(int)
300
+ for node in graph:
301
+ id = node.op_identifier()
302
+ if id in self._functions:
303
+ # If there are multiple calls to same function, we use a prefix to disambiguate
304
+ # the different call-sites:
305
+ if id_count[id] > 1:
306
+ call_site_prefix = f"_{next_id[id]}"
307
+ next_id[id] += 1
308
+ else:
309
+ call_site_prefix = ""
310
+ call_site = node.name or (
311
+ self._function_id_abbreviations[id] + call_site_prefix
312
+ )
313
+ nodes, values = self._instantiate_call(node, call_site)
314
+ _ir_convenience.replace_nodes_and_values(
315
+ graph,
316
+ insertion_point=node,
317
+ old_nodes=[node],
318
+ new_nodes=nodes,
319
+ old_values=node.outputs,
320
+ new_values=values,
321
+ )
322
+ else:
323
+ for attr in node.attributes.values():
324
+ if not isinstance(attr, ir.Attr):
325
+ continue
326
+ if attr.type == ir.AttributeType.GRAPH:
327
+ self._inline_calls_in(attr.as_graph())
328
+ elif attr.type == ir.AttributeType.GRAPHS:
329
+ for g in attr.as_graphs():
330
+ self._inline_calls_in(g)
331
+ return id_count
@@ -0,0 +1,57 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Passes for debugging purposes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "CheckerPass",
9
+ ]
10
+
11
+ from typing import Literal
12
+
13
+ import onnx
14
+
15
+ import onnx_ir as ir
16
+ from onnx_ir.passes.common import _c_api_utils
17
+
18
+
19
+ class CheckerPass(ir.passes.PassBase):
20
+ """Run onnx checker on the model."""
21
+
22
+ @property
23
+ def in_place(self) -> Literal[True]:
24
+ """This pass does not create a new model."""
25
+ return True
26
+
27
+ @property
28
+ def changes_input(self) -> Literal[False]:
29
+ """This pass does not change the input model."""
30
+ return False
31
+
32
+ def __init__(
33
+ self,
34
+ full_check: bool = False,
35
+ skip_opset_compatibility_check: bool = False,
36
+ check_custom_domain: bool = False,
37
+ ):
38
+ super().__init__()
39
+ self.full_check = full_check
40
+ self.skip_opset_compatibility_check = skip_opset_compatibility_check
41
+ self.check_custom_domain = check_custom_domain
42
+
43
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
44
+ """Run the onnx checker on the model."""
45
+
46
+ def _partial_check_model(proto: onnx.ModelProto) -> None:
47
+ """Partial function to check the model."""
48
+ onnx.checker.check_model(
49
+ proto,
50
+ full_check=self.full_check,
51
+ skip_opset_compatibility_check=self.skip_opset_compatibility_check,
52
+ check_custom_domain=self.check_custom_domain,
53
+ )
54
+
55
+ _c_api_utils.call_onnx_api(func=_partial_check_model, model=model)
56
+ # The model is not modified
57
+ return ir.passes.PassResult(model, False)
@@ -0,0 +1,112 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Shape inference pass using onnx.shape_inference."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "ShapeInferencePass",
9
+ "infer_shapes",
10
+ ]
11
+
12
+ import logging
13
+
14
+ import onnx
15
+
16
+ import onnx_ir as ir
17
+ from onnx_ir.passes.common import _c_api_utils
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def _merge_func(model: ir.Model, inferred_proto: onnx.ModelProto) -> bool:
23
+ """Merge the shape inferred model with the original model.
24
+
25
+ Args:
26
+ model: The original IR model.
27
+ inferred_proto: The ONNX model with shapes and types inferred.
28
+
29
+ Returns:
30
+ A tuple containing the modified model and a boolean indicating whether the model was modified.
31
+ """
32
+ inferred_model = ir.serde.deserialize_model(inferred_proto)
33
+ modified = False
34
+ for original_graph, inferred_graph in zip(model.graphs(), inferred_model.graphs()):
35
+ original_values = ir.convenience.create_value_mapping(original_graph)
36
+ inferred_values = ir.convenience.create_value_mapping(inferred_graph)
37
+ for name, value in original_values.items():
38
+ if name in inferred_values:
39
+ inferred_value = inferred_values[name]
40
+ if value.shape != inferred_value.shape and inferred_value.shape is not None:
41
+ value.shape = inferred_value.shape
42
+ modified = True
43
+ if value.dtype != inferred_value.dtype and inferred_value.dtype is not None:
44
+ value.dtype = inferred_value.dtype
45
+ modified = True
46
+ else:
47
+ logger.warning(
48
+ "Value %s not found in inferred graph %s", name, inferred_graph.name
49
+ )
50
+ return modified
51
+
52
+
53
+ class ShapeInferencePass(ir.passes.InPlacePass):
54
+ """This pass performs shape inference on the graph."""
55
+
56
+ def __init__(
57
+ self, check_type: bool = True, strict_mode: bool = True, data_prop: bool = True
58
+ ) -> None:
59
+ """Initialize the shape inference pass.
60
+
61
+ If inference fails, the model is left unchanged.
62
+
63
+ Args:
64
+ check_type: If True, check the types of the inputs and outputs.
65
+ strict_mode: If True, use strict mode for shape inference.
66
+ data_prop: If True, use data propagation for shape inference.
67
+ """
68
+ super().__init__()
69
+ self.check_type = check_type
70
+ self.strict_mode = strict_mode
71
+ self.data_prop = data_prop
72
+
73
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
74
+ def partial_infer_shapes(proto: onnx.ModelProto) -> onnx.ModelProto:
75
+ return onnx.shape_inference.infer_shapes(
76
+ proto,
77
+ check_type=self.check_type,
78
+ strict_mode=self.strict_mode,
79
+ data_prop=self.data_prop,
80
+ )
81
+
82
+ try:
83
+ inferred_model_proto = _c_api_utils.call_onnx_api(partial_infer_shapes, model)
84
+ except Exception as e: # pylint: disable=broad-exception-caught
85
+ logger.warning("Shape inference failed: %s. Model is left unchanged", exc_info=e)
86
+ return ir.passes.PassResult(model, False)
87
+
88
+ modified = _merge_func(model, inferred_model_proto)
89
+ return ir.passes.PassResult(model, modified=modified)
90
+
91
+
92
+ def infer_shapes(
93
+ model: ir.Model,
94
+ *,
95
+ check_type: bool = True,
96
+ strict_mode: bool = True,
97
+ data_prop: bool = True,
98
+ ) -> ir.Model:
99
+ """Perform shape inference on the model.
100
+
101
+ Args:
102
+ model: The model to perform shape inference on.
103
+ check_type: If True, check the types of the inputs and outputs.
104
+ strict_mode: If True, use strict mode for shape inference.
105
+ data_prop: If True, use data propagation for shape inference.
106
+
107
+ Returns:
108
+ The model with shape inference applied.
109
+ """
110
+ return ShapeInferencePass(
111
+ check_type=check_type, strict_mode=strict_mode, data_prop=data_prop
112
+ )(model).model
@@ -0,0 +1,33 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for topologically sorting the graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "TopologicalSortPass",
9
+ ]
10
+
11
+
12
+ import onnx_ir as ir
13
+
14
+
15
+ class TopologicalSortPass(ir.passes.InPlacePass):
16
+ """Topologically sort graphs and functions in a model."""
17
+
18
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
19
+ original_nodes = list(model.graph)
20
+ model.graph.sort()
21
+ sorted_nodes = list(model.graph)
22
+ for function in model.functions.values():
23
+ original_nodes.extend(function)
24
+ function.sort()
25
+ sorted_nodes.extend(function)
26
+
27
+ # Compare node orders to determine if any changes were made
28
+ modified = False
29
+ for node, new_node in zip(original_nodes, sorted_nodes):
30
+ if node is not new_node:
31
+ modified = True
32
+ break
33
+ return ir.passes.PassResult(model=model, modified=modified)