onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/_display.py ADDED
@@ -0,0 +1,54 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Internal utilities for displaying the intermediate representation of a model.
4
+
5
+ NOTE: All third-party imports should be scoped and imported only when used to avoid
6
+ importing unnecessary dependencies.
7
+ """
8
+ # pylint: disable=import-outside-toplevel
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import Any
13
+
14
+
15
+ def require_rich() -> Any:
16
+ """Raise an ImportError if rich is not installed."""
17
+ try:
18
+ import rich
19
+ except ImportError:
20
+ return None
21
+ return rich
22
+
23
+
24
+ class PrettyPrintable:
25
+ def display(self, *, page: bool = False) -> None:
26
+ """Pretty print the object.
27
+
28
+ Args:
29
+ page: Whether to page the output.
30
+ """
31
+ rich = require_rich()
32
+ text = str(self)
33
+
34
+ if rich is None:
35
+ print(text)
36
+ # Color print this message
37
+ print(
38
+ f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m"
39
+ )
40
+ return
41
+
42
+ import rich.markup
43
+
44
+ # Escape text to display `[...]` correctly
45
+ text = rich.markup.escape(text)
46
+
47
+ if page:
48
+ import rich.console
49
+
50
+ console = rich.console.Console()
51
+ with console.pager():
52
+ console.print(text)
53
+ else:
54
+ rich.print(text)
onnx_ir/_enums.py ADDED
@@ -0,0 +1,474 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """ONNX IR enums that matches the ONNX spec."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import enum
8
+ from typing import Any
9
+
10
+ import ml_dtypes
11
+ import numpy as np
12
+
13
+
14
+ class AttributeType(enum.IntEnum):
15
+ """Enum for the types of ONNX attributes."""
16
+
17
+ UNDEFINED = 0
18
+ FLOAT = 1
19
+ INT = 2
20
+ STRING = 3
21
+ TENSOR = 4
22
+ GRAPH = 5
23
+ FLOATS = 6
24
+ INTS = 7
25
+ STRINGS = 8
26
+ TENSORS = 9
27
+ GRAPHS = 10
28
+ SPARSE_TENSOR = 11
29
+ SPARSE_TENSORS = 12
30
+ TYPE_PROTO = 13
31
+ TYPE_PROTOS = 14
32
+
33
+ def __repr__(self) -> str:
34
+ return self.name
35
+
36
+ def __str__(self) -> str:
37
+ return self.__repr__()
38
+
39
+
40
+ class DataType(enum.IntEnum):
41
+ """Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
42
+
43
+ # NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
44
+ # but we should stick to the names used in the ONNX spec for consistency.
45
+ UNDEFINED = 0
46
+ FLOAT = 1
47
+ UINT8 = 2
48
+ INT8 = 3
49
+ UINT16 = 4
50
+ INT16 = 5
51
+ INT32 = 6
52
+ INT64 = 7
53
+ STRING = 8
54
+ BOOL = 9
55
+ FLOAT16 = 10
56
+ DOUBLE = 11
57
+ UINT32 = 12
58
+ UINT64 = 13
59
+ COMPLEX64 = 14
60
+ COMPLEX128 = 15
61
+ BFLOAT16 = 16
62
+ FLOAT8E4M3FN = 17
63
+ FLOAT8E4M3FNUZ = 18
64
+ FLOAT8E5M2 = 19
65
+ FLOAT8E5M2FNUZ = 20
66
+ UINT4 = 21
67
+ INT4 = 22
68
+ FLOAT4E2M1 = 23
69
+ FLOAT8E8M0 = 24
70
+ UINT2 = 25
71
+ INT2 = 26
72
+
73
+ @classmethod
74
+ def from_numpy(cls, dtype: np.dtype) -> DataType:
75
+ """Returns the ONNX data type for the numpy dtype.
76
+
77
+ Raises:
78
+ TypeError: If the data type is not supported by ONNX.
79
+ """
80
+ if dtype in _NP_TYPE_TO_DATA_TYPE:
81
+ return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
82
+
83
+ if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
84
+ return DataType.STRING
85
+
86
+ # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
87
+ # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
88
+ # TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
89
+ if hasattr(dtype, "names"):
90
+ if dtype.names == ("bfloat16",):
91
+ return DataType.BFLOAT16
92
+ if dtype.names == ("e4m3fn",):
93
+ return DataType.FLOAT8E4M3FN
94
+ if dtype.names == ("e4m3fnuz",):
95
+ return DataType.FLOAT8E4M3FNUZ
96
+ if dtype.names == ("e5m2",):
97
+ return DataType.FLOAT8E5M2
98
+ if dtype.names == ("e5m2fnuz",):
99
+ return DataType.FLOAT8E5M2FNUZ
100
+ if dtype.names == ("uint4",):
101
+ return DataType.UINT4
102
+ if dtype.names == ("int4",):
103
+ return DataType.INT4
104
+ if dtype.names == ("float4e2m1",):
105
+ return DataType.FLOAT4E2M1
106
+ if dtype.names == ("int2",):
107
+ return DataType.INT2
108
+ if dtype.names == ("uint2",):
109
+ return DataType.UINT2
110
+ raise TypeError(f"Unsupported numpy data type: {dtype}")
111
+
112
+ @classmethod
113
+ def from_short_name(cls, short_name: str) -> DataType:
114
+ """Returns the ONNX data type for the short name.
115
+
116
+ Raises:
117
+ TypeError: If the short name is not available for the data type.
118
+ """
119
+ if short_name not in _SHORT_NAME_TO_DATA_TYPE:
120
+ raise TypeError(f"Unknown short name: {short_name}")
121
+ return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
122
+
123
+ @property
124
+ def itemsize(self) -> float:
125
+ """Returns the size of the data type in bytes."""
126
+ return self.bitwidth / 8
127
+
128
+ @property
129
+ def bitwidth(self) -> int:
130
+ """Returns the bit width of the data type.
131
+
132
+ .. versionadded:: 0.1.2
133
+
134
+ Raises:
135
+ TypeError: If the data type is not supported.
136
+ """
137
+ if self not in _BITWIDTH_MAP:
138
+ raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
139
+ return _BITWIDTH_MAP[self]
140
+
141
+ @property
142
+ def exponent_bitwidth(self) -> int:
143
+ """Returns the bit width of the exponent for floating-point types.
144
+
145
+ .. versionadded:: 0.1.8
146
+
147
+ Raises:
148
+ TypeError: If the data type is not supported.
149
+ """
150
+ if self.is_floating_point():
151
+ return ml_dtypes.finfo(self.numpy()).nexp
152
+
153
+ raise TypeError(f"Exponent not available for ONNX data type: {self}")
154
+
155
+ @property
156
+ def mantissa_bitwidth(self) -> int:
157
+ """Returns the bit width of the mantissa for floating-point types.
158
+
159
+ .. versionadded:: 0.1.8
160
+
161
+ Raises:
162
+ TypeError: If the data type is not supported.
163
+ """
164
+ if self.is_floating_point():
165
+ return ml_dtypes.finfo(self.numpy()).nmant
166
+
167
+ raise TypeError(f"Mantissa not available for ONNX data type: {self}")
168
+
169
+ @property
170
+ def eps(self) -> int | np.floating[Any]:
171
+ """Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
172
+
173
+ Returns 1 for integers.
174
+
175
+ .. versionadded:: 0.1.8
176
+
177
+ Raises:
178
+ TypeError: If the data type is not a numeric data type.
179
+ """
180
+ if self.is_integer():
181
+ return 1
182
+
183
+ if self.is_floating_point():
184
+ return ml_dtypes.finfo(self.numpy()).eps
185
+
186
+ raise TypeError(f"Eps not available for ONNX data type: {self}")
187
+
188
+ @property
189
+ def tiny(self) -> int | np.floating[Any]:
190
+ """Returns the smallest positive non-zero value for the ONNX data type.
191
+
192
+ Returns 1 for integers.
193
+
194
+ .. versionadded:: 0.1.8
195
+
196
+ Raises:
197
+ TypeError: If the data type is not a numeric data type.
198
+ """
199
+ if self.is_integer():
200
+ return 1
201
+
202
+ if self.is_floating_point():
203
+ return ml_dtypes.finfo(self.numpy()).tiny
204
+
205
+ raise TypeError(f"Tiny not available for ONNX data type: {self}")
206
+
207
+ @property
208
+ def min(self) -> int | np.floating[Any]:
209
+ """Returns the minimum representable value for the ONNX data type.
210
+
211
+ .. versionadded:: 0.1.8
212
+
213
+ Raises:
214
+ TypeError: If the data type is not a numeric data type.
215
+ """
216
+ if self.is_integer():
217
+ return ml_dtypes.iinfo(self.numpy()).min
218
+
219
+ if self.is_floating_point():
220
+ return ml_dtypes.finfo(self.numpy()).min
221
+
222
+ raise TypeError(f"Minimum not available for ONNX data type: {self}")
223
+
224
+ @property
225
+ def max(self) -> int | np.floating[Any]:
226
+ """Returns the maximum representable value for the ONNX data type.
227
+
228
+ .. versionadded:: 0.1.8
229
+
230
+ Raises:
231
+ TypeError: If the data type is not a numeric data type.
232
+ """
233
+ if self.is_integer():
234
+ return ml_dtypes.iinfo(self.numpy()).max
235
+
236
+ if self.is_floating_point():
237
+ return ml_dtypes.finfo(self.numpy()).max
238
+
239
+ raise TypeError(f"Maximum not available for ONNX data type: {self}")
240
+
241
+ @property
242
+ def precision(self) -> int:
243
+ """Returns the precision for the ONNX dtype if supported.
244
+
245
+ For floats returns the approximate number of decimal digits to which
246
+ this kind of float is precise. Returns 0 for integers.
247
+
248
+ .. versionadded:: 0.1.8
249
+
250
+ Raises:
251
+ TypeError: If the data type is not a numeric data type.
252
+ """
253
+ if self.is_integer():
254
+ return 0
255
+
256
+ if self.is_floating_point():
257
+ return ml_dtypes.finfo(self.numpy()).precision
258
+
259
+ raise TypeError(f"Precision not available for ONNX data type: {self}")
260
+
261
+ @property
262
+ def resolution(self) -> int | np.floating[Any]:
263
+ """Returns the resolution for the ONNX dtype if supported.
264
+
265
+ Returns the approximate decimal resolution of this type, i.e.,
266
+ 10**-precision. Returns 1 for integers.
267
+
268
+ .. versionadded:: 0.1.8
269
+
270
+ Raises:
271
+ TypeError: If the data type is not a numeric data type.
272
+ """
273
+ if self.is_integer():
274
+ return 1
275
+
276
+ if self.is_floating_point():
277
+ return ml_dtypes.finfo(self.numpy()).resolution
278
+
279
+ raise TypeError(f"Resolution not available for ONNX data type: {self}")
280
+
281
+ def numpy(self) -> np.dtype:
282
+ """Returns the numpy dtype for the ONNX data type.
283
+
284
+ Raises:
285
+ TypeError: If the data type is not supported by numpy.
286
+ """
287
+ if self not in _DATA_TYPE_TO_NP_TYPE:
288
+ raise TypeError(f"Numpy does not support ONNX data type: {self}")
289
+ return _DATA_TYPE_TO_NP_TYPE[self]
290
+
291
+ def short_name(self) -> str:
292
+ """Returns the short name of the data type.
293
+
294
+ The short name is a string that is used to represent the data type in a more
295
+ compact form. For example, the short name for `DataType.FLOAT` is "f32".
296
+ To get the corresponding data type back, call ``from_short_name`` on a string.
297
+
298
+ Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
299
+
300
+ Raises:
301
+ TypeError: If the short name is not available for the data type.
302
+ """
303
+ if self not in _DATA_TYPE_TO_SHORT_NAME:
304
+ raise TypeError(f"Short name not available for ONNX data type: {self}")
305
+ return _DATA_TYPE_TO_SHORT_NAME[self]
306
+
307
+ def is_floating_point(self) -> bool:
308
+ """Returns True if the data type is a floating point type."""
309
+ return self in {
310
+ DataType.FLOAT,
311
+ DataType.FLOAT16,
312
+ DataType.DOUBLE,
313
+ DataType.BFLOAT16,
314
+ DataType.FLOAT8E4M3FN,
315
+ DataType.FLOAT8E4M3FNUZ,
316
+ DataType.FLOAT8E5M2,
317
+ DataType.FLOAT8E5M2FNUZ,
318
+ DataType.FLOAT4E2M1,
319
+ DataType.FLOAT8E8M0,
320
+ }
321
+
322
+ def is_integer(self) -> bool:
323
+ """Returns True if the data type is an integer.
324
+
325
+ .. versionadded:: 0.1.4
326
+ """
327
+ return self in {
328
+ DataType.UINT8,
329
+ DataType.INT8,
330
+ DataType.UINT16,
331
+ DataType.INT16,
332
+ DataType.INT32,
333
+ DataType.INT64,
334
+ DataType.UINT32,
335
+ DataType.UINT64,
336
+ DataType.UINT4,
337
+ DataType.INT4,
338
+ DataType.INT2,
339
+ DataType.UINT2,
340
+ }
341
+
342
+ def is_signed(self) -> bool:
343
+ """Returns True if the data type is a signed type.
344
+
345
+ .. versionadded:: 0.1.4
346
+ """
347
+ return self in {
348
+ DataType.FLOAT,
349
+ DataType.INT8,
350
+ DataType.INT16,
351
+ DataType.INT32,
352
+ DataType.INT64,
353
+ DataType.FLOAT16,
354
+ DataType.DOUBLE,
355
+ DataType.COMPLEX64,
356
+ DataType.COMPLEX128,
357
+ DataType.BFLOAT16,
358
+ DataType.FLOAT8E4M3FN,
359
+ DataType.FLOAT8E4M3FNUZ,
360
+ DataType.FLOAT8E5M2,
361
+ DataType.FLOAT8E5M2FNUZ,
362
+ DataType.INT4,
363
+ DataType.FLOAT4E2M1,
364
+ DataType.FLOAT8E8M0,
365
+ DataType.INT2,
366
+ }
367
+
368
+ def is_string(self) -> bool:
369
+ """Returns True if the data type is a string type.
370
+
371
+ .. versionadded:: 0.1.8
372
+ """
373
+ return self == DataType.STRING
374
+
375
+ def __repr__(self) -> str:
376
+ return self.name
377
+
378
+ def __str__(self) -> str:
379
+ return self.__repr__()
380
+
381
+
382
+ _BITWIDTH_MAP = {
383
+ DataType.FLOAT: 32,
384
+ DataType.UINT8: 8,
385
+ DataType.INT8: 8,
386
+ DataType.UINT16: 16,
387
+ DataType.INT16: 16,
388
+ DataType.INT32: 32,
389
+ DataType.INT64: 64,
390
+ DataType.BOOL: 8,
391
+ DataType.FLOAT16: 16,
392
+ DataType.DOUBLE: 64,
393
+ DataType.UINT32: 32,
394
+ DataType.UINT64: 64,
395
+ DataType.COMPLEX64: 64, # 2 * 32
396
+ DataType.COMPLEX128: 128, # 2 * 64
397
+ DataType.BFLOAT16: 16,
398
+ DataType.FLOAT8E4M3FN: 8,
399
+ DataType.FLOAT8E4M3FNUZ: 8,
400
+ DataType.FLOAT8E5M2: 8,
401
+ DataType.FLOAT8E5M2FNUZ: 8,
402
+ DataType.UINT4: 4,
403
+ DataType.INT4: 4,
404
+ DataType.FLOAT4E2M1: 4,
405
+ DataType.FLOAT8E8M0: 8,
406
+ DataType.INT2: 2,
407
+ DataType.UINT2: 2,
408
+ }
409
+
410
+
411
+ # We use ml_dtypes to support dtypes that are not in numpy.
412
+ _NP_TYPE_TO_DATA_TYPE = {
413
+ np.dtype("bool"): DataType.BOOL,
414
+ np.dtype("complex128"): DataType.COMPLEX128,
415
+ np.dtype("complex64"): DataType.COMPLEX64,
416
+ np.dtype("float16"): DataType.FLOAT16,
417
+ np.dtype("float32"): DataType.FLOAT,
418
+ np.dtype("float64"): DataType.DOUBLE,
419
+ np.dtype("int16"): DataType.INT16,
420
+ np.dtype("int32"): DataType.INT32,
421
+ np.dtype("int64"): DataType.INT64,
422
+ np.dtype("int8"): DataType.INT8,
423
+ np.dtype("object"): DataType.STRING,
424
+ np.dtype("uint16"): DataType.UINT16,
425
+ np.dtype("uint32"): DataType.UINT32,
426
+ np.dtype("uint64"): DataType.UINT64,
427
+ np.dtype("uint8"): DataType.UINT8,
428
+ np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16,
429
+ np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN,
430
+ np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
431
+ np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
432
+ np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
433
+ np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
434
+ np.dtype(ml_dtypes.int4): DataType.INT4,
435
+ np.dtype(ml_dtypes.uint4): DataType.UINT4,
436
+ np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
437
+ np.dtype(ml_dtypes.int2): DataType.INT2,
438
+ np.dtype(ml_dtypes.uint2): DataType.UINT2,
439
+ }
440
+
441
+ # ONNX DataType to Numpy dtype.
442
+ _DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
443
+
444
+ _DATA_TYPE_TO_SHORT_NAME = {
445
+ DataType.UNDEFINED: "undefined",
446
+ DataType.BFLOAT16: "bf16",
447
+ DataType.DOUBLE: "f64",
448
+ DataType.FLOAT: "f32",
449
+ DataType.FLOAT16: "f16",
450
+ DataType.FLOAT8E4M3FN: "f8e4m3fn",
451
+ DataType.FLOAT8E5M2: "f8e5m2",
452
+ DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
453
+ DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
454
+ DataType.FLOAT8E8M0: "f8e8m0",
455
+ DataType.FLOAT4E2M1: "f4e2m1",
456
+ DataType.COMPLEX64: "c64",
457
+ DataType.COMPLEX128: "c128",
458
+ DataType.INT2: "i2",
459
+ DataType.INT4: "i4",
460
+ DataType.INT8: "i8",
461
+ DataType.INT16: "i16",
462
+ DataType.INT32: "i32",
463
+ DataType.INT64: "i64",
464
+ DataType.BOOL: "b8",
465
+ DataType.UINT2: "u2",
466
+ DataType.UINT4: "u4",
467
+ DataType.UINT8: "u8",
468
+ DataType.UINT16: "u16",
469
+ DataType.UINT32: "u32",
470
+ DataType.UINT64: "u64",
471
+ DataType.STRING: "s",
472
+ }
473
+
474
+ _SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}
@@ -0,0 +1,23 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utilities for comparing IR graphs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from onnx_ir import _core
8
+
9
+ # NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs
10
+ # NOTE(justinchuby): A graph may be specified with a set of inputs and outputs
11
+
12
+
13
+ def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool:
14
+ """Return true if the two graphs are topologically equivalent, without considering initializers.
15
+
16
+ Args:
17
+ graph1: The first graph to compare.
18
+ graph2: The second graph to compare.
19
+
20
+ Returns:
21
+ True if the graphs are equal, False otherwise.
22
+ """
23
+ raise NotImplementedError()