onnx-ir 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. onnx_ir/__init__.py +176 -0
  2. onnx_ir/_cloner.py +229 -0
  3. onnx_ir/_convenience/__init__.py +558 -0
  4. onnx_ir/_convenience/_constructors.py +291 -0
  5. onnx_ir/_convenience/_extractor.py +191 -0
  6. onnx_ir/_core.py +4435 -0
  7. onnx_ir/_display.py +54 -0
  8. onnx_ir/_enums.py +474 -0
  9. onnx_ir/_graph_comparison.py +23 -0
  10. onnx_ir/_graph_containers.py +373 -0
  11. onnx_ir/_io.py +133 -0
  12. onnx_ir/_linked_list.py +284 -0
  13. onnx_ir/_metadata.py +45 -0
  14. onnx_ir/_name_authority.py +72 -0
  15. onnx_ir/_polyfill.py +26 -0
  16. onnx_ir/_protocols.py +627 -0
  17. onnx_ir/_safetensors/__init__.py +510 -0
  18. onnx_ir/_tape.py +242 -0
  19. onnx_ir/_thirdparty/asciichartpy.py +310 -0
  20. onnx_ir/_type_casting.py +89 -0
  21. onnx_ir/_version_utils.py +48 -0
  22. onnx_ir/analysis/__init__.py +21 -0
  23. onnx_ir/analysis/_implicit_usage.py +74 -0
  24. onnx_ir/convenience.py +38 -0
  25. onnx_ir/external_data.py +459 -0
  26. onnx_ir/passes/__init__.py +41 -0
  27. onnx_ir/passes/_pass_infra.py +351 -0
  28. onnx_ir/passes/common/__init__.py +54 -0
  29. onnx_ir/passes/common/_c_api_utils.py +76 -0
  30. onnx_ir/passes/common/clear_metadata_and_docstring.py +60 -0
  31. onnx_ir/passes/common/common_subexpression_elimination.py +207 -0
  32. onnx_ir/passes/common/constant_manipulation.py +230 -0
  33. onnx_ir/passes/common/default_attributes.py +99 -0
  34. onnx_ir/passes/common/identity_elimination.py +120 -0
  35. onnx_ir/passes/common/initializer_deduplication.py +179 -0
  36. onnx_ir/passes/common/inliner.py +223 -0
  37. onnx_ir/passes/common/naming.py +280 -0
  38. onnx_ir/passes/common/onnx_checker.py +57 -0
  39. onnx_ir/passes/common/output_fix.py +141 -0
  40. onnx_ir/passes/common/shape_inference.py +112 -0
  41. onnx_ir/passes/common/topological_sort.py +37 -0
  42. onnx_ir/passes/common/unused_removal.py +215 -0
  43. onnx_ir/py.typed +1 -0
  44. onnx_ir/serde.py +2043 -0
  45. onnx_ir/tape.py +15 -0
  46. onnx_ir/tensor_adapters.py +210 -0
  47. onnx_ir/testing.py +197 -0
  48. onnx_ir/traversal.py +118 -0
  49. onnx_ir-0.1.15.dist-info/METADATA +68 -0
  50. onnx_ir-0.1.15.dist-info/RECORD +53 -0
  51. onnx_ir-0.1.15.dist-info/WHEEL +5 -0
  52. onnx_ir-0.1.15.dist-info/licenses/LICENSE +202 -0
  53. onnx_ir-0.1.15.dist-info/top_level.txt +1 -0
@@ -0,0 +1,351 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # This module implements some APIs described in
5
+ # https://pytorch.org/executorch/stable/compiler-custom-compiler-passes.html
6
+ # for the ONNX IR.
7
+ # The classes {PassResult and PassManager} are derived from
8
+ # https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_base.py#L12
9
+ # and
10
+ # https://github.com/pytorch/pytorch/blob/1e47c7b11b312b47a621efd547f5c90081f0d9cb/torch/fx/passes/infra/pass_manager.py#L147
11
+ # The original code is licensed under the PyTorch License https://github.com/pytorch/pytorch/blob/main/LICENSE
12
+
13
+ """Passes infrastructure for the IR."""
14
+
15
+ from __future__ import annotations
16
+
17
+ import dataclasses
18
+ import logging
19
+ from collections.abc import Sequence
20
+ from typing import Literal, final
21
+
22
+ __all__ = [
23
+ "PassBase",
24
+ "Sequential",
25
+ "InPlacePass",
26
+ "FunctionalPass",
27
+ "PassManager",
28
+ "PassResult",
29
+ "functionalize",
30
+ # Errors
31
+ "InvariantError",
32
+ "PreconditionError",
33
+ "PostconditionError",
34
+ "PassError",
35
+ ]
36
+
37
+ import abc
38
+
39
+ import onnx_ir as ir
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class InvariantError(Exception):
45
+ """Raised when an invariant is violated."""
46
+
47
+
48
+ class PreconditionError(InvariantError):
49
+ """Raised when a precondition is violated."""
50
+
51
+
52
+ class PostconditionError(InvariantError):
53
+ """Raised when a postcondition is violated."""
54
+
55
+
56
+ class PassError(RuntimeError):
57
+ """Raised when an error occurs during a pass."""
58
+
59
+
60
+ @dataclasses.dataclass
61
+ class PassResult:
62
+ """Result of a pass.
63
+
64
+ Attributes:
65
+ model: The transformed model.
66
+ modified: Whether the resulting model is different from the input model.
67
+ """
68
+
69
+ model: ir.Model
70
+ modified: bool
71
+
72
+
73
+ class PassBase(abc.ABC):
74
+ """Base class for all passes.
75
+
76
+ ``in_place`` and ``changes_input`` properties and what they mean:
77
+
78
+ +------------+------------------+----------------------------+
79
+ | | changes_inputs | not changes_inputs |
80
+ +------------+------------------+----------------------------+
81
+ | in_place | in place | Side-effect-only pass |
82
+ +------------+------------------+----------------------------+
83
+ | not | destructive | functional |
84
+ | in_place | | |
85
+ +------------+------------------+----------------------------+
86
+ """
87
+
88
+ @property
89
+ @abc.abstractmethod
90
+ def in_place(self) -> bool:
91
+ """Whether the pass modifies the model in place and returns it.
92
+
93
+ If True, the pass will return the same model object that was passed in.
94
+ If False, the pass will return a new model object.
95
+ """
96
+ raise NotImplementedError
97
+
98
+ @property
99
+ @abc.abstractmethod
100
+ def changes_input(self) -> bool:
101
+ """Whether the pass modifies input model."""
102
+ raise NotImplementedError
103
+
104
+ @property
105
+ def destructive(self) -> bool:
106
+ """Whether the pass will destroy the input model when ``in_place=False``.
107
+
108
+ A pass is destructive if it is not in place and it modifies the input model.
109
+ """
110
+ return not self.in_place and self.changes_input
111
+
112
+ def __call__(self, model_or_result: ir.Model | PassResult, /) -> PassResult:
113
+ if isinstance(model_or_result, PassResult):
114
+ model = model_or_result.model
115
+ else:
116
+ model = model_or_result
117
+ # Check preconditions
118
+ try:
119
+ self.requires(model)
120
+ except PreconditionError:
121
+ raise
122
+ except Exception as e:
123
+ raise PreconditionError(
124
+ f"Pre-condition for pass '{self.__class__.__name__}' failed"
125
+ ) from e
126
+
127
+ result = self.call(model)
128
+
129
+ # Check postconditions
130
+ try:
131
+ self.ensures(result.model)
132
+ except PostconditionError:
133
+ raise
134
+ except Exception as e:
135
+ raise PostconditionError(
136
+ f"Post-condition for pass '{self.__class__.__name__}' failed"
137
+ ) from e
138
+
139
+ if not isinstance(result, PassResult):
140
+ raise TypeError(
141
+ f"The result of the pass '{self.__class__.__name__}' should be type PassResult. "
142
+ "Please create one with ir.passes.PassResult()."
143
+ )
144
+
145
+ # Checks that the declared in-place property is respected
146
+ if self.in_place and result.model is not model:
147
+ raise PassError(
148
+ f"The pass '{self.__class__.__name__}' is declared in-place, "
149
+ "but the model returned is *not* the same object as the input model. "
150
+ "Pass developer: Pass should return the same model object or the in_place property should return False."
151
+ )
152
+ if not self.in_place and result.model is model:
153
+ raise PassError(
154
+ f"The pass '{self.__class__.__name__}' is declared not in-place, "
155
+ "but the model returned *is* the same object as the input model. "
156
+ "Pass developer: Pass should return a new model object or the in_place property should return True."
157
+ )
158
+ return result
159
+
160
+ @abc.abstractmethod
161
+ def call(self, model: ir.Model) -> PassResult:
162
+ """The main entry point for the pass."""
163
+ ...
164
+
165
+ def requires(self, model: ir.Model) -> None:
166
+ """Pre-conditions for the pass.
167
+
168
+ This is optional to implement, will be called before call() if run by a pass manager.
169
+ """
170
+ del model # Unused
171
+
172
+ def ensures(self, model: ir.Model) -> None:
173
+ """Post-conditions for the pass.
174
+
175
+ This is optional to implement, will be called after call() if run by a pass manager.
176
+ """
177
+ del model # Unused
178
+
179
+
180
+ class InPlacePass(PassBase):
181
+ """A pass that modifies the input model in place and returns it."""
182
+
183
+ @property
184
+ @final
185
+ def in_place(self) -> Literal[True]:
186
+ """An in-place pass is in place."""
187
+ return True
188
+
189
+ @property
190
+ @final
191
+ def changes_input(self) -> Literal[True]:
192
+ """An in-place pass changes the input model."""
193
+ return True
194
+
195
+
196
+ class FunctionalPass(PassBase):
197
+ """A pass that returns a new model but does not modify the input model."""
198
+
199
+ @property
200
+ @final
201
+ def in_place(self) -> Literal[False]:
202
+ """A functional pass is not in place."""
203
+ return False
204
+
205
+ @property
206
+ @final
207
+ def changes_input(self) -> Literal[False]:
208
+ """A functional pass does not change the input model."""
209
+ return False
210
+
211
+
212
+ class Sequential(PassBase):
213
+ """Run a sequence of passes in order.
214
+
215
+ Example::
216
+ import onnx_ir as ir
217
+ import onnx_ir.passes.common as common_passes
218
+
219
+ passes = ir.passes.Sequential(
220
+ common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
221
+ common_passes.CommonSubexpressionEliminationPass(),
222
+ common_passes.ClearMetadataAndDocStringPass(),
223
+ )
224
+ result = passes(model)
225
+ """
226
+
227
+ def __init__(self, *passes: PassBase):
228
+ if not passes:
229
+ raise ValueError("Sequential must take at least one pass")
230
+ self.passes = passes
231
+ self._in_place = all(pass_.in_place for pass_ in passes)
232
+ # The reason changes_inputs is decided by the first pass is that if the first pass is either in-place,
233
+ # or if it is not designed to be in-place but somehow changes the input (destructive),
234
+ # this pass sequence will change inputs.
235
+ self._changes_input = self.passes[0].changes_input or self.passes[0].in_place
236
+
237
+ @property
238
+ def in_place(self) -> bool:
239
+ return self._in_place
240
+
241
+ @property
242
+ def changes_input(self) -> bool:
243
+ return self._changes_input
244
+
245
+ def call(self, model: ir.Model) -> PassResult:
246
+ modified = False
247
+ for i, pass_ in enumerate(self.passes):
248
+ logger.debug("Running the %s-th pass '%s'", i, pass_)
249
+ try:
250
+ pass_result = pass_(model)
251
+ except Exception as e:
252
+ prev_pass_names = [str(p) for p in self.passes[:i]]
253
+ raise PassError(
254
+ f"An error occurred when running the '{pass_}' pass after the "
255
+ f"following passes: {prev_pass_names}"
256
+ ) from e
257
+
258
+ model = pass_result.model
259
+ modified = modified or pass_result.modified
260
+
261
+ return PassResult(model, modified)
262
+
263
+
264
+ class PassManager(Sequential):
265
+ """Pass manager for the IR.
266
+
267
+ The PassManager is a Pass that runs a sequence of passes on a model.
268
+
269
+ Example::
270
+ import onnx_ir as ir
271
+ import onnx_ir.passes.common as common_passes
272
+
273
+ model = ir.load("model.onnx")
274
+ passes = ir.passes.PassManager(
275
+ [
276
+ # Pass managers can be nested
277
+ ir.passes.PassManager(
278
+ [
279
+ common_passes.DeduplicateHashedInitializersPass(size_limit=1024 * 1024),
280
+ common_passes.CommonSubexpressionEliminationPass(),
281
+ ],
282
+ steps=2,
283
+ early_stop=True,
284
+ ),
285
+ common_passes.ClearMetadataAndDocStringPass(),
286
+ ],
287
+ steps=2,
288
+ early_stop=False,
289
+ )
290
+
291
+ # Apply the passes to the model
292
+ result = passes(model)
293
+
294
+ Attributes:
295
+ passes: The passes to run.
296
+ steps: The number of times to run the passes.
297
+ early_stop: Whether to stop running the passes if the graph stops changing.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ passes: Sequence[PassBase],
303
+ steps: int = 1,
304
+ early_stop: bool = True,
305
+ ):
306
+ # TODO(justinchuby): Implement constraints
307
+ super().__init__(*passes)
308
+ self.steps = steps
309
+ self.early_stop = early_stop
310
+
311
+ def call(self, model: ir.Model) -> PassResult:
312
+ """Run the set of passes `steps` number of times or until the graph stops changing."""
313
+ overall_modified = False
314
+ for step in range(self.steps):
315
+ try:
316
+ # Call the call method of Sequential
317
+ step_result = super().call(model)
318
+ except Exception as e:
319
+ raise PassError(f"An error occurred at step {step}") from e
320
+ model = step_result.model
321
+ modified = step_result.modified
322
+ overall_modified = overall_modified or modified
323
+ # If the graph no longer changes, then we can stop running these passes
324
+ if not modified and self.early_stop:
325
+ logger.info("PassManager: No more graph changes detected after step %s", step)
326
+ break
327
+ return PassResult(model, overall_modified)
328
+
329
+
330
+ class _FunctionalPassWrapper(FunctionalPass):
331
+ def __init__(self, inner_pass: PassBase) -> None:
332
+ self._inner_pass = inner_pass
333
+
334
+ def call(self, model: ir.Model) -> PassResult:
335
+ return self._inner_pass(model.clone())
336
+
337
+
338
+ def functionalize(pass_instance: PassBase) -> FunctionalPass:
339
+ """Produce a functional pass from a given pass.
340
+
341
+ A new functional pass is created that clones the input model before running the pass.
342
+
343
+ .. versionadded:: 0.1.14
344
+
345
+ Args:
346
+ pass_instance: The pass to convert.
347
+
348
+ Returns:
349
+ A functional pass.
350
+ """
351
+ return _FunctionalPassWrapper(pass_instance)
@@ -0,0 +1,54 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ __all__ = [
5
+ "AddDefaultAttributesPass",
6
+ "AddInitializersToInputsPass",
7
+ "CheckerPass",
8
+ "ClearMetadataAndDocStringPass",
9
+ "CommonSubexpressionEliminationPass",
10
+ "DeduplicateHashedInitializersPass",
11
+ "DeduplicateInitializersPass",
12
+ "IdentityEliminationPass",
13
+ "InlinePass",
14
+ "LiftConstantsToInitializersPass",
15
+ "LiftSubgraphInitializersToMainGraphPass",
16
+ "NameFixPass",
17
+ "OutputFixPass",
18
+ "RemoveInitializersFromInputsPass",
19
+ "RemoveUnusedFunctionsPass",
20
+ "RemoveUnusedNodesPass",
21
+ "RemoveUnusedOpsetsPass",
22
+ "ShapeInferencePass",
23
+ "TopologicalSortPass",
24
+ ]
25
+
26
+ from onnx_ir.passes.common.clear_metadata_and_docstring import (
27
+ ClearMetadataAndDocStringPass,
28
+ )
29
+ from onnx_ir.passes.common.common_subexpression_elimination import (
30
+ CommonSubexpressionEliminationPass,
31
+ )
32
+ from onnx_ir.passes.common.constant_manipulation import (
33
+ AddInitializersToInputsPass,
34
+ LiftConstantsToInitializersPass,
35
+ LiftSubgraphInitializersToMainGraphPass,
36
+ RemoveInitializersFromInputsPass,
37
+ )
38
+ from onnx_ir.passes.common.default_attributes import AddDefaultAttributesPass
39
+ from onnx_ir.passes.common.identity_elimination import IdentityEliminationPass
40
+ from onnx_ir.passes.common.initializer_deduplication import (
41
+ DeduplicateHashedInitializersPass,
42
+ DeduplicateInitializersPass,
43
+ )
44
+ from onnx_ir.passes.common.inliner import InlinePass
45
+ from onnx_ir.passes.common.naming import NameFixPass
46
+ from onnx_ir.passes.common.onnx_checker import CheckerPass
47
+ from onnx_ir.passes.common.output_fix import OutputFixPass
48
+ from onnx_ir.passes.common.shape_inference import ShapeInferencePass
49
+ from onnx_ir.passes.common.topological_sort import TopologicalSortPass
50
+ from onnx_ir.passes.common.unused_removal import (
51
+ RemoveUnusedFunctionsPass,
52
+ RemoveUnusedNodesPass,
53
+ RemoveUnusedOpsetsPass,
54
+ )
@@ -0,0 +1,76 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Utilities for interfacing with onnx C APIs."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import logging
8
+ from typing import TYPE_CHECKING, Callable, TypeVar
9
+
10
+ import onnx_ir as ir
11
+
12
+ if TYPE_CHECKING:
13
+ import onnx # noqa: TID251
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+ # Temporarily remove initializers larger than this size to keep model size down
18
+ # for the onnx.shape_inference call because it needs to serialize the model
19
+ _BIG_TENSOR_SIZE_LIMIT = 1000 # 1KB
20
+ _R = TypeVar("_R")
21
+
22
+
23
+ def call_onnx_api(func: Callable[[onnx.ModelProto], _R], model: ir.Model) -> _R:
24
+ """Call an ONNX C API function by temporarily removing initializers.
25
+
26
+ This is necessary because the ONNX C API does not support large models
27
+ with initializers that have large tensor values. The input model is left
28
+ unchanged no matter the call succeeds or not.
29
+
30
+ Args:
31
+ func: Partially applied function that takes a model proto and returns anything.
32
+ model: The IR model to pass to the API function.
33
+
34
+ Returns:
35
+ The resulting ModelProto that contains the result of the API call.
36
+ """
37
+ # Store the original initializer values so they can be restored
38
+ initializer_values = tuple(model.graph.initializers.values())
39
+ tensors = {v.name: v.const_value for v in initializer_values}
40
+ original_inputs_len = len(model.graph.inputs)
41
+
42
+ # Turn the initializers into inputs and clear the initializers
43
+ # to limit the model size
44
+ for initializer in initializer_values:
45
+ # Make sure the initializer has its shape/type set
46
+ assert initializer.const_value is not None
47
+ if initializer.shape is None:
48
+ initializer.shape = initializer.const_value.shape # type: ignore[assignment]
49
+ if initializer.dtype is None:
50
+ initializer.dtype = initializer.const_value.dtype
51
+ if initializer not in model.graph.inputs:
52
+ model.graph.inputs.append(initializer)
53
+ if initializer.const_value.nbytes > _BIG_TENSOR_SIZE_LIMIT:
54
+ # Temporarily remove the initializer value to reduce model size
55
+ # for onnx.shape_inference
56
+ initializer.const_value = None
57
+ assert initializer.name is not None
58
+ model.graph.initializers.pop(initializer.name)
59
+
60
+ proto = ir.serde.serialize_model(model)
61
+
62
+ try:
63
+ # Call the ONNX C API function
64
+ result = func(proto)
65
+ finally:
66
+ # Restore the original initializer values so the model is unchanged
67
+ for initializer in initializer_values:
68
+ initializer.const_value = tensors[initializer.name]
69
+ model.graph.register_initializer(initializer)
70
+
71
+ # Restore the original inputs
72
+ inputs = model.graph.inputs[:original_inputs_len]
73
+ model.graph.inputs.clear()
74
+ model.graph.inputs.extend(inputs)
75
+
76
+ return result
@@ -0,0 +1,60 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Clear all metadata and docstring from the model, graphs, nodes, and functions."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "ClearMetadataAndDocStringPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class ClearMetadataAndDocStringPass(ir.passes.InPlacePass):
19
+ """Clear all metadata and docstring from the model, graphs, nodes, and functions."""
20
+
21
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
22
+ # 0. TODO: Should we clean model metadata and docstring?
23
+
24
+ # 1. Clean up the graph and the belonged nodes metadata properties
25
+ modified = self._clear_graph_or_function_metadata_and_docstring(model.graph)
26
+
27
+ # 2. Clean up all of the functions metadata properties
28
+ for function in model.functions.values():
29
+ modified = (
30
+ self._clear_graph_or_function_metadata_and_docstring(function) or modified
31
+ )
32
+ return ir.passes.PassResult(model, modified=modified)
33
+
34
+ def _clear_graph_or_function_metadata_and_docstring(
35
+ self,
36
+ graph_or_function: ir.Graph | ir.Function,
37
+ ) -> bool:
38
+ """Clear metadata and docstring from the graph or function."""
39
+ checked_graphs_or_functions: set[ir.Graph | ir.Function] = set()
40
+ modified = False
41
+ # Clean up all of the nodes metadata properties
42
+ for node in ir.traversal.RecursiveGraphIterator(graph_or_function):
43
+ if node.metadata_props:
44
+ modified = True
45
+ logger.debug("Removed metadata from %s nodes", node.name)
46
+ node.metadata_props.clear()
47
+ node.doc_string = None
48
+
49
+ # Clean up the owning graph/function metadata properties
50
+ # and doc_string if the graph/function is not already checked
51
+ assert node.graph is not None
52
+ if node.graph not in checked_graphs_or_functions and (
53
+ node.graph.metadata_props or node.graph.doc_string
54
+ ):
55
+ modified = True
56
+ logger.debug("Removed metadata from %s graph/function", node.graph.name)
57
+ node.graph.metadata_props.clear()
58
+ node.graph.doc_string = None
59
+ checked_graphs_or_functions.add(node.graph)
60
+ return modified