onnx-ir 0.0.1__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +874 -257
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +40 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +177 -0
- onnx_ir/passes/common/constant_manipulation.py +217 -0
- onnx_ir/passes/common/inliner.py +332 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.1.dist-info/METADATA +53 -0
- onnx_ir-0.1.1.dist-info/RECORD +42 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.1.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.1.dist-info}/top_level.txt +0 -0
onnx_ir/_core.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""data structures for the intermediate representation."""
|
|
4
4
|
|
|
5
5
|
# NOTES for developers:
|
|
@@ -22,26 +22,36 @@ import os
|
|
|
22
22
|
import sys
|
|
23
23
|
import textwrap
|
|
24
24
|
import typing
|
|
25
|
-
from
|
|
26
|
-
AbstractSet,
|
|
27
|
-
Any,
|
|
25
|
+
from collections.abc import (
|
|
28
26
|
Collection,
|
|
29
|
-
Generic,
|
|
30
27
|
Hashable,
|
|
31
28
|
Iterable,
|
|
32
29
|
Iterator,
|
|
33
|
-
|
|
30
|
+
Mapping,
|
|
31
|
+
MutableSequence,
|
|
34
32
|
Sequence,
|
|
33
|
+
)
|
|
34
|
+
from collections.abc import (
|
|
35
|
+
Set as AbstractSet,
|
|
36
|
+
)
|
|
37
|
+
from typing import (
|
|
38
|
+
Any,
|
|
39
|
+
Callable,
|
|
40
|
+
Generic,
|
|
41
|
+
NamedTuple,
|
|
42
|
+
SupportsInt,
|
|
35
43
|
Union,
|
|
36
44
|
)
|
|
37
45
|
|
|
38
46
|
import ml_dtypes
|
|
39
47
|
import numpy as np
|
|
48
|
+
from typing_extensions import TypeIs
|
|
40
49
|
|
|
41
50
|
import onnx_ir
|
|
42
51
|
from onnx_ir import (
|
|
43
52
|
_display,
|
|
44
53
|
_enums,
|
|
54
|
+
_graph_containers,
|
|
45
55
|
_linked_list,
|
|
46
56
|
_metadata,
|
|
47
57
|
_name_authority,
|
|
@@ -70,6 +80,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
|
|
|
70
80
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
71
81
|
_enums.DataType.INT4,
|
|
72
82
|
_enums.DataType.UINT4,
|
|
83
|
+
_enums.DataType.FLOAT4E2M1,
|
|
73
84
|
)
|
|
74
85
|
)
|
|
75
86
|
|
|
@@ -93,7 +104,23 @@ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
|
|
|
93
104
|
class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
94
105
|
"""Convenience Shared methods for classes implementing TensorProtocol."""
|
|
95
106
|
|
|
96
|
-
__slots__ = (
|
|
107
|
+
__slots__ = (
|
|
108
|
+
"_doc_string",
|
|
109
|
+
"_metadata",
|
|
110
|
+
"_metadata_props",
|
|
111
|
+
"_name",
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
def __init__(
|
|
115
|
+
self,
|
|
116
|
+
name: str | None = None,
|
|
117
|
+
doc_string: str | None = None,
|
|
118
|
+
metadata_props: dict[str, str] | None = None,
|
|
119
|
+
) -> None:
|
|
120
|
+
self._metadata: _metadata.MetadataStore | None = None
|
|
121
|
+
self._metadata_props: dict[str, str] | None = metadata_props
|
|
122
|
+
self._name: str | None = name
|
|
123
|
+
self._doc_string: str | None = doc_string
|
|
97
124
|
|
|
98
125
|
def _printable_type_shape(self) -> str:
|
|
99
126
|
"""Return a string representation of the shape and data type."""
|
|
@@ -106,10 +133,28 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
|
106
133
|
"""
|
|
107
134
|
return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
|
|
108
135
|
|
|
136
|
+
@property
|
|
137
|
+
def name(self) -> str | None:
|
|
138
|
+
"""The name of the tensor."""
|
|
139
|
+
return self._name
|
|
140
|
+
|
|
141
|
+
@name.setter
|
|
142
|
+
def name(self, value: str | None) -> None:
|
|
143
|
+
self._name = value
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
def doc_string(self) -> str | None:
|
|
147
|
+
"""The documentation string."""
|
|
148
|
+
return self._doc_string
|
|
149
|
+
|
|
150
|
+
@doc_string.setter
|
|
151
|
+
def doc_string(self, value: str | None) -> None:
|
|
152
|
+
self._doc_string = value
|
|
153
|
+
|
|
109
154
|
@property
|
|
110
155
|
def size(self) -> int:
|
|
111
156
|
"""The number of elements in the tensor."""
|
|
112
|
-
return
|
|
157
|
+
return math.prod(self.shape.numpy()) # type: ignore[attr-defined]
|
|
113
158
|
|
|
114
159
|
@property
|
|
115
160
|
def nbytes(self) -> int:
|
|
@@ -117,6 +162,23 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
|
117
162
|
# Use math.ceil because when dtype is INT4, the itemsize is 0.5
|
|
118
163
|
return math.ceil(self.dtype.itemsize * self.size)
|
|
119
164
|
|
|
165
|
+
@property
|
|
166
|
+
def metadata_props(self) -> dict[str, str]:
|
|
167
|
+
if self._metadata_props is None:
|
|
168
|
+
self._metadata_props = {}
|
|
169
|
+
return self._metadata_props
|
|
170
|
+
|
|
171
|
+
@property
|
|
172
|
+
def meta(self) -> _metadata.MetadataStore:
|
|
173
|
+
"""The metadata store for intermediate analysis.
|
|
174
|
+
|
|
175
|
+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
176
|
+
to the ONNX proto.
|
|
177
|
+
"""
|
|
178
|
+
if self._metadata is None:
|
|
179
|
+
self._metadata = _metadata.MetadataStore()
|
|
180
|
+
return self._metadata
|
|
181
|
+
|
|
120
182
|
def display(self, *, page: bool = False) -> None:
|
|
121
183
|
rich = _display.require_rich()
|
|
122
184
|
|
|
@@ -182,7 +244,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
182
244
|
When the dtype is not one of the numpy native dtypes, the value needs need to be:
|
|
183
245
|
|
|
184
246
|
- ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
|
|
185
|
-
- ``uint8`` for uint4.
|
|
247
|
+
- ``uint8`` for uint4 or float4.
|
|
186
248
|
- ``uint8`` for 8-bit data types.
|
|
187
249
|
- ``uint16`` for bfloat16
|
|
188
250
|
|
|
@@ -195,7 +257,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
195
257
|
)
|
|
196
258
|
if dtype.itemsize == 1 and array.dtype not in (
|
|
197
259
|
np.uint8,
|
|
198
|
-
ml_dtypes.
|
|
260
|
+
ml_dtypes.float8_e4m3fnuz,
|
|
199
261
|
ml_dtypes.float8_e4m3fn,
|
|
200
262
|
ml_dtypes.float8_e5m2fnuz,
|
|
201
263
|
ml_dtypes.float8_e5m2,
|
|
@@ -213,6 +275,11 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
213
275
|
raise TypeError(
|
|
214
276
|
f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
|
|
215
277
|
)
|
|
278
|
+
if dtype == _enums.DataType.FLOAT4E2M1:
|
|
279
|
+
if array.dtype not in (np.uint8, ml_dtypes.float4_e2m1fn):
|
|
280
|
+
raise TypeError(
|
|
281
|
+
f"The numpy array dtype must be uint8 or ml_dtypes.float4_e2m1fn (not {array.dtype}) for IR data type {dtype}."
|
|
282
|
+
)
|
|
216
283
|
return
|
|
217
284
|
|
|
218
285
|
try:
|
|
@@ -256,6 +323,8 @@ def _maybe_view_np_array_with_ml_dtypes(
|
|
|
256
323
|
return array.view(ml_dtypes.int4)
|
|
257
324
|
if dtype == _enums.DataType.UINT4:
|
|
258
325
|
return array.view(ml_dtypes.uint4)
|
|
326
|
+
if dtype == _enums.DataType.FLOAT4E2M1:
|
|
327
|
+
return array.view(ml_dtypes.float4_e2m1fn)
|
|
259
328
|
return array
|
|
260
329
|
|
|
261
330
|
|
|
@@ -298,12 +367,8 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
298
367
|
|
|
299
368
|
__slots__ = (
|
|
300
369
|
"_dtype",
|
|
301
|
-
"_metadata",
|
|
302
|
-
"_metadata_props",
|
|
303
370
|
"_raw",
|
|
304
371
|
"_shape",
|
|
305
|
-
"doc_string",
|
|
306
|
-
"name",
|
|
307
372
|
)
|
|
308
373
|
|
|
309
374
|
def __init__(
|
|
@@ -322,7 +387,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
322
387
|
value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
|
|
323
388
|
When the dtype is not one of the numpy native dtypes, the value needs
|
|
324
389
|
to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
|
|
325
|
-
when the value is a numpy array;
|
|
390
|
+
when the value is a numpy array; ``dtype`` must be specified in this case.
|
|
326
391
|
dtype: The data type of the tensor. It can be None only when value is a numpy array.
|
|
327
392
|
Users are responsible for making sure the dtype matches the value when value is not a numpy array.
|
|
328
393
|
shape: The shape of the tensor. If None, the shape is obtained from the value.
|
|
@@ -336,6 +401,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
336
401
|
ValueError: If the shape is not specified and the value does not have a shape attribute.
|
|
337
402
|
ValueError: If the dtype is not specified and the value is not a numpy array.
|
|
338
403
|
"""
|
|
404
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
339
405
|
# NOTE: We should not do any copying here for performance reasons
|
|
340
406
|
if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
|
|
341
407
|
raise TypeError(f"Expected an array compatible object, got {type(value)}")
|
|
@@ -349,7 +415,7 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
349
415
|
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
|
|
350
416
|
else:
|
|
351
417
|
self._shape = shape
|
|
352
|
-
self._shape.
|
|
418
|
+
self._shape.freeze()
|
|
353
419
|
if dtype is None:
|
|
354
420
|
if isinstance(value, np.ndarray):
|
|
355
421
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
@@ -370,17 +436,13 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
370
436
|
value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
|
|
371
437
|
|
|
372
438
|
self._raw = value
|
|
373
|
-
self.name = name
|
|
374
|
-
self.doc_string = doc_string
|
|
375
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
376
|
-
self._metadata_props = metadata_props
|
|
377
439
|
|
|
378
440
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
379
441
|
if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
|
|
380
442
|
return self._raw.__array__(dtype)
|
|
381
|
-
assert _compatible_with_dlpack(
|
|
382
|
-
self._raw
|
|
383
|
-
)
|
|
443
|
+
assert _compatible_with_dlpack(self._raw), (
|
|
444
|
+
f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
|
|
445
|
+
)
|
|
384
446
|
return np.from_dlpack(self._raw)
|
|
385
447
|
|
|
386
448
|
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
@@ -394,7 +456,10 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
394
456
|
return self.__array__().__dlpack_device__()
|
|
395
457
|
|
|
396
458
|
def __repr__(self) -> str:
|
|
397
|
-
|
|
459
|
+
# Avoid multi-line repr
|
|
460
|
+
tensor_lines = repr(self._raw).split("\n")
|
|
461
|
+
tensor_text = " ".join(line.strip() for line in tensor_lines)
|
|
462
|
+
return f"{self._repr_base()}({tensor_text}, name={self.name!r})"
|
|
398
463
|
|
|
399
464
|
@property
|
|
400
465
|
def dtype(self) -> _enums.DataType:
|
|
@@ -431,7 +496,11 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
431
496
|
"""
|
|
432
497
|
# TODO(justinchuby): Support DLPack
|
|
433
498
|
array = self.numpy()
|
|
434
|
-
if self.dtype in {
|
|
499
|
+
if self.dtype in {
|
|
500
|
+
_enums.DataType.INT4,
|
|
501
|
+
_enums.DataType.UINT4,
|
|
502
|
+
_enums.DataType.FLOAT4E2M1,
|
|
503
|
+
}:
|
|
435
504
|
# Pack the array into int4
|
|
436
505
|
array = _type_casting.pack_int4(array)
|
|
437
506
|
else:
|
|
@@ -440,23 +509,6 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
440
509
|
array = array.view(array.dtype.newbyteorder("<"))
|
|
441
510
|
return array.tobytes()
|
|
442
511
|
|
|
443
|
-
@property
|
|
444
|
-
def metadata_props(self) -> dict[str, str]:
|
|
445
|
-
if self._metadata_props is None:
|
|
446
|
-
self._metadata_props = {}
|
|
447
|
-
return self._metadata_props
|
|
448
|
-
|
|
449
|
-
@property
|
|
450
|
-
def meta(self) -> _metadata.MetadataStore:
|
|
451
|
-
"""The metadata store for intermediate analysis.
|
|
452
|
-
|
|
453
|
-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
454
|
-
to the ONNX proto.
|
|
455
|
-
"""
|
|
456
|
-
if self._metadata is None:
|
|
457
|
-
self._metadata = _metadata.MetadataStore()
|
|
458
|
-
return self._metadata
|
|
459
|
-
|
|
460
512
|
|
|
461
513
|
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
462
514
|
"""An immutable concrete tensor with its data store on disk.
|
|
@@ -497,12 +549,9 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
497
549
|
"_dtype",
|
|
498
550
|
"_length",
|
|
499
551
|
"_location",
|
|
500
|
-
"_metadata",
|
|
501
|
-
"_metadata_props",
|
|
502
552
|
"_offset",
|
|
503
553
|
"_shape",
|
|
504
|
-
"
|
|
505
|
-
"name",
|
|
554
|
+
"_valid",
|
|
506
555
|
"raw",
|
|
507
556
|
)
|
|
508
557
|
|
|
@@ -532,6 +581,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
532
581
|
metadata_props: The metadata properties.
|
|
533
582
|
base_dir: The base directory for the external data. It is used to resolve relative paths.
|
|
534
583
|
"""
|
|
584
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
535
585
|
# NOTE: Do not verify the location by default. This is because the location field
|
|
536
586
|
# in the tensor proto can be anything and we would like deserialization from
|
|
537
587
|
# proto to IR to not fail.
|
|
@@ -547,12 +597,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
547
597
|
self._dtype: _enums.DataType = dtype
|
|
548
598
|
self.name: str = name # mutable
|
|
549
599
|
self._shape: Shape = shape
|
|
550
|
-
self._shape.
|
|
600
|
+
self._shape.freeze()
|
|
551
601
|
self.doc_string: str | None = doc_string # mutable
|
|
552
602
|
self._array: np.ndarray | None = None
|
|
553
603
|
self.raw: mmap.mmap | None = None
|
|
554
604
|
self._metadata_props = metadata_props
|
|
555
605
|
self._metadata: _metadata.MetadataStore | None = None
|
|
606
|
+
self._valid = True
|
|
556
607
|
|
|
557
608
|
@property
|
|
558
609
|
def base_dir(self) -> str | os.PathLike:
|
|
@@ -594,6 +645,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
594
645
|
return self._shape
|
|
595
646
|
|
|
596
647
|
def _load(self):
|
|
648
|
+
self._check_validity()
|
|
597
649
|
assert self._array is None, "Bug: The array should be loaded only once."
|
|
598
650
|
if self.size == 0:
|
|
599
651
|
# When the size is 0, mmap is impossible and meaningless
|
|
@@ -609,7 +661,11 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
609
661
|
)
|
|
610
662
|
# Handle the byte order correctly by always using little endian
|
|
611
663
|
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
612
|
-
if self.dtype in {
|
|
664
|
+
if self.dtype in {
|
|
665
|
+
_enums.DataType.INT4,
|
|
666
|
+
_enums.DataType.UINT4,
|
|
667
|
+
_enums.DataType.FLOAT4E2M1,
|
|
668
|
+
}:
|
|
613
669
|
# Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
|
|
614
670
|
dt = np.dtype(np.uint8).newbyteorder("<")
|
|
615
671
|
count = self.size // 2 + self.size % 2
|
|
@@ -622,10 +678,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
622
678
|
self._array = _type_casting.unpack_int4(self._array, shape)
|
|
623
679
|
elif self.dtype == _enums.DataType.UINT4:
|
|
624
680
|
self._array = _type_casting.unpack_uint4(self._array, shape)
|
|
681
|
+
elif self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
682
|
+
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
|
|
625
683
|
else:
|
|
626
684
|
self._array = self._array.reshape(shape)
|
|
627
685
|
|
|
628
686
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
687
|
+
self._check_validity()
|
|
629
688
|
if self._array is None:
|
|
630
689
|
self._load()
|
|
631
690
|
assert self._array is not None
|
|
@@ -654,6 +713,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
654
713
|
|
|
655
714
|
The data will be memory mapped into memory and will not taken up physical memory space.
|
|
656
715
|
"""
|
|
716
|
+
self._check_validity()
|
|
657
717
|
if self._array is None:
|
|
658
718
|
self._load()
|
|
659
719
|
assert self._array is not None
|
|
@@ -664,6 +724,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
664
724
|
|
|
665
725
|
This will load the tensor into memory.
|
|
666
726
|
"""
|
|
727
|
+
self._check_validity()
|
|
667
728
|
if self.raw is None:
|
|
668
729
|
self._load()
|
|
669
730
|
assert self.raw is not None
|
|
@@ -671,6 +732,26 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
671
732
|
length = self._length or self.nbytes
|
|
672
733
|
return self.raw[offset : offset + length]
|
|
673
734
|
|
|
735
|
+
def valid(self) -> bool:
|
|
736
|
+
"""Check if the tensor is valid.
|
|
737
|
+
|
|
738
|
+
The external tensor is valid if it has not been invalidated.
|
|
739
|
+
"""
|
|
740
|
+
return self._valid
|
|
741
|
+
|
|
742
|
+
def _check_validity(self) -> None:
|
|
743
|
+
if not self.valid():
|
|
744
|
+
raise ValueError(
|
|
745
|
+
f"The external tensor '{self!r}' is invalidated. The data may be corrupted or deleted."
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
def invalidate(self) -> None:
|
|
749
|
+
"""Invalidate the tensor.
|
|
750
|
+
|
|
751
|
+
The external tensor is invalidated when the data is known to be corrupted or deleted.
|
|
752
|
+
"""
|
|
753
|
+
self._valid = False
|
|
754
|
+
|
|
674
755
|
def release(self) -> None:
|
|
675
756
|
"""Delete all references to the memory buffer and close the memory-mapped file."""
|
|
676
757
|
self._array = None
|
|
@@ -678,34 +759,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
678
759
|
self.raw.close()
|
|
679
760
|
self.raw = None
|
|
680
761
|
|
|
681
|
-
@property
|
|
682
|
-
def metadata_props(self) -> dict[str, str]:
|
|
683
|
-
if self._metadata_props is None:
|
|
684
|
-
self._metadata_props = {}
|
|
685
|
-
return self._metadata_props
|
|
686
|
-
|
|
687
|
-
@property
|
|
688
|
-
def meta(self) -> _metadata.MetadataStore:
|
|
689
|
-
"""The metadata store for intermediate analysis.
|
|
690
|
-
|
|
691
|
-
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
692
|
-
to the ONNX proto.
|
|
693
|
-
"""
|
|
694
|
-
if self._metadata is None:
|
|
695
|
-
self._metadata = _metadata.MetadataStore()
|
|
696
|
-
return self._metadata
|
|
697
|
-
|
|
698
762
|
|
|
699
763
|
class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
700
764
|
"""Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
|
|
701
765
|
|
|
702
766
|
__slots__ = (
|
|
703
|
-
"_metadata",
|
|
704
|
-
"_metadata_props",
|
|
705
767
|
"_raw",
|
|
706
768
|
"_shape",
|
|
707
|
-
"doc_string",
|
|
708
|
-
"name",
|
|
709
769
|
)
|
|
710
770
|
|
|
711
771
|
def __init__(
|
|
@@ -726,6 +786,7 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
726
786
|
doc_string: The documentation string.
|
|
727
787
|
metadata_props: The metadata properties.
|
|
728
788
|
"""
|
|
789
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
729
790
|
if shape is None:
|
|
730
791
|
if not hasattr(value, "shape"):
|
|
731
792
|
raise ValueError(
|
|
@@ -735,19 +796,15 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
735
796
|
self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
|
|
736
797
|
else:
|
|
737
798
|
self._shape = shape
|
|
738
|
-
self._shape.
|
|
799
|
+
self._shape.freeze()
|
|
739
800
|
self._raw = value
|
|
740
|
-
self.name = name
|
|
741
|
-
self.doc_string = doc_string
|
|
742
|
-
self._metadata: _metadata.MetadataStore | None = None
|
|
743
|
-
self._metadata_props = metadata_props
|
|
744
801
|
|
|
745
802
|
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
746
803
|
if isinstance(self._raw, np.ndarray):
|
|
747
804
|
return self._raw
|
|
748
|
-
assert isinstance(
|
|
749
|
-
self._raw
|
|
750
|
-
)
|
|
805
|
+
assert isinstance(self._raw, Sequence), (
|
|
806
|
+
f"Bug: Expected a sequence, got {type(self._raw)}"
|
|
807
|
+
)
|
|
751
808
|
return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
|
|
752
809
|
|
|
753
810
|
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
@@ -788,25 +845,125 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
788
845
|
return self._raw.flatten().tolist()
|
|
789
846
|
return self._raw
|
|
790
847
|
|
|
848
|
+
|
|
849
|
+
class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
850
|
+
"""A tensor that lazily evaluates a function to get the actual tensor.
|
|
851
|
+
|
|
852
|
+
This class takes a function returning an `ir.TensorProtocol`, a dtype, and a shape argument.
|
|
853
|
+
The function is lazily evaluated to get the actual tensor when `tobytes()` or `numpy()` is called.
|
|
854
|
+
|
|
855
|
+
Example::
|
|
856
|
+
|
|
857
|
+
>>> import numpy as np
|
|
858
|
+
>>> import onnx_ir as ir
|
|
859
|
+
>>> weights = np.array([[1, 2, 3]])
|
|
860
|
+
>>> def create_tensor(): # Delay applying transformations to the weights
|
|
861
|
+
... weights_t = weights.transpose()
|
|
862
|
+
... return ir.tensor(weights_t)
|
|
863
|
+
>>> lazy_tensor = ir.LazyTensor(create_tensor, dtype=ir.DataType.INT64, shape=ir.Shape([1, 3]))
|
|
864
|
+
>>> print(lazy_tensor.numpy())
|
|
865
|
+
[[1]
|
|
866
|
+
[2]
|
|
867
|
+
[3]]
|
|
868
|
+
|
|
869
|
+
Attributes:
|
|
870
|
+
func: The function that returns the actual tensor.
|
|
871
|
+
dtype: The data type of the tensor.
|
|
872
|
+
shape: The shape of the tensor.
|
|
873
|
+
cache: Whether to cache the result of the function. If False,
|
|
874
|
+
the function is called every time the tensor content is accessed.
|
|
875
|
+
If True, the function is called only once and the result is cached in memory.
|
|
876
|
+
Default is False.
|
|
877
|
+
name: The name of the tensor.
|
|
878
|
+
doc_string: The documentation string.
|
|
879
|
+
metadata_props: The metadata properties.
|
|
880
|
+
"""
|
|
881
|
+
|
|
882
|
+
__slots__ = (
|
|
883
|
+
"_dtype",
|
|
884
|
+
"_func",
|
|
885
|
+
"_shape",
|
|
886
|
+
"_tensor",
|
|
887
|
+
"cache",
|
|
888
|
+
)
|
|
889
|
+
|
|
890
|
+
def __init__(
|
|
891
|
+
self,
|
|
892
|
+
func: Callable[[], _protocols.TensorProtocol],
|
|
893
|
+
dtype: _enums.DataType,
|
|
894
|
+
shape: Shape,
|
|
895
|
+
*,
|
|
896
|
+
cache: bool = False,
|
|
897
|
+
name: str | None = None,
|
|
898
|
+
doc_string: str | None = None,
|
|
899
|
+
metadata_props: dict[str, str] | None = None,
|
|
900
|
+
) -> None:
|
|
901
|
+
"""Initialize a lazy tensor.
|
|
902
|
+
|
|
903
|
+
Args:
|
|
904
|
+
func: The function that returns the actual tensor.
|
|
905
|
+
dtype: The data type of the tensor.
|
|
906
|
+
shape: The shape of the tensor.
|
|
907
|
+
cache: Whether to cache the result of the function.
|
|
908
|
+
name: The name of the tensor.
|
|
909
|
+
doc_string: The documentation string.
|
|
910
|
+
metadata_props: The metadata properties.
|
|
911
|
+
"""
|
|
912
|
+
super().__init__(name=name, doc_string=doc_string, metadata_props=metadata_props)
|
|
913
|
+
self._func = func
|
|
914
|
+
self._dtype = dtype
|
|
915
|
+
self._shape = shape
|
|
916
|
+
self._tensor: _protocols.TensorProtocol | None = None
|
|
917
|
+
self.cache = cache
|
|
918
|
+
|
|
919
|
+
def _evaluate(self) -> _protocols.TensorProtocol:
|
|
920
|
+
"""Evaluate the function to get the actual tensor."""
|
|
921
|
+
if not self.cache:
|
|
922
|
+
return self._func()
|
|
923
|
+
|
|
924
|
+
# Cache the tensor
|
|
925
|
+
if self._tensor is None:
|
|
926
|
+
self._tensor = self._func()
|
|
927
|
+
return self._tensor
|
|
928
|
+
|
|
929
|
+
def __array__(self, dtype: Any = None) -> np.ndarray:
|
|
930
|
+
return self._evaluate().__array__(dtype)
|
|
931
|
+
|
|
932
|
+
def __dlpack__(self, *, stream: Any = None) -> Any:
|
|
933
|
+
return self._evaluate().__dlpack__(stream=stream)
|
|
934
|
+
|
|
935
|
+
def __dlpack_device__(self) -> tuple[int, int]:
|
|
936
|
+
return self._evaluate().__dlpack_device__()
|
|
937
|
+
|
|
938
|
+
def __repr__(self) -> str:
|
|
939
|
+
return f"{self._repr_base()}(func={self._func!r}, name={self.name!r})"
|
|
940
|
+
|
|
791
941
|
@property
|
|
792
|
-
def
|
|
793
|
-
|
|
794
|
-
self._metadata_props = {}
|
|
795
|
-
return self._metadata_props
|
|
942
|
+
def raw(self) -> Callable[[], _protocols.TensorProtocol]:
|
|
943
|
+
return self._func
|
|
796
944
|
|
|
797
945
|
@property
|
|
798
|
-
def
|
|
799
|
-
"""The
|
|
946
|
+
def dtype(self) -> _enums.DataType:
|
|
947
|
+
"""The data type of the tensor. Immutable."""
|
|
948
|
+
return self._dtype
|
|
800
949
|
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
"""
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
950
|
+
@property
|
|
951
|
+
def shape(self) -> Shape:
|
|
952
|
+
"""The shape of the tensor. Immutable."""
|
|
953
|
+
return self._shape
|
|
954
|
+
|
|
955
|
+
def numpy(self) -> np.ndarray:
|
|
956
|
+
"""Return the tensor as a numpy array."""
|
|
957
|
+
return self._evaluate().numpy()
|
|
958
|
+
|
|
959
|
+
def tobytes(self) -> bytes:
|
|
960
|
+
"""Return the bytes of the tensor."""
|
|
961
|
+
return self._evaluate().tobytes()
|
|
807
962
|
|
|
808
963
|
|
|
809
964
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
965
|
+
"""Immutable symbolic dimension that can be shared across multiple shapes."""
|
|
966
|
+
|
|
810
967
|
__slots__ = ("_value",)
|
|
811
968
|
|
|
812
969
|
def __init__(self, value: str | None) -> None:
|
|
@@ -841,12 +998,84 @@ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
|
841
998
|
return f"{self.__class__.__name__}({self._value})"
|
|
842
999
|
|
|
843
1000
|
|
|
1001
|
+
def _is_int_compatible(value: object) -> TypeIs[SupportsInt]:
|
|
1002
|
+
"""Return True if the value is int compatible."""
|
|
1003
|
+
if isinstance(value, int):
|
|
1004
|
+
return True
|
|
1005
|
+
if hasattr(value, "__int__"):
|
|
1006
|
+
# For performance reasons, we do not use isinstance(value, SupportsInt)
|
|
1007
|
+
return True
|
|
1008
|
+
return False
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
def _maybe_convert_to_symbolic_dim(
|
|
1012
|
+
dim: int | SupportsInt | SymbolicDim | str | None,
|
|
1013
|
+
) -> SymbolicDim | int:
|
|
1014
|
+
"""Convert the value to a SymbolicDim if it is not an int."""
|
|
1015
|
+
if dim is None or isinstance(dim, str):
|
|
1016
|
+
return SymbolicDim(dim)
|
|
1017
|
+
if _is_int_compatible(dim):
|
|
1018
|
+
return int(dim)
|
|
1019
|
+
if isinstance(dim, SymbolicDim):
|
|
1020
|
+
return dim
|
|
1021
|
+
raise TypeError(
|
|
1022
|
+
f"Expected int, str, None or SymbolicDim, but value {dim!r} has type '{type(dim)}'"
|
|
1023
|
+
)
|
|
1024
|
+
|
|
1025
|
+
|
|
844
1026
|
class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
1027
|
+
"""The shape of a tensor, including its dimensions and optional denotations.
|
|
1028
|
+
|
|
1029
|
+
The :class:`Shape` stores the dimensions of a tensor, which can be integers, None (unknown), or
|
|
1030
|
+
symbolic dimensions.
|
|
1031
|
+
|
|
1032
|
+
A shape can be compared to another shape or plain Python list.
|
|
1033
|
+
|
|
1034
|
+
A shape can be frozen (made immutable). When the shape is frozen, it cannot be
|
|
1035
|
+
unfrozen, making it suitable to be shared across tensors or values.
|
|
1036
|
+
Call :method:`freeze` to freeze the shape.
|
|
1037
|
+
|
|
1038
|
+
To update the dimension of a frozen shape, call :method:`copy` to create a
|
|
1039
|
+
new shape with the same dimensions that can be modified.
|
|
1040
|
+
|
|
1041
|
+
Use :method:`get_denotation` and :method:`set_denotation` to access and modify the denotations.
|
|
1042
|
+
|
|
1043
|
+
Example::
|
|
1044
|
+
|
|
1045
|
+
>>> import onnx_ir as ir
|
|
1046
|
+
>>> shape = ir.Shape(["B", None, 3])
|
|
1047
|
+
>>> shape.rank()
|
|
1048
|
+
3
|
|
1049
|
+
>>> shape.is_static()
|
|
1050
|
+
False
|
|
1051
|
+
>>> shape.is_dynamic()
|
|
1052
|
+
True
|
|
1053
|
+
>>> shape.is_static(dim=2)
|
|
1054
|
+
True
|
|
1055
|
+
>>> shape[0] = 1
|
|
1056
|
+
>>> shape[1] = 2
|
|
1057
|
+
>>> shape.dims
|
|
1058
|
+
(1, 2, 3)
|
|
1059
|
+
>>> shape == [1, 2, 3]
|
|
1060
|
+
True
|
|
1061
|
+
>>> shape.frozen
|
|
1062
|
+
False
|
|
1063
|
+
>>> shape.freeze()
|
|
1064
|
+
>>> shape.frozen
|
|
1065
|
+
True
|
|
1066
|
+
|
|
1067
|
+
Attributes:
|
|
1068
|
+
dims: A tuple of dimensions representing the shape.
|
|
1069
|
+
Each dimension can be an integer, None or a :class:`SymbolicDim`.
|
|
1070
|
+
frozen: Indicates whether the shape is immutable. When frozen, the shape
|
|
1071
|
+
cannot be modified or unfrozen.
|
|
1072
|
+
"""
|
|
1073
|
+
|
|
845
1074
|
__slots__ = ("_dims", "_frozen")
|
|
846
1075
|
|
|
847
1076
|
def __init__(
|
|
848
1077
|
self,
|
|
849
|
-
dims: Iterable[int | SymbolicDim | str | None],
|
|
1078
|
+
dims: Iterable[int | SupportsInt | SymbolicDim | str | None],
|
|
850
1079
|
/,
|
|
851
1080
|
denotations: Iterable[str | None] | None = None,
|
|
852
1081
|
frozen: bool = False,
|
|
@@ -864,11 +1093,11 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
864
1093
|
Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
|
|
865
1094
|
for pre-defined dimension denotations.
|
|
866
1095
|
frozen: If True, the shape is immutable and cannot be modified. This
|
|
867
|
-
is useful when the shape is initialized by a Tensor
|
|
1096
|
+
is useful when the shape is initialized by a Tensor or when the shape
|
|
1097
|
+
is shared across multiple tensors. The default is False.
|
|
868
1098
|
"""
|
|
869
1099
|
self._dims: list[int | SymbolicDim] = [
|
|
870
|
-
|
|
871
|
-
for dim in dims
|
|
1100
|
+
_maybe_convert_to_symbolic_dim(dim) for dim in dims
|
|
872
1101
|
]
|
|
873
1102
|
self._denotations: list[str | None] = (
|
|
874
1103
|
list(denotations) if denotations is not None else [None] * len(self._dims)
|
|
@@ -879,10 +1108,6 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
879
1108
|
)
|
|
880
1109
|
self._frozen: bool = frozen
|
|
881
1110
|
|
|
882
|
-
def copy(self):
|
|
883
|
-
"""Return a copy of the shape."""
|
|
884
|
-
return Shape(self._dims, self._denotations, self._frozen)
|
|
885
|
-
|
|
886
1111
|
@property
|
|
887
1112
|
def dims(self) -> tuple[int | SymbolicDim, ...]:
|
|
888
1113
|
"""All dimensions in the shape.
|
|
@@ -891,8 +1116,29 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
891
1116
|
"""
|
|
892
1117
|
return tuple(self._dims)
|
|
893
1118
|
|
|
1119
|
+
@property
|
|
1120
|
+
def frozen(self) -> bool:
|
|
1121
|
+
"""Whether the shape is frozen.
|
|
1122
|
+
|
|
1123
|
+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1124
|
+
Call :method:`freeze` to freeze the shape. Call :method:`copy` to create a
|
|
1125
|
+
new shape with the same dimensions that can be modified.
|
|
1126
|
+
"""
|
|
1127
|
+
return self._frozen
|
|
1128
|
+
|
|
1129
|
+
def freeze(self) -> None:
|
|
1130
|
+
"""Freeze the shape.
|
|
1131
|
+
|
|
1132
|
+
When the shape is frozen, it cannot be unfrozen, making it suitable to be shared.
|
|
1133
|
+
"""
|
|
1134
|
+
self._frozen = True
|
|
1135
|
+
|
|
1136
|
+
def copy(self, frozen: bool = False):
|
|
1137
|
+
"""Return a copy of the shape."""
|
|
1138
|
+
return Shape(self._dims, self._denotations, frozen=frozen)
|
|
1139
|
+
|
|
894
1140
|
def rank(self) -> int:
|
|
895
|
-
"""The rank of the shape."""
|
|
1141
|
+
"""The rank of the tensor this shape represents."""
|
|
896
1142
|
return len(self._dims)
|
|
897
1143
|
|
|
898
1144
|
def numpy(self) -> tuple[int, ...]:
|
|
@@ -928,12 +1174,8 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
928
1174
|
"""
|
|
929
1175
|
if self._frozen:
|
|
930
1176
|
raise TypeError("The shape is frozen and cannot be modified.")
|
|
931
|
-
if isinstance(value, str) or value is None:
|
|
932
|
-
value = SymbolicDim(value)
|
|
933
|
-
if not isinstance(value, (int, SymbolicDim)):
|
|
934
|
-
raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
|
|
935
1177
|
|
|
936
|
-
self._dims[index] = value
|
|
1178
|
+
self._dims[index] = _maybe_convert_to_symbolic_dim(value)
|
|
937
1179
|
|
|
938
1180
|
def get_denotation(self, index: int) -> str | None:
|
|
939
1181
|
"""Return the denotation of the dimension at the index.
|
|
@@ -968,7 +1210,7 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
968
1210
|
def __eq__(self, other: object) -> bool:
|
|
969
1211
|
"""Return True if the shapes are equal.
|
|
970
1212
|
|
|
971
|
-
Two shapes are
|
|
1213
|
+
Two shapes are equal if all their dimensions are equal.
|
|
972
1214
|
"""
|
|
973
1215
|
if isinstance(other, Shape):
|
|
974
1216
|
return self._dims == other._dims
|
|
@@ -979,6 +1221,33 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
|
|
|
979
1221
|
def __ne__(self, other: object) -> bool:
|
|
980
1222
|
return not self.__eq__(other)
|
|
981
1223
|
|
|
1224
|
+
@typing.overload
|
|
1225
|
+
def is_static(self, dim: int) -> bool: # noqa: D418
|
|
1226
|
+
"""Return True if the dimension is static."""
|
|
1227
|
+
|
|
1228
|
+
@typing.overload
|
|
1229
|
+
def is_static(self) -> bool: # noqa: D418
|
|
1230
|
+
"""Return True if all dimensions are static."""
|
|
1231
|
+
|
|
1232
|
+
def is_static(self, dim=None) -> bool:
|
|
1233
|
+
"""Return True if the dimension is static. If dim is None, return True if all dimensions are static."""
|
|
1234
|
+
if dim is None:
|
|
1235
|
+
return all(isinstance(dim, int) for dim in self._dims)
|
|
1236
|
+
return isinstance(self[dim], int)
|
|
1237
|
+
|
|
1238
|
+
@typing.overload
|
|
1239
|
+
def is_dynamic(self, dim: int) -> bool: # noqa: D418
|
|
1240
|
+
"""Return True if the dimension is dynamic."""
|
|
1241
|
+
|
|
1242
|
+
@typing.overload
|
|
1243
|
+
def is_dynamic(self) -> bool: # noqa: D418
|
|
1244
|
+
"""Return True if any dimension is dynamic."""
|
|
1245
|
+
|
|
1246
|
+
def is_dynamic(self, dim=None) -> bool:
|
|
1247
|
+
if dim is None:
|
|
1248
|
+
return not self.is_static()
|
|
1249
|
+
return not self.is_static(dim)
|
|
1250
|
+
|
|
982
1251
|
|
|
983
1252
|
def _quoted(string: str) -> str:
|
|
984
1253
|
"""Return a quoted string.
|
|
@@ -988,6 +1257,35 @@ def _quoted(string: str) -> str:
|
|
|
988
1257
|
return f'"{string}"'
|
|
989
1258
|
|
|
990
1259
|
|
|
1260
|
+
class Usage(NamedTuple):
|
|
1261
|
+
"""A usage of a value in a node.
|
|
1262
|
+
|
|
1263
|
+
Attributes:
|
|
1264
|
+
node: The node that uses the value.
|
|
1265
|
+
idx: The input index of the value in the node.
|
|
1266
|
+
"""
|
|
1267
|
+
|
|
1268
|
+
node: Node
|
|
1269
|
+
idx: int
|
|
1270
|
+
|
|
1271
|
+
|
|
1272
|
+
def _short_tensor_str_for_node(x: Value) -> str:
|
|
1273
|
+
if x.const_value is None:
|
|
1274
|
+
return ""
|
|
1275
|
+
if x.const_value.size <= 10:
|
|
1276
|
+
try:
|
|
1277
|
+
data = x.const_value.numpy().tolist()
|
|
1278
|
+
except Exception: # pylint: disable=broad-except
|
|
1279
|
+
return "{...}"
|
|
1280
|
+
return f"{{{data}}}"
|
|
1281
|
+
return "{...}"
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def _normalize_domain(domain: str) -> str:
|
|
1285
|
+
"""Normalize 'ai.onnx' to ''."""
|
|
1286
|
+
return "" if domain == "ai.onnx" else domain
|
|
1287
|
+
|
|
1288
|
+
|
|
991
1289
|
class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
992
1290
|
"""IR Node.
|
|
993
1291
|
|
|
@@ -1001,6 +1299,9 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1001
1299
|
To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
|
|
1002
1300
|
the new output values by calling :meth:`replace_input_with` on the using nodes
|
|
1003
1301
|
of this node's outputs.
|
|
1302
|
+
|
|
1303
|
+
.. note:
|
|
1304
|
+
When the ``domain`` is `"ai.onnx"`, it is normalized to `""`.
|
|
1004
1305
|
"""
|
|
1005
1306
|
|
|
1006
1307
|
__slots__ = (
|
|
@@ -1023,13 +1324,13 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1023
1324
|
domain: str,
|
|
1024
1325
|
op_type: str,
|
|
1025
1326
|
inputs: Iterable[Value | None],
|
|
1026
|
-
attributes: Iterable[Attr |
|
|
1327
|
+
attributes: Iterable[Attr] | Mapping[str, Attr] = (),
|
|
1027
1328
|
*,
|
|
1028
1329
|
overload: str = "",
|
|
1029
1330
|
num_outputs: int | None = None,
|
|
1030
1331
|
outputs: Sequence[Value] | None = None,
|
|
1031
1332
|
version: int | None = None,
|
|
1032
|
-
graph: Graph | None = None,
|
|
1333
|
+
graph: Graph | Function | None = None,
|
|
1033
1334
|
name: str | None = None,
|
|
1034
1335
|
doc_string: str | None = None,
|
|
1035
1336
|
metadata_props: dict[str, str] | None = None,
|
|
@@ -1038,27 +1339,30 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1038
1339
|
|
|
1039
1340
|
Args:
|
|
1040
1341
|
domain: The domain of the operator. For onnx operators, this is an empty string.
|
|
1342
|
+
When it is `"ai.onnx"`, it is normalized to `""`.
|
|
1041
1343
|
op_type: The name of the operator.
|
|
1042
|
-
inputs: The input values. When an input is None
|
|
1344
|
+
inputs: The input values. When an input is ``None``, it is an empty input.
|
|
1043
1345
|
attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
|
|
1044
1346
|
overload: The overload name when the node is invoking a function.
|
|
1045
1347
|
num_outputs: The number of outputs of the node. If not specified, the number is 1.
|
|
1046
|
-
outputs: The output values. If None
|
|
1047
|
-
version: The version of the operator. If None
|
|
1048
|
-
graph: The graph that the node belongs to. If None
|
|
1049
|
-
A `Node` must belong to zero or one graph.
|
|
1050
|
-
|
|
1348
|
+
outputs: The output values. If ``None``, the outputs are created during initialization.
|
|
1349
|
+
version: The version of the operator. If ``None``, the version is unspecified and will follow that of the graph.
|
|
1350
|
+
graph: The graph that the node belongs to. If ``None``, the node is not added to any graph.
|
|
1351
|
+
A `Node` must belong to zero or one graph. If a :class:`Function`, the underlying graph
|
|
1352
|
+
of the function is assigned to the node.
|
|
1353
|
+
name: The name of the node. If ``None``, the node is anonymous. The name may be
|
|
1354
|
+
set by a :class:`Graph` if ``graph`` is specified.
|
|
1051
1355
|
doc_string: The documentation string.
|
|
1052
1356
|
metadata_props: The metadata properties.
|
|
1053
1357
|
|
|
1054
1358
|
Raises:
|
|
1055
|
-
TypeError: If the attributes are not Attr
|
|
1056
|
-
ValueError: If
|
|
1057
|
-
ValueError: If an output value is None
|
|
1359
|
+
TypeError: If the attributes are not :class:`Attr`.
|
|
1360
|
+
ValueError: If ``num_outputs``, when not ``None``, is not the same as the length of the outputs.
|
|
1361
|
+
ValueError: If an output value is ``None``, when outputs is specified.
|
|
1058
1362
|
ValueError: If an output value has a producer set already, when outputs is specified.
|
|
1059
1363
|
"""
|
|
1060
1364
|
self._name = name
|
|
1061
|
-
self._domain: str = domain
|
|
1365
|
+
self._domain: str = _normalize_domain(domain)
|
|
1062
1366
|
self._op_type: str = op_type
|
|
1063
1367
|
# NOTE: Make inputs immutable with the assumption that they are not mutated
|
|
1064
1368
|
# very often. This way all mutations can be tracked.
|
|
@@ -1066,22 +1370,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1066
1370
|
self._inputs: tuple[Value | None, ...] = tuple(inputs)
|
|
1067
1371
|
# Values belong to their defining nodes. The values list is immutable
|
|
1068
1372
|
self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
|
|
1069
|
-
|
|
1070
|
-
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
"If you are copying the attributes from another node, make sure you call "
|
|
1074
|
-
"node.attributes.values() because it is a dictionary."
|
|
1075
|
-
)
|
|
1076
|
-
self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
|
|
1077
|
-
(attr.name, attr) for attr in attributes
|
|
1373
|
+
if isinstance(attributes, Mapping):
|
|
1374
|
+
attributes = tuple(attributes.values())
|
|
1375
|
+
self._attributes: _graph_containers.Attributes = _graph_containers.Attributes(
|
|
1376
|
+
attributes
|
|
1078
1377
|
)
|
|
1079
1378
|
self._overload: str = overload
|
|
1080
1379
|
# TODO(justinchuby): Potentially support a version range
|
|
1081
1380
|
self._version: int | None = version
|
|
1082
1381
|
self._metadata: _metadata.MetadataStore | None = None
|
|
1083
1382
|
self._metadata_props: dict[str, str] | None = metadata_props
|
|
1084
|
-
|
|
1383
|
+
# _graph is set by graph.append
|
|
1384
|
+
self._graph: Graph | None = None
|
|
1385
|
+
# Add the node to the graph if graph is specified
|
|
1386
|
+
if graph is not None:
|
|
1387
|
+
graph.append(self)
|
|
1085
1388
|
self.doc_string = doc_string
|
|
1086
1389
|
|
|
1087
1390
|
# Add the node as a use of the inputs
|
|
@@ -1089,10 +1392,6 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1089
1392
|
if input_value is not None:
|
|
1090
1393
|
input_value._add_usage(self, i) # pylint: disable=protected-access
|
|
1091
1394
|
|
|
1092
|
-
# Add the node to the graph if graph is specified
|
|
1093
|
-
if self._graph is not None:
|
|
1094
|
-
self._graph.append(self)
|
|
1095
|
-
|
|
1096
1395
|
def _create_outputs(
|
|
1097
1396
|
self, num_outputs: int | None, outputs: Sequence[Value] | None
|
|
1098
1397
|
) -> tuple[Value, ...]:
|
|
@@ -1150,7 +1449,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1150
1449
|
+ ", ".join(
|
|
1151
1450
|
[
|
|
1152
1451
|
(
|
|
1153
|
-
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
|
|
1452
|
+
f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}{_short_tensor_str_for_node(x)}"
|
|
1154
1453
|
if x is not None
|
|
1155
1454
|
else "None"
|
|
1156
1455
|
)
|
|
@@ -1178,6 +1477,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1178
1477
|
|
|
1179
1478
|
@property
|
|
1180
1479
|
def name(self) -> str | None:
|
|
1480
|
+
"""Optional name of the node."""
|
|
1181
1481
|
return self._name
|
|
1182
1482
|
|
|
1183
1483
|
@name.setter
|
|
@@ -1186,14 +1486,26 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1186
1486
|
|
|
1187
1487
|
@property
|
|
1188
1488
|
def domain(self) -> str:
|
|
1489
|
+
"""The domain of the operator. For onnx operators, this is an empty string.
|
|
1490
|
+
|
|
1491
|
+
.. note:
|
|
1492
|
+
When domain is `"ai.onnx"`, it is normalized to `""`.
|
|
1493
|
+
"""
|
|
1189
1494
|
return self._domain
|
|
1190
1495
|
|
|
1191
1496
|
@domain.setter
|
|
1192
1497
|
def domain(self, value: str) -> None:
|
|
1193
|
-
self._domain = value
|
|
1498
|
+
self._domain = _normalize_domain(value)
|
|
1194
1499
|
|
|
1195
1500
|
@property
|
|
1196
1501
|
def version(self) -> int | None:
|
|
1502
|
+
"""Opset version of the operator called.
|
|
1503
|
+
|
|
1504
|
+
If ``None``, the version is unspecified and will follow that of the graph.
|
|
1505
|
+
This property is special to ONNX IR to allow mixed opset usage in a graph
|
|
1506
|
+
for supporting more flexible graph transformations. It does not exist in the ONNX
|
|
1507
|
+
serialization (protobuf) spec.
|
|
1508
|
+
"""
|
|
1197
1509
|
return self._version
|
|
1198
1510
|
|
|
1199
1511
|
@version.setter
|
|
@@ -1202,6 +1514,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1202
1514
|
|
|
1203
1515
|
@property
|
|
1204
1516
|
def op_type(self) -> str:
|
|
1517
|
+
"""The name of the operator called."""
|
|
1205
1518
|
return self._op_type
|
|
1206
1519
|
|
|
1207
1520
|
@op_type.setter
|
|
@@ -1210,6 +1523,7 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1210
1523
|
|
|
1211
1524
|
@property
|
|
1212
1525
|
def overload(self) -> str:
|
|
1526
|
+
"""The overload name when the node is invoking a function."""
|
|
1213
1527
|
return self._overload
|
|
1214
1528
|
|
|
1215
1529
|
@overload.setter
|
|
@@ -1218,6 +1532,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1218
1532
|
|
|
1219
1533
|
@property
|
|
1220
1534
|
def inputs(self) -> Sequence[Value | None]:
|
|
1535
|
+
"""The input values of the node.
|
|
1536
|
+
|
|
1537
|
+
The inputs are immutable. To change the inputs, create a new node and
|
|
1538
|
+
replace the inputs of the using nodes of this node's outputs by calling
|
|
1539
|
+
:meth:`replace_input_with` on the using nodes of this node's outputs.
|
|
1540
|
+
"""
|
|
1221
1541
|
return self._inputs
|
|
1222
1542
|
|
|
1223
1543
|
@inputs.setter
|
|
@@ -1226,6 +1546,25 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1226
1546
|
"Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
|
|
1227
1547
|
)
|
|
1228
1548
|
|
|
1549
|
+
def predecessors(self) -> Sequence[Node]:
|
|
1550
|
+
"""Return the predecessor nodes of the node, deduplicated, in a deterministic order."""
|
|
1551
|
+
# Use the ordered nature of a dictionary to deduplicate the nodes
|
|
1552
|
+
predecessors: dict[Node, None] = {}
|
|
1553
|
+
for value in self.inputs:
|
|
1554
|
+
if value is not None and (producer := value.producer()) is not None:
|
|
1555
|
+
predecessors[producer] = None
|
|
1556
|
+
return tuple(predecessors)
|
|
1557
|
+
|
|
1558
|
+
def successors(self) -> Sequence[Node]:
|
|
1559
|
+
"""Return the successor nodes of the node, deduplicated, in a deterministic order."""
|
|
1560
|
+
# Use the ordered nature of a dictionary to deduplicate the nodes
|
|
1561
|
+
successors: dict[Node, None] = {}
|
|
1562
|
+
for value in self.outputs:
|
|
1563
|
+
assert value is not None, "Bug: Output values are not expected to be None"
|
|
1564
|
+
for usage in value.uses():
|
|
1565
|
+
successors[usage.node] = None
|
|
1566
|
+
return tuple(successors)
|
|
1567
|
+
|
|
1229
1568
|
def replace_input_with(self, index: int, value: Value | None) -> None:
|
|
1230
1569
|
"""Replace an input with a new value."""
|
|
1231
1570
|
if index < 0 or index >= len(self.inputs):
|
|
@@ -1279,6 +1618,12 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1279
1618
|
|
|
1280
1619
|
@property
|
|
1281
1620
|
def outputs(self) -> Sequence[Value]:
|
|
1621
|
+
"""The output values of the node.
|
|
1622
|
+
|
|
1623
|
+
The outputs are immutable. To change the outputs, create a new node and
|
|
1624
|
+
replace the inputs of the using nodes of this node's outputs by calling
|
|
1625
|
+
:meth:`replace_input_with` on the using nodes of this node's outputs.
|
|
1626
|
+
"""
|
|
1282
1627
|
return self._outputs
|
|
1283
1628
|
|
|
1284
1629
|
@outputs.setter
|
|
@@ -1286,7 +1631,8 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1286
1631
|
raise AttributeError("outputs is immutable. Please create a new node instead.")
|
|
1287
1632
|
|
|
1288
1633
|
@property
|
|
1289
|
-
def attributes(self) ->
|
|
1634
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
1635
|
+
"""The attributes of the node."""
|
|
1290
1636
|
return self._attributes
|
|
1291
1637
|
|
|
1292
1638
|
@property
|
|
@@ -1302,12 +1648,21 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1302
1648
|
|
|
1303
1649
|
@property
|
|
1304
1650
|
def metadata_props(self) -> dict[str, str]:
|
|
1651
|
+
"""The metadata properties of the node.
|
|
1652
|
+
|
|
1653
|
+
The metadata properties are used to store additional information about the node.
|
|
1654
|
+
Unlike ``meta``, this property is serialized to the ONNX proto.
|
|
1655
|
+
"""
|
|
1305
1656
|
if self._metadata_props is None:
|
|
1306
1657
|
self._metadata_props = {}
|
|
1307
1658
|
return self._metadata_props
|
|
1308
1659
|
|
|
1309
1660
|
@property
|
|
1310
1661
|
def graph(self) -> Graph | None:
|
|
1662
|
+
"""The graph that the node belongs to.
|
|
1663
|
+
|
|
1664
|
+
If the node is not added to any graph, this property is None.
|
|
1665
|
+
"""
|
|
1311
1666
|
return self._graph
|
|
1312
1667
|
|
|
1313
1668
|
@graph.setter
|
|
@@ -1315,9 +1670,17 @@ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
|
|
|
1315
1670
|
self._graph = value
|
|
1316
1671
|
|
|
1317
1672
|
def op_identifier(self) -> _protocols.OperatorIdentifier:
|
|
1673
|
+
"""Return the operator identifier of the node.
|
|
1674
|
+
|
|
1675
|
+
The operator identifier is a tuple of the domain, op_type and overload.
|
|
1676
|
+
"""
|
|
1318
1677
|
return self.domain, self.op_type, self.overload
|
|
1319
1678
|
|
|
1320
1679
|
def display(self, *, page: bool = False) -> None:
|
|
1680
|
+
"""Pretty print the node.
|
|
1681
|
+
|
|
1682
|
+
This method is used for debugging and visualization purposes.
|
|
1683
|
+
"""
|
|
1321
1684
|
# Add the node's name to the displayed text
|
|
1322
1685
|
print(f"Node: {self.name!r}")
|
|
1323
1686
|
if self.doc_string:
|
|
@@ -1344,7 +1707,7 @@ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashabl
|
|
|
1344
1707
|
|
|
1345
1708
|
@property
|
|
1346
1709
|
def elem_type(self) -> _enums.DataType:
|
|
1347
|
-
"""Return the element type of the tensor type"""
|
|
1710
|
+
"""Return the element type of the tensor type."""
|
|
1348
1711
|
return self.dtype
|
|
1349
1712
|
|
|
1350
1713
|
def __hash__(self) -> int:
|
|
@@ -1438,18 +1801,19 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1438
1801
|
|
|
1439
1802
|
To find all the nodes that use this value as an input, call :meth:`uses`.
|
|
1440
1803
|
|
|
1441
|
-
To check if the value is an output of a graph,
|
|
1804
|
+
To check if the value is an is an input, output or initializer of a graph,
|
|
1805
|
+
use :meth:`is_graph_input`, :meth:`is_graph_output` or :meth:`is_initializer`.
|
|
1442
1806
|
|
|
1443
|
-
|
|
1444
|
-
name: The name of the value. A value is always named when it is part of a graph.
|
|
1445
|
-
shape: The shape of the value.
|
|
1446
|
-
type: The type of the value.
|
|
1447
|
-
metadata_props: Metadata.
|
|
1807
|
+
Use :meth:`graph` to get the graph that owns the value.
|
|
1448
1808
|
"""
|
|
1449
1809
|
|
|
1450
1810
|
__slots__ = (
|
|
1451
1811
|
"_const_value",
|
|
1812
|
+
"_graph",
|
|
1452
1813
|
"_index",
|
|
1814
|
+
"_is_graph_input",
|
|
1815
|
+
"_is_graph_output",
|
|
1816
|
+
"_is_initializer",
|
|
1453
1817
|
"_metadata",
|
|
1454
1818
|
"_metadata_props",
|
|
1455
1819
|
"_name",
|
|
@@ -1497,19 +1861,33 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1497
1861
|
# Use a collection of (Node, int) to store uses. This is needed
|
|
1498
1862
|
# because a single use can use the same value multiple times.
|
|
1499
1863
|
# Use a dictionary to preserve insertion order so that the visiting order is deterministic
|
|
1500
|
-
self._uses: dict[
|
|
1864
|
+
self._uses: dict[Usage, None] = {}
|
|
1501
1865
|
self.doc_string = doc_string
|
|
1502
1866
|
|
|
1867
|
+
# The graph this value belongs to. It is set *only* when the value is added as
|
|
1868
|
+
# a graph input, output or initializer.
|
|
1869
|
+
# The four properties can only be set by the Graph class (_GraphIO and GraphInitializers).
|
|
1870
|
+
self._graph: Graph | None = None
|
|
1871
|
+
self._is_graph_input: bool = False
|
|
1872
|
+
self._is_graph_output: bool = False
|
|
1873
|
+
self._is_initializer: bool = False
|
|
1874
|
+
|
|
1503
1875
|
def __repr__(self) -> str:
|
|
1504
1876
|
value_name = self.name if self.name else "anonymous:" + str(id(self))
|
|
1877
|
+
type_text = f", type={self.type!r}" if self.type is not None else ""
|
|
1878
|
+
shape_text = f", shape={self.shape!r}" if self.shape is not None else ""
|
|
1505
1879
|
producer = self.producer()
|
|
1506
1880
|
if producer is None:
|
|
1507
|
-
producer_text = "
|
|
1881
|
+
producer_text = ""
|
|
1508
1882
|
elif producer.name is not None:
|
|
1509
|
-
producer_text = producer.name
|
|
1883
|
+
producer_text = f", producer='{producer.name}'"
|
|
1510
1884
|
else:
|
|
1511
|
-
producer_text = f"anonymous_node:{id(producer)}"
|
|
1512
|
-
|
|
1885
|
+
producer_text = f", producer=anonymous_node:{id(producer)}"
|
|
1886
|
+
index_text = f", index={self.index()}" if self.index() is not None else ""
|
|
1887
|
+
const_value_text = self._constant_tensor_part()
|
|
1888
|
+
if const_value_text:
|
|
1889
|
+
const_value_text = f", const_value={const_value_text}"
|
|
1890
|
+
return f"{self.__class__.__name__}(name={value_name!r}{type_text}{shape_text}{producer_text}{index_text}{const_value_text})"
|
|
1513
1891
|
|
|
1514
1892
|
def __str__(self) -> str:
|
|
1515
1893
|
value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
|
|
@@ -1518,41 +1896,85 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1518
1896
|
|
|
1519
1897
|
# Quote the name because in reality the names can have invalid characters
|
|
1520
1898
|
# that make them hard to read
|
|
1521
|
-
return
|
|
1899
|
+
return (
|
|
1900
|
+
f"%{_quoted(value_name)}<{type_text},{shape_text}>{self._constant_tensor_part()}"
|
|
1901
|
+
)
|
|
1902
|
+
|
|
1903
|
+
def _constant_tensor_part(self) -> str:
|
|
1904
|
+
"""Display string for the constant tensor attached to str of Value."""
|
|
1905
|
+
if self.const_value is not None:
|
|
1906
|
+
# Only display when the const value is small
|
|
1907
|
+
if self.const_value.size <= 10:
|
|
1908
|
+
return f"{{{self.const_value}}}"
|
|
1909
|
+
else:
|
|
1910
|
+
return f"{{{self.const_value.__class__.__name__}(...)}}"
|
|
1911
|
+
return ""
|
|
1912
|
+
|
|
1913
|
+
@property
|
|
1914
|
+
def graph(self) -> Graph | None:
|
|
1915
|
+
"""Return the graph that defines this value.
|
|
1916
|
+
|
|
1917
|
+
When the value is an input/output/initializer of a graph, the owning graph
|
|
1918
|
+
is that graph. When the value is an output of a node, the owning graph is the
|
|
1919
|
+
graph that the node belongs to. When the value is not owned by any graph,
|
|
1920
|
+
it returns ``None``.
|
|
1921
|
+
"""
|
|
1922
|
+
if self._graph is not None:
|
|
1923
|
+
return self._graph
|
|
1924
|
+
if self._producer is not None:
|
|
1925
|
+
return self._producer.graph
|
|
1926
|
+
return None
|
|
1927
|
+
|
|
1928
|
+
def _owned_by_graph(self) -> bool:
|
|
1929
|
+
"""Return True if the value is owned by a graph."""
|
|
1930
|
+
result = self._is_graph_input or self._is_graph_output or self._is_initializer
|
|
1931
|
+
if result:
|
|
1932
|
+
assert self._graph is not None
|
|
1933
|
+
return result
|
|
1522
1934
|
|
|
1523
1935
|
def producer(self) -> Node | None:
|
|
1524
1936
|
"""The node that produces this value.
|
|
1525
1937
|
|
|
1526
1938
|
When producer is ``None``, the value does not belong to a node, and is
|
|
1527
|
-
typically a graph input or an initializer.
|
|
1939
|
+
typically a graph input or an initializer. You can use :meth:`graph``
|
|
1940
|
+
to find the graph that owns this value. Use :meth:`is_graph_input`, :meth:`is_graph_output`
|
|
1941
|
+
or :meth:`is_initializer` to check if the value is an input, output or initializer of a graph.
|
|
1528
1942
|
"""
|
|
1529
1943
|
return self._producer
|
|
1530
1944
|
|
|
1945
|
+
def consumers(self) -> Sequence[Node]:
|
|
1946
|
+
"""Return the nodes (deduplicated) that consume this value."""
|
|
1947
|
+
return tuple({usage.node: None for usage in self._uses})
|
|
1948
|
+
|
|
1531
1949
|
def index(self) -> int | None:
|
|
1532
1950
|
"""The index of the output of the defining node."""
|
|
1533
1951
|
return self._index
|
|
1534
1952
|
|
|
1535
|
-
def uses(self) -> Collection[
|
|
1953
|
+
def uses(self) -> Collection[Usage]:
|
|
1536
1954
|
"""Return a set of uses of the value.
|
|
1537
1955
|
|
|
1538
1956
|
The set contains tuples of ``(Node, index)`` where the index is the index of the input
|
|
1539
1957
|
of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
|
|
1540
1958
|
"""
|
|
1541
|
-
|
|
1959
|
+
# Create a tuple for the collection so that iteration on will will not
|
|
1960
|
+
# be affected when the usage changes during graph mutation.
|
|
1961
|
+
# This adds a small overhead but is better a user experience than
|
|
1962
|
+
# having users call tuple().
|
|
1963
|
+
return tuple(self._uses)
|
|
1542
1964
|
|
|
1543
1965
|
def _add_usage(self, use: Node, index: int) -> None:
|
|
1544
1966
|
"""Add a usage of this value.
|
|
1545
1967
|
|
|
1546
1968
|
This is an internal method. It should only be called by the Node class.
|
|
1547
1969
|
"""
|
|
1548
|
-
self._uses[(use, index)] = None
|
|
1970
|
+
self._uses[Usage(use, index)] = None
|
|
1549
1971
|
|
|
1550
1972
|
def _remove_usage(self, use: Node, index: int) -> None:
|
|
1551
1973
|
"""Remove a node from the uses of this value.
|
|
1552
1974
|
|
|
1553
1975
|
This is an internal method. It should only be called by the Node class.
|
|
1554
1976
|
"""
|
|
1555
|
-
self._uses.pop((use, index))
|
|
1977
|
+
self._uses.pop(Usage(use, index))
|
|
1556
1978
|
|
|
1557
1979
|
@property
|
|
1558
1980
|
def name(self) -> str | None:
|
|
@@ -1652,15 +2074,17 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
1652
2074
|
self._metadata_props = {}
|
|
1653
2075
|
return self._metadata_props
|
|
1654
2076
|
|
|
2077
|
+
def is_graph_input(self) -> bool:
|
|
2078
|
+
"""Whether the value is an input of a graph."""
|
|
2079
|
+
return self._is_graph_input
|
|
2080
|
+
|
|
1655
2081
|
def is_graph_output(self) -> bool:
|
|
1656
2082
|
"""Whether the value is an output of a graph."""
|
|
1657
|
-
|
|
1658
|
-
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
# it is not recommended
|
|
1663
|
-
return any(output is self for output in graph.outputs)
|
|
2083
|
+
return self._is_graph_output
|
|
2084
|
+
|
|
2085
|
+
def is_initializer(self) -> bool:
|
|
2086
|
+
"""Whether the value is an initializer of a graph."""
|
|
2087
|
+
return self._is_initializer
|
|
1664
2088
|
|
|
1665
2089
|
|
|
1666
2090
|
def Input(
|
|
@@ -1673,7 +2097,6 @@ def Input(
|
|
|
1673
2097
|
|
|
1674
2098
|
This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
|
|
1675
2099
|
"""
|
|
1676
|
-
|
|
1677
2100
|
# NOTE: The function name is capitalized to maintain API backward compatibility.
|
|
1678
2101
|
|
|
1679
2102
|
return Value(name=name, shape=shape, type=type, doc_string=doc_string)
|
|
@@ -1770,19 +2193,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1770
2193
|
self.name = name
|
|
1771
2194
|
|
|
1772
2195
|
# Private fields that are not to be accessed by any other classes
|
|
1773
|
-
self._inputs =
|
|
1774
|
-
self._outputs =
|
|
1775
|
-
self._initializers =
|
|
1776
|
-
|
|
1777
|
-
|
|
1778
|
-
raise TypeError(
|
|
1779
|
-
"Initializer must be a Value, not a string. "
|
|
1780
|
-
"If you are copying the initializers from another graph, "
|
|
1781
|
-
"make sure you call graph.initializers.values() because it is a dictionary."
|
|
1782
|
-
)
|
|
1783
|
-
if initializer.name is None:
|
|
1784
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
1785
|
-
self._initializers[initializer.name] = initializer
|
|
2196
|
+
self._inputs = _graph_containers.GraphInputs(self, inputs)
|
|
2197
|
+
self._outputs = _graph_containers.GraphOutputs(self, outputs)
|
|
2198
|
+
self._initializers = _graph_containers.GraphInitializers(
|
|
2199
|
+
self, {initializer.name: initializer for initializer in initializers}
|
|
2200
|
+
)
|
|
1786
2201
|
self._doc_string = doc_string
|
|
1787
2202
|
self._opset_imports = opset_imports or {}
|
|
1788
2203
|
self._metadata: _metadata.MetadataStore | None = None
|
|
@@ -1791,21 +2206,67 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1791
2206
|
# Be sure the initialize the name authority before extending the nodes
|
|
1792
2207
|
# because it is used to name the nodes and their outputs
|
|
1793
2208
|
self._name_authority = _name_authority.NameAuthority()
|
|
2209
|
+
# TODO(justinchuby): Trigger again if inputs or initializers are modified.
|
|
2210
|
+
self._set_input_and_initializer_value_names_into_name_authority()
|
|
1794
2211
|
# Call self.extend not self._nodes.extend so the graph reference is added to the nodes
|
|
1795
2212
|
self.extend(nodes)
|
|
1796
2213
|
|
|
1797
2214
|
@property
|
|
1798
|
-
def inputs(self) ->
|
|
2215
|
+
def inputs(self) -> MutableSequence[Value]:
|
|
1799
2216
|
return self._inputs
|
|
1800
2217
|
|
|
1801
2218
|
@property
|
|
1802
|
-
def outputs(self) ->
|
|
2219
|
+
def outputs(self) -> MutableSequence[Value]:
|
|
1803
2220
|
return self._outputs
|
|
1804
2221
|
|
|
1805
2222
|
@property
|
|
1806
|
-
def initializers(self) ->
|
|
2223
|
+
def initializers(self) -> _graph_containers.GraphInitializers:
|
|
2224
|
+
"""The initializers of the graph as a ``MutableMapping[str, Value]``.
|
|
2225
|
+
|
|
2226
|
+
The keys are the names of the initializers. The values are the :class:`Value` objects.
|
|
2227
|
+
|
|
2228
|
+
This property additionally supports the ``add`` method, which takes a :class:`Value`
|
|
2229
|
+
and adds it to the initializers if it is not already present.
|
|
2230
|
+
|
|
2231
|
+
.. note::
|
|
2232
|
+
When setting an initializer with ``graph.initializers[key] = value``,
|
|
2233
|
+
if the value does not have a name, it will be assigned ``key`` as its name.
|
|
2234
|
+
|
|
2235
|
+
"""
|
|
1807
2236
|
return self._initializers
|
|
1808
2237
|
|
|
2238
|
+
def register_initializer(self, value: Value) -> None:
|
|
2239
|
+
"""Register an initializer to the graph.
|
|
2240
|
+
|
|
2241
|
+
This is a convenience method to register an initializer to the graph with
|
|
2242
|
+
checks.
|
|
2243
|
+
|
|
2244
|
+
Args:
|
|
2245
|
+
value: The :class:`Value` to register as an initializer of the graph.
|
|
2246
|
+
It must have its ``.const_value`` set.
|
|
2247
|
+
|
|
2248
|
+
Raises:
|
|
2249
|
+
ValueError: If a value of the same name that is not this value
|
|
2250
|
+
is already registered.
|
|
2251
|
+
ValueError: If the value does not have a name.
|
|
2252
|
+
ValueError: If the initializer is produced by a node.
|
|
2253
|
+
ValueError: If the value does not have its ``.const_value`` set.
|
|
2254
|
+
"""
|
|
2255
|
+
if not value.name:
|
|
2256
|
+
raise ValueError(f"Initializer must have a name: {value!r}")
|
|
2257
|
+
if value.name in self._initializers:
|
|
2258
|
+
if self._initializers[value.name] is not value:
|
|
2259
|
+
raise ValueError(
|
|
2260
|
+
f"Initializer '{value.name}' is already registered, but"
|
|
2261
|
+
" it is not the same object: existing={self._initializers[value.name]!r},"
|
|
2262
|
+
f" new={value!r}"
|
|
2263
|
+
)
|
|
2264
|
+
if value.const_value is None:
|
|
2265
|
+
raise ValueError(
|
|
2266
|
+
f"Value '{value!r}' must have its const_value set to be an initializer."
|
|
2267
|
+
)
|
|
2268
|
+
self._initializers.add(value)
|
|
2269
|
+
|
|
1809
2270
|
@property
|
|
1810
2271
|
def doc_string(self) -> str | None:
|
|
1811
2272
|
return self._doc_string
|
|
@@ -1818,7 +2279,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1818
2279
|
def opset_imports(self) -> dict[str, int]:
|
|
1819
2280
|
return self._opset_imports
|
|
1820
2281
|
|
|
1821
|
-
|
|
2282
|
+
@typing.overload
|
|
2283
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2284
|
+
@typing.overload
|
|
2285
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2286
|
+
|
|
2287
|
+
def __getitem__(self, index):
|
|
1822
2288
|
return self._nodes[index]
|
|
1823
2289
|
|
|
1824
2290
|
def __len__(self) -> int:
|
|
@@ -1830,6 +2296,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1830
2296
|
def __reversed__(self) -> Iterator[Node]:
|
|
1831
2297
|
return reversed(self._nodes)
|
|
1832
2298
|
|
|
2299
|
+
def _set_input_and_initializer_value_names_into_name_authority(self):
|
|
2300
|
+
for value in self.inputs:
|
|
2301
|
+
self._name_authority.register_or_name_value(value)
|
|
2302
|
+
for value in self.initializers.values():
|
|
2303
|
+
self._name_authority.register_or_name_value(value)
|
|
2304
|
+
|
|
1833
2305
|
def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
|
|
1834
2306
|
"""Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
|
|
1835
2307
|
if node.graph is not None and node.graph is not self:
|
|
@@ -1993,7 +2465,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
1993
2465
|
|
|
1994
2466
|
This sort is stable. It preserves the original order as much as possible.
|
|
1995
2467
|
|
|
1996
|
-
|
|
2468
|
+
Reference: https://github.com/madelson/MedallionTopologicalSort#stable-sort
|
|
1997
2469
|
|
|
1998
2470
|
Raises:
|
|
1999
2471
|
ValueError: If the graph contains a cycle, making topological sorting impossible.
|
|
@@ -2004,7 +2476,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2004
2476
|
sorted_nodes_by_graph: dict[Graph, list[Node]] = {
|
|
2005
2477
|
graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
|
|
2006
2478
|
}
|
|
2007
|
-
# TODO: Explain why we need to store direct predecessors and children and why
|
|
2479
|
+
# TODO(justinchuby): Explain why we need to store direct predecessors and children and why
|
|
2008
2480
|
# we only need to store the direct ones
|
|
2009
2481
|
|
|
2010
2482
|
# The depth of a node is defined as the number of direct children it has
|
|
@@ -2024,7 +2496,7 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2024
2496
|
node_depth[predecessor] += 1
|
|
2025
2497
|
|
|
2026
2498
|
# 1. Build the direct predecessors of each node and the depth of each node
|
|
2027
|
-
# for sorting
|
|
2499
|
+
# for sorting topologically using Kahn's algorithm.
|
|
2028
2500
|
# Note that when a node contains graph attributes (aka. has subgraphs),
|
|
2029
2501
|
# we consider all nodes in the subgraphs *predecessors* of this node. This
|
|
2030
2502
|
# way we ensure the implicit dependencies of the subgraphs are captured
|
|
@@ -2125,11 +2597,11 @@ def _graph_str(graph: Graph | GraphView) -> str:
|
|
|
2125
2597
|
)
|
|
2126
2598
|
signature = f"""\
|
|
2127
2599
|
graph(
|
|
2128
|
-
name={graph.name or
|
|
2129
|
-
inputs=({textwrap.indent(inputs_text,
|
|
2600
|
+
name={graph.name or "anonymous_graph:" + str(id(graph))},
|
|
2601
|
+
inputs=({textwrap.indent(inputs_text, " " * 8)}
|
|
2130
2602
|
),
|
|
2131
|
-
outputs=({textwrap.indent(outputs_text,
|
|
2132
|
-
),{textwrap.indent(initializers_text,
|
|
2603
|
+
outputs=({textwrap.indent(outputs_text, " " * 8)}
|
|
2604
|
+
),{textwrap.indent(initializers_text, " " * 4)}
|
|
2133
2605
|
)"""
|
|
2134
2606
|
node_count = len(graph)
|
|
2135
2607
|
number_width = len(str(node_count))
|
|
@@ -2163,11 +2635,11 @@ def _graph_repr(graph: Graph | GraphView) -> str:
|
|
|
2163
2635
|
)
|
|
2164
2636
|
return f"""\
|
|
2165
2637
|
{graph.__class__.__name__}(
|
|
2166
|
-
name={graph.name or
|
|
2167
|
-
inputs=({textwrap.indent(inputs_text,
|
|
2638
|
+
name={graph.name or "anonymous_graph:" + str(id(graph))!r},
|
|
2639
|
+
inputs=({textwrap.indent(inputs_text, " " * 8)}
|
|
2168
2640
|
),
|
|
2169
|
-
outputs=({textwrap.indent(outputs_text,
|
|
2170
|
-
),{textwrap.indent(initializers_text,
|
|
2641
|
+
outputs=({textwrap.indent(outputs_text, " " * 8)}
|
|
2642
|
+
),{textwrap.indent(initializers_text, " " * 4)}
|
|
2171
2643
|
len()={len(graph)}
|
|
2172
2644
|
)"""
|
|
2173
2645
|
|
|
@@ -2223,7 +2695,7 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2223
2695
|
outputs: Sequence[Value],
|
|
2224
2696
|
*,
|
|
2225
2697
|
nodes: Iterable[Node],
|
|
2226
|
-
initializers: Sequence[
|
|
2698
|
+
initializers: Sequence[Value] = (),
|
|
2227
2699
|
doc_string: str | None = None,
|
|
2228
2700
|
opset_imports: dict[str, int] | None = None,
|
|
2229
2701
|
name: str | None = None,
|
|
@@ -2232,17 +2704,19 @@ class GraphView(Sequence[Node], _display.PrettyPrintable):
|
|
|
2232
2704
|
self.name = name
|
|
2233
2705
|
self.inputs = tuple(inputs)
|
|
2234
2706
|
self.outputs = tuple(outputs)
|
|
2235
|
-
for initializer in initializers
|
|
2236
|
-
if initializer.name is None:
|
|
2237
|
-
raise ValueError(f"Initializer must have a name: {initializer}")
|
|
2238
|
-
self.initializers = {tensor.name: tensor for tensor in initializers}
|
|
2707
|
+
self.initializers = {initializer.name: initializer for initializer in initializers}
|
|
2239
2708
|
self.doc_string = doc_string
|
|
2240
2709
|
self.opset_imports = opset_imports or {}
|
|
2241
2710
|
self._metadata: _metadata.MetadataStore | None = None
|
|
2242
2711
|
self._metadata_props: dict[str, str] | None = metadata_props
|
|
2243
2712
|
self._nodes: tuple[Node, ...] = tuple(nodes)
|
|
2244
2713
|
|
|
2245
|
-
|
|
2714
|
+
@typing.overload
|
|
2715
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2716
|
+
@typing.overload
|
|
2717
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2718
|
+
|
|
2719
|
+
def __getitem__(self, index):
|
|
2246
2720
|
return self._nodes[index]
|
|
2247
2721
|
|
|
2248
2722
|
def __len__(self) -> int:
|
|
@@ -2318,7 +2792,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
|
|
|
2318
2792
|
model_version: int | None = None,
|
|
2319
2793
|
doc_string: str | None = None,
|
|
2320
2794
|
functions: Sequence[Function] = (),
|
|
2321
|
-
|
|
2795
|
+
metadata_props: dict[str, str] | None = None,
|
|
2322
2796
|
) -> None:
|
|
2323
2797
|
self.graph: Graph = graph
|
|
2324
2798
|
self.ir_version = ir_version
|
|
@@ -2329,7 +2803,7 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
|
|
|
2329
2803
|
self.doc_string = doc_string
|
|
2330
2804
|
self._functions = {func.identifier(): func for func in functions}
|
|
2331
2805
|
self._metadata: _metadata.MetadataStore | None = None
|
|
2332
|
-
self._metadata_props: dict[str, str] | None =
|
|
2806
|
+
self._metadata_props: dict[str, str] | None = metadata_props
|
|
2333
2807
|
|
|
2334
2808
|
@property
|
|
2335
2809
|
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
|
|
@@ -2381,9 +2855,28 @@ Model(
|
|
|
2381
2855
|
domain={self.domain!r},
|
|
2382
2856
|
model_version={self.model_version!r},
|
|
2383
2857
|
functions={self.functions!r},
|
|
2384
|
-
graph={textwrap.indent(repr(self.graph),
|
|
2858
|
+
graph={textwrap.indent(repr(self.graph), " " * 4).strip()}
|
|
2385
2859
|
)"""
|
|
2386
2860
|
|
|
2861
|
+
def graphs(self) -> Iterable[Graph]:
|
|
2862
|
+
"""Get all graphs and subgraphs in the model.
|
|
2863
|
+
|
|
2864
|
+
This is a convenience method to traverse the model. Consider using
|
|
2865
|
+
`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2866
|
+
traversals on nodes.
|
|
2867
|
+
"""
|
|
2868
|
+
# NOTE(justinchuby): Given
|
|
2869
|
+
# (1) how useful the method is
|
|
2870
|
+
# (2) I couldn't find an appropriate name for it in `traversal.py`
|
|
2871
|
+
# (3) Users familiar with onnxruntime optimization tools expect this method
|
|
2872
|
+
# I created this method as a core method instead of an iterator in
|
|
2873
|
+
# `traversal.py`.
|
|
2874
|
+
seen_graphs: set[Graph] = set()
|
|
2875
|
+
for node in onnx_ir.traversal.RecursiveGraphIterator(self.graph):
|
|
2876
|
+
if node.graph is not None and node.graph not in seen_graphs:
|
|
2877
|
+
seen_graphs.add(node.graph)
|
|
2878
|
+
yield node.graph
|
|
2879
|
+
|
|
2387
2880
|
|
|
2388
2881
|
class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
2389
2882
|
"""IR functions.
|
|
@@ -2404,16 +2897,14 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2404
2897
|
outputs: The output values of the function.
|
|
2405
2898
|
opset_imports: Opsets imported by the function.
|
|
2406
2899
|
doc_string: Documentation string.
|
|
2407
|
-
metadata_props: Metadata that will be serialized to the ONNX file.
|
|
2408
2900
|
meta: Metadata store for graph transform passes.
|
|
2901
|
+
metadata_props: Metadata that will be serialized to the ONNX file.
|
|
2409
2902
|
"""
|
|
2410
2903
|
|
|
2411
2904
|
__slots__ = (
|
|
2412
2905
|
"_attributes",
|
|
2413
2906
|
"_domain",
|
|
2414
2907
|
"_graph",
|
|
2415
|
-
"_metadata",
|
|
2416
|
-
"_metadata_props",
|
|
2417
2908
|
"_name",
|
|
2418
2909
|
"_overload",
|
|
2419
2910
|
)
|
|
@@ -2427,16 +2918,15 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2427
2918
|
# Ensure the inputs and outputs of the function belong to a graph
|
|
2428
2919
|
# and not from an outer scope
|
|
2429
2920
|
graph: Graph,
|
|
2430
|
-
attributes:
|
|
2431
|
-
metadata_props: dict[str, str] | None = None,
|
|
2921
|
+
attributes: Iterable[Attr] | Mapping[str, Attr],
|
|
2432
2922
|
) -> None:
|
|
2433
2923
|
self._domain = domain
|
|
2434
2924
|
self._name = name
|
|
2435
2925
|
self._overload = overload
|
|
2436
2926
|
self._graph = graph
|
|
2437
|
-
|
|
2438
|
-
|
|
2439
|
-
self.
|
|
2927
|
+
if isinstance(attributes, Mapping):
|
|
2928
|
+
attributes = tuple(attributes.values())
|
|
2929
|
+
self._attributes = _graph_containers.Attributes(attributes)
|
|
2440
2930
|
|
|
2441
2931
|
def identifier(self) -> _protocols.OperatorIdentifier:
|
|
2442
2932
|
return self.domain, self.name, self.overload
|
|
@@ -2455,7 +2945,7 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2455
2945
|
|
|
2456
2946
|
@domain.setter
|
|
2457
2947
|
def domain(self, value: str) -> None:
|
|
2458
|
-
self._domain = value
|
|
2948
|
+
self._domain = _normalize_domain(value)
|
|
2459
2949
|
|
|
2460
2950
|
@property
|
|
2461
2951
|
def overload(self) -> str:
|
|
@@ -2466,18 +2956,23 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2466
2956
|
self._overload = value
|
|
2467
2957
|
|
|
2468
2958
|
@property
|
|
2469
|
-
def inputs(self) ->
|
|
2959
|
+
def inputs(self) -> MutableSequence[Value]:
|
|
2470
2960
|
return self._graph.inputs
|
|
2471
2961
|
|
|
2472
2962
|
@property
|
|
2473
|
-
def outputs(self) ->
|
|
2963
|
+
def outputs(self) -> MutableSequence[Value]:
|
|
2474
2964
|
return self._graph.outputs
|
|
2475
2965
|
|
|
2476
2966
|
@property
|
|
2477
|
-
def attributes(self) ->
|
|
2967
|
+
def attributes(self) -> _graph_containers.Attributes:
|
|
2478
2968
|
return self._attributes
|
|
2479
2969
|
|
|
2480
|
-
|
|
2970
|
+
@typing.overload
|
|
2971
|
+
def __getitem__(self, index: int) -> Node: ...
|
|
2972
|
+
@typing.overload
|
|
2973
|
+
def __getitem__(self, index: slice) -> Sequence[Node]: ...
|
|
2974
|
+
|
|
2975
|
+
def __getitem__(self, index):
|
|
2481
2976
|
return self._graph.__getitem__(index)
|
|
2482
2977
|
|
|
2483
2978
|
def __len__(self) -> int:
|
|
@@ -2508,15 +3003,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2508
3003
|
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
2509
3004
|
to the ONNX proto.
|
|
2510
3005
|
"""
|
|
2511
|
-
|
|
2512
|
-
self._metadata = _metadata.MetadataStore()
|
|
2513
|
-
return self._metadata
|
|
3006
|
+
return self._graph.meta
|
|
2514
3007
|
|
|
2515
3008
|
@property
|
|
2516
3009
|
def metadata_props(self) -> dict[str, str]:
|
|
2517
|
-
|
|
2518
|
-
self._metadata_props = {}
|
|
2519
|
-
return self._metadata_props
|
|
3010
|
+
return self._graph.metadata_props
|
|
2520
3011
|
|
|
2521
3012
|
# Mutation methods
|
|
2522
3013
|
def append(self, node: Node, /) -> None:
|
|
@@ -2549,11 +3040,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2549
3040
|
"""
|
|
2550
3041
|
self._graph.remove(nodes, safe=safe)
|
|
2551
3042
|
|
|
2552
|
-
def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
|
|
3043
|
+
def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
|
|
2553
3044
|
"""Insert new nodes after the given node in O(#new_nodes) time."""
|
|
2554
3045
|
self._graph.insert_after(node, new_nodes)
|
|
2555
3046
|
|
|
2556
|
-
def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
|
|
3047
|
+
def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
|
|
2557
3048
|
"""Insert new nodes before the given node in O(#new_nodes) time."""
|
|
2558
3049
|
self._graph.insert_before(node, new_nodes)
|
|
2559
3050
|
|
|
@@ -2581,10 +3072,10 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
2581
3072
|
>
|
|
2582
3073
|
def {full_name}(
|
|
2583
3074
|
inputs=(
|
|
2584
|
-
{textwrap.indent(inputs_text,
|
|
2585
|
-
),{textwrap.indent(attributes_text,
|
|
3075
|
+
{textwrap.indent(inputs_text, " " * 8)}
|
|
3076
|
+
),{textwrap.indent(attributes_text, " " * 4)}
|
|
2586
3077
|
outputs=(
|
|
2587
|
-
{textwrap.indent(outputs_text,
|
|
3078
|
+
{textwrap.indent(outputs_text, " " * 8)}
|
|
2588
3079
|
),
|
|
2589
3080
|
)"""
|
|
2590
3081
|
node_count = len(self)
|
|
@@ -2611,22 +3102,28 @@ def {full_name}(
|
|
|
2611
3102
|
return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
|
|
2612
3103
|
|
|
2613
3104
|
|
|
2614
|
-
class
|
|
2615
|
-
|
|
3105
|
+
class Attr(
|
|
3106
|
+
_protocols.AttributeProtocol,
|
|
3107
|
+
_protocols.ReferenceAttributeProtocol,
|
|
3108
|
+
_display.PrettyPrintable,
|
|
3109
|
+
):
|
|
3110
|
+
"""Base class for ONNX attributes or references."""
|
|
2616
3111
|
|
|
2617
|
-
__slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
|
|
3112
|
+
__slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
2618
3113
|
|
|
2619
3114
|
def __init__(
|
|
2620
3115
|
self,
|
|
2621
3116
|
name: str,
|
|
2622
|
-
ref_attr_name: str,
|
|
2623
3117
|
type: _enums.AttributeType,
|
|
3118
|
+
value: Any,
|
|
3119
|
+
ref_attr_name: str | None = None,
|
|
2624
3120
|
*,
|
|
2625
3121
|
doc_string: str | None = None,
|
|
2626
3122
|
) -> None:
|
|
2627
3123
|
self._name = name
|
|
2628
|
-
self._ref_attr_name = ref_attr_name
|
|
2629
3124
|
self._type = type
|
|
3125
|
+
self._value = value
|
|
3126
|
+
self._ref_attr_name = ref_attr_name
|
|
2630
3127
|
self.doc_string = doc_string
|
|
2631
3128
|
|
|
2632
3129
|
@property
|
|
@@ -2637,43 +3134,21 @@ class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
|
|
|
2637
3134
|
def name(self, value: str) -> None:
|
|
2638
3135
|
self._name = value
|
|
2639
3136
|
|
|
2640
|
-
@property
|
|
2641
|
-
def ref_attr_name(self) -> str:
|
|
2642
|
-
return self._ref_attr_name
|
|
2643
|
-
|
|
2644
|
-
@ref_attr_name.setter
|
|
2645
|
-
def ref_attr_name(self, value: str) -> None:
|
|
2646
|
-
self._ref_attr_name = value
|
|
2647
|
-
|
|
2648
3137
|
@property
|
|
2649
3138
|
def type(self) -> _enums.AttributeType:
|
|
2650
3139
|
return self._type
|
|
2651
3140
|
|
|
2652
|
-
@
|
|
2653
|
-
def
|
|
2654
|
-
self.
|
|
2655
|
-
|
|
2656
|
-
def __repr__(self) -> str:
|
|
2657
|
-
return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
|
|
2658
|
-
|
|
3141
|
+
@property
|
|
3142
|
+
def value(self) -> Any:
|
|
3143
|
+
return self._value
|
|
2659
3144
|
|
|
2660
|
-
|
|
2661
|
-
|
|
3145
|
+
@property
|
|
3146
|
+
def ref_attr_name(self) -> str | None:
|
|
3147
|
+
return self._ref_attr_name
|
|
2662
3148
|
|
|
2663
|
-
|
|
2664
|
-
|
|
2665
|
-
|
|
2666
|
-
self,
|
|
2667
|
-
name: str,
|
|
2668
|
-
type: _enums.AttributeType,
|
|
2669
|
-
value: Any,
|
|
2670
|
-
*,
|
|
2671
|
-
doc_string: str | None = None,
|
|
2672
|
-
):
|
|
2673
|
-
self.name = name
|
|
2674
|
-
self.type = type
|
|
2675
|
-
self.value = value
|
|
2676
|
-
self.doc_string = doc_string
|
|
3149
|
+
def is_ref(self) -> bool:
|
|
3150
|
+
"""Check if this attribute is a reference attribute."""
|
|
3151
|
+
return self.ref_attr_name is not None
|
|
2677
3152
|
|
|
2678
3153
|
def __eq__(self, other: object) -> bool:
|
|
2679
3154
|
if not isinstance(other, _protocols.AttributeProtocol):
|
|
@@ -2690,15 +3165,157 @@ class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
|
|
|
2690
3165
|
return True
|
|
2691
3166
|
|
|
2692
3167
|
def __str__(self) -> str:
|
|
3168
|
+
if self.is_ref():
|
|
3169
|
+
return f"@{self.ref_attr_name}"
|
|
2693
3170
|
if self.type == _enums.AttributeType.GRAPH:
|
|
2694
3171
|
return textwrap.indent("\n" + str(self.value), " " * 4)
|
|
2695
3172
|
return str(self.value)
|
|
2696
3173
|
|
|
2697
3174
|
def __repr__(self) -> str:
|
|
3175
|
+
if self.is_ref():
|
|
3176
|
+
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, ref_attr_name={self.ref_attr_name!r})"
|
|
2698
3177
|
return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
|
|
2699
3178
|
|
|
3179
|
+
# Well typed getters
|
|
3180
|
+
def as_float(self) -> float:
|
|
3181
|
+
"""Get the attribute value as a float."""
|
|
3182
|
+
if self.type != _enums.AttributeType.FLOAT:
|
|
3183
|
+
raise TypeError(
|
|
3184
|
+
f"Attribute '{self.name}' is not of type FLOAT. Actual type: {self.type}"
|
|
3185
|
+
)
|
|
3186
|
+
# Do not use isinstance check because it may prevent np.float32 etc. from being used
|
|
3187
|
+
return float(self.value)
|
|
3188
|
+
|
|
3189
|
+
def as_int(self) -> int:
|
|
3190
|
+
"""Get the attribute value as an int."""
|
|
3191
|
+
if self.type != _enums.AttributeType.INT:
|
|
3192
|
+
raise TypeError(
|
|
3193
|
+
f"Attribute '{self.name}' is not of type INT. Actual type: {self.type}"
|
|
3194
|
+
)
|
|
3195
|
+
# Do not use isinstance check because it may prevent np.int32 etc. from being used
|
|
3196
|
+
return int(self.value)
|
|
3197
|
+
|
|
3198
|
+
def as_string(self) -> str:
|
|
3199
|
+
"""Get the attribute value as a string."""
|
|
3200
|
+
if self.type != _enums.AttributeType.STRING:
|
|
3201
|
+
raise TypeError(
|
|
3202
|
+
f"Attribute '{self.name}' is not of type STRING. Actual type: {self.type}"
|
|
3203
|
+
)
|
|
3204
|
+
if not isinstance(self.value, str):
|
|
3205
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a string.")
|
|
3206
|
+
return self.value
|
|
3207
|
+
|
|
3208
|
+
def as_tensor(self) -> _protocols.TensorProtocol:
|
|
3209
|
+
"""Get the attribute value as a tensor."""
|
|
3210
|
+
if self.type != _enums.AttributeType.TENSOR:
|
|
3211
|
+
raise TypeError(
|
|
3212
|
+
f"Attribute '{self.name}' is not of type TENSOR. Actual type: {self.type}"
|
|
3213
|
+
)
|
|
3214
|
+
if not isinstance(self.value, _protocols.TensorProtocol):
|
|
3215
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a tensor.")
|
|
3216
|
+
return self.value
|
|
3217
|
+
|
|
3218
|
+
def as_graph(self) -> Graph:
|
|
3219
|
+
"""Get the attribute value as a graph."""
|
|
3220
|
+
if self.type != _enums.AttributeType.GRAPH:
|
|
3221
|
+
raise TypeError(
|
|
3222
|
+
f"Attribute '{self.name}' is not of type GRAPH. Actual type: {self.type}"
|
|
3223
|
+
)
|
|
3224
|
+
if not isinstance(self.value, Graph):
|
|
3225
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a graph.")
|
|
3226
|
+
return self.value
|
|
3227
|
+
|
|
3228
|
+
def as_floats(self) -> Sequence[float]:
|
|
3229
|
+
"""Get the attribute value as a sequence of floats."""
|
|
3230
|
+
if self.type != _enums.AttributeType.FLOATS:
|
|
3231
|
+
raise TypeError(
|
|
3232
|
+
f"Attribute '{self.name}' is not of type FLOATS. Actual type: {self.type}"
|
|
3233
|
+
)
|
|
3234
|
+
if not isinstance(self.value, Sequence):
|
|
3235
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3236
|
+
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
|
|
3237
|
+
# Create a copy of the list to prevent mutation
|
|
3238
|
+
return [float(v) for v in self.value]
|
|
3239
|
+
|
|
3240
|
+
def as_ints(self) -> Sequence[int]:
|
|
3241
|
+
"""Get the attribute value as a sequence of ints."""
|
|
3242
|
+
if self.type != _enums.AttributeType.INTS:
|
|
3243
|
+
raise TypeError(
|
|
3244
|
+
f"Attribute '{self.name}' is not of type INTS. Actual type: {self.type}"
|
|
3245
|
+
)
|
|
3246
|
+
if not isinstance(self.value, Sequence):
|
|
3247
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3248
|
+
# Do not use isinstance check on elements because it may prevent np.int32 etc. from being used
|
|
3249
|
+
# Create a copy of the list to prevent mutation
|
|
3250
|
+
return list(self.value)
|
|
3251
|
+
|
|
3252
|
+
def as_strings(self) -> Sequence[str]:
|
|
3253
|
+
"""Get the attribute value as a sequence of strings."""
|
|
3254
|
+
if self.type != _enums.AttributeType.STRINGS:
|
|
3255
|
+
raise TypeError(
|
|
3256
|
+
f"Attribute '{self.name}' is not of type STRINGS. Actual type: {self.type}"
|
|
3257
|
+
)
|
|
3258
|
+
if not isinstance(self.value, Sequence):
|
|
3259
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3260
|
+
if onnx_ir.DEBUG:
|
|
3261
|
+
if not all(isinstance(x, str) for x in self.value):
|
|
3262
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of strings.")
|
|
3263
|
+
# Create a copy of the list to prevent mutation
|
|
3264
|
+
return list(self.value)
|
|
3265
|
+
|
|
3266
|
+
def as_tensors(self) -> Sequence[_protocols.TensorProtocol]:
|
|
3267
|
+
"""Get the attribute value as a sequence of tensors."""
|
|
3268
|
+
if self.type != _enums.AttributeType.TENSORS:
|
|
3269
|
+
raise TypeError(
|
|
3270
|
+
f"Attribute '{self.name}' is not of type TENSORS. Actual type: {self.type}"
|
|
3271
|
+
)
|
|
3272
|
+
if not isinstance(self.value, Sequence):
|
|
3273
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3274
|
+
if onnx_ir.DEBUG:
|
|
3275
|
+
if not all(isinstance(x, _protocols.TensorProtocol) for x in self.value):
|
|
3276
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of tensors.")
|
|
3277
|
+
# Create a copy of the list to prevent mutation
|
|
3278
|
+
return list(self.value)
|
|
3279
|
+
|
|
3280
|
+
def as_graphs(self) -> Sequence[Graph]:
|
|
3281
|
+
"""Get the attribute value as a sequence of graphs."""
|
|
3282
|
+
if self.type != _enums.AttributeType.GRAPHS:
|
|
3283
|
+
raise TypeError(
|
|
3284
|
+
f"Attribute '{self.name}' is not of type GRAPHS. Actual type: {self.type}"
|
|
3285
|
+
)
|
|
3286
|
+
if not isinstance(self.value, Sequence):
|
|
3287
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence.")
|
|
3288
|
+
if onnx_ir.DEBUG:
|
|
3289
|
+
if not all(isinstance(x, Graph) for x in self.value):
|
|
3290
|
+
raise TypeError(f"Value of attribute '{self!r}' is not a Sequence of graphs.")
|
|
3291
|
+
# Create a copy of the list to prevent mutation
|
|
3292
|
+
return list(self.value)
|
|
3293
|
+
|
|
2700
3294
|
|
|
2701
3295
|
# NOTE: The following functions are just for convenience
|
|
3296
|
+
|
|
3297
|
+
|
|
3298
|
+
def RefAttr(
|
|
3299
|
+
name: str,
|
|
3300
|
+
ref_attr_name: str,
|
|
3301
|
+
type: _enums.AttributeType,
|
|
3302
|
+
doc_string: str | None = None,
|
|
3303
|
+
) -> Attr:
|
|
3304
|
+
"""Create a reference attribute.
|
|
3305
|
+
|
|
3306
|
+
Args:
|
|
3307
|
+
name: The name of the attribute.
|
|
3308
|
+
type: The type of the attribute.
|
|
3309
|
+
ref_attr_name: The name of the referenced attribute.
|
|
3310
|
+
doc_string: Documentation string.
|
|
3311
|
+
|
|
3312
|
+
Returns:
|
|
3313
|
+
A reference attribute.
|
|
3314
|
+
"""
|
|
3315
|
+
# NOTE: The function name is capitalized to maintain API backward compatibility.
|
|
3316
|
+
return Attr(name, type, None, ref_attr_name=ref_attr_name, doc_string=doc_string)
|
|
3317
|
+
|
|
3318
|
+
|
|
2702
3319
|
def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
|
|
2703
3320
|
"""Create a float attribute."""
|
|
2704
3321
|
# NOTE: The function name is capitalized to maintain API backward compatibility.
|