onnx-ir 0.1.6__tar.gz → 0.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.6/src/onnx_ir.egg-info → onnx_ir-0.1.8}/PKG-INFO +7 -4
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/README.md +5 -1
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/pyproject.toml +3 -3
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_convenience/__init__.py +49 -32
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_core.py +65 -16
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_enums.py +146 -1
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/__init__.py +2 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/constant_manipulation.py +15 -7
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/identity_elimination.py +1 -0
- onnx_ir-0.1.8/src/onnx_ir/passes/common/initializer_deduplication.py +179 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/naming.py +1 -1
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/serde.py +97 -36
- {onnx_ir-0.1.6 → onnx_ir-0.1.8/src/onnx_ir.egg-info}/PKG-INFO +7 -4
- onnx_ir-0.1.6/src/onnx_ir/passes/common/initializer_deduplication.py +0 -56
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/LICENSE +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/MANIFEST.in +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/setup.cfg +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/py.typed +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/tensor_adapters.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
|
-
License: Apache
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
7
|
Project-URL: Homepage, https://onnx.ai/ir-py
|
|
8
8
|
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
9
|
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
17
16
|
Requires-Python: >=3.9
|
|
18
17
|
Description-Content-Type: text/markdown
|
|
19
18
|
License-File: LICENSE
|
|
@@ -23,7 +22,7 @@ Requires-Dist: typing_extensions>=4.10
|
|
|
23
22
|
Requires-Dist: ml_dtypes
|
|
24
23
|
Dynamic: license-file
|
|
25
24
|
|
|
26
|
-
# ONNX IR
|
|
25
|
+
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
27
26
|
|
|
28
27
|
[](https://pypi.org/project/onnx-ir)
|
|
29
28
|
[](https://pypi.org/project/onnx-ir)
|
|
@@ -61,6 +60,10 @@ pip install git+https://github.com/onnx/ir-py.git
|
|
|
61
60
|
- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
|
|
62
61
|
- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
|
|
63
62
|
|
|
63
|
+
## Concept Diagram
|
|
64
|
+
|
|
65
|
+

|
|
66
|
+
|
|
64
67
|
## Code Organization 🗺️
|
|
65
68
|
|
|
66
69
|
- [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# ONNX IR
|
|
1
|
+
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
2
2
|
|
|
3
3
|
[](https://pypi.org/project/onnx-ir)
|
|
4
4
|
[](https://pypi.org/project/onnx-ir)
|
|
@@ -36,6 +36,10 @@ pip install git+https://github.com/onnx/ir-py.git
|
|
|
36
36
|
- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
|
|
37
37
|
- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
|
|
38
38
|
|
|
39
|
+
## Concept Diagram
|
|
40
|
+
|
|
41
|
+

|
|
42
|
+
|
|
39
43
|
## Code Organization 🗺️
|
|
40
44
|
|
|
41
45
|
- [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
[build-system]
|
|
2
|
-
requires = ["setuptools>=
|
|
2
|
+
requires = ["setuptools>=77"]
|
|
3
3
|
build-backend = "setuptools.build_meta"
|
|
4
4
|
|
|
5
5
|
[project]
|
|
@@ -11,7 +11,8 @@ authors = [
|
|
|
11
11
|
]
|
|
12
12
|
readme = "README.md"
|
|
13
13
|
requires-python = ">=3.9"
|
|
14
|
-
license =
|
|
14
|
+
license = "Apache-2.0"
|
|
15
|
+
license-files = ["LICEN[CS]E*"]
|
|
15
16
|
classifiers = [
|
|
16
17
|
"Development Status :: 4 - Beta",
|
|
17
18
|
"Programming Language :: Python :: 3.9",
|
|
@@ -19,7 +20,6 @@ classifiers = [
|
|
|
19
20
|
"Programming Language :: Python :: 3.11",
|
|
20
21
|
"Programming Language :: Python :: 3.12",
|
|
21
22
|
"Programming Language :: Python :: 3.13",
|
|
22
|
-
"License :: OSI Approved :: Apache Software License",
|
|
23
23
|
]
|
|
24
24
|
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
|
|
25
25
|
|
|
@@ -58,44 +58,52 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
|
|
|
58
58
|
return _enums.AttributeType.STRING
|
|
59
59
|
if isinstance(attr, _core.Attr):
|
|
60
60
|
return attr.type
|
|
61
|
-
if isinstance(attr,
|
|
62
|
-
return _enums.AttributeType.
|
|
63
|
-
if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
|
|
64
|
-
return _enums.AttributeType.FLOATS
|
|
65
|
-
if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
|
|
66
|
-
return _enums.AttributeType.STRINGS
|
|
61
|
+
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
62
|
+
return _enums.AttributeType.GRAPH
|
|
67
63
|
if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
|
|
68
64
|
# Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
|
|
69
65
|
return _enums.AttributeType.TENSOR
|
|
70
|
-
if isinstance(attr, Sequence) and all(
|
|
71
|
-
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
72
|
-
for x in attr
|
|
73
|
-
):
|
|
74
|
-
return _enums.AttributeType.TENSORS
|
|
75
|
-
if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
|
|
76
|
-
return _enums.AttributeType.GRAPH
|
|
77
|
-
if isinstance(attr, Sequence) and all(
|
|
78
|
-
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
|
|
79
|
-
):
|
|
80
|
-
return _enums.AttributeType.GRAPHS
|
|
81
66
|
if isinstance(
|
|
82
67
|
attr,
|
|
83
68
|
(_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
|
|
84
69
|
):
|
|
85
70
|
return _enums.AttributeType.TYPE_PROTO
|
|
86
|
-
if isinstance(attr, Sequence)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
71
|
+
if isinstance(attr, Sequence):
|
|
72
|
+
if not attr:
|
|
73
|
+
logger.warning(
|
|
74
|
+
"Attribute type is ambiguous because it is an empty sequence. "
|
|
75
|
+
"Please create an Attr with an explicit type. Defaulted to INTS"
|
|
76
|
+
)
|
|
77
|
+
return _enums.AttributeType.INTS
|
|
78
|
+
if all(isinstance(x, int) for x in attr):
|
|
79
|
+
return _enums.AttributeType.INTS
|
|
80
|
+
if all(isinstance(x, float) for x in attr):
|
|
81
|
+
return _enums.AttributeType.FLOATS
|
|
82
|
+
if all(isinstance(x, str) for x in attr):
|
|
83
|
+
return _enums.AttributeType.STRINGS
|
|
84
|
+
if all(
|
|
85
|
+
isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
|
|
86
|
+
for x in attr
|
|
87
|
+
):
|
|
88
|
+
return _enums.AttributeType.TENSORS
|
|
89
|
+
if all(
|
|
90
|
+
isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
|
|
91
|
+
for x in attr
|
|
92
|
+
):
|
|
93
|
+
return _enums.AttributeType.GRAPHS
|
|
94
|
+
if all(
|
|
95
|
+
isinstance(
|
|
96
|
+
x,
|
|
97
|
+
(
|
|
98
|
+
_core.TensorType,
|
|
99
|
+
_core.SequenceType,
|
|
100
|
+
_core.OptionalType,
|
|
101
|
+
_protocols.TypeProtocol,
|
|
102
|
+
),
|
|
103
|
+
)
|
|
104
|
+
for x in attr
|
|
105
|
+
):
|
|
106
|
+
return _enums.AttributeType.TYPE_PROTOS
|
|
99
107
|
raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
|
|
100
108
|
|
|
101
109
|
|
|
@@ -218,7 +226,7 @@ def convert_attributes(
|
|
|
218
226
|
... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
|
|
219
227
|
... }
|
|
220
228
|
>>> convert_attributes(attrs)
|
|
221
|
-
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph',
|
|
229
|
+
[Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
|
|
222
230
|
name='graph0',
|
|
223
231
|
inputs=(
|
|
224
232
|
<BLANKLINE>
|
|
@@ -247,11 +255,20 @@ def convert_attributes(
|
|
|
247
255
|
len()=0
|
|
248
256
|
)]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
|
|
249
257
|
|
|
258
|
+
.. important::
|
|
259
|
+
An empty sequence should be created with an explicit type by initializing
|
|
260
|
+
an Attr object with an attribute type to avoid type ambiguity. For example::
|
|
261
|
+
|
|
262
|
+
ir.Attr("empty", [], type=ir.AttributeType.INTS)
|
|
263
|
+
|
|
250
264
|
Args:
|
|
251
265
|
attrs: A dictionary of {<attribute name>: <python objects>} to convert.
|
|
252
266
|
|
|
253
267
|
Returns:
|
|
254
|
-
A list of _core.Attr objects.
|
|
268
|
+
A list of :class:`_core.Attr` objects.
|
|
269
|
+
|
|
270
|
+
Raises:
|
|
271
|
+
TypeError: If an attribute type is not supported.
|
|
255
272
|
"""
|
|
256
273
|
attributes: list[_core.Attr] = []
|
|
257
274
|
for name, attr in attrs.items():
|
|
@@ -836,6 +836,11 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
|
|
|
836
836
|
"""The shape of the tensor. Immutable."""
|
|
837
837
|
return self._shape
|
|
838
838
|
|
|
839
|
+
@property
|
|
840
|
+
def nbytes(self) -> int:
|
|
841
|
+
"""The number of bytes in the tensor."""
|
|
842
|
+
return sum(len(string) for string in self.string_data())
|
|
843
|
+
|
|
839
844
|
@property
|
|
840
845
|
def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
|
|
841
846
|
"""Backing data of the tensor. Immutable."""
|
|
@@ -2564,14 +2569,23 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
|
|
|
2564
2569
|
|
|
2565
2570
|
.. versionadded:: 0.1.2
|
|
2566
2571
|
"""
|
|
2567
|
-
|
|
2568
|
-
|
|
2569
|
-
|
|
2572
|
+
# Use a dict to preserve order
|
|
2573
|
+
seen_graphs: dict[Graph, None] = {}
|
|
2574
|
+
|
|
2575
|
+
# Need to use the enter_graph callback so that empty subgraphs are collected
|
|
2576
|
+
def enter_subgraph(graph) -> None:
|
|
2570
2577
|
if graph is self:
|
|
2571
|
-
|
|
2572
|
-
if
|
|
2573
|
-
|
|
2574
|
-
|
|
2578
|
+
return
|
|
2579
|
+
if not isinstance(graph, Graph):
|
|
2580
|
+
raise TypeError(
|
|
2581
|
+
f"Expected a Graph, got {type(graph)}. The model may be invalid"
|
|
2582
|
+
)
|
|
2583
|
+
if graph not in seen_graphs:
|
|
2584
|
+
seen_graphs[graph] = None
|
|
2585
|
+
|
|
2586
|
+
for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
|
|
2587
|
+
pass
|
|
2588
|
+
yield from seen_graphs.keys()
|
|
2575
2589
|
|
|
2576
2590
|
# Mutation methods
|
|
2577
2591
|
def append(self, node: Node, /) -> None:
|
|
@@ -3180,6 +3194,21 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3180
3194
|
def attributes(self) -> _graph_containers.Attributes:
|
|
3181
3195
|
return self._attributes
|
|
3182
3196
|
|
|
3197
|
+
@property
|
|
3198
|
+
def graph(self) -> Graph:
|
|
3199
|
+
"""The underlying Graph object that contains the nodes of this function.
|
|
3200
|
+
|
|
3201
|
+
Only use this graph for identity comparison::
|
|
3202
|
+
|
|
3203
|
+
if value.graph is function.graph:
|
|
3204
|
+
# Do something with the value that belongs to this function
|
|
3205
|
+
|
|
3206
|
+
Otherwise use the Function object directly to access the nodes and other properties.
|
|
3207
|
+
|
|
3208
|
+
.. versionadded:: 0.1.7
|
|
3209
|
+
"""
|
|
3210
|
+
return self._graph
|
|
3211
|
+
|
|
3183
3212
|
@typing.overload
|
|
3184
3213
|
def __getitem__(self, index: int) -> Node: ...
|
|
3185
3214
|
@typing.overload
|
|
@@ -3240,14 +3269,22 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
|
|
|
3240
3269
|
|
|
3241
3270
|
.. versionadded:: 0.1.2
|
|
3242
3271
|
"""
|
|
3243
|
-
seen_graphs:
|
|
3244
|
-
|
|
3245
|
-
|
|
3246
|
-
|
|
3247
|
-
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
|
|
3272
|
+
seen_graphs: dict[Graph, None] = {}
|
|
3273
|
+
|
|
3274
|
+
# Need to use the enter_graph callback so that empty subgraphs are collected
|
|
3275
|
+
def enter_subgraph(graph) -> None:
|
|
3276
|
+
if graph is self:
|
|
3277
|
+
return
|
|
3278
|
+
if not isinstance(graph, Graph):
|
|
3279
|
+
raise TypeError(
|
|
3280
|
+
f"Expected a Graph, got {type(graph)}. The model may be invalid"
|
|
3281
|
+
)
|
|
3282
|
+
if graph not in seen_graphs:
|
|
3283
|
+
seen_graphs[graph] = None
|
|
3284
|
+
|
|
3285
|
+
for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
|
|
3286
|
+
pass
|
|
3287
|
+
yield from seen_graphs.keys()
|
|
3251
3288
|
|
|
3252
3289
|
# Mutation methods
|
|
3253
3290
|
def append(self, node: Node, /) -> None:
|
|
@@ -3349,7 +3386,7 @@ class Attr(
|
|
|
3349
3386
|
):
|
|
3350
3387
|
"""Base class for ONNX attributes or references."""
|
|
3351
3388
|
|
|
3352
|
-
__slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
3389
|
+
__slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
|
|
3353
3390
|
|
|
3354
3391
|
def __init__(
|
|
3355
3392
|
self,
|
|
@@ -3365,6 +3402,7 @@ class Attr(
|
|
|
3365
3402
|
self._value = value
|
|
3366
3403
|
self._ref_attr_name = ref_attr_name
|
|
3367
3404
|
self.doc_string = doc_string
|
|
3405
|
+
self._metadata: _metadata.MetadataStore | None = None
|
|
3368
3406
|
|
|
3369
3407
|
@property
|
|
3370
3408
|
def name(self) -> str:
|
|
@@ -3386,6 +3424,17 @@ class Attr(
|
|
|
3386
3424
|
def ref_attr_name(self) -> str | None:
|
|
3387
3425
|
return self._ref_attr_name
|
|
3388
3426
|
|
|
3427
|
+
@property
|
|
3428
|
+
def meta(self) -> _metadata.MetadataStore:
|
|
3429
|
+
"""The metadata store for intermediate analysis.
|
|
3430
|
+
|
|
3431
|
+
Write to the :attr:`metadata_props` if you would like the metadata to be serialized
|
|
3432
|
+
to the ONNX proto.
|
|
3433
|
+
"""
|
|
3434
|
+
if self._metadata is None:
|
|
3435
|
+
self._metadata = _metadata.MetadataStore()
|
|
3436
|
+
return self._metadata
|
|
3437
|
+
|
|
3389
3438
|
def is_ref(self) -> bool:
|
|
3390
3439
|
"""Check if this attribute is a reference attribute."""
|
|
3391
3440
|
return self.ref_attr_name is not None
|
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
from __future__ import annotations
|
|
6
6
|
|
|
7
7
|
import enum
|
|
8
|
+
from typing import Any
|
|
8
9
|
|
|
9
10
|
import ml_dtypes
|
|
10
11
|
import numpy as np
|
|
@@ -77,7 +78,7 @@ class DataType(enum.IntEnum):
|
|
|
77
78
|
if dtype in _NP_TYPE_TO_DATA_TYPE:
|
|
78
79
|
return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
|
|
79
80
|
|
|
80
|
-
if np.issubdtype(dtype, np.str_):
|
|
81
|
+
if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
|
|
81
82
|
return DataType.STRING
|
|
82
83
|
|
|
83
84
|
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
|
|
@@ -131,6 +132,146 @@ class DataType(enum.IntEnum):
|
|
|
131
132
|
raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
|
|
132
133
|
return _BITWIDTH_MAP[self]
|
|
133
134
|
|
|
135
|
+
@property
|
|
136
|
+
def exponent_bitwidth(self) -> int:
|
|
137
|
+
"""Returns the bit width of the exponent for floating-point types.
|
|
138
|
+
|
|
139
|
+
.. versionadded:: 0.1.8
|
|
140
|
+
|
|
141
|
+
Raises:
|
|
142
|
+
TypeError: If the data type is not supported.
|
|
143
|
+
"""
|
|
144
|
+
if self.is_floating_point():
|
|
145
|
+
return ml_dtypes.finfo(self.numpy()).nexp
|
|
146
|
+
|
|
147
|
+
raise TypeError(f"Exponent not available for ONNX data type: {self}")
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def mantissa_bitwidth(self) -> int:
|
|
151
|
+
"""Returns the bit width of the mantissa for floating-point types.
|
|
152
|
+
|
|
153
|
+
.. versionadded:: 0.1.8
|
|
154
|
+
|
|
155
|
+
Raises:
|
|
156
|
+
TypeError: If the data type is not supported.
|
|
157
|
+
"""
|
|
158
|
+
if self.is_floating_point():
|
|
159
|
+
return ml_dtypes.finfo(self.numpy()).nmant
|
|
160
|
+
|
|
161
|
+
raise TypeError(f"Mantissa not available for ONNX data type: {self}")
|
|
162
|
+
|
|
163
|
+
@property
|
|
164
|
+
def eps(self) -> int | np.floating[Any]:
|
|
165
|
+
"""Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
|
|
166
|
+
|
|
167
|
+
Returns 1 for integers.
|
|
168
|
+
|
|
169
|
+
.. versionadded:: 0.1.8
|
|
170
|
+
|
|
171
|
+
Raises:
|
|
172
|
+
TypeError: If the data type is not a numeric data type.
|
|
173
|
+
"""
|
|
174
|
+
if self.is_integer():
|
|
175
|
+
return 1
|
|
176
|
+
|
|
177
|
+
if self.is_floating_point():
|
|
178
|
+
return ml_dtypes.finfo(self.numpy()).eps
|
|
179
|
+
|
|
180
|
+
raise TypeError(f"Eps not available for ONNX data type: {self}")
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def tiny(self) -> int | np.floating[Any]:
|
|
184
|
+
"""Returns the smallest positive non-zero value for the ONNX data type.
|
|
185
|
+
|
|
186
|
+
Returns 1 for integers.
|
|
187
|
+
|
|
188
|
+
.. versionadded:: 0.1.8
|
|
189
|
+
|
|
190
|
+
Raises:
|
|
191
|
+
TypeError: If the data type is not a numeric data type.
|
|
192
|
+
"""
|
|
193
|
+
if self.is_integer():
|
|
194
|
+
return 1
|
|
195
|
+
|
|
196
|
+
if self.is_floating_point():
|
|
197
|
+
return ml_dtypes.finfo(self.numpy()).tiny
|
|
198
|
+
|
|
199
|
+
raise TypeError(f"Tiny not available for ONNX data type: {self}")
|
|
200
|
+
|
|
201
|
+
@property
|
|
202
|
+
def min(self) -> int | np.floating[Any]:
|
|
203
|
+
"""Returns the minimum representable value for the ONNX data type.
|
|
204
|
+
|
|
205
|
+
.. versionadded:: 0.1.8
|
|
206
|
+
|
|
207
|
+
Raises:
|
|
208
|
+
TypeError: If the data type is not a numeric data type.
|
|
209
|
+
"""
|
|
210
|
+
if self.is_integer():
|
|
211
|
+
return ml_dtypes.iinfo(self.numpy()).min
|
|
212
|
+
|
|
213
|
+
if self.is_floating_point():
|
|
214
|
+
return ml_dtypes.finfo(self.numpy()).min
|
|
215
|
+
|
|
216
|
+
raise TypeError(f"Minimum not available for ONNX data type: {self}")
|
|
217
|
+
|
|
218
|
+
@property
|
|
219
|
+
def max(self) -> int | np.floating[Any]:
|
|
220
|
+
"""Returns the maximum representable value for the ONNX data type.
|
|
221
|
+
|
|
222
|
+
.. versionadded:: 0.1.8
|
|
223
|
+
|
|
224
|
+
Raises:
|
|
225
|
+
TypeError: If the data type is not a numeric data type.
|
|
226
|
+
"""
|
|
227
|
+
if self.is_integer():
|
|
228
|
+
return ml_dtypes.iinfo(self.numpy()).max
|
|
229
|
+
|
|
230
|
+
if self.is_floating_point():
|
|
231
|
+
return ml_dtypes.finfo(self.numpy()).max
|
|
232
|
+
|
|
233
|
+
raise TypeError(f"Maximum not available for ONNX data type: {self}")
|
|
234
|
+
|
|
235
|
+
@property
|
|
236
|
+
def precision(self) -> int:
|
|
237
|
+
"""Returns the precision for the ONNX dtype if supported.
|
|
238
|
+
|
|
239
|
+
For floats returns the approximate number of decimal digits to which
|
|
240
|
+
this kind of float is precise. Returns 0 for integers.
|
|
241
|
+
|
|
242
|
+
.. versionadded:: 0.1.8
|
|
243
|
+
|
|
244
|
+
Raises:
|
|
245
|
+
TypeError: If the data type is not a numeric data type.
|
|
246
|
+
"""
|
|
247
|
+
if self.is_integer():
|
|
248
|
+
return 0
|
|
249
|
+
|
|
250
|
+
if self.is_floating_point():
|
|
251
|
+
return ml_dtypes.finfo(self.numpy()).precision
|
|
252
|
+
|
|
253
|
+
raise TypeError(f"Precision not available for ONNX data type: {self}")
|
|
254
|
+
|
|
255
|
+
@property
|
|
256
|
+
def resolution(self) -> int | np.floating[Any]:
|
|
257
|
+
"""Returns the resolution for the ONNX dtype if supported.
|
|
258
|
+
|
|
259
|
+
Returns the approximate decimal resolution of this type, i.e.,
|
|
260
|
+
10**-precision. Returns 1 for integers.
|
|
261
|
+
|
|
262
|
+
.. versionadded:: 0.1.8
|
|
263
|
+
|
|
264
|
+
Raises:
|
|
265
|
+
TypeError: If the data type is not a numeric data type.
|
|
266
|
+
"""
|
|
267
|
+
if self.is_integer():
|
|
268
|
+
return 1
|
|
269
|
+
|
|
270
|
+
if self.is_floating_point():
|
|
271
|
+
return ml_dtypes.finfo(self.numpy()).resolution
|
|
272
|
+
|
|
273
|
+
raise TypeError(f"Resolution not available for ONNX data type: {self}")
|
|
274
|
+
|
|
134
275
|
def numpy(self) -> np.dtype:
|
|
135
276
|
"""Returns the numpy dtype for the ONNX data type.
|
|
136
277
|
|
|
@@ -215,6 +356,10 @@ class DataType(enum.IntEnum):
|
|
|
215
356
|
DataType.FLOAT8E8M0,
|
|
216
357
|
}
|
|
217
358
|
|
|
359
|
+
def is_string(self) -> bool:
|
|
360
|
+
"""Returns True if the data type is a string type."""
|
|
361
|
+
return self == DataType.STRING
|
|
362
|
+
|
|
218
363
|
def __repr__(self) -> str:
|
|
219
364
|
return self.name
|
|
220
365
|
|
|
@@ -6,6 +6,7 @@ __all__ = [
|
|
|
6
6
|
"CheckerPass",
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
|
+
"DeduplicateHashedInitializersPass",
|
|
9
10
|
"DeduplicateInitializersPass",
|
|
10
11
|
"IdentityEliminationPass",
|
|
11
12
|
"InlinePass",
|
|
@@ -36,6 +37,7 @@ from onnx_ir.passes.common.identity_elimination import (
|
|
|
36
37
|
IdentityEliminationPass,
|
|
37
38
|
)
|
|
38
39
|
from onnx_ir.passes.common.initializer_deduplication import (
|
|
40
|
+
DeduplicateHashedInitializersPass,
|
|
39
41
|
DeduplicateInitializersPass,
|
|
40
42
|
)
|
|
41
43
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
@@ -148,6 +148,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
148
148
|
if graph is model.graph:
|
|
149
149
|
continue
|
|
150
150
|
for name in tuple(graph.initializers):
|
|
151
|
+
assert name is not None
|
|
151
152
|
initializer = graph.initializers[name]
|
|
152
153
|
if initializer.is_graph_input():
|
|
153
154
|
# Skip the ones that are also graph inputs
|
|
@@ -156,17 +157,24 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
|
|
|
156
157
|
initializer.name,
|
|
157
158
|
)
|
|
158
159
|
continue
|
|
160
|
+
if initializer.is_graph_output():
|
|
161
|
+
logger.debug(
|
|
162
|
+
"Initializer '%s' is used as output, so it can't be lifted",
|
|
163
|
+
initializer.name,
|
|
164
|
+
)
|
|
165
|
+
continue
|
|
159
166
|
# Remove the initializer from the subgraph
|
|
160
167
|
graph.initializers.pop(name)
|
|
161
168
|
# To avoid name conflicts, we need to rename the initializer
|
|
162
169
|
# to a unique name in the main graph
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
registered_initializer_names[
|
|
170
|
+
new_name = name
|
|
171
|
+
while new_name in model.graph.initializers:
|
|
172
|
+
if name in registered_initializer_names:
|
|
173
|
+
registered_initializer_names[name] += 1
|
|
174
|
+
else:
|
|
175
|
+
registered_initializer_names[name] = 1
|
|
176
|
+
new_name = f"{name}_{registered_initializer_names[name]}"
|
|
177
|
+
initializer.name = new_name
|
|
170
178
|
model.graph.register_initializer(initializer)
|
|
171
179
|
count += 1
|
|
172
180
|
logger.debug(
|
|
@@ -19,6 +19,7 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
|
19
19
|
"""Pass for eliminating redundant Identity nodes.
|
|
20
20
|
|
|
21
21
|
This pass removes Identity nodes according to the following rules:
|
|
22
|
+
|
|
22
23
|
1. For any node of the form `y = Identity(x)`, where `y` is not an output
|
|
23
24
|
of any graph, replace all uses of `y` with a use of `x`, and remove the node.
|
|
24
25
|
2. If `y` is an output of a graph, and `x` is not an input of any graph,
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
import hashlib
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
import onnx_ir as ir
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
|
|
21
|
+
"""Check if the initializer should be skipped for deduplication."""
|
|
22
|
+
if initializer.is_graph_input() or initializer.is_graph_output():
|
|
23
|
+
# Skip graph inputs and outputs
|
|
24
|
+
logger.warning(
|
|
25
|
+
"Skipped deduplication of initializer '%s' as it is a graph input or output",
|
|
26
|
+
initializer.name,
|
|
27
|
+
)
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
const_val = initializer.const_value
|
|
31
|
+
if const_val is None:
|
|
32
|
+
# Skip if initializer has no constant value
|
|
33
|
+
logger.warning(
|
|
34
|
+
"Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
|
|
35
|
+
initializer.name,
|
|
36
|
+
)
|
|
37
|
+
return True
|
|
38
|
+
|
|
39
|
+
if const_val.size > size_limit:
|
|
40
|
+
# Skip if the initializer is larger than the size limit
|
|
41
|
+
logger.debug(
|
|
42
|
+
"Skipped initializer '%s' as it exceeds the size limit of %d elements",
|
|
43
|
+
initializer.name,
|
|
44
|
+
size_limit,
|
|
45
|
+
)
|
|
46
|
+
return True
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _tobytes(val):
|
|
51
|
+
"""StringTensor does not support tobytes. Use 'string_data' instead.
|
|
52
|
+
|
|
53
|
+
However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
|
|
54
|
+
cannot be used to index into a dict. To generate keys for identifying
|
|
55
|
+
tensors in initializer deduplication the following converts the list of
|
|
56
|
+
bytes to an array of fixed-length strings which can be flattened into a
|
|
57
|
+
bytes-string. This, together with the tensor shape, is sufficient for
|
|
58
|
+
identifying tensors for deduplication, but it differs from the
|
|
59
|
+
representation used for serializing tensors (that is string_data) by adding
|
|
60
|
+
padding bytes so that each string occupies the same number of consecutive
|
|
61
|
+
bytes in the flattened .tobytes representation.
|
|
62
|
+
"""
|
|
63
|
+
if val.dtype.is_string():
|
|
64
|
+
return np.array(val.string_data()).tobytes()
|
|
65
|
+
return val.tobytes()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
69
|
+
"""Remove duplicated initializer tensors from the main graph and all subgraphs.
|
|
70
|
+
|
|
71
|
+
This pass detects initializers with identical shape, dtype, and content,
|
|
72
|
+
and replaces all duplicate references with a canonical one.
|
|
73
|
+
|
|
74
|
+
Initializers are deduplicated within each graph. To deduplicate initializers
|
|
75
|
+
in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
76
|
+
to lift the initializers to the main graph first before running pass.
|
|
77
|
+
|
|
78
|
+
.. versionadded:: 0.1.3
|
|
79
|
+
.. versionchanged:: 0.1.7
|
|
80
|
+
This pass now deduplicates initializers in subgraphs as well.
|
|
81
|
+
"""
|
|
82
|
+
|
|
83
|
+
def __init__(self, size_limit: int = 1024):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.size_limit = size_limit
|
|
86
|
+
|
|
87
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
88
|
+
modified = False
|
|
89
|
+
|
|
90
|
+
for graph in model.graphs():
|
|
91
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
92
|
+
for initializer in tuple(graph.initializers.values()):
|
|
93
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
94
|
+
continue
|
|
95
|
+
|
|
96
|
+
const_val = initializer.const_value
|
|
97
|
+
assert const_val is not None
|
|
98
|
+
|
|
99
|
+
key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
|
|
100
|
+
if key in initializers:
|
|
101
|
+
modified = True
|
|
102
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
103
|
+
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
|
|
104
|
+
assert initializer.name is not None
|
|
105
|
+
graph.initializers.pop(initializer.name)
|
|
106
|
+
logger.info(
|
|
107
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
108
|
+
initializer.name,
|
|
109
|
+
initializer_to_keep.name,
|
|
110
|
+
)
|
|
111
|
+
else:
|
|
112
|
+
initializers[key] = initializer # type: ignore[index]
|
|
113
|
+
|
|
114
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
|
|
118
|
+
"""Remove duplicated initializer tensors (using a hashed method) from the graph.
|
|
119
|
+
|
|
120
|
+
This pass detects initializers with identical shape, dtype, and hashed content,
|
|
121
|
+
and replaces all duplicate references with a canonical one.
|
|
122
|
+
|
|
123
|
+
This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
|
|
124
|
+
as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
|
|
125
|
+
|
|
126
|
+
.. versionadded:: 0.1.7
|
|
127
|
+
"""
|
|
128
|
+
|
|
129
|
+
def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
|
|
130
|
+
super().__init__()
|
|
131
|
+
# 4 GB default size limit for deduplication
|
|
132
|
+
self.size_limit = size_limit
|
|
133
|
+
|
|
134
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
135
|
+
modified = False
|
|
136
|
+
|
|
137
|
+
for graph in model.graphs():
|
|
138
|
+
initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
|
|
139
|
+
|
|
140
|
+
for initializer in tuple(graph.initializers.values()):
|
|
141
|
+
if _should_skip_initializer(initializer, self.size_limit):
|
|
142
|
+
continue
|
|
143
|
+
|
|
144
|
+
const_val = initializer.const_value
|
|
145
|
+
assert const_val is not None
|
|
146
|
+
|
|
147
|
+
# Hash tensor data to avoid storing large amounts of data in memory
|
|
148
|
+
hashed = hashlib.sha512()
|
|
149
|
+
tensor_data = const_val.numpy()
|
|
150
|
+
hashed.update(tensor_data)
|
|
151
|
+
tensor_digest = hashed.hexdigest()
|
|
152
|
+
|
|
153
|
+
tensor_dims = tuple(const_val.shape.numpy())
|
|
154
|
+
|
|
155
|
+
key = (const_val.dtype, tensor_dims, tensor_digest)
|
|
156
|
+
|
|
157
|
+
if key in initializers:
|
|
158
|
+
if _tobytes(initializers[key].const_value) != _tobytes(const_val):
|
|
159
|
+
logger.warning(
|
|
160
|
+
"Initializer deduplication failed: "
|
|
161
|
+
"hashes match but values differ with values %s and %s",
|
|
162
|
+
initializers[key],
|
|
163
|
+
initializer,
|
|
164
|
+
)
|
|
165
|
+
continue
|
|
166
|
+
modified = True
|
|
167
|
+
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
168
|
+
ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
|
|
169
|
+
assert initializer.name is not None
|
|
170
|
+
graph.initializers.pop(initializer.name)
|
|
171
|
+
logger.info(
|
|
172
|
+
"Replaced initializer '%s' with existing initializer '%s'",
|
|
173
|
+
initializer.name,
|
|
174
|
+
initializer_to_keep.name,
|
|
175
|
+
)
|
|
176
|
+
else:
|
|
177
|
+
initializers[key] = initializer # type: ignore[index]
|
|
178
|
+
|
|
179
|
+
return ir.passes.PassResult(model=model, modified=modified)
|
|
@@ -64,7 +64,7 @@ class NameFixPass(ir.passes.InPlacePass):
|
|
|
64
64
|
def custom_value_name(value: ir.Value) -> str:
|
|
65
65
|
return f"custom_value_{value.type}"
|
|
66
66
|
|
|
67
|
-
name_fix_pass = NameFixPass(
|
|
67
|
+
name_fix_pass = NameFixPass(name_generator=CustomNameGenerator())
|
|
68
68
|
|
|
69
69
|
.. versionadded:: 0.1.6
|
|
70
70
|
"""
|
|
@@ -682,8 +682,8 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
|
|
|
682
682
|
Returns:
|
|
683
683
|
IR Graph.
|
|
684
684
|
|
|
685
|
-
.. versionadded:: 0.3
|
|
686
|
-
Support for
|
|
685
|
+
.. versionadded:: 0.1.3
|
|
686
|
+
Support for `quantization_annotation` is added.
|
|
687
687
|
"""
|
|
688
688
|
return _deserialize_graph(proto, [])
|
|
689
689
|
|
|
@@ -760,6 +760,18 @@ def _deserialize_graph(
|
|
|
760
760
|
# Build the value info dictionary to allow for quick lookup for this graph scope
|
|
761
761
|
value_info = {info.name: info for info in proto.value_info}
|
|
762
762
|
|
|
763
|
+
# Declare values for all node outputs from this graph scope. This is necessary
|
|
764
|
+
# to handle the case where a node in a subgraph uses a value that is declared out
|
|
765
|
+
# of order in the outer graph. Declaring the values first allows us to find the
|
|
766
|
+
# values later when deserializing the nodes in subgraphs.
|
|
767
|
+
for node in proto.node:
|
|
768
|
+
_declare_node_outputs(
|
|
769
|
+
node,
|
|
770
|
+
values,
|
|
771
|
+
value_info=value_info,
|
|
772
|
+
quantization_annotations=quantization_annotations,
|
|
773
|
+
)
|
|
774
|
+
|
|
763
775
|
# Deserialize nodes with all known values
|
|
764
776
|
nodes = [
|
|
765
777
|
_deserialize_node(node, scoped_values, value_info, quantization_annotations)
|
|
@@ -798,6 +810,55 @@ def _deserialize_graph(
|
|
|
798
810
|
)
|
|
799
811
|
|
|
800
812
|
|
|
813
|
+
def _declare_node_outputs(
|
|
814
|
+
proto: onnx.NodeProto,
|
|
815
|
+
current_value_scope: dict[str, _core.Value],
|
|
816
|
+
value_info: dict[str, onnx.ValueInfoProto],
|
|
817
|
+
quantization_annotations: dict[str, onnx.TensorAnnotation],
|
|
818
|
+
) -> None:
|
|
819
|
+
"""Declare outputs for a node in the current graph scope.
|
|
820
|
+
|
|
821
|
+
This is necessary to handle the case where a node in a subgraph uses a value that is declared
|
|
822
|
+
out of order in the outer graph. Declaring the values first allows us to find the values later
|
|
823
|
+
when deserializing the nodes in subgraphs.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
proto: The ONNX NodeProto to declare outputs for.
|
|
827
|
+
current_value_scope: The current scope of values, mapping value names to their corresponding Value objects.
|
|
828
|
+
value_info: A dictionary mapping value names to their corresponding ValueInfoProto.
|
|
829
|
+
quantization_annotations: A dictionary mapping tensor names to their corresponding TensorAnnotation.
|
|
830
|
+
|
|
831
|
+
Raises:
|
|
832
|
+
ValueError: If an output name is redeclared in the current graph scope.
|
|
833
|
+
"""
|
|
834
|
+
for output_name in proto.output:
|
|
835
|
+
if output_name == "":
|
|
836
|
+
continue
|
|
837
|
+
if output_name in current_value_scope:
|
|
838
|
+
raise ValueError(
|
|
839
|
+
f"Output '{output_name}' is redeclared in the current graph scope. "
|
|
840
|
+
f"Original declaration {current_value_scope[output_name]}. "
|
|
841
|
+
f"New declaration: by operator '{proto.op_type}' of node '{proto.name}'. "
|
|
842
|
+
"The model is invalid"
|
|
843
|
+
)
|
|
844
|
+
|
|
845
|
+
# Create the value and add it to the current scope.
|
|
846
|
+
value = _core.Value(name=output_name)
|
|
847
|
+
current_value_scope[output_name] = value
|
|
848
|
+
# Fill in shape/type information if they exist
|
|
849
|
+
if output_name in value_info:
|
|
850
|
+
deserialize_value_info_proto(value_info[output_name], value)
|
|
851
|
+
else:
|
|
852
|
+
logger.debug(
|
|
853
|
+
"ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
|
|
854
|
+
output_name,
|
|
855
|
+
proto.name,
|
|
856
|
+
proto.op_type,
|
|
857
|
+
)
|
|
858
|
+
if output_name in quantization_annotations:
|
|
859
|
+
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
|
|
860
|
+
|
|
861
|
+
|
|
801
862
|
@_capture_errors(lambda proto: proto.name)
|
|
802
863
|
def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
803
864
|
"""Deserialize an ONNX FunctionProto into an IR Function.
|
|
@@ -812,7 +873,14 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
|
|
|
812
873
|
values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
|
|
813
874
|
value_info = {info.name: info for info in getattr(proto, "value_info", [])}
|
|
814
875
|
|
|
815
|
-
|
|
876
|
+
for node in proto.node:
|
|
877
|
+
_declare_node_outputs(
|
|
878
|
+
node,
|
|
879
|
+
values,
|
|
880
|
+
value_info=value_info,
|
|
881
|
+
quantization_annotations={},
|
|
882
|
+
)
|
|
883
|
+
|
|
816
884
|
nodes = [
|
|
817
885
|
_deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
|
|
818
886
|
for node in proto.node
|
|
@@ -1137,8 +1205,15 @@ def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
|
|
|
1137
1205
|
Returns:
|
|
1138
1206
|
An IR Node object representing the ONNX node.
|
|
1139
1207
|
"""
|
|
1208
|
+
value_scope: dict[str, _core.Value] = {}
|
|
1209
|
+
_declare_node_outputs(
|
|
1210
|
+
proto,
|
|
1211
|
+
value_scope,
|
|
1212
|
+
value_info={},
|
|
1213
|
+
quantization_annotations={},
|
|
1214
|
+
)
|
|
1140
1215
|
return _deserialize_node(
|
|
1141
|
-
proto, scoped_values=[
|
|
1216
|
+
proto, scoped_values=[value_scope], value_info={}, quantization_annotations={}
|
|
1142
1217
|
)
|
|
1143
1218
|
|
|
1144
1219
|
|
|
@@ -1161,18 +1236,18 @@ def _deserialize_node(
|
|
|
1161
1236
|
for values in reversed(scoped_values):
|
|
1162
1237
|
if input_name not in values:
|
|
1163
1238
|
continue
|
|
1239
|
+
|
|
1164
1240
|
node_inputs.append(values[input_name])
|
|
1165
1241
|
found = True
|
|
1166
1242
|
del values # Remove the reference so it is not used by mistake
|
|
1167
1243
|
break
|
|
1168
1244
|
if not found:
|
|
1169
|
-
# If the input is not found, we know the graph
|
|
1170
|
-
#
|
|
1171
|
-
#
|
|
1172
|
-
# Nodes need to check the value pool for potentially initialized outputs
|
|
1245
|
+
# If the input is not found, we know the graph is invalid because the value
|
|
1246
|
+
# is not declared. We will still create a new input for the node so that
|
|
1247
|
+
# it can be fixed later.
|
|
1173
1248
|
logger.warning(
|
|
1174
|
-
"Input '%s' of node '%s(%s::%s:%s)
|
|
1175
|
-
"The
|
|
1249
|
+
"Input '%s' of node '%s' (%s::%s:%s) cannot be found in any scope. "
|
|
1250
|
+
"The model is invalid but we will still create a new input for the node (current depth: %s)",
|
|
1176
1251
|
input_name,
|
|
1177
1252
|
proto.name,
|
|
1178
1253
|
proto.domain,
|
|
@@ -1208,35 +1283,22 @@ def _deserialize_node(
|
|
|
1208
1283
|
node_outputs.append(_core.Value(name=""))
|
|
1209
1284
|
continue
|
|
1210
1285
|
|
|
1211
|
-
#
|
|
1286
|
+
# The outputs should already be declared in the current scope by _declare_node_outputs.
|
|
1287
|
+
#
|
|
1288
|
+
# When the graph is unsorted, we may be able to find the output already created
|
|
1212
1289
|
# as an input to some other nodes in the current scope.
|
|
1213
1290
|
# Note that a value is always owned by the producing node. Even though a value
|
|
1214
1291
|
# can be created when parsing inputs of other nodes, the new node created here
|
|
1215
1292
|
# that produces the value will assume ownership. It is then impossible to transfer
|
|
1216
1293
|
# the ownership to any other node.
|
|
1217
|
-
|
|
1294
|
+
#
|
|
1218
1295
|
# The output can only be found in the current scope. It is impossible for
|
|
1219
1296
|
# a node to produce an output that is not in its own scope.
|
|
1220
1297
|
current_scope = scoped_values[-1]
|
|
1221
|
-
|
|
1222
|
-
|
|
1223
|
-
|
|
1224
|
-
|
|
1225
|
-
# Create the value and add it to the current scope.
|
|
1226
|
-
value = _core.Value(name=output_name)
|
|
1227
|
-
current_scope[output_name] = value
|
|
1228
|
-
# Fill in shape/type information if they exist
|
|
1229
|
-
if output_name in value_info:
|
|
1230
|
-
deserialize_value_info_proto(value_info[output_name], value)
|
|
1231
|
-
else:
|
|
1232
|
-
logger.debug(
|
|
1233
|
-
"ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
|
|
1234
|
-
output_name,
|
|
1235
|
-
proto.name,
|
|
1236
|
-
proto.op_type,
|
|
1237
|
-
)
|
|
1238
|
-
if output_name in quantization_annotations:
|
|
1239
|
-
_deserialize_quantization_annotation(quantization_annotations[output_name], value)
|
|
1298
|
+
assert output_name in current_scope, (
|
|
1299
|
+
f"Output '{output_name}' not found in the current scope. This is unexpected"
|
|
1300
|
+
)
|
|
1301
|
+
value = current_scope[output_name]
|
|
1240
1302
|
node_outputs.append(value)
|
|
1241
1303
|
return _core.Node(
|
|
1242
1304
|
proto.domain,
|
|
@@ -1469,8 +1531,6 @@ def serialize_graph_into(
|
|
|
1469
1531
|
serialize_value_into(graph_proto.input.add(), input_)
|
|
1470
1532
|
if input_.name not in from_.initializers:
|
|
1471
1533
|
# Annotations for initializers will be added below to avoid double adding
|
|
1472
|
-
# TODO(justinchuby): We should add a method is_initializer() on Value when
|
|
1473
|
-
# the initializer list is tracked
|
|
1474
1534
|
_maybe_add_quantization_annotation(graph_proto, input_)
|
|
1475
1535
|
input_names = {input_.name for input_ in from_.inputs}
|
|
1476
1536
|
# TODO(justinchuby): Support sparse_initializer
|
|
@@ -1724,11 +1784,12 @@ def _fill_in_value_for_attribute(
|
|
|
1724
1784
|
) -> None:
|
|
1725
1785
|
if type_ == _enums.AttributeType.INT:
|
|
1726
1786
|
# value: int
|
|
1727
|
-
|
|
1787
|
+
# Cast bool to int, for example
|
|
1788
|
+
attribute_proto.i = int(value)
|
|
1728
1789
|
attribute_proto.type = onnx.AttributeProto.INT
|
|
1729
1790
|
elif type_ == _enums.AttributeType.FLOAT:
|
|
1730
1791
|
# value: float
|
|
1731
|
-
attribute_proto.f = value
|
|
1792
|
+
attribute_proto.f = float(value)
|
|
1732
1793
|
attribute_proto.type = onnx.AttributeProto.FLOAT
|
|
1733
1794
|
elif type_ == _enums.AttributeType.STRING:
|
|
1734
1795
|
# value: str
|
|
@@ -1818,7 +1879,7 @@ def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.
|
|
|
1818
1879
|
return value_info_proto
|
|
1819
1880
|
|
|
1820
1881
|
|
|
1821
|
-
@_capture_errors(lambda value_info_proto, from_: repr(from_))
|
|
1882
|
+
@_capture_errors(lambda value_info_proto, from_, name="": repr(from_))
|
|
1822
1883
|
def serialize_value_into(
|
|
1823
1884
|
value_info_proto: onnx.ValueInfoProto,
|
|
1824
1885
|
from_: _protocols.ValueProtocol,
|
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.8
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
|
-
License: Apache
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
7
|
Project-URL: Homepage, https://onnx.ai/ir-py
|
|
8
8
|
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
9
|
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.11
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.12
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.13
|
|
16
|
-
Classifier: License :: OSI Approved :: Apache Software License
|
|
17
16
|
Requires-Python: >=3.9
|
|
18
17
|
Description-Content-Type: text/markdown
|
|
19
18
|
License-File: LICENSE
|
|
@@ -23,7 +22,7 @@ Requires-Dist: typing_extensions>=4.10
|
|
|
23
22
|
Requires-Dist: ml_dtypes
|
|
24
23
|
Dynamic: license-file
|
|
25
24
|
|
|
26
|
-
# ONNX IR
|
|
25
|
+
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
27
26
|
|
|
28
27
|
[](https://pypi.org/project/onnx-ir)
|
|
29
28
|
[](https://pypi.org/project/onnx-ir)
|
|
@@ -61,6 +60,10 @@ pip install git+https://github.com/onnx/ir-py.git
|
|
|
61
60
|
- Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
|
|
62
61
|
- No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
|
|
63
62
|
|
|
63
|
+
## Concept Diagram
|
|
64
|
+
|
|
65
|
+

|
|
66
|
+
|
|
64
67
|
## Code Organization 🗺️
|
|
65
68
|
|
|
66
69
|
- [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
|
|
@@ -1,56 +0,0 @@
|
|
|
1
|
-
# Copyright (c) ONNX Project Contributors
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
"""Pass for removing duplicated initializer tensors from a graph."""
|
|
4
|
-
|
|
5
|
-
from __future__ import annotations
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"DeduplicateInitializersPass",
|
|
9
|
-
]
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
import onnx_ir as ir
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
16
|
-
"""Remove duplicated initializer tensors from the graph.
|
|
17
|
-
|
|
18
|
-
This pass detects initializers with identical shape, dtype, and content,
|
|
19
|
-
and replaces all duplicate references with a canonical one.
|
|
20
|
-
|
|
21
|
-
To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
|
|
22
|
-
to lift the initializers to the main graph first before running pass.
|
|
23
|
-
|
|
24
|
-
.. versionadded:: 0.1.3
|
|
25
|
-
"""
|
|
26
|
-
|
|
27
|
-
def __init__(self, size_limit: int = 1024):
|
|
28
|
-
super().__init__()
|
|
29
|
-
self.size_limit = size_limit
|
|
30
|
-
|
|
31
|
-
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
32
|
-
graph = model.graph
|
|
33
|
-
initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
|
|
34
|
-
modified = False
|
|
35
|
-
|
|
36
|
-
for initializer in tuple(graph.initializers.values()):
|
|
37
|
-
# TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
|
|
38
|
-
# out from the main graph before running this pass.
|
|
39
|
-
const_val = initializer.const_value
|
|
40
|
-
if const_val is None:
|
|
41
|
-
# Skip if initializer has no constant value
|
|
42
|
-
continue
|
|
43
|
-
|
|
44
|
-
if const_val.size > self.size_limit:
|
|
45
|
-
continue
|
|
46
|
-
|
|
47
|
-
key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
|
|
48
|
-
if key in initializers:
|
|
49
|
-
modified = True
|
|
50
|
-
ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
|
|
51
|
-
assert initializer.name is not None
|
|
52
|
-
graph.initializers.pop(initializer.name)
|
|
53
|
-
else:
|
|
54
|
-
initializers[key] = initializer # type: ignore[index]
|
|
55
|
-
|
|
56
|
-
return ir.passes.PassResult(model=model, modified=modified)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/common_subexpression_elimination.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|