onnx-ir 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +5 -2
- onnx_ir/_convenience/__init__.py +130 -4
- onnx_ir/_convenience/_constructors.py +6 -2
- onnx_ir/_core.py +283 -39
- onnx_ir/_enums.py +37 -25
- onnx_ir/_graph_containers.py +2 -2
- onnx_ir/_io.py +40 -4
- onnx_ir/_type_casting.py +2 -1
- onnx_ir/_version_utils.py +5 -48
- onnx_ir/convenience.py +3 -1
- onnx_ir/external_data.py +43 -3
- onnx_ir/passes/_pass_infra.py +1 -1
- onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir/passes/common/_c_api_utils.py +1 -1
- onnx_ir/passes/common/common_subexpression_elimination.py +104 -75
- onnx_ir/passes/common/initializer_deduplication.py +56 -0
- onnx_ir/passes/common/onnx_checker.py +1 -1
- onnx_ir/passes/common/shape_inference.py +1 -1
- onnx_ir/passes/common/unused_removal.py +1 -1
- onnx_ir/serde.py +176 -6
- onnx_ir/tensor_adapters.py +62 -7
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/METADATA +22 -4
- onnx_ir-0.1.3.dist-info/RECORD +43 -0
- onnx_ir-0.1.1.dist-info/RECORD +0 -42
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/WHEEL +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/licenses/LICENSE +0 -0
- {onnx_ir-0.1.1.dist-info → onnx_ir-0.1.3.dist-info}/top_level.txt +0 -0
onnx_ir/_type_casting.py
CHANGED
|
@@ -15,7 +15,7 @@ if typing.TYPE_CHECKING:
|
|
|
15
15
|
import numpy.typing as npt
|
|
16
16
|
|
|
17
17
|
|
|
18
|
-
def
|
|
18
|
+
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
19
19
|
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
20
20
|
# Create a 1D copy
|
|
21
21
|
array_flat = array.ravel().view(np.uint8).copy()
|
|
@@ -40,6 +40,7 @@ def _unpack_uint4_as_uint8(
|
|
|
40
40
|
Returns:
|
|
41
41
|
A numpy array of int8/uint8 reshaped to dims.
|
|
42
42
|
"""
|
|
43
|
+
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
43
44
|
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
44
45
|
array_low = data & np.uint8(0x0F)
|
|
45
46
|
array_high = data & np.uint8(0xF0)
|
onnx_ir/_version_utils.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
"""Version utils for testing."""
|
|
4
4
|
|
|
5
|
+
# pylint: disable=import-outside-toplevel
|
|
5
6
|
from __future__ import annotations
|
|
6
7
|
|
|
7
8
|
import packaging.version
|
|
@@ -9,7 +10,7 @@ import packaging.version
|
|
|
9
10
|
|
|
10
11
|
def onnx_older_than(version: str) -> bool:
|
|
11
12
|
"""Returns True if the ONNX version is older than the given version."""
|
|
12
|
-
import onnx #
|
|
13
|
+
import onnx # noqa: TID251
|
|
13
14
|
|
|
14
15
|
return (
|
|
15
16
|
packaging.version.parse(onnx.__version__).release
|
|
@@ -19,7 +20,7 @@ def onnx_older_than(version: str) -> bool:
|
|
|
19
20
|
|
|
20
21
|
def torch_older_than(version: str) -> bool:
|
|
21
22
|
"""Returns True if the torch version is older than the given version."""
|
|
22
|
-
import torch
|
|
23
|
+
import torch
|
|
23
24
|
|
|
24
25
|
return (
|
|
25
26
|
packaging.version.parse(torch.__version__).release
|
|
@@ -27,42 +28,9 @@ def torch_older_than(version: str) -> bool:
|
|
|
27
28
|
)
|
|
28
29
|
|
|
29
30
|
|
|
30
|
-
def transformers_older_than(version: str) -> bool | None:
|
|
31
|
-
"""Returns True if the transformers version is older than the given version."""
|
|
32
|
-
try:
|
|
33
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
34
|
-
except ImportError:
|
|
35
|
-
return None
|
|
36
|
-
|
|
37
|
-
return (
|
|
38
|
-
packaging.version.parse(transformers.__version__).release
|
|
39
|
-
< packaging.version.parse(version).release
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def is_onnxruntime_training() -> bool:
|
|
44
|
-
"""Returns True if the onnxruntime is onnxruntime-training."""
|
|
45
|
-
try:
|
|
46
|
-
from onnxruntime import training # pylint: disable=import-outside-toplevel
|
|
47
|
-
|
|
48
|
-
assert training
|
|
49
|
-
except ImportError:
|
|
50
|
-
# onnxruntime not training
|
|
51
|
-
return False
|
|
52
|
-
|
|
53
|
-
try:
|
|
54
|
-
from onnxruntime.capi.onnxruntime_pybind11_state import ( # pylint: disable=import-outside-toplevel
|
|
55
|
-
OrtValueVector,
|
|
56
|
-
)
|
|
57
|
-
except ImportError:
|
|
58
|
-
return False
|
|
59
|
-
|
|
60
|
-
return hasattr(OrtValueVector, "push_back_batch")
|
|
61
|
-
|
|
62
|
-
|
|
63
31
|
def onnxruntime_older_than(version: str) -> bool:
|
|
64
32
|
"""Returns True if the onnxruntime version is older than the given version."""
|
|
65
|
-
import onnxruntime
|
|
33
|
+
import onnxruntime
|
|
66
34
|
|
|
67
35
|
return (
|
|
68
36
|
packaging.version.parse(onnxruntime.__version__).release
|
|
@@ -72,20 +40,9 @@ def onnxruntime_older_than(version: str) -> bool:
|
|
|
72
40
|
|
|
73
41
|
def numpy_older_than(version: str) -> bool:
|
|
74
42
|
"""Returns True if the numpy version is older than the given version."""
|
|
75
|
-
import numpy
|
|
43
|
+
import numpy
|
|
76
44
|
|
|
77
45
|
return (
|
|
78
46
|
packaging.version.parse(numpy.__version__).release
|
|
79
47
|
< packaging.version.parse(version).release
|
|
80
48
|
)
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
def has_transformers():
|
|
84
|
-
"""Tells if transformers is installed."""
|
|
85
|
-
try:
|
|
86
|
-
import transformers # pylint: disable=import-outside-toplevel
|
|
87
|
-
|
|
88
|
-
assert transformers
|
|
89
|
-
return True # noqa
|
|
90
|
-
except ImportError:
|
|
91
|
-
return False
|
onnx_ir/convenience.py
CHANGED
|
@@ -7,15 +7,17 @@ from __future__ import annotations
|
|
|
7
7
|
__all__ = [
|
|
8
8
|
"convert_attribute",
|
|
9
9
|
"convert_attributes",
|
|
10
|
+
"create_value_mapping",
|
|
11
|
+
"get_const_tensor",
|
|
10
12
|
"replace_all_uses_with",
|
|
11
13
|
"replace_nodes_and_values",
|
|
12
|
-
"create_value_mapping",
|
|
13
14
|
]
|
|
14
15
|
|
|
15
16
|
from onnx_ir._convenience import (
|
|
16
17
|
convert_attribute,
|
|
17
18
|
convert_attributes,
|
|
18
19
|
create_value_mapping,
|
|
20
|
+
get_const_tensor,
|
|
19
21
|
replace_all_uses_with,
|
|
20
22
|
replace_nodes_and_values,
|
|
21
23
|
)
|
onnx_ir/external_data.py
CHANGED
|
@@ -4,12 +4,15 @@
|
|
|
4
4
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
|
+
from typing import Callable
|
|
8
|
+
|
|
7
9
|
__all__ = [
|
|
8
10
|
"set_base_dir",
|
|
9
11
|
"unload_from_model",
|
|
10
12
|
"load_to_model",
|
|
11
13
|
"convert_tensors_to_external",
|
|
12
14
|
"convert_tensors_from_external",
|
|
15
|
+
"CallbackInfo",
|
|
13
16
|
]
|
|
14
17
|
|
|
15
18
|
import dataclasses
|
|
@@ -48,6 +51,21 @@ class _ExternalDataInfo:
|
|
|
48
51
|
length: int
|
|
49
52
|
|
|
50
53
|
|
|
54
|
+
@dataclasses.dataclass
|
|
55
|
+
class CallbackInfo:
|
|
56
|
+
"""A class that shares information about a tensor that is to be saved as external data for callback functions.
|
|
57
|
+
|
|
58
|
+
Attributes:
|
|
59
|
+
total: The total number of tensors to save.
|
|
60
|
+
index: The index of the tensor being saved.
|
|
61
|
+
offset: The offset of the tensor in the external data file.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
total: int
|
|
65
|
+
index: int
|
|
66
|
+
offset: int
|
|
67
|
+
|
|
68
|
+
|
|
51
69
|
def _all_tensors(
|
|
52
70
|
graph: _core.Graph | _core.GraphView, include_attributes: bool = False
|
|
53
71
|
) -> Iterator[_protocols.TensorProtocol]:
|
|
@@ -157,6 +175,7 @@ def _write_external_data(
|
|
|
157
175
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
158
176
|
external_data_infos: Sequence[_ExternalDataInfo],
|
|
159
177
|
file_path: str | os.PathLike,
|
|
178
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
160
179
|
) -> None:
|
|
161
180
|
"""Write tensor data to an external file according to information stored in ExternalDataInfo objects.
|
|
162
181
|
|
|
@@ -164,12 +183,26 @@ def _write_external_data(
|
|
|
164
183
|
tensors: Tensors to be written as external data.
|
|
165
184
|
external_data_infos: External data information stored for each tensor to be written as external data.
|
|
166
185
|
file_path: Location to which external data is to be stored.
|
|
186
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
187
|
+
for debugging or logging purposes.
|
|
167
188
|
"""
|
|
168
|
-
|
|
189
|
+
tensors_count = len(tensors)
|
|
190
|
+
assert tensors_count == len(external_data_infos), (
|
|
169
191
|
"Number of tensors and external data infos should match"
|
|
170
192
|
)
|
|
171
193
|
with open(file_path, "wb") as data_file:
|
|
172
|
-
for tensor, tensor_info in
|
|
194
|
+
for i, (tensor, tensor_info) in enumerate(
|
|
195
|
+
zip(tensors, external_data_infos, strict=True)
|
|
196
|
+
):
|
|
197
|
+
if callback is not None:
|
|
198
|
+
callback(
|
|
199
|
+
tensor,
|
|
200
|
+
CallbackInfo(
|
|
201
|
+
total=tensors_count,
|
|
202
|
+
index=i,
|
|
203
|
+
offset=tensor_info.offset,
|
|
204
|
+
),
|
|
205
|
+
)
|
|
173
206
|
current_offset = tensor_info.offset
|
|
174
207
|
assert tensor is not None
|
|
175
208
|
raw_data = tensor.tobytes()
|
|
@@ -228,6 +261,7 @@ def convert_tensors_to_external(
|
|
|
228
261
|
tensors: Sequence[_protocols.TensorProtocol],
|
|
229
262
|
base_dir: str | os.PathLike,
|
|
230
263
|
relative_path: str | os.PathLike,
|
|
264
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
231
265
|
) -> list[_core.ExternalTensor]:
|
|
232
266
|
"""Convert a sequence of any TensorProtocol tensors to external tensors.
|
|
233
267
|
|
|
@@ -238,6 +272,8 @@ def convert_tensors_to_external(
|
|
|
238
272
|
tensors: Tensors to be converted to external tensors. They can be external tensors themselves.
|
|
239
273
|
base_dir: Path of base directory.
|
|
240
274
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
275
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
276
|
+
for debugging or logging purposes.
|
|
241
277
|
|
|
242
278
|
Returns:
|
|
243
279
|
A list of external tensors derived from a list of input tensors. The order
|
|
@@ -285,7 +321,7 @@ def convert_tensors_to_external(
|
|
|
285
321
|
external_info = _compute_external_data_info(tensor, current_offset)
|
|
286
322
|
external_data_infos.append(external_info)
|
|
287
323
|
current_offset = external_info.offset + external_info.length
|
|
288
|
-
_write_external_data(sorted_tensors, external_data_infos, path)
|
|
324
|
+
_write_external_data(sorted_tensors, external_data_infos, path, callback=callback)
|
|
289
325
|
|
|
290
326
|
# Create external tensor objects
|
|
291
327
|
external_tensors: list[_core.ExternalTensor] = [
|
|
@@ -336,6 +372,7 @@ def unload_from_model(
|
|
|
336
372
|
relative_path: str | os.PathLike,
|
|
337
373
|
*,
|
|
338
374
|
size_threshold_bytes: int = 0,
|
|
375
|
+
callback: Callable[[_protocols.TensorProtocol, CallbackInfo], None] | None = None,
|
|
339
376
|
) -> _core.Model:
|
|
340
377
|
"""Convert all initializers equal or above size_threshold_bytes to external tensors in-place and save data to a single data file.
|
|
341
378
|
|
|
@@ -356,6 +393,8 @@ def unload_from_model(
|
|
|
356
393
|
relative_path: Path to which external data is to be stored, relative to the ONNX file.
|
|
357
394
|
E.g. "model.data"
|
|
358
395
|
size_threshold_bytes: Save to external data if the tensor size in bytes is larger than this threshold.
|
|
396
|
+
callback: A callback function that is called for each tensor that is saved to external data
|
|
397
|
+
for debugging or logging purposes.
|
|
359
398
|
|
|
360
399
|
Returns:
|
|
361
400
|
An ir.Model with all initializer data equal or above ``size_threshold_bytes``
|
|
@@ -384,6 +423,7 @@ def unload_from_model(
|
|
|
384
423
|
[v.const_value for v in initializers_to_become_external], # type: ignore[misc]
|
|
385
424
|
base_dir=base_dir,
|
|
386
425
|
relative_path=relative_path,
|
|
426
|
+
callback=callback,
|
|
387
427
|
)
|
|
388
428
|
|
|
389
429
|
# Replace the initializer values with external tensors and save the model
|
onnx_ir/passes/_pass_infra.py
CHANGED
|
@@ -6,6 +6,7 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateInitializersPass",
|
|
9
10
|
"InlinePass",
|
|
10
11
|
"LiftConstantsToInitializersPass",
|
|
11
12
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
|
|
|
29
30
|
LiftSubgraphInitializersToMainGraphPass,
|
|
30
31
|
RemoveInitializersFromInputsPass,
|
|
31
32
|
)
|
|
33
|
+
from onnx_ir.passes.common.initializer_deduplication import (
|
|
34
|
+
DeduplicateInitializersPass,
|
|
35
|
+
)
|
|
32
36
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
33
37
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
34
38
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
@@ -17,93 +17,122 @@ logger = logging.getLogger(__name__)
|
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
-
"""Eliminate common subexpression in ONNX graphs.
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.1.1
|
|
23
|
+
|
|
24
|
+
.. versionchanged:: 0.1.3
|
|
25
|
+
Constant nodes with values smaller than ``size_limit`` will be CSE'd.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
|
|
29
|
+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, size_limit: int = 10):
|
|
34
|
+
"""Initialize the CommonSubexpressionEliminationPass."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.size_limit = size_limit
|
|
21
37
|
|
|
22
38
|
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
39
|
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
-
modified = False
|
|
25
40
|
graph = model.graph
|
|
26
|
-
|
|
27
|
-
modified = _eliminate_common_subexpression(graph, modified)
|
|
41
|
+
modified = self._eliminate_common_subexpression(graph)
|
|
28
42
|
|
|
29
43
|
return ir.passes.PassResult(
|
|
30
44
|
model,
|
|
31
45
|
modified=modified,
|
|
32
46
|
)
|
|
33
47
|
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
48
|
+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
|
|
49
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
50
|
+
modified: bool = False
|
|
51
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
52
|
+
existing_node_info_to_the_node: dict[
|
|
53
|
+
tuple[
|
|
54
|
+
ir.OperatorIdentifier,
|
|
55
|
+
int, # len(outputs)
|
|
56
|
+
tuple[int, ...], # input ids
|
|
57
|
+
tuple[tuple[str, object], ...], # attributes
|
|
58
|
+
],
|
|
59
|
+
ir.Node,
|
|
60
|
+
] = {}
|
|
61
|
+
|
|
62
|
+
for node in graph:
|
|
63
|
+
# Skip control flow ops like Loop and If.
|
|
64
|
+
control_flow_op: bool = False
|
|
65
|
+
# Skip large tensors to avoid cse weights and bias.
|
|
66
|
+
large_tensor: bool = False
|
|
67
|
+
# Use equality to check if the node is a common subexpression.
|
|
68
|
+
attributes = {}
|
|
69
|
+
for k, v in node.attributes.items():
|
|
70
|
+
# TODO(exporter team): CSE subgraphs.
|
|
71
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
72
|
+
# because attribute: graph won't match.
|
|
73
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
74
|
+
control_flow_op = True
|
|
75
|
+
break
|
|
76
|
+
# The attribute value could be directly taken from the original
|
|
77
|
+
# protobuf, so we need to make a copy of it.
|
|
78
|
+
value = v.value
|
|
79
|
+
if v.type in (
|
|
80
|
+
ir.AttributeType.INTS,
|
|
81
|
+
ir.AttributeType.FLOATS,
|
|
82
|
+
ir.AttributeType.STRINGS,
|
|
83
|
+
):
|
|
84
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
85
|
+
# to ensure they are hashable.
|
|
86
|
+
value = tuple(value)
|
|
87
|
+
elif v.type is ir.AttributeType.TENSOR:
|
|
88
|
+
if value.size > self.size_limit:
|
|
89
|
+
# If the tensor is larger than the size limit, we skip it.
|
|
90
|
+
large_tensor = True
|
|
91
|
+
break
|
|
92
|
+
np_value = value.numpy()
|
|
93
|
+
|
|
94
|
+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
|
|
95
|
+
attributes[k] = value
|
|
96
|
+
|
|
97
|
+
if control_flow_op:
|
|
98
|
+
# If the node is a control flow op, we skip it.
|
|
59
99
|
logger.debug("Skipping control flow op %s", node)
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
):
|
|
68
|
-
#
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
if _is_non_deterministic_op(node):
|
|
79
|
-
# If the node is a non-deterministic op, we skip it.
|
|
80
|
-
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
|
-
node_info = (
|
|
84
|
-
node.op_identifier(),
|
|
85
|
-
len(node.outputs),
|
|
86
|
-
tuple(id(input) for input in node.inputs),
|
|
87
|
-
tuple(sorted(attributes.items())),
|
|
88
|
-
)
|
|
89
|
-
# Check if the node is a common subexpression.
|
|
90
|
-
if node_info in existing_node_info_to_the_node:
|
|
91
|
-
# If it is, this node has an existing node with the same
|
|
92
|
-
# operator, number of outputs, inputs, and attributes.
|
|
93
|
-
# We replace the node with the existing node.
|
|
94
|
-
modified = True
|
|
95
|
-
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
-
_remove_node_and_replace_values(
|
|
97
|
-
graph,
|
|
98
|
-
remove_node=node,
|
|
99
|
-
remove_values=node.outputs,
|
|
100
|
-
new_values=existing_node.outputs,
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if large_tensor:
|
|
103
|
+
# If the node has a large tensor, we skip it.
|
|
104
|
+
logger.debug("Skipping large tensor in node %s", node)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if _is_non_deterministic_op(node):
|
|
108
|
+
# If the node is a non-deterministic op, we skip it.
|
|
109
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
node_info = (
|
|
113
|
+
node.op_identifier(),
|
|
114
|
+
len(node.outputs),
|
|
115
|
+
tuple(id(input) for input in node.inputs),
|
|
116
|
+
tuple(sorted(attributes.items())),
|
|
101
117
|
)
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
118
|
+
# Check if the node is a common subexpression.
|
|
119
|
+
if node_info in existing_node_info_to_the_node:
|
|
120
|
+
# If it is, this node has an existing node with the same
|
|
121
|
+
# operator, number of outputs, inputs, and attributes.
|
|
122
|
+
# We replace the node with the existing node.
|
|
123
|
+
modified = True
|
|
124
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
125
|
+
_remove_node_and_replace_values(
|
|
126
|
+
graph,
|
|
127
|
+
remove_node=node,
|
|
128
|
+
remove_values=node.outputs,
|
|
129
|
+
new_values=existing_node.outputs,
|
|
130
|
+
)
|
|
131
|
+
logger.debug("Reusing node %s", existing_node)
|
|
132
|
+
else:
|
|
133
|
+
# If it is not, add to the mapping.
|
|
134
|
+
existing_node_info_to_the_node[node_info] = node
|
|
135
|
+
return modified
|
|
107
136
|
|
|
108
137
|
|
|
109
138
|
def _remove_node_and_replace_values(
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DeduplicateInitializersPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Remove duplicated initializer tensors from the graph.
|
|
17
|
+
|
|
18
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
19
|
+
and replaces all duplicate references with a canonical one.
|
|
20
|
+
|
|
21
|
+
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
22
|
+
to lift the initializers to the main graph first before running pass.
|
|
23
|
+
|
|
24
|
+
.. versionadded:: 0.1.3
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, size_limit: int = 1024):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.size_limit = size_limit
|
|
30
|
+
|
|
31
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
32
|
+
graph = model.graph
|
|
33
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
34
|
+
modified = False
|
|
35
|
+
|
|
36
|
+
for initializer in tuple(graph.initializers.values()):
|
|
37
|
+
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
|
|
38
|
+
# out from the main graph before running this pass.
|
|
39
|
+
const_val = initializer.const_value
|
|
40
|
+
if const_val is None:
|
|
41
|
+
# Skip if initializer has no constant value
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if const_val.size > self.size_limit:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
48
|
+
if key in initializers:
|
|
49
|
+
modified = True
|
|
50
|
+
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
|
|
51
|
+
assert initializer.name is not None
|
|
52
|
+
graph.initializers.pop(initializer.name)
|
|
53
|
+
else:
|
|
54
|
+
initializers[key] = initializer # type: ignore[index]
|
|
55
|
+
|
|
56
|
+
return ir.passes.PassResult(model=model, modified=modified)
|