onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,558 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Convenience methods for constructing and manipulating the IR.
|
|
4
|
+
|
|
5
|
+
This is an internal only module. We should choose to expose some of the methods
|
|
6
|
+
in convenience.py after they are proven to be useful.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"convert_attribute",
|
|
13
|
+
"convert_attributes",
|
|
14
|
+
"replace_all_uses_with",
|
|
15
|
+
"create_value_mapping",
|
|
16
|
+
"replace_nodes_and_values",
|
|
17
|
+
"get_const_tensor",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
import logging
|
|
21
|
+
from collections.abc import Iterable, Mapping, Sequence
|
|
22
|
+
from typing import Union
|
|
23
|
+
|
|
24
|
+
import numpy as np
|
|
25
|
+
import onnx # noqa: TID251
|
|
26
|
+
|
|
27
|
+
from onnx_ir import _core, _enums, _protocols, serde, traversal
|
|
28
|
+
|
|
29
|
+
SupportedAttrTypes = Union[
|
|
30
|
+
str,
|
|
31
|
+
int,
|
|
32
|
+
float,
|
|
33
|
+
Sequence[int],
|
|
34
|
+
Sequence[float],
|
|
35
|
+
Sequence[str],
|
|
36
|
+
_protocols.TensorProtocol, # This includes all in-memory tensor types
|
|
37
|
+
onnx.TensorProto,
|
|
38
|
+
_core.Attr,
|
|
39
|
+
_protocols.GraphProtocol,
|
|
40
|
+
Sequence[_protocols.GraphProtocol],
|
|
41
|
+
onnx.GraphProto,
|
|
42
|
+
_protocols.TypeProtocol,
|
|
43
|
+
Sequence[_protocols.TypeProtocol],
|
|
44
|
+
None,
|
|
45
|
+
]
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
52
|
+
"""Infer the attribute type based on the type of the Python object."""
|
|
53
|
+
if isinstance(attr, int):
|
|
54
|
+
return _enums.AttributeType.INT
|
|
55
|
+
if isinstance(attr, float):
|
|
56
|
+
return _enums.AttributeType.FLOAT
|
|
57
|
+
if isinstance(attr, str):
|
|
58
|
+
return _enums.AttributeType.STRING
|
|
59
|
+
if isinstance(attr, _core.Attr):
|
|
60
|
+
return attr.type
|
|
61
|
+
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
62
|
+
return _enums.AttributeType.GRAPH
|
|
63
|
+
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
|
|
64
|
+
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
|
|
65
|
+
return _enums.AttributeType.TENSOR
|
|
66
|
+
if isinstance(
|
|
67
|
+
attr,
|
|
68
|
+
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
|
|
69
|
+
):
|
|
70
|
+
return _enums.AttributeType.TYPE_PROTO
|
|
71
|
+
if isinstance(attr, Sequence):
|
|
72
|
+
if not attr:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"Attribute type is ambiguous because it is an empty sequence. "
|
|
75
|
+
"Please create an Attr with an explicit type. Defaulted to INTS"
|
|
76
|
+
)
|
|
77
|
+
return _enums.AttributeType.INTS
|
|
78
|
+
if all(isinstance(x, int) for x in attr):
|
|
79
|
+
return _enums.AttributeType.INTS
|
|
80
|
+
if all(isinstance(x, float) for x in attr):
|
|
81
|
+
return _enums.AttributeType.FLOATS
|
|
82
|
+
if all(isinstance(x, str) for x in attr):
|
|
83
|
+
return _enums.AttributeType.STRINGS
|
|
84
|
+
if all(
|
|
85
|
+
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
86
|
+
for x in attr
|
|
87
|
+
):
|
|
88
|
+
return _enums.AttributeType.TENSORS
|
|
89
|
+
if all(
|
|
90
|
+
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
|
|
91
|
+
for x in attr
|
|
92
|
+
):
|
|
93
|
+
return _enums.AttributeType.GRAPHS
|
|
94
|
+
if all(
|
|
95
|
+
isinstance(
|
|
96
|
+
x,
|
|
97
|
+
(
|
|
98
|
+
_core.TensorType,
|
|
99
|
+
_core.SequenceType,
|
|
100
|
+
_core.OptionalType,
|
|
101
|
+
_protocols.TypeProtocol,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
for x in attr
|
|
105
|
+
):
|
|
106
|
+
return _enums.AttributeType.TYPE_PROTOS
|
|
107
|
+
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def convert_attribute(
|
|
111
|
+
name: str,
|
|
112
|
+
attr: SupportedAttrTypes,
|
|
113
|
+
attr_type: _enums.AttributeType | None = None,
|
|
114
|
+
) -> _core.Attr:
|
|
115
|
+
"""Convert a Python object to a _core.Attr object.
|
|
116
|
+
|
|
117
|
+
This method is useful when constructing nodes with attributes. It infers the
|
|
118
|
+
attribute type based on the type of the Python value.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
name: The name of the attribute.
|
|
122
|
+
attr: The value of the attribute.
|
|
123
|
+
attr_type: The type of the attribute. This is required when attr is None.
|
|
124
|
+
When provided, it overrides the inferred type.
|
|
125
|
+
|
|
126
|
+
Returns:
|
|
127
|
+
A ``Attr`` object.
|
|
128
|
+
|
|
129
|
+
Raises:
|
|
130
|
+
ValueError: If ``attr`` is ``None`` and ``attr_type`` is not provided.
|
|
131
|
+
TypeError: If the type of the attribute is not supported.
|
|
132
|
+
"""
|
|
133
|
+
if attr is None:
|
|
134
|
+
if attr_type is None:
|
|
135
|
+
raise ValueError("attr_type must be provided when attr is None")
|
|
136
|
+
return _core.Attr(name, attr_type, None)
|
|
137
|
+
|
|
138
|
+
if isinstance(attr, _core.Attr):
|
|
139
|
+
if attr.name != name:
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f"Attribute name '{attr.name}' does not match provided name '{name}'"
|
|
142
|
+
)
|
|
143
|
+
if attr_type is not None and attr.type != attr_type:
|
|
144
|
+
raise ValueError(
|
|
145
|
+
f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
|
|
146
|
+
)
|
|
147
|
+
return attr
|
|
148
|
+
|
|
149
|
+
if attr_type is None:
|
|
150
|
+
attr_type = _infer_attribute_type(attr)
|
|
151
|
+
|
|
152
|
+
if attr_type == _enums.AttributeType.INT:
|
|
153
|
+
return _core.AttrInt64(name, attr) # type: ignore
|
|
154
|
+
if attr_type == _enums.AttributeType.FLOAT:
|
|
155
|
+
return _core.AttrFloat32(name, attr) # type: ignore
|
|
156
|
+
if attr_type == _enums.AttributeType.STRING:
|
|
157
|
+
return _core.AttrString(name, attr) # type: ignore
|
|
158
|
+
if attr_type == _enums.AttributeType.INTS:
|
|
159
|
+
return _core.AttrInt64s(name, attr) # type: ignore
|
|
160
|
+
if attr_type == _enums.AttributeType.FLOATS:
|
|
161
|
+
return _core.AttrFloat32s(name, attr) # type: ignore
|
|
162
|
+
if attr_type == _enums.AttributeType.STRINGS:
|
|
163
|
+
return _core.AttrStrings(name, attr) # type: ignore
|
|
164
|
+
if attr_type == _enums.AttributeType.TENSOR:
|
|
165
|
+
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
|
|
166
|
+
return _core.AttrTensor(name, attr)
|
|
167
|
+
if isinstance(attr, onnx.TensorProto):
|
|
168
|
+
return _core.AttrTensor(name, serde.deserialize_tensor(attr))
|
|
169
|
+
if attr_type == _enums.AttributeType.TENSORS:
|
|
170
|
+
tensors = []
|
|
171
|
+
for t in attr: # type: ignore[union-attr]
|
|
172
|
+
if isinstance(t, onnx.TensorProto):
|
|
173
|
+
tensors.append(_core.AttrTensor(name, serde.deserialize_tensor(t)))
|
|
174
|
+
else:
|
|
175
|
+
tensors.append(t) # type: ignore[arg-type]
|
|
176
|
+
return _core.AttrTensors(name, tensors) # type: ignore[arg-type]
|
|
177
|
+
if attr_type == _enums.AttributeType.GRAPH:
|
|
178
|
+
if isinstance(attr, onnx.GraphProto):
|
|
179
|
+
attr = serde.deserialize_graph(attr)
|
|
180
|
+
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
|
|
181
|
+
if attr_type == _enums.AttributeType.GRAPHS:
|
|
182
|
+
graphs = []
|
|
183
|
+
for graph in attr: # type: ignore[union-attr]
|
|
184
|
+
if isinstance(graph, onnx.GraphProto):
|
|
185
|
+
graphs.append(serde.deserialize_graph(graph))
|
|
186
|
+
else:
|
|
187
|
+
graphs.append(graph) # type: ignore[arg-type]
|
|
188
|
+
return _core.AttrGraphs(name, graphs) # type: ignore[arg-type]
|
|
189
|
+
if attr_type == _enums.AttributeType.TYPE_PROTO:
|
|
190
|
+
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
|
|
191
|
+
if attr_type == _enums.AttributeType.TYPE_PROTOS:
|
|
192
|
+
return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
|
|
193
|
+
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def convert_attributes(
|
|
197
|
+
attrs: Mapping[str, SupportedAttrTypes],
|
|
198
|
+
) -> list[_core.Attr]:
|
|
199
|
+
"""Convert a dictionary of attributes to a list of _core.Attr objects.
|
|
200
|
+
|
|
201
|
+
It infers the attribute type based on the type of the value. The supported
|
|
202
|
+
types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
|
|
203
|
+
:class:`_core.Tensor`, and :class:`_core.Attr`::
|
|
204
|
+
|
|
205
|
+
>>> import onnx_ir as ir
|
|
206
|
+
>>> import onnx
|
|
207
|
+
>>> import numpy as np
|
|
208
|
+
>>> attrs = {
|
|
209
|
+
... "int": 1,
|
|
210
|
+
... "float": 1.0,
|
|
211
|
+
... "str": "hello",
|
|
212
|
+
... "ints": [1, 2, 3],
|
|
213
|
+
... "floats": [1.0, 2.0, 3.0],
|
|
214
|
+
... "strings": ["hello", "world"],
|
|
215
|
+
... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
|
|
216
|
+
... "tensor_proto":
|
|
217
|
+
... onnx.TensorProto(
|
|
218
|
+
... dims=[3],
|
|
219
|
+
... data_type=onnx.TensorProto.FLOAT,
|
|
220
|
+
... float_data=[1.0, 2.0, 3.0],
|
|
221
|
+
... name="proto",
|
|
222
|
+
... ),
|
|
223
|
+
... "graph": ir.Graph([], [], nodes=[], name="graph0"),
|
|
224
|
+
... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
|
|
225
|
+
... "type_proto": ir.TensorType(ir.DataType.FLOAT),
|
|
226
|
+
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
|
|
227
|
+
... }
|
|
228
|
+
>>> convert_attributes(attrs)
|
|
229
|
+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, (1, 2, 3)), Attr('floats', FLOATS, (1.0, 2.0, 3.0)), Attr('strings', STRINGS, ('hello', 'world')), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
|
|
230
|
+
name='graph0',
|
|
231
|
+
inputs=(
|
|
232
|
+
<BLANKLINE>
|
|
233
|
+
),
|
|
234
|
+
outputs=(
|
|
235
|
+
<BLANKLINE>
|
|
236
|
+
),
|
|
237
|
+
len()=0
|
|
238
|
+
)), Attr('graphs', GRAPHS, (Graph(
|
|
239
|
+
name='graph1',
|
|
240
|
+
inputs=(
|
|
241
|
+
<BLANKLINE>
|
|
242
|
+
),
|
|
243
|
+
outputs=(
|
|
244
|
+
<BLANKLINE>
|
|
245
|
+
),
|
|
246
|
+
len()=0
|
|
247
|
+
), Graph(
|
|
248
|
+
name='graph2',
|
|
249
|
+
inputs=(
|
|
250
|
+
<BLANKLINE>
|
|
251
|
+
),
|
|
252
|
+
outputs=(
|
|
253
|
+
<BLANKLINE>
|
|
254
|
+
),
|
|
255
|
+
len()=0
|
|
256
|
+
))), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, (Tensor(FLOAT), Tensor(FLOAT)))]
|
|
257
|
+
|
|
258
|
+
.. important::
|
|
259
|
+
An empty sequence should be created with an explicit type by initializing
|
|
260
|
+
an Attr object with an attribute type to avoid type ambiguity. For example::
|
|
261
|
+
|
|
262
|
+
ir.Attr("empty", [], type=ir.AttributeType.INTS)
|
|
263
|
+
|
|
264
|
+
Args:
|
|
265
|
+
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
A list of :class:`_core.Attr` objects.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
TypeError: If an attribute type is not supported.
|
|
272
|
+
"""
|
|
273
|
+
attributes: list[_core.Attr] = []
|
|
274
|
+
for name, attr in attrs.items():
|
|
275
|
+
if attr is not None:
|
|
276
|
+
attributes.append(convert_attribute(name, attr))
|
|
277
|
+
return attributes
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def replace_all_uses_with(
|
|
281
|
+
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
282
|
+
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
283
|
+
replace_graph_outputs: bool = False,
|
|
284
|
+
) -> None:
|
|
285
|
+
"""Replace all uses of the given values with the replacements.
|
|
286
|
+
|
|
287
|
+
This is useful when nodes in the graph are replaced with new nodes, where
|
|
288
|
+
the old users need to be updated to use the outputs of the new nodes.
|
|
289
|
+
|
|
290
|
+
For example, suppose we have the following graph::
|
|
291
|
+
|
|
292
|
+
A -> {B, C}
|
|
293
|
+
|
|
294
|
+
We want to replace the node A with a new node D::
|
|
295
|
+
|
|
296
|
+
>>> import onnx_ir as ir
|
|
297
|
+
>>> input = ir.val("input")
|
|
298
|
+
>>> node_a = ir.Node("", "A", [input])
|
|
299
|
+
>>> node_b = ir.Node("", "B", node_a.outputs)
|
|
300
|
+
>>> node_c = ir.Node("", "C", node_a.outputs)
|
|
301
|
+
>>> node_d = ir.Node("", "D", [input])
|
|
302
|
+
>>> replace_all_uses_with(node_a.outputs, node_d.outputs)
|
|
303
|
+
>>> len(node_b.inputs)
|
|
304
|
+
1
|
|
305
|
+
>>> node_b.inputs[0].producer().op_type
|
|
306
|
+
'D'
|
|
307
|
+
>>> len(node_c.inputs)
|
|
308
|
+
1
|
|
309
|
+
>>> node_c.inputs[0].producer().op_type
|
|
310
|
+
'D'
|
|
311
|
+
>>> len(node_a.outputs[0].uses())
|
|
312
|
+
0
|
|
313
|
+
|
|
314
|
+
When values and replacements are sequences, they are zipped into pairs. All
|
|
315
|
+
users of the first value is replaced with the first replacement, and so on.
|
|
316
|
+
|
|
317
|
+
.. note::
|
|
318
|
+
Be sure to remove the old nodes from the graph using ``graph.remove()``
|
|
319
|
+
if they are no longer needed, or use :class:`onnx_ir.passes.common.RemoveUnusedNodesPass`
|
|
320
|
+
to remove all unused nodes in the graph.
|
|
321
|
+
|
|
322
|
+
.. tip::
|
|
323
|
+
**Handling graph outputs**
|
|
324
|
+
|
|
325
|
+
To also replace graph outputs that reference the values being replaced, either
|
|
326
|
+
set ``replace_graph_outputs`` to True, or manually update the graph outputs
|
|
327
|
+
before calling this function to avoid an error being raised when ``replace_graph_outputs=False``.
|
|
328
|
+
|
|
329
|
+
Be careful when a value appears multiple times in the graph outputs -
|
|
330
|
+
this is invalid. An identity node will need to be added on each duplicated
|
|
331
|
+
outputs to ensure a valid ONNX graph.
|
|
332
|
+
|
|
333
|
+
You may also want to assign the name of this value to the replacement value
|
|
334
|
+
to maintain the name when it is a graph output.
|
|
335
|
+
|
|
336
|
+
.. versionadded:: 0.1.12
|
|
337
|
+
The ``replace_graph_outputs`` parameter is added.
|
|
338
|
+
|
|
339
|
+
.. versionadded:: 0.1.12
|
|
340
|
+
ValueError is raised when ``replace_graph_outputs`` is False && when the value to
|
|
341
|
+
replace is a graph output.
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
values: The value or values to be replaced.
|
|
345
|
+
replacements: The new value or values to use as inputs.
|
|
346
|
+
replace_graph_outputs: If True, graph outputs that reference the values
|
|
347
|
+
being replaced will also be updated to reference the replacements.
|
|
348
|
+
|
|
349
|
+
Raises:
|
|
350
|
+
ValueError: When ``replace_graph_outputs`` is False && when the value to
|
|
351
|
+
replace is a graph output.
|
|
352
|
+
"""
|
|
353
|
+
if not isinstance(values, Sequence):
|
|
354
|
+
values = (values,)
|
|
355
|
+
if not isinstance(replacements, Sequence):
|
|
356
|
+
replacements = (replacements,)
|
|
357
|
+
if len(values) != len(replacements):
|
|
358
|
+
raise ValueError("The number of values and replacements must match.")
|
|
359
|
+
for value, replacement in zip(values, replacements):
|
|
360
|
+
value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def create_value_mapping(
|
|
364
|
+
graph: _core.Graph | _core.GraphView | _core.Function,
|
|
365
|
+
*,
|
|
366
|
+
include_subgraphs: bool = True,
|
|
367
|
+
) -> dict[str, _core.Value]:
|
|
368
|
+
"""Return a dictionary mapping names to values in the graph.
|
|
369
|
+
|
|
370
|
+
The mapping includes values from subgraphs. Duplicated names are omitted,
|
|
371
|
+
and the first value with that name is returned. Values with empty names
|
|
372
|
+
are excluded from the mapping.
|
|
373
|
+
|
|
374
|
+
.. versionchanged:: 0.1.2
|
|
375
|
+
Values from subgraphs are now included in the mapping.
|
|
376
|
+
|
|
377
|
+
.. versionadded:: 0.1.14
|
|
378
|
+
The ``include_subgraphs`` parameter.
|
|
379
|
+
|
|
380
|
+
Args:
|
|
381
|
+
graph: The graph to extract the mapping from.
|
|
382
|
+
include_subgraphs: If True, values from subgraphs are included in the mapping.
|
|
383
|
+
|
|
384
|
+
Returns:
|
|
385
|
+
A dictionary mapping names to values.
|
|
386
|
+
"""
|
|
387
|
+
values: dict[str, _core.Value] = {}
|
|
388
|
+
if not isinstance(graph, _core.Function):
|
|
389
|
+
values.update(graph.initializers)
|
|
390
|
+
# The names of the values can be None or "", which we need to exclude
|
|
391
|
+
for input in graph.inputs:
|
|
392
|
+
if not input.name:
|
|
393
|
+
continue
|
|
394
|
+
if input.name in values:
|
|
395
|
+
continue
|
|
396
|
+
values[input.name] = input
|
|
397
|
+
if include_subgraphs:
|
|
398
|
+
iterator: Iterable[_core.Node] = traversal.RecursiveGraphIterator(graph)
|
|
399
|
+
else:
|
|
400
|
+
iterator = graph
|
|
401
|
+
for node in iterator:
|
|
402
|
+
for value in node.inputs:
|
|
403
|
+
if not value:
|
|
404
|
+
continue
|
|
405
|
+
if not value.name:
|
|
406
|
+
continue
|
|
407
|
+
if value.name in values:
|
|
408
|
+
continue
|
|
409
|
+
values[value.name] = value
|
|
410
|
+
for value in node.outputs:
|
|
411
|
+
if not value.name:
|
|
412
|
+
continue
|
|
413
|
+
if value.name in values:
|
|
414
|
+
continue
|
|
415
|
+
values[value.name] = value
|
|
416
|
+
return values
|
|
417
|
+
|
|
418
|
+
|
|
419
|
+
def replace_nodes_and_values(
|
|
420
|
+
graph_or_function: _core.Graph | _core.Function,
|
|
421
|
+
/,
|
|
422
|
+
insertion_point: _core.Node,
|
|
423
|
+
old_nodes: Sequence[_core.Node],
|
|
424
|
+
new_nodes: Sequence[_core.Node],
|
|
425
|
+
old_values: Sequence[_core.Value],
|
|
426
|
+
new_values: Sequence[_core.Value],
|
|
427
|
+
) -> None:
|
|
428
|
+
"""Replaces nodes and values in the graph or function.
|
|
429
|
+
|
|
430
|
+
Args:
|
|
431
|
+
graph_or_function: The graph or function to replace nodes and values in.
|
|
432
|
+
insertion_point: The node to insert the new nodes after.
|
|
433
|
+
old_nodes: The nodes to replace.
|
|
434
|
+
new_nodes: The nodes to replace with.
|
|
435
|
+
old_values: The values to replace.
|
|
436
|
+
new_values: The values to replace with.
|
|
437
|
+
"""
|
|
438
|
+
for old_value, new_value in zip(old_values, new_values):
|
|
439
|
+
# Propagate relevant info from old value to new value
|
|
440
|
+
# TODO(Rama): Perhaps this should be a separate utility function.
|
|
441
|
+
new_value.type = old_value.type if old_value.type is not None else new_value.type
|
|
442
|
+
new_value.shape = old_value.shape if old_value.shape is not None else new_value.shape
|
|
443
|
+
new_value.const_value = (
|
|
444
|
+
old_value.const_value
|
|
445
|
+
if old_value.const_value is not None
|
|
446
|
+
else new_value.const_value
|
|
447
|
+
)
|
|
448
|
+
new_value.name = old_value.name if old_value.name is not None else new_value.name
|
|
449
|
+
|
|
450
|
+
# Reconnect the users of the deleted values to use the new values
|
|
451
|
+
replace_all_uses_with(old_values, new_values, replace_graph_outputs=True)
|
|
452
|
+
|
|
453
|
+
# insert new nodes after the index node
|
|
454
|
+
graph_or_function.insert_after(insertion_point, new_nodes)
|
|
455
|
+
graph_or_function.remove(old_nodes, safe=True)
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def get_const_tensor(
|
|
459
|
+
value: _core.Value, propagate_shape_type: bool = False
|
|
460
|
+
) -> _protocols.TensorProtocol | None:
|
|
461
|
+
"""Get the constant tensor from a value, if it exists.
|
|
462
|
+
|
|
463
|
+
A constant tensor can be obtained if the value has a ``const_value`` set
|
|
464
|
+
(as in the case of an initializer) or if the value is produced by a
|
|
465
|
+
Constant node.
|
|
466
|
+
|
|
467
|
+
This function will not alter the ``const_value`` of the value, but
|
|
468
|
+
it will propagate the shape and type of the constant tensor to the value
|
|
469
|
+
if `propagate_shape_type` is set to True.
|
|
470
|
+
|
|
471
|
+
.. versionadded:: 0.1.2
|
|
472
|
+
|
|
473
|
+
Args:
|
|
474
|
+
value: The value to get the constant tensor from.
|
|
475
|
+
propagate_shape_type: If True, the shape and type of the value will be
|
|
476
|
+
propagated to the Value.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
The constant tensor if it exists, otherwise None.
|
|
480
|
+
|
|
481
|
+
Raises:
|
|
482
|
+
ValueError: If the Constant node does not have exactly one output or
|
|
483
|
+
one attribute.
|
|
484
|
+
"""
|
|
485
|
+
tensor = None
|
|
486
|
+
if value.const_value is not None:
|
|
487
|
+
tensor = value.const_value
|
|
488
|
+
else:
|
|
489
|
+
node = value.producer()
|
|
490
|
+
if node is None:
|
|
491
|
+
# Potentially a graph input
|
|
492
|
+
return None
|
|
493
|
+
if node.op_type != "Constant" or node.domain != "":
|
|
494
|
+
# Not a Constant node or not in the ONNX domain
|
|
495
|
+
return None
|
|
496
|
+
if len(node.outputs) != 1:
|
|
497
|
+
raise ValueError(
|
|
498
|
+
f"Constant node '{node.name}' must have exactly one output, "
|
|
499
|
+
f"but has {len(node.outputs)} outputs."
|
|
500
|
+
)
|
|
501
|
+
if len(node.attributes) != 1:
|
|
502
|
+
raise ValueError(
|
|
503
|
+
f"Constant node '{node.name}' must have exactly one attribute, "
|
|
504
|
+
f"but has {len(node.attributes)} attributes."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
attr_name, attr_value = next(iter(node.attributes.items()))
|
|
508
|
+
|
|
509
|
+
if attr_value.is_ref():
|
|
510
|
+
# TODO: Make it easier to resolve a reference attribute.
|
|
511
|
+
# For now we just return None
|
|
512
|
+
return None
|
|
513
|
+
|
|
514
|
+
ir_value = node.outputs[0]
|
|
515
|
+
if attr_name in {"value_float", "value_floats"}:
|
|
516
|
+
tensor = _core.Tensor(
|
|
517
|
+
np.array(attr_value.value, dtype=np.float32), name=ir_value.name
|
|
518
|
+
)
|
|
519
|
+
elif attr_name in {"value_int", "value_ints"}:
|
|
520
|
+
tensor = _core.Tensor(
|
|
521
|
+
np.array(attr_value.value, dtype=np.int64), name=ir_value.name
|
|
522
|
+
)
|
|
523
|
+
elif attr_name in {"value_string", "value_strings"}:
|
|
524
|
+
tensor = _core.StringTensor(
|
|
525
|
+
np.array(attr_value.value, dtype=np.bytes_), name=ir_value.name
|
|
526
|
+
)
|
|
527
|
+
elif attr_name == "value":
|
|
528
|
+
tensor = attr_value.as_tensor()
|
|
529
|
+
else:
|
|
530
|
+
raise ValueError(
|
|
531
|
+
f"Unsupported attribute '{attr_name}' in Constant node '{node.name}'. "
|
|
532
|
+
"Expected one of 'value_float', 'value_floats', 'value_int', "
|
|
533
|
+
"'value_ints', 'value_string', 'value_strings', or 'value'."
|
|
534
|
+
)
|
|
535
|
+
# Assign the name of the constant value to the tensor
|
|
536
|
+
tensor.name = value.name
|
|
537
|
+
if tensor is not None and propagate_shape_type:
|
|
538
|
+
# Propagate the shape and type of the tensor to the value
|
|
539
|
+
if value.shape is not None and value.shape != tensor.shape:
|
|
540
|
+
logger.warning(
|
|
541
|
+
"Value '%s' has a shape %s that differs from "
|
|
542
|
+
"the constant tensor's shape %s. The value's shape will be updated.",
|
|
543
|
+
value,
|
|
544
|
+
value.shape,
|
|
545
|
+
tensor.shape,
|
|
546
|
+
)
|
|
547
|
+
value.shape = tensor.shape # type: ignore[assignment]
|
|
548
|
+
new_value_type = _core.TensorType(tensor.dtype)
|
|
549
|
+
if value.type is not None and value.type != new_value_type:
|
|
550
|
+
logger.warning(
|
|
551
|
+
"Value '%s' has a type '%s' that differs from "
|
|
552
|
+
"the constant tensor's type '%s'. The value's type will be updated.",
|
|
553
|
+
value,
|
|
554
|
+
value.type,
|
|
555
|
+
new_value_type,
|
|
556
|
+
)
|
|
557
|
+
value.type = new_value_type
|
|
558
|
+
return tensor
|