onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""data structures for the intermediate representation."""
|
|
4
4
|
|
|
5
5
|
# NOTES for developers:
|
|
@@ -22,26 +22,37 @@ import os
|
|
|
22
22
|
import sys
|
|
23
23
|
import textwrap
|
|
24
24
|
import typing
|
|
25
|
-
from
|
|
26
|
-
|
|
27
|
-
Any,
|
|
25
|
+
from collections import OrderedDict
|
|
26
|
+
from collections.abc import (
|
|
28
27
|
Collection,
|
|
29
|
-
Generic,
|
|
30
28
|
Hashable,
|
|
31
29
|
Iterable,
|
|
32
30
|
Iterator,
|
|
33
|
-
|
|
31
|
+
MutableMapping,
|
|
32
|
+
MutableSequence,
|
|
34
33
|
Sequence,
|
|
34
|
+
)
|
|
35
|
+
from collections.abc import (
|
|
36
|
+
Set as AbstractSet,
|
|
37
|
+
)
|
|
38
|
+
from typing import (
|
|
39
|
+
Any,
|
|
40
|
+
Callable,
|
|
41
|
+
Generic,
|
|
42
|
+
NamedTuple,
|
|
43
|
+
SupportsInt,
|
|
35
44
|
Union,
|
|
36
45
|
)
|
|
37
46
|
|
|
38
47
|
import ml_dtypes
|
|
39
48
|
import numpy as np
|
|
49
|
+
from typing_extensions import TypeIs
|
|
40
50
|
|
|
41
51
|
import onnx_ir
|
|
42
52
|
from onnx_ir import (
|
|
43
53
|
_display,
|
|
44
54
|
_enums,
|
|
55
|
+
_graph_containers,
|
|
45
56
|
_linked_list,
|
|
46
57
|
_metadata,
|
|
47
58
|
_name_authority,
|
|
@@ -70,6 +81,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
|
|
|
70
81
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
71
82
|
_enums.DataType.INT4,
|
|
72
83
|
_enums.DataType.UINT4,
|
|
84
|
+
_enums.DataType.FLOAT4E2M1,
|
|
73
85
|
)
|
|
74
86
|
)
|
|
75
87
|
|
|
@@ -93,7 +105,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
|
|
|
93
105
|
class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
94
106
|
"""Convenience Shared methods for classes implementing TensorProtocol."""
|
|
95
107
|
|
|
96
|
-
__slots__ = (
|
|
108
|
+
__slots__ = (
|
|
109
|
+
"_doc_string",
|
|
110
|
+
"_metadata",
|
|
111
|
+
"_metadata_props",
|
|
112
|
+
"_name",
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
name: str | None = None,
|
|
118
|
+
doc_string: str | None = None,
|
|
119
|
+
metadata_props: dict[str, str] | None = None,
|
|
120
|
+
) -> None:
|
|
121
|
+
self._metadata: _metadata.MetadataStore | None = None
|
|
122
|
+
self._metadata_props: dict[str, str] | None = metadata_props
|
|
123
|
+
self._name: str | None = name
|
|
124
|
+
self._doc_string: str | None = doc_string
|
|
97
125
|
|
|
98
126
|
def _printable_type_shape(self) -> str:
|
|
99
127
|
"""Return a string representation of the shape and data type."""
|
|
@@ -106,10 +134,28 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
|
106
134
|
"""
|
|
107
135
|
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
|
|
108
136
|
|
|
137
|
+
@property
|
|
138
|
+
def name(self) -> str | None:
|
|
139
|
+
"""The name of the tensor."""
|
|
140
|
+
return self._name
|
|
141
|
+
|
|
142
|
+
@name.setter
|
|
143
|
+
def name(self, value: str | None) -> None:
|
|
144
|
+
self._name = value
|
|
145
|
+
|
|
146
|
+
@property
|
|
147
|
+
def doc_string(self) -> str | None:
|
|
148
|
+
"""The documentation string."""
|
|
149
|
+
return self._doc_string
|
|
150
|
+
|
|
151
|
+
@doc_string.setter
|
|
152
|
+
def doc_string(self, value: str | None) -> None:
|
|
153
|
+
self._doc_string = value
|
|
154
|
+
|
|
109
155
|
@property
|
|
110
156
|
def size(self) -> int:
|
|
111
157
|
"""The number of elements in the tensor."""
|
|
112
|
-
return
|
|
158
|
+
return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
|
|
113
159
|
|
|
114
160
|
@property
|
|
115
161
|
def nbytes(self) -> int:
|
|
@@ -117,6 +163,23 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
|
117
163
|
# Use math.ceil because when dtype is INT4, the itemsize is 0.5
|
|
118
164
|
return math.ceil(self.dtype.itemsize * self.size)
|
|
119
165
|
|
|
166
|
+
@property
|
|
167
|
+
def metadata_props(self) -> dict[str, str]:
|
|
168
|
+
if self._metadata_props is None:
|
|
169
|
+
self._metadata_props = {}
|
|
170
|
+
return self._metadata_props
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def meta(self) -> _metadata.MetadataStore:
|
|
174
|
+
"""The metadata store for intermediate analysis.
|
|
175
|
+
|
|
176
|
+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
177
|
+
to the ONNX proto.
|
|
178
|
+
"""
|
|
179
|
+
if self._metadata is None:
|
|
180
|
+
self._metadata = _metadata.MetadataStore()
|
|
181
|
+
return self._metadata
|
|
182
|
+
|
|
120
183
|
def display(self, *, page: bool = False) -> None:
|
|
121
184
|
rich = _display.require_rich()
|
|
122
185
|
|
|
@@ -182,7 +245,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
182
245
|
When the dtype is not one of the numpy native dtypes, the value needs need to be:
|
|
183
246
|
|
|
184
247
|
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
|
|
185
|
-
- ``uint8`` for uint4.
|
|
248
|
+
- ``uint8`` for uint4 or float4.
|
|
186
249
|
- ``uint8`` for 8-bit data types.
|
|
187
250
|
- ``uint16`` for bfloat16
|
|
188
251
|
|
|
@@ -195,7 +258,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
195
258
|
)
|
|
196
259
|
if dtype.itemsize == 1 and array.dtype not in (
|
|
197
260
|
np.uint8,
|
|
198
|
-
ml_dtypes.
|
|
261
|
+
ml_dtypes.float8_e4m3fnuz,
|
|
199
262
|
ml_dtypes.float8_e4m3fn,
|
|
200
263
|
ml_dtypes.float8_e5m2fnuz,
|
|
201
264
|
ml_dtypes.float8_e5m2,
|
|
@@ -213,6 +276,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
213
276
|
raise TypeError(
|
|
214
277
|
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
|
|
215
278
|
)
|
|
279
|
+
if dtype == _enums.DataType.FLOAT4E2M1:
|
|
280
|
+
if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
|
|
281
|
+
raise TypeError(
|
|
282
|
+
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
|
|
283
|
+
)
|
|
216
284
|
return
|
|
217
285
|
|
|
218
286
|
try:
|
|
@@ -256,6 +324,8 @@ def _maybe_view_np_array_with_ml_dtypes(
|
|
|
256
324
|
return array.view(ml_dtypes.int4)
|
|
257
325
|
if dtype == _enums.DataType.UINT4:
|
|
258
326
|
return array.view(ml_dtypes.uint4)
|
|
327
|
+
if dtype == _enums.DataType.FLOAT4E2M1:
|
|
328
|
+
return array.view(ml_dtypes.float4_e2m1fn)
|
|
259
329
|
return array
|
|
260
330
|
|
|
261
331
|
|
|
@@ -298,12 +368,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
298
368
|
|
|
299
369
|
__slots__ = (
|
|
300
370
|
"_dtype",
|
|
301
|
-
"_metadata",
|
|
302
|
-
"_metadata_props",
|
|
303
371
|
"_raw",
|
|
304
372
|
"_shape",
|
|
305
|
-
"doc_string",
|
|
306
|
-
"name",
|
|
307
373
|
)
|
|
308
374
|
|
|
309
375
|
def __init__(
|
|
@@ -322,7 +388,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
322
388
|
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
323
389
|
When the dtype is not one of the numpy native dtypes, the value needs
|
|
324
390
|
to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
|
|
325
|
-
when the value is a numpy array;
|
|
391
|
+
when the value is a numpy array; ``dtype`` must be specified in this case.
|
|
326
392
|
dtype: The data type of the tensor. It can be None only when value is a numpy array.
|
|
327
393
|
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
|
|
328
394
|
shape: The shape of the tensor. If None, the shape is obtained from the value.
|
|
@@ -336,6 +402,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
336
402
|
ValueError: If the shape is not specified and the value does not have a shape attribute.
|
|
337
403
|
ValueError: If the dtype is not specified and the value is not a numpy array.
|
|
338
404
|
"""
|
|
405
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
339
406
|
# NOTE: We should not do any copying here for performance reasons
|
|
340
407
|
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
|
|
341
408
|
raise TypeError(f"Expected an array compatible object, got {type(value)}")
|
|
@@ -349,7 +416,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
349
416
|
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
|
|
350
417
|
else:
|
|
351
418
|
self._shape = shape
|
|
352
|
-
self._shape.
|
|
419
|
+
self._shape.freeze()
|
|
353
420
|
if dtype is None:
|
|
354
421
|
if isinstance(value, np.ndarray):
|
|
355
422
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
@@ -370,17 +437,13 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
370
437
|
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
|
|
371
438
|
|
|
372
439
|
self._raw = value
|
|
373
|
-
self.name = name
|
|
374
|
-
self.doc_string = doc_string
|
|
375
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
376
|
-
self._metadata_props = metadata_props
|
|
377
440
|
|
|
378
441
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
379
442
|
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
|
|
380
443
|
return self._raw.__array__(dtype)
|
|
381
|
-
assert _compatible_with_dlpack(
|
|
382
|
-
self._raw
|
|
383
|
-
)
|
|
444
|
+
assert _compatible_with_dlpack(self._raw), (
|
|
445
|
+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
|
|
446
|
+
)
|
|
384
447
|
return np.from_dlpack(self._raw)
|
|
385
448
|
|
|
386
449
|
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
@@ -394,7 +457,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
394
457
|
return self.__array__().__dlpack_device__()
|
|
395
458
|
|
|
396
459
|
def __repr__(self) -> str:
|
|
397
|
-
|
|
460
|
+
# Avoid multi-line repr
|
|
461
|
+
tensor_lines = repr(self._raw).split("\n")
|
|
462
|
+
tensor_text = " ".join(line.strip() for line in tensor_lines)
|
|
463
|
+
return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
|
|
398
464
|
|
|
399
465
|
@property
|
|
400
466
|
def dtype(self) -> _enums.DataType:
|
|
@@ -431,7 +497,11 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
431
497
|
"""
|
|
432
498
|
# TODO(justinchuby): Support DLPack
|
|
433
499
|
array = self.numpy()
|
|
434
|
-
if self.dtype in {
|
|
500
|
+
if self.dtype in {
|
|
501
|
+
_enums.DataType.INT4,
|
|
502
|
+
_enums.DataType.UINT4,
|
|
503
|
+
_enums.DataType.FLOAT4E2M1,
|
|
504
|
+
}:
|
|
435
505
|
# Pack the array into int4
|
|
436
506
|
array = _type_casting.pack_int4(array)
|
|
437
507
|
else:
|
|
@@ -440,23 +510,6 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
440
510
|
array = array.view(array.dtype.newbyteorder("<"))
|
|
441
511
|
return array.tobytes()
|
|
442
512
|
|
|
443
|
-
@property
|
|
444
|
-
def metadata_props(self) -> dict[str, str]:
|
|
445
|
-
if self._metadata_props is None:
|
|
446
|
-
self._metadata_props = {}
|
|
447
|
-
return self._metadata_props
|
|
448
|
-
|
|
449
|
-
@property
|
|
450
|
-
def meta(self) -> _metadata.MetadataStore:
|
|
451
|
-
"""The metadata store for intermediate analysis.
|
|
452
|
-
|
|
453
|
-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
454
|
-
to the ONNX proto.
|
|
455
|
-
"""
|
|
456
|
-
if self._metadata is None:
|
|
457
|
-
self._metadata = _metadata.MetadataStore()
|
|
458
|
-
return self._metadata
|
|
459
|
-
|
|
460
513
|
|
|
461
514
|
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
462
515
|
"""An immutable concrete tensor with its data store on disk.
|
|
@@ -497,12 +550,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
497
550
|
"_dtype",
|
|
498
551
|
"_length",
|
|
499
552
|
"_location",
|
|
500
|
-
"_metadata",
|
|
501
|
-
"_metadata_props",
|
|
502
553
|
"_offset",
|
|
503
554
|
"_shape",
|
|
504
|
-
"
|
|
505
|
-
"name",
|
|
555
|
+
"_valid",
|
|
506
556
|
"raw",
|
|
507
557
|
)
|
|
508
558
|
|
|
@@ -532,6 +582,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
532
582
|
metadata_props: The metadata properties.
|
|
533
583
|
base_dir: The base directory for the external data. It is used to resolve relative paths.
|
|
534
584
|
"""
|
|
585
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
535
586
|
# NOTE: Do not verify the location by default. This is because the location field
|
|
536
587
|
# in the tensor proto can be anything and we would like deserialization from
|
|
537
588
|
# proto to IR to not fail.
|
|
@@ -547,12 +598,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
547
598
|
self._dtype: _enums.DataType = dtype
|
|
548
599
|
self.name: str = name # mutable
|
|
549
600
|
self._shape: Shape = shape
|
|
550
|
-
self._shape.
|
|
601
|
+
self._shape.freeze()
|
|
551
602
|
self.doc_string: str | None = doc_string # mutable
|
|
552
603
|
self._array: np.ndarray | None = None
|
|
553
604
|
self.raw: mmap.mmap | None = None
|
|
554
605
|
self._metadata_props = metadata_props
|
|
555
606
|
self._metadata: _metadata.MetadataStore | None = None
|
|
607
|
+
self._valid = True
|
|
556
608
|
|
|
557
609
|
@property
|
|
558
610
|
def base_dir(self) -> str | os.PathLike:
|
|
@@ -594,6 +646,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
594
646
|
return self._shape
|
|
595
647
|
|
|
596
648
|
def _load(self):
|
|
649
|
+
self._check_validity()
|
|
597
650
|
assert self._array is None, "Bug: The array should be loaded only once."
|
|
598
651
|
if self.size == 0:
|
|
599
652
|
# When the size is 0, mmap is impossible and meaningless
|
|
@@ -609,7 +662,11 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
609
662
|
)
|
|
610
663
|
# Handle the byte order correctly by always using little endian
|
|
611
664
|
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
612
|
-
if self.dtype in {
|
|
665
|
+
if self.dtype in {
|
|
666
|
+
_enums.DataType.INT4,
|
|
667
|
+
_enums.DataType.UINT4,
|
|
668
|
+
_enums.DataType.FLOAT4E2M1,
|
|
669
|
+
}:
|
|
613
670
|
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
|
|
614
671
|
dt = np.dtype(np.uint8).newbyteorder("<")
|
|
615
672
|
count = self.size // 2 + self.size % 2
|
|
@@ -622,10 +679,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
622
679
|
self._array = _type_casting.unpack_int4(self._array, shape)
|
|
623
680
|
elif self.dtype == _enums.DataType.UINT4:
|
|
624
681
|
self._array = _type_casting.unpack_uint4(self._array, shape)
|
|
682
|
+
elif self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
683
|
+
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
|
|
625
684
|
else:
|
|
626
685
|
self._array = self._array.reshape(shape)
|
|
627
686
|
|
|
628
687
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
688
|
+
self._check_validity()
|
|
629
689
|
if self._array is None:
|
|
630
690
|
self._load()
|
|
631
691
|
assert self._array is not None
|
|
@@ -654,6 +714,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
654
714
|
|
|
655
715
|
The data will be memory mapped into memory and will not taken up physical memory space.
|
|
656
716
|
"""
|
|
717
|
+
self._check_validity()
|
|
657
718
|
if self._array is None:
|
|
658
719
|
self._load()
|
|
659
720
|
assert self._array is not None
|
|
@@ -664,6 +725,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
664
725
|
|
|
665
726
|
This will load the tensor into memory.
|
|
666
727
|
"""
|
|
728
|
+
self._check_validity()
|
|
667
729
|
if self.raw is None:
|
|
668
730
|
self._load()
|
|
669
731
|
assert self.raw is not None
|
|
@@ -671,6 +733,26 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
671
733
|
length = self._length or self.nbytes
|
|
672
734
|
return self.raw[offset : offset + length]
|
|
673
735
|
|
|
736
|
+
def valid(self) -> bool:
|
|
737
|
+
"""Check if the tensor is valid.
|
|
738
|
+
|
|
739
|
+
The external tensor is valid if it has not been invalidated.
|
|
740
|
+
"""
|
|
741
|
+
return self._valid
|
|
742
|
+
|
|
743
|
+
def _check_validity(self) -> None:
|
|
744
|
+
if not self.valid():
|
|
745
|
+
raise ValueError(
|
|
746
|
+
f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
def invalidate(self) -> None:
|
|
750
|
+
"""Invalidate the tensor.
|
|
751
|
+
|
|
752
|
+
The external tensor is invalidated when the data is known to be corrupted or deleted.
|
|
753
|
+
"""
|
|
754
|
+
self._valid = False
|
|
755
|
+
|
|
674
756
|
def release(self) -> None:
|
|
675
757
|
"""Delete all references to the memory buffer and close the memory-mapped file."""
|
|
676
758
|
self._array = None
|
|
@@ -678,34 +760,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
678
760
|
self.raw.close()
|
|
679
761
|
self.raw = None
|
|
680
762
|
|
|
681
|
-
@property
|
|
682
|
-
def metadata_props(self) -> dict[str, str]:
|
|
683
|
-
if self._metadata_props is None:
|
|
684
|
-
self._metadata_props = {}
|
|
685
|
-
return self._metadata_props
|
|
686
|
-
|
|
687
|
-
@property
|
|
688
|
-
def meta(self) -> _metadata.MetadataStore:
|
|
689
|
-
"""The metadata store for intermediate analysis.
|
|
690
|
-
|
|
691
|
-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
692
|
-
to the ONNX proto.
|
|
693
|
-
"""
|
|
694
|
-
if self._metadata is None:
|
|
695
|
-
self._metadata = _metadata.MetadataStore()
|
|
696
|
-
return self._metadata
|
|
697
|
-
|
|
698
763
|
|
|
699
764
|
class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
700
765
|
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
|
|
701
766
|
|
|
702
767
|
__slots__ = (
|
|
703
|
-
"_metadata",
|
|
704
|
-
"_metadata_props",
|
|
705
768
|
"_raw",
|
|
706
769
|
"_shape",
|
|
707
|
-
"doc_string",
|
|
708
|
-
"name",
|
|
709
770
|
)
|
|
710
771
|
|
|
711
772
|
def __init__(
|
|
@@ -726,6 +787,7 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
726
787
|
doc_string: The documentation string.
|
|
727
788
|
metadata_props: The metadata properties.
|
|
728
789
|
"""
|
|
790
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
729
791
|
if shape is None:
|
|
730
792
|
if not hasattr(value, "shape"):
|
|
731
793
|
raise ValueError(
|
|
@@ -735,19 +797,15 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
735
797
|
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
|
|
736
798
|
else:
|
|
737
799
|
self._shape = shape
|
|
738
|
-
self._shape.
|
|
800
|
+
self._shape.freeze()
|
|
739
801
|
self._raw = value
|
|
740
|
-
self.name = name
|
|
741
|
-
self.doc_string = doc_string
|
|
742
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
743
|
-
self._metadata_props = metadata_props
|
|
744
802
|
|
|
745
803
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
746
804
|
if isinstance(self._raw, np.ndarray):
|
|
747
805
|
return self._raw
|
|
748
|
-
assert isinstance(
|
|
749
|
-
self._raw
|
|
750
|
-
)
|
|
806
|
+
assert isinstance(self._raw, Sequence), (
|
|
807
|
+
f"Bug: Expected a sequence, got {type(self._raw)}"
|
|
808
|
+
)
|
|
751
809
|
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
|
|
752
810
|
|
|
753
811
|
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
@@ -788,25 +846,125 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
788
846
|
return self._raw.flatten().tolist()
|
|
789
847
|
return self._raw
|
|
790
848
|
|
|
849
|
+
|
|
850
|
+
class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
851
|
+
"""A tensor that lazily evaluates a function to get the actual tensor.
|
|
852
|
+
|
|
853
|
+
This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
|
|
854
|
+
The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
|
|
855
|
+
|
|
856
|
+
Example::
|
|
857
|
+
|
|
858
|
+
>>> import numpy as np
|
|
859
|
+
>>> import onnx_ir as ir
|
|
860
|
+
>>> weights = np.array([[1, 2, 3]])
|
|
861
|
+
>>> def create_tensor(): # Delay applying transformations to the weights
|
|
862
|
+
... weights_t = weights.transpose()
|
|
863
|
+
... return ir.tensor(weights_t)
|
|
864
|
+
>>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
|
|
865
|
+
>>> print(lazy_tensor.numpy())
|
|
866
|
+
[[1]
|
|
867
|
+
[2]
|
|
868
|
+
[3]]
|
|
869
|
+
|
|
870
|
+
Attributes:
|
|
871
|
+
func: The function that returns the actual tensor.
|
|
872
|
+
dtype: The data type of the tensor.
|
|
873
|
+
shape: The shape of the tensor.
|
|
874
|
+
cache: Whether to cache the result of the function. If False,
|
|
875
|
+
the function is called every time the tensor content is accessed.
|
|
876
|
+
If True, the function is called only once and the result is cached in memory.
|
|
877
|
+
Default is False.
|
|
878
|
+
name: The name of the tensor.
|
|
879
|
+
doc_string: The documentation string.
|
|
880
|
+
metadata_props: The metadata properties.
|
|
881
|
+
"""
|
|
882
|
+
|
|
883
|
+
__slots__ = (
|
|
884
|
+
"_dtype",
|
|
885
|
+
"_func",
|
|
886
|
+
"_shape",
|
|
887
|
+
"_tensor",
|
|
888
|
+
"cache",
|
|
889
|
+
)
|
|
890
|
+
|
|
891
|
+
def __init__(
|
|
892
|
+
self,
|
|
893
|
+
func: Callable[[], _protocols.TensorProtocol],
|
|
894
|
+
dtype: _enums.DataType,
|
|
895
|
+
shape: Shape,
|
|
896
|
+
*,
|
|
897
|
+
cache: bool = False,
|
|
898
|
+
name: str | None = None,
|
|
899
|
+
doc_string: str | None = None,
|
|
900
|
+
metadata_props: dict[str, str] | None = None,
|
|
901
|
+
) -> None:
|
|
902
|
+
"""Initialize a lazy tensor.
|
|
903
|
+
|
|
904
|
+
Args:
|
|
905
|
+
func: The function that returns the actual tensor.
|
|
906
|
+
dtype: The data type of the tensor.
|
|
907
|
+
shape: The shape of the tensor.
|
|
908
|
+
cache: Whether to cache the result of the function.
|
|
909
|
+
name: The name of the tensor.
|
|
910
|
+
doc_string: The documentation string.
|
|
911
|
+
metadata_props: The metadata properties.
|
|
912
|
+
"""
|
|
913
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
914
|
+
self._func = func
|
|
915
|
+
self._dtype = dtype
|
|
916
|
+
self._shape = shape
|
|
917
|
+
self._tensor: _protocols.TensorProtocol | None = None
|
|
918
|
+
self.cache = cache
|
|
919
|
+
|
|
920
|
+
def _evaluate(self) -> _protocols.TensorProtocol:
|
|
921
|
+
"""Evaluate the function to get the actual tensor."""
|
|
922
|
+
if not self.cache:
|
|
923
|
+
return self._func()
|
|
924
|
+
|
|
925
|
+
# Cache the tensor
|
|
926
|
+
if self._tensor is None:
|
|
927
|
+
self._tensor = self._func()
|
|
928
|
+
return self._tensor
|
|
929
|
+
|
|
930
|
+
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
931
|
+
return self._evaluate().__array__(dtype)
|
|
932
|
+
|
|
933
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
934
|
+
return self._evaluate().__dlpack__(stream=stream)
|
|
935
|
+
|
|
936
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
937
|
+
return self._evaluate().__dlpack_device__()
|
|
938
|
+
|
|
939
|
+
def __repr__(self) -> str:
|
|
940
|
+
return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
|
|
941
|
+
|
|
791
942
|
@property
|
|
792
|
-
def
|
|
793
|
-
|
|
794
|
-
self._metadata_props = {}
|
|
795
|
-
return self._metadata_props
|
|
943
|
+
def raw(self) -> Callable[[], _protocols.TensorProtocol]:
|
|
944
|
+
return self._func
|
|
796
945
|
|
|
797
946
|
@property
|
|
798
|
-
def
|
|
799
|
-
"""The
|
|
947
|
+
def dtype(self) -> _enums.DataType:
|
|
948
|
+
"""The data type of the tensor. Immutable."""
|
|
949
|
+
return self._dtype
|
|
800
950
|
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
"""
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
951
|
+
@property
|
|
952
|
+
def shape(self) -> Shape:
|
|
953
|
+
"""The shape of the tensor. Immutable."""
|
|
954
|
+
return self._shape
|
|
955
|
+
|
|
956
|
+
def numpy(self) -> np.ndarray:
|
|
957
|
+
"""Return the tensor as a numpy array."""
|
|
958
|
+
return self._evaluate().numpy()
|
|
959
|
+
|
|
960
|
+
def tobytes(self) -> bytes:
|
|
961
|
+
"""Return the bytes of the tensor."""
|
|
962
|
+
return self._evaluate().tobytes()
|
|
807
963
|
|
|
808
964
|
|
|
809
965
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
966
|
+
"""Immutable symbolic dimension that can be shared across multiple shapes."""
|
|
967
|
+
|
|
810
968
|
__slots__ = ("_value",)
|
|
811
969
|
|
|
812
970
|
def __init__(self, value: str | None) -> None:
|
|
@@ -841,12 +999,84 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
841
999
|
return f"{self.__class__.__name__}({self._value})"
|
|
842
1000
|
|
|
843
1001
|
|
|
1002
|
+
def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
1003
|
+
"""Return True if the value is int compatible."""
|
|
1004
|
+
if isinstance(value, int):
|
|
1005
|
+
return True
|
|
1006
|
+
if hasattr(value, "__int__"):
|
|
1007
|
+
# For performance reasons, we do not use isinstance(value, SupportsInt)
|
|
1008
|
+
return True
|
|
1009
|
+
return False
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
def _maybe_convert_to_symbolic_dim(
|
|
1013
|
+
dim: int | SupportsInt | SymbolicDim | str | None,
|
|
1014
|
+
) -> SymbolicDim | int:
|
|
1015
|
+
"""Convert the value to a SymbolicDim if it is not an int."""
|
|
1016
|
+
if dim is None or isinstance(dim, str):
|
|
1017
|
+
return SymbolicDim(dim)
|
|
1018
|
+
if _is_int_compatible(dim):
|
|
1019
|
+
return int(dim)
|
|
1020
|
+
if isinstance(dim, SymbolicDim):
|
|
1021
|
+
return dim
|
|
1022
|
+
raise TypeError(
|
|
1023
|
+
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
|
|
844
1027
|
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
1028
|
+
"""The shape of a tensor, including its dimensions and optional denotations.
|
|
1029
|
+
|
|
1030
|
+
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1031
|
+
symbolic dimensions.
|
|
1032
|
+
|
|
1033
|
+
A shape can be compared to another shape or plain Python list.
|
|
1034
|
+
|
|
1035
|
+
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
|
|
1036
|
+
unfrozen, making it suitable to be shared across tensors or values.
|
|
1037
|
+
Call :method:`freeze` to freeze the shape.
|
|
1038
|
+
|
|
1039
|
+
To update the dimension of a frozen shape, call :method:`copy` to create a
|
|
1040
|
+
new shape with the same dimensions that can be modified.
|
|
1041
|
+
|
|
1042
|
+
Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
|
|
1043
|
+
|
|
1044
|
+
Example::
|
|
1045
|
+
|
|
1046
|
+
>>> import onnx_ir as ir
|
|
1047
|
+
>>> shape = ir.Shape(["B", None, 3])
|
|
1048
|
+
>>> shape.rank()
|
|
1049
|
+
3
|
|
1050
|
+
>>> shape.is_static()
|
|
1051
|
+
False
|
|
1052
|
+
>>> shape.is_dynamic()
|
|
1053
|
+
True
|
|
1054
|
+
>>> shape.is_static(dim=2)
|
|
1055
|
+
True
|
|
1056
|
+
>>> shape[0] = 1
|
|
1057
|
+
>>> shape[1] = 2
|
|
1058
|
+
>>> shape.dims
|
|
1059
|
+
(1, 2, 3)
|
|
1060
|
+
>>> shape == [1, 2, 3]
|
|
1061
|
+
True
|
|
1062
|
+
>>> shape.frozen
|
|
1063
|
+
False
|
|
1064
|
+
>>> shape.freeze()
|
|
1065
|
+
>>> shape.frozen
|
|
1066
|
+
True
|
|
1067
|
+
|
|
1068
|
+
Attributes:
|
|
1069
|
+
dims: A tuple of dimensions representing the shape.
|
|
1070
|
+
Each dimension can be an integer, None or a :class:`SymbolicDim`.
|
|
1071
|
+
frozen: Indicates whether the shape is immutable. When frozen, the shape
|
|
1072
|
+
cannot be modified or unfrozen.
|
|
1073
|
+
"""
|
|
1074
|
+
|
|
845
1075
|
__slots__ = ("_dims", "_frozen")
|
|
846
1076
|
|
|
847
1077
|
def __init__(
|
|
848
1078
|
self,
|
|
849
|
-
dims: Iterable[int | SymbolicDim | str | None],
|
|
1079
|
+
dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
|
|
850
1080
|
/,
|
|
851
1081
|
denotations: Iterable[str | None] | None = None,
|
|
852
1082
|
frozen: bool = False,
|
|
@@ -864,11 +1094,11 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
864
1094
|
Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
|
|
865
1095
|
for pre-defined dimension denotations.
|
|
866
1096
|
frozen: If True, the shape is immutable and cannot be modified. This
|
|
867
|
-
is useful when the shape is initialized by a Tensor
|
|
1097
|
+
is useful when the shape is initialized by a Tensor or when the shape
|
|
1098
|
+
is shared across multiple tensors. The default is False.
|
|
868
1099
|
"""
|
|
869
1100
|
self._dims: list[int | SymbolicDim] = [
|
|
870
|
-
|
|
871
|
-
for dim in dims
|
|
1101
|
+
_maybe_convert_to_symbolic_dim(dim) for dim in dims
|
|
872
1102
|
]
|
|
873
1103
|
self._denotations: list[str | None] = (
|
|
874
1104
|
list(denotations) if denotations is not None else [None] * len(self._dims)
|
|
@@ -879,10 +1109,6 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
879
1109
|
)
|
|
880
1110
|
self._frozen: bool = frozen
|
|
881
1111
|
|
|
882
|
-
def copy(self):
|
|
883
|
-
"""Return a copy of the shape."""
|
|
884
|
-
return Shape(self._dims, self._denotations, self._frozen)
|
|
885
|
-
|
|
886
1112
|
@property
|
|
887
1113
|
def dims(self) -> tuple[int | SymbolicDim, ...]:
|
|
888
1114
|
"""All dimensions in the shape.
|
|
@@ -891,8 +1117,29 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
891
1117
|
"""
|
|
892
1118
|
return tuple(self._dims)
|
|
893
1119
|
|
|
1120
|
+
@property
|
|
1121
|
+
def frozen(self) -> bool:
|
|
1122
|
+
"""Whether the shape is frozen.
|
|
1123
|
+
|
|
1124
|
+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1125
|
+
Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
|
|
1126
|
+
new shape with the same dimensions that can be modified.
|
|
1127
|
+
"""
|
|
1128
|
+
return self._frozen
|
|
1129
|
+
|
|
1130
|
+
def freeze(self) -> None:
|
|
1131
|
+
"""Freeze the shape.
|
|
1132
|
+
|
|
1133
|
+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1134
|
+
"""
|
|
1135
|
+
self._frozen = True
|
|
1136
|
+
|
|
1137
|
+
def copy(self, frozen: bool = False):
|
|
1138
|
+
"""Return a copy of the shape."""
|
|
1139
|
+
return Shape(self._dims, self._denotations, frozen=frozen)
|
|
1140
|
+
|
|
894
1141
|
def rank(self) -> int:
|
|
895
|
-
"""The rank of the shape."""
|
|
1142
|
+
"""The rank of the tensor this shape represents."""
|
|
896
1143
|
return len(self._dims)
|
|
897
1144
|
|
|
898
1145
|
def numpy(self) -> tuple[int, ...]:
|
|
@@ -928,12 +1175,8 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
928
1175
|
"""
|
|
929
1176
|
if self._frozen:
|
|
930
1177
|
raise TypeError("The shape is frozen and cannot be modified.")
|
|
931
|
-
if isinstance(value, str) or value is None:
|
|
932
|
-
value = SymbolicDim(value)
|
|
933
|
-
if not isinstance(value, (int, SymbolicDim)):
|
|
934
|
-
raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
|
|
935
1178
|
|
|
936
|
-
self._dims[index] = value
|
|
1179
|
+
self._dims[index] = _maybe_convert_to_symbolic_dim(value)
|
|
937
1180
|
|
|
938
1181
|
def get_denotation(self, index: int) -> str | None:
|
|
939
1182
|
"""Return the denotation of the dimension at the index.
|
|
@@ -968,7 +1211,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
968
1211
|
def __eq__(self, other: object) -> bool:
|
|
969
1212
|
"""Return True if the shapes are equal.
|
|
970
1213
|
|
|
971
|
-
Two shapes are
|
|
1214
|
+
Two shapes are equal if all their dimensions are equal.
|
|
972
1215
|
"""
|
|
973
1216
|
if isinstance(other, Shape):
|
|
974
1217
|
return self._dims == other._dims
|
|
@@ -979,6 +1222,33 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
979
1222
|
def __ne__(self, other: object) -> bool:
|
|
980
1223
|
return not self.__eq__(other)
|
|
981
1224
|
|
|
1225
|
+
@typing.overload
|
|
1226
|
+
def is_static(self, dim: int) -> bool: # noqa: D418
|
|
1227
|
+
"""Return True if the dimension is static."""
|
|
1228
|
+
|
|
1229
|
+
@typing.overload
|
|
1230
|
+
def is_static(self) -> bool: # noqa: D418
|
|
1231
|
+
"""Return True if all dimensions are static."""
|
|
1232
|
+
|
|
1233
|
+
def is_static(self, dim=None) -> bool:
|
|
1234
|
+
"""Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
|
|
1235
|
+
if dim is None:
|
|
1236
|
+
return all(isinstance(dim, int) for dim in self._dims)
|
|
1237
|
+
return isinstance(self[dim], int)
|
|
1238
|
+
|
|
1239
|
+
@typing.overload
|
|
1240
|
+
def is_dynamic(self, dim: int) -> bool: # noqa: D418
|
|
1241
|
+
"""Return True if the dimension is dynamic."""
|
|
1242
|
+
|
|
1243
|
+
@typing.overload
|
|
1244
|
+
def is_dynamic(self) -> bool: # noqa: D418
|
|
1245
|
+
"""Return True if any dimension is dynamic."""
|
|
1246
|
+
|
|
1247
|
+
def is_dynamic(self, dim=None) -> bool:
|
|
1248
|
+
if dim is None:
|
|
1249
|
+
return not self.is_static()
|
|
1250
|
+
return not self.is_static(dim)
|
|
1251
|
+
|
|
982
1252
|
|
|
983
1253
|
def _quoted(string: str) -> str:
|
|
984
1254
|
"""Return a quoted string.
|
|
@@ -988,6 +1258,35 @@ def _quoted(string: str) -> str:
|
|
|
988
1258
|
return f'"{string}"'
|
|
989
1259
|
|
|
990
1260
|
|
|
1261
|
+
class Usage(NamedTuple):
|
|
1262
|
+
"""A usage of a value in a node.
|
|
1263
|
+
|
|
1264
|
+
Attributes:
|
|
1265
|
+
node: The node that uses the value.
|
|
1266
|
+
idx: The input index of the value in the node.
|
|
1267
|
+
"""
|
|
1268
|
+
|
|
1269
|
+
node: Node
|
|
1270
|
+
idx: int
|
|
1271
|
+
|
|
1272
|
+
|
|
1273
|
+
def _short_tensor_str_for_node(x: Value) -> str:
|
|
1274
|
+
if x.const_value is None:
|
|
1275
|
+
return ""
|
|
1276
|
+
if x.const_value.size <= 10:
|
|
1277
|
+
try:
|
|
1278
|
+
data = x.const_value.numpy().tolist()
|
|
1279
|
+
except Exception: # pylint: disable=broad-except
|
|
1280
|
+
return "{...}"
|
|
1281
|
+
return f"{{{data}}}"
|
|
1282
|
+
return "{...}"
|
|
1283
|
+
|
|
1284
|
+
|
|
1285
|
+
def _normalize_domain(domain: str) -> str:
|
|
1286
|
+
"""Normalize 'ai.onnx' to ''."""
|
|
1287
|
+
return "" if domain == "ai.onnx" else domain
|
|
1288
|
+
|
|
1289
|
+
|
|
991
1290
|
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
992
1291
|
"""IR Node.
|
|
993
1292
|
|
|
@@ -1001,6 +1300,9 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1001
1300
|
To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
|
|
1002
1301
|
the new output values by calling :meth:`replace_input_with` on the using nodes
|
|
1003
1302
|
of this node's outputs.
|
|
1303
|
+
|
|
1304
|
+
.. note:
|
|
1305
|
+
When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
|
|
1004
1306
|
"""
|
|
1005
1307
|
|
|
1006
1308
|
__slots__ = (
|
|
@@ -1023,13 +1325,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1023
1325
|
domain: str,
|
|
1024
1326
|
op_type: str,
|
|
1025
1327
|
inputs: Iterable[Value | None],
|
|
1026
|
-
attributes: Iterable[Attr
|
|
1328
|
+
attributes: Iterable[Attr] = (),
|
|
1027
1329
|
*,
|
|
1028
1330
|
overload: str = "",
|
|
1029
1331
|
num_outputs: int | None = None,
|
|
1030
1332
|
outputs: Sequence[Value] | None = None,
|
|
1031
1333
|
version: int | None = None,
|
|
1032
|
-
graph: Graph | None = None,
|
|
1334
|
+
graph: Graph | Function | None = None,
|
|
1033
1335
|
name: str | None = None,
|
|
1034
1336
|
doc_string: str | None = None,
|
|
1035
1337
|
metadata_props: dict[str, str] | None = None,
|
|
@@ -1038,27 +1340,30 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1038
1340
|
|
|
1039
1341
|
Args:
|
|
1040
1342
|
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
1343
|
+
When it is `"ai.onnx"`, it is normalized to `""`.
|
|
1041
1344
|
op_type: The name of the operator.
|
|
1042
|
-
inputs: The input values. When an input is None
|
|
1345
|
+
inputs: The input values. When an input is ``None``, it is an empty input.
|
|
1043
1346
|
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
1044
1347
|
overload: The overload name when the node is invoking a function.
|
|
1045
1348
|
num_outputs: The number of outputs of the node. If not specified, the number is 1.
|
|
1046
|
-
outputs: The output values. If None
|
|
1047
|
-
version: The version of the operator. If None
|
|
1048
|
-
graph: The graph that the node belongs to. If None
|
|
1049
|
-
A `Node` must belong to zero or one graph.
|
|
1050
|
-
|
|
1349
|
+
outputs: The output values. If ``None``, the outputs are created during initialization.
|
|
1350
|
+
version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
|
|
1351
|
+
graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
|
|
1352
|
+
A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
|
|
1353
|
+
of the function is assigned to the node.
|
|
1354
|
+
name: The name of the node. If ``None``, the node is anonymous. The name may be
|
|
1355
|
+
set by a :class:`Graph` if ``graph`` is specified.
|
|
1051
1356
|
doc_string: The documentation string.
|
|
1052
1357
|
metadata_props: The metadata properties.
|
|
1053
1358
|
|
|
1054
1359
|
Raises:
|
|
1055
|
-
TypeError: If the attributes are not Attr
|
|
1056
|
-
ValueError: If
|
|
1057
|
-
ValueError: If an output value is None
|
|
1360
|
+
TypeError: If the attributes are not :class:`Attr`.
|
|
1361
|
+
ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
|
|
1362
|
+
ValueError: If an output value is ``None``, when outputs is specified.
|
|
1058
1363
|
ValueError: If an output value has a producer set already, when outputs is specified.
|
|
1059
1364
|
"""
|
|
1060
1365
|
self._name = name
|
|
1061
|
-
self._domain: str = domain
|
|
1366
|
+
self._domain: str = _normalize_domain(domain)
|
|
1062
1367
|
self._op_type: str = op_type
|
|
1063
1368
|
# NOTE: Make inputs immutable with the assumption that they are not mutated
|
|
1064
1369
|
# very often. This way all mutations can be tracked.
|
|
@@ -1067,13 +1372,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1067
1372
|
# Values belong to their defining nodes. The values list is immutable
|
|
1068
1373
|
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
|
|
1069
1374
|
attributes = tuple(attributes)
|
|
1070
|
-
if attributes and not isinstance(attributes[0],
|
|
1375
|
+
if attributes and not isinstance(attributes[0], Attr):
|
|
1071
1376
|
raise TypeError(
|
|
1072
|
-
f"Expected the attributes to be Attr
|
|
1377
|
+
f"Expected the attributes to be Attr, got {type(attributes[0])}. "
|
|
1073
1378
|
"If you are copying the attributes from another node, make sure you call "
|
|
1074
1379
|
"node.attributes.values() because it is a dictionary."
|
|
1075
1380
|
)
|
|
1076
|
-
self._attributes: OrderedDict[str, Attr
|
|
1381
|
+
self._attributes: OrderedDict[str, Attr] = OrderedDict(
|
|
1077
1382
|
(attr.name, attr) for attr in attributes
|
|
1078
1383
|
)
|
|
1079
1384
|
self._overload: str = overload
|
|
@@ -1081,7 +1386,11 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1081
1386
|
self._version: int | None = version
|
|
1082
1387
|
self._metadata: _metadata.MetadataStore | None = None
|
|
1083
1388
|
self._metadata_props: dict[str, str] | None = metadata_props
|
|
1084
|
-
|
|
1389
|
+
# _graph is set by graph.append
|
|
1390
|
+
self._graph: Graph | None = None
|
|
1391
|
+
# Add the node to the graph if graph is specified
|
|
1392
|
+
if graph is not None:
|
|
1393
|
+
graph.append(self)
|
|
1085
1394
|
self.doc_string = doc_string
|
|
1086
1395
|
|
|
1087
1396
|
# Add the node as a use of the inputs
|
|
@@ -1089,10 +1398,6 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1089
1398
|
if input_value is not None:
|
|
1090
1399
|
input_value._add_usage(self, i) # pylint: disable=protected-access
|
|
1091
1400
|
|
|
1092
|
-
# Add the node to the graph if graph is specified
|
|
1093
|
-
if self._graph is not None:
|
|
1094
|
-
self._graph.append(self)
|
|
1095
|
-
|
|
1096
1401
|
def _create_outputs(
|
|
1097
1402
|
self, num_outputs: int | None, outputs: Sequence[Value] | None
|
|
1098
1403
|
) -> tuple[Value, ...]:
|
|
@@ -1150,7 +1455,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1150
1455
|
+ ", ".join(
|
|
1151
1456
|
[
|
|
1152
1457
|
(
|
|
1153
|
-
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
|
|
1458
|
+
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
|
|
1154
1459
|
if x is not None
|
|
1155
1460
|
else "None"
|
|
1156
1461
|
)
|
|
@@ -1178,6 +1483,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1178
1483
|
|
|
1179
1484
|
@property
|
|
1180
1485
|
def name(self) -> str | None:
|
|
1486
|
+
"""Optional name of the node."""
|
|
1181
1487
|
return self._name
|
|
1182
1488
|
|
|
1183
1489
|
@name.setter
|
|
@@ -1186,14 +1492,26 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1186
1492
|
|
|
1187
1493
|
@property
|
|
1188
1494
|
def domain(self) -> str:
|
|
1495
|
+
"""The domain of the operator. For onnx operators, this is an empty string.
|
|
1496
|
+
|
|
1497
|
+
.. note:
|
|
1498
|
+
When domain is `"ai.onnx"`, it is normalized to `""`.
|
|
1499
|
+
"""
|
|
1189
1500
|
return self._domain
|
|
1190
1501
|
|
|
1191
1502
|
@domain.setter
|
|
1192
1503
|
def domain(self, value: str) -> None:
|
|
1193
|
-
self._domain = value
|
|
1504
|
+
self._domain = _normalize_domain(value)
|
|
1194
1505
|
|
|
1195
1506
|
@property
|
|
1196
1507
|
def version(self) -> int | None:
|
|
1508
|
+
"""Opset version of the operator called.
|
|
1509
|
+
|
|
1510
|
+
If ``None``, the version is unspecified and will follow that of the graph.
|
|
1511
|
+
This property is special to ONNX IR to allow mixed opset usage in a graph
|
|
1512
|
+
for supporting more flexible graph transformations. It does not exist in the ONNX
|
|
1513
|
+
serialization (protobuf) spec.
|
|
1514
|
+
"""
|
|
1197
1515
|
return self._version
|
|
1198
1516
|
|
|
1199
1517
|
@version.setter
|
|
@@ -1202,6 +1520,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1202
1520
|
|
|
1203
1521
|
@property
|
|
1204
1522
|
def op_type(self) -> str:
|
|
1523
|
+
"""The name of the operator called."""
|
|
1205
1524
|
return self._op_type
|
|
1206
1525
|
|
|
1207
1526
|
@op_type.setter
|
|
@@ -1210,6 +1529,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1210
1529
|
|
|
1211
1530
|
@property
|
|
1212
1531
|
def overload(self) -> str:
|
|
1532
|
+
"""The overload name when the node is invoking a function."""
|
|
1213
1533
|
return self._overload
|
|
1214
1534
|
|
|
1215
1535
|
@overload.setter
|
|
@@ -1218,6 +1538,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1218
1538
|
|
|
1219
1539
|
@property
|
|
1220
1540
|
def inputs(self) -> Sequence[Value | None]:
|
|
1541
|
+
"""The input values of the node.
|
|
1542
|
+
|
|
1543
|
+
The inputs are immutable. To change the inputs, create a new node and
|
|
1544
|
+
replace the inputs of the using nodes of this node's outputs by calling
|
|
1545
|
+
:meth:`replace_input_with` on the using nodes of this node's outputs.
|
|
1546
|
+
"""
|
|
1221
1547
|
return self._inputs
|
|
1222
1548
|
|
|
1223
1549
|
@inputs.setter
|
|
@@ -1226,6 +1552,25 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1226
1552
|
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
|
|
1227
1553
|
)
|
|
1228
1554
|
|
|
1555
|
+
def predecessors(self) -> Sequence[Node]:
|
|
1556
|
+
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
|
|
1557
|
+
# Use the ordered nature of a dictionary to deduplicate the nodes
|
|
1558
|
+
predecessors: dict[Node, None] = {}
|
|
1559
|
+
for value in self.inputs:
|
|
1560
|
+
if value is not None and (producer := value.producer()) is not None:
|
|
1561
|
+
predecessors[producer] = None
|
|
1562
|
+
return tuple(predecessors)
|
|
1563
|
+
|
|
1564
|
+
def successors(self) -> Sequence[Node]:
|
|
1565
|
+
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
|
|
1566
|
+
# Use the ordered nature of a dictionary to deduplicate the nodes
|
|
1567
|
+
successors: dict[Node, None] = {}
|
|
1568
|
+
for value in self.outputs:
|
|
1569
|
+
assert value is not None, "Bug: Output values are not expected to be None"
|
|
1570
|
+
for usage in value.uses():
|
|
1571
|
+
successors[usage.node] = None
|
|
1572
|
+
return tuple(successors)
|
|
1573
|
+
|
|
1229
1574
|
def replace_input_with(self, index: int, value: Value | None) -> None:
|
|
1230
1575
|
"""Replace an input with a new value."""
|
|
1231
1576
|
if index < 0 or index >= len(self.inputs):
|
|
@@ -1279,6 +1624,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1279
1624
|
|
|
1280
1625
|
@property
|
|
1281
1626
|
def outputs(self) -> Sequence[Value]:
|
|
1627
|
+
"""The output values of the node.
|
|
1628
|
+
|
|
1629
|
+
The outputs are immutable. To change the outputs, create a new node and
|
|
1630
|
+
replace the inputs of the using nodes of this node's outputs by calling
|
|
1631
|
+
:meth:`replace_input_with` on the using nodes of this node's outputs.
|
|
1632
|
+
"""
|
|
1282
1633
|
return self._outputs
|
|
1283
1634
|
|
|
1284
1635
|
@outputs.setter
|
|
@@ -1286,7 +1637,8 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1286
1637
|
raise AttributeError("outputs is immutable. Please create a new node instead.")
|
|
1287
1638
|
|
|
1288
1639
|
@property
|
|
1289
|
-
def attributes(self) -> OrderedDict[str, Attr
|
|
1640
|
+
def attributes(self) -> OrderedDict[str, Attr]:
|
|
1641
|
+
"""The attributes of the node."""
|
|
1290
1642
|
return self._attributes
|
|
1291
1643
|
|
|
1292
1644
|
@property
|
|
@@ -1302,12 +1654,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1302
1654
|
|
|
1303
1655
|
@property
|
|
1304
1656
|
def metadata_props(self) -> dict[str, str]:
|
|
1657
|
+
"""The metadata properties of the node.
|
|
1658
|
+
|
|
1659
|
+
The metadata properties are used to store additional information about the node.
|
|
1660
|
+
Unlike ``meta``, this property is serialized to the ONNX proto.
|
|
1661
|
+
"""
|
|
1305
1662
|
if self._metadata_props is None:
|
|
1306
1663
|
self._metadata_props = {}
|
|
1307
1664
|
return self._metadata_props
|
|
1308
1665
|
|
|
1309
1666
|
@property
|
|
1310
1667
|
def graph(self) -> Graph | None:
|
|
1668
|
+
"""The graph that the node belongs to.
|
|
1669
|
+
|
|
1670
|
+
If the node is not added to any graph, this property is None.
|
|
1671
|
+
"""
|
|
1311
1672
|
return self._graph
|
|
1312
1673
|
|
|
1313
1674
|
@graph.setter
|
|
@@ -1315,9 +1676,17 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1315
1676
|
self._graph = value
|
|
1316
1677
|
|
|
1317
1678
|
def op_identifier(self) -> _protocols.OperatorIdentifier:
|
|
1679
|
+
"""Return the operator identifier of the node.
|
|
1680
|
+
|
|
1681
|
+
The operator identifier is a tuple of the domain, op_type and overload.
|
|
1682
|
+
"""
|
|
1318
1683
|
return self.domain, self.op_type, self.overload
|
|
1319
1684
|
|
|
1320
1685
|
def display(self, *, page: bool = False) -> None:
|
|
1686
|
+
"""Pretty print the node.
|
|
1687
|
+
|
|
1688
|
+
This method is used for debugging and visualization purposes.
|
|
1689
|
+
"""
|
|
1321
1690
|
# Add the node's name to the displayed text
|
|
1322
1691
|
print(f"Node: {self.name!r}")
|
|
1323
1692
|
if self.doc_string:
|
|
@@ -1344,7 +1713,7 @@ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashabl
|
|
|
1344
1713
|
|
|
1345
1714
|
@property
|
|
1346
1715
|
def elem_type(self) -> _enums.DataType:
|
|
1347
|
-
"""Return the element type of the tensor type"""
|
|
1716
|
+
"""Return the element type of the tensor type."""
|
|
1348
1717
|
return self.dtype
|
|
1349
1718
|
|
|
1350
1719
|
def __hash__(self) -> int:
|
|
@@ -1438,18 +1807,19 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1438
1807
|
|
|
1439
1808
|
To find all the nodes that use this value as an input, call :meth:`uses`.
|
|
1440
1809
|
|
|
1441
|
-
To check if the value is an output of a graph,
|
|
1810
|
+
To check if the value is an is an input, output or initializer of a graph,
|
|
1811
|
+
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
|
|
1442
1812
|
|
|
1443
|
-
|
|
1444
|
-
name: The name of the value. A value is always named when it is part of a graph.
|
|
1445
|
-
shape: The shape of the value.
|
|
1446
|
-
type: The type of the value.
|
|
1447
|
-
metadata_props: Metadata.
|
|
1813
|
+
Use :meth:`graph` to get the graph that owns the value.
|
|
1448
1814
|
"""
|
|
1449
1815
|
|
|
1450
1816
|
__slots__ = (
|
|
1451
1817
|
"_const_value",
|
|
1818
|
+
"_graph",
|
|
1452
1819
|
"_index",
|
|
1820
|
+
"_is_graph_input",
|
|
1821
|
+
"_is_graph_output",
|
|
1822
|
+
"_is_initializer",
|
|
1453
1823
|
"_metadata",
|
|
1454
1824
|
"_metadata_props",
|
|
1455
1825
|
"_name",
|
|
@@ -1497,19 +1867,33 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1497
1867
|
# Use a collection of (Node, int) to store uses. This is needed
|
|
1498
1868
|
# because a single use can use the same value multiple times.
|
|
1499
1869
|
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
|
|
1500
|
-
self._uses: dict[
|
|
1870
|
+
self._uses: dict[Usage, None] = {}
|
|
1501
1871
|
self.doc_string = doc_string
|
|
1502
1872
|
|
|
1873
|
+
# The graph this value belongs to. It is set *only* when the value is added as
|
|
1874
|
+
# a graph input, output or initializer.
|
|
1875
|
+
# The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
|
|
1876
|
+
self._graph: Graph | None = None
|
|
1877
|
+
self._is_graph_input: bool = False
|
|
1878
|
+
self._is_graph_output: bool = False
|
|
1879
|
+
self._is_initializer: bool = False
|
|
1880
|
+
|
|
1503
1881
|
def __repr__(self) -> str:
|
|
1504
1882
|
value_name = self.name if self.name else "anonymous:" + str(id(self))
|
|
1883
|
+
type_text = f", type={self.type!r}" if self.type is not None else ""
|
|
1884
|
+
shape_text = f", shape={self.shape!r}" if self.shape is not None else ""
|
|
1505
1885
|
producer = self.producer()
|
|
1506
1886
|
if producer is None:
|
|
1507
|
-
producer_text = "
|
|
1887
|
+
producer_text = ""
|
|
1508
1888
|
elif producer.name is not None:
|
|
1509
|
-
producer_text = producer.name
|
|
1889
|
+
producer_text = f", producer='{producer.name}'"
|
|
1510
1890
|
else:
|
|
1511
|
-
producer_text = f"anonymous_node:{id(producer)}"
|
|
1512
|
-
|
|
1891
|
+
producer_text = f", producer=anonymous_node:{id(producer)}"
|
|
1892
|
+
index_text = f", index={self.index()}" if self.index() is not None else ""
|
|
1893
|
+
const_value_text = self._constant_tensor_part()
|
|
1894
|
+
if const_value_text:
|
|
1895
|
+
const_value_text = f", const_value={const_value_text}"
|
|
1896
|
+
return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})"
|
|
1513
1897
|
|
|
1514
1898
|
def __str__(self) -> str:
|
|
1515
1899
|
value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
|
|
@@ -1518,41 +1902,85 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1518
1902
|
|
|
1519
1903
|
# Quote the name because in reality the names can have invalid characters
|
|
1520
1904
|
# that make them hard to read
|
|
1521
|
-
return
|
|
1905
|
+
return (
|
|
1906
|
+
f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
|
|
1907
|
+
)
|
|
1908
|
+
|
|
1909
|
+
def _constant_tensor_part(self) -> str:
|
|
1910
|
+
"""Display string for the constant tensor attached to str of Value."""
|
|
1911
|
+
if self.const_value is not None:
|
|
1912
|
+
# Only display when the const value is small
|
|
1913
|
+
if self.const_value.size <= 10:
|
|
1914
|
+
return f"{{{self.const_value}}}"
|
|
1915
|
+
else:
|
|
1916
|
+
return f"{{{self.const_value.__class__.__name__}(...)}}"
|
|
1917
|
+
return ""
|
|
1918
|
+
|
|
1919
|
+
@property
|
|
1920
|
+
def graph(self) -> Graph | None:
|
|
1921
|
+
"""Return the graph that defines this value.
|
|
1922
|
+
|
|
1923
|
+
When the value is an input/output/initializer of a graph, the owning graph
|
|
1924
|
+
is that graph. When the value is an output of a node, the owning graph is the
|
|
1925
|
+
graph that the node belongs to. When the value is not owned by any graph,
|
|
1926
|
+
it returns ``None``.
|
|
1927
|
+
"""
|
|
1928
|
+
if self._graph is not None:
|
|
1929
|
+
return self._graph
|
|
1930
|
+
if self._producer is not None:
|
|
1931
|
+
return self._producer.graph
|
|
1932
|
+
return None
|
|
1933
|
+
|
|
1934
|
+
def _owned_by_graph(self) -> bool:
|
|
1935
|
+
"""Return True if the value is owned by a graph."""
|
|
1936
|
+
result = self._is_graph_input or self._is_graph_output or self._is_initializer
|
|
1937
|
+
if result:
|
|
1938
|
+
assert self._graph is not None
|
|
1939
|
+
return result
|
|
1522
1940
|
|
|
1523
1941
|
def producer(self) -> Node | None:
|
|
1524
1942
|
"""The node that produces this value.
|
|
1525
1943
|
|
|
1526
1944
|
When producer is ``None``, the value does not belong to a node, and is
|
|
1527
|
-
typically a graph input or an initializer.
|
|
1945
|
+
typically a graph input or an initializer. You can use :meth:`graph``
|
|
1946
|
+
to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
|
|
1947
|
+
or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
|
|
1528
1948
|
"""
|
|
1529
1949
|
return self._producer
|
|
1530
1950
|
|
|
1951
|
+
def consumers(self) -> Sequence[Node]:
|
|
1952
|
+
"""Return the nodes (deduplicated) that consume this value."""
|
|
1953
|
+
return tuple({usage.node: None for usage in self._uses})
|
|
1954
|
+
|
|
1531
1955
|
def index(self) -> int | None:
|
|
1532
1956
|
"""The index of the output of the defining node."""
|
|
1533
1957
|
return self._index
|
|
1534
1958
|
|
|
1535
|
-
def uses(self) -> Collection[
|
|
1959
|
+
def uses(self) -> Collection[Usage]:
|
|
1536
1960
|
"""Return a set of uses of the value.
|
|
1537
1961
|
|
|
1538
1962
|
The set contains tuples of ``(Node, index)`` where the index is the index of the input
|
|
1539
1963
|
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
|
|
1540
1964
|
"""
|
|
1541
|
-
|
|
1965
|
+
# Create a tuple for the collection so that iteration on will will not
|
|
1966
|
+
# be affected when the usage changes during graph mutation.
|
|
1967
|
+
# This adds a small overhead but is better a user experience than
|
|
1968
|
+
# having users call tuple().
|
|
1969
|
+
return tuple(self._uses)
|
|
1542
1970
|
|
|
1543
1971
|
def _add_usage(self, use: Node, index: int) -> None:
|
|
1544
1972
|
"""Add a usage of this value.
|
|
1545
1973
|
|
|
1546
1974
|
This is an internal method. It should only be called by the Node class.
|
|
1547
1975
|
"""
|
|
1548
|
-
self._uses[(use, index)] = None
|
|
1976
|
+
self._uses[Usage(use, index)] = None
|
|
1549
1977
|
|
|
1550
1978
|
def _remove_usage(self, use: Node, index: int) -> None:
|
|
1551
1979
|
"""Remove a node from the uses of this value.
|
|
1552
1980
|
|
|
1553
1981
|
This is an internal method. It should only be called by the Node class.
|
|
1554
1982
|
"""
|
|
1555
|
-
self._uses.pop((use, index))
|
|
1983
|
+
self._uses.pop(Usage(use, index))
|
|
1556
1984
|
|
|
1557
1985
|
@property
|
|
1558
1986
|
def name(self) -> str | None:
|
|
@@ -1652,15 +2080,17 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1652
2080
|
self._metadata_props = {}
|
|
1653
2081
|
return self._metadata_props
|
|
1654
2082
|
|
|
2083
|
+
def is_graph_input(self) -> bool:
|
|
2084
|
+
"""Whether the value is an input of a graph."""
|
|
2085
|
+
return self._is_graph_input
|
|
2086
|
+
|
|
1655
2087
|
def is_graph_output(self) -> bool:
|
|
1656
2088
|
"""Whether the value is an output of a graph."""
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
# it is not recommended
|
|
1663
|
-
return any(output is self for output in graph.outputs)
|
|
2089
|
+
return self._is_graph_output
|
|
2090
|
+
|
|
2091
|
+
def is_initializer(self) -> bool:
|
|
2092
|
+
"""Whether the value is an initializer of a graph."""
|
|
2093
|
+
return self._is_initializer
|
|
1664
2094
|
|
|
1665
2095
|
|
|
1666
2096
|
def Input(
|
|
@@ -1673,7 +2103,6 @@ def Input(
|
|
|
1673
2103
|
|
|
1674
2104
|
This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
|
|
1675
2105
|
"""
|
|
1676
|
-
|
|
1677
2106
|
# NOTE: The function name is capitalized to maintain API backward compatibility.
|
|
1678
2107
|
|
|
1679
2108
|
return Value(name=name, shape=shape, type=type, doc_string=doc_string)
|
|
@@ -1770,9 +2199,9 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1770
2199
|
self.name = name
|
|
1771
2200
|
|
|
1772
2201
|
# Private fields that are not to be accessed by any other classes
|
|
1773
|
-
self._inputs =
|
|
1774
|
-
self._outputs =
|
|
1775
|
-
self._initializers =
|
|
2202
|
+
self._inputs = _graph_containers.GraphInputs(self, inputs)
|
|
2203
|
+
self._outputs = _graph_containers.GraphOutputs(self, outputs)
|
|
2204
|
+
self._initializers = _graph_containers.GraphInitializers(self)
|
|
1776
2205
|
for initializer in initializers:
|
|
1777
2206
|
if isinstance(initializer, str):
|
|
1778
2207
|
raise TypeError(
|
|
@@ -1791,21 +2220,59 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1791
2220
|
# Be sure the initialize the name authority before extending the nodes
|
|
1792
2221
|
# because it is used to name the nodes and their outputs
|
|
1793
2222
|
self._name_authority = _name_authority.NameAuthority()
|
|
2223
|
+
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
|
|
2224
|
+
self._set_input_and_initializer_value_names_into_name_authority()
|
|
1794
2225
|
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
|
|
1795
2226
|
self.extend(nodes)
|
|
1796
2227
|
|
|
1797
2228
|
@property
|
|
1798
|
-
def inputs(self) ->
|
|
2229
|
+
def inputs(self) -> MutableSequence[Value]:
|
|
1799
2230
|
return self._inputs
|
|
1800
2231
|
|
|
1801
2232
|
@property
|
|
1802
|
-
def outputs(self) ->
|
|
2233
|
+
def outputs(self) -> MutableSequence[Value]:
|
|
1803
2234
|
return self._outputs
|
|
1804
2235
|
|
|
1805
2236
|
@property
|
|
1806
|
-
def initializers(self) ->
|
|
2237
|
+
def initializers(self) -> MutableMapping[str, Value]:
|
|
1807
2238
|
return self._initializers
|
|
1808
2239
|
|
|
2240
|
+
def register_initializer(self, value: Value) -> None:
|
|
2241
|
+
"""Register an initializer to the graph.
|
|
2242
|
+
|
|
2243
|
+
This is a convenience method to register an initializer to the graph with
|
|
2244
|
+
checks.
|
|
2245
|
+
|
|
2246
|
+
Args:
|
|
2247
|
+
value: The :class:`Value` to register as an initializer of the graph.
|
|
2248
|
+
It must have its ``.const_value`` set.
|
|
2249
|
+
|
|
2250
|
+
Raises:
|
|
2251
|
+
ValueError: If a value of the same name that is not this value
|
|
2252
|
+
is already registered.
|
|
2253
|
+
ValueError: If the value does not have a name.
|
|
2254
|
+
ValueError: If the initializer is produced by a node.
|
|
2255
|
+
ValueError: If the value does not have its ``.const_value`` set.
|
|
2256
|
+
"""
|
|
2257
|
+
if not value.name:
|
|
2258
|
+
raise ValueError(f"Initializer must have a name: {value!r}")
|
|
2259
|
+
if value.name in self._initializers:
|
|
2260
|
+
if self._initializers[value.name] is not value:
|
|
2261
|
+
raise ValueError(
|
|
2262
|
+
f"Initializer '{value.name}' is already registered, but"
|
|
2263
|
+
" it is not the same object: existing={self._initializers[value.name]!r},"
|
|
2264
|
+
f" new={value!r}"
|
|
2265
|
+
)
|
|
2266
|
+
if value.producer() is not None:
|
|
2267
|
+
raise ValueError(
|
|
2268
|
+
f"Value '{value!r}' is produced by a node and cannot be an initializer."
|
|
2269
|
+
)
|
|
2270
|
+
if value.const_value is None:
|
|
2271
|
+
raise ValueError(
|
|
2272
|
+
f"Value '{value!r}' must have its const_value set to be an initializer."
|
|
2273
|
+
)
|
|
2274
|
+
self._initializers[value.name] = value
|
|
2275
|
+
|
|
1809
2276
|
@property
|
|
1810
2277
|
def doc_string(self) -> str | None:
|
|
1811
2278
|
return self._doc_string
|
|
@@ -1818,7 +2285,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1818
2285
|
def opset_imports(self) -> dict[str, int]:
|
|
1819
2286
|
return self._opset_imports
|
|
1820
2287
|
|
|
1821
|
-
|
|
2288
|
+
@typing.overload
|
|
2289
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2290
|
+
@typing.overload
|
|
2291
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2292
|
+
|
|
2293
|
+
def __getitem__(self, index):
|
|
1822
2294
|
return self._nodes[index]
|
|
1823
2295
|
|
|
1824
2296
|
def __len__(self) -> int:
|
|
@@ -1830,6 +2302,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1830
2302
|
def __reversed__(self) -> Iterator[Node]:
|
|
1831
2303
|
return reversed(self._nodes)
|
|
1832
2304
|
|
|
2305
|
+
def _set_input_and_initializer_value_names_into_name_authority(self):
|
|
2306
|
+
for value in self.inputs:
|
|
2307
|
+
self._name_authority.register_or_name_value(value)
|
|
2308
|
+
for value in self.initializers.values():
|
|
2309
|
+
self._name_authority.register_or_name_value(value)
|
|
2310
|
+
|
|
1833
2311
|
def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
|
|
1834
2312
|
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
|
|
1835
2313
|
if node.graph is not None and node.graph is not self:
|
|
@@ -1993,7 +2471,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1993
2471
|
|
|
1994
2472
|
This sort is stable. It preserves the original order as much as possible.
|
|
1995
2473
|
|
|
1996
|
-
|
|
2474
|
+
Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
|
|
1997
2475
|
|
|
1998
2476
|
Raises:
|
|
1999
2477
|
ValueError: If the graph contains a cycle, making topological sorting impossible.
|
|
@@ -2004,7 +2482,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2004
2482
|
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
|
|
2005
2483
|
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
|
|
2006
2484
|
}
|
|
2007
|
-
# TODO: Explain why we need to store direct predecessors and children and why
|
|
2485
|
+
# TODO(justinchuby): Explain why we need to store direct predecessors and children and why
|
|
2008
2486
|
# we only need to store the direct ones
|
|
2009
2487
|
|
|
2010
2488
|
# The depth of a node is defined as the number of direct children it has
|
|
@@ -2024,7 +2502,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2024
2502
|
node_depth[predecessor] += 1
|
|
2025
2503
|
|
|
2026
2504
|
# 1. Build the direct predecessors of each node and the depth of each node
|
|
2027
|
-
# for sorting
|
|
2505
|
+
# for sorting topologically using Kahn's algorithm.
|
|
2028
2506
|
# Note that when a node contains graph attributes (aka. has subgraphs),
|
|
2029
2507
|
# we consider all nodes in the subgraphs *predecessors* of this node. This
|
|
2030
2508
|
# way we ensure the implicit dependencies of the subgraphs are captured
|
|
@@ -2125,11 +2603,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
|
|
|
2125
2603
|
)
|
|
2126
2604
|
signature = f"""\
|
|
2127
2605
|
graph(
|
|
2128
|
-
name={graph.name or
|
|
2129
|
-
inputs=({textwrap.indent(inputs_text,
|
|
2606
|
+
name={graph.name or "anonymous_graph:" + str(id(graph))},
|
|
2607
|
+
inputs=({textwrap.indent(inputs_text, " " * 8)}
|
|
2130
2608
|
),
|
|
2131
|
-
outputs=({textwrap.indent(outputs_text,
|
|
2132
|
-
),{textwrap.indent(initializers_text,
|
|
2609
|
+
outputs=({textwrap.indent(outputs_text, " " * 8)}
|
|
2610
|
+
),{textwrap.indent(initializers_text, " " * 4)}
|
|
2133
2611
|
)"""
|
|
2134
2612
|
node_count = len(graph)
|
|
2135
2613
|
number_width = len(str(node_count))
|
|
@@ -2163,11 +2641,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
|
|
|
2163
2641
|
)
|
|
2164
2642
|
return f"""\
|
|
2165
2643
|
{graph.__class__.__name__}(
|
|
2166
|
-
name={graph.name or
|
|
2167
|
-
inputs=({textwrap.indent(inputs_text,
|
|
2644
|
+
name={graph.name or "anonymous_graph:" + str(id(graph))!r},
|
|
2645
|
+
inputs=({textwrap.indent(inputs_text, " " * 8)}
|
|
2168
2646
|
),
|
|
2169
|
-
outputs=({textwrap.indent(outputs_text,
|
|
2170
|
-
),{textwrap.indent(initializers_text,
|
|
2647
|
+
outputs=({textwrap.indent(outputs_text, " " * 8)}
|
|
2648
|
+
),{textwrap.indent(initializers_text, " " * 4)}
|
|
2171
2649
|
len()={len(graph)}
|
|
2172
2650
|
)"""
|
|
2173
2651
|
|
|
@@ -2242,7 +2720,12 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2242
2720
|
self._metadata_props: dict[str, str] | None = metadata_props
|
|
2243
2721
|
self._nodes: tuple[Node, ...] = tuple(nodes)
|
|
2244
2722
|
|
|
2245
|
-
|
|
2723
|
+
@typing.overload
|
|
2724
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2725
|
+
@typing.overload
|
|
2726
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2727
|
+
|
|
2728
|
+
def __getitem__(self, index):
|
|
2246
2729
|
return self._nodes[index]
|
|
2247
2730
|
|
|
2248
2731
|
def __len__(self) -> int:
|
|
@@ -2318,7 +2801,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
|
|
|
2318
2801
|
model_version: int | None = None,
|
|
2319
2802
|
doc_string: str | None = None,
|
|
2320
2803
|
functions: Sequence[Function] = (),
|
|
2321
|
-
|
|
2804
|
+
metadata_props: dict[str, str] | None = None,
|
|
2322
2805
|
) -> None:
|
|
2323
2806
|
self.graph: Graph = graph
|
|
2324
2807
|
self.ir_version = ir_version
|
|
@@ -2329,7 +2812,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
|
|
|
2329
2812
|
self.doc_string = doc_string
|
|
2330
2813
|
self._functions = {func.identifier(): func for func in functions}
|
|
2331
2814
|
self._metadata: _metadata.MetadataStore | None = None
|
|
2332
|
-
self._metadata_props: dict[str, str] | None =
|
|
2815
|
+
self._metadata_props: dict[str, str] | None = metadata_props
|
|
2333
2816
|
|
|
2334
2817
|
@property
|
|
2335
2818
|
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
|
|
@@ -2381,9 +2864,28 @@ Model(
|
|
|
2381
2864
|
domain={self.domain!r},
|
|
2382
2865
|
model_version={self.model_version!r},
|
|
2383
2866
|
functions={self.functions!r},
|
|
2384
|
-
graph={textwrap.indent(repr(self.graph),
|
|
2867
|
+
graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
|
|
2385
2868
|
)"""
|
|
2386
2869
|
|
|
2870
|
+
def graphs(self) -> Iterable[Graph]:
|
|
2871
|
+
"""Get all graphs and subgraphs in the model.
|
|
2872
|
+
|
|
2873
|
+
This is a convenience method to traverse the model. Consider using
|
|
2874
|
+
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2875
|
+
traversals on nodes.
|
|
2876
|
+
"""
|
|
2877
|
+
# NOTE(justinchuby): Given
|
|
2878
|
+
# (1) how useful the method is
|
|
2879
|
+
# (2) I couldn't find an appropriate name for it in `traversal.py`
|
|
2880
|
+
# (3) Users familiar with onnxruntime optimization tools expect this method
|
|
2881
|
+
# I created this method as a core method instead of an iterator in
|
|
2882
|
+
# `traversal.py`.
|
|
2883
|
+
seen_graphs: set[Graph] = set()
|
|
2884
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
|
|
2885
|
+
if node.graph is not None and node.graph not in seen_graphs:
|
|
2886
|
+
seen_graphs.add(node.graph)
|
|
2887
|
+
yield node.graph
|
|
2888
|
+
|
|
2387
2889
|
|
|
2388
2890
|
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
2389
2891
|
"""IR functions.
|
|
@@ -2404,16 +2906,14 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2404
2906
|
outputs: The output values of the function.
|
|
2405
2907
|
opset_imports: Opsets imported by the function.
|
|
2406
2908
|
doc_string: Documentation string.
|
|
2407
|
-
metadata_props: Metadata that will be serialized to the ONNX file.
|
|
2408
2909
|
meta: Metadata store for graph transform passes.
|
|
2910
|
+
metadata_props: Metadata that will be serialized to the ONNX file.
|
|
2409
2911
|
"""
|
|
2410
2912
|
|
|
2411
2913
|
__slots__ = (
|
|
2412
2914
|
"_attributes",
|
|
2413
2915
|
"_domain",
|
|
2414
2916
|
"_graph",
|
|
2415
|
-
"_metadata",
|
|
2416
|
-
"_metadata_props",
|
|
2417
2917
|
"_name",
|
|
2418
2918
|
"_overload",
|
|
2419
2919
|
)
|
|
@@ -2428,15 +2928,12 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2428
2928
|
# and not from an outer scope
|
|
2429
2929
|
graph: Graph,
|
|
2430
2930
|
attributes: Sequence[Attr],
|
|
2431
|
-
metadata_props: dict[str, str] | None = None,
|
|
2432
2931
|
) -> None:
|
|
2433
2932
|
self._domain = domain
|
|
2434
2933
|
self._name = name
|
|
2435
2934
|
self._overload = overload
|
|
2436
2935
|
self._graph = graph
|
|
2437
2936
|
self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
|
|
2438
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
2439
|
-
self._metadata_props: dict[str, str] | None = metadata_props
|
|
2440
2937
|
|
|
2441
2938
|
def identifier(self) -> _protocols.OperatorIdentifier:
|
|
2442
2939
|
return self.domain, self.name, self.overload
|
|
@@ -2455,7 +2952,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2455
2952
|
|
|
2456
2953
|
@domain.setter
|
|
2457
2954
|
def domain(self, value: str) -> None:
|
|
2458
|
-
self._domain = value
|
|
2955
|
+
self._domain = _normalize_domain(value)
|
|
2459
2956
|
|
|
2460
2957
|
@property
|
|
2461
2958
|
def overload(self) -> str:
|
|
@@ -2466,18 +2963,23 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2466
2963
|
self._overload = value
|
|
2467
2964
|
|
|
2468
2965
|
@property
|
|
2469
|
-
def inputs(self) ->
|
|
2966
|
+
def inputs(self) -> MutableSequence[Value]:
|
|
2470
2967
|
return self._graph.inputs
|
|
2471
2968
|
|
|
2472
2969
|
@property
|
|
2473
|
-
def outputs(self) ->
|
|
2970
|
+
def outputs(self) -> MutableSequence[Value]:
|
|
2474
2971
|
return self._graph.outputs
|
|
2475
2972
|
|
|
2476
2973
|
@property
|
|
2477
2974
|
def attributes(self) -> OrderedDict[str, Attr]:
|
|
2478
2975
|
return self._attributes
|
|
2479
2976
|
|
|
2480
|
-
|
|
2977
|
+
@typing.overload
|
|
2978
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2979
|
+
@typing.overload
|
|
2980
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2981
|
+
|
|
2982
|
+
def __getitem__(self, index):
|
|
2481
2983
|
return self._graph.__getitem__(index)
|
|
2482
2984
|
|
|
2483
2985
|
def __len__(self) -> int:
|
|
@@ -2508,15 +3010,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2508
3010
|
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
2509
3011
|
to the ONNX proto.
|
|
2510
3012
|
"""
|
|
2511
|
-
|
|
2512
|
-
self._metadata = _metadata.MetadataStore()
|
|
2513
|
-
return self._metadata
|
|
3013
|
+
return self._graph.meta
|
|
2514
3014
|
|
|
2515
3015
|
@property
|
|
2516
3016
|
def metadata_props(self) -> dict[str, str]:
|
|
2517
|
-
|
|
2518
|
-
self._metadata_props = {}
|
|
2519
|
-
return self._metadata_props
|
|
3017
|
+
return self._graph.metadata_props
|
|
2520
3018
|
|
|
2521
3019
|
# Mutation methods
|
|
2522
3020
|
def append(self, node: Node, /) -> None:
|
|
@@ -2549,11 +3047,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2549
3047
|
"""
|
|
2550
3048
|
self._graph.remove(nodes, safe=safe)
|
|
2551
3049
|
|
|
2552
|
-
def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
|
|
3050
|
+
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
|
|
2553
3051
|
"""Insert new nodes after the given node in O(#new_nodes) time."""
|
|
2554
3052
|
self._graph.insert_after(node, new_nodes)
|
|
2555
3053
|
|
|
2556
|
-
def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
|
|
3054
|
+
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
|
|
2557
3055
|
"""Insert new nodes before the given node in O(#new_nodes) time."""
|
|
2558
3056
|
self._graph.insert_before(node, new_nodes)
|
|
2559
3057
|
|
|
@@ -2581,10 +3079,10 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2581
3079
|
>
|
|
2582
3080
|
def {full_name}(
|
|
2583
3081
|
inputs=(
|
|
2584
|
-
{textwrap.indent(inputs_text,
|
|
2585
|
-
),{textwrap.indent(attributes_text,
|
|
3082
|
+
{textwrap.indent(inputs_text, " " * 8)}
|
|
3083
|
+
),{textwrap.indent(attributes_text, " " * 4)}
|
|
2586
3084
|
outputs=(
|
|
2587
|
-
{textwrap.indent(outputs_text,
|
|
3085
|
+
{textwrap.indent(outputs_text, " " * 8)}
|
|
2588
3086
|
),
|
|
2589
3087
|
)"""
|
|
2590
3088
|
node_count = len(self)
|
|
@@ -2611,22 +3109,28 @@ def {full_name}(
|
|
|
2611
3109
|
return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
|
|
2612
3110
|
|
|
2613
3111
|
|
|
2614
|
-
class
|
|
2615
|
-
|
|
3112
|
+
class Attr(
|
|
3113
|
+
_protocols.AttributeProtocol,
|
|
3114
|
+
_protocols.ReferenceAttributeProtocol,
|
|
3115
|
+
_display.PrettyPrintable,
|
|
3116
|
+
):
|
|
3117
|
+
"""Base class for ONNX attributes or references."""
|
|
2616
3118
|
|
|
2617
|
-
__slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
|
|
3119
|
+
__slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
2618
3120
|
|
|
2619
3121
|
def __init__(
|
|
2620
3122
|
self,
|
|
2621
3123
|
name: str,
|
|
2622
|
-
ref_attr_name: str,
|
|
2623
3124
|
type: _enums.AttributeType,
|
|
3125
|
+
value: Any,
|
|
3126
|
+
ref_attr_name: str | None = None,
|
|
2624
3127
|
*,
|
|
2625
3128
|
doc_string: str | None = None,
|
|
2626
3129
|
) -> None:
|
|
2627
3130
|
self._name = name
|
|
2628
|
-
self._ref_attr_name = ref_attr_name
|
|
2629
3131
|
self._type = type
|
|
3132
|
+
self._value = value
|
|
3133
|
+
self._ref_attr_name = ref_attr_name
|
|
2630
3134
|
self.doc_string = doc_string
|
|
2631
3135
|
|
|
2632
3136
|
@property
|
|
@@ -2637,43 +3141,21 @@ class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
|
|
|
2637
3141
|
def name(self, value: str) -> None:
|
|
2638
3142
|
self._name = value
|
|
2639
3143
|
|
|
2640
|
-
@property
|
|
2641
|
-
def ref_attr_name(self) -> str:
|
|
2642
|
-
return self._ref_attr_name
|
|
2643
|
-
|
|
2644
|
-
@ref_attr_name.setter
|
|
2645
|
-
def ref_attr_name(self, value: str) -> None:
|
|
2646
|
-
self._ref_attr_name = value
|
|
2647
|
-
|
|
2648
3144
|
@property
|
|
2649
3145
|
def type(self) -> _enums.AttributeType:
|
|
2650
3146
|
return self._type
|
|
2651
3147
|
|
|
2652
|
-
@
|
|
2653
|
-
def
|
|
2654
|
-
self.
|
|
2655
|
-
|
|
2656
|
-
def __repr__(self) -> str:
|
|
2657
|
-
return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
|
|
2658
|
-
|
|
2659
|
-
|
|
2660
|
-
class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
|
|
2661
|
-
"""Base class for ONNX attributes."""
|
|
3148
|
+
@property
|
|
3149
|
+
def value(self) -> Any:
|
|
3150
|
+
return self._value
|
|
2662
3151
|
|
|
2663
|
-
|
|
3152
|
+
@property
|
|
3153
|
+
def ref_attr_name(self) -> str | None:
|
|
3154
|
+
return self._ref_attr_name
|
|
2664
3155
|
|
|
2665
|
-
def
|
|
2666
|
-
|
|
2667
|
-
|
|
2668
|
-
type: _enums.AttributeType,
|
|
2669
|
-
value: Any,
|
|
2670
|
-
*,
|
|
2671
|
-
doc_string: str | None = None,
|
|
2672
|
-
):
|
|
2673
|
-
self.name = name
|
|
2674
|
-
self.type = type
|
|
2675
|
-
self.value = value
|
|
2676
|
-
self.doc_string = doc_string
|
|
3156
|
+
def is_ref(self) -> bool:
|
|
3157
|
+
"""Check if this attribute is a reference attribute."""
|
|
3158
|
+
return self.ref_attr_name is not None
|
|
2677
3159
|
|
|
2678
3160
|
def __eq__(self, other: object) -> bool:
|
|
2679
3161
|
if not isinstance(other, _protocols.AttributeProtocol):
|
|
@@ -2690,15 +3172,157 @@ class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
|
|
|
2690
3172
|
return True
|
|
2691
3173
|
|
|
2692
3174
|
def __str__(self) -> str:
|
|
3175
|
+
if self.is_ref():
|
|
3176
|
+
return f"@{self.ref_attr_name}"
|
|
2693
3177
|
if self.type == _enums.AttributeType.GRAPH:
|
|
2694
3178
|
return textwrap.indent("\n" + str(self.value), " " * 4)
|
|
2695
3179
|
return str(self.value)
|
|
2696
3180
|
|
|
2697
3181
|
def __repr__(self) -> str:
|
|
3182
|
+
if self.is_ref():
|
|
3183
|
+
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
|
|
2698
3184
|
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
|
|
2699
3185
|
|
|
3186
|
+
# Well typed getters
|
|
3187
|
+
def as_float(self) -> float:
|
|
3188
|
+
"""Get the attribute value as a float."""
|
|
3189
|
+
if self.type != _enums.AttributeType.FLOAT:
|
|
3190
|
+
raise TypeError(
|
|
3191
|
+
f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
|
|
3192
|
+
)
|
|
3193
|
+
# Do not use isinstance check because it may prevent np.float32 etc. from being used
|
|
3194
|
+
return float(self.value)
|
|
3195
|
+
|
|
3196
|
+
def as_int(self) -> int:
|
|
3197
|
+
"""Get the attribute value as an int."""
|
|
3198
|
+
if self.type != _enums.AttributeType.INT:
|
|
3199
|
+
raise TypeError(
|
|
3200
|
+
f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
|
|
3201
|
+
)
|
|
3202
|
+
# Do not use isinstance check because it may prevent np.int32 etc. from being used
|
|
3203
|
+
return int(self.value)
|
|
3204
|
+
|
|
3205
|
+
def as_string(self) -> str:
|
|
3206
|
+
"""Get the attribute value as a string."""
|
|
3207
|
+
if self.type != _enums.AttributeType.STRING:
|
|
3208
|
+
raise TypeError(
|
|
3209
|
+
f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
|
|
3210
|
+
)
|
|
3211
|
+
if not isinstance(self.value, str):
|
|
3212
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
|
|
3213
|
+
return self.value
|
|
3214
|
+
|
|
3215
|
+
def as_tensor(self) -> _protocols.TensorProtocol:
|
|
3216
|
+
"""Get the attribute value as a tensor."""
|
|
3217
|
+
if self.type != _enums.AttributeType.TENSOR:
|
|
3218
|
+
raise TypeError(
|
|
3219
|
+
f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
|
|
3220
|
+
)
|
|
3221
|
+
if not isinstance(self.value, _protocols.TensorProtocol):
|
|
3222
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
|
|
3223
|
+
return self.value
|
|
3224
|
+
|
|
3225
|
+
def as_graph(self) -> Graph:
|
|
3226
|
+
"""Get the attribute value as a graph."""
|
|
3227
|
+
if self.type != _enums.AttributeType.GRAPH:
|
|
3228
|
+
raise TypeError(
|
|
3229
|
+
f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
|
|
3230
|
+
)
|
|
3231
|
+
if not isinstance(self.value, Graph):
|
|
3232
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
|
|
3233
|
+
return self.value
|
|
3234
|
+
|
|
3235
|
+
def as_floats(self) -> Sequence[float]:
|
|
3236
|
+
"""Get the attribute value as a sequence of floats."""
|
|
3237
|
+
if self.type != _enums.AttributeType.FLOATS:
|
|
3238
|
+
raise TypeError(
|
|
3239
|
+
f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
|
|
3240
|
+
)
|
|
3241
|
+
if not isinstance(self.value, Sequence):
|
|
3242
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3243
|
+
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
|
|
3244
|
+
# Create a copy of the list to prevent mutation
|
|
3245
|
+
return [float(v) for v in self.value]
|
|
3246
|
+
|
|
3247
|
+
def as_ints(self) -> Sequence[int]:
|
|
3248
|
+
"""Get the attribute value as a sequence of ints."""
|
|
3249
|
+
if self.type != _enums.AttributeType.INTS:
|
|
3250
|
+
raise TypeError(
|
|
3251
|
+
f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
|
|
3252
|
+
)
|
|
3253
|
+
if not isinstance(self.value, Sequence):
|
|
3254
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3255
|
+
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
|
|
3256
|
+
# Create a copy of the list to prevent mutation
|
|
3257
|
+
return list(self.value)
|
|
3258
|
+
|
|
3259
|
+
def as_strings(self) -> Sequence[str]:
|
|
3260
|
+
"""Get the attribute value as a sequence of strings."""
|
|
3261
|
+
if self.type != _enums.AttributeType.STRINGS:
|
|
3262
|
+
raise TypeError(
|
|
3263
|
+
f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
|
|
3264
|
+
)
|
|
3265
|
+
if not isinstance(self.value, Sequence):
|
|
3266
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3267
|
+
if onnx_ir.DEBUG:
|
|
3268
|
+
if not all(isinstance(x, str) for x in self.value):
|
|
3269
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
|
|
3270
|
+
# Create a copy of the list to prevent mutation
|
|
3271
|
+
return list(self.value)
|
|
3272
|
+
|
|
3273
|
+
def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
|
|
3274
|
+
"""Get the attribute value as a sequence of tensors."""
|
|
3275
|
+
if self.type != _enums.AttributeType.TENSORS:
|
|
3276
|
+
raise TypeError(
|
|
3277
|
+
f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
|
|
3278
|
+
)
|
|
3279
|
+
if not isinstance(self.value, Sequence):
|
|
3280
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3281
|
+
if onnx_ir.DEBUG:
|
|
3282
|
+
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
|
|
3283
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
|
|
3284
|
+
# Create a copy of the list to prevent mutation
|
|
3285
|
+
return list(self.value)
|
|
3286
|
+
|
|
3287
|
+
def as_graphs(self) -> Sequence[Graph]:
|
|
3288
|
+
"""Get the attribute value as a sequence of graphs."""
|
|
3289
|
+
if self.type != _enums.AttributeType.GRAPHS:
|
|
3290
|
+
raise TypeError(
|
|
3291
|
+
f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
|
|
3292
|
+
)
|
|
3293
|
+
if not isinstance(self.value, Sequence):
|
|
3294
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3295
|
+
if onnx_ir.DEBUG:
|
|
3296
|
+
if not all(isinstance(x, Graph) for x in self.value):
|
|
3297
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
|
|
3298
|
+
# Create a copy of the list to prevent mutation
|
|
3299
|
+
return list(self.value)
|
|
3300
|
+
|
|
2700
3301
|
|
|
2701
3302
|
# NOTE: The following functions are just for convenience
|
|
3303
|
+
|
|
3304
|
+
|
|
3305
|
+
def RefAttr(
|
|
3306
|
+
name: str,
|
|
3307
|
+
ref_attr_name: str,
|
|
3308
|
+
type: _enums.AttributeType,
|
|
3309
|
+
doc_string: str | None = None,
|
|
3310
|
+
) -> Attr:
|
|
3311
|
+
"""Create a reference attribute.
|
|
3312
|
+
|
|
3313
|
+
Args:
|
|
3314
|
+
name: The name of the attribute.
|
|
3315
|
+
type: The type of the attribute.
|
|
3316
|
+
ref_attr_name: The name of the referenced attribute.
|
|
3317
|
+
doc_string: Documentation string.
|
|
3318
|
+
|
|
3319
|
+
Returns:
|
|
3320
|
+
A reference attribute.
|
|
3321
|
+
"""
|
|
3322
|
+
# NOTE: The function name is capitalized to maintain API backward compatibility.
|
|
3323
|
+
return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
|
|
3324
|
+
|
|
3325
|
+
|
|
2702
3326
|
def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
|
|
2703
3327
|
"""Create a float attribute."""
|
|
2704
3328
|
# NOTE: The function name is capitalized to maintain API backward compatibility.
|