onnx-ir 0.1.2__tar.gz → 0.1.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.2/src/onnx_ir.egg-info → onnx_ir-0.1.3}/PKG-INFO +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_convenience/__init__.py +5 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_core.py +25 -3
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_enums.py +2 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/__init__.py +4 -0
- onnx_ir-0.1.3/src/onnx_ir/passes/common/common_subexpression_elimination.py +206 -0
- onnx_ir-0.1.3/src/onnx_ir/passes/common/initializer_deduplication.py +56 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/serde.py +5 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/tensor_adapters.py +62 -7
- {onnx_ir-0.1.2 → onnx_ir-0.1.3/src/onnx_ir.egg-info}/PKG-INFO +1 -1
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
- onnx_ir-0.1.2/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -177
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/LICENSE +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/MANIFEST.in +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/README.md +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/pyproject.toml +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/setup.cfg +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.2 → onnx_ir-0.1.3}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -323,6 +323,9 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
|
323
323
|
and the first value with that name is returned. Values with empty names
|
|
324
324
|
are excluded from the mapping.
|
|
325
325
|
|
|
326
|
+
.. versionchanged:: 0.1.2
|
|
327
|
+
Values from subgraphs are now included in the mapping.
|
|
328
|
+
|
|
326
329
|
Args:
|
|
327
330
|
graph: The graph to extract the mapping from.
|
|
328
331
|
|
|
@@ -410,6 +413,8 @@ def get_const_tensor(
|
|
|
410
413
|
it will propagate the shape and type of the constant tensor to the value
|
|
411
414
|
if `propagate_shape_type` is set to True.
|
|
412
415
|
|
|
416
|
+
.. versionadded:: 0.1.2
|
|
417
|
+
|
|
413
418
|
Args:
|
|
414
419
|
value: The value to get the constant tensor from.
|
|
415
420
|
propagate_shape_type: If True, the shape and type of the value will be
|
|
@@ -417,6 +417,9 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
417
417
|
else:
|
|
418
418
|
self._shape = shape
|
|
419
419
|
self._shape.freeze()
|
|
420
|
+
if isinstance(value, np.generic):
|
|
421
|
+
# Turn numpy scalar into a numpy array
|
|
422
|
+
value = np.array(value) # type: ignore[assignment]
|
|
420
423
|
if dtype is None:
|
|
421
424
|
if isinstance(value, np.ndarray):
|
|
422
425
|
self._dtype = _enums.DataType.from_numpy(value.dtype)
|
|
@@ -964,7 +967,10 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
964
967
|
|
|
965
968
|
|
|
966
969
|
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
967
|
-
"""A tensor that stores 4bit datatypes in packed format.
|
|
970
|
+
"""A tensor that stores 4bit datatypes in packed format.
|
|
971
|
+
|
|
972
|
+
.. versionadded:: 0.1.2
|
|
973
|
+
"""
|
|
968
974
|
|
|
969
975
|
__slots__ = (
|
|
970
976
|
"_dtype",
|
|
@@ -2335,6 +2341,12 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2335
2341
|
seen as a Sequence of nodes and should be used as such. For example, to obtain
|
|
2336
2342
|
all nodes as a list, call ``list(graph)``.
|
|
2337
2343
|
|
|
2344
|
+
.. versionchanged:: 0.1.1
|
|
2345
|
+
Values with non-none producers will be rejected as graph inputs or initializers.
|
|
2346
|
+
|
|
2347
|
+
.. versionadded:: 0.1.1
|
|
2348
|
+
Added ``add`` method to initializers and attributes.
|
|
2349
|
+
|
|
2338
2350
|
Attributes:
|
|
2339
2351
|
name: The name of the graph.
|
|
2340
2352
|
inputs: The input values of the graph.
|
|
@@ -2545,12 +2557,17 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2545
2557
|
Consider using
|
|
2546
2558
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
2547
2559
|
traversals on nodes.
|
|
2560
|
+
|
|
2561
|
+
.. versionadded:: 0.1.2
|
|
2548
2562
|
"""
|
|
2549
2563
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
2550
2564
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
2551
2565
|
|
|
2552
2566
|
def subgraphs(self) -> Iterator[Graph]:
|
|
2553
|
-
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2567
|
+
"""Get all subgraphs in the graph in O(#nodes + #attributes) time.
|
|
2568
|
+
|
|
2569
|
+
.. versionadded:: 0.1.2
|
|
2570
|
+
"""
|
|
2554
2571
|
seen_graphs: set[Graph] = set()
|
|
2555
2572
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
2556
2573
|
graph = node.graph
|
|
@@ -3216,12 +3233,17 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3216
3233
|
Consider using
|
|
3217
3234
|
:class:`onnx_ir.traversal.RecursiveGraphIterator` for more advanced
|
|
3218
3235
|
traversals on nodes.
|
|
3236
|
+
|
|
3237
|
+
.. versionadded:: 0.1.2
|
|
3219
3238
|
"""
|
|
3220
3239
|
# NOTE: This is a method specific to Graph, not required by the protocol unless proven
|
|
3221
3240
|
return onnx_ir.traversal.RecursiveGraphIterator(self)
|
|
3222
3241
|
|
|
3223
3242
|
def subgraphs(self) -> Iterator[Graph]:
|
|
3224
|
-
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3243
|
+
"""Get all subgraphs in the function in O(#nodes + #attributes) time.
|
|
3244
|
+
|
|
3245
|
+
.. versionadded:: 0.1.2
|
|
3246
|
+
"""
|
|
3225
3247
|
seen_graphs: set[Graph] = set()
|
|
3226
3248
|
for node in onnx_ir.traversal.RecursiveGraphIterator(self):
|
|
3227
3249
|
graph = node.graph
|
|
@@ -6,6 +6,7 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateInitializersPass",
|
|
9
10
|
"InlinePass",
|
|
10
11
|
"LiftConstantsToInitializersPass",
|
|
11
12
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
@@ -29,6 +30,9 @@ from onnx_ir.passes.common.constant_manipulation import (
|
|
|
29
30
|
LiftSubgraphInitializersToMainGraphPass,
|
|
30
31
|
RemoveInitializersFromInputsPass,
|
|
31
32
|
)
|
|
33
|
+
from onnx_ir.passes.common.initializer_deduplication import (
|
|
34
|
+
DeduplicateInitializersPass,
|
|
35
|
+
)
|
|
32
36
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
33
37
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
34
38
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"CommonSubexpressionEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
from collections.abc import Sequence
|
|
13
|
+
|
|
14
|
+
import onnx_ir as ir
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
+
"""Eliminate common subexpression in ONNX graphs.
|
|
21
|
+
|
|
22
|
+
.. versionadded:: 0.1.1
|
|
23
|
+
|
|
24
|
+
.. versionchanged:: 0.1.3
|
|
25
|
+
Constant nodes with values smaller than ``size_limit`` will be CSE'd.
|
|
26
|
+
|
|
27
|
+
Attributes:
|
|
28
|
+
size_limit: The maximum size of the tensor to be csed. If the tensor contains
|
|
29
|
+
number of elements larger than size_limit, it will not be cse'd. Default is 10.
|
|
30
|
+
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, size_limit: int = 10):
|
|
34
|
+
"""Initialize the CommonSubexpressionEliminationPass."""
|
|
35
|
+
super().__init__()
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
40
|
+
graph = model.graph
|
|
41
|
+
modified = self._eliminate_common_subexpression(graph)
|
|
42
|
+
|
|
43
|
+
return ir.passes.PassResult(
|
|
44
|
+
model,
|
|
45
|
+
modified=modified,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def _eliminate_common_subexpression(self, graph: ir.Graph) -> bool:
|
|
49
|
+
"""Eliminate common subexpression in ONNX graphs."""
|
|
50
|
+
modified: bool = False
|
|
51
|
+
# node to node identifier, length of outputs, inputs, and attributes
|
|
52
|
+
existing_node_info_to_the_node: dict[
|
|
53
|
+
tuple[
|
|
54
|
+
ir.OperatorIdentifier,
|
|
55
|
+
int, # len(outputs)
|
|
56
|
+
tuple[int, ...], # input ids
|
|
57
|
+
tuple[tuple[str, object], ...], # attributes
|
|
58
|
+
],
|
|
59
|
+
ir.Node,
|
|
60
|
+
] = {}
|
|
61
|
+
|
|
62
|
+
for node in graph:
|
|
63
|
+
# Skip control flow ops like Loop and If.
|
|
64
|
+
control_flow_op: bool = False
|
|
65
|
+
# Skip large tensors to avoid cse weights and bias.
|
|
66
|
+
large_tensor: bool = False
|
|
67
|
+
# Use equality to check if the node is a common subexpression.
|
|
68
|
+
attributes = {}
|
|
69
|
+
for k, v in node.attributes.items():
|
|
70
|
+
# TODO(exporter team): CSE subgraphs.
|
|
71
|
+
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
72
|
+
# because attribute: graph won't match.
|
|
73
|
+
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
74
|
+
control_flow_op = True
|
|
75
|
+
break
|
|
76
|
+
# The attribute value could be directly taken from the original
|
|
77
|
+
# protobuf, so we need to make a copy of it.
|
|
78
|
+
value = v.value
|
|
79
|
+
if v.type in (
|
|
80
|
+
ir.AttributeType.INTS,
|
|
81
|
+
ir.AttributeType.FLOATS,
|
|
82
|
+
ir.AttributeType.STRINGS,
|
|
83
|
+
):
|
|
84
|
+
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
85
|
+
# to ensure they are hashable.
|
|
86
|
+
value = tuple(value)
|
|
87
|
+
elif v.type is ir.AttributeType.TENSOR:
|
|
88
|
+
if value.size > self.size_limit:
|
|
89
|
+
# If the tensor is larger than the size limit, we skip it.
|
|
90
|
+
large_tensor = True
|
|
91
|
+
break
|
|
92
|
+
np_value = value.numpy()
|
|
93
|
+
|
|
94
|
+
value = (np_value.shape, str(np_value.dtype), np_value.tobytes())
|
|
95
|
+
attributes[k] = value
|
|
96
|
+
|
|
97
|
+
if control_flow_op:
|
|
98
|
+
# If the node is a control flow op, we skip it.
|
|
99
|
+
logger.debug("Skipping control flow op %s", node)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if large_tensor:
|
|
103
|
+
# If the node has a large tensor, we skip it.
|
|
104
|
+
logger.debug("Skipping large tensor in node %s", node)
|
|
105
|
+
continue
|
|
106
|
+
|
|
107
|
+
if _is_non_deterministic_op(node):
|
|
108
|
+
# If the node is a non-deterministic op, we skip it.
|
|
109
|
+
logger.debug("Skipping non-deterministic op %s", node)
|
|
110
|
+
continue
|
|
111
|
+
|
|
112
|
+
node_info = (
|
|
113
|
+
node.op_identifier(),
|
|
114
|
+
len(node.outputs),
|
|
115
|
+
tuple(id(input) for input in node.inputs),
|
|
116
|
+
tuple(sorted(attributes.items())),
|
|
117
|
+
)
|
|
118
|
+
# Check if the node is a common subexpression.
|
|
119
|
+
if node_info in existing_node_info_to_the_node:
|
|
120
|
+
# If it is, this node has an existing node with the same
|
|
121
|
+
# operator, number of outputs, inputs, and attributes.
|
|
122
|
+
# We replace the node with the existing node.
|
|
123
|
+
modified = True
|
|
124
|
+
existing_node = existing_node_info_to_the_node[node_info]
|
|
125
|
+
_remove_node_and_replace_values(
|
|
126
|
+
graph,
|
|
127
|
+
remove_node=node,
|
|
128
|
+
remove_values=node.outputs,
|
|
129
|
+
new_values=existing_node.outputs,
|
|
130
|
+
)
|
|
131
|
+
logger.debug("Reusing node %s", existing_node)
|
|
132
|
+
else:
|
|
133
|
+
# If it is not, add to the mapping.
|
|
134
|
+
existing_node_info_to_the_node[node_info] = node
|
|
135
|
+
return modified
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _remove_node_and_replace_values(
|
|
139
|
+
graph: ir.Graph,
|
|
140
|
+
/,
|
|
141
|
+
remove_node: ir.Node,
|
|
142
|
+
remove_values: Sequence[ir.Value],
|
|
143
|
+
new_values: Sequence[ir.Value],
|
|
144
|
+
) -> None:
|
|
145
|
+
"""Replaces nodes and values in the graph or function.
|
|
146
|
+
|
|
147
|
+
Args:
|
|
148
|
+
graph: The graph to replace nodes and values in.
|
|
149
|
+
remove_node: The node to remove.
|
|
150
|
+
remove_values: The values to replace.
|
|
151
|
+
new_values: The values to replace with.
|
|
152
|
+
"""
|
|
153
|
+
# Reconnect the users of the deleted values to use the new values
|
|
154
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
155
|
+
# Update graph/function outputs if the node generates output
|
|
156
|
+
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
157
|
+
replacement_mapping = dict(zip(remove_values, new_values))
|
|
158
|
+
for idx, graph_output in enumerate(graph.outputs):
|
|
159
|
+
if graph_output in replacement_mapping:
|
|
160
|
+
new_value = replacement_mapping[graph_output]
|
|
161
|
+
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
162
|
+
# If the new value is also a graph input/output, we need to
|
|
163
|
+
# create a Identity node to preserve the remove_value and
|
|
164
|
+
# prevent from changing new_value name.
|
|
165
|
+
identity_node = ir.node(
|
|
166
|
+
"Identity",
|
|
167
|
+
inputs=[new_value],
|
|
168
|
+
outputs=[
|
|
169
|
+
ir.Value(
|
|
170
|
+
name=graph_output.name,
|
|
171
|
+
type=graph_output.type,
|
|
172
|
+
shape=graph_output.shape,
|
|
173
|
+
)
|
|
174
|
+
],
|
|
175
|
+
)
|
|
176
|
+
# reuse the name of the graph output
|
|
177
|
+
graph.outputs[idx] = identity_node.outputs[0]
|
|
178
|
+
graph.insert_before(
|
|
179
|
+
remove_node,
|
|
180
|
+
identity_node,
|
|
181
|
+
)
|
|
182
|
+
else:
|
|
183
|
+
# if new_value is not graph output, we just
|
|
184
|
+
# update it to use old_value name.
|
|
185
|
+
new_value.name = graph_output.name
|
|
186
|
+
graph.outputs[idx] = new_value
|
|
187
|
+
|
|
188
|
+
graph.remove(remove_node, safe=True)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
192
|
+
non_deterministic_ops = frozenset(
|
|
193
|
+
{
|
|
194
|
+
"RandomUniform",
|
|
195
|
+
"RandomNormal",
|
|
196
|
+
"RandomUniformLike",
|
|
197
|
+
"RandomNormalLike",
|
|
198
|
+
"Multinomial",
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _is_onnx_domain(d: str) -> bool:
|
|
205
|
+
"""Check if the domain is the ONNX domain."""
|
|
206
|
+
return d == ""
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"DeduplicateInitializersPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
import onnx_ir as ir
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
16
|
+
"""Remove duplicated initializer tensors from the graph.
|
|
17
|
+
|
|
18
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
19
|
+
and replaces all duplicate references with a canonical one.
|
|
20
|
+
|
|
21
|
+
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
22
|
+
to lift the initializers to the main graph first before running pass.
|
|
23
|
+
|
|
24
|
+
.. versionadded:: 0.1.3
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
def __init__(self, size_limit: int = 1024):
|
|
28
|
+
super().__init__()
|
|
29
|
+
self.size_limit = size_limit
|
|
30
|
+
|
|
31
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
32
|
+
graph = model.graph
|
|
33
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
34
|
+
modified = False
|
|
35
|
+
|
|
36
|
+
for initializer in tuple(graph.initializers.values()):
|
|
37
|
+
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
|
|
38
|
+
# out from the main graph before running this pass.
|
|
39
|
+
const_val = initializer.const_value
|
|
40
|
+
if const_val is None:
|
|
41
|
+
# Skip if initializer has no constant value
|
|
42
|
+
continue
|
|
43
|
+
|
|
44
|
+
if const_val.size > self.size_limit:
|
|
45
|
+
continue
|
|
46
|
+
|
|
47
|
+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
48
|
+
if key in initializers:
|
|
49
|
+
modified = True
|
|
50
|
+
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
|
|
51
|
+
assert initializer.name is not None
|
|
52
|
+
graph.initializers.pop(initializer.name)
|
|
53
|
+
else:
|
|
54
|
+
initializers[key] = initializer # type: ignore[index]
|
|
55
|
+
|
|
56
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
@@ -200,6 +200,9 @@ def from_onnx_text(
|
|
|
200
200
|
|
|
201
201
|
Read more about the textual representation at: https://onnx.ai/onnx/repo-docs/Syntax.html
|
|
202
202
|
|
|
203
|
+
.. versionchanged:: 0.1.2
|
|
204
|
+
Added the ``initializers`` argument.
|
|
205
|
+
|
|
203
206
|
Args:
|
|
204
207
|
model_text: The ONNX textual representation of the model.
|
|
205
208
|
initializers: Tensors to be added as initializers. If provided, these tensors
|
|
@@ -237,6 +240,8 @@ def to_onnx_text(
|
|
|
237
240
|
) -> str:
|
|
238
241
|
"""Convert the IR model to the ONNX textual representation.
|
|
239
242
|
|
|
243
|
+
.. versionadded:: 0.1.2
|
|
244
|
+
|
|
240
245
|
Args:
|
|
241
246
|
model: The IR model to convert.
|
|
242
247
|
exclude_initializers: If True, the initializers will not be included in the output.
|
|
@@ -29,6 +29,8 @@ Example::
|
|
|
29
29
|
from __future__ import annotations
|
|
30
30
|
|
|
31
31
|
__all__ = [
|
|
32
|
+
"from_torch_dtype",
|
|
33
|
+
"to_torch_dtype",
|
|
32
34
|
"TorchTensor",
|
|
33
35
|
]
|
|
34
36
|
|
|
@@ -44,14 +46,17 @@ if TYPE_CHECKING:
|
|
|
44
46
|
import torch
|
|
45
47
|
|
|
46
48
|
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
_TORCH_DTYPE_TO_ONNX: dict[torch.dtype, ir.DataType] | None = None
|
|
50
|
+
_ONNX_DTYPE_TO_TORCH: dict[ir.DataType, torch.dtype] | None = None
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
54
|
+
"""Convert a PyTorch dtype to an ONNX IR DataType."""
|
|
55
|
+
global _TORCH_DTYPE_TO_ONNX
|
|
56
|
+
if _TORCH_DTYPE_TO_ONNX is None:
|
|
52
57
|
import torch
|
|
53
58
|
|
|
54
|
-
_TORCH_DTYPE_TO_ONNX
|
|
59
|
+
_TORCH_DTYPE_TO_ONNX = {
|
|
55
60
|
torch.bfloat16: ir.DataType.BFLOAT16,
|
|
56
61
|
torch.bool: ir.DataType.BOOL,
|
|
57
62
|
torch.complex128: ir.DataType.COMPLEX128,
|
|
@@ -72,8 +77,58 @@ class TorchTensor(_core.Tensor):
|
|
|
72
77
|
torch.uint32: ir.DataType.UINT32,
|
|
73
78
|
torch.uint64: ir.DataType.UINT64,
|
|
74
79
|
}
|
|
80
|
+
if dtype not in _TORCH_DTYPE_TO_ONNX:
|
|
81
|
+
raise TypeError(
|
|
82
|
+
f"Unsupported PyTorch dtype '{dtype}'. "
|
|
83
|
+
"Please use a supported dtype from the list: "
|
|
84
|
+
f"{list(_TORCH_DTYPE_TO_ONNX.keys())}"
|
|
85
|
+
)
|
|
86
|
+
return _TORCH_DTYPE_TO_ONNX[dtype]
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
90
|
+
"""Convert an ONNX IR DataType to a PyTorch dtype."""
|
|
91
|
+
global _ONNX_DTYPE_TO_TORCH
|
|
92
|
+
if _ONNX_DTYPE_TO_TORCH is None:
|
|
93
|
+
import torch
|
|
94
|
+
|
|
95
|
+
_ONNX_DTYPE_TO_TORCH = {
|
|
96
|
+
ir.DataType.BFLOAT16: torch.bfloat16,
|
|
97
|
+
ir.DataType.BOOL: torch.bool,
|
|
98
|
+
ir.DataType.COMPLEX128: torch.complex128,
|
|
99
|
+
ir.DataType.COMPLEX64: torch.complex64,
|
|
100
|
+
ir.DataType.FLOAT16: torch.float16,
|
|
101
|
+
ir.DataType.FLOAT: torch.float32,
|
|
102
|
+
ir.DataType.DOUBLE: torch.float64,
|
|
103
|
+
ir.DataType.FLOAT8E4M3FN: torch.float8_e4m3fn,
|
|
104
|
+
ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
|
|
105
|
+
ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
|
|
106
|
+
ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
|
|
107
|
+
ir.DataType.INT16: torch.int16,
|
|
108
|
+
ir.DataType.INT32: torch.int32,
|
|
109
|
+
ir.DataType.INT64: torch.int64,
|
|
110
|
+
ir.DataType.INT8: torch.int8,
|
|
111
|
+
ir.DataType.UINT8: torch.uint8,
|
|
112
|
+
ir.DataType.UINT16: torch.uint16,
|
|
113
|
+
ir.DataType.UINT32: torch.uint32,
|
|
114
|
+
ir.DataType.UINT64: torch.uint64,
|
|
115
|
+
}
|
|
116
|
+
if dtype not in _ONNX_DTYPE_TO_TORCH:
|
|
117
|
+
raise TypeError(
|
|
118
|
+
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
|
|
119
|
+
"Please use a supported dtype from the list: "
|
|
120
|
+
f"{list(_ONNX_DTYPE_TO_TORCH.keys())}"
|
|
121
|
+
)
|
|
122
|
+
return _ONNX_DTYPE_TO_TORCH[dtype]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class TorchTensor(_core.Tensor):
|
|
126
|
+
def __init__(
|
|
127
|
+
self, tensor: torch.Tensor, name: str | None = None, doc_string: str | None = None
|
|
128
|
+
):
|
|
129
|
+
# Pass the tensor as the raw data to ir.Tensor's constructor
|
|
75
130
|
super().__init__(
|
|
76
|
-
tensor, dtype=
|
|
131
|
+
tensor, dtype=from_torch_dtype(tensor.dtype), name=name, doc_string=doc_string
|
|
77
132
|
)
|
|
78
133
|
|
|
79
134
|
def numpy(self) -> npt.NDArray:
|
|
@@ -39,6 +39,7 @@ src/onnx_ir/passes/common/_c_api_utils.py
|
|
|
39
39
|
src/onnx_ir/passes/common/clear_metadata_and_docstring.py
|
|
40
40
|
src/onnx_ir/passes/common/common_subexpression_elimination.py
|
|
41
41
|
src/onnx_ir/passes/common/constant_manipulation.py
|
|
42
|
+
src/onnx_ir/passes/common/initializer_deduplication.py
|
|
42
43
|
src/onnx_ir/passes/common/inliner.py
|
|
43
44
|
src/onnx_ir/passes/common/onnx_checker.py
|
|
44
45
|
src/onnx_ir/passes/common/shape_inference.py
|
|
@@ -1,177 +0,0 @@
|
|
|
1
|
-
# Copyright (c) ONNX Project Contributors
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
4
|
-
|
|
5
|
-
from __future__ import annotations
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"CommonSubexpressionEliminationPass",
|
|
9
|
-
]
|
|
10
|
-
|
|
11
|
-
import logging
|
|
12
|
-
from collections.abc import Sequence
|
|
13
|
-
|
|
14
|
-
import onnx_ir as ir
|
|
15
|
-
|
|
16
|
-
logger = logging.getLogger(__name__)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
|
|
20
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
21
|
-
|
|
22
|
-
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
23
|
-
"""Return the same ir.Model but with CSE applied to the graph."""
|
|
24
|
-
modified = False
|
|
25
|
-
graph = model.graph
|
|
26
|
-
|
|
27
|
-
modified = _eliminate_common_subexpression(graph, modified)
|
|
28
|
-
|
|
29
|
-
return ir.passes.PassResult(
|
|
30
|
-
model,
|
|
31
|
-
modified=modified,
|
|
32
|
-
)
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
|
|
36
|
-
"""Eliminate common subexpression in ONNX graphs."""
|
|
37
|
-
# node to node identifier, length of outputs, inputs, and attributes
|
|
38
|
-
existing_node_info_to_the_node: dict[
|
|
39
|
-
tuple[
|
|
40
|
-
ir.OperatorIdentifier,
|
|
41
|
-
int, # len(outputs)
|
|
42
|
-
tuple[int, ...], # input ids
|
|
43
|
-
tuple[tuple[str, object], ...], # attributes
|
|
44
|
-
],
|
|
45
|
-
ir.Node,
|
|
46
|
-
] = {}
|
|
47
|
-
|
|
48
|
-
for node in graph:
|
|
49
|
-
# Skip control flow ops like Loop and If.
|
|
50
|
-
control_flow_op: bool = False
|
|
51
|
-
# Use equality to check if the node is a common subexpression.
|
|
52
|
-
attributes = {}
|
|
53
|
-
for k, v in node.attributes.items():
|
|
54
|
-
# TODO(exporter team): CSE subgraphs.
|
|
55
|
-
# NOTE: control flow ops like Loop and If won't be CSEd
|
|
56
|
-
# because attribute: graph won't match.
|
|
57
|
-
if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS):
|
|
58
|
-
control_flow_op = True
|
|
59
|
-
logger.debug("Skipping control flow op %s", node)
|
|
60
|
-
# The attribute value could be directly taken from the original
|
|
61
|
-
# protobuf, so we need to make a copy of it.
|
|
62
|
-
value = v.value
|
|
63
|
-
if v.type in (
|
|
64
|
-
ir.AttributeType.INTS,
|
|
65
|
-
ir.AttributeType.FLOATS,
|
|
66
|
-
ir.AttributeType.STRINGS,
|
|
67
|
-
):
|
|
68
|
-
# For INT, FLOAT and STRING attributes, we convert them to tuples
|
|
69
|
-
# to ensure they are hashable.
|
|
70
|
-
value = tuple(value)
|
|
71
|
-
attributes[k] = value
|
|
72
|
-
|
|
73
|
-
if control_flow_op:
|
|
74
|
-
# If the node is a control flow op, we skip it.
|
|
75
|
-
logger.debug("Skipping control flow op %s", node)
|
|
76
|
-
continue
|
|
77
|
-
|
|
78
|
-
if _is_non_deterministic_op(node):
|
|
79
|
-
# If the node is a non-deterministic op, we skip it.
|
|
80
|
-
logger.debug("Skipping non-deterministic op %s", node)
|
|
81
|
-
continue
|
|
82
|
-
|
|
83
|
-
node_info = (
|
|
84
|
-
node.op_identifier(),
|
|
85
|
-
len(node.outputs),
|
|
86
|
-
tuple(id(input) for input in node.inputs),
|
|
87
|
-
tuple(sorted(attributes.items())),
|
|
88
|
-
)
|
|
89
|
-
# Check if the node is a common subexpression.
|
|
90
|
-
if node_info in existing_node_info_to_the_node:
|
|
91
|
-
# If it is, this node has an existing node with the same
|
|
92
|
-
# operator, number of outputs, inputs, and attributes.
|
|
93
|
-
# We replace the node with the existing node.
|
|
94
|
-
modified = True
|
|
95
|
-
existing_node = existing_node_info_to_the_node[node_info]
|
|
96
|
-
_remove_node_and_replace_values(
|
|
97
|
-
graph,
|
|
98
|
-
remove_node=node,
|
|
99
|
-
remove_values=node.outputs,
|
|
100
|
-
new_values=existing_node.outputs,
|
|
101
|
-
)
|
|
102
|
-
logger.debug("Reusing node %s", existing_node)
|
|
103
|
-
else:
|
|
104
|
-
# If it is not, add to the mapping.
|
|
105
|
-
existing_node_info_to_the_node[node_info] = node
|
|
106
|
-
return modified
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
def _remove_node_and_replace_values(
|
|
110
|
-
graph: ir.Graph,
|
|
111
|
-
/,
|
|
112
|
-
remove_node: ir.Node,
|
|
113
|
-
remove_values: Sequence[ir.Value],
|
|
114
|
-
new_values: Sequence[ir.Value],
|
|
115
|
-
) -> None:
|
|
116
|
-
"""Replaces nodes and values in the graph or function.
|
|
117
|
-
|
|
118
|
-
Args:
|
|
119
|
-
graph: The graph to replace nodes and values in.
|
|
120
|
-
remove_node: The node to remove.
|
|
121
|
-
remove_values: The values to replace.
|
|
122
|
-
new_values: The values to replace with.
|
|
123
|
-
"""
|
|
124
|
-
# Reconnect the users of the deleted values to use the new values
|
|
125
|
-
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
126
|
-
# Update graph/function outputs if the node generates output
|
|
127
|
-
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
128
|
-
replacement_mapping = dict(zip(remove_values, new_values))
|
|
129
|
-
for idx, graph_output in enumerate(graph.outputs):
|
|
130
|
-
if graph_output in replacement_mapping:
|
|
131
|
-
new_value = replacement_mapping[graph_output]
|
|
132
|
-
if new_value.is_graph_output() or new_value.is_graph_input():
|
|
133
|
-
# If the new value is also a graph input/output, we need to
|
|
134
|
-
# create a Identity node to preserve the remove_value and
|
|
135
|
-
# prevent from changing new_value name.
|
|
136
|
-
identity_node = ir.node(
|
|
137
|
-
"Identity",
|
|
138
|
-
inputs=[new_value],
|
|
139
|
-
outputs=[
|
|
140
|
-
ir.Value(
|
|
141
|
-
name=graph_output.name,
|
|
142
|
-
type=graph_output.type,
|
|
143
|
-
shape=graph_output.shape,
|
|
144
|
-
)
|
|
145
|
-
],
|
|
146
|
-
)
|
|
147
|
-
# reuse the name of the graph output
|
|
148
|
-
graph.outputs[idx] = identity_node.outputs[0]
|
|
149
|
-
graph.insert_before(
|
|
150
|
-
remove_node,
|
|
151
|
-
identity_node,
|
|
152
|
-
)
|
|
153
|
-
else:
|
|
154
|
-
# if new_value is not graph output, we just
|
|
155
|
-
# update it to use old_value name.
|
|
156
|
-
new_value.name = graph_output.name
|
|
157
|
-
graph.outputs[idx] = new_value
|
|
158
|
-
|
|
159
|
-
graph.remove(remove_node, safe=True)
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
def _is_non_deterministic_op(node: ir.Node) -> bool:
|
|
163
|
-
non_deterministic_ops = frozenset(
|
|
164
|
-
{
|
|
165
|
-
"RandomUniform",
|
|
166
|
-
"RandomNormal",
|
|
167
|
-
"RandomUniformLike",
|
|
168
|
-
"RandomNormalLike",
|
|
169
|
-
"Multinomial",
|
|
170
|
-
}
|
|
171
|
-
)
|
|
172
|
-
return node.op_type in non_deterministic_ops and _is_onnx_domain(node.domain)
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
def _is_onnx_domain(d: str) -> bool:
|
|
176
|
-
"""Check if the domain is the ONNX domain."""
|
|
177
|
-
return d == ""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|