onnx-ir 0.1.4__tar.gz → 0.1.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.4/src/onnx_ir.egg-info → onnx_ir-0.1.6}/PKG-INFO +2 -2
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/README.md +1 -1
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_core.py +4 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_enums.py +7 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/__init__.py +6 -0
- onnx_ir-0.1.6/src/onnx_ir/passes/common/identity_elimination.py +97 -0
- onnx_ir-0.1.6/src/onnx_ir/passes/common/naming.py +286 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/serde.py +2 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/tensor_adapters.py +15 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/traversal.py +35 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6/src/onnx_ir.egg-info}/PKG-INFO +2 -2
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/SOURCES.txt +2 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/LICENSE +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/MANIFEST.in +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/pyproject.toml +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/setup.cfg +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_convenience/__init__.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/py.typed +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License: Apache License v2.0
|
|
@@ -29,7 +29,7 @@ Dynamic: license-file
|
|
|
29
29
|
[](https://pypi.org/project/onnx-ir)
|
|
30
30
|
[](https://github.com/astral-sh/ruff)
|
|
31
31
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
32
|
-
[](https://pepy.tech/projects/onnx-ir)
|
|
33
33
|
|
|
34
34
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
35
35
|
|
|
@@ -4,7 +4,7 @@
|
|
|
4
4
|
[](https://pypi.org/project/onnx-ir)
|
|
5
5
|
[](https://github.com/astral-sh/ruff)
|
|
6
6
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
7
|
-
[](https://pepy.tech/projects/onnx-ir)
|
|
8
8
|
|
|
9
9
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
10
10
|
|
|
@@ -78,6 +78,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
|
|
|
78
78
|
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
79
79
|
_enums.DataType.FLOAT8E5M2,
|
|
80
80
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
81
|
+
_enums.DataType.FLOAT8E8M0,
|
|
81
82
|
_enums.DataType.INT4,
|
|
82
83
|
_enums.DataType.UINT4,
|
|
83
84
|
_enums.DataType.FLOAT4E2M1,
|
|
@@ -261,6 +262,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
|
|
|
261
262
|
ml_dtypes.float8_e4m3fn,
|
|
262
263
|
ml_dtypes.float8_e5m2fnuz,
|
|
263
264
|
ml_dtypes.float8_e5m2,
|
|
265
|
+
ml_dtypes.float8_e8m0fnu,
|
|
264
266
|
):
|
|
265
267
|
raise TypeError(
|
|
266
268
|
f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
|
|
@@ -319,6 +321,8 @@ def _maybe_view_np_array_with_ml_dtypes(
|
|
|
319
321
|
return array.view(ml_dtypes.float8_e5m2)
|
|
320
322
|
if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
|
|
321
323
|
return array.view(ml_dtypes.float8_e5m2fnuz)
|
|
324
|
+
if dtype == _enums.DataType.FLOAT8E8M0:
|
|
325
|
+
return array.view(ml_dtypes.float8_e8m0fnu)
|
|
322
326
|
if dtype == _enums.DataType.INT4:
|
|
323
327
|
return array.view(ml_dtypes.int4)
|
|
324
328
|
if dtype == _enums.DataType.UINT4:
|
|
@@ -65,6 +65,7 @@ class DataType(enum.IntEnum):
|
|
|
65
65
|
UINT4 = 21
|
|
66
66
|
INT4 = 22
|
|
67
67
|
FLOAT4E2M1 = 23
|
|
68
|
+
FLOAT8E8M0 = 24
|
|
68
69
|
|
|
69
70
|
@classmethod
|
|
70
71
|
def from_numpy(cls, dtype: np.dtype) -> DataType:
|
|
@@ -81,6 +82,7 @@ class DataType(enum.IntEnum):
|
|
|
81
82
|
|
|
82
83
|
# Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
|
|
83
84
|
# Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
|
|
85
|
+
# TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
|
|
84
86
|
if hasattr(dtype, "names"):
|
|
85
87
|
if dtype.names == ("bfloat16",):
|
|
86
88
|
return DataType.BFLOAT16
|
|
@@ -167,6 +169,7 @@ class DataType(enum.IntEnum):
|
|
|
167
169
|
DataType.FLOAT8E5M2,
|
|
168
170
|
DataType.FLOAT8E5M2FNUZ,
|
|
169
171
|
DataType.FLOAT4E2M1,
|
|
172
|
+
DataType.FLOAT8E8M0,
|
|
170
173
|
}
|
|
171
174
|
|
|
172
175
|
def is_integer(self) -> bool:
|
|
@@ -209,6 +212,7 @@ class DataType(enum.IntEnum):
|
|
|
209
212
|
DataType.FLOAT8E5M2FNUZ,
|
|
210
213
|
DataType.INT4,
|
|
211
214
|
DataType.FLOAT4E2M1,
|
|
215
|
+
DataType.FLOAT8E8M0,
|
|
212
216
|
}
|
|
213
217
|
|
|
214
218
|
def __repr__(self) -> str:
|
|
@@ -241,6 +245,7 @@ _BITWIDTH_MAP = {
|
|
|
241
245
|
DataType.UINT4: 4,
|
|
242
246
|
DataType.INT4: 4,
|
|
243
247
|
DataType.FLOAT4E2M1: 4,
|
|
248
|
+
DataType.FLOAT8E8M0: 8,
|
|
244
249
|
}
|
|
245
250
|
|
|
246
251
|
|
|
@@ -266,6 +271,7 @@ _NP_TYPE_TO_DATA_TYPE = {
|
|
|
266
271
|
np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
|
|
267
272
|
np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
|
|
268
273
|
np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
|
|
274
|
+
np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
|
|
269
275
|
np.dtype(ml_dtypes.int4): DataType.INT4,
|
|
270
276
|
np.dtype(ml_dtypes.uint4): DataType.UINT4,
|
|
271
277
|
}
|
|
@@ -290,6 +296,7 @@ _DATA_TYPE_TO_SHORT_NAME = {
|
|
|
290
296
|
DataType.FLOAT8E5M2: "f8e5m2",
|
|
291
297
|
DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
|
|
292
298
|
DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
|
|
299
|
+
DataType.FLOAT8E8M0: "f8e8m0",
|
|
293
300
|
DataType.FLOAT4E2M1: "f4e2m1",
|
|
294
301
|
DataType.COMPLEX64: "c64",
|
|
295
302
|
DataType.COMPLEX128: "c128",
|
|
@@ -7,9 +7,11 @@ __all__ = [
|
|
|
7
7
|
"ClearMetadataAndDocStringPass",
|
|
8
8
|
"CommonSubexpressionEliminationPass",
|
|
9
9
|
"DeduplicateInitializersPass",
|
|
10
|
+
"IdentityEliminationPass",
|
|
10
11
|
"InlinePass",
|
|
11
12
|
"LiftConstantsToInitializersPass",
|
|
12
13
|
"LiftSubgraphInitializersToMainGraphPass",
|
|
14
|
+
"NameFixPass",
|
|
13
15
|
"RemoveInitializersFromInputsPass",
|
|
14
16
|
"RemoveUnusedFunctionsPass",
|
|
15
17
|
"RemoveUnusedNodesPass",
|
|
@@ -30,10 +32,14 @@ from onnx_ir.passes.common.constant_manipulation import (
|
|
|
30
32
|
LiftSubgraphInitializersToMainGraphPass,
|
|
31
33
|
RemoveInitializersFromInputsPass,
|
|
32
34
|
)
|
|
35
|
+
from onnx_ir.passes.common.identity_elimination import (
|
|
36
|
+
IdentityEliminationPass,
|
|
37
|
+
)
|
|
33
38
|
from onnx_ir.passes.common.initializer_deduplication import (
|
|
34
39
|
DeduplicateInitializersPass,
|
|
35
40
|
)
|
|
36
41
|
from onnx_ir.passes.common.inliner import InlinePass
|
|
42
|
+
from onnx_ir.passes.common.naming import NameFixPass
|
|
37
43
|
from onnx_ir.passes.common.onnx_checker import CheckerPass
|
|
38
44
|
from onnx_ir.passes.common.shape_inference import ShapeInferencePass
|
|
39
45
|
from onnx_ir.passes.common.topological_sort import TopologicalSortPass
|
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Identity elimination pass for removing redundant Identity nodes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"IdentityEliminationPass",
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
|
|
13
|
+
import onnx_ir as ir
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
19
|
+
"""Pass for eliminating redundant Identity nodes.
|
|
20
|
+
|
|
21
|
+
This pass removes Identity nodes according to the following rules:
|
|
22
|
+
1. For any node of the form `y = Identity(x)`, where `y` is not an output
|
|
23
|
+
of any graph, replace all uses of `y` with a use of `x`, and remove the node.
|
|
24
|
+
2. If `y` is an output of a graph, and `x` is not an input of any graph,
|
|
25
|
+
we can still do the elimination, but the value `x` should be renamed to be `y`.
|
|
26
|
+
3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
|
|
27
|
+
the node. It should be retained.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
31
|
+
"""Main entry point for the identity elimination pass."""
|
|
32
|
+
modified = False
|
|
33
|
+
|
|
34
|
+
# Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
|
|
35
|
+
for node in ir.traversal.RecursiveGraphIterator(model.graph):
|
|
36
|
+
if self._try_eliminate_identity_node(node):
|
|
37
|
+
modified = True
|
|
38
|
+
|
|
39
|
+
# Process nodes in functions
|
|
40
|
+
for function in model.functions.values():
|
|
41
|
+
for node in ir.traversal.RecursiveGraphIterator(function):
|
|
42
|
+
if self._try_eliminate_identity_node(node):
|
|
43
|
+
modified = True
|
|
44
|
+
|
|
45
|
+
if modified:
|
|
46
|
+
logger.info("Identity elimination pass modified the model")
|
|
47
|
+
|
|
48
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
49
|
+
|
|
50
|
+
def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
|
|
51
|
+
"""Try to eliminate a single identity node. Returns True if modified."""
|
|
52
|
+
if node.op_type != "Identity" or node.domain != "":
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
if len(node.inputs) != 1 or len(node.outputs) != 1:
|
|
56
|
+
# Invalid Identity node, skip
|
|
57
|
+
return False
|
|
58
|
+
|
|
59
|
+
input_value = node.inputs[0]
|
|
60
|
+
output_value = node.outputs[0]
|
|
61
|
+
|
|
62
|
+
if input_value is None:
|
|
63
|
+
# Cannot eliminate if input is None
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
# Get the graph that contains this node
|
|
67
|
+
graph_like = node.graph
|
|
68
|
+
assert graph_like is not None, "Node must be in a graph"
|
|
69
|
+
|
|
70
|
+
output_is_graph_output = output_value.is_graph_output()
|
|
71
|
+
input_is_graph_input = input_value.is_graph_input()
|
|
72
|
+
|
|
73
|
+
# Case 3: Both output is graph output and input is graph input - keep the node
|
|
74
|
+
if output_is_graph_output and input_is_graph_input:
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
# Case 1 & 2 (merged): Eliminate the identity node
|
|
78
|
+
# Replace all uses of output with input
|
|
79
|
+
ir.convenience.replace_all_uses_with(output_value, input_value)
|
|
80
|
+
|
|
81
|
+
# If output is a graph output, we need to rename input and update graph outputs
|
|
82
|
+
if output_is_graph_output:
|
|
83
|
+
# Store the original output name
|
|
84
|
+
original_output_name = output_value.name
|
|
85
|
+
|
|
86
|
+
# Update the input value to have the output's name
|
|
87
|
+
input_value.name = original_output_name
|
|
88
|
+
|
|
89
|
+
# Update graph outputs to point to the input value
|
|
90
|
+
for idx, graph_output in enumerate(graph_like.outputs):
|
|
91
|
+
if graph_output is output_value:
|
|
92
|
+
graph_like.outputs[idx] = input_value
|
|
93
|
+
|
|
94
|
+
# Remove the identity node
|
|
95
|
+
graph_like.remove(node, safe=True)
|
|
96
|
+
logger.debug("Eliminated identity node: %s", node)
|
|
97
|
+
return True
|
|
@@ -0,0 +1,286 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Name fix pass for ensuring unique names for all values and nodes."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"NameFixPass",
|
|
9
|
+
"NameGenerator",
|
|
10
|
+
"SimpleNameGenerator",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
import collections
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Protocol
|
|
16
|
+
|
|
17
|
+
import onnx_ir as ir
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NameGenerator(Protocol):
|
|
23
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
24
|
+
"""Generate a preferred name for a node."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
28
|
+
"""Generate a preferred name for a value."""
|
|
29
|
+
...
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SimpleNameGenerator(NameGenerator):
|
|
33
|
+
"""Base class for name generation functions."""
|
|
34
|
+
|
|
35
|
+
def generate_node_name(self, node: ir.Node) -> str:
|
|
36
|
+
"""Generate a preferred name for a node."""
|
|
37
|
+
return node.name or "node"
|
|
38
|
+
|
|
39
|
+
def generate_value_name(self, value: ir.Value) -> str:
|
|
40
|
+
"""Generate a preferred name for a value."""
|
|
41
|
+
return value.name or "v"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class NameFixPass(ir.passes.InPlacePass):
|
|
45
|
+
"""Pass for fixing names to ensure all values and nodes have unique names.
|
|
46
|
+
|
|
47
|
+
This pass ensures that:
|
|
48
|
+
1. Graph inputs and outputs have unique names (take precedence)
|
|
49
|
+
2. All intermediate values have unique names (assign names to unnamed values)
|
|
50
|
+
3. All values in subgraphs have unique names within their graph and parent graphs
|
|
51
|
+
4. All nodes have unique names within their graph
|
|
52
|
+
|
|
53
|
+
The pass maintains global uniqueness across the entire model.
|
|
54
|
+
|
|
55
|
+
You can customize the name generation functions for nodes and values by passing
|
|
56
|
+
a subclass of :class:`NameGenerator`.
|
|
57
|
+
|
|
58
|
+
For example, you can use a custom naming scheme like this::
|
|
59
|
+
|
|
60
|
+
class CustomNameGenerator:
|
|
61
|
+
def custom_node_name(node: ir.Node) -> str:
|
|
62
|
+
return f"custom_node_{node.op_type}"
|
|
63
|
+
|
|
64
|
+
def custom_value_name(value: ir.Value) -> str:
|
|
65
|
+
return f"custom_value_{value.type}"
|
|
66
|
+
|
|
67
|
+
name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())
|
|
68
|
+
|
|
69
|
+
.. versionadded:: 0.1.6
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
name_generator: NameGenerator | None = None,
|
|
75
|
+
) -> None:
|
|
76
|
+
"""Initialize the NameFixPass with custom name generation functions.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name_generator (NameGenerator, optional): An instance of a subclass of
|
|
80
|
+
:class:`NameGenerator` to customize name generation for nodes and values.
|
|
81
|
+
If not provided, defaults to a basic implementation that uses
|
|
82
|
+
the node's or value's existing name or a generic name like "node" or "v".
|
|
83
|
+
"""
|
|
84
|
+
super().__init__()
|
|
85
|
+
self._name_generator = name_generator or SimpleNameGenerator()
|
|
86
|
+
|
|
87
|
+
def call(self, model: ir.Model) -> ir.passes.PassResult:
|
|
88
|
+
# Process the main graph
|
|
89
|
+
modified = self._fix_graph_names(model.graph)
|
|
90
|
+
|
|
91
|
+
# Process functions
|
|
92
|
+
for function in model.functions.values():
|
|
93
|
+
modified = self._fix_graph_names(function) or modified
|
|
94
|
+
|
|
95
|
+
return ir.passes.PassResult(model, modified=modified)
|
|
96
|
+
|
|
97
|
+
def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
|
|
98
|
+
"""Fix names in a graph and return whether modifications were made."""
|
|
99
|
+
modified = False
|
|
100
|
+
|
|
101
|
+
# Set to track which values have been assigned names
|
|
102
|
+
seen_values: set[ir.Value] = set()
|
|
103
|
+
|
|
104
|
+
# The first set is a dummy placeholder so that there is always a [-1] scope for access
|
|
105
|
+
# (even though we don't write to it)
|
|
106
|
+
scoped_used_value_names: list[set[str]] = [set()]
|
|
107
|
+
scoped_used_node_names: list[set[str]] = [set()]
|
|
108
|
+
|
|
109
|
+
# Counters for generating unique names (using list to pass by reference)
|
|
110
|
+
value_counter = collections.Counter()
|
|
111
|
+
node_counter = collections.Counter()
|
|
112
|
+
|
|
113
|
+
def enter_graph(graph_like) -> None:
|
|
114
|
+
"""Callback for entering a subgraph."""
|
|
115
|
+
# Initialize new scopes with all names from the parent scope
|
|
116
|
+
scoped_used_value_names.append(set(scoped_used_value_names[-1]))
|
|
117
|
+
scoped_used_node_names.append(set())
|
|
118
|
+
|
|
119
|
+
nonlocal modified
|
|
120
|
+
|
|
121
|
+
# Step 1: Fix graph input names first (they have precedence)
|
|
122
|
+
for input_value in graph_like.inputs:
|
|
123
|
+
if self._process_value(
|
|
124
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
125
|
+
):
|
|
126
|
+
modified = True
|
|
127
|
+
|
|
128
|
+
# Step 2: Fix graph output names (they have precedence)
|
|
129
|
+
for output_value in graph_like.outputs:
|
|
130
|
+
if self._process_value(
|
|
131
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
132
|
+
):
|
|
133
|
+
modified = True
|
|
134
|
+
|
|
135
|
+
if isinstance(graph_like, ir.Graph):
|
|
136
|
+
# For graphs, also fix initializers
|
|
137
|
+
for initializer in graph_like.initializers.values():
|
|
138
|
+
if self._process_value(
|
|
139
|
+
initializer, scoped_used_value_names[-1], seen_values, value_counter
|
|
140
|
+
):
|
|
141
|
+
modified = True
|
|
142
|
+
|
|
143
|
+
def exit_graph(_) -> None:
|
|
144
|
+
"""Callback for exiting a subgraph."""
|
|
145
|
+
# Pop the current scope
|
|
146
|
+
scoped_used_value_names.pop()
|
|
147
|
+
scoped_used_node_names.pop()
|
|
148
|
+
|
|
149
|
+
# Step 3: Process all nodes and their values
|
|
150
|
+
for node in ir.traversal.RecursiveGraphIterator(
|
|
151
|
+
graph_like, enter_graph=enter_graph, exit_graph=exit_graph
|
|
152
|
+
):
|
|
153
|
+
# Fix node name
|
|
154
|
+
if not node.name:
|
|
155
|
+
if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
|
|
156
|
+
modified = True
|
|
157
|
+
else:
|
|
158
|
+
if self._fix_duplicate_node_name(
|
|
159
|
+
node, scoped_used_node_names[-1], node_counter
|
|
160
|
+
):
|
|
161
|
+
modified = True
|
|
162
|
+
|
|
163
|
+
# Fix input value names (only if not already processed)
|
|
164
|
+
for input_value in node.inputs:
|
|
165
|
+
if input_value is not None:
|
|
166
|
+
if self._process_value(
|
|
167
|
+
input_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
168
|
+
):
|
|
169
|
+
modified = True
|
|
170
|
+
|
|
171
|
+
# Fix output value names (only if not already processed)
|
|
172
|
+
for output_value in node.outputs:
|
|
173
|
+
if self._process_value(
|
|
174
|
+
output_value, scoped_used_value_names[-1], seen_values, value_counter
|
|
175
|
+
):
|
|
176
|
+
modified = True
|
|
177
|
+
|
|
178
|
+
return modified
|
|
179
|
+
|
|
180
|
+
def _process_value(
|
|
181
|
+
self,
|
|
182
|
+
value: ir.Value,
|
|
183
|
+
used_value_names: set[str],
|
|
184
|
+
seen_values: set[ir.Value],
|
|
185
|
+
value_counter: collections.Counter,
|
|
186
|
+
) -> bool:
|
|
187
|
+
"""Process a value only if it hasn't been processed before."""
|
|
188
|
+
if value in seen_values:
|
|
189
|
+
return False
|
|
190
|
+
|
|
191
|
+
modified = False
|
|
192
|
+
|
|
193
|
+
if not value.name:
|
|
194
|
+
modified = self._assign_value_name(value, used_value_names, value_counter)
|
|
195
|
+
else:
|
|
196
|
+
old_name = value.name
|
|
197
|
+
modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
|
|
198
|
+
if modified:
|
|
199
|
+
assert value.graph is not None
|
|
200
|
+
if value.is_initializer():
|
|
201
|
+
value.graph.initializers.pop(old_name)
|
|
202
|
+
# Add the initializer back with the new name
|
|
203
|
+
value.graph.initializers.add(value)
|
|
204
|
+
|
|
205
|
+
# Record the final name for this value
|
|
206
|
+
assert value.name is not None
|
|
207
|
+
seen_values.add(value)
|
|
208
|
+
return modified
|
|
209
|
+
|
|
210
|
+
def _assign_value_name(
|
|
211
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
|
|
212
|
+
) -> bool:
|
|
213
|
+
"""Assign a name to an unnamed value. Returns True if modified."""
|
|
214
|
+
assert not value.name, (
|
|
215
|
+
"value should not have a name already if function is called correctly"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
preferred_name = self._name_generator.generate_value_name(value)
|
|
219
|
+
value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
220
|
+
logger.debug("Assigned name %s to unnamed value", value.name)
|
|
221
|
+
return True
|
|
222
|
+
|
|
223
|
+
def _assign_node_name(
|
|
224
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
|
|
225
|
+
) -> bool:
|
|
226
|
+
"""Assign a name to an unnamed node. Returns True if modified."""
|
|
227
|
+
assert not node.name, (
|
|
228
|
+
"node should not have a name already if function is called correctly"
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
preferred_name = self._name_generator.generate_node_name(node)
|
|
232
|
+
node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
|
|
233
|
+
logger.debug("Assigned name %s to unnamed node", node.name)
|
|
234
|
+
return True
|
|
235
|
+
|
|
236
|
+
def _fix_duplicate_value_name(
|
|
237
|
+
self, value: ir.Value, used_names: set[str], counter: collections.Counter
|
|
238
|
+
) -> bool:
|
|
239
|
+
"""Fix a value's name if it conflicts with existing names. Returns True if modified."""
|
|
240
|
+
original_name = value.name
|
|
241
|
+
|
|
242
|
+
assert original_name, (
|
|
243
|
+
"value should have a name already if function is called correctly"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
if original_name not in used_names:
|
|
247
|
+
# Name is unique, just record it
|
|
248
|
+
used_names.add(original_name)
|
|
249
|
+
return False
|
|
250
|
+
|
|
251
|
+
# If name is already used, make it unique
|
|
252
|
+
base_name = self._name_generator.generate_value_name(value)
|
|
253
|
+
value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
254
|
+
logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
|
|
255
|
+
return True
|
|
256
|
+
|
|
257
|
+
def _fix_duplicate_node_name(
|
|
258
|
+
self, node: ir.Node, used_names: set[str], counter: collections.Counter
|
|
259
|
+
) -> bool:
|
|
260
|
+
"""Fix a node's name if it conflicts with existing names. Returns True if modified."""
|
|
261
|
+
original_name = node.name
|
|
262
|
+
|
|
263
|
+
assert original_name, "node should have a name already if function is called correctly"
|
|
264
|
+
|
|
265
|
+
if original_name not in used_names:
|
|
266
|
+
# Name is unique, just record it
|
|
267
|
+
used_names.add(original_name)
|
|
268
|
+
return False
|
|
269
|
+
|
|
270
|
+
# If name is already used, make it unique
|
|
271
|
+
base_name = self._name_generator.generate_node_name(node)
|
|
272
|
+
node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
|
|
273
|
+
logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
|
|
274
|
+
return True
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def _find_and_record_next_unique_name(
|
|
278
|
+
preferred_name: str, used_names: set[str], counter: collections.Counter
|
|
279
|
+
) -> str:
|
|
280
|
+
"""Generate a unique name based on the preferred name and current counter."""
|
|
281
|
+
new_name = preferred_name
|
|
282
|
+
while new_name in used_names:
|
|
283
|
+
counter[preferred_name] += 1
|
|
284
|
+
new_name = f"{preferred_name}_{counter[preferred_name]}"
|
|
285
|
+
used_names.add(new_name)
|
|
286
|
+
return new_name
|
|
@@ -405,6 +405,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
405
405
|
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
406
406
|
_enums.DataType.FLOAT8E5M2,
|
|
407
407
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
408
|
+
_enums.DataType.FLOAT8E8M0,
|
|
408
409
|
_enums.DataType.INT16,
|
|
409
410
|
_enums.DataType.INT32,
|
|
410
411
|
_enums.DataType.INT4,
|
|
@@ -505,6 +506,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
505
506
|
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
506
507
|
_enums.DataType.FLOAT8E5M2,
|
|
507
508
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
509
|
+
_enums.DataType.FLOAT8E8M0,
|
|
508
510
|
_enums.DataType.INT4,
|
|
509
511
|
_enums.DataType.UINT4,
|
|
510
512
|
_enums.DataType.FLOAT4E2M1,
|
|
@@ -77,6 +77,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
|
|
|
77
77
|
torch.uint32: ir.DataType.UINT32,
|
|
78
78
|
torch.uint64: ir.DataType.UINT64,
|
|
79
79
|
}
|
|
80
|
+
if hasattr(torch, "float8_e8m0fnu"):
|
|
81
|
+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
|
|
82
|
+
_TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0
|
|
83
|
+
|
|
80
84
|
if dtype not in _TORCH_DTYPE_TO_ONNX:
|
|
81
85
|
raise TypeError(
|
|
82
86
|
f"Unsupported PyTorch dtype '{dtype}'. "
|
|
@@ -113,7 +117,17 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
|
|
|
113
117
|
ir.DataType.UINT32: torch.uint32,
|
|
114
118
|
ir.DataType.UINT64: torch.uint64,
|
|
115
119
|
}
|
|
120
|
+
|
|
121
|
+
if hasattr(torch, "float8_e8m0fnu"):
|
|
122
|
+
# torch.float8_e8m0fnu is available in PyTorch 2.7+
|
|
123
|
+
_ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu
|
|
124
|
+
|
|
116
125
|
if dtype not in _ONNX_DTYPE_TO_TORCH:
|
|
126
|
+
if dtype == ir.DataType.FLOAT8E8M0:
|
|
127
|
+
raise ValueError(
|
|
128
|
+
"The requested DataType 'FLOAT8E8M0' is only supported in PyTorch 2.7+. "
|
|
129
|
+
"Please upgrade your PyTorch version to use this dtype."
|
|
130
|
+
)
|
|
117
131
|
raise TypeError(
|
|
118
132
|
f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
|
|
119
133
|
"Please use a supported dtype from the list: "
|
|
@@ -142,6 +156,7 @@ class TorchTensor(_core.Tensor):
|
|
|
142
156
|
ir.DataType.FLOAT8E4M3FNUZ,
|
|
143
157
|
ir.DataType.FLOAT8E5M2,
|
|
144
158
|
ir.DataType.FLOAT8E5M2FNUZ,
|
|
159
|
+
ir.DataType.FLOAT8E8M0,
|
|
145
160
|
}:
|
|
146
161
|
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
|
147
162
|
|
|
@@ -25,19 +25,33 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
|
|
|
25
25
|
*,
|
|
26
26
|
recursive: Callable[[_core.Node], bool] | None = None,
|
|
27
27
|
reverse: bool = False,
|
|
28
|
+
enter_graph: Callable[[GraphLike], None] | None = None,
|
|
29
|
+
exit_graph: Callable[[GraphLike], None] | None = None,
|
|
28
30
|
):
|
|
29
31
|
"""Iterate over the nodes in the graph, recursively visiting subgraphs.
|
|
30
32
|
|
|
33
|
+
This iterator allows for traversing the nodes of a graph and its subgraphs
|
|
34
|
+
in a depth-first manner. It supports optional callbacks for entering and exiting
|
|
35
|
+
subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs
|
|
36
|
+
contained within nodes.
|
|
37
|
+
|
|
38
|
+
.. versionadded:: 0.1.6
|
|
39
|
+
Added the `enter_graph` and `exit_graph` callbacks.
|
|
40
|
+
|
|
31
41
|
Args:
|
|
32
42
|
graph_like: The graph to traverse.
|
|
33
43
|
recursive: A callback that determines whether to recursively visit the subgraphs
|
|
34
44
|
contained in a node. If not provided, all nodes in subgraphs are visited.
|
|
35
45
|
reverse: Whether to iterate in reverse order.
|
|
46
|
+
enter_graph: An optional callback that is called when entering a subgraph.
|
|
47
|
+
exit_graph: An optional callback that is called when exiting a subgraph.
|
|
36
48
|
"""
|
|
37
49
|
self._graph = graph_like
|
|
38
50
|
self._recursive = recursive
|
|
39
51
|
self._reverse = reverse
|
|
40
52
|
self._iterator = self._recursive_node_iter(graph_like)
|
|
53
|
+
self._enter_graph = enter_graph
|
|
54
|
+
self._exit_graph = exit_graph
|
|
41
55
|
|
|
42
56
|
def __iter__(self) -> Self:
|
|
43
57
|
self._iterator = self._recursive_node_iter(self._graph)
|
|
@@ -50,34 +64,55 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
|
|
|
50
64
|
self, graph: _core.Graph | _core.Function | _core.GraphView
|
|
51
65
|
) -> Iterator[_core.Node]:
|
|
52
66
|
iterable = reversed(graph) if self._reverse else graph
|
|
67
|
+
|
|
68
|
+
if self._enter_graph is not None:
|
|
69
|
+
self._enter_graph(graph)
|
|
70
|
+
|
|
53
71
|
for node in iterable: # type: ignore[union-attr]
|
|
54
72
|
yield node
|
|
55
73
|
if self._recursive is not None and not self._recursive(node):
|
|
56
74
|
continue
|
|
57
75
|
yield from self._iterate_subgraphs(node)
|
|
58
76
|
|
|
77
|
+
if self._exit_graph is not None:
|
|
78
|
+
self._exit_graph(graph)
|
|
79
|
+
|
|
59
80
|
def _iterate_subgraphs(self, node: _core.Node):
|
|
60
81
|
for attr in node.attributes.values():
|
|
61
82
|
if not isinstance(attr, _core.Attr):
|
|
62
83
|
continue
|
|
63
84
|
if attr.type == _enums.AttributeType.GRAPH:
|
|
85
|
+
if self._enter_graph is not None:
|
|
86
|
+
self._enter_graph(attr.value)
|
|
64
87
|
yield from RecursiveGraphIterator(
|
|
65
88
|
attr.value,
|
|
66
89
|
recursive=self._recursive,
|
|
67
90
|
reverse=self._reverse,
|
|
91
|
+
enter_graph=self._enter_graph,
|
|
92
|
+
exit_graph=self._exit_graph,
|
|
68
93
|
)
|
|
94
|
+
if self._exit_graph is not None:
|
|
95
|
+
self._exit_graph(attr.value)
|
|
69
96
|
elif attr.type == _enums.AttributeType.GRAPHS:
|
|
70
97
|
graphs = reversed(attr.value) if self._reverse else attr.value
|
|
71
98
|
for graph in graphs:
|
|
99
|
+
if self._enter_graph is not None:
|
|
100
|
+
self._enter_graph(graph)
|
|
72
101
|
yield from RecursiveGraphIterator(
|
|
73
102
|
graph,
|
|
74
103
|
recursive=self._recursive,
|
|
75
104
|
reverse=self._reverse,
|
|
105
|
+
enter_graph=self._enter_graph,
|
|
106
|
+
exit_graph=self._exit_graph,
|
|
76
107
|
)
|
|
108
|
+
if self._exit_graph is not None:
|
|
109
|
+
self._exit_graph(graph)
|
|
77
110
|
|
|
78
111
|
def __reversed__(self) -> Iterator[_core.Node]:
|
|
79
112
|
return RecursiveGraphIterator(
|
|
80
113
|
self._graph,
|
|
81
114
|
recursive=self._recursive,
|
|
82
115
|
reverse=not self._reverse,
|
|
116
|
+
enter_graph=self._enter_graph,
|
|
117
|
+
exit_graph=self._exit_graph,
|
|
83
118
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.6
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License: Apache License v2.0
|
|
@@ -29,7 +29,7 @@ Dynamic: license-file
|
|
|
29
29
|
[](https://pypi.org/project/onnx-ir)
|
|
30
30
|
[](https://github.com/astral-sh/ruff)
|
|
31
31
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
32
|
-
[](https://pepy.tech/projects/onnx-ir)
|
|
33
33
|
|
|
34
34
|
An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
|
|
35
35
|
|
|
@@ -40,8 +40,10 @@ src/onnx_ir/passes/common/_c_api_utils.py
|
|
|
40
40
|
src/onnx_ir/passes/common/clear_metadata_and_docstring.py
|
|
41
41
|
src/onnx_ir/passes/common/common_subexpression_elimination.py
|
|
42
42
|
src/onnx_ir/passes/common/constant_manipulation.py
|
|
43
|
+
src/onnx_ir/passes/common/identity_elimination.py
|
|
43
44
|
src/onnx_ir/passes/common/initializer_deduplication.py
|
|
44
45
|
src/onnx_ir/passes/common/inliner.py
|
|
46
|
+
src/onnx_ir/passes/common/naming.py
|
|
45
47
|
src/onnx_ir/passes/common/onnx_checker.py
|
|
46
48
|
src/onnx_ir/passes/common/shape_inference.py
|
|
47
49
|
src/onnx_ir/passes/common/topological_sort.py
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/common_subexpression_elimination.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|