onnx-ir 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/_core.py ADDED
@@ -0,0 +1,2875 @@
1
+ # Copyright (c) Microsoft Corporation.
2
+ # Licensed under the MIT License.
3
+ """data structures for the intermediate representation."""
4
+
5
+ # NOTES for developers:
6
+ # NOTE: None of these classes will have a "to_onnx" or "from_protobuf" method because
7
+ # We cannot assume that the build tool chain has protoc installed and would like
8
+ # to keep this module protobuf free. This way we separate the concerns of the IR
9
+ # and the serialization/deserialization.
10
+ #
11
+ # NOTE: Do not import pathlib in the IR. It is slow. Use os.path methods instead.
12
+
13
+ from __future__ import annotations
14
+
15
+ import abc
16
+ import contextlib
17
+ import dataclasses
18
+ import heapq
19
+ import math
20
+ import mmap
21
+ import os
22
+ import sys
23
+ import textwrap
24
+ import typing
25
+ from typing import (
26
+ AbstractSet,
27
+ Any,
28
+ Collection,
29
+ Generic,
30
+ Hashable,
31
+ Iterable,
32
+ Iterator,
33
+ OrderedDict,
34
+ Sequence,
35
+ Union,
36
+ )
37
+
38
+ import ml_dtypes
39
+ import numpy as np
40
+
41
+ import onnx_ir
42
+ from onnx_ir import (
43
+ _display,
44
+ _enums,
45
+ _linked_list,
46
+ _metadata,
47
+ _name_authority,
48
+ _protocols,
49
+ _type_casting,
50
+ )
51
+
52
+ if typing.TYPE_CHECKING:
53
+ import numpy.typing as npt
54
+ from typing_extensions import TypeGuard
55
+
56
+ TArrayCompatible = typing.TypeVar(
57
+ "TArrayCompatible",
58
+ bound=Union[_protocols.ArrayCompatible, _protocols.DLPackCompatible],
59
+ )
60
+
61
+ # System is little endian
62
+ _IS_LITTLE_ENDIAN = sys.byteorder == "little"
63
+ # Data types that are not supported by numpy
64
+ _NON_NUMPY_NATIVE_TYPES = frozenset(
65
+ (
66
+ _enums.DataType.BFLOAT16,
67
+ _enums.DataType.FLOAT8E4M3FN,
68
+ _enums.DataType.FLOAT8E4M3FNUZ,
69
+ _enums.DataType.FLOAT8E5M2,
70
+ _enums.DataType.FLOAT8E5M2FNUZ,
71
+ _enums.DataType.INT4,
72
+ _enums.DataType.UINT4,
73
+ )
74
+ )
75
+
76
+
77
+ def _compatible_with_numpy(obj: Any) -> TypeGuard[_protocols.ArrayCompatible]:
78
+ """Use this function to check if an object is compatible with numpy.
79
+
80
+ Avoid isinstance checks with the ArrayCompatible protocol for performance reasons.
81
+ """
82
+ return hasattr(obj, "__array__")
83
+
84
+
85
+ def _compatible_with_dlpack(obj: Any) -> TypeGuard[_protocols.DLPackCompatible]:
86
+ """Use this function to check if an object is compatible with DLPack.
87
+
88
+ Avoid isinstance checks with the DLPackCompatible protocol for performance reasons.
89
+ """
90
+ return hasattr(obj, "__dlpack__")
91
+
92
+
93
+ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
94
+ """Convenience Shared methods for classes implementing TensorProtocol."""
95
+
96
+ __slots__ = ()
97
+
98
+ def _printable_type_shape(self) -> str:
99
+ """Return a string representation of the shape and data type."""
100
+ return f"{self.dtype},{self.shape}"
101
+
102
+ def _repr_base(self) -> str:
103
+ """Base string for the repr method.
104
+
105
+ Example: Tensor<FLOAT,[5,42]>
106
+ """
107
+ return f"{self.__class__.__name__}<{self._printable_type_shape()}>"
108
+
109
+ @property
110
+ def size(self) -> int:
111
+ """The number of elements in the tensor."""
112
+ return np.prod(self.shape.numpy()) # type: ignore[return-value,attr-defined]
113
+
114
+ @property
115
+ def nbytes(self) -> int:
116
+ """The number of bytes in the tensor."""
117
+ # Use math.ceil because when dtype is INT4, the itemsize is 0.5
118
+ return math.ceil(self.dtype.itemsize * self.size)
119
+
120
+ def display(self, *, page: bool = False) -> None:
121
+ rich = _display.require_rich()
122
+
123
+ if rich is None:
124
+ status_manager = contextlib.nullcontext()
125
+ else:
126
+ import rich.status # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
127
+
128
+ status_manager = rich.status.Status(f"Computing tensor stats for {self!r}")
129
+
130
+ from onnx_ir._thirdparty import ( # pylint: disable=import-outside-toplevel
131
+ asciichartpy,
132
+ )
133
+
134
+ with status_manager:
135
+ # Construct the text to display
136
+ lines = []
137
+ array = self.numpy().flatten()
138
+ lines.append(repr(self))
139
+ lines.append("")
140
+ nan_values = np.isnan(array)
141
+ nan_count = np.count_nonzero(nan_values)
142
+ inf_count = np.count_nonzero(np.isinf(array))
143
+ numbers = array[~nan_values]
144
+ lines.append(
145
+ f"Min: {np.min(numbers)}, Max: {np.max(numbers)}, "
146
+ f"NaN count: {nan_count}, "
147
+ f"Inf count: {inf_count}"
148
+ )
149
+ # Compute sparsity
150
+ sparse_threathold = 1e-6
151
+ # NOTE: count_nonzero() is faster than sum() for boolean arrays
152
+ sparsity = np.count_nonzero(np.abs(array) < sparse_threathold) / array.size
153
+ lines.append(f"Sparsity (abs<{sparse_threathold}): {sparsity:.2f}")
154
+
155
+ # Compute histogram
156
+ finite_numbers = array[np.isfinite(array)]
157
+ lines.append("Histogram:")
158
+ hist, bin_edges = np.histogram(finite_numbers, bins=80, density=False)
159
+ lines.append(
160
+ asciichartpy.plot(
161
+ hist, bin_edges=bin_edges, cfg={"height": 8, "format": "{:8.0f}"}
162
+ )
163
+ )
164
+
165
+ text = "\n".join(lines)
166
+
167
+ if rich is None:
168
+ print(text)
169
+ elif page:
170
+ import rich.console # type: ignore[import-not-found, no-redef] # pylint: disable=import-outside-toplevel
171
+
172
+ console = rich.console.Console()
173
+ with console.pager():
174
+ console.print(text)
175
+ else:
176
+ rich.print(text)
177
+
178
+
179
+ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType) -> None:
180
+ """Check if the numpy array dtype matches the IR data type.
181
+
182
+ When the dtype is not one of the numpy native dtypes, the value needs need to be:
183
+
184
+ - ``int8`` or ``uint8`` for int4, with the sign bit extended to 8 bits.
185
+ - ``uint8`` for uint4.
186
+ - ``uint8`` for 8-bit data types.
187
+ - ``uint16`` for bfloat16
188
+
189
+ or corresponding dtypes from the ``ml_dtype`` package.
190
+ """
191
+ if dtype in _NON_NUMPY_NATIVE_TYPES:
192
+ if dtype.itemsize == 2 and array.dtype not in (np.uint16, ml_dtypes.bfloat16):
193
+ raise TypeError(
194
+ f"The numpy array dtype must be uint16 or ml_dtypes.bfloat16 (not {array.dtype}) for IR data type {dtype}."
195
+ )
196
+ if dtype.itemsize == 1 and array.dtype not in (
197
+ np.uint8,
198
+ ml_dtypes.float8_e4m3b11fnuz,
199
+ ml_dtypes.float8_e4m3fn,
200
+ ml_dtypes.float8_e5m2fnuz,
201
+ ml_dtypes.float8_e5m2,
202
+ ):
203
+ raise TypeError(
204
+ f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
205
+ )
206
+ if dtype == _enums.DataType.INT4:
207
+ if array.dtype not in (np.int8, np.uint8, ml_dtypes.int4):
208
+ raise TypeError(
209
+ f"The numpy array dtype must be int8 or uint8 or ml_dtypes.int4 (not {array.dtype}) for IR data type {dtype}."
210
+ )
211
+ if dtype == _enums.DataType.UINT4:
212
+ if array.dtype not in (np.uint8, ml_dtypes.uint4):
213
+ raise TypeError(
214
+ f"The numpy array dtype must be uint8 or or ml_dtypes.uint4 (not {array.dtype}) for IR data type {dtype}."
215
+ )
216
+ return
217
+
218
+ try:
219
+ dtype_numpy = _enums.DataType.from_numpy(array.dtype)
220
+ except TypeError as e:
221
+ raise TypeError(
222
+ "Failed to convert the numpy dtype to an IR data type. "
223
+ "If you are using a non-native dtype, be sure to specify the corresponding IR dtype when "
224
+ "creating a Tensor."
225
+ ) from e
226
+
227
+ if dtype_numpy != dtype:
228
+ raise TypeError(
229
+ f"The numpy array dtype {array.dtype} does not match the IR data type {dtype}."
230
+ )
231
+
232
+
233
+ def _maybe_view_np_array_with_ml_dtypes(
234
+ array: np.ndarray, dtype: _enums.DataType
235
+ ) -> np.ndarray:
236
+ """Reinterpret the array when it is a bit representation of a dtype not supported by numpy.
237
+
238
+ Args:
239
+ array: The numpy array to reinterpret.
240
+ dtype: The data type to reinterpret the array as.
241
+
242
+ Returns:
243
+ The array reinterpreted as the dtype.
244
+ """
245
+ if dtype == _enums.DataType.BFLOAT16:
246
+ return array.view(ml_dtypes.bfloat16)
247
+ if dtype == _enums.DataType.FLOAT8E4M3FN:
248
+ return array.view(ml_dtypes.float8_e4m3fn)
249
+ if dtype == _enums.DataType.FLOAT8E4M3FNUZ:
250
+ return array.view(ml_dtypes.float8_e4m3fnuz)
251
+ if dtype == _enums.DataType.FLOAT8E5M2:
252
+ return array.view(ml_dtypes.float8_e5m2)
253
+ if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
254
+ return array.view(ml_dtypes.float8_e5m2fnuz)
255
+ if dtype == _enums.DataType.INT4:
256
+ return array.view(ml_dtypes.int4)
257
+ if dtype == _enums.DataType.UINT4:
258
+ return array.view(ml_dtypes.uint4)
259
+ return array
260
+
261
+
262
+ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
263
+ """An immutable concrete tensor.
264
+
265
+ This class is a wrapper around the raw tensor data. The raw tensor data can be a numpy array
266
+ compatible object (e.g. ``np.ndarray``, ``torch.Tensor``) or a ``DLPack`` compatible object.
267
+ The tensor is immutable and the data is not copied at initialization.
268
+
269
+ To create a tensor from a numpy array::
270
+
271
+ >>> import numpy as np
272
+ >>> array = np.array([1, 2, 3])
273
+ >>> tensor = Tensor(array)
274
+ >>> # The tensor itself can be treated as a numpy array because it implements the __array__ method
275
+ >>> np.allclose(tensor, array)
276
+ True
277
+
278
+ To get a numpy array from the tensor, call :meth:`numpy`. To convert the tensor
279
+ to a byte string for serialization, call :meth:`tobytes`.
280
+
281
+ It is recommended to check the size of the tensor first before accessing the
282
+ underlying data, because accessing the data may be expensive and incur IO
283
+ overhead.
284
+
285
+ Subclass this class to efficiently handle different types of tensors from different frameworks.
286
+
287
+ Attributes:
288
+ name: The name of the tensor.
289
+ shape: The shape of the tensor.
290
+ dtype: The data type of the elements of the tensor. It is an :class:`ir.DataType` enum.
291
+ doc_string: Documentation string.
292
+ raw: The raw data behind this tensor. It can be anything.
293
+ size: The number of elements in the tensor.
294
+ nbytes: The number of bytes in the tensor.
295
+ metadata_props: Metadata that will be serialized to the ONNX file.
296
+ meta: Metadata store for graph transform passes.
297
+ """
298
+
299
+ __slots__ = (
300
+ "_dtype",
301
+ "_metadata",
302
+ "_metadata_props",
303
+ "_raw",
304
+ "_shape",
305
+ "doc_string",
306
+ "name",
307
+ )
308
+
309
+ def __init__(
310
+ self,
311
+ value: TArrayCompatible,
312
+ dtype: _enums.DataType | None = None,
313
+ *,
314
+ shape: Shape | None = None,
315
+ name: str | None = None,
316
+ doc_string: str | None = None,
317
+ metadata_props: dict[str, str] | None = None,
318
+ ) -> None:
319
+ """Initialize a tensor.
320
+
321
+ Args:
322
+ value: The backing data of the tensor. It can be a numpy array compatible object or a DLPack compatible object.
323
+ When the dtype is not one of the numpy native dtypes, the value needs
324
+ to be ``uint8`` for 4-bit and 8-bit data types, and ``uint16`` for bfloat16
325
+ when the value is a numpy array; :param:`dtype` must be specified in this case.
326
+ dtype: The data type of the tensor. It can be None only when value is a numpy array.
327
+ Users are responsible for making sure the dtype matches the value when value is not a numpy array.
328
+ shape: The shape of the tensor. If None, the shape is obtained from the value.
329
+ name: The name of the tensor.
330
+ doc_string: The documentation string.
331
+ metadata_props: The metadata properties.
332
+
333
+ Raises:
334
+ TypeError: If the value is not a numpy array compatible or a DLPack compatible object.
335
+ TypeError: If the value is a numpy array and the dtype is specified but does not match the dtype of the array.
336
+ ValueError: If the shape is not specified and the value does not have a shape attribute.
337
+ ValueError: If the dtype is not specified and the value is not a numpy array.
338
+ """
339
+ # NOTE: We should not do any copying here for performance reasons
340
+ if not _compatible_with_numpy(value) and not _compatible_with_dlpack(value):
341
+ raise TypeError(f"Expected an array compatible object, got {type(value)}")
342
+ if shape is None:
343
+ # Obtain the shape from the value
344
+ if not hasattr(value, "shape"):
345
+ raise ValueError(
346
+ f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
347
+ "Please specify the shape explicitly."
348
+ )
349
+ self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
350
+ else:
351
+ self._shape = shape
352
+ self._shape._frozen = True
353
+ if dtype is None:
354
+ if isinstance(value, np.ndarray):
355
+ self._dtype = _enums.DataType.from_numpy(value.dtype)
356
+ else:
357
+ raise ValueError(
358
+ "The dtype must be specified when the value is not a numpy array."
359
+ )
360
+ else:
361
+ if isinstance(value, np.ndarray):
362
+ # Make sure the dtype matches the value
363
+ _check_numpy_representation_type(value, dtype)
364
+ # Users are responsible for making sure the dtype matches the value
365
+ # when value is not a numpy array
366
+ self._dtype = dtype
367
+
368
+ # View the bfloat16, float8 and int4 types using ml_dtypes
369
+ if isinstance(value, np.ndarray):
370
+ value = _maybe_view_np_array_with_ml_dtypes(value, self._dtype) # type: ignore[assignment]
371
+
372
+ self._raw = value
373
+ self.name = name
374
+ self.doc_string = doc_string
375
+ self._metadata: _metadata.MetadataStore | None = None
376
+ self._metadata_props = metadata_props
377
+
378
+ def __array__(self, dtype: Any = None) -> np.ndarray:
379
+ if isinstance(self._raw, np.ndarray) or _compatible_with_numpy(self._raw):
380
+ return self._raw.__array__(dtype)
381
+ assert _compatible_with_dlpack(
382
+ self._raw
383
+ ), f"Bug: Expected DLPack or Numpy compatible objects, got {type(self._raw)}"
384
+ return np.from_dlpack(self._raw)
385
+
386
+ def __dlpack__(self, *, stream: Any = None) -> Any:
387
+ if _compatible_with_dlpack(self._raw):
388
+ return self._raw.__dlpack__(stream=stream)
389
+ return self.__array__().__dlpack__(stream=stream)
390
+
391
+ def __dlpack_device__(self) -> tuple[int, int]:
392
+ if _compatible_with_dlpack(self._raw):
393
+ return self._raw.__dlpack_device__()
394
+ return self.__array__().__dlpack_device__()
395
+
396
+ def __repr__(self) -> str:
397
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
398
+
399
+ @property
400
+ def dtype(self) -> _enums.DataType:
401
+ """The data type of the tensor. Immutable."""
402
+ return self._dtype
403
+
404
+ @property
405
+ def shape(self) -> Shape:
406
+ """The shape of the tensor. Immutable."""
407
+ return self._shape
408
+
409
+ @property
410
+ def raw(self) -> TArrayCompatible:
411
+ """Backing data of the tensor. Immutable."""
412
+ return self._raw # type: ignore[return-value]
413
+
414
+ def numpy(self) -> np.ndarray:
415
+ """Return the tensor as a numpy array.
416
+
417
+ When the data type is not supported by numpy, the dtypes from the ``ml_dtype``
418
+ package are used. The values can be reinterpreted as bit representations
419
+ using the ``.view()`` method.
420
+ """
421
+ if isinstance(self._raw, np.ndarray):
422
+ return self._raw
423
+ # We do not cache the value to save memory
424
+ return self.__array__()
425
+
426
+ def tobytes(self) -> bytes:
427
+ """Returns the value as bytes encoded in little endian.
428
+
429
+ Override this method for more efficient serialization when the raw
430
+ value is not a numpy array.
431
+ """
432
+ # TODO(justinchuby): Support DLPack
433
+ array = self.numpy()
434
+ if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
435
+ # Pack the array into int4
436
+ array = _type_casting.pack_int4(array)
437
+ else:
438
+ assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
439
+ if not _IS_LITTLE_ENDIAN:
440
+ array = array.view(array.dtype.newbyteorder("<"))
441
+ return array.tobytes()
442
+
443
+ @property
444
+ def metadata_props(self) -> dict[str, str]:
445
+ if self._metadata_props is None:
446
+ self._metadata_props = {}
447
+ return self._metadata_props
448
+
449
+ @property
450
+ def meta(self) -> _metadata.MetadataStore:
451
+ """The metadata store for intermediate analysis.
452
+
453
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
454
+ to the ONNX proto.
455
+ """
456
+ if self._metadata is None:
457
+ self._metadata = _metadata.MetadataStore()
458
+ return self._metadata
459
+
460
+
461
+ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
462
+ """An immutable concrete tensor with its data store on disk.
463
+
464
+ This class uses memory mapping to avoid loading the tensor into memory,
465
+ when the data type is supported by numpy. Otherwise, the tensor is loaded
466
+ into memory lazily when accessed.
467
+
468
+ Calling :attr:`shape` does not incur IO. Checking shape before loading
469
+ the tensor is recommended if IO overhead and memory usage is a concern.
470
+
471
+ To obtain an array, call :meth:`numpy`. To obtain the bytes,
472
+ call :meth:`tobytes`.
473
+
474
+ The :attr:`location` must be a relative path conforming to the ONNX
475
+ specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
476
+ to be the full path to the data file. Users should expect that the :attr:`path`
477
+ always leads to the correct file. At initialization, paths are not checked.
478
+ It is the user's responsibility to ensure the paths are valid and accessible.
479
+
480
+ Attributes:
481
+ location: The location of the data file. It is the path relative to the base directory.
482
+ base_dir: The base directory for the external data. It is used to resolve relative paths.
483
+ At serialization, only the :attr:`location` is serialized into the "location" field of the ``TensorProto``.
484
+ path: The path to the data file. This is computed by joining :attr:`base_dir` and :attr:`location`.
485
+ offset: The offset in bytes from the start of the file.
486
+ length: The length of the data in bytes.
487
+ dtype: The data type of the tensor.
488
+ shape: The shape of the tensor.
489
+ name: The name of the tensor. It must be specified.
490
+ doc_string: The documentation string.
491
+ metadata_props: The metadata properties.
492
+ """
493
+
494
+ __slots__ = (
495
+ "_array",
496
+ "_base_dir",
497
+ "_dtype",
498
+ "_length",
499
+ "_location",
500
+ "_metadata",
501
+ "_metadata_props",
502
+ "_offset",
503
+ "_shape",
504
+ "doc_string",
505
+ "name",
506
+ "raw",
507
+ )
508
+
509
+ def __init__(
510
+ self,
511
+ location: os.PathLike | str,
512
+ offset: int | None,
513
+ length: int | None,
514
+ dtype: _enums.DataType,
515
+ *,
516
+ shape: Shape,
517
+ name: str,
518
+ doc_string: str | None = None,
519
+ metadata_props: dict[str, str] | None = None,
520
+ base_dir: os.PathLike | str = "",
521
+ ) -> None:
522
+ """Initialize an external tensor.
523
+
524
+ Args:
525
+ location: The location of the data file. It is the path relative to the base directory.
526
+ offset: The offset in bytes from the start of the file.
527
+ length: The length of the data in bytes.
528
+ dtype: The data type of the tensor.
529
+ shape: The shape of the tensor.
530
+ name: The name of the tensor..
531
+ doc_string: The documentation string.
532
+ metadata_props: The metadata properties.
533
+ base_dir: The base directory for the external data. It is used to resolve relative paths.
534
+ """
535
+ # NOTE: Do not verify the location by default. This is because the location field
536
+ # in the tensor proto can be anything and we would like deserialization from
537
+ # proto to IR to not fail.
538
+ if onnx_ir.DEBUG:
539
+ if os.path.isabs(location):
540
+ raise ValueError(
541
+ "The location must be a relative path. Please specify base_dir as well."
542
+ )
543
+ self._location = location
544
+ self._base_dir = base_dir
545
+ self._offset: int | None = offset
546
+ self._length: int | None = length
547
+ self._dtype: _enums.DataType = dtype
548
+ self.name: str = name # mutable
549
+ self._shape: Shape = shape
550
+ self._shape._frozen = True
551
+ self.doc_string: str | None = doc_string # mutable
552
+ self._array: np.ndarray | None = None
553
+ self.raw: mmap.mmap | None = None
554
+ self._metadata_props = metadata_props
555
+ self._metadata: _metadata.MetadataStore | None = None
556
+
557
+ @property
558
+ def base_dir(self) -> str | os.PathLike:
559
+ # Mutable
560
+ return self._base_dir
561
+
562
+ @base_dir.setter
563
+ def base_dir(self, value: str | os.PathLike) -> None:
564
+ self._base_dir = value
565
+
566
+ @property
567
+ def location(self) -> str | os.PathLike:
568
+ # Immutable
569
+ return self._location
570
+
571
+ @property
572
+ def path(self) -> str:
573
+ # Immutable, computed
574
+ return os.path.join(self._base_dir, self._location)
575
+
576
+ @property
577
+ def offset(self) -> int | None:
578
+ # Immutable
579
+ return self._offset
580
+
581
+ @property
582
+ def length(self) -> int | None:
583
+ # Immutable
584
+ return self._length
585
+
586
+ @property
587
+ def dtype(self) -> _enums.DataType:
588
+ # Immutable
589
+ return self._dtype
590
+
591
+ @property
592
+ def shape(self) -> Shape:
593
+ # Immutable
594
+ return self._shape
595
+
596
+ def _load(self):
597
+ assert self._array is None, "Bug: The array should be loaded only once."
598
+ if self.size == 0:
599
+ # When the size is 0, mmap is impossible and meaningless
600
+ self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
601
+ return
602
+ # Map the whole file into the memory
603
+ # TODO(justinchuby): Verify if this would exhaust the memory address space
604
+ with open(self.path, "rb") as f:
605
+ self.raw = mmap.mmap(
606
+ f.fileno(),
607
+ 0,
608
+ access=mmap.ACCESS_READ,
609
+ )
610
+ # Handle the byte order correctly by always using little endian
611
+ dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
612
+ if self.dtype in {_enums.DataType.INT4, _enums.DataType.UINT4}:
613
+ # Use uint8 to read in the full byte. Otherwise ml_dtypes.int4 will clip the values
614
+ dt = np.dtype(np.uint8).newbyteorder("<")
615
+ count = self.size // 2 + self.size % 2
616
+ else:
617
+ count = self.size
618
+ self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
619
+ shape = self.shape.numpy()
620
+ if self.dtype == _enums.DataType.INT4:
621
+ # Unpack the int4 arrays
622
+ self._array = _type_casting.unpack_int4(self._array, shape)
623
+ elif self.dtype == _enums.DataType.UINT4:
624
+ self._array = _type_casting.unpack_uint4(self._array, shape)
625
+ else:
626
+ self._array = self._array.reshape(shape)
627
+
628
+ def __array__(self, dtype: Any = None) -> np.ndarray:
629
+ if self._array is None:
630
+ self._load()
631
+ assert self._array is not None
632
+ return self._array.__array__(dtype)
633
+
634
+ def __dlpack__(self, *, stream: Any = None) -> Any:
635
+ raise NotImplementedError(
636
+ "ExternalTensor does not support DLPack because it uses memory mapping. "
637
+ "Call numpy() to get a numpy array instead."
638
+ )
639
+
640
+ def __dlpack_device__(self) -> tuple[int, int]:
641
+ raise NotImplementedError(
642
+ "ExternalTensor does not support DLPack because it uses memory mapping. "
643
+ "Call numpy() to get a numpy array instead."
644
+ )
645
+
646
+ def __repr__(self) -> str:
647
+ return (
648
+ f"{self._repr_base()}(location='{self.location}', name={self.name!r}, "
649
+ f"offset={self.offset!r}, length={self.length!r}, base_dir={self.base_dir!r})"
650
+ )
651
+
652
+ def numpy(self) -> np.ndarray:
653
+ """Return the tensor as a numpy array.
654
+
655
+ The data will be memory mapped into memory and will not taken up physical memory space.
656
+ """
657
+ if self._array is None:
658
+ self._load()
659
+ assert self._array is not None
660
+ return self._array
661
+
662
+ def tobytes(self) -> bytes:
663
+ """Return the bytes of the tensor.
664
+
665
+ This will load the tensor into memory.
666
+ """
667
+ if self.raw is None:
668
+ self._load()
669
+ assert self.raw is not None
670
+ offset = self._offset or 0
671
+ length = self._length or self.nbytes
672
+ return self.raw[offset : offset + length]
673
+
674
+ def release(self) -> None:
675
+ """Delete all references to the memory buffer and close the memory-mapped file."""
676
+ self._array = None
677
+ if self.raw is not None:
678
+ self.raw.close()
679
+ self.raw = None
680
+
681
+ @property
682
+ def metadata_props(self) -> dict[str, str]:
683
+ if self._metadata_props is None:
684
+ self._metadata_props = {}
685
+ return self._metadata_props
686
+
687
+ @property
688
+ def meta(self) -> _metadata.MetadataStore:
689
+ """The metadata store for intermediate analysis.
690
+
691
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
692
+ to the ONNX proto.
693
+ """
694
+ if self._metadata is None:
695
+ self._metadata = _metadata.MetadataStore()
696
+ return self._metadata
697
+
698
+
699
+ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
700
+ """Multidimensional array of strings (as binary data to match the string_data field in TensorProto)."""
701
+
702
+ __slots__ = (
703
+ "_metadata",
704
+ "_metadata_props",
705
+ "_raw",
706
+ "_shape",
707
+ "doc_string",
708
+ "name",
709
+ )
710
+
711
+ def __init__(
712
+ self,
713
+ value: Sequence[bytes] | npt.NDArray[np.bytes_],
714
+ *,
715
+ shape: Shape | None = None,
716
+ name: str | None = None,
717
+ doc_string: str | None = None,
718
+ metadata_props: dict[str, str] | None = None,
719
+ ) -> None:
720
+ """Initialize a tensor.
721
+
722
+ Args:
723
+ value: The backing data of the tensor. It can be a numpy array or a Sequence of bytes.
724
+ shape: The shape of the tensor. If None, the shape is obtained from the value.
725
+ name: The name of the tensor.
726
+ doc_string: The documentation string.
727
+ metadata_props: The metadata properties.
728
+ """
729
+ if shape is None:
730
+ if not hasattr(value, "shape"):
731
+ raise ValueError(
732
+ f"Expected an object with a shape attribute, but {type(value)} does not have shape. "
733
+ "Please specify the shape explicitly."
734
+ )
735
+ self._shape = Shape(getattr(value, "shape"), frozen=True) # noqa: B009
736
+ else:
737
+ self._shape = shape
738
+ self._shape._frozen = True
739
+ self._raw = value
740
+ self.name = name
741
+ self.doc_string = doc_string
742
+ self._metadata: _metadata.MetadataStore | None = None
743
+ self._metadata_props = metadata_props
744
+
745
+ def __array__(self, dtype: Any = None) -> np.ndarray:
746
+ if isinstance(self._raw, np.ndarray):
747
+ return self._raw
748
+ assert isinstance(
749
+ self._raw, Sequence
750
+ ), f"Bug: Expected a sequence, got {type(self._raw)}"
751
+ return np.array(self._raw, dtype=dtype).reshape(self.shape.numpy())
752
+
753
+ def __dlpack__(self, *, stream: Any = None) -> Any:
754
+ del stream # unused
755
+ raise TypeError("StringTensor does not support DLPack")
756
+
757
+ def __dlpack_device__(self) -> tuple[int, int]:
758
+ raise TypeError("StringTensor does not support DLPack")
759
+
760
+ def __repr__(self) -> str:
761
+ return f"{self._repr_base()}({self._raw!r}, name={self.name!r})"
762
+
763
+ @property
764
+ def dtype(self) -> _enums.DataType:
765
+ """The data type of the tensor. Immutable."""
766
+ return _enums.DataType.STRING
767
+
768
+ @property
769
+ def shape(self) -> Shape:
770
+ """The shape of the tensor. Immutable."""
771
+ return self._shape
772
+
773
+ @property
774
+ def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
775
+ """Backing data of the tensor. Immutable."""
776
+ return self._raw # type: ignore[return-value]
777
+
778
+ def numpy(self) -> npt.NDArray[np.bytes_]:
779
+ """Return the tensor as a numpy array."""
780
+ return self.__array__()
781
+
782
+ def tobytes(self) -> bytes:
783
+ raise ValueError("StringTensor does not support tobytes. Use 'string_data' instead.")
784
+
785
+ def string_data(self) -> Sequence[bytes]:
786
+ """Return the string data of the tensor."""
787
+ if isinstance(self._raw, np.ndarray):
788
+ return self._raw.flatten().tolist()
789
+ return self._raw
790
+
791
+ @property
792
+ def metadata_props(self) -> dict[str, str]:
793
+ if self._metadata_props is None:
794
+ self._metadata_props = {}
795
+ return self._metadata_props
796
+
797
+ @property
798
+ def meta(self) -> _metadata.MetadataStore:
799
+ """The metadata store for intermediate analysis.
800
+
801
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
802
+ to the ONNX proto.
803
+ """
804
+ if self._metadata is None:
805
+ self._metadata = _metadata.MetadataStore()
806
+ return self._metadata
807
+
808
+
809
+ class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
810
+ __slots__ = ("_value",)
811
+
812
+ def __init__(self, value: str | None) -> None:
813
+ """Initialize a symbolic dimension.
814
+
815
+ Args:
816
+ value: The value of the dimension. It should not be an int.
817
+ """
818
+ if isinstance(value, int):
819
+ raise TypeError(
820
+ "The value of a SymbolicDim cannot be an int. "
821
+ "If you are creating a Shape, use int directly instead of SymbolicDim."
822
+ )
823
+ self._value = value
824
+
825
+ def __eq__(self, other: object) -> bool:
826
+ if not isinstance(other, SymbolicDim):
827
+ return self.value == other
828
+ return self.value == other.value
829
+
830
+ def __hash__(self) -> int:
831
+ return hash(self.value)
832
+
833
+ @property
834
+ def value(self) -> str | None:
835
+ return self._value
836
+
837
+ def __str__(self) -> str:
838
+ return f"{self._value}"
839
+
840
+ def __repr__(self) -> str:
841
+ return f"{self.__class__.__name__}({self._value})"
842
+
843
+
844
+ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
845
+ __slots__ = ("_dims", "_frozen")
846
+
847
+ def __init__(
848
+ self,
849
+ dims: Iterable[int | SymbolicDim | str | None],
850
+ /,
851
+ denotations: Iterable[str | None] | None = None,
852
+ frozen: bool = False,
853
+ ) -> None:
854
+ """Initialize a shape.
855
+
856
+ Args:
857
+ dims: The dimensions of the shape. Each dimension can be an integer or a
858
+ SymbolicDim or any Python object. When a ``dim`` is not an integer or a
859
+ SymbolicDim, it is converted to a SymbolicDim.
860
+ denotations: The denotations of the dimensions. If None, the denotations are not set.
861
+ Standard denotation can optionally be used to denote tensor
862
+ dimensions with standard semantic descriptions to ensure
863
+ that operations are applied to the correct axis of a tensor.
864
+ Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition
865
+ for pre-defined dimension denotations.
866
+ frozen: If True, the shape is immutable and cannot be modified. This
867
+ is useful when the shape is initialized by a Tensor.
868
+ """
869
+ self._dims: list[int | SymbolicDim] = [
870
+ SymbolicDim(dim) if not isinstance(dim, (int, SymbolicDim)) else dim
871
+ for dim in dims
872
+ ]
873
+ self._denotations: list[str | None] = (
874
+ list(denotations) if denotations is not None else [None] * len(self._dims)
875
+ )
876
+ if len(self._denotations) != len(self._dims):
877
+ raise ValueError(
878
+ "The number of denotations, when provided, must be equal to the number of dimensions."
879
+ )
880
+ self._frozen: bool = frozen
881
+
882
+ def copy(self):
883
+ """Return a copy of the shape."""
884
+ return Shape(self._dims, self._denotations, self._frozen)
885
+
886
+ @property
887
+ def dims(self) -> tuple[int | SymbolicDim, ...]:
888
+ """All dimensions in the shape.
889
+
890
+ This property is read-only. Use __getitem__ and __setitem__ to modify the shape or create a new shape.
891
+ """
892
+ return tuple(self._dims)
893
+
894
+ def rank(self) -> int:
895
+ """The rank of the shape."""
896
+ return len(self._dims)
897
+
898
+ def numpy(self) -> tuple[int, ...]:
899
+ if any(not isinstance(dim, int) for dim in self._dims):
900
+ raise ValueError(f"Cannot convert the shape {self} to a tuple of ints")
901
+ return tuple(dim for dim in self._dims) # type: ignore
902
+
903
+ def __len__(self) -> int:
904
+ return len(self._dims)
905
+
906
+ def __iter__(self) -> Iterator[int | SymbolicDim]:
907
+ return iter(self._dims)
908
+
909
+ @typing.overload
910
+ def __getitem__(self, index: int) -> int | SymbolicDim: ...
911
+
912
+ @typing.overload
913
+ def __getitem__(self, index: slice) -> tuple[int | SymbolicDim, ...]: ...
914
+
915
+ def __getitem__(self, index):
916
+ return tuple(self._dims)[index]
917
+
918
+ def __setitem__(self, index: int, value: int | SymbolicDim | str | None) -> None:
919
+ """Set the dimension at the index.
920
+
921
+ Args:
922
+ index: The index of the dimension.
923
+ value: The value of the dimension.
924
+
925
+ Raises:
926
+ TypeError: If the shape is frozen and cannot be modified.
927
+ TypeError: If the value is not an int or SymbolicDim.
928
+ """
929
+ if self._frozen:
930
+ raise TypeError("The shape is frozen and cannot be modified.")
931
+ if isinstance(value, str) or value is None:
932
+ value = SymbolicDim(value)
933
+ if not isinstance(value, (int, SymbolicDim)):
934
+ raise TypeError(f"Expected int, str, None or SymbolicDim, got '{type(value)}'")
935
+
936
+ self._dims[index] = value
937
+
938
+ def get_denotation(self, index: int) -> str | None:
939
+ """Return the denotation of the dimension at the index.
940
+
941
+ Args:
942
+ index: The index of the dimension.
943
+
944
+ Returns:
945
+ The denotation of the dimension.
946
+ """
947
+ return self._denotations[index]
948
+
949
+ def set_denotation(self, index: int, denotation: str | None) -> None:
950
+ """Set the denotation of the dimension at the index.
951
+
952
+ Args:
953
+ index: The index of the dimension.
954
+ denotation: The denotation of the dimension.
955
+ """
956
+ self._denotations[index] = denotation
957
+
958
+ def __repr__(self) -> str:
959
+ return f"{self.__class__.__name__}({self._dims!r})"
960
+
961
+ def __str__(self) -> str:
962
+ """Return a string representation of the shape.
963
+
964
+ E.g. [n,1,3]
965
+ """
966
+ return f"[{','.join([str(dim) for dim in self._dims])}]"
967
+
968
+ def __eq__(self, other: object) -> bool:
969
+ """Return True if the shapes are equal.
970
+
971
+ Two shapes are eqaul if all their dimensions are equal.
972
+ """
973
+ if isinstance(other, Shape):
974
+ return self._dims == other._dims
975
+ if not isinstance(other, Iterable):
976
+ return False
977
+ return self._dims == list(other)
978
+
979
+ def __ne__(self, other: object) -> bool:
980
+ return not self.__eq__(other)
981
+
982
+
983
+ def _quoted(string: str) -> str:
984
+ """Return a quoted string.
985
+
986
+ This function is used to quote value/node names in the IR for better readability.
987
+ """
988
+ return f'"{string}"'
989
+
990
+
991
+ class Node(_protocols.NodeProtocol, _display.PrettyPrintable):
992
+ """IR Node.
993
+
994
+ If the ``graph`` is provided, the node will be added to the graph. Otherwise,
995
+ user is responsible to call ``graph.append(node)`` (or other mutation methods
996
+ in :class:`Graph`) to add the node to the graph.
997
+
998
+ After the node is initialized, it will add itself as a user of the input values.
999
+
1000
+ The output values of the node are created during node initialization and are immutable.
1001
+ To change the output values, create a new node and replace the each of the inputs of ``output.uses()`` with
1002
+ the new output values by calling :meth:`replace_input_with` on the using nodes
1003
+ of this node's outputs.
1004
+ """
1005
+
1006
+ __slots__ = (
1007
+ "_attributes",
1008
+ "_domain",
1009
+ "_graph",
1010
+ "_inputs",
1011
+ "_metadata",
1012
+ "_metadata_props",
1013
+ "_name",
1014
+ "_op_type",
1015
+ "_outputs",
1016
+ "_overload",
1017
+ "_version",
1018
+ "doc_string",
1019
+ )
1020
+
1021
+ def __init__(
1022
+ self,
1023
+ domain: str,
1024
+ op_type: str,
1025
+ inputs: Iterable[Value | None],
1026
+ attributes: Iterable[Attr | RefAttr] = (),
1027
+ *,
1028
+ overload: str = "",
1029
+ num_outputs: int | None = None,
1030
+ outputs: Sequence[Value] | None = None,
1031
+ version: int | None = None,
1032
+ graph: Graph | None = None,
1033
+ name: str | None = None,
1034
+ doc_string: str | None = None,
1035
+ metadata_props: dict[str, str] | None = None,
1036
+ ):
1037
+ """Initialize a node and add it as a user of the input values.
1038
+
1039
+ Args:
1040
+ domain: The domain of the operator. For onnx operators, this is an empty string.
1041
+ op_type: The name of the operator.
1042
+ inputs: The input values. When an input is None, it is an empty input.
1043
+ attributes: The attributes. RefAttr can be used only when the node is defined in a Function.
1044
+ overload: The overload name when the node is invoking a function.
1045
+ num_outputs: The number of outputs of the node. If not specified, the number is 1.
1046
+ outputs: The output values. If None, the outputs are created during initialization.
1047
+ version: The version of the operator. If None, the version is unspecified and will follow that of the graph.
1048
+ graph: The graph that the node belongs to. If None, the node is not added to any graph.
1049
+ A `Node` must belong to zero or one graph.
1050
+ name: The name of the node. If None, the node is anonymous.
1051
+ doc_string: The documentation string.
1052
+ metadata_props: The metadata properties.
1053
+
1054
+ Raises:
1055
+ TypeError: If the attributes are not Attr or RefAttr.
1056
+ ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1057
+ ValueError: If an output value is None, when outputs is specified.
1058
+ ValueError: If an output value has a producer set already, when outputs is specified.
1059
+ """
1060
+ self._name = name
1061
+ self._domain: str = domain
1062
+ self._op_type: str = op_type
1063
+ # NOTE: Make inputs immutable with the assumption that they are not mutated
1064
+ # very often. This way all mutations can be tracked.
1065
+ # If necessary, we can cache the inputs and outputs as tuples.
1066
+ self._inputs: tuple[Value | None, ...] = tuple(inputs)
1067
+ # Values belong to their defining nodes. The values list is immutable
1068
+ self._outputs: tuple[Value, ...] = self._create_outputs(num_outputs, outputs)
1069
+ attributes = tuple(attributes)
1070
+ if attributes and not isinstance(attributes[0], (Attr, RefAttr)):
1071
+ raise TypeError(
1072
+ f"Expected the attributes to be Attr or RefAttr, got {type(attributes[0])}. "
1073
+ "If you are copying the attributes from another node, make sure you call "
1074
+ "node.attributes.values() because it is a dictionary."
1075
+ )
1076
+ self._attributes: OrderedDict[str, Attr | RefAttr] = OrderedDict(
1077
+ (attr.name, attr) for attr in attributes
1078
+ )
1079
+ self._overload: str = overload
1080
+ # TODO(justinchuby): Potentially support a version range
1081
+ self._version: int | None = version
1082
+ self._metadata: _metadata.MetadataStore | None = None
1083
+ self._metadata_props: dict[str, str] | None = metadata_props
1084
+ self._graph: Graph | None = graph
1085
+ self.doc_string = doc_string
1086
+
1087
+ # Add the node as a use of the inputs
1088
+ for i, input_value in enumerate(self._inputs):
1089
+ if input_value is not None:
1090
+ input_value._add_usage(self, i) # pylint: disable=protected-access
1091
+
1092
+ # Add the node to the graph if graph is specified
1093
+ if self._graph is not None:
1094
+ self._graph.append(self)
1095
+
1096
+ def _create_outputs(
1097
+ self, num_outputs: int | None, outputs: Sequence[Value] | None
1098
+ ) -> tuple[Value, ...]:
1099
+ """Check the parameters and create outputs for the node.
1100
+
1101
+ Args:
1102
+ num_outputs: The number of outputs of the node.
1103
+ outputs: The output values of the node.
1104
+
1105
+ Returns:
1106
+ The output values of the node.
1107
+
1108
+ Raises:
1109
+ ValueError: If `num_outputs`, when not None, is not the same as the length of the outputs.
1110
+ ValueError: If an output value is None.
1111
+ ValueError: If an output value has a producer set already.
1112
+ """
1113
+ # Check num_outputs and outputs are consistent
1114
+ if num_outputs is not None and outputs is not None and num_outputs != len(outputs):
1115
+ raise ValueError(
1116
+ "num_outputs must be the same as len(outputs) when num_outputs is specified."
1117
+ f"num_outputs: {num_outputs}, outputs: {outputs}"
1118
+ )
1119
+ # 1. If outputs is specified (can be empty []), use the outputs
1120
+ if outputs is not None:
1121
+ # Check all output values are valid first
1122
+ for output in outputs:
1123
+ if output is None:
1124
+ raise ValueError(f"Output value cannot be None. All outputs: {outputs}")
1125
+ if output.producer() is not None:
1126
+ raise ValueError(
1127
+ f"Supplied output value cannot have a producer when used for initializing a Node. "
1128
+ f"Output: {output}. All outputs: {outputs}"
1129
+ )
1130
+ result = []
1131
+ for i, output in enumerate(outputs):
1132
+ output._producer = self # pylint: disable=protected-access
1133
+ output._index = i # pylint: disable=protected-access
1134
+ result.append(output)
1135
+ return tuple(result)
1136
+
1137
+ # 2. If num_outputs is specified, create num_outputs outputs
1138
+ if num_outputs is None:
1139
+ # Default to 1 output
1140
+ num_outputs = 1
1141
+ assert num_outputs is not None
1142
+ return tuple(Value(self, index=i) for i in range(num_outputs))
1143
+
1144
+ def __str__(self) -> str:
1145
+ node_type_text = f"{self._domain}::{self._op_type}" + f":{self._overload}" * (
1146
+ self._overload != ""
1147
+ )
1148
+ inputs_text = (
1149
+ "("
1150
+ + ", ".join(
1151
+ [
1152
+ (
1153
+ f"%{_quoted(x.name) if x.name else 'anonymous:' + str(id(x))}"
1154
+ if x is not None
1155
+ else "None"
1156
+ )
1157
+ for x in self._inputs
1158
+ ]
1159
+ )
1160
+ + ")"
1161
+ )
1162
+ attributes_text = (
1163
+ (" {" + ", ".join([f"{k}={v}" for k, v in self._attributes.items()]) + "}")
1164
+ if self._attributes
1165
+ else ""
1166
+ )
1167
+ outputs_text = ", ".join(str(x) for x in self._outputs)
1168
+
1169
+ return f"{outputs_text} ⬅️ {node_type_text}{inputs_text}{attributes_text}"
1170
+
1171
+ def __repr__(self) -> str:
1172
+ return (
1173
+ f"{self.__class__.__name__}(name={self._name!r}, domain={self._domain!r}, "
1174
+ f"op_type={self._op_type!r}, inputs={self._inputs!r}, attributes={self._attributes!r}, "
1175
+ f"overload={self._overload!r}, outputs={self._outputs!r}, "
1176
+ f"version={self._version!r}, doc_string={self.doc_string!r})"
1177
+ )
1178
+
1179
+ @property
1180
+ def name(self) -> str | None:
1181
+ return self._name
1182
+
1183
+ @name.setter
1184
+ def name(self, value: str | None) -> None:
1185
+ self._name = value
1186
+
1187
+ @property
1188
+ def domain(self) -> str:
1189
+ return self._domain
1190
+
1191
+ @domain.setter
1192
+ def domain(self, value: str) -> None:
1193
+ self._domain = value
1194
+
1195
+ @property
1196
+ def version(self) -> int | None:
1197
+ return self._version
1198
+
1199
+ @version.setter
1200
+ def version(self, value: int | None) -> None:
1201
+ self._version = value
1202
+
1203
+ @property
1204
+ def op_type(self) -> str:
1205
+ return self._op_type
1206
+
1207
+ @op_type.setter
1208
+ def op_type(self, value: str) -> None:
1209
+ self._op_type = value
1210
+
1211
+ @property
1212
+ def overload(self) -> str:
1213
+ return self._overload
1214
+
1215
+ @overload.setter
1216
+ def overload(self, value: str) -> None:
1217
+ self._overload = value
1218
+
1219
+ @property
1220
+ def inputs(self) -> Sequence[Value | None]:
1221
+ return self._inputs
1222
+
1223
+ @inputs.setter
1224
+ def inputs(self, _: Any) -> None:
1225
+ raise AttributeError(
1226
+ "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead."
1227
+ )
1228
+
1229
+ def replace_input_with(self, index: int, value: Value | None) -> None:
1230
+ """Replace an input with a new value."""
1231
+ if index < 0 or index >= len(self.inputs):
1232
+ raise ValueError(f"Index out of range: {index}")
1233
+ old_input = self.inputs[index]
1234
+ self._inputs = tuple(
1235
+ value if i == index else old_input for i, old_input in enumerate(self.inputs)
1236
+ )
1237
+ if old_input is not None:
1238
+ old_input._remove_usage(self, index) # pylint: disable=protected-access
1239
+ if value is not None:
1240
+ value._add_usage(self, index) # pylint: disable=protected-access
1241
+
1242
+ def prepend(self, /, nodes: Node | Iterable[Node]) -> None:
1243
+ """Insert a node before this node in the list of nodes in the graph.
1244
+
1245
+ It is the same as calling ``graph.insert_before(self, nodes)``.
1246
+
1247
+ Example::
1248
+
1249
+ Before: previous_node -> self
1250
+ previous_node' -> node -> next_node'
1251
+ After: previous_node -> node -> self
1252
+ previous_node' -> next_node'
1253
+
1254
+ Args:
1255
+ nodes: A node or a sequence of nodes to put before this node.
1256
+ """
1257
+ if self._graph is None:
1258
+ raise ValueError("The node to prepend to does not belong to any graph.")
1259
+ self._graph.insert_before(self, nodes)
1260
+
1261
+ def append(self, /, nodes: Node | Iterable[Node]) -> None:
1262
+ """Insert a node after this node in the list of nodes in the graph.
1263
+
1264
+ It is the same as calling ``graph.insert_after(self, nodes)``.
1265
+
1266
+ Example::
1267
+
1268
+ Before: previous_node -> self
1269
+ previous_node' -> node -> next_node'
1270
+ After: previous_node -> self -> node
1271
+ previous_node' -> next_node'
1272
+
1273
+ Args:
1274
+ nodes: A node or a sequence of nodes to put after this node.
1275
+ """
1276
+ if self._graph is None:
1277
+ raise ValueError("The node to append to does not belong to any graph.")
1278
+ self._graph.insert_after(self, nodes)
1279
+
1280
+ @property
1281
+ def outputs(self) -> Sequence[Value]:
1282
+ return self._outputs
1283
+
1284
+ @outputs.setter
1285
+ def outputs(self, _: Sequence[Value]) -> None:
1286
+ raise AttributeError("outputs is immutable. Please create a new node instead.")
1287
+
1288
+ @property
1289
+ def attributes(self) -> OrderedDict[str, Attr | RefAttr]:
1290
+ return self._attributes
1291
+
1292
+ @property
1293
+ def meta(self) -> _metadata.MetadataStore:
1294
+ """The metadata store for intermediate analysis.
1295
+
1296
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
1297
+ to the ONNX proto.
1298
+ """
1299
+ if self._metadata is None:
1300
+ self._metadata = _metadata.MetadataStore()
1301
+ return self._metadata
1302
+
1303
+ @property
1304
+ def metadata_props(self) -> dict[str, str]:
1305
+ if self._metadata_props is None:
1306
+ self._metadata_props = {}
1307
+ return self._metadata_props
1308
+
1309
+ @property
1310
+ def graph(self) -> Graph | None:
1311
+ return self._graph
1312
+
1313
+ @graph.setter
1314
+ def graph(self, value: Graph | None) -> None:
1315
+ self._graph = value
1316
+
1317
+ def op_identifier(self) -> _protocols.OperatorIdentifier:
1318
+ return self.domain, self.op_type, self.overload
1319
+
1320
+ def display(self, *, page: bool = False) -> None:
1321
+ # Add the node's name to the displayed text
1322
+ print(f"Node: {self.name!r}")
1323
+ if self.doc_string:
1324
+ print(f"Doc: {self.doc_string}")
1325
+ super().display(page=page)
1326
+
1327
+
1328
+ class _TensorTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
1329
+ """Tensor types that are non recursive types."""
1330
+
1331
+ __slots__ = ("_dtype", "denotation")
1332
+
1333
+ def __init__(self, dtype: _enums.DataType, *, denotation: str | None = None) -> None:
1334
+ self._dtype = dtype
1335
+ self.denotation = denotation
1336
+
1337
+ @property
1338
+ def dtype(self) -> _enums.DataType:
1339
+ return self._dtype
1340
+
1341
+ @dtype.setter
1342
+ def dtype(self, value: _enums.DataType) -> None:
1343
+ self._dtype = value
1344
+
1345
+ @property
1346
+ def elem_type(self) -> _enums.DataType:
1347
+ """Return the element type of the tensor type"""
1348
+ return self.dtype
1349
+
1350
+ def __hash__(self) -> int:
1351
+ return hash(repr(self))
1352
+
1353
+ def __eq__(self, other: object) -> bool:
1354
+ if self.__class__ is not other.__class__:
1355
+ return False
1356
+ return self.dtype == other.dtype # type: ignore[attr-defined]
1357
+
1358
+ def __repr__(self) -> str:
1359
+ # Remove "Type" from name for display
1360
+ short_name = self.__class__.__name__[:-4]
1361
+ return f"{short_name}({self.dtype!r})"
1362
+
1363
+
1364
+ class TensorType(_TensorTypeBase):
1365
+ """A type that represents a tensor."""
1366
+
1367
+ def __str__(self) -> str:
1368
+ return f"{self.dtype}"
1369
+
1370
+
1371
+ class SparseTensorType(_TensorTypeBase):
1372
+ """A type that represents a sparse tensor."""
1373
+
1374
+
1375
+ class _RecursiveTypeBase(_protocols.TypeProtocol, _display.PrettyPrintable, Hashable):
1376
+ """Base for recursive types like Optional and Sequence."""
1377
+
1378
+ __slots__ = ("_elem_type", "denotation")
1379
+
1380
+ def __init__(
1381
+ self, elem_type: _protocols.TypeProtocol, *, denotation: str | None = None
1382
+ ) -> None:
1383
+ self._elem_type = elem_type
1384
+ self.denotation = denotation
1385
+
1386
+ @property
1387
+ def dtype(self) -> _enums.DataType:
1388
+ return self._elem_type.dtype
1389
+
1390
+ @dtype.setter
1391
+ def dtype(self, value: _enums.DataType) -> None:
1392
+ self._elem_type.dtype = value
1393
+
1394
+ @property
1395
+ def elem_type(self) -> _protocols.TypeProtocol:
1396
+ return self._elem_type
1397
+
1398
+ def __hash__(self) -> int:
1399
+ return hash(repr(self))
1400
+
1401
+ def __eq__(self, other: object) -> bool:
1402
+ if not isinstance(other, _RecursiveTypeBase):
1403
+ return False
1404
+ if self.__class__ != other.__class__:
1405
+ return False
1406
+ # Recursively compare the type of the elements
1407
+ return self.elem_type == other.elem_type
1408
+
1409
+ def __repr__(self) -> str:
1410
+ # Remove "Type" from name for display
1411
+ short_name = self.__class__.__name__[:-4]
1412
+ return f"{short_name}({self.elem_type!r})"
1413
+
1414
+
1415
+ class SequenceType(_RecursiveTypeBase):
1416
+ """A type that represents a sequence of elements."""
1417
+
1418
+
1419
+ class OptionalType(_RecursiveTypeBase):
1420
+ """A type that represents an optional element."""
1421
+
1422
+
1423
+ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
1424
+ """IR Value.
1425
+
1426
+ A value is a named entity that can be used to represent an input or output of a graph,
1427
+ a function, or a node. The information it stores generalizes over ``ValueInfoProto``
1428
+ in the ONNX specification.
1429
+
1430
+ A :class:`Value` is always not owned or owned by exactly one node. When the value is not
1431
+ owned, it must be an input of a graph or a function. ``producer`` and ``index``
1432
+ are ``None``.
1433
+
1434
+ When the value is owned by a node, it is an output of the node.
1435
+ The node that produces the value can be accessed with :meth:`producer`.
1436
+ The index of the output of the node that produces the value can be accessed with
1437
+ :meth:`index`.
1438
+
1439
+ To find all the nodes that use this value as an input, call :meth:`uses`.
1440
+
1441
+ To check if the value is an output of a graph, call :meth:`is_graph_output`.
1442
+
1443
+ Attributes:
1444
+ name: The name of the value. A value is always named when it is part of a graph.
1445
+ shape: The shape of the value.
1446
+ type: The type of the value.
1447
+ metadata_props: Metadata.
1448
+ """
1449
+
1450
+ __slots__ = (
1451
+ "_const_value",
1452
+ "_index",
1453
+ "_metadata",
1454
+ "_metadata_props",
1455
+ "_name",
1456
+ "_producer",
1457
+ "_shape",
1458
+ "_type",
1459
+ "_uses",
1460
+ "doc_string",
1461
+ )
1462
+
1463
+ def __init__(
1464
+ self,
1465
+ producer: Node | None = None,
1466
+ *,
1467
+ index: int | None = None,
1468
+ name: str | None = None,
1469
+ shape: Shape | None = None,
1470
+ type: _protocols.TypeProtocol | None = None,
1471
+ doc_string: str | None = None,
1472
+ const_value: _protocols.TensorProtocol | None = None,
1473
+ ) -> None:
1474
+ """Initialize a value.
1475
+
1476
+ Args:
1477
+ producer: The node that produces the value.
1478
+ It can be ``None`` when the value is initialized first than its producer.
1479
+ index: The index of the output of the defining node.
1480
+ name: The name of the value.
1481
+ shape: The shape of the value.
1482
+ type: The type of the value.
1483
+ doc_string: The documentation string.
1484
+ const_value: The constant tensor if the value is constant.
1485
+ """
1486
+ self._producer: Node | None = producer
1487
+ self._index: int | None = index
1488
+ self._metadata: _metadata.MetadataStore | None = None
1489
+ self._metadata_props: dict[str, str] | None = None
1490
+
1491
+ self._name: str | None = name
1492
+ self._shape: Shape | None = shape
1493
+ self._type: _protocols.TypeProtocol | None = type
1494
+ # TODO(justinchuby): Handle initialization when a const value is provided
1495
+ # We can get shape and type information from the const value
1496
+ self._const_value = const_value
1497
+ # Use a collection of (Node, int) to store uses. This is needed
1498
+ # because a single use can use the same value multiple times.
1499
+ # Use a dictionary to preserve insertion order so that the visiting order is deterministic
1500
+ self._uses: dict[tuple[Node, int], None] = {}
1501
+ self.doc_string = doc_string
1502
+
1503
+ def __repr__(self) -> str:
1504
+ value_name = self.name if self.name else "anonymous:" + str(id(self))
1505
+ producer = self.producer()
1506
+ if producer is None:
1507
+ producer_text = "None"
1508
+ elif producer.name is not None:
1509
+ producer_text = producer.name
1510
+ else:
1511
+ producer_text = f"anonymous_node:{id(producer)}"
1512
+ return f"{self.__class__.__name__}({value_name!r}, type={self.type!r}, shape={self.shape}, producer={producer_text}, index={self.index()})"
1513
+
1514
+ def __str__(self) -> str:
1515
+ value_name = self.name if self.name is not None else "anonymous:" + str(id(self))
1516
+ shape_text = str(self.shape) if self.shape is not None else "?"
1517
+ type_text = str(self.type) if self.type is not None else "?"
1518
+
1519
+ # Quote the name because in reality the names can have invalid characters
1520
+ # that make them hard to read
1521
+ return f"%{_quoted(value_name)}<{type_text},{shape_text}>"
1522
+
1523
+ def producer(self) -> Node | None:
1524
+ """The node that produces this value.
1525
+
1526
+ When producer is ``None``, the value does not belong to a node, and is
1527
+ typically a graph input or an initializer.
1528
+ """
1529
+ return self._producer
1530
+
1531
+ def index(self) -> int | None:
1532
+ """The index of the output of the defining node."""
1533
+ return self._index
1534
+
1535
+ def uses(self) -> Collection[tuple[Node, int]]:
1536
+ """Return a set of uses of the value.
1537
+
1538
+ The set contains tuples of ``(Node, index)`` where the index is the index of the input
1539
+ of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``.
1540
+ """
1541
+ return self._uses.keys()
1542
+
1543
+ def _add_usage(self, use: Node, index: int) -> None:
1544
+ """Add a usage of this value.
1545
+
1546
+ This is an internal method. It should only be called by the Node class.
1547
+ """
1548
+ self._uses[(use, index)] = None
1549
+
1550
+ def _remove_usage(self, use: Node, index: int) -> None:
1551
+ """Remove a node from the uses of this value.
1552
+
1553
+ This is an internal method. It should only be called by the Node class.
1554
+ """
1555
+ self._uses.pop((use, index))
1556
+
1557
+ @property
1558
+ def name(self) -> str | None:
1559
+ return self._name
1560
+
1561
+ @name.setter
1562
+ def name(self, value: str | None) -> None:
1563
+ if self._const_value is not None:
1564
+ self._const_value.name = value
1565
+ self._name = value
1566
+
1567
+ @property
1568
+ def type(self) -> _protocols.TypeProtocol | None:
1569
+ """The type of the tensor.
1570
+
1571
+ Example types can be ``TensorType``, ``SparseTensorType``, ``SequenceType``, ``OptionalType``.
1572
+ To obtain the data type of the tensor, use ``type.dtype`` or conveniently
1573
+ :attr:`dtype`.
1574
+ """
1575
+ return self._type
1576
+
1577
+ @type.setter
1578
+ def type(self, value: _protocols.TypeProtocol | None) -> None:
1579
+ self._type = value
1580
+
1581
+ @property
1582
+ def dtype(self) -> _enums.DataType | None:
1583
+ """The data type of the tensor."""
1584
+ if self._type is None:
1585
+ return None
1586
+ return self._type.dtype
1587
+
1588
+ @dtype.setter
1589
+ def dtype(self, value: _enums.DataType) -> None:
1590
+ """Set the data type of the tensor.
1591
+
1592
+ If the type is not set, it will be initialized to a new TensorType. To
1593
+ set the type as other types like ``SequenceType``, initialize the type
1594
+ then set :attr:`type` instead.
1595
+ """
1596
+ if self._type is None:
1597
+ self._type = TensorType(value)
1598
+ else:
1599
+ self._type.dtype = value
1600
+
1601
+ @property
1602
+ def shape(self) -> Shape | None:
1603
+ return self._shape
1604
+
1605
+ @shape.setter
1606
+ def shape(self, value: Shape | None) -> None:
1607
+ if value is None:
1608
+ self._shape = None
1609
+ return
1610
+ if isinstance(value, Shape):
1611
+ self._shape = value
1612
+ return
1613
+ raise TypeError(f"Expected value to be a Shape or None, got '{type(value)}'")
1614
+
1615
+ @property
1616
+ def const_value(
1617
+ self,
1618
+ ) -> _protocols.TensorProtocol | None:
1619
+ """A concrete value.
1620
+
1621
+ The value can be backed by different raw data types, such as numpy arrays.
1622
+ The only guarantee is that it conforms TensorProtocol.
1623
+ """
1624
+ return self._const_value
1625
+
1626
+ @const_value.setter
1627
+ def const_value(
1628
+ self,
1629
+ value: _protocols.TensorProtocol | None,
1630
+ ) -> None:
1631
+ if onnx_ir.DEBUG:
1632
+ if value is not None and not isinstance(value, _protocols.TensorProtocol):
1633
+ raise TypeError(
1634
+ f"Expected value to be a TensorProtocol or None, got '{type(value)}'"
1635
+ )
1636
+ self._const_value = value
1637
+
1638
+ @property
1639
+ def meta(self) -> _metadata.MetadataStore:
1640
+ """The metadata store for intermediate analysis.
1641
+
1642
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
1643
+ to the ONNX proto.
1644
+ """
1645
+ if self._metadata is None:
1646
+ self._metadata = _metadata.MetadataStore()
1647
+ return self._metadata
1648
+
1649
+ @property
1650
+ def metadata_props(self) -> dict[str, str]:
1651
+ if self._metadata_props is None:
1652
+ self._metadata_props = {}
1653
+ return self._metadata_props
1654
+
1655
+ def is_graph_output(self) -> bool:
1656
+ """Whether the value is an output of a graph."""
1657
+ if (producer := self.producer()) is None:
1658
+ return False
1659
+ if (graph := producer.graph) is None:
1660
+ return False
1661
+ # Cannot use `in` because __eq__ may be defined by subclasses, even though
1662
+ # it is not recommended
1663
+ return any(output is self for output in graph.outputs)
1664
+
1665
+
1666
+ def Input(
1667
+ name: str | None = None,
1668
+ shape: Shape | None = None,
1669
+ type: _protocols.TypeProtocol | None = None,
1670
+ doc_string: str | None = None,
1671
+ ) -> Value:
1672
+ """Create an input of a Graph or a Function.
1673
+
1674
+ This is equivalent to calling ``Value(name=name, shape=shape, type=type, doc_string=doc_string)``.
1675
+ """
1676
+
1677
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
1678
+
1679
+ return Value(name=name, shape=shape, type=type, doc_string=doc_string)
1680
+
1681
+
1682
+ def _check_node_safe_to_remove(
1683
+ node: Node, to_remove: AbstractSet[Node], graph_outputs: AbstractSet[Value]
1684
+ ) -> None:
1685
+ """Check if a node is safe to remove.
1686
+
1687
+ 1. It checks to make sure there are no users of the node that are not
1688
+ to be removed before removing it.
1689
+ 2. It checks the node does not contribute to any graph outputs.
1690
+
1691
+ This check is typically O(1) assuming the number of uses of the node is small
1692
+
1693
+ Args:
1694
+ node: The node to check.
1695
+ to_remove: A set of nodes that are to be removed.
1696
+ This set is used to check if the node is still being used by other
1697
+ nodes that are not to be removed.
1698
+ graph_outputs: A set of values that are outputs of the graph.
1699
+
1700
+ Raises:
1701
+ ValueError: If the node does not belong to this graph or if there are users of the node.
1702
+ ValueError: If the node is still being used by other nodes not to be removed.
1703
+ """
1704
+ for output in node.outputs:
1705
+ if output in graph_outputs:
1706
+ raise ValueError(
1707
+ f"Node '{node!r}' is still an output of the graph and cannot be removed when safe=True."
1708
+ )
1709
+ uses_not_to_remove = [user for user, _ in output.uses() if user not in to_remove]
1710
+ if uses_not_to_remove:
1711
+ raise ValueError(
1712
+ f"Output value '{output!r}' is still being used by other nodes that are not to be "
1713
+ f"removed. All of its users that is not being removed: {uses_not_to_remove!r}. "
1714
+ "Please make sure these nodes are no longer using the output value."
1715
+ )
1716
+
1717
+
1718
+ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
1719
+ """IR Graph.
1720
+
1721
+ Graph represents a computation graph. In addition to the ONNX specification
1722
+ specified fields, it also contains a mapping of :attr:`opset_imports`. This
1723
+ allows different subgraphs to import different opsets. It is the responsibility
1724
+ of the deserializer to reconcile the different opsets.
1725
+
1726
+ The `nodes` are not guaranteed to be topologically sorted. But the
1727
+ iteration order should be deterministic across different runs. It is the
1728
+ responsibility of the user to maintain a topological order of the nodes.
1729
+
1730
+ Note that there is not a ``node`` attribute in the Graph. The Graph can be
1731
+ seen as a Sequence of nodes and should be used as such. For example, to obtain
1732
+ all nodes as a list, call ``list(graph)``.
1733
+
1734
+ Attributes:
1735
+ name: The name of the graph.
1736
+ inputs: The input values of the graph.
1737
+ outputs: The output values of the graph.
1738
+ initializers: The initializers in the graph.
1739
+ doc_string: Documentation string.
1740
+ opset_imports: Opsets imported by the graph.
1741
+ metadata_props: Metadata that will be serialized to the ONNX file.
1742
+ meta: Metadata store for graph transform passes.
1743
+ """
1744
+
1745
+ __slots__ = (
1746
+ "_doc_string",
1747
+ "_initializers",
1748
+ "_inputs",
1749
+ "_metadata",
1750
+ "_metadata_props",
1751
+ "_name_authority",
1752
+ "_nodes",
1753
+ "_opset_imports",
1754
+ "_outputs",
1755
+ "name",
1756
+ )
1757
+
1758
+ def __init__(
1759
+ self,
1760
+ inputs: Sequence[Value],
1761
+ outputs: Sequence[Value],
1762
+ *,
1763
+ nodes: Iterable[Node],
1764
+ initializers: Sequence[Value] = (),
1765
+ doc_string: str | None = None,
1766
+ opset_imports: dict[str, int] | None = None,
1767
+ name: str | None = None,
1768
+ metadata_props: dict[str, str] | None = None,
1769
+ ):
1770
+ self.name = name
1771
+
1772
+ # Private fields that are not to be accessed by any other classes
1773
+ self._inputs = list(inputs)
1774
+ self._outputs = list(outputs)
1775
+ self._initializers = {}
1776
+ for initializer in initializers:
1777
+ if isinstance(initializer, str):
1778
+ raise TypeError(
1779
+ "Initializer must be a Value, not a string. "
1780
+ "If you are copying the initializers from another graph, "
1781
+ "make sure you call graph.initializers.values() because it is a dictionary."
1782
+ )
1783
+ if initializer.name is None:
1784
+ raise ValueError(f"Initializer must have a name: {initializer}")
1785
+ self._initializers[initializer.name] = initializer
1786
+ self._doc_string = doc_string
1787
+ self._opset_imports = opset_imports or {}
1788
+ self._metadata: _metadata.MetadataStore | None = None
1789
+ self._metadata_props: dict[str, str] | None = metadata_props
1790
+ self._nodes: _linked_list.DoublyLinkedSet[Node] = _linked_list.DoublyLinkedSet()
1791
+ # Be sure the initialize the name authority before extending the nodes
1792
+ # because it is used to name the nodes and their outputs
1793
+ self._name_authority = _name_authority.NameAuthority()
1794
+ # Call self.extend not self._nodes.extend so the graph reference is added to the nodes
1795
+ self.extend(nodes)
1796
+
1797
+ @property
1798
+ def inputs(self) -> list[Value]:
1799
+ return self._inputs
1800
+
1801
+ @property
1802
+ def outputs(self) -> list[Value]:
1803
+ return self._outputs
1804
+
1805
+ @property
1806
+ def initializers(self) -> dict[str, Value]:
1807
+ return self._initializers
1808
+
1809
+ @property
1810
+ def doc_string(self) -> str | None:
1811
+ return self._doc_string
1812
+
1813
+ @doc_string.setter
1814
+ def doc_string(self, value: str | None) -> None:
1815
+ self._doc_string = value
1816
+
1817
+ @property
1818
+ def opset_imports(self) -> dict[str, int]:
1819
+ return self._opset_imports
1820
+
1821
+ def __getitem__(self, index: int) -> Node:
1822
+ return self._nodes[index]
1823
+
1824
+ def __len__(self) -> int:
1825
+ return len(self._nodes)
1826
+
1827
+ def __iter__(self) -> Iterator[Node]:
1828
+ return iter(self._nodes)
1829
+
1830
+ def __reversed__(self) -> Iterator[Node]:
1831
+ return reversed(self._nodes)
1832
+
1833
+ def _set_node_graph_to_self_and_assign_names(self, node: Node) -> Node:
1834
+ """Set the graph reference for the node and assign names to it and its outputs if they don't have one."""
1835
+ if node.graph is not None and node.graph is not self:
1836
+ raise ValueError(
1837
+ f"The node '{node!r}' belongs to another graph. Please remove it first with Graph.remove()."
1838
+ )
1839
+ # Give the node and its output values names if they don't not have one
1840
+ self._name_authority.register_or_name_node(node)
1841
+ for value in node._outputs: # pylint: disable=protected-access
1842
+ self._name_authority.register_or_name_value(value)
1843
+ node.graph = self
1844
+ return node
1845
+
1846
+ def node(self, index_or_name: int | str, /) -> Node:
1847
+ """Get a node by index or name.
1848
+
1849
+ This is an O(n) operation. Getting nodes on the ends of the graph (0 or -1) is O(1).
1850
+
1851
+ .. note::
1852
+ If you need repeated random access, consider turning it into a list with ``list(graph)`` .
1853
+ Or a dictionary for repeated access by name: ``{node.name for node in graph}`` .
1854
+
1855
+ When a name is provided and if there are multiple nodes with the same name,
1856
+ the first node with the name is returned.
1857
+
1858
+ Args:
1859
+ index_or_name: The index or name of the node.
1860
+
1861
+ Returns:
1862
+ The node if found.
1863
+
1864
+ Raises:
1865
+ IndexError: If the index is out of range.
1866
+ ValueError: If the node with the given name is not found.
1867
+ """
1868
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
1869
+ if isinstance(index_or_name, int):
1870
+ return self[index_or_name]
1871
+ for node in self:
1872
+ if node.name == index_or_name:
1873
+ return node
1874
+ raise ValueError(f"Node with name '{index_or_name}' not found.")
1875
+
1876
+ def num_nodes(self) -> int:
1877
+ """Get the number of nodes in the graph in O(1) time.
1878
+
1879
+ Note that this method returns the number of nodes this graph directly contains.
1880
+ It does not count nodes in subgraphs.
1881
+
1882
+ This is an alias for ``len(graph)``. Use this if you prefer a more descriptive
1883
+ name for readability.
1884
+ """
1885
+ # NOTE: This is a method specific to Graph, not required by the protocol unless proven
1886
+ return len(self)
1887
+
1888
+ # Mutation methods
1889
+ def append(self, node: Node, /) -> None:
1890
+ """Append a node to the graph in O(1) time.
1891
+
1892
+ Unique names will be assigned to the node and its values if any name is ``None``.
1893
+
1894
+ Args:
1895
+ node: The node to append.
1896
+
1897
+ Raises:
1898
+ ValueError: If the node belongs to another graph.
1899
+ """
1900
+ self._set_node_graph_to_self_and_assign_names(node)
1901
+ self._nodes.append(node)
1902
+
1903
+ def extend(self, nodes: Iterable[Node], /) -> None:
1904
+ """Extend the graph with the given nodes in O(#new_nodes) time.
1905
+
1906
+ Unique names will be assigned to the node and its values if any name is ``None``.
1907
+
1908
+ Args:
1909
+ nodes: The nodes to extend the graph with.
1910
+
1911
+ Raises:
1912
+ ValueError: If any node belongs to another graph.
1913
+ """
1914
+ nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in nodes]
1915
+ self._nodes.extend(nodes)
1916
+
1917
+ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
1918
+ """Remove nodes from the graph in O(#num of nodes to remove) time.
1919
+
1920
+ If any errors are raise, to ensure the graph is not left in an inconsistent state,
1921
+ the graph is not modified.
1922
+
1923
+ Args:
1924
+ nodes: The node to remove.
1925
+ safe: If True, performs the following actions before removal:
1926
+
1927
+ 1. It checks to make sure there are no users of the node that are not
1928
+ to be removed before removing it.
1929
+ 2. It checks the node does not contribute to any graph outputs.
1930
+ 3. It removes references to all inputs so it is no longer a user of other nodes.
1931
+
1932
+ Raises:
1933
+ ValueError: If any node to remove does not belong to this graph.
1934
+ ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
1935
+ ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
1936
+ """
1937
+ if not isinstance(nodes, Iterable):
1938
+ nodes_set: AbstractSet[Node] = {nodes}
1939
+ else:
1940
+ nodes_set = frozenset(nodes)
1941
+ graph_outputs = frozenset(self.outputs)
1942
+ for node in nodes_set:
1943
+ if node.graph is not self:
1944
+ raise ValueError(f"The node '{node!r}' does not belong to this graph.")
1945
+ if safe:
1946
+ # Check 1, 2
1947
+ _check_node_safe_to_remove(node, nodes_set, graph_outputs)
1948
+ for node in nodes_set:
1949
+ if safe:
1950
+ # 3. Detach from all inputs so that it is no longer a user of other nodes
1951
+ for i in range(len(node.inputs)):
1952
+ node.replace_input_with(i, None)
1953
+ # Set attributes to remove the node from this graph
1954
+ node.graph = None
1955
+ self._nodes.remove(node)
1956
+
1957
+ def insert_after(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
1958
+ """Insert new nodes after the given node in O(#new_nodes) time.
1959
+
1960
+ Unique names will be assigned to the node and its values if any name is ``None``.
1961
+
1962
+ Args:
1963
+ node: The node to insert after.
1964
+ new_nodes: The new nodes to insert.
1965
+
1966
+ Raises:
1967
+ ValueError: If any node belongs to another graph.
1968
+ """
1969
+ if isinstance(new_nodes, Node):
1970
+ new_nodes = (new_nodes,)
1971
+ new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
1972
+ self._nodes.insert_after(node, new_nodes)
1973
+
1974
+ def insert_before(self, node: Node, new_nodes: Iterable[Node] | Node, /) -> None:
1975
+ """Insert new nodes before the given node in O(#new_nodes) time.
1976
+
1977
+ Unique names will be assigned to the node and its values if any name is ``None``.
1978
+
1979
+ Args:
1980
+ node: The node to insert before.
1981
+ new_nodes: The new nodes to insert.
1982
+
1983
+ Raises:
1984
+ ValueError: If any node belongs to another graph.
1985
+ """
1986
+ if isinstance(new_nodes, Node):
1987
+ new_nodes = (new_nodes,)
1988
+ new_nodes = [self._set_node_graph_to_self_and_assign_names(node) for node in new_nodes]
1989
+ self._nodes.insert_before(node, new_nodes)
1990
+
1991
+ def sort(self) -> None:
1992
+ """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time.
1993
+
1994
+ This sort is stable. It preserves the original order as much as possible.
1995
+
1996
+ Referece: https://github.com/madelson/MedallionTopologicalSort#stable-sort
1997
+
1998
+ Raises:
1999
+ ValueError: If the graph contains a cycle, making topological sorting impossible.
2000
+ """
2001
+ # Obtain all nodes from the graph and its subgraphs for sorting
2002
+ nodes = list(onnx_ir.traversal.RecursiveGraphIterator(self))
2003
+ # Store the sorted nodes of each subgraph
2004
+ sorted_nodes_by_graph: dict[Graph, list[Node]] = {
2005
+ graph: [] for graph in {node.graph for node in nodes if node.graph is not None}
2006
+ }
2007
+ # TODO: Explain why we need to store direct predecessors and children and why
2008
+ # we only need to store the direct ones
2009
+
2010
+ # The depth of a node is defined as the number of direct children it has
2011
+ node_depth: dict[Node, int] = dict.fromkeys(nodes, 0)
2012
+ # Direct predecessors of a node
2013
+ node_predecessors: dict[Node, list[Node]] = {node: [] for node in nodes}
2014
+ # Store the negative index of the nodes because heapq is a min heap and we
2015
+ # want to pop the node with largest index value first, effectively turning
2016
+ # it to a max heap
2017
+ neg_node_index: dict[Node, int] = {node: -i for i, node in enumerate(nodes)}
2018
+
2019
+ def add_predecessor(child: Node, predecessor: Node | None) -> None:
2020
+ """Add a predecessor of a node, and increment the depth of the predecessor."""
2021
+ if predecessor is None:
2022
+ return
2023
+ node_predecessors[child].append(predecessor)
2024
+ node_depth[predecessor] += 1
2025
+
2026
+ # 1. Build the direct predecessors of each node and the depth of each node
2027
+ # for sorting topolocally using Kahn's algorithm.
2028
+ # Note that when a node contains graph attributes (aka. has subgraphs),
2029
+ # we consider all nodes in the subgraphs *predecessors* of this node. This
2030
+ # way we ensure the implicit dependencies of the subgraphs are captured
2031
+ # as predecessors of the node.
2032
+ for node in nodes:
2033
+ # All producers of input values are considered as direct predecessors.
2034
+ for input_value in node.inputs:
2035
+ if input_value is None:
2036
+ continue
2037
+ predecessor_node = input_value.producer()
2038
+ add_predecessor(node, predecessor_node)
2039
+ # All nodes in attribute graphs are considered as direct predecessors.
2040
+ for attr in node.attributes.values():
2041
+ if not isinstance(attr, Attr):
2042
+ continue
2043
+ # A nice thing about this algorithm is that we only need to record
2044
+ # direct predecessors. This continues to be true even with subgraphs.
2045
+ # When a node in a subgraph (a) contains its own subgraphs (b), the
2046
+ # node in subgraphs (b) are guranteed to appear before the node
2047
+ # in (a).
2048
+ if attr.type == _enums.AttributeType.GRAPH:
2049
+ for predecessor_node in attr.value:
2050
+ add_predecessor(node, predecessor_node)
2051
+ elif attr.type == _enums.AttributeType.GRAPHS:
2052
+ for attribute_graph in attr.value:
2053
+ for predecessor_node in attribute_graph:
2054
+ add_predecessor(node, predecessor_node)
2055
+
2056
+ # 2. Priority Queue: Track nodes with zero direct children in a priority queue,
2057
+ # using NEGATIVE original index for ordering.
2058
+ # This ensures nodes appearing LATER in the original order are processed EARLIER.
2059
+ # We get REVERSED topological order of each subgraph.
2060
+ priority_queue: list[tuple[int, Node]] = [
2061
+ (neg_node_index[node], node) for node in nodes if node_depth[node] == 0
2062
+ ]
2063
+ heapq.heapify(priority_queue)
2064
+
2065
+ # 3. Topological Sort:
2066
+ num_of_sorted_nodes = 0
2067
+ while priority_queue:
2068
+ # Pop the node with the most negative index and add it to the sorted nodes by subgraph.
2069
+ _, current_node = heapq.heappop(priority_queue)
2070
+ assert current_node.graph is not None
2071
+ sorted_nodes_by_graph[current_node.graph].append(current_node)
2072
+ num_of_sorted_nodes += 1
2073
+ # Decrement the depth of its predecessors. If any predecessor node has zero direct children, push it into the queue.
2074
+ for predecessor_node in node_predecessors[current_node]:
2075
+ node_depth[predecessor_node] -= 1
2076
+ if node_depth[predecessor_node] == 0:
2077
+ heapq.heappush(
2078
+ priority_queue, (neg_node_index[predecessor_node], predecessor_node)
2079
+ )
2080
+
2081
+ # 4. Cycle Check: Ensure all nodes are processed. If not, raise a ValueError indicating a cycle.
2082
+ if num_of_sorted_nodes != len(nodes):
2083
+ raise ValueError("Graph contains a cycle, topological sort is not possible.")
2084
+
2085
+ # 5. Reverse: Reverse the sorted nodes of each subgraph to get the topological order.
2086
+ for graph, sorted_nodes in sorted_nodes_by_graph.items():
2087
+ # The graph container ensures all the nodes are unique so we can safely extend
2088
+ graph.extend(reversed(sorted_nodes))
2089
+
2090
+ # End of mutation methods
2091
+
2092
+ @property
2093
+ def meta(self) -> _metadata.MetadataStore:
2094
+ """The metadata store for intermediate analysis.
2095
+
2096
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2097
+ to the ONNX proto.
2098
+ """
2099
+ if self._metadata is None:
2100
+ self._metadata = _metadata.MetadataStore()
2101
+ return self._metadata
2102
+
2103
+ @property
2104
+ def metadata_props(self) -> dict[str, str]:
2105
+ if self._metadata_props is None:
2106
+ self._metadata_props = {}
2107
+ return self._metadata_props
2108
+
2109
+ def __str__(self) -> str:
2110
+ return _graph_str(self)
2111
+
2112
+ def __repr__(self) -> str:
2113
+ return _graph_repr(self)
2114
+
2115
+
2116
+ def _graph_str(graph: Graph | GraphView) -> str:
2117
+ """Return a string representation of the graph."""
2118
+ # TODO(justinchuby): Show docstrings and metadata
2119
+ inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
2120
+ outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
2121
+ initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
2122
+ if initializers_text:
2123
+ initializers_text = (
2124
+ "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
2125
+ )
2126
+ signature = f"""\
2127
+ graph(
2128
+ name={graph.name or 'anonymous_graph:' + str(id(graph))},
2129
+ inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2130
+ ),
2131
+ outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2132
+ ),{textwrap.indent(initializers_text, ' ' * 4)}
2133
+ )"""
2134
+ node_count = len(graph)
2135
+ number_width = len(str(node_count))
2136
+ node_lines = []
2137
+ for i, node in enumerate(graph):
2138
+ node_name = node.name if node.name else f":anonymous_node:{id(node)}"
2139
+ node_text = f"# {node_name}\n{node}"
2140
+ indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
2141
+ # Remove the leading spaces
2142
+ indented_node_text = indented_node_text.strip()
2143
+ node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
2144
+ returns = ", ".join(str(x) for x in graph.outputs)
2145
+ body = (
2146
+ "{\n"
2147
+ + textwrap.indent("\n".join(node_lines), " " * 4)
2148
+ + textwrap.indent(f"\nreturn {returns}", " " * 4)
2149
+ + "\n}"
2150
+ )
2151
+
2152
+ return f"{signature} {body}"
2153
+
2154
+
2155
+ def _graph_repr(graph: Graph | GraphView) -> str:
2156
+ """Return an repr string of the graph."""
2157
+ inputs_text = "\n" + ",\n".join(str(x) for x in graph.inputs)
2158
+ outputs_text = "\n" + ",\n".join(str(x) for x in graph.outputs)
2159
+ initializers_text = ",\n".join(str(x) for x in graph.initializers.values())
2160
+ if initializers_text:
2161
+ initializers_text = (
2162
+ "\ninitializers=(\n" + textwrap.indent(initializers_text, " " * 4) + "\n),"
2163
+ )
2164
+ return f"""\
2165
+ {graph.__class__.__name__}(
2166
+ name={graph.name or 'anonymous_graph:' + str(id(graph))!r},
2167
+ inputs=({textwrap.indent(inputs_text, ' ' * 8)}
2168
+ ),
2169
+ outputs=({textwrap.indent(outputs_text, ' ' * 8)}
2170
+ ),{textwrap.indent(initializers_text, ' ' * 4)}
2171
+ len()={len(graph)}
2172
+ )"""
2173
+
2174
+
2175
+ class GraphView(Sequence[Node], _display.PrettyPrintable):
2176
+ """A read-only view on a graph.
2177
+
2178
+ The GraphView is useful for analysis of a subgraph. It can be initialized
2179
+ with a subset of nodes from a :class:`Graph`. Creating GraphView does not
2180
+ change the ownership of the nodes, and so it is possible to create multiple
2181
+ GraphViews that contain the same nodes. If the underlying nodes / connections
2182
+ are mutated, the mutation will be reflected in all views as well.
2183
+
2184
+ The graph view can be serialized to ONNX::
2185
+
2186
+ graph_proto = ir.serde.serialize_graph(graph_view)
2187
+
2188
+ It can also be used to create a model::
2189
+
2190
+ model = ir.Model(graph_view, ir_version=8)
2191
+ model_proto = ir.serde.serialize_model(model)
2192
+
2193
+ The model created with a GraphView will have a fixed topology, and its graph
2194
+ will remain read-only as a GraphView. No copying will be done during the
2195
+ initialization process.
2196
+
2197
+ Attributes:
2198
+ name: The name of the graph.
2199
+ inputs: The input values of the graph.
2200
+ outputs: The output values of the graph.
2201
+ initializers: The initializers in the graph.
2202
+ doc_string: Documentation string.
2203
+ opset_imports: Opsets imported by the graph.
2204
+ metadata_props: Metadata that will be serialized to the ONNX file.
2205
+ meta: Metadata store for graph transform passes.
2206
+ """
2207
+
2208
+ __slots__ = (
2209
+ "_metadata",
2210
+ "_metadata_props",
2211
+ "doc_string",
2212
+ "initializers",
2213
+ "inputs",
2214
+ "name",
2215
+ "nodes",
2216
+ "opset_imports",
2217
+ "outputs",
2218
+ )
2219
+
2220
+ def __init__(
2221
+ self,
2222
+ inputs: Sequence[Value],
2223
+ outputs: Sequence[Value],
2224
+ *,
2225
+ nodes: Iterable[Node],
2226
+ initializers: Sequence[_protocols.ValueProtocol] = (),
2227
+ doc_string: str | None = None,
2228
+ opset_imports: dict[str, int] | None = None,
2229
+ name: str | None = None,
2230
+ metadata_props: dict[str, str] | None = None,
2231
+ ):
2232
+ self.name = name
2233
+ self.inputs = tuple(inputs)
2234
+ self.outputs = tuple(outputs)
2235
+ for initializer in initializers:
2236
+ if initializer.name is None:
2237
+ raise ValueError(f"Initializer must have a name: {initializer}")
2238
+ self.initializers = {tensor.name: tensor for tensor in initializers}
2239
+ self.doc_string = doc_string
2240
+ self.opset_imports = opset_imports or {}
2241
+ self._metadata: _metadata.MetadataStore | None = None
2242
+ self._metadata_props: dict[str, str] | None = metadata_props
2243
+ self._nodes: tuple[Node, ...] = tuple(nodes)
2244
+
2245
+ def __getitem__(self, index: int) -> Node:
2246
+ return self._nodes[index]
2247
+
2248
+ def __len__(self) -> int:
2249
+ return len(self._nodes)
2250
+
2251
+ def __iter__(self) -> Iterator[Node]:
2252
+ return iter(self._nodes)
2253
+
2254
+ def __reversed__(self) -> Iterator[Node]:
2255
+ return reversed(self._nodes)
2256
+
2257
+ @property
2258
+ def meta(self) -> _metadata.MetadataStore:
2259
+ """The metadata store for intermediate analysis.
2260
+
2261
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2262
+ to the ONNX proto.
2263
+ """
2264
+ if self._metadata is None:
2265
+ self._metadata = _metadata.MetadataStore()
2266
+ return self._metadata
2267
+
2268
+ @property
2269
+ def metadata_props(self) -> dict[str, str]:
2270
+ if self._metadata_props is None:
2271
+ self._metadata_props = {}
2272
+ return self._metadata_props
2273
+
2274
+ def __str__(self) -> str:
2275
+ return _graph_str(self)
2276
+
2277
+ def __repr__(self) -> str:
2278
+ return _graph_repr(self)
2279
+
2280
+
2281
+ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
2282
+ __slots__ = (
2283
+ "_functions",
2284
+ "_metadata",
2285
+ "_metadata_props",
2286
+ "doc_string",
2287
+ "domain",
2288
+ "graph",
2289
+ "ir_version",
2290
+ "model_version",
2291
+ "producer_name",
2292
+ "producer_version",
2293
+ )
2294
+ """IR Model.
2295
+
2296
+ A model is a container for a graph and metadata.
2297
+
2298
+ Attributes:
2299
+ graph: The graph of the model.
2300
+ ir_version: The version of the IR.
2301
+ producer_name: The name of the producer.
2302
+ producer_version: The version of the producer.
2303
+ domain: The domain of the model.
2304
+ model_version: The version of the model.
2305
+ doc_string: Documentation string.
2306
+ functions: The functions defined in the model.
2307
+ metadata_props: Metadata.
2308
+ """
2309
+
2310
+ def __init__(
2311
+ self,
2312
+ graph: Graph,
2313
+ *,
2314
+ ir_version: int,
2315
+ producer_name: str | None = None,
2316
+ producer_version: str | None = None,
2317
+ domain: str | None = None,
2318
+ model_version: int | None = None,
2319
+ doc_string: str | None = None,
2320
+ functions: Sequence[Function] = (),
2321
+ meta_data_props: dict[str, str] | None = None,
2322
+ ) -> None:
2323
+ self.graph: Graph = graph
2324
+ self.ir_version = ir_version
2325
+ self.producer_name = producer_name
2326
+ self.producer_version = producer_version
2327
+ self.domain = domain
2328
+ self.model_version = model_version
2329
+ self.doc_string = doc_string
2330
+ self._functions = {func.identifier(): func for func in functions}
2331
+ self._metadata: _metadata.MetadataStore | None = None
2332
+ self._metadata_props: dict[str, str] | None = meta_data_props
2333
+
2334
+ @property
2335
+ def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
2336
+ return self._functions
2337
+
2338
+ @property
2339
+ def opset_imports(self) -> dict[str, int]:
2340
+ return self.graph.opset_imports
2341
+
2342
+ @property
2343
+ def meta(self) -> _metadata.MetadataStore:
2344
+ """The metadata store for intermediate analysis.
2345
+
2346
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2347
+ to the ONNX proto.
2348
+ """
2349
+ if self._metadata is None:
2350
+ self._metadata = _metadata.MetadataStore()
2351
+ return self._metadata
2352
+
2353
+ @property
2354
+ def metadata_props(self) -> dict[str, str]:
2355
+ if self._metadata_props is None:
2356
+ self._metadata_props = {}
2357
+ return self._metadata_props
2358
+
2359
+ def __str__(self) -> str:
2360
+ # TODO(justinchuby): Show docstrings and metadata
2361
+ signature = f"""\
2362
+ <
2363
+ ir_version={self.ir_version!r},
2364
+ opset_imports={self.opset_imports!r},
2365
+ producer_name={self.producer_name!r},
2366
+ producer_version={self.producer_version!r},
2367
+ domain={self.domain!r},
2368
+ model_version={self.model_version!r},
2369
+ >"""
2370
+ graph_text = str(self.graph)
2371
+ functions_text = "\n\n".join(str(func) for func in self.functions.values())
2372
+ return f"{signature}\n{graph_text}" + f"\n\n{functions_text}"
2373
+
2374
+ def __repr__(self) -> str:
2375
+ return f"""\
2376
+ Model(
2377
+ ir_version={self.ir_version!r},
2378
+ opset_imports={self.opset_imports!r},
2379
+ producer_name={self.producer_name!r},
2380
+ producer_version={self.producer_version!r},
2381
+ domain={self.domain!r},
2382
+ model_version={self.model_version!r},
2383
+ functions={self.functions!r},
2384
+ graph={textwrap.indent(repr(self.graph), ' ' * 4).strip()}
2385
+ )"""
2386
+
2387
+
2388
+ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrintable):
2389
+ """IR functions.
2390
+
2391
+ Like a graph, a function can have nodes that are not topologically sorted. It is
2392
+ the responsibility of the user to maintain a topological order of the nodes.
2393
+
2394
+ Note that there is not a ``node`` attribute in the Function. The Function can be
2395
+ seen as a Sequence of nodes and should be used as such. For example, to obtain
2396
+ all nodes as a list, call ``list(function)``.
2397
+
2398
+ Attributes:
2399
+ name: The function name.
2400
+ domain: The domain this function is defined in.
2401
+ overload: The overload name when the function is overloaded.
2402
+ inputs: The input values of the function.
2403
+ attributes: The attributes this function defines.
2404
+ outputs: The output values of the function.
2405
+ opset_imports: Opsets imported by the function.
2406
+ doc_string: Documentation string.
2407
+ metadata_props: Metadata that will be serialized to the ONNX file.
2408
+ meta: Metadata store for graph transform passes.
2409
+ """
2410
+
2411
+ __slots__ = (
2412
+ "_attributes",
2413
+ "_domain",
2414
+ "_graph",
2415
+ "_metadata",
2416
+ "_metadata_props",
2417
+ "_name",
2418
+ "_overload",
2419
+ )
2420
+
2421
+ def __init__(
2422
+ self,
2423
+ domain: str,
2424
+ name: str,
2425
+ overload: str = "",
2426
+ *,
2427
+ # Ensure the inputs and outputs of the function belong to a graph
2428
+ # and not from an outer scope
2429
+ graph: Graph,
2430
+ attributes: Sequence[Attr],
2431
+ metadata_props: dict[str, str] | None = None,
2432
+ ) -> None:
2433
+ self._domain = domain
2434
+ self._name = name
2435
+ self._overload = overload
2436
+ self._graph = graph
2437
+ self._attributes = OrderedDict((attr.name, attr) for attr in attributes)
2438
+ self._metadata: _metadata.MetadataStore | None = None
2439
+ self._metadata_props: dict[str, str] | None = metadata_props
2440
+
2441
+ def identifier(self) -> _protocols.OperatorIdentifier:
2442
+ return self.domain, self.name, self.overload
2443
+
2444
+ @property
2445
+ def name(self) -> str:
2446
+ return self._name
2447
+
2448
+ @name.setter
2449
+ def name(self, value: str) -> None:
2450
+ self._name = value
2451
+
2452
+ @property
2453
+ def domain(self) -> str:
2454
+ return self._domain
2455
+
2456
+ @domain.setter
2457
+ def domain(self, value: str) -> None:
2458
+ self._domain = value
2459
+
2460
+ @property
2461
+ def overload(self) -> str:
2462
+ return self._overload
2463
+
2464
+ @overload.setter
2465
+ def overload(self, value: str) -> None:
2466
+ self._overload = value
2467
+
2468
+ @property
2469
+ def inputs(self) -> list[Value]:
2470
+ return self._graph.inputs
2471
+
2472
+ @property
2473
+ def outputs(self) -> list[Value]:
2474
+ return self._graph.outputs
2475
+
2476
+ @property
2477
+ def attributes(self) -> OrderedDict[str, Attr]:
2478
+ return self._attributes
2479
+
2480
+ def __getitem__(self, index: int) -> Node:
2481
+ return self._graph.__getitem__(index)
2482
+
2483
+ def __len__(self) -> int:
2484
+ return self._graph.__len__()
2485
+
2486
+ def __iter__(self) -> Iterator[Node]:
2487
+ return self._graph.__iter__()
2488
+
2489
+ def __reversed__(self) -> Iterator[Node]:
2490
+ return self._graph.__reversed__()
2491
+
2492
+ @property
2493
+ def doc_string(self) -> str | None:
2494
+ return self._graph.doc_string
2495
+
2496
+ @doc_string.setter
2497
+ def doc_string(self, value: str | None) -> None:
2498
+ self._graph.doc_string = value
2499
+
2500
+ @property
2501
+ def opset_imports(self) -> dict[str, int]:
2502
+ return self._graph.opset_imports
2503
+
2504
+ @property
2505
+ def meta(self) -> _metadata.MetadataStore:
2506
+ """The metadata store for intermediate analysis.
2507
+
2508
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
2509
+ to the ONNX proto.
2510
+ """
2511
+ if self._metadata is None:
2512
+ self._metadata = _metadata.MetadataStore()
2513
+ return self._metadata
2514
+
2515
+ @property
2516
+ def metadata_props(self) -> dict[str, str]:
2517
+ if self._metadata_props is None:
2518
+ self._metadata_props = {}
2519
+ return self._metadata_props
2520
+
2521
+ # Mutation methods
2522
+ def append(self, node: Node, /) -> None:
2523
+ """Append a node to the function in O(1) time."""
2524
+ self._graph.append(node)
2525
+
2526
+ def extend(self, nodes: Iterable[Node], /) -> None:
2527
+ """Extend the function with the given nodes in O(#new_nodes) time."""
2528
+ self._graph.extend(nodes)
2529
+
2530
+ def remove(self, nodes: Node | Iterable[Node], /, safe: bool = False) -> None:
2531
+ """Remove nodes from the graph in O(#num of nodes) time.
2532
+
2533
+ If any errors are raise, to ensure the graph is not left in an inconsistent state,
2534
+ the graph is not modified.
2535
+
2536
+ Args:
2537
+ nodes: The node to remove.
2538
+ safe: If True, performs the following actions before removal:
2539
+
2540
+ 1. It checks to make sure there are no users of the node that are not
2541
+ to be removed before removing it.
2542
+ 2. It checks the node does not contribute to any graph outputs.
2543
+ 3. It removes references to all inputs so it is no longer a user of other nodes.
2544
+
2545
+ Raises:
2546
+ ValueError: If any node to remove does not belong to this graph.
2547
+ ValueError: (When ``safe=True``) If the node does not belong to this graph or if there are users of the node.
2548
+ ValueError: (When ``safe=True``) If the node is still being used by other nodes not to be removed.
2549
+ """
2550
+ self._graph.remove(nodes, safe=safe)
2551
+
2552
+ def insert_after(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2553
+ """Insert new nodes after the given node in O(#new_nodes) time."""
2554
+ self._graph.insert_after(node, new_nodes)
2555
+
2556
+ def insert_before(self, node: Node, new_nodes: Iterable[Node], /) -> None:
2557
+ """Insert new nodes before the given node in O(#new_nodes) time."""
2558
+ self._graph.insert_before(node, new_nodes)
2559
+
2560
+ def sort(self) -> None:
2561
+ """Perform a topological sort of this graph and all subgraphs in O(#nodes + #values) time."""
2562
+ self._graph.sort()
2563
+
2564
+ # End of mutation methods
2565
+
2566
+ def __str__(self) -> str:
2567
+ full_name = f"{self.domain}::{self.name}" + f":{self.overload}" * (self.overload != "")
2568
+ inputs_text = ",\n".join(str(x) for x in self.inputs)
2569
+ outputs_text = ",\n".join(str(x) for x in self.outputs)
2570
+ attributes_text = ",\n".join(
2571
+ f"{attr.name}: {attr.type}" + f" = {attr.value}" * (attr.value is not None)
2572
+ for attr in self.attributes.values()
2573
+ )
2574
+ if attributes_text:
2575
+ attributes_text = (
2576
+ "\nattributes={\n" + textwrap.indent(attributes_text, " " * 4) + "\n}"
2577
+ )
2578
+ signature = f"""\
2579
+ <
2580
+ opset_imports={self.opset_imports!r},
2581
+ >
2582
+ def {full_name}(
2583
+ inputs=(
2584
+ {textwrap.indent(inputs_text, ' ' * 8)}
2585
+ ),{textwrap.indent(attributes_text, ' ' * 4)}
2586
+ outputs=(
2587
+ {textwrap.indent(outputs_text, ' ' * 8)}
2588
+ ),
2589
+ )"""
2590
+ node_count = len(self)
2591
+ number_width = len(str(node_count))
2592
+ node_lines = []
2593
+ for i, node in enumerate(self):
2594
+ node_name = node.name if node.name else f":anonymous_node:{id(node)}"
2595
+ node_text = f"# {node_name}\n{node}"
2596
+ indented_node_text = textwrap.indent(node_text, " " * (number_width + 4))
2597
+ # Remove the leading spaces
2598
+ indented_node_text = indented_node_text.strip()
2599
+ node_lines.append(f"{i:>{number_width}} | {indented_node_text}")
2600
+ returns = ", ".join(str(x) for x in self.outputs)
2601
+ body = (
2602
+ "{\n"
2603
+ + textwrap.indent("\n".join(node_lines), " " * 4)
2604
+ + textwrap.indent(f"\nreturn {returns}", " " * 4)
2605
+ + "\n}"
2606
+ )
2607
+
2608
+ return f"{signature} {body}"
2609
+
2610
+ def __repr__(self) -> str:
2611
+ return f"{self.__class__.__name__}({self.domain!r}, {self.name!r}, {self.overload!r}, inputs={self.inputs!r}, attributes={self.attributes!r}), outputs={self.outputs!r})"
2612
+
2613
+
2614
+ class RefAttr(_protocols.ReferenceAttributeProtocol, _display.PrettyPrintable):
2615
+ """Reference attribute."""
2616
+
2617
+ __slots__ = ("_name", "_ref_attr_name", "_type", "doc_string")
2618
+
2619
+ def __init__(
2620
+ self,
2621
+ name: str,
2622
+ ref_attr_name: str,
2623
+ type: _enums.AttributeType,
2624
+ *,
2625
+ doc_string: str | None = None,
2626
+ ) -> None:
2627
+ self._name = name
2628
+ self._ref_attr_name = ref_attr_name
2629
+ self._type = type
2630
+ self.doc_string = doc_string
2631
+
2632
+ @property
2633
+ def name(self) -> str:
2634
+ return self._name
2635
+
2636
+ @name.setter
2637
+ def name(self, value: str) -> None:
2638
+ self._name = value
2639
+
2640
+ @property
2641
+ def ref_attr_name(self) -> str:
2642
+ return self._ref_attr_name
2643
+
2644
+ @ref_attr_name.setter
2645
+ def ref_attr_name(self, value: str) -> None:
2646
+ self._ref_attr_name = value
2647
+
2648
+ @property
2649
+ def type(self) -> _enums.AttributeType:
2650
+ return self._type
2651
+
2652
+ @type.setter
2653
+ def type(self, value: _enums.AttributeType) -> None:
2654
+ self._type = value
2655
+
2656
+ def __repr__(self) -> str:
2657
+ return f"{self.__class__.__name__}({self._name!r}, {self._type!r}, ref_attr_name={self.ref_attr_name!r})"
2658
+
2659
+
2660
+ class Attr(_protocols.AttributeProtocol, _display.PrettyPrintable):
2661
+ """Base class for ONNX attributes."""
2662
+
2663
+ __slots__ = ("doc_string", "name", "type", "value")
2664
+
2665
+ def __init__(
2666
+ self,
2667
+ name: str,
2668
+ type: _enums.AttributeType,
2669
+ value: Any,
2670
+ *,
2671
+ doc_string: str | None = None,
2672
+ ):
2673
+ self.name = name
2674
+ self.type = type
2675
+ self.value = value
2676
+ self.doc_string = doc_string
2677
+
2678
+ def __eq__(self, other: object) -> bool:
2679
+ if not isinstance(other, _protocols.AttributeProtocol):
2680
+ return False
2681
+
2682
+ if self.name != other.name:
2683
+ return False
2684
+ if self.type != other.type:
2685
+ return False
2686
+ if self.value != other.value:
2687
+ return False
2688
+ if self.doc_string != other.doc_string:
2689
+ return False
2690
+ return True
2691
+
2692
+ def __str__(self) -> str:
2693
+ if self.type == _enums.AttributeType.GRAPH:
2694
+ return textwrap.indent("\n" + str(self.value), " " * 4)
2695
+ return str(self.value)
2696
+
2697
+ def __repr__(self) -> str:
2698
+ return f"{self.__class__.__name__}({self.name!r}, {self.type!r}, {self.value!r})"
2699
+
2700
+
2701
+ # NOTE: The following functions are just for convenience
2702
+ def AttrFloat32(name: str, value: float, doc_string: str | None = None) -> Attr:
2703
+ """Create a float attribute."""
2704
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2705
+ return Attr(
2706
+ name,
2707
+ _enums.AttributeType.FLOAT,
2708
+ value,
2709
+ doc_string=doc_string,
2710
+ )
2711
+
2712
+
2713
+ def AttrInt64(name: str, value: int, doc_string: str | None = None) -> Attr:
2714
+ """Create an int attribute."""
2715
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2716
+ return Attr(
2717
+ name,
2718
+ _enums.AttributeType.INT,
2719
+ value,
2720
+ doc_string=doc_string,
2721
+ )
2722
+
2723
+
2724
+ def AttrString(name: str, value: str, doc_string: str | None = None) -> Attr:
2725
+ """Create a str attribute."""
2726
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2727
+ return Attr(
2728
+ name,
2729
+ _enums.AttributeType.STRING,
2730
+ value,
2731
+ doc_string=doc_string,
2732
+ )
2733
+
2734
+
2735
+ def AttrTensor(
2736
+ name: str, value: _protocols.TensorProtocol, doc_string: str | None = None
2737
+ ) -> Attr:
2738
+ """Create a tensor attribute."""
2739
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2740
+ return Attr(
2741
+ name,
2742
+ _enums.AttributeType.TENSOR,
2743
+ value,
2744
+ doc_string=doc_string,
2745
+ )
2746
+
2747
+
2748
+ def AttrGraph(name: str, value: Graph, doc_string: str | None = None) -> Attr:
2749
+ """Create a graph attribute."""
2750
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2751
+ return Attr(
2752
+ name,
2753
+ _enums.AttributeType.GRAPH,
2754
+ value,
2755
+ doc_string=doc_string,
2756
+ )
2757
+
2758
+
2759
+ def AttrFloat32s(name: str, value: Sequence[float], doc_string: str | None = None) -> Attr:
2760
+ """Create a float sequence attribute."""
2761
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2762
+ return Attr(
2763
+ name,
2764
+ _enums.AttributeType.FLOATS,
2765
+ value,
2766
+ doc_string=doc_string,
2767
+ )
2768
+
2769
+
2770
+ def AttrInt64s(name: str, value: Sequence[int], doc_string: str | None = None) -> Attr:
2771
+ """Create an int sequence attribute."""
2772
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2773
+ return Attr(
2774
+ name,
2775
+ _enums.AttributeType.INTS,
2776
+ value,
2777
+ doc_string=doc_string,
2778
+ )
2779
+
2780
+
2781
+ def AttrStrings(name: str, value: Sequence[str], doc_string: str | None = None) -> Attr:
2782
+ """Create a string sequence attribute."""
2783
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2784
+ return Attr(
2785
+ name,
2786
+ _enums.AttributeType.STRINGS,
2787
+ value,
2788
+ doc_string=doc_string,
2789
+ )
2790
+
2791
+
2792
+ def AttrTensors(
2793
+ name: str, value: Sequence[_protocols.TensorProtocol], doc_string: str | None = None
2794
+ ) -> Attr:
2795
+ """Create a tensor sequence attribute."""
2796
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2797
+ return Attr(
2798
+ name,
2799
+ _enums.AttributeType.TENSORS,
2800
+ value,
2801
+ doc_string=doc_string,
2802
+ )
2803
+
2804
+
2805
+ def AttrGraphs(name: str, value: Sequence[Graph], doc_string: str | None = None) -> Attr:
2806
+ """Create a graph sequence attribute."""
2807
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2808
+ return Attr(
2809
+ name,
2810
+ _enums.AttributeType.GRAPHS,
2811
+ value,
2812
+ doc_string=doc_string,
2813
+ )
2814
+
2815
+
2816
+ # NOTE: SparseTensor should be a sparse tensor proto
2817
+ def AttrSparseTensor(
2818
+ name: str, value: _protocols.SparseTensorProtocol, doc_string: str | None = None
2819
+ ) -> Attr:
2820
+ """Create a sparse tensor attribute."""
2821
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2822
+ return Attr(
2823
+ name,
2824
+ _enums.AttributeType.SPARSE_TENSOR,
2825
+ value,
2826
+ doc_string=doc_string,
2827
+ )
2828
+
2829
+
2830
+ def AttrSparseTensors(
2831
+ name: str, value: Sequence[_protocols.SparseTensorProtocol], doc_string: str | None = None
2832
+ ) -> Attr:
2833
+ """Create a sparse tensor sequence attribute."""
2834
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2835
+ return Attr(
2836
+ name,
2837
+ _enums.AttributeType.SPARSE_TENSORS,
2838
+ value,
2839
+ doc_string=doc_string,
2840
+ )
2841
+
2842
+
2843
+ @dataclasses.dataclass
2844
+ class TypeAndShape:
2845
+ """Type and shape.
2846
+
2847
+ Useful for constructing a type proto.
2848
+ """
2849
+
2850
+ type: _protocols.TypeProtocol | None
2851
+ shape: Shape | None
2852
+
2853
+
2854
+ def AttrTypeProto(name: str, value: TypeAndShape, doc_string: str | None = None) -> Attr:
2855
+ """Create a type attribute."""
2856
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2857
+ return Attr(
2858
+ name,
2859
+ _enums.AttributeType.TYPE_PROTO,
2860
+ value,
2861
+ doc_string=doc_string,
2862
+ )
2863
+
2864
+
2865
+ def AttrTypeProtos(
2866
+ name: str, value: Sequence[TypeAndShape], doc_string: str | None = None
2867
+ ) -> Attr:
2868
+ """Create a type sequence attribute."""
2869
+ # NOTE: The function name is capitalized to maintain API backward compatibility.
2870
+ return Attr(
2871
+ name,
2872
+ _enums.AttributeType.TYPE_PROTOS,
2873
+ value,
2874
+ doc_string=doc_string,
2875
+ )