onnx-ir 0.1.2__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.2/src/onnx_ir.egg-info → onnx_ir-0.1.4}/PKG-INFO +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/__init__.py +5 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_core.py +36 -22
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_enums.py +44 -0
- onnx_ir-0.1.4/src/onnx_ir/_type_casting.py +50 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir-0.1.4/src/onnx_ir/passes/common/common_subexpression_elimination.py +206 -0
- onnx_ir-0.1.4/src/onnx_ir/passes/common/initializer_deduplication.py +56 -0
- onnx_ir-0.1.4/src/onnx_ir/py.typed +1 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/serde.py +77 -46
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/tensor_adapters.py +62 -7
- {onnx_ir-0.1.2 → onnx_ir-0.1.4/src/onnx_ir.egg-info}/PKG-INFO +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/SOURCES.txt +2 -0
- onnx_ir-0.1.2/src/onnx_ir/_type_casting.py +0 -107
- onnx_ir-0.1.2/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -177
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/LICENSE +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/MANIFEST.in +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/README.md +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/pyproject.toml +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/setup.cfg +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -323,6 +323,9 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
|
323
323
|
and the first value with that name is returned. Values with empty names
|
|
324
324
|
are excluded from the mapping.
|
|
325
325
|
|
|
326
|
+
.. versionchanged:: 0.1.2
|
|
327
|
+
Values from subgraphs are now included in the mapping.
|
|
328
|
+
|
|
326
329
|
Args:
|
|
327
330
|
graph: The graph to extract the mapping from.
|
|
328
331
|
|
|
@@ -410,6 +413,8 @@ def get_const_tensor(
|
|
|
410
413
|
it will propagate the shape and type of the constant tensor to the value
|
|
411
414
|
if `propagate_shape_type` is set to True.
|
|
412
415
|
|
|
416
|
+
.. versionadded:: 0.1.2
|
|
417
|
+
|
|
413
418
|
Args:
|
|
414
419
|
value: The value to get the constant tensor from.
|
|
415
420
|
propagate_shape_type: If True, the shape and type of the value will be
|
|
@@ -417,6 +417,9 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
417
417
|
else:
|
|
418
418
|
self._shape = shape
|
|
419
419
|
self._shape.freeze()
|
|
420
|
+
if isinstance(value, np.generic):
|
|
421
|
+
# Turn numpy scalar into a numpy array
|
|
422
|
+
value = np.array(value) # type: ignore[assignment]
|
|
420
423
|
if dtype is None:
|
|
421
424
|
if isinstance(value, np.ndarray):
|
|
422
425
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
@@ -654,15 +657,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
654
657
|
self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
|
|
655
658
|
return
|
|
656
659
|
# Map the whole file into the memory
|
|
657
|
-
# TODO(justinchuby): Verify if this would exhaust the memory address space
|
|
658
660
|
with open(self.path, "rb") as f:
|
|
659
661
|
self.raw = mmap.mmap(
|
|
660
662
|
f.fileno(),
|
|
661
663
|
0,
|
|
662
664
|
access=mmap.ACCESS_READ,
|
|
663
665
|
)
|
|
664
|
-
|
|
665
|
-
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
666
|
+
|
|
666
667
|
if self.dtype in {
|
|
667
668
|
_enums.DataType.INT4,
|
|
668
669
|
_enums.DataType.UINT4,
|
|
@@ -672,16 +673,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
672
673
|
dt = np.dtype(np.uint8).newbyteorder("<")
|
|
673
674
|
count = self.size // 2 + self.size % 2
|
|
674
675
|
else:
|
|
676
|
+
# Handle the byte order correctly by always using little endian
|
|
677
|
+
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
675
678
|
count = self.size
|
|
679
|
+
|
|
676
680
|
self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
|
|
677
681
|
shape = self.shape.numpy()
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
|
|
682
|
+
|
|
683
|
+
if self.dtype.bitwidth == 4:
|
|
684
|
+
# Unpack the 4bit arrays
|
|
685
|
+
self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
|
|
686
|
+
self.dtype.numpy()
|
|
687
|
+
)
|
|
685
688
|
else:
|
|
686
689
|
self._array = self._array.reshape(shape)
|
|
687
690
|
|
|
@@ -964,7 +967,10 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
964
967
|
|
|
965
968
|
|
|
966
969
|
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
967
|
-
"""A tensor that stores 4bit datatypes in packed format.
|
|
970
|
+
"""A tensor that stores 4bit datatypes in packed format.
|
|
971
|
+
|
|
972
|
+
.. versionadded:: 0.1.2
|
|
973
|
+
"""
|
|
968
974
|
|
|
969
975
|
__slots__ = (
|
|
970
976
|
"_dtype",
|
|
@@ -1065,15 +1071,7 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
|
|
|
1065
1071
|
"""
|
|
1066
1072
|
array = self.numpy_packed()
|
|
1067
1073
|
# ONNX IR returns the unpacked arrays
|
|
1068
|
-
|
|
1069
|
-
return _type_casting.unpack_int4(array, self.shape.numpy())
|
|
1070
|
-
if self.dtype == _enums.DataType.UINT4:
|
|
1071
|
-
return _type_casting.unpack_uint4(array, self.shape.numpy())
|
|
1072
|
-
if self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
1073
|
-
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
|
|
1074
|
-
raise TypeError(
|
|
1075
|
-
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
|
|
1076
|
-
)
|
|
1074
|
+
return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
|
|
1077
1075
|
|
|
1078
1076
|
def numpy_packed(self) -> npt.NDArray[np.uint8]:
|
|
1079
1077
|
"""Return the tensor as a packed array."""
|
|
@@ -2335,6 +2333,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2335
2333
|
seen as a Sequence of nodes and should be used as such. For example, to obtain
|
|
2336
2334
|
all nodes as a list, call ``list(graph)``.
|
|
2337
2335
|
|
|
2336
|
+
.. versionchanged:: 0.1.1
|
|
2337
|
+
Values with non-none producers will be rejected as graph inputs or initializers.
|
|
2338
|
+
|
|
2339
|
+
.. versionadded:: 0.1.1
|
|
2340
|
+
Added ``add`` method to initializers and attributes.
|
|
2341
|
+
|
|
2338
2342
|
Attributes:
|
|
2339
2343
|
name: The name of the graph.
|
|
2340
2344
|
inputs: The input values of the graph.
|
|
@@ -2545,12 +2549,17 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2545
2549
|
Consider using
|
|
2546
2550
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2547
2551
|
traversals on nodes.
|
|
2552
|
+
|
|
2553
|
+
.. versionadded:: 0.1.2
|
|
2548
2554
|
"""
|
|
2549
2555
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2550
2556
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2551
2557
|
|
|
2552
2558
|
def subgraphs(self) -> Iterator[Graph]:
|
|
2553
|
-
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2559
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2560
|
+
|
|
2561
|
+
.. versionadded:: 0.1.2
|
|
2562
|
+
"""
|
|
2554
2563
|
seen_graphs: set[Graph] = set()
|
|
2555
2564
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2556
2565
|
graph = node.graph
|
|
@@ -3216,12 +3225,17 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3216
3225
|
Consider using
|
|
3217
3226
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3218
3227
|
traversals on nodes.
|
|
3228
|
+
|
|
3229
|
+
.. versionadded:: 0.1.2
|
|
3219
3230
|
"""
|
|
3220
3231
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3221
3232
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3222
3233
|
|
|
3223
3234
|
def subgraphs(self) -> Iterator[Graph]:
|
|
3224
|
-
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3235
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3236
|
+
|
|
3237
|
+
.. versionadded:: 0.1.2
|
|
3238
|
+
"""
|
|
3225
3239
|
seen_graphs: set[Graph] = set()
|
|
3226
3240
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3227
3241
|
graph = node.graph
|
|
@@ -120,6 +120,8 @@ class DataType(enum.IntEnum):
|
|
|
120
120
|
def bitwidth(self) -> int:
|
|
121
121
|
"""Returns the bit width of the data type.
|
|
122
122
|
|
|
123
|
+
.. versionadded:: 0.1.2
|
|
124
|
+
|
|
123
125
|
Raises:
|
|
124
126
|
TypeError: If the data type is not supported.
|
|
125
127
|
"""
|
|
@@ -167,6 +169,48 @@ class DataType(enum.IntEnum):
|
|
|
167
169
|
DataType.FLOAT4E2M1,
|
|
168
170
|
}
|
|
169
171
|
|
|
172
|
+
def is_integer(self) -> bool:
|
|
173
|
+
"""Returns True if the data type is an integer.
|
|
174
|
+
|
|
175
|
+
.. versionadded:: 0.1.4
|
|
176
|
+
"""
|
|
177
|
+
return self in {
|
|
178
|
+
DataType.UINT8,
|
|
179
|
+
DataType.INT8,
|
|
180
|
+
DataType.UINT16,
|
|
181
|
+
DataType.INT16,
|
|
182
|
+
DataType.INT32,
|
|
183
|
+
DataType.INT64,
|
|
184
|
+
DataType.UINT32,
|
|
185
|
+
DataType.UINT64,
|
|
186
|
+
DataType.UINT4,
|
|
187
|
+
DataType.INT4,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def is_signed(self) -> bool:
|
|
191
|
+
"""Returns True if the data type is a signed type.
|
|
192
|
+
|
|
193
|
+
.. versionadded:: 0.1.4
|
|
194
|
+
"""
|
|
195
|
+
return self in {
|
|
196
|
+
DataType.FLOAT,
|
|
197
|
+
DataType.INT8,
|
|
198
|
+
DataType.INT16,
|
|
199
|
+
DataType.INT32,
|
|
200
|
+
DataType.INT64,
|
|
201
|
+
DataType.FLOAT16,
|
|
202
|
+
DataType.DOUBLE,
|
|
203
|
+
DataType.COMPLEX64,
|
|
204
|
+
DataType.COMPLEX128,
|
|
205
|
+
DataType.BFLOAT16,
|
|
206
|
+
DataType.FLOAT8E4M3FN,
|
|
207
|
+
DataType.FLOAT8E4M3FNUZ,
|
|
208
|
+
DataType.FLOAT8E5M2,
|
|
209
|
+
DataType.FLOAT8E5M2FNUZ,
|
|
210
|
+
DataType.INT4,
|
|
211
|
+
DataType.FLOAT4E2M1,
|
|
212
|
+
}
|
|
213
|
+
|
|
170
214
|
def __repr__(self) -> str:
|
|
171
215
|
return self.name
|
|
172
216
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Numpy utilities for non-native type operation."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import typing
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
if typing.TYPE_CHECKING:
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
17
|
+
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
18
|
+
# Create a 1D copy
|
|
19
|
+
array_flat = array.ravel().view(np.uint8).copy()
|
|
20
|
+
size = array.size
|
|
21
|
+
odd_sized = size % 2 == 1
|
|
22
|
+
if odd_sized:
|
|
23
|
+
array_flat.resize([size + 1], refcheck=False)
|
|
24
|
+
array_flat &= 0x0F
|
|
25
|
+
array_flat[1::2] <<= 4
|
|
26
|
+
return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
|
|
30
|
+
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
data: A numpy array.
|
|
34
|
+
dims: The dimensions are used to reshape the unpacked buffer.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
A numpy array of int8/uint8 reshaped to dims.
|
|
38
|
+
"""
|
|
39
|
+
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
40
|
+
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
41
|
+
array_low = data & np.uint8(0x0F)
|
|
42
|
+
array_high = data & np.uint8(0xF0)
|
|
43
|
+
array_high >>= np.uint8(4)
|
|
44
|
+
result[0::2] = array_low
|
|
45
|
+
result[1::2] = array_high
|
|
46
|
+
if result.size == np.prod(dims) + 1:
|
|
47
|
+
# handle single-element padding due to odd number of elements
|
|
48
|
+
result = result[:-1]
|
|
49
|
+
result.resize(dims, refcheck=False)
|
|
50
|
+
return result
|
|
@@ -6,6 +6,7 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateInitializersPass",
|
|
9
10
|
"InlinePass",
|
|
10
11
|
"LiftConstantsToInitializersPass",
|
|
11
12
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
|
|
|
29
30
|
LiftSubgraphInitializersToMainGraphPass,
|
|
30
31
|
RemoveInitializersFromInputsPass,
|
|
31
32
|
)
|
|
33
|
+
from onnx_ir.passes.common.initializer_deduplication import (
|
|
34
|
+
DeduplicateInitializersPass,
|
|
35
|
+
)
|
|
32
36
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
33
37
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
34
38
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.1.1
|
|
23
|
+
|
|
24
|
+
.. versionchanged:: 0.1.3
|
|
25
|
+
Constant nodes with values smaller than ``size_limit`` will be CSE'd.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
|
|
29
|
+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, size_limit: int = 10):
|
|
34
|
+
"""Initialize the CommonSubexpressionEliminationPass."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
40
|
+
graph = model.graph
|
|
41
|
+
modified = self._eliminate_common_subexpression(graph)
|
|
42
|
+
|
|
43
|
+
return ir.passes.PassResult(
|
|
44
|
+
model,
|
|
45
|
+
modified=modified,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
|
|
49
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
50
|
+
modified: bool = False
|
|
51
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
52
|
+
existing_node_info_to_the_node: dict[
|
|
53
|
+
tuple[
|
|
54
|
+
ir.OperatorIdentifier,
|
|
55
|
+
int, # len(outputs)
|
|
56
|
+
tuple[int, ...], # input ids
|
|
57
|
+
tuple[tuple[str, object], ...], # attributes
|
|
58
|
+
],
|
|
59
|
+
ir.Node,
|
|
60
|
+
] = {}
|
|
61
|
+
|
|
62
|
+
for node in graph:
|
|
63
|
+
# Skip control flow ops like Loop and If.
|
|
64
|
+
control_flow_op: bool = False
|
|
65
|
+
# Skip large tensors to avoid cse weights and bias.
|
|
66
|
+
large_tensor: bool = False
|
|
67
|
+
# Use equality to check if the node is a common subexpression.
|
|
68
|
+
attributes = {}
|
|
69
|
+
for k, v in node.attributes.items():
|
|
70
|
+
# TODO(exporter team): CSE subgraphs.
|
|
71
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
72
|
+
# because attribute: graph won't match.
|
|
73
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
74
|
+
control_flow_op = True
|
|
75
|
+
break
|
|
76
|
+
# The attribute value could be directly taken from the original
|
|
77
|
+
# protobuf, so we need to make a copy of it.
|
|
78
|
+
value = v.value
|
|
79
|
+
if v.type in (
|
|
80
|
+
ir.AttributeType.INTS,
|
|
81
|
+
ir.AttributeType.FLOATS,
|
|
82
|
+
ir.AttributeType.STRINGS,
|
|
83
|
+
):
|
|
84
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
85
|
+
# to ensure they are hashable.
|
|
86
|
+
value = tuple(value)
|
|
87
|
+
elif v.type is ir.AttributeType.TENSOR:
|
|
88
|
+
if value.size > self.size_limit:
|
|
89
|
+
# If the tensor is larger than the size limit, we skip it.
|
|
90
|
+
large_tensor = True
|
|
91
|
+
break
|
|
92
|
+
np_value = value.numpy()
|
|
93
|
+
|
|
94
|
+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
|
|
95
|
+
attributes[k] = value
|
|
96
|
+
|
|
97
|
+
if control_flow_op:
|
|
98
|
+
# If the node is a control flow op, we skip it.
|
|
99
|
+
logger.debug("Skipping control flow op %s", node)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if large_tensor:
|
|
103
|
+
# If the node has a large tensor, we skip it.
|
|
104
|
+
logger.debug("Skipping large tensor in node %s", node)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if _is_non_deterministic_op(node):
|
|
108
|
+
# If the node is a non-deterministic op, we skip it.
|
|
109
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
node_info = (
|
|
113
|
+
node.op_identifier(),
|
|
114
|
+
len(node.outputs),
|
|
115
|
+
tuple(id(input) for input in node.inputs),
|
|
116
|
+
tuple(sorted(attributes.items())),
|
|
117
|
+
)
|
|
118
|
+
# Check if the node is a common subexpression.
|
|
119
|
+
if node_info in existing_node_info_to_the_node:
|
|
120
|
+
# If it is, this node has an existing node with the same
|
|
121
|
+
# operator, number of outputs, inputs, and attributes.
|
|
122
|
+
# We replace the node with the existing node.
|
|
123
|
+
modified = True
|
|
124
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
125
|
+
_remove_node_and_replace_values(
|
|
126
|
+
graph,
|
|
127
|
+
remove_node=node,
|
|
128
|
+
remove_values=node.outputs,
|
|
129
|
+
new_values=existing_node.outputs,
|
|
130
|
+
)
|
|
131
|
+
logger.debug("Reusing node %s", existing_node)
|
|
132
|
+
else:
|
|
133
|
+
# If it is not, add to the mapping.
|
|
134
|
+
existing_node_info_to_the_node[node_info] = node
|
|
135
|
+
return modified
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _remove_node_and_replace_values(
|
|
139
|
+
graph: ir.Graph,
|
|
140
|
+
/,
|
|
141
|
+
remove_node: ir.Node,
|
|
142
|
+
remove_values: Sequence[ir.Value],
|
|
143
|
+
new_values: Sequence[ir.Value],
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Replaces nodes and values in the graph or function.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
graph: The graph to replace nodes and values in.
|
|
149
|
+
remove_node: The node to remove.
|
|
150
|
+
remove_values: The values to replace.
|
|
151
|
+
new_values: The values to replace with.
|
|
152
|
+
"""
|
|
153
|
+
# Reconnect the users of the deleted values to use the new values
|
|
154
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
155
|
+
# Update graph/function outputs if the node generates output
|
|
156
|
+
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
157
|
+
replacement_mapping = dict(zip(remove_values, new_values))
|
|
158
|
+
for idx, graph_output in enumerate(graph.outputs):
|
|
159
|
+
if graph_output in replacement_mapping:
|
|
160
|
+
new_value = replacement_mapping[graph_output]
|
|
161
|
+
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
162
|
+
# If the new value is also a graph input/output, we need to
|
|
163
|
+
# create a Identity node to preserve the remove_value and
|
|
164
|
+
# prevent from changing new_value name.
|
|
165
|
+
identity_node = ir.node(
|
|
166
|
+
"Identity",
|
|
167
|
+
inputs=[new_value],
|
|
168
|
+
outputs=[
|
|
169
|
+
ir.Value(
|
|
170
|
+
name=graph_output.name,
|
|
171
|
+
type=graph_output.type,
|
|
172
|
+
shape=graph_output.shape,
|
|
173
|
+
)
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
# reuse the name of the graph output
|
|
177
|
+
graph.outputs[idx] = identity_node.outputs[0]
|
|
178
|
+
graph.insert_before(
|
|
179
|
+
remove_node,
|
|
180
|
+
identity_node,
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
# if new_value is not graph output, we just
|
|
184
|
+
# update it to use old_value name.
|
|
185
|
+
new_value.name = graph_output.name
|
|
186
|
+
graph.outputs[idx] = new_value
|
|
187
|
+
|
|
188
|
+
graph.remove(remove_node, safe=True)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
192
|
+
non_deterministic_ops = frozenset(
|
|
193
|
+
{
|
|
194
|
+
"RandomUniform",
|
|
195
|
+
"RandomNormal",
|
|
196
|
+
"RandomUniformLike",
|
|
197
|
+
"RandomNormalLike",
|
|
198
|
+
"Multinomial",
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _is_onnx_domain(d: str) -> bool:
|
|
205
|
+
"""Check if the domain is the ONNX domain."""
|
|
206
|
+
return d == ""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DeduplicateInitializersPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Remove duplicated initializer tensors from the graph.
|
|
17
|
+
|
|
18
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
19
|
+
and replaces all duplicate references with a canonical one.
|
|
20
|
+
|
|
21
|
+
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
22
|
+
to lift the initializers to the main graph first before running pass.
|
|
23
|
+
|
|
24
|
+
.. versionadded:: 0.1.3
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, size_limit: int = 1024):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.size_limit = size_limit
|
|
30
|
+
|
|
31
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
32
|
+
graph = model.graph
|
|
33
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
34
|
+
modified = False
|
|
35
|
+
|
|
36
|
+
for initializer in tuple(graph.initializers.values()):
|
|
37
|
+
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
|
|
38
|
+
# out from the main graph before running this pass.
|
|
39
|
+
const_val = initializer.const_value
|
|
40
|
+
if const_val is None:
|
|
41
|
+
# Skip if initializer has no constant value
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if const_val.size > self.size_limit:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
48
|
+
if key in initializers:
|
|
49
|
+
modified = True
|
|
50
|
+
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
|
|
51
|
+
assert initializer.name is not None
|
|
52
|
+
graph.initializers.pop(initializer.name)
|
|
53
|
+
else:
|
|
54
|
+
initializers[key] = initializer # type: ignore[index]
|
|
55
|
+
|
|
56
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -74,7 +74,6 @@ from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
|
|
|
74
74
|
|
|
75
75
|
if typing.TYPE_CHECKING:
|
|
76
76
|
import google.protobuf.internal.containers as proto_containers
|
|
77
|
-
import numpy.typing as npt
|
|
78
77
|
|
|
79
78
|
logger = logging.getLogger(__name__)
|
|
80
79
|
|
|
@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
|
|
|
117
116
|
return np.dtype(dtype).newbyteorder("<")
|
|
118
117
|
|
|
119
118
|
|
|
120
|
-
def _unflatten_complex(
|
|
121
|
-
array: npt.NDArray[np.float32 | np.float64],
|
|
122
|
-
) -> npt.NDArray[np.complex64 | np.complex128]:
|
|
123
|
-
"""Convert the real representation of a complex dtype to the complex dtype."""
|
|
124
|
-
return array[::2] + 1j * array[1::2]
|
|
125
|
-
|
|
126
|
-
|
|
127
119
|
@typing.overload
|
|
128
120
|
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
|
|
129
121
|
@typing.overload
|
|
@@ -200,6 +192,9 @@ def from_onnx_text(
|
|
|
200
192
|
|
|
201
193
|
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
202
194
|
|
|
195
|
+
.. versionchanged:: 0.1.2
|
|
196
|
+
Added the ``initializers`` argument.
|
|
197
|
+
|
|
203
198
|
Args:
|
|
204
199
|
model_text: The ONNX textual representation of the model.
|
|
205
200
|
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
@@ -237,6 +232,8 @@ def to_onnx_text(
|
|
|
237
232
|
) -> str:
|
|
238
233
|
"""Convert the IR model to the ONNX textual representation.
|
|
239
234
|
|
|
235
|
+
.. versionadded:: 0.1.2
|
|
236
|
+
|
|
240
237
|
Args:
|
|
241
238
|
model: The IR model to convert.
|
|
242
239
|
exclude_initializers: If True, the initializers will not be included in the output.
|
|
@@ -386,54 +383,88 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
386
383
|
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
|
|
387
384
|
)
|
|
388
385
|
|
|
386
|
+
shape = self._proto.dims
|
|
387
|
+
|
|
389
388
|
if self._proto.HasField("raw_data"):
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
389
|
+
if dtype.bitwidth == 4:
|
|
390
|
+
return _type_casting.unpack_4bitx2(
|
|
391
|
+
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
|
|
392
|
+
).view(dtype.numpy())
|
|
393
|
+
return np.frombuffer(
|
|
394
|
+
self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
|
|
395
|
+
).reshape(shape)
|
|
396
|
+
if dtype == _enums.DataType.STRING:
|
|
397
|
+
return np.array(self._proto.string_data).reshape(shape)
|
|
398
|
+
if self._proto.int32_data:
|
|
399
|
+
assert dtype in {
|
|
400
|
+
_enums.DataType.BFLOAT16,
|
|
401
|
+
_enums.DataType.BOOL,
|
|
402
|
+
_enums.DataType.FLOAT16,
|
|
403
|
+
_enums.DataType.FLOAT4E2M1,
|
|
400
404
|
_enums.DataType.FLOAT8E4M3FN,
|
|
401
405
|
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
402
406
|
_enums.DataType.FLOAT8E5M2,
|
|
403
407
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
408
|
+
_enums.DataType.INT16,
|
|
409
|
+
_enums.DataType.INT32,
|
|
410
|
+
_enums.DataType.INT4,
|
|
411
|
+
_enums.DataType.INT8,
|
|
412
|
+
_enums.DataType.UINT16,
|
|
413
|
+
_enums.DataType.UINT4,
|
|
414
|
+
_enums.DataType.UINT8,
|
|
415
|
+
}, f"Unsupported dtype {dtype} for int32_data"
|
|
416
|
+
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
|
|
417
|
+
if dtype.bitwidth == 32:
|
|
418
|
+
return array.reshape(shape)
|
|
419
|
+
if dtype.bitwidth == 16:
|
|
420
|
+
# Reinterpret the int32 as float16 or bfloat16
|
|
421
|
+
return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
|
|
422
|
+
if dtype.bitwidth == 8:
|
|
423
|
+
return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
|
|
424
|
+
if dtype.bitwidth == 4:
|
|
425
|
+
return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
|
|
426
|
+
dtype.numpy()
|
|
427
|
+
)
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
|
|
430
|
+
)
|
|
431
|
+
if self._proto.int64_data:
|
|
432
|
+
assert dtype in {
|
|
433
|
+
_enums.DataType.INT64,
|
|
434
|
+
}, f"Unsupported dtype {dtype} for int64_data"
|
|
435
|
+
return np.array(
|
|
436
|
+
self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
|
|
437
|
+
).reshape(shape)
|
|
438
|
+
if self._proto.uint64_data:
|
|
439
|
+
assert dtype in {
|
|
440
|
+
_enums.DataType.UINT64,
|
|
441
|
+
_enums.DataType.UINT32,
|
|
442
|
+
}, f"Unsupported dtype {dtype} for uint64_data"
|
|
409
443
|
array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
|
|
410
|
-
|
|
444
|
+
if dtype == _enums.DataType.UINT32:
|
|
445
|
+
return array.astype(np.uint32).reshape(shape)
|
|
446
|
+
return array.reshape(shape)
|
|
447
|
+
if self._proto.float_data:
|
|
448
|
+
assert dtype in {
|
|
449
|
+
_enums.DataType.FLOAT,
|
|
450
|
+
_enums.DataType.COMPLEX64,
|
|
451
|
+
}, f"Unsupported dtype {dtype} for float_data"
|
|
411
452
|
array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
|
|
412
453
|
if dtype == _enums.DataType.COMPLEX64:
|
|
413
|
-
array
|
|
414
|
-
|
|
454
|
+
return array.view(np.complex64).reshape(shape)
|
|
455
|
+
return array.reshape(shape)
|
|
456
|
+
if self._proto.double_data:
|
|
457
|
+
assert dtype in {
|
|
458
|
+
_enums.DataType.DOUBLE,
|
|
459
|
+
_enums.DataType.COMPLEX128,
|
|
460
|
+
}, f"Unsupported dtype {dtype} for double_data"
|
|
415
461
|
array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
|
|
416
462
|
if dtype == _enums.DataType.COMPLEX128:
|
|
417
|
-
array
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
return np.array([], dtype=dtype.numpy())
|
|
423
|
-
else:
|
|
424
|
-
# Otherwise we return a size 0 array with the correct shape
|
|
425
|
-
return np.zeros(self._proto.dims, dtype=dtype.numpy())
|
|
426
|
-
|
|
427
|
-
if dtype == _enums.DataType.INT4:
|
|
428
|
-
return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
|
|
429
|
-
elif dtype == _enums.DataType.UINT4:
|
|
430
|
-
return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
|
|
431
|
-
elif dtype == _enums.DataType.FLOAT4E2M1:
|
|
432
|
-
return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
|
|
433
|
-
else:
|
|
434
|
-
# Otherwise convert to the correct dtype and reshape
|
|
435
|
-
# Note we cannot use view() here because the storage dtype may not be the same size as the target
|
|
436
|
-
return array.astype(dtype.numpy()).reshape(self._proto.dims)
|
|
463
|
+
return array.view(np.complex128).reshape(shape)
|
|
464
|
+
return array.reshape(shape)
|
|
465
|
+
|
|
466
|
+
# Empty tensor. We return a size 0 array with the correct shape
|
|
467
|
+
return np.zeros(shape, dtype=dtype.numpy())
|
|
437
468
|
|
|
438
469
|
def tobytes(self) -> bytes:
|
|
439
470
|
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian.
|
|
@@ -29,6 +29,8 @@ Example::
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
__all__ = [
|
|
32
|
+
"from_torch_dtype",
|
|
33
|
+
"to_torch_dtype",
|
|
32
34
|
"TorchTensor",
|
|
33
35
|
]
|
|
34
36
|
|
|
@@ -44,14 +46,17 @@ if TYPE_CHECKING:
|
|
|
44
46
|
import torch
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
|
|
50
|
+
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
54
|
+
"""Convert a PyTorch dtype to an ONNX IR DataType."""
|
|
55
|
+
global _TORCH_DTYPE_TO_ONNX
|
|
56
|
+
if _TORCH_DTYPE_TO_ONNX is None:
|
|
52
57
|
import torch
|
|
53
58
|
|
|
54
|
-
_TORCH_DTYPE_TO_ONNX
|
|
59
|
+
_TORCH_DTYPE_TO_ONNX = {
|
|
55
60
|
torch.bfloat16: ir.DataType.BFLOAT16,
|
|
56
61
|
torch.bool: ir.DataType.BOOL,
|
|
57
62
|
torch.complex128: ir.DataType.COMPLEX128,
|
|
@@ -72,8 +77,58 @@ class TorchTensor(_core.Tensor):
|
|
|
72
77
|
torch.uint32: ir.DataType.UINT32,
|
|
73
78
|
torch.uint64: ir.DataType.UINT64,
|
|
74
79
|
}
|
|
80
|
+
if dtype not in _TORCH_DTYPE_TO_ONNX:
|
|
81
|
+
raise TypeError(
|
|
82
|
+
f"Unsupported PyTorch dtype '{dtype}'. "
|
|
83
|
+
"Please use a supported dtype from the list: "
|
|
84
|
+
f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
|
|
85
|
+
)
|
|
86
|
+
return _TORCH_DTYPE_TO_ONNX[dtype]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
90
|
+
"""Convert an ONNX IR DataType to a PyTorch dtype."""
|
|
91
|
+
global _ONNX_DTYPE_TO_TORCH
|
|
92
|
+
if _ONNX_DTYPE_TO_TORCH is None:
|
|
93
|
+
import torch
|
|
94
|
+
|
|
95
|
+
_ONNX_DTYPE_TO_TORCH = {
|
|
96
|
+
ir.DataType.BFLOAT16: torch.bfloat16,
|
|
97
|
+
ir.DataType.BOOL: torch.bool,
|
|
98
|
+
ir.DataType.COMPLEX128: torch.complex128,
|
|
99
|
+
ir.DataType.COMPLEX64: torch.complex64,
|
|
100
|
+
ir.DataType.FLOAT16: torch.float16,
|
|
101
|
+
ir.DataType.FLOAT: torch.float32,
|
|
102
|
+
ir.DataType.DOUBLE: torch.float64,
|
|
103
|
+
ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
|
|
104
|
+
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
|
|
105
|
+
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
|
|
106
|
+
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
|
|
107
|
+
ir.DataType.INT16: torch.int16,
|
|
108
|
+
ir.DataType.INT32: torch.int32,
|
|
109
|
+
ir.DataType.INT64: torch.int64,
|
|
110
|
+
ir.DataType.INT8: torch.int8,
|
|
111
|
+
ir.DataType.UINT8: torch.uint8,
|
|
112
|
+
ir.DataType.UINT16: torch.uint16,
|
|
113
|
+
ir.DataType.UINT32: torch.uint32,
|
|
114
|
+
ir.DataType.UINT64: torch.uint64,
|
|
115
|
+
}
|
|
116
|
+
if dtype not in _ONNX_DTYPE_TO_TORCH:
|
|
117
|
+
raise TypeError(
|
|
118
|
+
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
|
|
119
|
+
"Please use a supported dtype from the list: "
|
|
120
|
+
f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
|
|
121
|
+
)
|
|
122
|
+
return _ONNX_DTYPE_TO_TORCH[dtype]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TorchTensor(_core.Tensor):
|
|
126
|
+
def __init__(
|
|
127
|
+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
|
|
128
|
+
):
|
|
129
|
+
# Pass the tensor as the raw data to ir.Tensor's constructor
|
|
75
130
|
super().__init__(
|
|
76
|
-
tensor, dtype=
|
|
131
|
+
tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
|
|
77
132
|
)
|
|
78
133
|
|
|
79
134
|
def numpy(self) -> npt.NDArray:
|
|
@@ -19,6 +19,7 @@ src/onnx_ir/_type_casting.py
|
|
|
19
19
|
src/onnx_ir/_version_utils.py
|
|
20
20
|
src/onnx_ir/convenience.py
|
|
21
21
|
src/onnx_ir/external_data.py
|
|
22
|
+
src/onnx_ir/py.typed
|
|
22
23
|
src/onnx_ir/serde.py
|
|
23
24
|
src/onnx_ir/tape.py
|
|
24
25
|
src/onnx_ir/tensor_adapters.py
|
|
@@ -39,6 +40,7 @@ src/onnx_ir/passes/common/_c_api_utils.py
|
|
|
39
40
|
src/onnx_ir/passes/common/clear_metadata_and_docstring.py
|
|
40
41
|
src/onnx_ir/passes/common/common_subexpression_elimination.py
|
|
41
42
|
src/onnx_ir/passes/common/constant_manipulation.py
|
|
43
|
+
src/onnx_ir/passes/common/initializer_deduplication.py
|
|
42
44
|
src/onnx_ir/passes/common/inliner.py
|
|
43
45
|
src/onnx_ir/passes/common/onnx_checker.py
|
|
44
46
|
src/onnx_ir/passes/common/shape_inference.py
|
|
@@ -1,107 +0,0 @@
|
|
|
1
|
-
# Copyright (c) ONNX Project Contributors
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
"""Numpy utilities for non-native type operation."""
|
|
4
|
-
# TODO(justinchuby): Upstream the logic to onnx
|
|
5
|
-
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
|
|
8
|
-
import typing
|
|
9
|
-
from collections.abc import Sequence
|
|
10
|
-
|
|
11
|
-
import ml_dtypes
|
|
12
|
-
import numpy as np
|
|
13
|
-
|
|
14
|
-
if typing.TYPE_CHECKING:
|
|
15
|
-
import numpy.typing as npt
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
19
|
-
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
20
|
-
# Create a 1D copy
|
|
21
|
-
array_flat = array.ravel().view(np.uint8).copy()
|
|
22
|
-
size = array.size
|
|
23
|
-
odd_sized = size % 2 == 1
|
|
24
|
-
if odd_sized:
|
|
25
|
-
array_flat.resize([size + 1], refcheck=False)
|
|
26
|
-
array_flat &= 0x0F
|
|
27
|
-
array_flat[1::2] <<= 4
|
|
28
|
-
return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def _unpack_uint4_as_uint8(
|
|
32
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
33
|
-
) -> npt.NDArray[np.uint8]:
|
|
34
|
-
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
data: A numpy array.
|
|
38
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
A numpy array of int8/uint8 reshaped to dims.
|
|
42
|
-
"""
|
|
43
|
-
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
44
|
-
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
45
|
-
array_low = data & np.uint8(0x0F)
|
|
46
|
-
array_high = data & np.uint8(0xF0)
|
|
47
|
-
array_high >>= np.uint8(4)
|
|
48
|
-
result[0::2] = array_low
|
|
49
|
-
result[1::2] = array_high
|
|
50
|
-
if result.size == np.prod(dims) + 1:
|
|
51
|
-
# handle single-element padding due to odd number of elements
|
|
52
|
-
result = result[:-1]
|
|
53
|
-
result.resize(dims, refcheck=False)
|
|
54
|
-
return result
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def unpack_uint4(
|
|
58
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
59
|
-
) -> npt.NDArray[ml_dtypes.uint4]:
|
|
60
|
-
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
data: A numpy array.
|
|
64
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
A numpy array of int8/uint8 reshaped to dims.
|
|
68
|
-
"""
|
|
69
|
-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
|
|
73
|
-
"""Extend 4-bit signed integer to 8-bit signed integer."""
|
|
74
|
-
return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def unpack_int4(
|
|
78
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
79
|
-
) -> npt.NDArray[ml_dtypes.int4]:
|
|
80
|
-
"""Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
|
|
81
|
-
|
|
82
|
-
The sign bit is extended to the most significant bit of the int8.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
data: A numpy array.
|
|
86
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
87
|
-
|
|
88
|
-
Returns:
|
|
89
|
-
A numpy array of int8 reshaped to dims.
|
|
90
|
-
"""
|
|
91
|
-
unpacked = _unpack_uint4_as_uint8(data, dims)
|
|
92
|
-
return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def unpack_float4e2m1(
|
|
96
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
97
|
-
) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
|
|
98
|
-
"""Convert a packed float4e2m1 array to unpacked float4e2m1 array.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
data: A numpy array.
|
|
102
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
103
|
-
|
|
104
|
-
Returns:
|
|
105
|
-
A numpy array of float32 reshaped to dims.
|
|
106
|
-
"""
|
|
107
|
-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
|
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
# Copyright (c) ONNX Project Contributors
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
-
|
|
5
|
-
from __future__ import annotations
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"CommonSubexpressionEliminationPass",
|
|
9
|
-
]
|
|
10
|
-
|
|
11
|
-
import logging
|
|
12
|
-
from collections.abc import Sequence
|
|
13
|
-
|
|
14
|
-
import onnx_ir as ir
|
|
15
|
-
|
|
16
|
-
logger = logging.getLogger(__name__)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
21
|
-
|
|
22
|
-
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
|
-
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
-
modified = False
|
|
25
|
-
graph = model.graph
|
|
26
|
-
|
|
27
|
-
modified = _eliminate_common_subexpression(graph, modified)
|
|
28
|
-
|
|
29
|
-
return ir.passes.PassResult(
|
|
30
|
-
model,
|
|
31
|
-
modified=modified,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
|
|
36
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
37
|
-
# node to node identifier, length of outputs, inputs, and attributes
|
|
38
|
-
existing_node_info_to_the_node: dict[
|
|
39
|
-
tuple[
|
|
40
|
-
ir.OperatorIdentifier,
|
|
41
|
-
int, # len(outputs)
|
|
42
|
-
tuple[int, ...], # input ids
|
|
43
|
-
tuple[tuple[str, object], ...], # attributes
|
|
44
|
-
],
|
|
45
|
-
ir.Node,
|
|
46
|
-
] = {}
|
|
47
|
-
|
|
48
|
-
for node in graph:
|
|
49
|
-
# Skip control flow ops like Loop and If.
|
|
50
|
-
control_flow_op: bool = False
|
|
51
|
-
# Use equality to check if the node is a common subexpression.
|
|
52
|
-
attributes = {}
|
|
53
|
-
for k, v in node.attributes.items():
|
|
54
|
-
# TODO(exporter team): CSE subgraphs.
|
|
55
|
-
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
56
|
-
# because attribute: graph won't match.
|
|
57
|
-
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
58
|
-
control_flow_op = True
|
|
59
|
-
logger.debug("Skipping control flow op %s", node)
|
|
60
|
-
# The attribute value could be directly taken from the original
|
|
61
|
-
# protobuf, so we need to make a copy of it.
|
|
62
|
-
value = v.value
|
|
63
|
-
if v.type in (
|
|
64
|
-
ir.AttributeType.INTS,
|
|
65
|
-
ir.AttributeType.FLOATS,
|
|
66
|
-
ir.AttributeType.STRINGS,
|
|
67
|
-
):
|
|
68
|
-
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
69
|
-
# to ensure they are hashable.
|
|
70
|
-
value = tuple(value)
|
|
71
|
-
attributes[k] = value
|
|
72
|
-
|
|
73
|
-
if control_flow_op:
|
|
74
|
-
# If the node is a control flow op, we skip it.
|
|
75
|
-
logger.debug("Skipping control flow op %s", node)
|
|
76
|
-
continue
|
|
77
|
-
|
|
78
|
-
if _is_non_deterministic_op(node):
|
|
79
|
-
# If the node is a non-deterministic op, we skip it.
|
|
80
|
-
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
|
-
node_info = (
|
|
84
|
-
node.op_identifier(),
|
|
85
|
-
len(node.outputs),
|
|
86
|
-
tuple(id(input) for input in node.inputs),
|
|
87
|
-
tuple(sorted(attributes.items())),
|
|
88
|
-
)
|
|
89
|
-
# Check if the node is a common subexpression.
|
|
90
|
-
if node_info in existing_node_info_to_the_node:
|
|
91
|
-
# If it is, this node has an existing node with the same
|
|
92
|
-
# operator, number of outputs, inputs, and attributes.
|
|
93
|
-
# We replace the node with the existing node.
|
|
94
|
-
modified = True
|
|
95
|
-
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
-
_remove_node_and_replace_values(
|
|
97
|
-
graph,
|
|
98
|
-
remove_node=node,
|
|
99
|
-
remove_values=node.outputs,
|
|
100
|
-
new_values=existing_node.outputs,
|
|
101
|
-
)
|
|
102
|
-
logger.debug("Reusing node %s", existing_node)
|
|
103
|
-
else:
|
|
104
|
-
# If it is not, add to the mapping.
|
|
105
|
-
existing_node_info_to_the_node[node_info] = node
|
|
106
|
-
return modified
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def _remove_node_and_replace_values(
|
|
110
|
-
graph: ir.Graph,
|
|
111
|
-
/,
|
|
112
|
-
remove_node: ir.Node,
|
|
113
|
-
remove_values: Sequence[ir.Value],
|
|
114
|
-
new_values: Sequence[ir.Value],
|
|
115
|
-
) -> None:
|
|
116
|
-
"""Replaces nodes and values in the graph or function.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
graph: The graph to replace nodes and values in.
|
|
120
|
-
remove_node: The node to remove.
|
|
121
|
-
remove_values: The values to replace.
|
|
122
|
-
new_values: The values to replace with.
|
|
123
|
-
"""
|
|
124
|
-
# Reconnect the users of the deleted values to use the new values
|
|
125
|
-
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
126
|
-
# Update graph/function outputs if the node generates output
|
|
127
|
-
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
128
|
-
replacement_mapping = dict(zip(remove_values, new_values))
|
|
129
|
-
for idx, graph_output in enumerate(graph.outputs):
|
|
130
|
-
if graph_output in replacement_mapping:
|
|
131
|
-
new_value = replacement_mapping[graph_output]
|
|
132
|
-
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
133
|
-
# If the new value is also a graph input/output, we need to
|
|
134
|
-
# create a Identity node to preserve the remove_value and
|
|
135
|
-
# prevent from changing new_value name.
|
|
136
|
-
identity_node = ir.node(
|
|
137
|
-
"Identity",
|
|
138
|
-
inputs=[new_value],
|
|
139
|
-
outputs=[
|
|
140
|
-
ir.Value(
|
|
141
|
-
name=graph_output.name,
|
|
142
|
-
type=graph_output.type,
|
|
143
|
-
shape=graph_output.shape,
|
|
144
|
-
)
|
|
145
|
-
],
|
|
146
|
-
)
|
|
147
|
-
# reuse the name of the graph output
|
|
148
|
-
graph.outputs[idx] = identity_node.outputs[0]
|
|
149
|
-
graph.insert_before(
|
|
150
|
-
remove_node,
|
|
151
|
-
identity_node,
|
|
152
|
-
)
|
|
153
|
-
else:
|
|
154
|
-
# if new_value is not graph output, we just
|
|
155
|
-
# update it to use old_value name.
|
|
156
|
-
new_value.name = graph_output.name
|
|
157
|
-
graph.outputs[idx] = new_value
|
|
158
|
-
|
|
159
|
-
graph.remove(remove_node, safe=True)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
163
|
-
non_deterministic_ops = frozenset(
|
|
164
|
-
{
|
|
165
|
-
"RandomUniform",
|
|
166
|
-
"RandomNormal",
|
|
167
|
-
"RandomUniformLike",
|
|
168
|
-
"RandomNormalLike",
|
|
169
|
-
"Multinomial",
|
|
170
|
-
}
|
|
171
|
-
)
|
|
172
|
-
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def _is_onnx_domain(d: str) -> bool:
|
|
176
|
-
"""Check if the domain is the ONNX domain."""
|
|
177
|
-
return d == ""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|