onnx-ir 0.1.15__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_ir/__init__.py +176 -0
- onnx_ir/_cloner.py +229 -0
- onnx_ir/_convenience/__init__.py +558 -0
- onnx_ir/_convenience/_constructors.py +291 -0
- onnx_ir/_convenience/_extractor.py +191 -0
- onnx_ir/_core.py +4435 -0
- onnx_ir/_display.py +54 -0
- onnx_ir/_enums.py +474 -0
- onnx_ir/_graph_comparison.py +23 -0
- onnx_ir/_graph_containers.py +373 -0
- onnx_ir/_io.py +133 -0
- onnx_ir/_linked_list.py +284 -0
- onnx_ir/_metadata.py +45 -0
- onnx_ir/_name_authority.py +72 -0
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +627 -0
- onnx_ir/_safetensors/__init__.py +510 -0
- onnx_ir/_tape.py +242 -0
- onnx_ir/_thirdparty/asciichartpy.py +310 -0
- onnx_ir/_type_casting.py +89 -0
- onnx_ir/_version_utils.py +48 -0
- onnx_ir/analysis/__init__.py +21 -0
- onnx_ir/analysis/_implicit_usage.py +74 -0
- onnx_ir/convenience.py +38 -0
- onnx_ir/external_data.py +459 -0
- onnx_ir/passes/__init__.py +41 -0
- onnx_ir/passes/_pass_infra.py +351 -0
- onnx_ir/passes/common/__init__.py +54 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
- onnx_ir/passes/common/constant_manipulation.py +230 -0
- onnx_ir/passes/common/default_attributes.py +99 -0
- onnx_ir/passes/common/identity_elimination.py +120 -0
- onnx_ir/passes/common/initializer_deduplication.py +179 -0
- onnx_ir/passes/common/inliner.py +223 -0
- onnx_ir/passes/common/naming.py +280 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/output_fix.py +141 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +37 -0
- onnx_ir/passes/common/unused_removal.py +215 -0
- onnx_ir/py.typed +1 -0
- onnx_ir/serde.py +2043 -0
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +210 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +118 -0
- onnx_ir-0.1.15.dist-info/METADATA +68 -0
- onnx_ir-0.1.15.dist-info/RECORD +53 -0
- onnx_ir-0.1.15.dist-info/WHEEL +5 -0
- onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
- onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
onnx_ir/_display.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Internal utilities for displaying the intermediate representation of a model.
|
|
4
|
+
|
|
5
|
+
NOTE: All third-party imports should be scoped and imported only when used to avoid
|
|
6
|
+
importing unnecessary dependencies.
|
|
7
|
+
"""
|
|
8
|
+
# pylint: disable=import-outside-toplevel
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def require_rich() -> Any:
|
|
16
|
+
"""Raise an ImportError if rich is not installed."""
|
|
17
|
+
try:
|
|
18
|
+
import rich
|
|
19
|
+
except ImportError:
|
|
20
|
+
return None
|
|
21
|
+
return rich
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class PrettyPrintable:
|
|
25
|
+
def display(self, *, page: bool = False) -> None:
|
|
26
|
+
"""Pretty print the object.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
page: Whether to page the output.
|
|
30
|
+
"""
|
|
31
|
+
rich = require_rich()
|
|
32
|
+
text = str(self)
|
|
33
|
+
|
|
34
|
+
if rich is None:
|
|
35
|
+
print(text)
|
|
36
|
+
# Color print this message
|
|
37
|
+
print(
|
|
38
|
+
f"\n\n\u001b[36mTip: Install the rich library with 'pip install rich' to pretty print this {self.__class__.__name__}.\u001b[0m"
|
|
39
|
+
)
|
|
40
|
+
return
|
|
41
|
+
|
|
42
|
+
import rich.markup
|
|
43
|
+
|
|
44
|
+
# Escape text to display `[...]` correctly
|
|
45
|
+
text = rich.markup.escape(text)
|
|
46
|
+
|
|
47
|
+
if page:
|
|
48
|
+
import rich.console
|
|
49
|
+
|
|
50
|
+
console = rich.console.Console()
|
|
51
|
+
with console.pager():
|
|
52
|
+
console.print(text)
|
|
53
|
+
else:
|
|
54
|
+
rich.print(text)
|
onnx_ir/_enums.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""ONNX IR enums that matches the ONNX spec."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import enum
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import ml_dtypes
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AttributeType(enum.IntEnum):
|
|
15
|
+
"""Enum for the types of ONNX attributes."""
|
|
16
|
+
|
|
17
|
+
UNDEFINED = 0
|
|
18
|
+
FLOAT = 1
|
|
19
|
+
INT = 2
|
|
20
|
+
STRING = 3
|
|
21
|
+
TENSOR = 4
|
|
22
|
+
GRAPH = 5
|
|
23
|
+
FLOATS = 6
|
|
24
|
+
INTS = 7
|
|
25
|
+
STRINGS = 8
|
|
26
|
+
TENSORS = 9
|
|
27
|
+
GRAPHS = 10
|
|
28
|
+
SPARSE_TENSOR = 11
|
|
29
|
+
SPARSE_TENSORS = 12
|
|
30
|
+
TYPE_PROTO = 13
|
|
31
|
+
TYPE_PROTOS = 14
|
|
32
|
+
|
|
33
|
+
def __repr__(self) -> str:
|
|
34
|
+
return self.name
|
|
35
|
+
|
|
36
|
+
def __str__(self) -> str:
|
|
37
|
+
return self.__repr__()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DataType(enum.IntEnum):
|
|
41
|
+
"""Enum for the data types of ONNX tensors, defined in ``onnx.TensorProto``."""
|
|
42
|
+
|
|
43
|
+
# NOTE: Naming: It is tempting to use shorter and more modern names like f32, i64,
|
|
44
|
+
# but we should stick to the names used in the ONNX spec for consistency.
|
|
45
|
+
UNDEFINED = 0
|
|
46
|
+
FLOAT = 1
|
|
47
|
+
UINT8 = 2
|
|
48
|
+
INT8 = 3
|
|
49
|
+
UINT16 = 4
|
|
50
|
+
INT16 = 5
|
|
51
|
+
INT32 = 6
|
|
52
|
+
INT64 = 7
|
|
53
|
+
STRING = 8
|
|
54
|
+
BOOL = 9
|
|
55
|
+
FLOAT16 = 10
|
|
56
|
+
DOUBLE = 11
|
|
57
|
+
UINT32 = 12
|
|
58
|
+
UINT64 = 13
|
|
59
|
+
COMPLEX64 = 14
|
|
60
|
+
COMPLEX128 = 15
|
|
61
|
+
BFLOAT16 = 16
|
|
62
|
+
FLOAT8E4M3FN = 17
|
|
63
|
+
FLOAT8E4M3FNUZ = 18
|
|
64
|
+
FLOAT8E5M2 = 19
|
|
65
|
+
FLOAT8E5M2FNUZ = 20
|
|
66
|
+
UINT4 = 21
|
|
67
|
+
INT4 = 22
|
|
68
|
+
FLOAT4E2M1 = 23
|
|
69
|
+
FLOAT8E8M0 = 24
|
|
70
|
+
UINT2 = 25
|
|
71
|
+
INT2 = 26
|
|
72
|
+
|
|
73
|
+
@classmethod
|
|
74
|
+
def from_numpy(cls, dtype: np.dtype) -> DataType:
|
|
75
|
+
"""Returns the ONNX data type for the numpy dtype.
|
|
76
|
+
|
|
77
|
+
Raises:
|
|
78
|
+
TypeError: If the data type is not supported by ONNX.
|
|
79
|
+
"""
|
|
80
|
+
if dtype in _NP_TYPE_TO_DATA_TYPE:
|
|
81
|
+
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
|
|
82
|
+
|
|
83
|
+
if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
|
|
84
|
+
return DataType.STRING
|
|
85
|
+
|
|
86
|
+
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
|
|
87
|
+
# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
|
|
88
|
+
# TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
|
|
89
|
+
if hasattr(dtype, "names"):
|
|
90
|
+
if dtype.names == ("bfloat16",):
|
|
91
|
+
return DataType.BFLOAT16
|
|
92
|
+
if dtype.names == ("e4m3fn",):
|
|
93
|
+
return DataType.FLOAT8E4M3FN
|
|
94
|
+
if dtype.names == ("e4m3fnuz",):
|
|
95
|
+
return DataType.FLOAT8E4M3FNUZ
|
|
96
|
+
if dtype.names == ("e5m2",):
|
|
97
|
+
return DataType.FLOAT8E5M2
|
|
98
|
+
if dtype.names == ("e5m2fnuz",):
|
|
99
|
+
return DataType.FLOAT8E5M2FNUZ
|
|
100
|
+
if dtype.names == ("uint4",):
|
|
101
|
+
return DataType.UINT4
|
|
102
|
+
if dtype.names == ("int4",):
|
|
103
|
+
return DataType.INT4
|
|
104
|
+
if dtype.names == ("float4e2m1",):
|
|
105
|
+
return DataType.FLOAT4E2M1
|
|
106
|
+
if dtype.names == ("int2",):
|
|
107
|
+
return DataType.INT2
|
|
108
|
+
if dtype.names == ("uint2",):
|
|
109
|
+
return DataType.UINT2
|
|
110
|
+
raise TypeError(f"Unsupported numpy data type: {dtype}")
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_short_name(cls, short_name: str) -> DataType:
|
|
114
|
+
"""Returns the ONNX data type for the short name.
|
|
115
|
+
|
|
116
|
+
Raises:
|
|
117
|
+
TypeError: If the short name is not available for the data type.
|
|
118
|
+
"""
|
|
119
|
+
if short_name not in _SHORT_NAME_TO_DATA_TYPE:
|
|
120
|
+
raise TypeError(f"Unknown short name: {short_name}")
|
|
121
|
+
return cls(_SHORT_NAME_TO_DATA_TYPE[short_name])
|
|
122
|
+
|
|
123
|
+
@property
|
|
124
|
+
def itemsize(self) -> float:
|
|
125
|
+
"""Returns the size of the data type in bytes."""
|
|
126
|
+
return self.bitwidth / 8
|
|
127
|
+
|
|
128
|
+
@property
|
|
129
|
+
def bitwidth(self) -> int:
|
|
130
|
+
"""Returns the bit width of the data type.
|
|
131
|
+
|
|
132
|
+
.. versionadded:: 0.1.2
|
|
133
|
+
|
|
134
|
+
Raises:
|
|
135
|
+
TypeError: If the data type is not supported.
|
|
136
|
+
"""
|
|
137
|
+
if self not in _BITWIDTH_MAP:
|
|
138
|
+
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
|
|
139
|
+
return _BITWIDTH_MAP[self]
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def exponent_bitwidth(self) -> int:
|
|
143
|
+
"""Returns the bit width of the exponent for floating-point types.
|
|
144
|
+
|
|
145
|
+
.. versionadded:: 0.1.8
|
|
146
|
+
|
|
147
|
+
Raises:
|
|
148
|
+
TypeError: If the data type is not supported.
|
|
149
|
+
"""
|
|
150
|
+
if self.is_floating_point():
|
|
151
|
+
return ml_dtypes.finfo(self.numpy()).nexp
|
|
152
|
+
|
|
153
|
+
raise TypeError(f"Exponent not available for ONNX data type: {self}")
|
|
154
|
+
|
|
155
|
+
@property
|
|
156
|
+
def mantissa_bitwidth(self) -> int:
|
|
157
|
+
"""Returns the bit width of the mantissa for floating-point types.
|
|
158
|
+
|
|
159
|
+
.. versionadded:: 0.1.8
|
|
160
|
+
|
|
161
|
+
Raises:
|
|
162
|
+
TypeError: If the data type is not supported.
|
|
163
|
+
"""
|
|
164
|
+
if self.is_floating_point():
|
|
165
|
+
return ml_dtypes.finfo(self.numpy()).nmant
|
|
166
|
+
|
|
167
|
+
raise TypeError(f"Mantissa not available for ONNX data type: {self}")
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def eps(self) -> int | np.floating[Any]:
|
|
171
|
+
"""Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
|
|
172
|
+
|
|
173
|
+
Returns 1 for integers.
|
|
174
|
+
|
|
175
|
+
.. versionadded:: 0.1.8
|
|
176
|
+
|
|
177
|
+
Raises:
|
|
178
|
+
TypeError: If the data type is not a numeric data type.
|
|
179
|
+
"""
|
|
180
|
+
if self.is_integer():
|
|
181
|
+
return 1
|
|
182
|
+
|
|
183
|
+
if self.is_floating_point():
|
|
184
|
+
return ml_dtypes.finfo(self.numpy()).eps
|
|
185
|
+
|
|
186
|
+
raise TypeError(f"Eps not available for ONNX data type: {self}")
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
def tiny(self) -> int | np.floating[Any]:
|
|
190
|
+
"""Returns the smallest positive non-zero value for the ONNX data type.
|
|
191
|
+
|
|
192
|
+
Returns 1 for integers.
|
|
193
|
+
|
|
194
|
+
.. versionadded:: 0.1.8
|
|
195
|
+
|
|
196
|
+
Raises:
|
|
197
|
+
TypeError: If the data type is not a numeric data type.
|
|
198
|
+
"""
|
|
199
|
+
if self.is_integer():
|
|
200
|
+
return 1
|
|
201
|
+
|
|
202
|
+
if self.is_floating_point():
|
|
203
|
+
return ml_dtypes.finfo(self.numpy()).tiny
|
|
204
|
+
|
|
205
|
+
raise TypeError(f"Tiny not available for ONNX data type: {self}")
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def min(self) -> int | np.floating[Any]:
|
|
209
|
+
"""Returns the minimum representable value for the ONNX data type.
|
|
210
|
+
|
|
211
|
+
.. versionadded:: 0.1.8
|
|
212
|
+
|
|
213
|
+
Raises:
|
|
214
|
+
TypeError: If the data type is not a numeric data type.
|
|
215
|
+
"""
|
|
216
|
+
if self.is_integer():
|
|
217
|
+
return ml_dtypes.iinfo(self.numpy()).min
|
|
218
|
+
|
|
219
|
+
if self.is_floating_point():
|
|
220
|
+
return ml_dtypes.finfo(self.numpy()).min
|
|
221
|
+
|
|
222
|
+
raise TypeError(f"Minimum not available for ONNX data type: {self}")
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def max(self) -> int | np.floating[Any]:
|
|
226
|
+
"""Returns the maximum representable value for the ONNX data type.
|
|
227
|
+
|
|
228
|
+
.. versionadded:: 0.1.8
|
|
229
|
+
|
|
230
|
+
Raises:
|
|
231
|
+
TypeError: If the data type is not a numeric data type.
|
|
232
|
+
"""
|
|
233
|
+
if self.is_integer():
|
|
234
|
+
return ml_dtypes.iinfo(self.numpy()).max
|
|
235
|
+
|
|
236
|
+
if self.is_floating_point():
|
|
237
|
+
return ml_dtypes.finfo(self.numpy()).max
|
|
238
|
+
|
|
239
|
+
raise TypeError(f"Maximum not available for ONNX data type: {self}")
|
|
240
|
+
|
|
241
|
+
@property
|
|
242
|
+
def precision(self) -> int:
|
|
243
|
+
"""Returns the precision for the ONNX dtype if supported.
|
|
244
|
+
|
|
245
|
+
For floats returns the approximate number of decimal digits to which
|
|
246
|
+
this kind of float is precise. Returns 0 for integers.
|
|
247
|
+
|
|
248
|
+
.. versionadded:: 0.1.8
|
|
249
|
+
|
|
250
|
+
Raises:
|
|
251
|
+
TypeError: If the data type is not a numeric data type.
|
|
252
|
+
"""
|
|
253
|
+
if self.is_integer():
|
|
254
|
+
return 0
|
|
255
|
+
|
|
256
|
+
if self.is_floating_point():
|
|
257
|
+
return ml_dtypes.finfo(self.numpy()).precision
|
|
258
|
+
|
|
259
|
+
raise TypeError(f"Precision not available for ONNX data type: {self}")
|
|
260
|
+
|
|
261
|
+
@property
|
|
262
|
+
def resolution(self) -> int | np.floating[Any]:
|
|
263
|
+
"""Returns the resolution for the ONNX dtype if supported.
|
|
264
|
+
|
|
265
|
+
Returns the approximate decimal resolution of this type, i.e.,
|
|
266
|
+
10**-precision. Returns 1 for integers.
|
|
267
|
+
|
|
268
|
+
.. versionadded:: 0.1.8
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
TypeError: If the data type is not a numeric data type.
|
|
272
|
+
"""
|
|
273
|
+
if self.is_integer():
|
|
274
|
+
return 1
|
|
275
|
+
|
|
276
|
+
if self.is_floating_point():
|
|
277
|
+
return ml_dtypes.finfo(self.numpy()).resolution
|
|
278
|
+
|
|
279
|
+
raise TypeError(f"Resolution not available for ONNX data type: {self}")
|
|
280
|
+
|
|
281
|
+
def numpy(self) -> np.dtype:
|
|
282
|
+
"""Returns the numpy dtype for the ONNX data type.
|
|
283
|
+
|
|
284
|
+
Raises:
|
|
285
|
+
TypeError: If the data type is not supported by numpy.
|
|
286
|
+
"""
|
|
287
|
+
if self not in _DATA_TYPE_TO_NP_TYPE:
|
|
288
|
+
raise TypeError(f"Numpy does not support ONNX data type: {self}")
|
|
289
|
+
return _DATA_TYPE_TO_NP_TYPE[self]
|
|
290
|
+
|
|
291
|
+
def short_name(self) -> str:
|
|
292
|
+
"""Returns the short name of the data type.
|
|
293
|
+
|
|
294
|
+
The short name is a string that is used to represent the data type in a more
|
|
295
|
+
compact form. For example, the short name for `DataType.FLOAT` is "f32".
|
|
296
|
+
To get the corresponding data type back, call ``from_short_name`` on a string.
|
|
297
|
+
|
|
298
|
+
Naming reference: https://github.com/pytorch/pytorch/blob/4bead7b85ea4160243c74109e0ce9bb80686d016/torch/utils/_dtype_abbrs.py
|
|
299
|
+
|
|
300
|
+
Raises:
|
|
301
|
+
TypeError: If the short name is not available for the data type.
|
|
302
|
+
"""
|
|
303
|
+
if self not in _DATA_TYPE_TO_SHORT_NAME:
|
|
304
|
+
raise TypeError(f"Short name not available for ONNX data type: {self}")
|
|
305
|
+
return _DATA_TYPE_TO_SHORT_NAME[self]
|
|
306
|
+
|
|
307
|
+
def is_floating_point(self) -> bool:
|
|
308
|
+
"""Returns True if the data type is a floating point type."""
|
|
309
|
+
return self in {
|
|
310
|
+
DataType.FLOAT,
|
|
311
|
+
DataType.FLOAT16,
|
|
312
|
+
DataType.DOUBLE,
|
|
313
|
+
DataType.BFLOAT16,
|
|
314
|
+
DataType.FLOAT8E4M3FN,
|
|
315
|
+
DataType.FLOAT8E4M3FNUZ,
|
|
316
|
+
DataType.FLOAT8E5M2,
|
|
317
|
+
DataType.FLOAT8E5M2FNUZ,
|
|
318
|
+
DataType.FLOAT4E2M1,
|
|
319
|
+
DataType.FLOAT8E8M0,
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
def is_integer(self) -> bool:
|
|
323
|
+
"""Returns True if the data type is an integer.
|
|
324
|
+
|
|
325
|
+
.. versionadded:: 0.1.4
|
|
326
|
+
"""
|
|
327
|
+
return self in {
|
|
328
|
+
DataType.UINT8,
|
|
329
|
+
DataType.INT8,
|
|
330
|
+
DataType.UINT16,
|
|
331
|
+
DataType.INT16,
|
|
332
|
+
DataType.INT32,
|
|
333
|
+
DataType.INT64,
|
|
334
|
+
DataType.UINT32,
|
|
335
|
+
DataType.UINT64,
|
|
336
|
+
DataType.UINT4,
|
|
337
|
+
DataType.INT4,
|
|
338
|
+
DataType.INT2,
|
|
339
|
+
DataType.UINT2,
|
|
340
|
+
}
|
|
341
|
+
|
|
342
|
+
def is_signed(self) -> bool:
|
|
343
|
+
"""Returns True if the data type is a signed type.
|
|
344
|
+
|
|
345
|
+
.. versionadded:: 0.1.4
|
|
346
|
+
"""
|
|
347
|
+
return self in {
|
|
348
|
+
DataType.FLOAT,
|
|
349
|
+
DataType.INT8,
|
|
350
|
+
DataType.INT16,
|
|
351
|
+
DataType.INT32,
|
|
352
|
+
DataType.INT64,
|
|
353
|
+
DataType.FLOAT16,
|
|
354
|
+
DataType.DOUBLE,
|
|
355
|
+
DataType.COMPLEX64,
|
|
356
|
+
DataType.COMPLEX128,
|
|
357
|
+
DataType.BFLOAT16,
|
|
358
|
+
DataType.FLOAT8E4M3FN,
|
|
359
|
+
DataType.FLOAT8E4M3FNUZ,
|
|
360
|
+
DataType.FLOAT8E5M2,
|
|
361
|
+
DataType.FLOAT8E5M2FNUZ,
|
|
362
|
+
DataType.INT4,
|
|
363
|
+
DataType.FLOAT4E2M1,
|
|
364
|
+
DataType.FLOAT8E8M0,
|
|
365
|
+
DataType.INT2,
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
def is_string(self) -> bool:
|
|
369
|
+
"""Returns True if the data type is a string type.
|
|
370
|
+
|
|
371
|
+
.. versionadded:: 0.1.8
|
|
372
|
+
"""
|
|
373
|
+
return self == DataType.STRING
|
|
374
|
+
|
|
375
|
+
def __repr__(self) -> str:
|
|
376
|
+
return self.name
|
|
377
|
+
|
|
378
|
+
def __str__(self) -> str:
|
|
379
|
+
return self.__repr__()
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
_BITWIDTH_MAP = {
|
|
383
|
+
DataType.FLOAT: 32,
|
|
384
|
+
DataType.UINT8: 8,
|
|
385
|
+
DataType.INT8: 8,
|
|
386
|
+
DataType.UINT16: 16,
|
|
387
|
+
DataType.INT16: 16,
|
|
388
|
+
DataType.INT32: 32,
|
|
389
|
+
DataType.INT64: 64,
|
|
390
|
+
DataType.BOOL: 8,
|
|
391
|
+
DataType.FLOAT16: 16,
|
|
392
|
+
DataType.DOUBLE: 64,
|
|
393
|
+
DataType.UINT32: 32,
|
|
394
|
+
DataType.UINT64: 64,
|
|
395
|
+
DataType.COMPLEX64: 64, # 2 * 32
|
|
396
|
+
DataType.COMPLEX128: 128, # 2 * 64
|
|
397
|
+
DataType.BFLOAT16: 16,
|
|
398
|
+
DataType.FLOAT8E4M3FN: 8,
|
|
399
|
+
DataType.FLOAT8E4M3FNUZ: 8,
|
|
400
|
+
DataType.FLOAT8E5M2: 8,
|
|
401
|
+
DataType.FLOAT8E5M2FNUZ: 8,
|
|
402
|
+
DataType.UINT4: 4,
|
|
403
|
+
DataType.INT4: 4,
|
|
404
|
+
DataType.FLOAT4E2M1: 4,
|
|
405
|
+
DataType.FLOAT8E8M0: 8,
|
|
406
|
+
DataType.INT2: 2,
|
|
407
|
+
DataType.UINT2: 2,
|
|
408
|
+
}
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# We use ml_dtypes to support dtypes that are not in numpy.
|
|
412
|
+
_NP_TYPE_TO_DATA_TYPE = {
|
|
413
|
+
np.dtype("bool"): DataType.BOOL,
|
|
414
|
+
np.dtype("complex128"): DataType.COMPLEX128,
|
|
415
|
+
np.dtype("complex64"): DataType.COMPLEX64,
|
|
416
|
+
np.dtype("float16"): DataType.FLOAT16,
|
|
417
|
+
np.dtype("float32"): DataType.FLOAT,
|
|
418
|
+
np.dtype("float64"): DataType.DOUBLE,
|
|
419
|
+
np.dtype("int16"): DataType.INT16,
|
|
420
|
+
np.dtype("int32"): DataType.INT32,
|
|
421
|
+
np.dtype("int64"): DataType.INT64,
|
|
422
|
+
np.dtype("int8"): DataType.INT8,
|
|
423
|
+
np.dtype("object"): DataType.STRING,
|
|
424
|
+
np.dtype("uint16"): DataType.UINT16,
|
|
425
|
+
np.dtype("uint32"): DataType.UINT32,
|
|
426
|
+
np.dtype("uint64"): DataType.UINT64,
|
|
427
|
+
np.dtype("uint8"): DataType.UINT8,
|
|
428
|
+
np.dtype(ml_dtypes.bfloat16): DataType.BFLOAT16,
|
|
429
|
+
np.dtype(ml_dtypes.float8_e4m3fn): DataType.FLOAT8E4M3FN,
|
|
430
|
+
np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
|
|
431
|
+
np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
|
|
432
|
+
np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
|
|
433
|
+
np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
|
|
434
|
+
np.dtype(ml_dtypes.int4): DataType.INT4,
|
|
435
|
+
np.dtype(ml_dtypes.uint4): DataType.UINT4,
|
|
436
|
+
np.dtype(ml_dtypes.float4_e2m1fn): DataType.FLOAT4E2M1,
|
|
437
|
+
np.dtype(ml_dtypes.int2): DataType.INT2,
|
|
438
|
+
np.dtype(ml_dtypes.uint2): DataType.UINT2,
|
|
439
|
+
}
|
|
440
|
+
|
|
441
|
+
# ONNX DataType to Numpy dtype.
|
|
442
|
+
_DATA_TYPE_TO_NP_TYPE = {v: k for k, v in _NP_TYPE_TO_DATA_TYPE.items()}
|
|
443
|
+
|
|
444
|
+
_DATA_TYPE_TO_SHORT_NAME = {
|
|
445
|
+
DataType.UNDEFINED: "undefined",
|
|
446
|
+
DataType.BFLOAT16: "bf16",
|
|
447
|
+
DataType.DOUBLE: "f64",
|
|
448
|
+
DataType.FLOAT: "f32",
|
|
449
|
+
DataType.FLOAT16: "f16",
|
|
450
|
+
DataType.FLOAT8E4M3FN: "f8e4m3fn",
|
|
451
|
+
DataType.FLOAT8E5M2: "f8e5m2",
|
|
452
|
+
DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
|
|
453
|
+
DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
|
|
454
|
+
DataType.FLOAT8E8M0: "f8e8m0",
|
|
455
|
+
DataType.FLOAT4E2M1: "f4e2m1",
|
|
456
|
+
DataType.COMPLEX64: "c64",
|
|
457
|
+
DataType.COMPLEX128: "c128",
|
|
458
|
+
DataType.INT2: "i2",
|
|
459
|
+
DataType.INT4: "i4",
|
|
460
|
+
DataType.INT8: "i8",
|
|
461
|
+
DataType.INT16: "i16",
|
|
462
|
+
DataType.INT32: "i32",
|
|
463
|
+
DataType.INT64: "i64",
|
|
464
|
+
DataType.BOOL: "b8",
|
|
465
|
+
DataType.UINT2: "u2",
|
|
466
|
+
DataType.UINT4: "u4",
|
|
467
|
+
DataType.UINT8: "u8",
|
|
468
|
+
DataType.UINT16: "u16",
|
|
469
|
+
DataType.UINT32: "u32",
|
|
470
|
+
DataType.UINT64: "u64",
|
|
471
|
+
DataType.STRING: "s",
|
|
472
|
+
}
|
|
473
|
+
|
|
474
|
+
_SHORT_NAME_TO_DATA_TYPE = {v: k for k, v in _DATA_TYPE_TO_SHORT_NAME.items()}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Utilities for comparing IR graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
from onnx_ir import _core
|
|
8
|
+
|
|
9
|
+
# NOTE(justinchuby): We need to ensure a graph has valid inputs and outputs
|
|
10
|
+
# NOTE(justinchuby): A graph may be specified with a set of inputs and outputs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def topologically_equal(graph1: _core.Graph, graph2: _core.Graph) -> bool:
|
|
14
|
+
"""Return true if the two graphs are topologically equivalent, without considering initializers.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
graph1: The first graph to compare.
|
|
18
|
+
graph2: The second graph to compare.
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
True if the graphs are equal, False otherwise.
|
|
22
|
+
"""
|
|
23
|
+
raise NotImplementedError()
|