onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Convenience constructors for IR objects."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"tensor",
|
|
9
|
+
"node",
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
import typing
|
|
13
|
+
from collections.abc import Mapping, Sequence
|
|
14
|
+
|
|
15
|
+
import numpy as np
|
|
16
|
+
import onnx # noqa: TID251
|
|
17
|
+
|
|
18
|
+
from onnx_ir import _convenience, _core, _enums, _protocols, serde, tensor_adapters
|
|
19
|
+
|
|
20
|
+
if typing.TYPE_CHECKING:
|
|
21
|
+
import numpy.typing as npt
|
|
22
|
+
|
|
23
|
+
import onnx_ir as ir
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def tensor(
|
|
27
|
+
value: npt.ArrayLike | onnx.TensorProto | ir.DLPackCompatible | ir.ArrayCompatible,
|
|
28
|
+
dtype: ir.DataType | None = None,
|
|
29
|
+
name: str | None = None,
|
|
30
|
+
doc_string: str | None = None,
|
|
31
|
+
) -> _protocols.TensorProtocol:
|
|
32
|
+
"""Create a tensor value from an ArrayLike object or a TensorProto.
|
|
33
|
+
|
|
34
|
+
The dtype must match the value. Reinterpretation of the value is
|
|
35
|
+
not supported, unless if the value is a plain Python object, in which case
|
|
36
|
+
it is converted to a numpy array with the given dtype.
|
|
37
|
+
|
|
38
|
+
``value`` can be a numpy array, a plain Python object, or a TensorProto.
|
|
39
|
+
|
|
40
|
+
.. warning::
|
|
41
|
+
For 4bit dtypes, the value must be unpacked. Use :class:`~onnx_ir.PackedTensor`
|
|
42
|
+
to create a tensor with packed data.
|
|
43
|
+
|
|
44
|
+
Example::
|
|
45
|
+
|
|
46
|
+
>>> import onnx_ir as ir
|
|
47
|
+
>>> import numpy as np
|
|
48
|
+
>>> import ml_dtypes
|
|
49
|
+
>>> import onnx
|
|
50
|
+
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
|
|
51
|
+
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
|
|
52
|
+
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
|
|
53
|
+
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
|
|
54
|
+
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
|
|
55
|
+
>>> tp_tensor.numpy()
|
|
56
|
+
array(0.5, dtype=float32)
|
|
57
|
+
>>> import torch
|
|
58
|
+
>>> ir.tensor(torch.tensor([1.0, 2.0]), name="torch_tensor")
|
|
59
|
+
TorchTensor<FLOAT,[2]>(tensor([1., 2.]), name='torch_tensor')
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
value: The numpy array to create the tensor from.
|
|
63
|
+
dtype: The data type of the tensor.
|
|
64
|
+
name: The name of the tensor.
|
|
65
|
+
doc_string: The documentation string of the tensor.
|
|
66
|
+
|
|
67
|
+
Returns:
|
|
68
|
+
A tensor value.
|
|
69
|
+
|
|
70
|
+
Raises:
|
|
71
|
+
ValueError: If the dtype does not match the value when value is not a plain Python
|
|
72
|
+
object like ``list[int]``.
|
|
73
|
+
"""
|
|
74
|
+
if isinstance(value, _protocols.TensorProtocol):
|
|
75
|
+
if dtype is not None and dtype != value.dtype:
|
|
76
|
+
raise ValueError(
|
|
77
|
+
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
|
|
78
|
+
"You do not have to specify the dtype when value is a Tensor."
|
|
79
|
+
)
|
|
80
|
+
return value
|
|
81
|
+
if isinstance(value, onnx.TensorProto):
|
|
82
|
+
tensor_ = serde.deserialize_tensor(value)
|
|
83
|
+
if name is not None:
|
|
84
|
+
tensor_.name = name
|
|
85
|
+
if doc_string is not None:
|
|
86
|
+
tensor_.doc_string = doc_string
|
|
87
|
+
if dtype is not None and dtype != tensor_.dtype:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
|
|
90
|
+
"You do not have to specify the dtype when value is a TensorProto."
|
|
91
|
+
)
|
|
92
|
+
return tensor_
|
|
93
|
+
elif str(type(value)) == "<class 'torch.Tensor'>":
|
|
94
|
+
# NOTE: We use str(type(...)) and do not import torch for type checking
|
|
95
|
+
# as it creates overhead during import
|
|
96
|
+
return tensor_adapters.TorchTensor(value, name=name, doc_string=doc_string) # type: ignore[arg-type]
|
|
97
|
+
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
|
|
98
|
+
return _core.Tensor(value, dtype=dtype, name=name, doc_string=doc_string)
|
|
99
|
+
|
|
100
|
+
# Plain (numerical) Python object. Determine the numpy dtype and use np.array to construct the tensor
|
|
101
|
+
if dtype is not None:
|
|
102
|
+
if not isinstance(dtype, _enums.DataType):
|
|
103
|
+
raise TypeError(f"dtype must be an instance of DataType. dtype={dtype}")
|
|
104
|
+
numpy_dtype = dtype.numpy()
|
|
105
|
+
elif isinstance(value, Sequence) and not value:
|
|
106
|
+
raise ValueError("dtype must be specified when value is an empty sequence.")
|
|
107
|
+
elif isinstance(value, int) and not isinstance(value, bool):
|
|
108
|
+
# Specify int64 for ints because on Windows this may be int32
|
|
109
|
+
numpy_dtype = np.dtype(np.int64)
|
|
110
|
+
elif isinstance(value, float):
|
|
111
|
+
# If the value is a single float, we use np.float32 as the default dtype
|
|
112
|
+
numpy_dtype = np.dtype(np.float32)
|
|
113
|
+
elif isinstance(value, Sequence) and value:
|
|
114
|
+
if all((isinstance(elem, int) and not isinstance(elem, bool)) for elem in value):
|
|
115
|
+
numpy_dtype = np.dtype(np.int64)
|
|
116
|
+
elif all(isinstance(elem, float) for elem in value):
|
|
117
|
+
# If the value is a sequence of floats, we use np.float32 as the default dtype
|
|
118
|
+
numpy_dtype = np.dtype(np.float32)
|
|
119
|
+
else:
|
|
120
|
+
numpy_dtype = None
|
|
121
|
+
else:
|
|
122
|
+
numpy_dtype = None
|
|
123
|
+
|
|
124
|
+
array = np.array(value, dtype=numpy_dtype)
|
|
125
|
+
|
|
126
|
+
# Handle string tensors by encoding them
|
|
127
|
+
if isinstance(value, str) or (
|
|
128
|
+
isinstance(value, Sequence) and value and all(isinstance(elem, str) for elem in value)
|
|
129
|
+
):
|
|
130
|
+
array = np.strings.encode(array, encoding="utf-8")
|
|
131
|
+
return _core.StringTensor(
|
|
132
|
+
array,
|
|
133
|
+
shape=_core.Shape(array.shape),
|
|
134
|
+
name=name,
|
|
135
|
+
doc_string=doc_string,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
return _core.Tensor(
|
|
139
|
+
array,
|
|
140
|
+
dtype=dtype,
|
|
141
|
+
shape=_core.Shape(array.shape),
|
|
142
|
+
name=name,
|
|
143
|
+
doc_string=doc_string,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def node(
|
|
148
|
+
op_type: str,
|
|
149
|
+
inputs: Sequence[ir.Value | None],
|
|
150
|
+
attributes: Mapping[str, _convenience.SupportedAttrTypes] | None = None,
|
|
151
|
+
*,
|
|
152
|
+
domain: str = "",
|
|
153
|
+
overload: str = "",
|
|
154
|
+
num_outputs: int | None = None,
|
|
155
|
+
outputs: Sequence[ir.Value] | None = None,
|
|
156
|
+
version: int | None = None,
|
|
157
|
+
graph: ir.Graph | None = None,
|
|
158
|
+
name: str | None = None,
|
|
159
|
+
doc_string: str | None = None,
|
|
160
|
+
metadata_props: dict[str, str] | None = None,
|
|
161
|
+
) -> ir.Node:
|
|
162
|
+
"""Create a :class:`~onnx_ir.Node`.
|
|
163
|
+
|
|
164
|
+
This is a convenience constructor for creating a Node that supports Python
|
|
165
|
+
objects as attributes.
|
|
166
|
+
|
|
167
|
+
Example::
|
|
168
|
+
|
|
169
|
+
>>> import onnx_ir as ir
|
|
170
|
+
>>> input_a = ir.val("A", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
|
|
171
|
+
>>> input_b = ir.val("B", shape=[1, 2], type=ir.TensorType(ir.DataType.INT32))
|
|
172
|
+
>>> node = ir.node(
|
|
173
|
+
... "SomeOp",
|
|
174
|
+
... inputs=[input_a, input_b],
|
|
175
|
+
... attributes={"alpha": 1.0, "some_list": [1, 2, 3]},
|
|
176
|
+
... domain="some.domain",
|
|
177
|
+
... name="node_name"
|
|
178
|
+
... )
|
|
179
|
+
>>> node.op_type
|
|
180
|
+
'SomeOp'
|
|
181
|
+
|
|
182
|
+
Args:
|
|
183
|
+
op_type: The name of the operator.
|
|
184
|
+
inputs: The input values. When an input is None, it is an empty input.
|
|
185
|
+
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
186
|
+
overload: The overload name when the node is invoking a function.
|
|
187
|
+
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
188
|
+
num_outputs: The number of outputs of the node. If not specified, the number is 1.
|
|
189
|
+
outputs: The output values. If None, the outputs are created during initialization.
|
|
190
|
+
version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
|
|
191
|
+
graph: The graph that the node belongs to. If None, the node is not added to any graph.
|
|
192
|
+
A `Node` must belong to zero or one graph.
|
|
193
|
+
name: The name of the node. If None, the node is anonymous.
|
|
194
|
+
doc_string: The documentation string.
|
|
195
|
+
metadata_props: The metadata properties.
|
|
196
|
+
|
|
197
|
+
Returns:
|
|
198
|
+
A node with the given op_type and inputs.
|
|
199
|
+
"""
|
|
200
|
+
if attributes is None:
|
|
201
|
+
attrs: Sequence[ir.Attr] = ()
|
|
202
|
+
else:
|
|
203
|
+
attrs = _convenience.convert_attributes(attributes)
|
|
204
|
+
return _core.Node(
|
|
205
|
+
domain=domain,
|
|
206
|
+
op_type=op_type,
|
|
207
|
+
inputs=inputs,
|
|
208
|
+
attributes=attrs,
|
|
209
|
+
overload=overload,
|
|
210
|
+
num_outputs=num_outputs,
|
|
211
|
+
outputs=outputs,
|
|
212
|
+
version=version,
|
|
213
|
+
graph=graph,
|
|
214
|
+
name=name,
|
|
215
|
+
doc_string=doc_string,
|
|
216
|
+
metadata_props=metadata_props,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def val(
|
|
221
|
+
name: str | None,
|
|
222
|
+
dtype: ir.DataType | None = None,
|
|
223
|
+
shape: ir.Shape | Sequence[int | str | None] | None = None,
|
|
224
|
+
*,
|
|
225
|
+
type: ir.TypeProtocol | None = None,
|
|
226
|
+
const_value: ir.TensorProtocol | None = None,
|
|
227
|
+
metadata_props: dict[str, str] | None = None,
|
|
228
|
+
) -> ir.Value:
|
|
229
|
+
"""Create a :class:`~onnx_ir.Value` with the given name and type.
|
|
230
|
+
|
|
231
|
+
This is a convenience constructor for creating a Value that allows you to specify
|
|
232
|
+
dtype and shape in a more relaxed manner. Whereas to create a Value directly, you
|
|
233
|
+
need to create a :class:`~onnx_ir.TypeProtocol` and :class:`~onnx_ir.Shape` object
|
|
234
|
+
first, this function allows you to specify dtype as a :class:`~onnx_ir.DataType`
|
|
235
|
+
and shape as a sequence of integers or symbolic dimensions.
|
|
236
|
+
|
|
237
|
+
Example::
|
|
238
|
+
|
|
239
|
+
>>> import onnx_ir as ir
|
|
240
|
+
>>> t = ir.val("x", ir.DataType.FLOAT, ["N", 42, 3])
|
|
241
|
+
>>> t.name
|
|
242
|
+
'x'
|
|
243
|
+
>>> t.type
|
|
244
|
+
Tensor(FLOAT)
|
|
245
|
+
>>> t.shape
|
|
246
|
+
Shape([SymbolicDim(N), 42, 3])
|
|
247
|
+
|
|
248
|
+
.. versionadded:: 0.1.9
|
|
249
|
+
|
|
250
|
+
Args:
|
|
251
|
+
name: The name of the value.
|
|
252
|
+
dtype: The data type of the TensorType of the value. This is used only when type is None.
|
|
253
|
+
shape: The shape of the value.
|
|
254
|
+
type: The type of the value. Only one of dtype and type can be specified.
|
|
255
|
+
const_value: The constant tensor that initializes the value. Supply this argument
|
|
256
|
+
when you want to create an initializer. The type and shape can be obtained from the tensor.
|
|
257
|
+
metadata_props: The metadata properties that will be serialized to the ONNX proto.
|
|
258
|
+
|
|
259
|
+
Returns:
|
|
260
|
+
A Value object.
|
|
261
|
+
"""
|
|
262
|
+
if const_value is not None:
|
|
263
|
+
const_tensor_type = _core.TensorType(const_value.dtype)
|
|
264
|
+
if type is not None and type != const_tensor_type:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"The type does not match the const_value. type={type} but const_value has type {const_tensor_type}. "
|
|
267
|
+
"You do not have to specify the type when const_value is provided."
|
|
268
|
+
)
|
|
269
|
+
if dtype is not None and dtype != const_value.dtype:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"The dtype does not match the const_value. dtype={dtype} but const_value has dtype {const_value.dtype}. "
|
|
272
|
+
"You do not have to specify the dtype when const_value is provided."
|
|
273
|
+
)
|
|
274
|
+
if shape is not None and _core.Shape(shape) != const_value.shape:
|
|
275
|
+
raise ValueError(
|
|
276
|
+
f"The shape does not match the const_value. shape={shape} but const_value has shape {const_value.shape}. "
|
|
277
|
+
"You do not have to specify the shape when const_value is provided."
|
|
278
|
+
)
|
|
279
|
+
return _core.Value(
|
|
280
|
+
name=name,
|
|
281
|
+
type=const_tensor_type,
|
|
282
|
+
shape=_core.Shape(const_value.shape), # type: ignore
|
|
283
|
+
const_value=const_value,
|
|
284
|
+
metadata_props=metadata_props,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if type is None and dtype is not None:
|
|
288
|
+
type = _core.TensorType(dtype)
|
|
289
|
+
if shape is not None and not isinstance(shape, _core.Shape):
|
|
290
|
+
shape = _core.Shape(shape)
|
|
291
|
+
return _core.Value(name=name, type=type, shape=shape, metadata_props=metadata_props)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Utilities for extracting subgraphs from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import itertools
|
|
8
|
+
from collections.abc import Collection, Sequence
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
import onnx_ir as ir
|
|
12
|
+
|
|
13
|
+
GraphLike = Union["ir.Graph", "ir.Function", "ir.GraphView"]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _collect_all_external_values(parent_graph: ir.Graph, graph: ir.Graph) -> set[ir.Value]:
|
|
17
|
+
"""Collects all values in the given graph-like object.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
parent_graph: The parent graph to which collected values must belong.
|
|
21
|
+
graph: The graph-like object to collect values from.
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A set of :class:`~onnx_ir.Value` objects belonging to ``parent_graph``.
|
|
25
|
+
"""
|
|
26
|
+
values: set[ir.Value] = set()
|
|
27
|
+
for node in ir.traversal.RecursiveGraphIterator(graph):
|
|
28
|
+
for val in node.inputs:
|
|
29
|
+
if val is None:
|
|
30
|
+
continue
|
|
31
|
+
if val.graph is parent_graph:
|
|
32
|
+
values.add(val)
|
|
33
|
+
return values
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _find_subgraph_bounded_by_values(
|
|
37
|
+
graph: GraphLike,
|
|
38
|
+
inputs: Collection[ir.Value],
|
|
39
|
+
outputs: Collection[ir.Value],
|
|
40
|
+
parent_graph: ir.Graph,
|
|
41
|
+
) -> tuple[list[ir.Node], Collection[ir.Value]]:
|
|
42
|
+
"""Finds the subgraph bounded by the given inputs and outputs.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
graph: The graph to search.
|
|
46
|
+
inputs: The inputs to the subgraph.
|
|
47
|
+
outputs: The outputs of the subgraph.
|
|
48
|
+
parent_graph: The parent graph of the subgraph.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of nodes in the subgraph and the initializers used.
|
|
52
|
+
|
|
53
|
+
Raises:
|
|
54
|
+
ValueError: If the subgraph is not properly bounded by the given inputs and outputs.
|
|
55
|
+
"""
|
|
56
|
+
if isinstance(graph, ir.Function):
|
|
57
|
+
initialized_values: set[ir.Value] = set()
|
|
58
|
+
else:
|
|
59
|
+
initialized_values = {val for val in inputs if val.is_initializer()}
|
|
60
|
+
node_index = {node: idx for idx, node in enumerate(graph)}
|
|
61
|
+
all_nodes = []
|
|
62
|
+
value_stack: list[ir.Value] = [*outputs]
|
|
63
|
+
visited_nodes: set[ir.Node] = set()
|
|
64
|
+
visited_values: set[ir.Value] = set(inputs)
|
|
65
|
+
|
|
66
|
+
while value_stack:
|
|
67
|
+
value = value_stack.pop()
|
|
68
|
+
if value in visited_values:
|
|
69
|
+
continue
|
|
70
|
+
if value.is_initializer():
|
|
71
|
+
# Record the initializer
|
|
72
|
+
initialized_values.add(value)
|
|
73
|
+
|
|
74
|
+
visited_values.add(value)
|
|
75
|
+
|
|
76
|
+
if (node := value.producer()) is not None:
|
|
77
|
+
if node not in visited_nodes:
|
|
78
|
+
visited_nodes.add(node)
|
|
79
|
+
all_nodes.append(node)
|
|
80
|
+
for input in node.inputs:
|
|
81
|
+
if input not in visited_values and input is not None:
|
|
82
|
+
value_stack.append(input)
|
|
83
|
+
for attr in node.attributes.values():
|
|
84
|
+
if attr.type == ir.AttributeType.GRAPH:
|
|
85
|
+
values = _collect_all_external_values(parent_graph, attr.as_graph())
|
|
86
|
+
for val in values:
|
|
87
|
+
if val not in visited_values:
|
|
88
|
+
value_stack.append(val)
|
|
89
|
+
elif attr.type == ir.AttributeType.GRAPHS:
|
|
90
|
+
for g in attr.as_graphs():
|
|
91
|
+
values = _collect_all_external_values(parent_graph, g)
|
|
92
|
+
for val in values:
|
|
93
|
+
if val not in visited_values:
|
|
94
|
+
value_stack.append(val)
|
|
95
|
+
|
|
96
|
+
# Validate that the subgraph is properly bounded
|
|
97
|
+
# Collect all values at the input frontier (used by subgraph but not produced by it)
|
|
98
|
+
# The frontier can only contain graph inputs or initializers (values with no producer)
|
|
99
|
+
input_frontier: set[ir.Value] = set()
|
|
100
|
+
for node in visited_nodes:
|
|
101
|
+
for input_val in node.inputs:
|
|
102
|
+
if input_val is None:
|
|
103
|
+
continue
|
|
104
|
+
# If this value is not produced by any node in the subgraph
|
|
105
|
+
producer = input_val.producer()
|
|
106
|
+
if producer is None or producer not in visited_nodes:
|
|
107
|
+
input_frontier.add(input_val)
|
|
108
|
+
|
|
109
|
+
# Check for graph inputs that weren't specified in the inputs parameter
|
|
110
|
+
# (initializers are allowed, but unspecified graph inputs mean the subgraph is unbounded)
|
|
111
|
+
unspecified_graph_inputs: list[ir.Value] = []
|
|
112
|
+
inputs_set = set(inputs)
|
|
113
|
+
for val in sorted(input_frontier, key=lambda v: v.name or ""):
|
|
114
|
+
if val not in inputs_set and not val.is_initializer():
|
|
115
|
+
unspecified_graph_inputs.append(val)
|
|
116
|
+
|
|
117
|
+
if unspecified_graph_inputs:
|
|
118
|
+
value_names = [val.name or "<None>" for val in unspecified_graph_inputs]
|
|
119
|
+
raise ValueError(
|
|
120
|
+
f"The subgraph is not properly bounded by the specified inputs and outputs. "
|
|
121
|
+
f"The following graph inputs are required but not provided: {', '.join(value_names)}"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# Preserve the original order
|
|
125
|
+
all_nodes.sort(key=lambda n: node_index[n])
|
|
126
|
+
return all_nodes, initialized_values
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
def extract(
|
|
130
|
+
graph_like: GraphLike,
|
|
131
|
+
/,
|
|
132
|
+
inputs: Sequence[ir.Value | str],
|
|
133
|
+
outputs: Sequence[ir.Value | str],
|
|
134
|
+
) -> ir.Graph:
|
|
135
|
+
"""Extracts a subgraph from the given graph-like object.
|
|
136
|
+
|
|
137
|
+
.. versionadded:: 0.1.14
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
graph_like: The graph-like object to extract from.
|
|
141
|
+
inputs: The inputs to the subgraph. Can be Value objects or their names.
|
|
142
|
+
outputs: The outputs of the subgraph. Can be Value objects or their names.
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
The extracted subgraph as a new :class:`~onnx_ir.Graph` object.
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
ValueError: If any of the inputs or outputs are not found in the graph.
|
|
149
|
+
ValueError: If the subgraph is not properly bounded by the given inputs and outputs.
|
|
150
|
+
"""
|
|
151
|
+
if isinstance(graph_like, ir.Function):
|
|
152
|
+
graph: ir.Graph | ir.GraphView = graph_like.graph
|
|
153
|
+
else:
|
|
154
|
+
graph = graph_like
|
|
155
|
+
values = ir.convenience.create_value_mapping(graph, include_subgraphs=False)
|
|
156
|
+
is_graph_view = isinstance(graph_like, ir.GraphView)
|
|
157
|
+
for val in itertools.chain(inputs, outputs):
|
|
158
|
+
if isinstance(val, ir.Value):
|
|
159
|
+
if not is_graph_view and val.graph is not graph:
|
|
160
|
+
graph_name = graph.name if graph.name is not None else "unnamed graph"
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Value '{val}' does not belong to the given "
|
|
163
|
+
f"{graph_like.__class__.__name__} ({graph_name})."
|
|
164
|
+
)
|
|
165
|
+
else:
|
|
166
|
+
if val not in values:
|
|
167
|
+
raise ValueError(f"Value with name '{val}' not found in the graph.")
|
|
168
|
+
|
|
169
|
+
input_vals = [values[val] if isinstance(val, str) else val for val in inputs]
|
|
170
|
+
output_vals = [values[val] if isinstance(val, str) else val for val in outputs]
|
|
171
|
+
# Find the owning graph of the outputs to set as the parent graph
|
|
172
|
+
if not output_vals:
|
|
173
|
+
raise ValueError("At least one output must be provided to extract a subgraph.")
|
|
174
|
+
parent_graph = output_vals[0].graph
|
|
175
|
+
assert parent_graph is not None
|
|
176
|
+
extracted_nodes, initialized_values = _find_subgraph_bounded_by_values(
|
|
177
|
+
graph_like, input_vals, output_vals, parent_graph=parent_graph
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
graph_view = ir.GraphView(
|
|
181
|
+
input_vals,
|
|
182
|
+
output_vals,
|
|
183
|
+
nodes=extracted_nodes,
|
|
184
|
+
initializers=tuple(initialized_values),
|
|
185
|
+
doc_string=graph_like.doc_string,
|
|
186
|
+
opset_imports=graph_like.opset_imports,
|
|
187
|
+
name=graph_like.name,
|
|
188
|
+
metadata_props=graph_like.metadata_props,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return graph_view.clone()
|