onnx-ir 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- onnx_ir/__init__.py +23 -10
- onnx_ir/{_convenience.py → _convenience/__init__.py} +40 -102
- onnx_ir/_convenience/_constructors.py +213 -0
- onnx_ir/_core.py +857 -233
- onnx_ir/_display.py +2 -2
- onnx_ir/_enums.py +107 -5
- onnx_ir/_graph_comparison.py +2 -2
- onnx_ir/_graph_containers.py +268 -0
- onnx_ir/_io.py +57 -10
- onnx_ir/_linked_list.py +15 -7
- onnx_ir/_metadata.py +4 -3
- onnx_ir/_name_authority.py +2 -2
- onnx_ir/_polyfill.py +26 -0
- onnx_ir/_protocols.py +31 -13
- onnx_ir/_tape.py +139 -32
- onnx_ir/_thirdparty/asciichartpy.py +1 -4
- onnx_ir/_type_casting.py +18 -3
- onnx_ir/{_internal/version_utils.py → _version_utils.py} +2 -29
- onnx_ir/convenience.py +4 -2
- onnx_ir/external_data.py +401 -0
- onnx_ir/passes/__init__.py +8 -2
- onnx_ir/passes/_pass_infra.py +173 -56
- onnx_ir/passes/common/__init__.py +36 -0
- onnx_ir/passes/common/_c_api_utils.py +76 -0
- onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
- onnx_ir/passes/common/constant_manipulation.py +232 -0
- onnx_ir/passes/common/inliner.py +331 -0
- onnx_ir/passes/common/onnx_checker.py +57 -0
- onnx_ir/passes/common/shape_inference.py +112 -0
- onnx_ir/passes/common/topological_sort.py +33 -0
- onnx_ir/passes/common/unused_removal.py +196 -0
- onnx_ir/serde.py +288 -124
- onnx_ir/tape.py +15 -0
- onnx_ir/tensor_adapters.py +122 -0
- onnx_ir/testing.py +197 -0
- onnx_ir/traversal.py +4 -3
- onnx_ir-0.1.0.dist-info/METADATA +53 -0
- onnx_ir-0.1.0.dist-info/RECORD +41 -0
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/WHEEL +1 -1
- onnx_ir-0.1.0.dist-info/licenses/LICENSE +202 -0
- onnx_ir/_external_data.py +0 -323
- onnx_ir-0.0.1.dist-info/LICENSE +0 -22
- onnx_ir-0.0.1.dist-info/METADATA +0 -73
- onnx_ir-0.0.1.dist-info/RECORD +0 -26
- {onnx_ir-0.0.1.dist-info → onnx_ir-0.1.0.dist-info}/top_level.txt +0 -0
onnx_ir/passes/_pass_infra.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
|
-
# Copyright (c)
|
|
2
|
-
#
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
3
|
#
|
|
4
4
|
# This module implements some APIs described in
|
|
5
5
|
# https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html
|
|
@@ -16,10 +16,14 @@ from __future__ import annotations
|
|
|
16
16
|
|
|
17
17
|
import dataclasses
|
|
18
18
|
import logging
|
|
19
|
-
from
|
|
19
|
+
from collections.abc import Sequence
|
|
20
|
+
from typing import Literal, final
|
|
20
21
|
|
|
21
22
|
__all__ = [
|
|
22
23
|
"PassBase",
|
|
24
|
+
"Sequential",
|
|
25
|
+
"InPlacePass",
|
|
26
|
+
"FunctionalPass",
|
|
23
27
|
"PassManager",
|
|
24
28
|
"PassResult",
|
|
25
29
|
# Errors
|
|
@@ -58,7 +62,7 @@ class PassResult:
|
|
|
58
62
|
|
|
59
63
|
Attributes:
|
|
60
64
|
model: The transformed model.
|
|
61
|
-
modified: Whether the model
|
|
65
|
+
modified: Whether the resulting model is different from the input model.
|
|
62
66
|
"""
|
|
63
67
|
|
|
64
68
|
model: ir.Model
|
|
@@ -68,14 +72,89 @@ class PassResult:
|
|
|
68
72
|
class PassBase(abc.ABC):
|
|
69
73
|
"""Base class for all passes.
|
|
70
74
|
|
|
71
|
-
|
|
72
|
-
|
|
75
|
+
``in_place`` and ``changes_input`` properties and what they mean:
|
|
76
|
+
|
|
77
|
+
+------------+------------------+----------------------------+
|
|
78
|
+
| | changes_inputs | not changes_inputs |
|
|
79
|
+
+------------+------------------+----------------------------+
|
|
80
|
+
| in_place | in place | Side-effect-only pass |
|
|
81
|
+
+------------+------------------+----------------------------+
|
|
82
|
+
| not | destructive | functional |
|
|
83
|
+
| in_place | | |
|
|
84
|
+
+------------+------------------+----------------------------+
|
|
73
85
|
"""
|
|
74
86
|
|
|
75
|
-
|
|
87
|
+
@property
|
|
88
|
+
@abc.abstractmethod
|
|
89
|
+
def in_place(self) -> bool:
|
|
90
|
+
"""Whether the pass modifies the model in place and returns it.
|
|
91
|
+
|
|
92
|
+
If True, the pass will return the same model object that was passed in.
|
|
93
|
+
If False, the pass will return a new model object.
|
|
94
|
+
"""
|
|
95
|
+
raise NotImplementedError
|
|
96
|
+
|
|
97
|
+
@property
|
|
98
|
+
@abc.abstractmethod
|
|
99
|
+
def changes_input(self) -> bool:
|
|
100
|
+
"""Whether the pass modifies input model."""
|
|
101
|
+
raise NotImplementedError
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def destructive(self) -> bool:
|
|
105
|
+
"""Whether the pass will destroy the input model when ``in_place=False``.
|
|
76
106
|
|
|
77
|
-
|
|
78
|
-
|
|
107
|
+
A pass is destructive if it is not in place and it modifies the input model.
|
|
108
|
+
"""
|
|
109
|
+
return not self.in_place and self.changes_input
|
|
110
|
+
|
|
111
|
+
def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult:
|
|
112
|
+
if isinstance(model_or_result, PassResult):
|
|
113
|
+
model = model_or_result.model
|
|
114
|
+
else:
|
|
115
|
+
model = model_or_result
|
|
116
|
+
# Check preconditions
|
|
117
|
+
try:
|
|
118
|
+
self.requires(model)
|
|
119
|
+
except PreconditionError:
|
|
120
|
+
raise
|
|
121
|
+
except Exception as e:
|
|
122
|
+
raise PreconditionError(
|
|
123
|
+
f"Pre-condition for pass '{self.__class__.__name__}' failed"
|
|
124
|
+
) from e
|
|
125
|
+
|
|
126
|
+
result = self.call(model)
|
|
127
|
+
|
|
128
|
+
# Check postconditions
|
|
129
|
+
try:
|
|
130
|
+
self.ensures(model)
|
|
131
|
+
except PostconditionError:
|
|
132
|
+
raise
|
|
133
|
+
except Exception as e:
|
|
134
|
+
raise PostconditionError(
|
|
135
|
+
f"Post-condition for pass '{self.__class__.__name__}' failed"
|
|
136
|
+
) from e
|
|
137
|
+
|
|
138
|
+
if not isinstance(result, PassResult):
|
|
139
|
+
raise TypeError(
|
|
140
|
+
f"The result of the pass '{self.__class__.__name__}' should be type PassResult. "
|
|
141
|
+
"Please create one with ir.passes.PassResult()."
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
# Checks that the declared in-place property is respected
|
|
145
|
+
if self.in_place and result.model is not model:
|
|
146
|
+
raise PassError(
|
|
147
|
+
f"The pass '{self.__class__.__name__}' is declared in-place, "
|
|
148
|
+
"but the model returned is *not* the same object as the input model. "
|
|
149
|
+
"Pass developer: Pass should return the same model object or the in_place property should return False."
|
|
150
|
+
)
|
|
151
|
+
if not self.in_place and result.model is model:
|
|
152
|
+
raise PassError(
|
|
153
|
+
f"The pass '{self.__class__.__name__}' is declared not in-place, "
|
|
154
|
+
"but the model returned *is* the same object as the input model. "
|
|
155
|
+
"Pass developer: Pass should return a new model object or the in_place property should return True."
|
|
156
|
+
)
|
|
157
|
+
return result
|
|
79
158
|
|
|
80
159
|
@abc.abstractmethod
|
|
81
160
|
def call(self, model: ir.Model) -> PassResult:
|
|
@@ -97,76 +176,114 @@ class PassBase(abc.ABC):
|
|
|
97
176
|
del model # Unused
|
|
98
177
|
|
|
99
178
|
|
|
100
|
-
class
|
|
179
|
+
class InPlacePass(PassBase):
|
|
180
|
+
"""A pass that modifies the input model in place and returns it."""
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
@final
|
|
184
|
+
def in_place(self) -> Literal[True]:
|
|
185
|
+
"""An in-place pass is in place."""
|
|
186
|
+
return True
|
|
187
|
+
|
|
188
|
+
@property
|
|
189
|
+
@final
|
|
190
|
+
def changes_input(self) -> Literal[True]:
|
|
191
|
+
"""An in-place pass changes the input model."""
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class FunctionalPass(PassBase):
|
|
196
|
+
"""A pass that returns a new model but does not modify the input model."""
|
|
197
|
+
|
|
198
|
+
@property
|
|
199
|
+
@final
|
|
200
|
+
def in_place(self) -> Literal[False]:
|
|
201
|
+
"""A functional pass is not in place."""
|
|
202
|
+
return False
|
|
203
|
+
|
|
204
|
+
@property
|
|
205
|
+
@final
|
|
206
|
+
def changes_input(self) -> Literal[False]:
|
|
207
|
+
"""A functional pass does not change the input model."""
|
|
208
|
+
return False
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
class Sequential(PassBase):
|
|
212
|
+
"""Run a sequence of passes in order."""
|
|
213
|
+
|
|
214
|
+
def __init__(self, *passes: PassBase):
|
|
215
|
+
if not passes:
|
|
216
|
+
raise ValueError("Sequential must take at least one pass")
|
|
217
|
+
self.passes = passes
|
|
218
|
+
self._in_place = all(pass_.in_place for pass_ in passes)
|
|
219
|
+
# The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
|
|
220
|
+
# or if it is not designed to be in-place but somehow changes the input (destructive),
|
|
221
|
+
# this pass sequence will change inputs.
|
|
222
|
+
self._changes_input = self.passes[0].changes_input or self.passes[0].in_place
|
|
223
|
+
|
|
224
|
+
@property
|
|
225
|
+
def in_place(self) -> bool:
|
|
226
|
+
return self._in_place
|
|
227
|
+
|
|
228
|
+
@property
|
|
229
|
+
def changes_input(self) -> bool:
|
|
230
|
+
return self._changes_input
|
|
231
|
+
|
|
232
|
+
def call(self, model: ir.Model) -> PassResult:
|
|
233
|
+
modified = False
|
|
234
|
+
for i, pass_ in enumerate(self.passes):
|
|
235
|
+
logger.debug("Running the %s-th pass '%s'", i, pass_)
|
|
236
|
+
try:
|
|
237
|
+
pass_result = pass_(model)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
prev_pass_names = [str(p) for p in self.passes[:i]]
|
|
240
|
+
raise PassError(
|
|
241
|
+
f"An error occurred when running the '{pass_}' pass after the "
|
|
242
|
+
f"following passes: {prev_pass_names}"
|
|
243
|
+
) from e
|
|
244
|
+
|
|
245
|
+
model = pass_result.model
|
|
246
|
+
modified = modified or pass_result.modified
|
|
247
|
+
|
|
248
|
+
return PassResult(model, modified)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class PassManager(Sequential):
|
|
101
252
|
"""Pass manager for the IR.
|
|
102
253
|
|
|
103
|
-
The PassManager is a
|
|
254
|
+
The PassManager is a Pass that runs a sequence of passes on a model.
|
|
104
255
|
|
|
105
256
|
Attributes:
|
|
106
257
|
passes: The passes to run.
|
|
107
|
-
check_invariants: Whether to check invariants before and after each pass.
|
|
108
258
|
steps: The number of times to run the passes.
|
|
259
|
+
early_stop: Whether to stop running the passes if the graph stops changing.
|
|
109
260
|
"""
|
|
110
261
|
|
|
111
262
|
def __init__(
|
|
112
263
|
self,
|
|
113
264
|
passes: Sequence[PassBase],
|
|
114
|
-
check_invariants: bool = False,
|
|
115
265
|
steps: int = 1,
|
|
266
|
+
early_stop: bool = True,
|
|
116
267
|
):
|
|
117
268
|
# TODO(justinchuby): Implement constraints
|
|
118
|
-
|
|
119
|
-
self.check_invariants = check_invariants
|
|
269
|
+
super().__init__(*passes)
|
|
120
270
|
self.steps = steps
|
|
271
|
+
self.early_stop = early_stop
|
|
121
272
|
|
|
122
|
-
def
|
|
273
|
+
def call(self, model: ir.Model) -> PassResult:
|
|
123
274
|
"""Run the set of passes `steps` number of times or until the graph stops changing."""
|
|
124
275
|
overall_modified = False
|
|
125
276
|
for step in range(self.steps):
|
|
126
|
-
|
|
277
|
+
try:
|
|
278
|
+
# Call the call method of Sequential
|
|
279
|
+
step_result = super().call(model)
|
|
280
|
+
except Exception as e:
|
|
281
|
+
raise PassError(f"An error occurred at step {step}") from e
|
|
127
282
|
model = step_result.model
|
|
128
283
|
modified = step_result.modified
|
|
129
284
|
overall_modified = overall_modified or modified
|
|
130
285
|
# If the graph no longer changes, then we can stop running these passes
|
|
131
|
-
if not modified:
|
|
286
|
+
if not modified and self.early_stop:
|
|
132
287
|
logger.info("PassManager: No more graph changes detected after step %s", step)
|
|
133
288
|
break
|
|
134
289
|
return PassResult(model, overall_modified)
|
|
135
|
-
|
|
136
|
-
def _run_one_step(self, model: ir.Model, step: int) -> PassResult:
|
|
137
|
-
modified = False
|
|
138
|
-
for i, pass_ in enumerate(self.passes):
|
|
139
|
-
logger.debug("Running the %s-th pass '%s', (step %s)", i, pass_, step)
|
|
140
|
-
|
|
141
|
-
# 1. Check preconditions
|
|
142
|
-
if self.check_invariants:
|
|
143
|
-
try:
|
|
144
|
-
pass_.requires(model)
|
|
145
|
-
except Exception as e:
|
|
146
|
-
raise PreconditionError(f"Pre-condition failed for {pass_}") from e
|
|
147
|
-
|
|
148
|
-
# 2. Run the pass
|
|
149
|
-
try:
|
|
150
|
-
pass_result = pass_(model)
|
|
151
|
-
except Exception as e:
|
|
152
|
-
prev_pass_names = [str(p) for p in self.passes[:i]]
|
|
153
|
-
raise PassError(
|
|
154
|
-
f"An error occurred when running the '{pass_}' pass after the "
|
|
155
|
-
f"following passes: {prev_pass_names} during step {step}"
|
|
156
|
-
) from e
|
|
157
|
-
if not isinstance(pass_result, PassResult):
|
|
158
|
-
raise TypeError(
|
|
159
|
-
f"The result of the pass {pass_} should be type PassResult."
|
|
160
|
-
"Please create one with ir.passes.PassResult()."
|
|
161
|
-
)
|
|
162
|
-
|
|
163
|
-
model = pass_result.model
|
|
164
|
-
modified = modified or pass_result.modified
|
|
165
|
-
|
|
166
|
-
# 3. Check postconditions
|
|
167
|
-
if self.check_invariants:
|
|
168
|
-
try:
|
|
169
|
-
pass_.ensures(model)
|
|
170
|
-
except Exception as e:
|
|
171
|
-
raise PostconditionError(f"Post-condition failed for {pass_}") from e
|
|
172
|
-
return PassResult(model, modified)
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
__all__ = [
|
|
5
|
+
"AddInitializersToInputsPass",
|
|
6
|
+
"CheckerPass",
|
|
7
|
+
"ClearMetadataAndDocStringPass",
|
|
8
|
+
"InlinePass",
|
|
9
|
+
"LiftConstantsToInitializersPass",
|
|
10
|
+
"LiftSubgraphInitializersToMainGraphPass",
|
|
11
|
+
"RemoveInitializersFromInputsPass",
|
|
12
|
+
"RemoveUnusedFunctionsPass",
|
|
13
|
+
"RemoveUnusedNodesPass",
|
|
14
|
+
"RemoveUnusedOpsetsPass",
|
|
15
|
+
"ShapeInferencePass",
|
|
16
|
+
"TopologicalSortPass",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
from onnx_ir.passes.common.clear_metadata_and_docstring import (
|
|
20
|
+
ClearMetadataAndDocStringPass,
|
|
21
|
+
)
|
|
22
|
+
from onnx_ir.passes.common.constant_manipulation import (
|
|
23
|
+
AddInitializersToInputsPass,
|
|
24
|
+
LiftConstantsToInitializersPass,
|
|
25
|
+
LiftSubgraphInitializersToMainGraphPass,
|
|
26
|
+
RemoveInitializersFromInputsPass,
|
|
27
|
+
)
|
|
28
|
+
from onnx_ir.passes.common.inliner import InlinePass
|
|
29
|
+
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
30
|
+
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
31
|
+
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
|
|
32
|
+
from onnx_ir.passes.common.unused_removal import (
|
|
33
|
+
RemoveUnusedFunctionsPass,
|
|
34
|
+
RemoveUnusedNodesPass,
|
|
35
|
+
RemoveUnusedOpsetsPass,
|
|
36
|
+
)
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Utilities for interfacing with onnx C APIs."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import TYPE_CHECKING, Callable, TypeVar
|
|
9
|
+
|
|
10
|
+
import onnx_ir as ir
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
import onnx
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
# Temporarily remove initializers larger than this size to keep model size down
|
|
18
|
+
# for the onnx.shape_inference call because it needs to serialize the model
|
|
19
|
+
_BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
|
|
20
|
+
_R = TypeVar("_R")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
|
|
24
|
+
"""Call an ONNX C API function by temporarily removing initializers.
|
|
25
|
+
|
|
26
|
+
This is necessary because the ONNX C API does not support large models
|
|
27
|
+
with initializers that have large tensor values. The input model is left
|
|
28
|
+
unchanged no matter the call succeeds or not.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
func: Partially applied function that takes a model proto and returns anything.
|
|
32
|
+
model: The IR model to pass to the API function.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
The resulting ModelProto that contains the result of the API call.
|
|
36
|
+
"""
|
|
37
|
+
# Store the original initializer values so they can be restored
|
|
38
|
+
initializer_values = tuple(model.graph.initializers.values())
|
|
39
|
+
tensors = {v.name: v.const_value for v in initializer_values}
|
|
40
|
+
original_inputs_len = len(model.graph.inputs)
|
|
41
|
+
|
|
42
|
+
# Turn the initializers into inputs and clear the initializers
|
|
43
|
+
# to limit the model size
|
|
44
|
+
for initializer in initializer_values:
|
|
45
|
+
# Make sure the initializer has its shape/type set
|
|
46
|
+
assert initializer.const_value is not None
|
|
47
|
+
if initializer.shape is None:
|
|
48
|
+
initializer.shape = initializer.const_value.shape # type: ignore[assignment]
|
|
49
|
+
if initializer.dtype is None:
|
|
50
|
+
initializer.dtype = initializer.const_value.dtype
|
|
51
|
+
if initializer not in model.graph.inputs:
|
|
52
|
+
model.graph.inputs.append(initializer)
|
|
53
|
+
if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
|
|
54
|
+
# Temporarily remove the initializer value to reduce model size
|
|
55
|
+
# for onnx.shape_inference
|
|
56
|
+
initializer.const_value = None
|
|
57
|
+
assert initializer.name is not None
|
|
58
|
+
model.graph.initializers.pop(initializer.name)
|
|
59
|
+
|
|
60
|
+
proto = ir.serde.serialize_model(model)
|
|
61
|
+
|
|
62
|
+
try:
|
|
63
|
+
# Call the ONNX C API function
|
|
64
|
+
result = func(proto)
|
|
65
|
+
finally:
|
|
66
|
+
# Restore the original initializer values so the model is unchanged
|
|
67
|
+
for initializer in initializer_values:
|
|
68
|
+
initializer.const_value = tensors[initializer.name]
|
|
69
|
+
model.graph.register_initializer(initializer)
|
|
70
|
+
|
|
71
|
+
# Restore the original inputs
|
|
72
|
+
inputs = model.graph.inputs[:original_inputs_len]
|
|
73
|
+
model.graph.inputs.clear()
|
|
74
|
+
model.graph.inputs.extend(inputs)
|
|
75
|
+
|
|
76
|
+
return result
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Clear all metadata and docstring from the model, graphs, nodes, and functions."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"ClearMetadataAndDocStringPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx_ir as ir
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ClearMetadataAndDocStringPass(ir.passes.InPlacePass):
|
|
19
|
+
"""Clear all metadata and docstring from the model, graphs, nodes, and functions."""
|
|
20
|
+
|
|
21
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
22
|
+
# 0. TODO: Should we clean model metadata and docstring?
|
|
23
|
+
|
|
24
|
+
# 1. Clean up the graph and the belonged nodes metadata properties
|
|
25
|
+
modified = self._clear_graph_or_function_metadata_and_docstring(model.graph)
|
|
26
|
+
|
|
27
|
+
# 2. Clean up all of the functions metadata properties
|
|
28
|
+
for function in model.functions.values():
|
|
29
|
+
modified = (
|
|
30
|
+
self._clear_graph_or_function_metadata_and_docstring(function) or modified
|
|
31
|
+
)
|
|
32
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
33
|
+
|
|
34
|
+
def _clear_graph_or_function_metadata_and_docstring(
|
|
35
|
+
self,
|
|
36
|
+
graph_or_function: ir.Graph | ir.Function,
|
|
37
|
+
) -> bool:
|
|
38
|
+
"""Clear metadata and docstring from the graph or function."""
|
|
39
|
+
checked_graphs_or_functions: set[ir.Graph | ir.Function] = set()
|
|
40
|
+
modified = False
|
|
41
|
+
# Clean up all of the nodes metadata properties
|
|
42
|
+
for node in ir.traversal.RecursiveGraphIterator(graph_or_function):
|
|
43
|
+
if node.metadata_props:
|
|
44
|
+
modified = True
|
|
45
|
+
logger.debug("Removed metadata from %s nodes", node.name)
|
|
46
|
+
node.metadata_props.clear()
|
|
47
|
+
node.doc_string = None
|
|
48
|
+
|
|
49
|
+
# Clean up the owning graph/function metadata properties
|
|
50
|
+
# and doc_string if the graph/function is not already checked
|
|
51
|
+
assert node.graph is not None
|
|
52
|
+
if node.graph not in checked_graphs_or_functions and (
|
|
53
|
+
node.graph.metadata_props or node.graph.doc_string
|
|
54
|
+
):
|
|
55
|
+
modified = True
|
|
56
|
+
logger.debug("Removed metadata from %s graph/function", node.graph.name)
|
|
57
|
+
node.graph.metadata_props.clear()
|
|
58
|
+
node.graph.doc_string = None
|
|
59
|
+
checked_graphs_or_functions.add(node.graph)
|
|
60
|
+
return modified
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Lift constants to initializers."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"AddInitializersToInputsPass",
|
|
9
|
+
"LiftConstantsToInitializersPass",
|
|
10
|
+
"LiftSubgraphInitializersToMainGraphPass",
|
|
11
|
+
"RemoveInitializersFromInputsPass",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
import logging
|
|
15
|
+
|
|
16
|
+
import numpy as np
|
|
17
|
+
|
|
18
|
+
import onnx_ir as ir
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
|
|
24
|
+
"""Lift constants to initializers.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
lift_all_constants: Whether to lift all Constant nodes, including those that does not contain a tensor attribute (e.g. with value_ints etc.)
|
|
28
|
+
Default to False, where only Constants with the ``value`` attribute are lifted.
|
|
29
|
+
size_limit: The minimum size of the tensor to be lifted. If the tensor contains
|
|
30
|
+
number of elements less than size_limit, it will not be lifted. Default is 16.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(self, lift_all_constants: bool = False, size_limit: int = 16):
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.lift_all_constants = lift_all_constants
|
|
36
|
+
self.size_limit = size_limit
|
|
37
|
+
|
|
38
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
39
|
+
count = 0
|
|
40
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
41
|
+
assert node.graph is not None
|
|
42
|
+
if node.op_type != "Constant" or node.domain not in ("", "onnx.ai"):
|
|
43
|
+
continue
|
|
44
|
+
if node.outputs[0].is_graph_output():
|
|
45
|
+
logger.debug(
|
|
46
|
+
"Constant node '%s' is used as output, so it can't be lifted.", node.name
|
|
47
|
+
)
|
|
48
|
+
continue
|
|
49
|
+
constant_node_attribute = set(node.attributes.keys())
|
|
50
|
+
if len(constant_node_attribute) != 1:
|
|
51
|
+
logger.debug(
|
|
52
|
+
"Invalid constant node '%s' has more than one attribute", node.name
|
|
53
|
+
)
|
|
54
|
+
continue
|
|
55
|
+
|
|
56
|
+
attr_name, attr_value = next(iter(node.attributes.items()))
|
|
57
|
+
initializer_name = node.outputs[0].name
|
|
58
|
+
assert initializer_name is not None
|
|
59
|
+
assert isinstance(attr_value, ir.Attr)
|
|
60
|
+
tensor = self._constant_node_attribute_to_tensor(
|
|
61
|
+
node, attr_name, attr_value, initializer_name
|
|
62
|
+
)
|
|
63
|
+
if tensor is None:
|
|
64
|
+
# The reason of None is logged in _constant_node_attribute_to_tensor
|
|
65
|
+
continue
|
|
66
|
+
# Register an initializer with the tensor value
|
|
67
|
+
initializer = ir.Value(
|
|
68
|
+
name=initializer_name,
|
|
69
|
+
shape=tensor.shape, # type: ignore[arg-type]
|
|
70
|
+
type=ir.TensorType(tensor.dtype),
|
|
71
|
+
const_value=tensor,
|
|
72
|
+
)
|
|
73
|
+
assert node.graph is not None
|
|
74
|
+
node.graph.register_initializer(initializer)
|
|
75
|
+
# Replace the constant node with the initializer
|
|
76
|
+
ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
|
|
77
|
+
node.graph.remove(node, safe=True)
|
|
78
|
+
count += 1
|
|
79
|
+
logger.debug(
|
|
80
|
+
"Converted constant node '%s' to initializer '%s'", node.name, initializer_name
|
|
81
|
+
)
|
|
82
|
+
if count:
|
|
83
|
+
logger.debug("Lifted %s constants to initializers", count)
|
|
84
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
85
|
+
|
|
86
|
+
def _constant_node_attribute_to_tensor(
|
|
87
|
+
self, node, attr_name: str, attr_value: ir.Attr, initializer_name: str
|
|
88
|
+
) -> ir.TensorProtocol | None:
|
|
89
|
+
"""Convert constant node attribute to tensor."""
|
|
90
|
+
if not self.lift_all_constants and attr_name != "value":
|
|
91
|
+
logger.debug(
|
|
92
|
+
"Constant node '%s' has non-tensor attribute '%s'", node.name, attr_name
|
|
93
|
+
)
|
|
94
|
+
return None
|
|
95
|
+
|
|
96
|
+
tensor: ir.TensorProtocol
|
|
97
|
+
if attr_name == "value":
|
|
98
|
+
tensor = attr_value.as_tensor()
|
|
99
|
+
elif attr_name == "value_int":
|
|
100
|
+
tensor = ir.tensor(
|
|
101
|
+
attr_value.as_int(), dtype=ir.DataType.INT64, name=initializer_name
|
|
102
|
+
)
|
|
103
|
+
elif attr_name == "value_ints":
|
|
104
|
+
tensor = ir.tensor(
|
|
105
|
+
attr_value.as_ints(), dtype=ir.DataType.INT64, name=initializer_name
|
|
106
|
+
)
|
|
107
|
+
elif attr_name == "value_float":
|
|
108
|
+
tensor = ir.tensor(
|
|
109
|
+
attr_value.as_float(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
110
|
+
)
|
|
111
|
+
elif attr_name == "value_floats":
|
|
112
|
+
tensor = ir.tensor(
|
|
113
|
+
attr_value.as_floats(), dtype=ir.DataType.FLOAT, name=initializer_name
|
|
114
|
+
)
|
|
115
|
+
elif attr_name in ("value_string", "value_strings"):
|
|
116
|
+
tensor = ir.StringTensor(
|
|
117
|
+
np.array(attr_value.value, dtype=np.bytes_), name=initializer_name
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise ValueError(
|
|
121
|
+
f"Unsupported constant node '{node.name}' attribute '{attr_name}'"
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
if tensor.size < self.size_limit:
|
|
125
|
+
logger.debug(
|
|
126
|
+
"Tensor from node '%s' has less than %s elements",
|
|
127
|
+
node.name,
|
|
128
|
+
self.size_limit,
|
|
129
|
+
)
|
|
130
|
+
return None
|
|
131
|
+
return tensor
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
135
|
+
"""Lift subgraph initializers to main graph.
|
|
136
|
+
|
|
137
|
+
This pass lifts the initializers of a subgraph to the main graph.
|
|
138
|
+
It is used to ensure that the initializers are available in the main graph
|
|
139
|
+
for further processing or optimization.
|
|
140
|
+
|
|
141
|
+
Initializers that are also graph inputs will not be lifted.
|
|
142
|
+
|
|
143
|
+
Preconditions:
|
|
144
|
+
- All initializers in the model must have unique names across the main graph and subgraphs.
|
|
145
|
+
"""
|
|
146
|
+
|
|
147
|
+
def requires(self, model: ir.Model) -> None:
|
|
148
|
+
"""Ensure all initializer names are unique."""
|
|
149
|
+
registered_initializer_names: set[str] = set()
|
|
150
|
+
duplicated_initializers: list[ir.Value] = []
|
|
151
|
+
for graph in model.graphs():
|
|
152
|
+
for initializer in graph.initializers.values():
|
|
153
|
+
if initializer.name is None:
|
|
154
|
+
raise ir.passes.PreconditionError(
|
|
155
|
+
f"Initializer name is None. Please ensure all initializers have unique names: {initializer!r}"
|
|
156
|
+
)
|
|
157
|
+
if initializer.name in registered_initializer_names:
|
|
158
|
+
duplicated_initializers.append(initializer)
|
|
159
|
+
else:
|
|
160
|
+
registered_initializer_names.add(initializer.name)
|
|
161
|
+
if duplicated_initializers:
|
|
162
|
+
raise ir.passes.PreconditionError(
|
|
163
|
+
"Found duplicated initializers in the model. "
|
|
164
|
+
"Initializer name must be unique across the main graph and subgraphs. "
|
|
165
|
+
"Please ensure all initializers have unique names. Duplicated: "
|
|
166
|
+
f"{duplicated_initializers!r}"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
170
|
+
count = 0
|
|
171
|
+
for graph in model.graphs():
|
|
172
|
+
if graph is model.graph:
|
|
173
|
+
continue
|
|
174
|
+
for name in tuple(graph.initializers):
|
|
175
|
+
initializer = graph.initializers[name]
|
|
176
|
+
if initializer.is_graph_input():
|
|
177
|
+
# Skip the ones that are also graph inputs
|
|
178
|
+
logger.debug(
|
|
179
|
+
"Initializer '%s' is also a graph input, so it can't be lifted",
|
|
180
|
+
initializer.name,
|
|
181
|
+
)
|
|
182
|
+
continue
|
|
183
|
+
# Remove the initializer from the subgraph
|
|
184
|
+
graph.initializers.pop(name)
|
|
185
|
+
model.graph.register_initializer(initializer)
|
|
186
|
+
count += 1
|
|
187
|
+
logger.debug(
|
|
188
|
+
"Lifted initializer '%s' from subgraph '%s' to main graph",
|
|
189
|
+
initializer.name,
|
|
190
|
+
graph.name,
|
|
191
|
+
)
|
|
192
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class RemoveInitializersFromInputsPass(ir.passes.InPlacePass):
|
|
196
|
+
"""Remove initializers from inputs.
|
|
197
|
+
|
|
198
|
+
This pass finds all graph inputs that have a const_value and removes them from the graph.inputs list.
|
|
199
|
+
"""
|
|
200
|
+
|
|
201
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
202
|
+
count = 0
|
|
203
|
+
for graph in model.graphs():
|
|
204
|
+
initializers = set(graph.initializers.values())
|
|
205
|
+
new_inputs = []
|
|
206
|
+
for input_value in graph.inputs:
|
|
207
|
+
if input_value in initializers:
|
|
208
|
+
count += 1
|
|
209
|
+
else:
|
|
210
|
+
new_inputs.append(input_value)
|
|
211
|
+
graph.inputs.clear()
|
|
212
|
+
graph.inputs.extend(new_inputs)
|
|
213
|
+
logger.info("Removed %s initializers from graph inputs", count)
|
|
214
|
+
return ir.passes.PassResult(model, modified=bool(count))
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
class AddInitializersToInputsPass(ir.passes.InPlacePass):
|
|
218
|
+
"""Add initializers to inputs.
|
|
219
|
+
|
|
220
|
+
This pass finds all initializers and adds them to the graph.inputs list if they are not already present.
|
|
221
|
+
"""
|
|
222
|
+
|
|
223
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
224
|
+
count = 0
|
|
225
|
+
for graph in model.graphs():
|
|
226
|
+
inputs_set = set(graph.inputs)
|
|
227
|
+
for initializer in graph.initializers.values():
|
|
228
|
+
if initializer not in inputs_set:
|
|
229
|
+
graph.inputs.append(initializer)
|
|
230
|
+
count += 1
|
|
231
|
+
logger.info("Added %s initializers to graph inputs", count)
|
|
232
|
+
return ir.passes.PassResult(model, modified=bool(count))
|