onnx-ir 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +154 -0
- onnx_ir/_convenience.py +439 -0
- onnx_ir/_core.py +2875 -0
- onnx_ir/_display.py +49 -0
- onnx_ir/_enums.py +154 -0
- onnx_ir/_external_data.py +323 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_internal/version_utils.py +118 -0
- onnx_ir/_io.py +50 -0
- onnx_ir/_linked_list.py +276 -0
- onnx_ir/_metadata.py +44 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_protocols.py +598 -0
- onnx_ir/_tape.py +104 -0
- onnx_ir/_thirdparty/asciichartpy.py +313 -0
- onnx_ir/_type_casting.py +91 -0
- onnx_ir/convenience.py +32 -0
- onnx_ir/passes/__init__.py +33 -0
- onnx_ir/passes/_pass_infra.py +172 -0
- onnx_ir/serde.py +1551 -0
- onnx_ir/traversal.py +82 -0
- onnx_ir-0.0.1.dist-info/LICENSE +22 -0
- onnx_ir-0.0.1.dist-info/METADATA +73 -0
- onnx_ir-0.0.1.dist-info/RECORD +26 -0
- onnx_ir-0.0.1.dist-info/WHEEL +5 -0
- onnx_ir-0.0.1.dist-info/top_level.txt +1 -0
onnx_ir/__init__.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""In-memory intermediate representation for ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
# Modules
|
|
7
|
+
"serde",
|
|
8
|
+
# IR classes
|
|
9
|
+
"Tensor",
|
|
10
|
+
"ExternalTensor",
|
|
11
|
+
"StringTensor",
|
|
12
|
+
"SymbolicDim",
|
|
13
|
+
"Shape",
|
|
14
|
+
"TensorType",
|
|
15
|
+
"OptionalType",
|
|
16
|
+
"SequenceType",
|
|
17
|
+
"SparseTensorType",
|
|
18
|
+
"TypeAndShape",
|
|
19
|
+
"Value",
|
|
20
|
+
"Attr",
|
|
21
|
+
"RefAttr",
|
|
22
|
+
"Node",
|
|
23
|
+
"Function",
|
|
24
|
+
"Graph",
|
|
25
|
+
"GraphView",
|
|
26
|
+
"Model",
|
|
27
|
+
# Constructors
|
|
28
|
+
"AttrFloat32",
|
|
29
|
+
"AttrFloat32s",
|
|
30
|
+
"AttrGraph",
|
|
31
|
+
"AttrGraphs",
|
|
32
|
+
"AttrInt64",
|
|
33
|
+
"AttrInt64s",
|
|
34
|
+
"AttrSparseTensor",
|
|
35
|
+
"AttrSparseTensors",
|
|
36
|
+
"AttrString",
|
|
37
|
+
"AttrStrings",
|
|
38
|
+
"AttrTensor",
|
|
39
|
+
"AttrTensors",
|
|
40
|
+
"AttrTypeProto",
|
|
41
|
+
"AttrTypeProtos",
|
|
42
|
+
"Input",
|
|
43
|
+
# Protocols
|
|
44
|
+
"ArrayCompatible",
|
|
45
|
+
"DLPackCompatible",
|
|
46
|
+
"TensorProtocol",
|
|
47
|
+
"ValueProtocol",
|
|
48
|
+
"ModelProtocol",
|
|
49
|
+
"NodeProtocol",
|
|
50
|
+
"GraphProtocol",
|
|
51
|
+
"GraphViewProtocol",
|
|
52
|
+
"AttributeProtocol",
|
|
53
|
+
"ReferenceAttributeProtocol",
|
|
54
|
+
"SparseTensorProtocol",
|
|
55
|
+
"SymbolicDimProtocol",
|
|
56
|
+
"ShapeProtocol",
|
|
57
|
+
"TypeProtocol",
|
|
58
|
+
"MapTypeProtocol",
|
|
59
|
+
"FunctionProtocol",
|
|
60
|
+
# Enums
|
|
61
|
+
"AttributeType",
|
|
62
|
+
"DataType",
|
|
63
|
+
# Types
|
|
64
|
+
"OperatorIdentifier",
|
|
65
|
+
# Protobuf compatible types
|
|
66
|
+
"TensorProtoTensor",
|
|
67
|
+
# Conversion functions
|
|
68
|
+
"from_proto",
|
|
69
|
+
"to_proto",
|
|
70
|
+
# IR Tensor initializer
|
|
71
|
+
"tensor",
|
|
72
|
+
# Pass infrastructure
|
|
73
|
+
"passes",
|
|
74
|
+
"traversal",
|
|
75
|
+
# IO
|
|
76
|
+
"load",
|
|
77
|
+
"save",
|
|
78
|
+
]
|
|
79
|
+
|
|
80
|
+
from onnx_ir import passes, serde, traversal
|
|
81
|
+
from onnx_ir._convenience import tensor
|
|
82
|
+
from onnx_ir._core import (
|
|
83
|
+
Attr,
|
|
84
|
+
AttrFloat32,
|
|
85
|
+
AttrFloat32s,
|
|
86
|
+
AttrGraph,
|
|
87
|
+
AttrGraphs,
|
|
88
|
+
AttrInt64,
|
|
89
|
+
AttrInt64s,
|
|
90
|
+
AttrSparseTensor,
|
|
91
|
+
AttrSparseTensors,
|
|
92
|
+
AttrString,
|
|
93
|
+
AttrStrings,
|
|
94
|
+
AttrTensor,
|
|
95
|
+
AttrTensors,
|
|
96
|
+
AttrTypeProto,
|
|
97
|
+
AttrTypeProtos,
|
|
98
|
+
ExternalTensor,
|
|
99
|
+
Function,
|
|
100
|
+
Graph,
|
|
101
|
+
GraphView,
|
|
102
|
+
Input,
|
|
103
|
+
Model,
|
|
104
|
+
Node,
|
|
105
|
+
OptionalType,
|
|
106
|
+
RefAttr,
|
|
107
|
+
SequenceType,
|
|
108
|
+
Shape,
|
|
109
|
+
SparseTensorType,
|
|
110
|
+
StringTensor,
|
|
111
|
+
SymbolicDim,
|
|
112
|
+
Tensor,
|
|
113
|
+
TensorType,
|
|
114
|
+
TypeAndShape,
|
|
115
|
+
Value,
|
|
116
|
+
)
|
|
117
|
+
from onnx_ir._enums import (
|
|
118
|
+
AttributeType,
|
|
119
|
+
DataType,
|
|
120
|
+
)
|
|
121
|
+
from onnx_ir._io import load, save
|
|
122
|
+
from onnx_ir._protocols import (
|
|
123
|
+
ArrayCompatible,
|
|
124
|
+
AttributeProtocol,
|
|
125
|
+
DLPackCompatible,
|
|
126
|
+
FunctionProtocol,
|
|
127
|
+
GraphProtocol,
|
|
128
|
+
GraphViewProtocol,
|
|
129
|
+
MapTypeProtocol,
|
|
130
|
+
ModelProtocol,
|
|
131
|
+
NodeProtocol,
|
|
132
|
+
OperatorIdentifier,
|
|
133
|
+
ReferenceAttributeProtocol,
|
|
134
|
+
ShapeProtocol,
|
|
135
|
+
SparseTensorProtocol,
|
|
136
|
+
SymbolicDimProtocol,
|
|
137
|
+
TensorProtocol,
|
|
138
|
+
TypeProtocol,
|
|
139
|
+
ValueProtocol,
|
|
140
|
+
)
|
|
141
|
+
from onnx_ir.serde import TensorProtoTensor, from_proto, to_proto
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
DEBUG: bool = False
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def __set_module() -> None:
|
|
148
|
+
"""Set the module of all functions in this module to this public module."""
|
|
149
|
+
global_dict = globals()
|
|
150
|
+
for name in __all__:
|
|
151
|
+
global_dict[name].__module__ = __name__
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
__set_module()
|
onnx_ir/_convenience.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
"""Convenience methods for constructing and manipulating the IR.
|
|
4
|
+
|
|
5
|
+
This is an internal only module. We should choose to expose some of the methods
|
|
6
|
+
in convenience.py after they are proven to be useful.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"convert_attribute",
|
|
13
|
+
"convert_attributes",
|
|
14
|
+
"replace_all_uses_with",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
import typing
|
|
18
|
+
from typing import Mapping, Sequence, Union
|
|
19
|
+
|
|
20
|
+
import numpy as np
|
|
21
|
+
import onnx
|
|
22
|
+
|
|
23
|
+
from onnx_ir import _core, _enums, _protocols, serde
|
|
24
|
+
|
|
25
|
+
if typing.TYPE_CHECKING:
|
|
26
|
+
import numpy.typing as npt
|
|
27
|
+
|
|
28
|
+
SupportedAttrTypes = Union[
|
|
29
|
+
str,
|
|
30
|
+
int,
|
|
31
|
+
float,
|
|
32
|
+
Sequence[int],
|
|
33
|
+
Sequence[float],
|
|
34
|
+
Sequence[str],
|
|
35
|
+
_protocols.TensorProtocol, # This includes all in-memory tensor types
|
|
36
|
+
onnx.TensorProto,
|
|
37
|
+
_core.Attr,
|
|
38
|
+
_core.RefAttr,
|
|
39
|
+
_protocols.GraphProtocol,
|
|
40
|
+
Sequence[_protocols.GraphProtocol],
|
|
41
|
+
_protocols.TypeProtocol,
|
|
42
|
+
Sequence[_protocols.TypeProtocol],
|
|
43
|
+
None,
|
|
44
|
+
]
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
48
|
+
"""Infer the attribute type based on the type of the Python object."""
|
|
49
|
+
if isinstance(attr, int):
|
|
50
|
+
return _enums.AttributeType.INT
|
|
51
|
+
if isinstance(attr, float):
|
|
52
|
+
return _enums.AttributeType.FLOAT
|
|
53
|
+
if isinstance(attr, str):
|
|
54
|
+
return _enums.AttributeType.STRING
|
|
55
|
+
if isinstance(attr, (_core.Attr, _core.RefAttr)):
|
|
56
|
+
return attr.type
|
|
57
|
+
if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
|
|
58
|
+
return _enums.AttributeType.INTS
|
|
59
|
+
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
|
|
60
|
+
return _enums.AttributeType.FLOATS
|
|
61
|
+
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
|
|
62
|
+
return _enums.AttributeType.STRINGS
|
|
63
|
+
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
|
|
64
|
+
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
|
|
65
|
+
return _enums.AttributeType.TENSOR
|
|
66
|
+
if isinstance(attr, (_core.Graph, _protocols.GraphProtocol)):
|
|
67
|
+
return _enums.AttributeType.GRAPH
|
|
68
|
+
if isinstance(attr, Sequence) and all(
|
|
69
|
+
isinstance(x, (_core.Graph, _protocols.GraphProtocol)) for x in attr
|
|
70
|
+
):
|
|
71
|
+
return _enums.AttributeType.GRAPHS
|
|
72
|
+
if isinstance(
|
|
73
|
+
attr,
|
|
74
|
+
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
|
|
75
|
+
):
|
|
76
|
+
return _enums.AttributeType.TYPE_PROTO
|
|
77
|
+
if isinstance(attr, Sequence) and all(
|
|
78
|
+
isinstance(
|
|
79
|
+
x,
|
|
80
|
+
(
|
|
81
|
+
_core.TensorType,
|
|
82
|
+
_core.SequenceType,
|
|
83
|
+
_core.OptionalType,
|
|
84
|
+
_protocols.TypeProtocol,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
for x in attr
|
|
88
|
+
):
|
|
89
|
+
return _enums.AttributeType.TYPE_PROTOS
|
|
90
|
+
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def convert_attribute(
|
|
94
|
+
name: str,
|
|
95
|
+
attr: SupportedAttrTypes,
|
|
96
|
+
attr_type: _enums.AttributeType | None = None,
|
|
97
|
+
) -> _core.Attr | _core.RefAttr:
|
|
98
|
+
"""Convert a Python object to a _core.Attr object.
|
|
99
|
+
|
|
100
|
+
This method is useful when constructing nodes with attributes. It infers the
|
|
101
|
+
attribute type based on the type of the Python value.
|
|
102
|
+
|
|
103
|
+
Args:
|
|
104
|
+
name: The name of the attribute.
|
|
105
|
+
attr: The value of the attribute.
|
|
106
|
+
attr_type: The type of the attribute. This is required when attr is None.
|
|
107
|
+
When provided, it overrides the inferred type.
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
A ``Attr`` object.
|
|
111
|
+
|
|
112
|
+
Raises:
|
|
113
|
+
ValueError: If :param:`attr` is ``None`` and :param:`attr_type` is not provided.
|
|
114
|
+
TypeError: If the type of the attribute is not supported.
|
|
115
|
+
"""
|
|
116
|
+
if attr is None:
|
|
117
|
+
if attr_type is None:
|
|
118
|
+
raise ValueError("attr_type must be provided when attr is None")
|
|
119
|
+
return _core.Attr(name, attr_type, None)
|
|
120
|
+
|
|
121
|
+
if isinstance(attr, (_core.Attr, _core.RefAttr)):
|
|
122
|
+
if attr.name != name:
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Attribute name '{attr.name}' does not match provided name '{name}'"
|
|
125
|
+
)
|
|
126
|
+
if attr_type is not None and attr.type != attr_type:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
f"Attribute type '{attr.type}' does not match provided type '{attr_type}'"
|
|
129
|
+
)
|
|
130
|
+
return attr
|
|
131
|
+
|
|
132
|
+
if attr_type is None:
|
|
133
|
+
attr_type = _infer_attribute_type(attr)
|
|
134
|
+
|
|
135
|
+
if attr_type == _enums.AttributeType.INT:
|
|
136
|
+
return _core.AttrInt64(name, attr) # type: ignore
|
|
137
|
+
if attr_type == _enums.AttributeType.FLOAT:
|
|
138
|
+
return _core.AttrFloat32(name, attr) # type: ignore
|
|
139
|
+
if attr_type == _enums.AttributeType.STRING:
|
|
140
|
+
return _core.AttrString(name, attr) # type: ignore
|
|
141
|
+
if attr_type == _enums.AttributeType.INTS:
|
|
142
|
+
return _core.AttrInt64s(name, attr) # type: ignore
|
|
143
|
+
if attr_type == _enums.AttributeType.FLOATS:
|
|
144
|
+
return _core.AttrFloat32s(name, attr) # type: ignore
|
|
145
|
+
if attr_type == _enums.AttributeType.STRINGS:
|
|
146
|
+
return _core.AttrStrings(name, attr) # type: ignore
|
|
147
|
+
if attr_type == _enums.AttributeType.TENSOR:
|
|
148
|
+
if isinstance(attr, (_core.TensorBase, _protocols.TensorProtocol)):
|
|
149
|
+
return _core.AttrTensor(name, attr)
|
|
150
|
+
if isinstance(attr, onnx.TensorProto):
|
|
151
|
+
return _core.AttrTensor(name, serde.TensorProtoTensor(attr))
|
|
152
|
+
if attr_type == _enums.AttributeType.GRAPH:
|
|
153
|
+
return _core.AttrGraph(name, attr) # type: ignore[arg-type]
|
|
154
|
+
if attr_type == _enums.AttributeType.GRAPHS:
|
|
155
|
+
return _core.AttrGraphs(name, attr) # type: ignore[arg-type]
|
|
156
|
+
if attr_type == _enums.AttributeType.TYPE_PROTO:
|
|
157
|
+
return _core.AttrTypeProto(name, attr) # type: ignore[arg-type]
|
|
158
|
+
if attr_type == _enums.AttributeType.TYPE_PROTOS:
|
|
159
|
+
return _core.AttrTypeProtos(name, attr) # type: ignore[arg-type]
|
|
160
|
+
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def convert_attributes(
|
|
164
|
+
attrs: Mapping[str, SupportedAttrTypes],
|
|
165
|
+
) -> list[_core.Attr | _core.RefAttr]:
|
|
166
|
+
"""Convert a dictionary of attributes to a list of _core.Attr objects.
|
|
167
|
+
|
|
168
|
+
It infers the attribute type based on the type of the value. The supported
|
|
169
|
+
types are: int, float, str, Sequence[int], Sequence[float], Sequence[str],
|
|
170
|
+
:class:`_core.Tensor`, and :class:`_core.Attr`::
|
|
171
|
+
|
|
172
|
+
>>> import onnx_ir as ir
|
|
173
|
+
>>> import onnx
|
|
174
|
+
>>> import numpy as np
|
|
175
|
+
>>> attrs = {
|
|
176
|
+
... "int": 1,
|
|
177
|
+
... "float": 1.0,
|
|
178
|
+
... "str": "hello",
|
|
179
|
+
... "ints": [1, 2, 3],
|
|
180
|
+
... "floats": [1.0, 2.0, 3.0],
|
|
181
|
+
... "strings": ["hello", "world"],
|
|
182
|
+
... "tensor": ir.Tensor(np.array([1.0, 2.0, 3.0])),
|
|
183
|
+
... "tensor_proto":
|
|
184
|
+
... onnx.TensorProto(
|
|
185
|
+
... dims=[3],
|
|
186
|
+
... data_type=onnx.TensorProto.FLOAT,
|
|
187
|
+
... float_data=[1.0, 2.0, 3.0],
|
|
188
|
+
... name="proto",
|
|
189
|
+
... ),
|
|
190
|
+
... "graph": ir.Graph([], [], nodes=[], name="graph0"),
|
|
191
|
+
... "graphs": [ir.Graph([], [], nodes=[], name="graph1"), ir.Graph([], [], nodes=[], name="graph2")],
|
|
192
|
+
... "type_proto": ir.TensorType(ir.DataType.FLOAT),
|
|
193
|
+
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
|
|
194
|
+
... }
|
|
195
|
+
>>> convert_attributes(attrs)
|
|
196
|
+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(name='proto')), Attr('graph', INTS, Graph(
|
|
197
|
+
name='graph0',
|
|
198
|
+
inputs=(
|
|
199
|
+
<BLANKLINE>
|
|
200
|
+
),
|
|
201
|
+
outputs=(
|
|
202
|
+
<BLANKLINE>
|
|
203
|
+
),
|
|
204
|
+
len()=0
|
|
205
|
+
)), Attr('graphs', GRAPHS, [Graph(
|
|
206
|
+
name='graph1',
|
|
207
|
+
inputs=(
|
|
208
|
+
<BLANKLINE>
|
|
209
|
+
),
|
|
210
|
+
outputs=(
|
|
211
|
+
<BLANKLINE>
|
|
212
|
+
),
|
|
213
|
+
len()=0
|
|
214
|
+
), Graph(
|
|
215
|
+
name='graph2',
|
|
216
|
+
inputs=(
|
|
217
|
+
<BLANKLINE>
|
|
218
|
+
),
|
|
219
|
+
outputs=(
|
|
220
|
+
<BLANKLINE>
|
|
221
|
+
),
|
|
222
|
+
len()=0
|
|
223
|
+
)]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
|
|
224
|
+
|
|
225
|
+
Args:
|
|
226
|
+
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
A list of _core.Attr objects.
|
|
230
|
+
"""
|
|
231
|
+
attributes: list[_core.Attr | _core.RefAttr] = []
|
|
232
|
+
for name, attr in attrs.items():
|
|
233
|
+
if attr is not None:
|
|
234
|
+
attributes.append(convert_attribute(name, attr))
|
|
235
|
+
return attributes
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def replace_all_uses_with(
|
|
239
|
+
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
240
|
+
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
241
|
+
) -> None:
|
|
242
|
+
"""Replace all uses of the given values with the replacements.
|
|
243
|
+
|
|
244
|
+
This is useful when nodes in the graph are replaced with new nodes, where
|
|
245
|
+
the old users need to be updated to use the outputs of the new nodes.
|
|
246
|
+
|
|
247
|
+
For example, suppose we have the following graph::
|
|
248
|
+
|
|
249
|
+
A -> {B, C}
|
|
250
|
+
|
|
251
|
+
We want to replace the node A with a new node D::
|
|
252
|
+
|
|
253
|
+
>>> import onnx_ir as ir
|
|
254
|
+
>>> input = ir.Input("input")
|
|
255
|
+
>>> node_a = ir.Node("", "A", [input])
|
|
256
|
+
>>> node_b = ir.Node("", "B", node_a.outputs)
|
|
257
|
+
>>> node_c = ir.Node("", "C", node_a.outputs)
|
|
258
|
+
>>> node_d = ir.Node("", "D", [input])
|
|
259
|
+
>>> replace_all_uses_with(node_a.outputs, node_d.outputs)
|
|
260
|
+
>>> len(node_b.inputs)
|
|
261
|
+
1
|
|
262
|
+
>>> node_b.inputs[0].producer().op_type
|
|
263
|
+
'D'
|
|
264
|
+
>>> len(node_c.inputs)
|
|
265
|
+
1
|
|
266
|
+
>>> node_c.inputs[0].producer().op_type
|
|
267
|
+
'D'
|
|
268
|
+
>>> len(node_a.outputs[0].uses())
|
|
269
|
+
0
|
|
270
|
+
|
|
271
|
+
When values and replacements are sequences, they are zipped into pairs. All
|
|
272
|
+
users of the first value is replaced with the first replacement, and so on.
|
|
273
|
+
|
|
274
|
+
.. note::
|
|
275
|
+
You still need to update the graph outputs if any of the values being
|
|
276
|
+
replaced are part of the graph outputs. Be sure to remove the old nodes
|
|
277
|
+
from the graph using ``graph.remove()`` if they are no longer needed.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
values: The value or values to be replaced.
|
|
281
|
+
replacements: The new value or values to use as inputs.
|
|
282
|
+
"""
|
|
283
|
+
if not isinstance(values, Sequence):
|
|
284
|
+
values = (values,)
|
|
285
|
+
if not isinstance(replacements, Sequence):
|
|
286
|
+
replacements = (replacements,)
|
|
287
|
+
if len(values) != len(replacements):
|
|
288
|
+
raise ValueError("The number of values and replacements must match.")
|
|
289
|
+
for value, replacement in zip(values, replacements):
|
|
290
|
+
for user_node, index in tuple(value.uses()):
|
|
291
|
+
user_node.replace_input_with(index, replacement)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def tensor(
|
|
295
|
+
value: npt.ArrayLike
|
|
296
|
+
| onnx.TensorProto
|
|
297
|
+
| _protocols.DLPackCompatible
|
|
298
|
+
| _protocols.ArrayCompatible,
|
|
299
|
+
dtype: _enums.DataType | None = None,
|
|
300
|
+
name: str | None = None,
|
|
301
|
+
doc_string: str | None = None,
|
|
302
|
+
) -> _protocols.TensorProtocol:
|
|
303
|
+
"""Create a tensor value from an ArrayLike object or a TensorProto.
|
|
304
|
+
|
|
305
|
+
The dtype must match the value. Reinterpretation of the value is
|
|
306
|
+
not supported, unless if the value is a plain Python object, in which case
|
|
307
|
+
it is converted to a numpy array with the given dtype.
|
|
308
|
+
|
|
309
|
+
:param:`value` can be a numpy array, a plain Python object, or a TensorProto.
|
|
310
|
+
|
|
311
|
+
Example::
|
|
312
|
+
|
|
313
|
+
>>> import onnx_ir as ir
|
|
314
|
+
>>> import numpy as np
|
|
315
|
+
>>> import ml_dtypes
|
|
316
|
+
>>> import onnx
|
|
317
|
+
>>> ir.tensor(np.array([1, 2, 3], dtype=np.int16))
|
|
318
|
+
Tensor<INT16,[3]>(array([1, 2, 3], dtype=int16), name=None)
|
|
319
|
+
>>> ir.tensor([1, 2, 3], dtype=ir.DataType.BFLOAT16)
|
|
320
|
+
Tensor<BFLOAT16,[3]>(array([1, 2, 3], dtype=bfloat16), name=None)
|
|
321
|
+
>>> tp_tensor = ir.tensor(onnx.helper.make_tensor("tensor", onnx.TensorProto.FLOAT, dims=[], vals=[0.5]))
|
|
322
|
+
>>> tp_tensor.numpy()
|
|
323
|
+
array(0.5, dtype=float32)
|
|
324
|
+
|
|
325
|
+
Args:
|
|
326
|
+
value: The numpy array to create the tensor from.
|
|
327
|
+
dtype: The data type of the tensor.
|
|
328
|
+
name: The name of the tensor.
|
|
329
|
+
doc_string: The documentation string of the tensor.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
A tensor value.
|
|
333
|
+
|
|
334
|
+
Raises:
|
|
335
|
+
ValueError: If the dtype does not match the value when value is not a plain Python
|
|
336
|
+
object like ``list[int]``.
|
|
337
|
+
"""
|
|
338
|
+
if isinstance(value, _protocols.TensorProtocol):
|
|
339
|
+
if dtype is not None and dtype != value.dtype:
|
|
340
|
+
raise ValueError(
|
|
341
|
+
f"The dtype must match the value when value is a Tensor. dtype={dtype}, value.dtype={value.dtype}. "
|
|
342
|
+
"You do not have to specify the dtype when value is a Tensor."
|
|
343
|
+
)
|
|
344
|
+
return value
|
|
345
|
+
if isinstance(value, onnx.TensorProto):
|
|
346
|
+
tensor_ = serde.deserialize_tensor(value)
|
|
347
|
+
if name is not None:
|
|
348
|
+
tensor_.name = name
|
|
349
|
+
if doc_string is not None:
|
|
350
|
+
tensor_.doc_string = doc_string
|
|
351
|
+
if dtype is not None and dtype != tensor_.dtype:
|
|
352
|
+
raise ValueError(
|
|
353
|
+
f"The dtype must match the value when value is a TensorProto. dtype={dtype}, value.data_type={tensor_.dtype}"
|
|
354
|
+
"You do not have to specify the dtype when value is a TensorProto."
|
|
355
|
+
)
|
|
356
|
+
elif isinstance(value, (_protocols.DLPackCompatible, _protocols.ArrayCompatible)):
|
|
357
|
+
tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
|
|
358
|
+
else:
|
|
359
|
+
if dtype is not None:
|
|
360
|
+
numpy_dtype = dtype.numpy()
|
|
361
|
+
else:
|
|
362
|
+
numpy_dtype = None
|
|
363
|
+
array = np.array(value, dtype=numpy_dtype)
|
|
364
|
+
tensor_ = _core.Tensor(
|
|
365
|
+
array,
|
|
366
|
+
dtype=dtype,
|
|
367
|
+
shape=_core.Shape(array.shape),
|
|
368
|
+
name=name,
|
|
369
|
+
doc_string=name,
|
|
370
|
+
)
|
|
371
|
+
return tensor_
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
375
|
+
"""Return a dictionary mapping names to values in the graph.
|
|
376
|
+
|
|
377
|
+
The mapping does not include values from subgraphs.
|
|
378
|
+
|
|
379
|
+
Args:
|
|
380
|
+
graph: The graph to extract the mapping from.
|
|
381
|
+
|
|
382
|
+
Returns:
|
|
383
|
+
A dictionary mapping names to values.
|
|
384
|
+
"""
|
|
385
|
+
values = {}
|
|
386
|
+
values.update(graph.initializers)
|
|
387
|
+
# The names of the values can be None or "", which we need to exclude
|
|
388
|
+
for input in graph.inputs:
|
|
389
|
+
if not input.name:
|
|
390
|
+
continue
|
|
391
|
+
values[input.name] = input
|
|
392
|
+
for node in graph:
|
|
393
|
+
for value in node.outputs:
|
|
394
|
+
if not value.name:
|
|
395
|
+
continue
|
|
396
|
+
values[value.name] = value
|
|
397
|
+
return values
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
def replace_nodes_and_values(
|
|
401
|
+
graph_or_function: _core.Graph | _core.Function,
|
|
402
|
+
/,
|
|
403
|
+
insertion_point: _core.Node,
|
|
404
|
+
old_nodes: Sequence[_core.Node],
|
|
405
|
+
new_nodes: Sequence[_core.Node],
|
|
406
|
+
old_values: Sequence[_core.Value],
|
|
407
|
+
new_values: Sequence[_core.Value],
|
|
408
|
+
) -> None:
|
|
409
|
+
"""Replaces nodes and values in the graph or function.
|
|
410
|
+
|
|
411
|
+
Args:
|
|
412
|
+
graph_or_function: The graph or function to replace nodes and values in.
|
|
413
|
+
insertion_point: The node to insert the new nodes after.
|
|
414
|
+
old_nodes: The nodes to replace.
|
|
415
|
+
new_nodes: The nodes to replace with.
|
|
416
|
+
old_values: The values to replace.
|
|
417
|
+
new_values: The values to replace with.
|
|
418
|
+
"""
|
|
419
|
+
|
|
420
|
+
for old_value, new_value in zip(old_values, new_values):
|
|
421
|
+
# Propagate relevant info from old value to new value
|
|
422
|
+
# TODO(Rama): Perhaps this should be a separate utility function. Also, consider
|
|
423
|
+
# merging old and new type/shape info.
|
|
424
|
+
new_value.type = old_value.type
|
|
425
|
+
new_value.shape = old_value.shape
|
|
426
|
+
new_value.const_value = old_value.const_value
|
|
427
|
+
new_value.name = old_value.name
|
|
428
|
+
|
|
429
|
+
# Reconnect the users of the deleted values to use the new values
|
|
430
|
+
replace_all_uses_with(old_values, new_values)
|
|
431
|
+
# Update graph/function outputs if the node generates output
|
|
432
|
+
replacement_mapping = dict(zip(old_values, new_values))
|
|
433
|
+
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
|
|
434
|
+
if graph_or_function_output in replacement_mapping:
|
|
435
|
+
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
|
|
436
|
+
|
|
437
|
+
# insert new nodes after the index node
|
|
438
|
+
graph_or_function.insert_after(insertion_point, new_nodes)
|
|
439
|
+
graph_or_function.remove(old_nodes, safe=True)
|