onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,291 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Convenience constructors for IR objects."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "tensor",
9
+ "node",
10
+ ]
11
+
12
+ import typing
13
+ from collections.abc import Mapping, Sequence
14
+
15
+ import numpy as np
16
+ import onnx # noqa: TID251
17
+
18
+ from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
19
+
20
+ if typing.TYPE_CHECKING:
21
+ import numpy.typing as npt
22
+
23
+ import onnx_ir as ir
24
+
25
+
26
+ def tensor(
27
+ value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
28
+ dtype: ir.DataType | None = None,
29
+ name: str | None = None,
30
+ doc_string: str | None = None,
31
+ ) -> _protocols.TensorProtocol:
32
+ """Create a tensor value from an ArrayLike object or a TensorProto.
33
+
34
+ The dtype must match the value. Reinterpretation of the value is
35
+ not supported, unless if the value is a plain Python object, in which case
36
+ it is converted to a numpy array with the given dtype.
37
+
38
+ ``value`` can be a numpy array, a plain Python object, or a TensorProto.
39
+
40
+ .. warning::
41
+ For 4bit dtypes, the value must be unpacked. Use :class:`~onnx_ir.PackedTensor`
42
+ to create a tensor with packed data.
43
+
44
+ Example::
45
+
46
+ >>> import onnx_ir as ir
47
+ >>> import numpy as np
48
+ >>> import ml_dtypes
49
+ >>> import onnx
50
+ >>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
51
+ Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
52
+ >>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
53
+ Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
54
+ >>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
55
+ >>> tp_tensor.numpy()
56
+ array(0.5, dtype=float32)
57
+ >>> import torch
58
+ >>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
59
+ TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
60
+
61
+ Args:
62
+ value: The numpy array to create the tensor from.
63
+ dtype: The data type of the tensor.
64
+ name: The name of the tensor.
65
+ doc_string: The documentation string of the tensor.
66
+
67
+ Returns:
68
+ A tensor value.
69
+
70
+ Raises:
71
+ ValueError: If the dtype does not match the value when value is not a plain Python
72
+ object like ``list[int]``.
73
+ """
74
+ if isinstance(value, _protocols.TensorProtocol):
75
+ if dtype is not None and dtype != value.dtype:
76
+ raise ValueError(
77
+ f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
78
+ "You do not have to specify the dtype when value is a Tensor."
79
+ )
80
+ return value
81
+ if isinstance(value, onnx.TensorProto):
82
+ tensor_ = serde.deserialize_tensor(value)
83
+ if name is not None:
84
+ tensor_.name = name
85
+ if doc_string is not None:
86
+ tensor_.doc_string = doc_string
87
+ if dtype is not None and dtype != tensor_.dtype:
88
+ raise ValueError(
89
+ f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
90
+ "You do not have to specify the dtype when value is a TensorProto."
91
+ )
92
+ return tensor_
93
+ elif str(type(value)) == "<class 'torch.Tensor'>":
94
+ # NOTE: We use str(type(...)) and do not import torch for type checking
95
+ # as it creates overhead during import
96
+ return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
97
+ elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
98
+ return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
99
+
100
+ # Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor
101
+ if dtype is not None:
102
+ if not isinstance(dtype, _enums.DataType):
103
+ raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}")
104
+ numpy_dtype = dtype.numpy()
105
+ elif isinstance(value, Sequence) and not value:
106
+ raise ValueError("dtype must be specified when value is an empty sequence.")
107
+ elif isinstance(value, int) and not isinstance(value, bool):
108
+ # Specify int64 for ints because on Windows this may be int32
109
+ numpy_dtype = np.dtype(np.int64)
110
+ elif isinstance(value, float):
111
+ # If the value is a single float, we use np.float32 as the default dtype
112
+ numpy_dtype = np.dtype(np.float32)
113
+ elif isinstance(value, Sequence) and value:
114
+ if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value):
115
+ numpy_dtype = np.dtype(np.int64)
116
+ elif all(isinstance(elem, float) for elem in value):
117
+ # If the value is a sequence of floats, we use np.float32 as the default dtype
118
+ numpy_dtype = np.dtype(np.float32)
119
+ else:
120
+ numpy_dtype = None
121
+ else:
122
+ numpy_dtype = None
123
+
124
+ array = np.array(value, dtype=numpy_dtype)
125
+
126
+ # Handle string tensors by encoding them
127
+ if isinstance(value, str) or (
128
+ isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value)
129
+ ):
130
+ array = np.strings.encode(array, encoding="utf-8")
131
+ return _core.StringTensor(
132
+ array,
133
+ shape=_core.Shape(array.shape),
134
+ name=name,
135
+ doc_string=doc_string,
136
+ )
137
+
138
+ return _core.Tensor(
139
+ array,
140
+ dtype=dtype,
141
+ shape=_core.Shape(array.shape),
142
+ name=name,
143
+ doc_string=doc_string,
144
+ )
145
+
146
+
147
+ def node(
148
+ op_type: str,
149
+ inputs: Sequence[ir.Value | None],
150
+ attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
151
+ *,
152
+ domain: str = "",
153
+ overload: str = "",
154
+ num_outputs: int | None = None,
155
+ outputs: Sequence[ir.Value] | None = None,
156
+ version: int | None = None,
157
+ graph: ir.Graph | None = None,
158
+ name: str | None = None,
159
+ doc_string: str | None = None,
160
+ metadata_props: dict[str, str] | None = None,
161
+ ) -> ir.Node:
162
+ """Create a :class:`~onnx_ir.Node`.
163
+
164
+ This is a convenience constructor for creating a Node that supports Python
165
+ objects as attributes.
166
+
167
+ Example::
168
+
169
+ >>> import onnx_ir as ir
170
+ >>> input_a = ir.val("A", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
171
+ >>> input_b = ir.val("B", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
172
+ >>> node = ir.node(
173
+ ... "SomeOp",
174
+ ... inputs=[input_a, input_b],
175
+ ... attributes={"alpha": 1.0, "some_list": [1, 2, 3]},
176
+ ... domain="some.domain",
177
+ ... name="node_name"
178
+ ... )
179
+ >>> node.op_type
180
+ 'SomeOp'
181
+
182
+ Args:
183
+ op_type: The name of the operator.
184
+ inputs: The input values. When an input is None, it is an empty input.
185
+ attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
186
+ overload: The overload name when the node is invoking a function.
187
+ domain: The domain of the operator. For onnx operators, this is an empty string.
188
+ num_outputs: The number of outputs of the node. If not specified, the number is 1.
189
+ outputs: The output values. If None, the outputs are created during initialization.
190
+ version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
191
+ graph: The graph that the node belongs to. If None, the node is not added to any graph.
192
+ A `Node` must belong to zero or one graph.
193
+ name: The name of the node. If None, the node is anonymous.
194
+ doc_string: The documentation string.
195
+ metadata_props: The metadata properties.
196
+
197
+ Returns:
198
+ A node with the given op_type and inputs.
199
+ """
200
+ if attributes is None:
201
+ attrs: Sequence[ir.Attr] = ()
202
+ else:
203
+ attrs = _convenience.convert_attributes(attributes)
204
+ return _core.Node(
205
+ domain=domain,
206
+ op_type=op_type,
207
+ inputs=inputs,
208
+ attributes=attrs,
209
+ overload=overload,
210
+ num_outputs=num_outputs,
211
+ outputs=outputs,
212
+ version=version,
213
+ graph=graph,
214
+ name=name,
215
+ doc_string=doc_string,
216
+ metadata_props=metadata_props,
217
+ )
218
+
219
+
220
+ def val(
221
+ name: str | None,
222
+ dtype: ir.DataType | None = None,
223
+ shape: ir.Shape | Sequence[int | str | None] | None = None,
224
+ *,
225
+ type: ir.TypeProtocol | None = None,
226
+ const_value: ir.TensorProtocol | None = None,
227
+ metadata_props: dict[str, str] | None = None,
228
+ ) -> ir.Value:
229
+ """Create a :class:`~onnx_ir.Value` with the given name and type.
230
+
231
+ This is a convenience constructor for creating a Value that allows you to specify
232
+ dtype and shape in a more relaxed manner. Whereas to create a Value directly, you
233
+ need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object
234
+ first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType`
235
+ and shape as a sequence of integers or symbolic dimensions.
236
+
237
+ Example::
238
+
239
+ >>> import onnx_ir as ir
240
+ >>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3])
241
+ >>> t.name
242
+ 'x'
243
+ >>> t.type
244
+ Tensor(FLOAT)
245
+ >>> t.shape
246
+ Shape([SymbolicDim(N), 42, 3])
247
+
248
+ .. versionadded:: 0.1.9
249
+
250
+ Args:
251
+ name: The name of the value.
252
+ dtype: The data type of the TensorType of the value. This is used only when type is None.
253
+ shape: The shape of the value.
254
+ type: The type of the value. Only one of dtype and type can be specified.
255
+ const_value: The constant tensor that initializes the value. Supply this argument
256
+ when you want to create an initializer. The type and shape can be obtained from the tensor.
257
+ metadata_props: The metadata properties that will be serialized to the ONNX proto.
258
+
259
+ Returns:
260
+ A Value object.
261
+ """
262
+ if const_value is not None:
263
+ const_tensor_type = _core.TensorType(const_value.dtype)
264
+ if type is not None and type != const_tensor_type:
265
+ raise ValueError(
266
+ f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. "
267
+ "You do not have to specify the type when const_value is provided."
268
+ )
269
+ if dtype is not None and dtype != const_value.dtype:
270
+ raise ValueError(
271
+ f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. "
272
+ "You do not have to specify the dtype when const_value is provided."
273
+ )
274
+ if shape is not None and _core.Shape(shape) != const_value.shape:
275
+ raise ValueError(
276
+ f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. "
277
+ "You do not have to specify the shape when const_value is provided."
278
+ )
279
+ return _core.Value(
280
+ name=name,
281
+ type=const_tensor_type,
282
+ shape=_core.Shape(const_value.shape), # type: ignore
283
+ const_value=const_value,
284
+ metadata_props=metadata_props,
285
+ )
286
+
287
+ if type is None and dtype is not None:
288
+ type = _core.TensorType(dtype)
289
+ if shape is not None and not isinstance(shape, _core.Shape):
290
+ shape = _core.Shape(shape)
291
+ return _core.Value(name=name, type=type, shape=shape, metadata_props=metadata_props)
@@ -0,0 +1,191 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utilities for extracting subgraphs from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ from collections.abc import Collection, Sequence
9
+ from typing import Union
10
+
11
+ import onnx_ir as ir
12
+
13
+ GraphLike = Union["ir.Graph", "ir.Function", "ir.GraphView"]
14
+
15
+
16
+ def _collect_all_external_values(parent_graph: ir.Graph, graph: ir.Graph) -> set[ir.Value]:
17
+ """Collects all values in the given graph-like object.
18
+
19
+ Args:
20
+ parent_graph: The parent graph to which collected values must belong.
21
+ graph: The graph-like object to collect values from.
22
+
23
+ Returns:
24
+ A set of :class:`~onnx_ir.Value` objects belonging to ``parent_graph``.
25
+ """
26
+ values: set[ir.Value] = set()
27
+ for node in ir.traversal.RecursiveGraphIterator(graph):
28
+ for val in node.inputs:
29
+ if val is None:
30
+ continue
31
+ if val.graph is parent_graph:
32
+ values.add(val)
33
+ return values
34
+
35
+
36
+ def _find_subgraph_bounded_by_values(
37
+ graph: GraphLike,
38
+ inputs: Collection[ir.Value],
39
+ outputs: Collection[ir.Value],
40
+ parent_graph: ir.Graph,
41
+ ) -> tuple[list[ir.Node], Collection[ir.Value]]:
42
+ """Finds the subgraph bounded by the given inputs and outputs.
43
+
44
+ Args:
45
+ graph: The graph to search.
46
+ inputs: The inputs to the subgraph.
47
+ outputs: The outputs of the subgraph.
48
+ parent_graph: The parent graph of the subgraph.
49
+
50
+ Returns:
51
+ A list of nodes in the subgraph and the initializers used.
52
+
53
+ Raises:
54
+ ValueError: If the subgraph is not properly bounded by the given inputs and outputs.
55
+ """
56
+ if isinstance(graph, ir.Function):
57
+ initialized_values: set[ir.Value] = set()
58
+ else:
59
+ initialized_values = {val for val in inputs if val.is_initializer()}
60
+ node_index = {node: idx for idx, node in enumerate(graph)}
61
+ all_nodes = []
62
+ value_stack: list[ir.Value] = [*outputs]
63
+ visited_nodes: set[ir.Node] = set()
64
+ visited_values: set[ir.Value] = set(inputs)
65
+
66
+ while value_stack:
67
+ value = value_stack.pop()
68
+ if value in visited_values:
69
+ continue
70
+ if value.is_initializer():
71
+ # Record the initializer
72
+ initialized_values.add(value)
73
+
74
+ visited_values.add(value)
75
+
76
+ if (node := value.producer()) is not None:
77
+ if node not in visited_nodes:
78
+ visited_nodes.add(node)
79
+ all_nodes.append(node)
80
+ for input in node.inputs:
81
+ if input not in visited_values and input is not None:
82
+ value_stack.append(input)
83
+ for attr in node.attributes.values():
84
+ if attr.type == ir.AttributeType.GRAPH:
85
+ values = _collect_all_external_values(parent_graph, attr.as_graph())
86
+ for val in values:
87
+ if val not in visited_values:
88
+ value_stack.append(val)
89
+ elif attr.type == ir.AttributeType.GRAPHS:
90
+ for g in attr.as_graphs():
91
+ values = _collect_all_external_values(parent_graph, g)
92
+ for val in values:
93
+ if val not in visited_values:
94
+ value_stack.append(val)
95
+
96
+ # Validate that the subgraph is properly bounded
97
+ # Collect all values at the input frontier (used by subgraph but not produced by it)
98
+ # The frontier can only contain graph inputs or initializers (values with no producer)
99
+ input_frontier: set[ir.Value] = set()
100
+ for node in visited_nodes:
101
+ for input_val in node.inputs:
102
+ if input_val is None:
103
+ continue
104
+ # If this value is not produced by any node in the subgraph
105
+ producer = input_val.producer()
106
+ if producer is None or producer not in visited_nodes:
107
+ input_frontier.add(input_val)
108
+
109
+ # Check for graph inputs that weren't specified in the inputs parameter
110
+ # (initializers are allowed, but unspecified graph inputs mean the subgraph is unbounded)
111
+ unspecified_graph_inputs: list[ir.Value] = []
112
+ inputs_set = set(inputs)
113
+ for val in sorted(input_frontier, key=lambda v: v.name or ""):
114
+ if val not in inputs_set and not val.is_initializer():
115
+ unspecified_graph_inputs.append(val)
116
+
117
+ if unspecified_graph_inputs:
118
+ value_names = [val.name or "<None>" for val in unspecified_graph_inputs]
119
+ raise ValueError(
120
+ f"The subgraph is not properly bounded by the specified inputs and outputs. "
121
+ f"The following graph inputs are required but not provided: {', '.join(value_names)}"
122
+ )
123
+
124
+ # Preserve the original order
125
+ all_nodes.sort(key=lambda n: node_index[n])
126
+ return all_nodes, initialized_values
127
+
128
+
129
+ def extract(
130
+ graph_like: GraphLike,
131
+ /,
132
+ inputs: Sequence[ir.Value | str],
133
+ outputs: Sequence[ir.Value | str],
134
+ ) -> ir.Graph:
135
+ """Extracts a subgraph from the given graph-like object.
136
+
137
+ .. versionadded:: 0.1.14
138
+
139
+ Args:
140
+ graph_like: The graph-like object to extract from.
141
+ inputs: The inputs to the subgraph. Can be Value objects or their names.
142
+ outputs: The outputs of the subgraph. Can be Value objects or their names.
143
+
144
+ Returns:
145
+ The extracted subgraph as a new :class:`~onnx_ir.Graph` object.
146
+
147
+ Raises:
148
+ ValueError: If any of the inputs or outputs are not found in the graph.
149
+ ValueError: If the subgraph is not properly bounded by the given inputs and outputs.
150
+ """
151
+ if isinstance(graph_like, ir.Function):
152
+ graph: ir.Graph | ir.GraphView = graph_like.graph
153
+ else:
154
+ graph = graph_like
155
+ values = ir.convenience.create_value_mapping(graph, include_subgraphs=False)
156
+ is_graph_view = isinstance(graph_like, ir.GraphView)
157
+ for val in itertools.chain(inputs, outputs):
158
+ if isinstance(val, ir.Value):
159
+ if not is_graph_view and val.graph is not graph:
160
+ graph_name = graph.name if graph.name is not None else "unnamed graph"
161
+ raise ValueError(
162
+ f"Value '{val}' does not belong to the given "
163
+ f"{graph_like.__class__.__name__} ({graph_name})."
164
+ )
165
+ else:
166
+ if val not in values:
167
+ raise ValueError(f"Value with name '{val}' not found in the graph.")
168
+
169
+ input_vals = [values[val] if isinstance(val, str) else val for val in inputs]
170
+ output_vals = [values[val] if isinstance(val, str) else val for val in outputs]
171
+ # Find the owning graph of the outputs to set as the parent graph
172
+ if not output_vals:
173
+ raise ValueError("At least one output must be provided to extract a subgraph.")
174
+ parent_graph = output_vals[0].graph
175
+ assert parent_graph is not None
176
+ extracted_nodes, initialized_values = _find_subgraph_bounded_by_values(
177
+ graph_like, input_vals, output_vals, parent_graph=parent_graph
178
+ )
179
+
180
+ graph_view = ir.GraphView(
181
+ input_vals,
182
+ output_vals,
183
+ nodes=extracted_nodes,
184
+ initializers=tuple(initialized_values),
185
+ doc_string=graph_like.doc_string,
186
+ opset_imports=graph_like.opset_imports,
187
+ name=graph_like.name,
188
+ metadata_props=graph_like.metadata_props,
189
+ )
190
+
191
+ return graph_view.clone()