onnx-ir 0.1.5__tar.gz → 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.5/src/onnx_ir.egg-info → onnx_ir-0.1.7}/PKG-INFO +2 -4
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/README.md +0 -1
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/pyproject.toml +3 -3
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_convenience/__init__.py +49 -32
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_core.py +60 -16
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/__init__.py +4 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/constant_manipulation.py +15 -7
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/identity_elimination.py +1 -0
- onnx_ir-0.1.7/src/onnx_ir/passes/common/initializer_deduplication.py +167 -0
- onnx_ir-0.1.7/src/onnx_ir/passes/common/naming.py +286 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/serde.py +94 -34
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/tensor_adapters.py +14 -2
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/traversal.py +35 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7/src/onnx_ir.egg-info}/PKG-INFO +2 -4
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
- onnx_ir-0.1.5/src/onnx_ir/passes/common/initializer_deduplication.py +0 -56
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/LICENSE +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/MANIFEST.in +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/setup.cfg +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_enums.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/py.typed +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.7
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
|
-
License: Apache
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
7
|
Project-URL: Homepage, https://onnx.ai/ir-py
|
|
8
8
|
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
9
|
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
17
16
|
Requires-Python: >=3.9
|
|
18
17
|
Description-Content-Type: text/markdown
|
|
19
18
|
License-File: LICENSE
|
|
@@ -29,7 +28,6 @@ Dynamic: license-file
|
|
|
29
28
|
[](https://pypi.org/project/onnx-ir)
|
|
30
29
|
[](https://github.com/astral-sh/ruff)
|
|
31
30
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
32
|
-
[](https://deepwiki.com/onnx/ir-py)
|
|
33
31
|
[](https://pepy.tech/projects/onnx-ir)
|
|
34
32
|
|
|
35
33
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
@@ -4,7 +4,6 @@
|
|
|
4
4
|
[](https://pypi.org/project/onnx-ir)
|
|
5
5
|
[](https://github.com/astral-sh/ruff)
|
|
6
6
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
7
|
-
[](https://deepwiki.com/onnx/ir-py)
|
|
8
7
|
[](https://pepy.tech/projects/onnx-ir)
|
|
9
8
|
|
|
10
9
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
[build-system]
|
|
2
|
-
requires = ["setuptools>=
|
|
2
|
+
requires = ["setuptools>=77"]
|
|
3
3
|
build-backend = "setuptools.build_meta"
|
|
4
4
|
|
|
5
5
|
[project]
|
|
@@ -11,7 +11,8 @@ authors = [
|
|
|
11
11
|
]
|
|
12
12
|
readme = "README.md"
|
|
13
13
|
requires-python = ">=3.9"
|
|
14
|
-
license =
|
|
14
|
+
license = "Apache-2.0"
|
|
15
|
+
license-files = ["LICEN[CS]E*"]
|
|
15
16
|
classifiers = [
|
|
16
17
|
"Development Status :: 4 - Beta",
|
|
17
18
|
"Programming Language :: Python :: 3.9",
|
|
@@ -19,7 +20,6 @@ classifiers = [
|
|
|
19
20
|
"Programming Language :: Python :: 3.11",
|
|
20
21
|
"Programming Language :: Python :: 3.12",
|
|
21
22
|
"Programming Language :: Python :: 3.13",
|
|
22
|
-
"License :: OSI Approved :: Apache Software License",
|
|
23
23
|
]
|
|
24
24
|
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
|
|
25
25
|
|
|
@@ -58,44 +58,52 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
|
58
58
|
return _enums.AttributeType.STRING
|
|
59
59
|
if isinstance(attr, _core.Attr):
|
|
60
60
|
return attr.type
|
|
61
|
-
if isinstance(attr,
|
|
62
|
-
return _enums.AttributeType.
|
|
63
|
-
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
|
|
64
|
-
return _enums.AttributeType.FLOATS
|
|
65
|
-
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
|
|
66
|
-
return _enums.AttributeType.STRINGS
|
|
61
|
+
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
62
|
+
return _enums.AttributeType.GRAPH
|
|
67
63
|
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
|
|
68
64
|
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
|
|
69
65
|
return _enums.AttributeType.TENSOR
|
|
70
|
-
if isinstance(attr, Sequence) and all(
|
|
71
|
-
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
72
|
-
for x in attr
|
|
73
|
-
):
|
|
74
|
-
return _enums.AttributeType.TENSORS
|
|
75
|
-
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
76
|
-
return _enums.AttributeType.GRAPH
|
|
77
|
-
if isinstance(attr, Sequence) and all(
|
|
78
|
-
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
|
|
79
|
-
):
|
|
80
|
-
return _enums.AttributeType.GRAPHS
|
|
81
66
|
if isinstance(
|
|
82
67
|
attr,
|
|
83
68
|
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
|
|
84
69
|
):
|
|
85
70
|
return _enums.AttributeType.TYPE_PROTO
|
|
86
|
-
if isinstance(attr, Sequence)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
71
|
+
if isinstance(attr, Sequence):
|
|
72
|
+
if not attr:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"Attribute type is ambiguous because it is an empty sequence. "
|
|
75
|
+
"Please create an Attr with an explicit type. Defaulted to INTS"
|
|
76
|
+
)
|
|
77
|
+
return _enums.AttributeType.INTS
|
|
78
|
+
if all(isinstance(x, int) for x in attr):
|
|
79
|
+
return _enums.AttributeType.INTS
|
|
80
|
+
if all(isinstance(x, float) for x in attr):
|
|
81
|
+
return _enums.AttributeType.FLOATS
|
|
82
|
+
if all(isinstance(x, str) for x in attr):
|
|
83
|
+
return _enums.AttributeType.STRINGS
|
|
84
|
+
if all(
|
|
85
|
+
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
86
|
+
for x in attr
|
|
87
|
+
):
|
|
88
|
+
return _enums.AttributeType.TENSORS
|
|
89
|
+
if all(
|
|
90
|
+
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
|
|
91
|
+
for x in attr
|
|
92
|
+
):
|
|
93
|
+
return _enums.AttributeType.GRAPHS
|
|
94
|
+
if all(
|
|
95
|
+
isinstance(
|
|
96
|
+
x,
|
|
97
|
+
(
|
|
98
|
+
_core.TensorType,
|
|
99
|
+
_core.SequenceType,
|
|
100
|
+
_core.OptionalType,
|
|
101
|
+
_protocols.TypeProtocol,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
for x in attr
|
|
105
|
+
):
|
|
106
|
+
return _enums.AttributeType.TYPE_PROTOS
|
|
99
107
|
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
100
108
|
|
|
101
109
|
|
|
@@ -218,7 +226,7 @@ def convert_attributes(
|
|
|
218
226
|
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
|
|
219
227
|
... }
|
|
220
228
|
>>> convert_attributes(attrs)
|
|
221
|
-
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph',
|
|
229
|
+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
|
|
222
230
|
name='graph0',
|
|
223
231
|
inputs=(
|
|
224
232
|
<BLANKLINE>
|
|
@@ -247,11 +255,20 @@ def convert_attributes(
|
|
|
247
255
|
len()=0
|
|
248
256
|
)]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
|
|
249
257
|
|
|
258
|
+
.. important::
|
|
259
|
+
An empty sequence should be created with an explicit type by initializing
|
|
260
|
+
an Attr object with an attribute type to avoid type ambiguity. For example::
|
|
261
|
+
|
|
262
|
+
ir.Attr("empty", [], type=ir.AttributeType.INTS)
|
|
263
|
+
|
|
250
264
|
Args:
|
|
251
265
|
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
|
|
252
266
|
|
|
253
267
|
Returns:
|
|
254
|
-
A list of _core.Attr objects.
|
|
268
|
+
A list of :class:`_core.Attr` objects.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
TypeError: If an attribute type is not supported.
|
|
255
272
|
"""
|
|
256
273
|
attributes: list[_core.Attr] = []
|
|
257
274
|
for name, attr in attrs.items():
|
|
@@ -2564,14 +2564,23 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2564
2564
|
|
|
2565
2565
|
.. versionadded:: 0.1.2
|
|
2566
2566
|
"""
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2567
|
+
# Use a dict to preserve order
|
|
2568
|
+
seen_graphs: dict[Graph, None] = {}
|
|
2569
|
+
|
|
2570
|
+
# Need to use the enter_graph callback so that empty subgraphs are collected
|
|
2571
|
+
def enter_subgraph(graph) -> None:
|
|
2570
2572
|
if graph is self:
|
|
2571
|
-
|
|
2572
|
-
if
|
|
2573
|
-
|
|
2574
|
-
|
|
2573
|
+
return
|
|
2574
|
+
if not isinstance(graph, Graph):
|
|
2575
|
+
raise TypeError(
|
|
2576
|
+
f"Expected a Graph, got {type(graph)}. The model may be invalid"
|
|
2577
|
+
)
|
|
2578
|
+
if graph not in seen_graphs:
|
|
2579
|
+
seen_graphs[graph] = None
|
|
2580
|
+
|
|
2581
|
+
for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
|
|
2582
|
+
pass
|
|
2583
|
+
yield from seen_graphs.keys()
|
|
2575
2584
|
|
|
2576
2585
|
# Mutation methods
|
|
2577
2586
|
def append(self, node: Node, /) -> None:
|
|
@@ -3180,6 +3189,21 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3180
3189
|
def attributes(self) -> _graph_containers.Attributes:
|
|
3181
3190
|
return self._attributes
|
|
3182
3191
|
|
|
3192
|
+
@property
|
|
3193
|
+
def graph(self) -> Graph:
|
|
3194
|
+
"""The underlying Graph object that contains the nodes of this function.
|
|
3195
|
+
|
|
3196
|
+
Only use this graph for identity comparison::
|
|
3197
|
+
|
|
3198
|
+
if value.graph is function.graph:
|
|
3199
|
+
# Do something with the value that belongs to this function
|
|
3200
|
+
|
|
3201
|
+
Otherwise use the Function object directly to access the nodes and other properties.
|
|
3202
|
+
|
|
3203
|
+
.. versionadded:: 0.1.7
|
|
3204
|
+
"""
|
|
3205
|
+
return self._graph
|
|
3206
|
+
|
|
3183
3207
|
@typing.overload
|
|
3184
3208
|
def __getitem__(self, index: int) -> Node: ...
|
|
3185
3209
|
@typing.overload
|
|
@@ -3240,14 +3264,22 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3240
3264
|
|
|
3241
3265
|
.. versionadded:: 0.1.2
|
|
3242
3266
|
"""
|
|
3243
|
-
seen_graphs:
|
|
3244
|
-
|
|
3245
|
-
|
|
3246
|
-
|
|
3247
|
-
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
|
|
3267
|
+
seen_graphs: dict[Graph, None] = {}
|
|
3268
|
+
|
|
3269
|
+
# Need to use the enter_graph callback so that empty subgraphs are collected
|
|
3270
|
+
def enter_subgraph(graph) -> None:
|
|
3271
|
+
if graph is self:
|
|
3272
|
+
return
|
|
3273
|
+
if not isinstance(graph, Graph):
|
|
3274
|
+
raise TypeError(
|
|
3275
|
+
f"Expected a Graph, got {type(graph)}. The model may be invalid"
|
|
3276
|
+
)
|
|
3277
|
+
if graph not in seen_graphs:
|
|
3278
|
+
seen_graphs[graph] = None
|
|
3279
|
+
|
|
3280
|
+
for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
|
|
3281
|
+
pass
|
|
3282
|
+
yield from seen_graphs.keys()
|
|
3251
3283
|
|
|
3252
3284
|
# Mutation methods
|
|
3253
3285
|
def append(self, node: Node, /) -> None:
|
|
@@ -3349,7 +3381,7 @@ class Attr(
|
|
|
3349
3381
|
):
|
|
3350
3382
|
"""Base class for ONNX attributes or references."""
|
|
3351
3383
|
|
|
3352
|
-
__slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
3384
|
+
__slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
3353
3385
|
|
|
3354
3386
|
def __init__(
|
|
3355
3387
|
self,
|
|
@@ -3365,6 +3397,7 @@ class Attr(
|
|
|
3365
3397
|
self._value = value
|
|
3366
3398
|
self._ref_attr_name = ref_attr_name
|
|
3367
3399
|
self.doc_string = doc_string
|
|
3400
|
+
self._metadata: _metadata.MetadataStore | None = None
|
|
3368
3401
|
|
|
3369
3402
|
@property
|
|
3370
3403
|
def name(self) -> str:
|
|
@@ -3386,6 +3419,17 @@ class Attr(
|
|
|
3386
3419
|
def ref_attr_name(self) -> str | None:
|
|
3387
3420
|
return self._ref_attr_name
|
|
3388
3421
|
|
|
3422
|
+
@property
|
|
3423
|
+
def meta(self) -> _metadata.MetadataStore:
|
|
3424
|
+
"""The metadata store for intermediate analysis.
|
|
3425
|
+
|
|
3426
|
+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
3427
|
+
to the ONNX proto.
|
|
3428
|
+
"""
|
|
3429
|
+
if self._metadata is None:
|
|
3430
|
+
self._metadata = _metadata.MetadataStore()
|
|
3431
|
+
return self._metadata
|
|
3432
|
+
|
|
3389
3433
|
def is_ref(self) -> bool:
|
|
3390
3434
|
"""Check if this attribute is a reference attribute."""
|
|
3391
3435
|
return self.ref_attr_name is not None
|
|
@@ -6,11 +6,13 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateHashedInitializersPass",
|
|
9
10
|
"DeduplicateInitializersPass",
|
|
10
11
|
"IdentityEliminationPass",
|
|
11
12
|
"InlinePass",
|
|
12
13
|
"LiftConstantsToInitializersPass",
|
|
13
14
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
15
|
+
"NameFixPass",
|
|
14
16
|
"RemoveInitializersFromInputsPass",
|
|
15
17
|
"RemoveUnusedFunctionsPass",
|
|
16
18
|
"RemoveUnusedNodesPass",
|
|
@@ -35,9 +37,11 @@ from onnx_ir.passes.common.identity_elimination import (
|
|
|
35
37
|
IdentityEliminationPass,
|
|
36
38
|
)
|
|
37
39
|
from onnx_ir.passes.common.initializer_deduplication import (
|
|
40
|
+
DeduplicateHashedInitializersPass,
|
|
38
41
|
DeduplicateInitializersPass,
|
|
39
42
|
)
|
|
40
43
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
44
|
+
from onnx_ir.passes.common.naming import NameFixPass
|
|
41
45
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
42
46
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
43
47
|
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
|
|
@@ -148,6 +148,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
148
148
|
if graph is model.graph:
|
|
149
149
|
continue
|
|
150
150
|
for name in tuple(graph.initializers):
|
|
151
|
+
assert name is not None
|
|
151
152
|
initializer = graph.initializers[name]
|
|
152
153
|
if initializer.is_graph_input():
|
|
153
154
|
# Skip the ones that are also graph inputs
|
|
@@ -156,17 +157,24 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
156
157
|
initializer.name,
|
|
157
158
|
)
|
|
158
159
|
continue
|
|
160
|
+
if initializer.is_graph_output():
|
|
161
|
+
logger.debug(
|
|
162
|
+
"Initializer '%s' is used as output, so it can't be lifted",
|
|
163
|
+
initializer.name,
|
|
164
|
+
)
|
|
165
|
+
continue
|
|
159
166
|
# Remove the initializer from the subgraph
|
|
160
167
|
graph.initializers.pop(name)
|
|
161
168
|
# To avoid name conflicts, we need to rename the initializer
|
|
162
169
|
# to a unique name in the main graph
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
registered_initializer_names[
|
|
170
|
+
new_name = name
|
|
171
|
+
while new_name in model.graph.initializers:
|
|
172
|
+
if name in registered_initializer_names:
|
|
173
|
+
registered_initializer_names[name] += 1
|
|
174
|
+
else:
|
|
175
|
+
registered_initializer_names[name] = 1
|
|
176
|
+
new_name = f"{name}_{registered_initializer_names[name]}"
|
|
177
|
+
initializer.name = new_name
|
|
170
178
|
model.graph.register_initializer(initializer)
|
|
171
179
|
count += 1
|
|
172
180
|
logger.debug(
|
|
@@ -19,6 +19,7 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
|
19
19
|
"""Pass for eliminating redundant Identity nodes.
|
|
20
20
|
|
|
21
21
|
This pass removes Identity nodes according to the following rules:
|
|
22
|
+
|
|
22
23
|
1. For any node of the form `y = Identity(x)`, where `y` is not an output
|
|
23
24
|
of any graph, replace all uses of `y` with a use of `x`, and remove the node.
|
|
24
25
|
2. If `y` is an output of a graph, and `x` is not an input of any graph,
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx_ir as ir
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
|
|
19
|
+
"""Check if the initializer should be skipped for deduplication."""
|
|
20
|
+
if initializer.is_graph_input() or initializer.is_graph_output():
|
|
21
|
+
# Skip graph inputs and outputs
|
|
22
|
+
logger.warning(
|
|
23
|
+
"Skipped deduplication of initializer '%s' as it is a graph input or output",
|
|
24
|
+
initializer.name,
|
|
25
|
+
)
|
|
26
|
+
return True
|
|
27
|
+
|
|
28
|
+
const_val = initializer.const_value
|
|
29
|
+
if const_val is None:
|
|
30
|
+
# Skip if initializer has no constant value
|
|
31
|
+
logger.warning(
|
|
32
|
+
"Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
|
|
33
|
+
initializer.name,
|
|
34
|
+
)
|
|
35
|
+
return True
|
|
36
|
+
|
|
37
|
+
if const_val.size > size_limit:
|
|
38
|
+
# Skip if the initializer is larger than the size limit
|
|
39
|
+
logger.debug(
|
|
40
|
+
"Skipped initializer '%s' as it exceeds the size limit of %d elements",
|
|
41
|
+
initializer.name,
|
|
42
|
+
size_limit,
|
|
43
|
+
)
|
|
44
|
+
return True
|
|
45
|
+
|
|
46
|
+
if const_val.dtype == ir.DataType.STRING:
|
|
47
|
+
# Skip string initializers as they don't have a bytes representation
|
|
48
|
+
logger.warning(
|
|
49
|
+
"Skipped deduplication of string initializer '%s' (unsupported yet)",
|
|
50
|
+
initializer.name,
|
|
51
|
+
)
|
|
52
|
+
return True
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
57
|
+
"""Remove duplicated initializer tensors from the main graph and all subgraphs.
|
|
58
|
+
|
|
59
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
60
|
+
and replaces all duplicate references with a canonical one.
|
|
61
|
+
|
|
62
|
+
Initializers are deduplicated within each graph. To deduplicate initializers
|
|
63
|
+
in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
64
|
+
to lift the initializers to the main graph first before running pass.
|
|
65
|
+
|
|
66
|
+
.. versionadded:: 0.1.3
|
|
67
|
+
.. versionchanged:: 0.1.7
|
|
68
|
+
This pass now deduplicates initializers in subgraphs as well.
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(self, size_limit: int = 1024):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.size_limit = size_limit
|
|
74
|
+
|
|
75
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
76
|
+
modified = False
|
|
77
|
+
|
|
78
|
+
for graph in model.graphs():
|
|
79
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
80
|
+
for initializer in tuple(graph.initializers.values()):
|
|
81
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
82
|
+
continue
|
|
83
|
+
|
|
84
|
+
const_val = initializer.const_value
|
|
85
|
+
assert const_val is not None
|
|
86
|
+
|
|
87
|
+
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
88
|
+
if key in initializers:
|
|
89
|
+
modified = True
|
|
90
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
91
|
+
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
|
|
92
|
+
assert initializer.name is not None
|
|
93
|
+
graph.initializers.pop(initializer.name)
|
|
94
|
+
logger.info(
|
|
95
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
96
|
+
initializer.name,
|
|
97
|
+
initializer_to_keep.name,
|
|
98
|
+
)
|
|
99
|
+
else:
|
|
100
|
+
initializers[key] = initializer # type: ignore[index]
|
|
101
|
+
|
|
102
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
|
|
106
|
+
"""Remove duplicated initializer tensors (using a hashed method) from the graph.
|
|
107
|
+
|
|
108
|
+
This pass detects initializers with identical shape, dtype, and hashed content,
|
|
109
|
+
and replaces all duplicate references with a canonical one.
|
|
110
|
+
|
|
111
|
+
This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
|
|
112
|
+
as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
|
|
113
|
+
|
|
114
|
+
.. versionadded:: 0.1.7
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
|
|
118
|
+
super().__init__()
|
|
119
|
+
# 4 GB default size limit for deduplication
|
|
120
|
+
self.size_limit = size_limit
|
|
121
|
+
|
|
122
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
123
|
+
modified = False
|
|
124
|
+
|
|
125
|
+
for graph in model.graphs():
|
|
126
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
|
|
127
|
+
|
|
128
|
+
for initializer in tuple(graph.initializers.values()):
|
|
129
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
130
|
+
continue
|
|
131
|
+
|
|
132
|
+
const_val = initializer.const_value
|
|
133
|
+
assert const_val is not None
|
|
134
|
+
|
|
135
|
+
# Hash tensor data to avoid storing large amounts of data in memory
|
|
136
|
+
hashed = hashlib.sha512()
|
|
137
|
+
tensor_data = const_val.numpy()
|
|
138
|
+
hashed.update(tensor_data)
|
|
139
|
+
tensor_digest = hashed.hexdigest()
|
|
140
|
+
|
|
141
|
+
tensor_dims = tuple(const_val.shape.numpy())
|
|
142
|
+
|
|
143
|
+
key = (const_val.dtype, tensor_dims, tensor_digest)
|
|
144
|
+
|
|
145
|
+
if key in initializers:
|
|
146
|
+
if initializers[key].const_value.tobytes() != const_val.tobytes():
|
|
147
|
+
logger.warning(
|
|
148
|
+
"Initializer deduplication failed: "
|
|
149
|
+
"hashes match but values differ with values %s and %s",
|
|
150
|
+
initializers[key],
|
|
151
|
+
initializer,
|
|
152
|
+
)
|
|
153
|
+
continue
|
|
154
|
+
modified = True
|
|
155
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
156
|
+
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
|
|
157
|
+
assert initializer.name is not None
|
|
158
|
+
graph.initializers.pop(initializer.name)
|
|
159
|
+
logger.info(
|
|
160
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
161
|
+
initializer.name,
|
|
162
|
+
initializer_to_keep.name,
|
|
163
|
+
)
|
|
164
|
+
else:
|
|
165
|
+
initializers[key] = initializer # type: ignore[index]
|
|
166
|
+
|
|
167
|
+
return ir.passes.PassResult(model=model, modified=modified)
|