onnx-ir 0.1.4__tar.gz → 0.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (52) hide show
  1. {onnx_ir-0.1.4/src/onnx_ir.egg-info → onnx_ir-0.1.6}/PKG-INFO +2 -2
  2. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/README.md +1 -1
  3. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/__init__.py +1 -1
  4. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_core.py +4 -0
  5. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_enums.py +7 -0
  6. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/__init__.py +6 -0
  7. onnx_ir-0.1.6/src/onnx_ir/passes/common/identity_elimination.py +97 -0
  8. onnx_ir-0.1.6/src/onnx_ir/passes/common/naming.py +286 -0
  9. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/serde.py +2 -0
  10. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/tensor_adapters.py +15 -0
  11. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/traversal.py +35 -0
  12. {onnx_ir-0.1.4 → onnx_ir-0.1.6/src/onnx_ir.egg-info}/PKG-INFO +2 -2
  13. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/SOURCES.txt +2 -0
  14. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/LICENSE +0 -0
  15. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/MANIFEST.in +0 -0
  16. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/pyproject.toml +0 -0
  17. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/setup.cfg +0 -0
  18. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_convenience/__init__.py +0 -0
  19. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_convenience/_constructors.py +0 -0
  20. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_display.py +0 -0
  21. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_graph_comparison.py +0 -0
  22. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_graph_containers.py +0 -0
  23. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_io.py +0 -0
  24. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_linked_list.py +0 -0
  25. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_metadata.py +0 -0
  26. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_name_authority.py +0 -0
  27. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_polyfill.py +0 -0
  28. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_protocols.py +0 -0
  29. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_tape.py +0 -0
  30. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  31. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_type_casting.py +0 -0
  32. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/_version_utils.py +0 -0
  33. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/convenience.py +0 -0
  34. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/external_data.py +0 -0
  35. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/__init__.py +0 -0
  36. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/_pass_infra.py +0 -0
  37. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  38. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  39. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  40. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
  41. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
  42. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/inliner.py +0 -0
  43. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  44. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  45. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  46. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  47. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/py.typed +0 -0
  48. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/tape.py +0 -0
  49. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir/testing.py +0 -0
  50. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  51. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/requires.txt +0 -0
  52. {onnx_ir-0.1.4 → onnx_ir-0.1.6}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -29,7 +29,7 @@ Dynamic: license-file
29
29
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
30
30
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
31
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
32
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
33
33
 
34
34
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
35
35
 
@@ -4,7 +4,7 @@
4
4
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
5
5
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
6
6
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
7
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
7
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
8
8
 
9
9
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
10
10
 
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.4"
170
+ __version__ = "0.1.6"
@@ -78,6 +78,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
78
78
  _enums.DataType.FLOAT8E4M3FNUZ,
79
79
  _enums.DataType.FLOAT8E5M2,
80
80
  _enums.DataType.FLOAT8E5M2FNUZ,
81
+ _enums.DataType.FLOAT8E8M0,
81
82
  _enums.DataType.INT4,
82
83
  _enums.DataType.UINT4,
83
84
  _enums.DataType.FLOAT4E2M1,
@@ -261,6 +262,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
261
262
  ml_dtypes.float8_e4m3fn,
262
263
  ml_dtypes.float8_e5m2fnuz,
263
264
  ml_dtypes.float8_e5m2,
265
+ ml_dtypes.float8_e8m0fnu,
264
266
  ):
265
267
  raise TypeError(
266
268
  f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
@@ -319,6 +321,8 @@ def _maybe_view_np_array_with_ml_dtypes(
319
321
  return array.view(ml_dtypes.float8_e5m2)
320
322
  if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
321
323
  return array.view(ml_dtypes.float8_e5m2fnuz)
324
+ if dtype == _enums.DataType.FLOAT8E8M0:
325
+ return array.view(ml_dtypes.float8_e8m0fnu)
322
326
  if dtype == _enums.DataType.INT4:
323
327
  return array.view(ml_dtypes.int4)
324
328
  if dtype == _enums.DataType.UINT4:
@@ -65,6 +65,7 @@ class DataType(enum.IntEnum):
65
65
  UINT4 = 21
66
66
  INT4 = 22
67
67
  FLOAT4E2M1 = 23
68
+ FLOAT8E8M0 = 24
68
69
 
69
70
  @classmethod
70
71
  def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -81,6 +82,7 @@ class DataType(enum.IntEnum):
81
82
 
82
83
  # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83
84
  # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
85
+ # TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
84
86
  if hasattr(dtype, "names"):
85
87
  if dtype.names == ("bfloat16",):
86
88
  return DataType.BFLOAT16
@@ -167,6 +169,7 @@ class DataType(enum.IntEnum):
167
169
  DataType.FLOAT8E5M2,
168
170
  DataType.FLOAT8E5M2FNUZ,
169
171
  DataType.FLOAT4E2M1,
172
+ DataType.FLOAT8E8M0,
170
173
  }
171
174
 
172
175
  def is_integer(self) -> bool:
@@ -209,6 +212,7 @@ class DataType(enum.IntEnum):
209
212
  DataType.FLOAT8E5M2FNUZ,
210
213
  DataType.INT4,
211
214
  DataType.FLOAT4E2M1,
215
+ DataType.FLOAT8E8M0,
212
216
  }
213
217
 
214
218
  def __repr__(self) -> str:
@@ -241,6 +245,7 @@ _BITWIDTH_MAP = {
241
245
  DataType.UINT4: 4,
242
246
  DataType.INT4: 4,
243
247
  DataType.FLOAT4E2M1: 4,
248
+ DataType.FLOAT8E8M0: 8,
244
249
  }
245
250
 
246
251
 
@@ -266,6 +271,7 @@ _NP_TYPE_TO_DATA_TYPE = {
266
271
  np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
267
272
  np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
268
273
  np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
274
+ np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
269
275
  np.dtype(ml_dtypes.int4): DataType.INT4,
270
276
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
271
277
  }
@@ -290,6 +296,7 @@ _DATA_TYPE_TO_SHORT_NAME = {
290
296
  DataType.FLOAT8E5M2: "f8e5m2",
291
297
  DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
292
298
  DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
299
+ DataType.FLOAT8E8M0: "f8e8m0",
293
300
  DataType.FLOAT4E2M1: "f4e2m1",
294
301
  DataType.COMPLEX64: "c64",
295
302
  DataType.COMPLEX128: "c128",
@@ -7,9 +7,11 @@ __all__ = [
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
9
  "DeduplicateInitializersPass",
10
+ "IdentityEliminationPass",
10
11
  "InlinePass",
11
12
  "LiftConstantsToInitializersPass",
12
13
  "LiftSubgraphInitializersToMainGraphPass",
14
+ "NameFixPass",
13
15
  "RemoveInitializersFromInputsPass",
14
16
  "RemoveUnusedFunctionsPass",
15
17
  "RemoveUnusedNodesPass",
@@ -30,10 +32,14 @@ from onnx_ir.passes.common.constant_manipulation import (
30
32
  LiftSubgraphInitializersToMainGraphPass,
31
33
  RemoveInitializersFromInputsPass,
32
34
  )
35
+ from onnx_ir.passes.common.identity_elimination import (
36
+ IdentityEliminationPass,
37
+ )
33
38
  from onnx_ir.passes.common.initializer_deduplication import (
34
39
  DeduplicateInitializersPass,
35
40
  )
36
41
  from onnx_ir.passes.common.inliner import InlinePass
42
+ from onnx_ir.passes.common.naming import NameFixPass
37
43
  from onnx_ir.passes.common.onnx_checker import CheckerPass
38
44
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
39
45
  from onnx_ir.passes.common.topological_sort import TopologicalSortPass
@@ -0,0 +1,97 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Identity elimination pass for removing redundant Identity nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "IdentityEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class IdentityEliminationPass(ir.passes.InPlacePass):
19
+ """Pass for eliminating redundant Identity nodes.
20
+
21
+ This pass removes Identity nodes according to the following rules:
22
+ 1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
+ of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
+ 2. If `y` is an output of a graph, and `x` is not an input of any graph,
25
+ we can still do the elimination, but the value `x` should be renamed to be `y`.
26
+ 3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
27
+ the node. It should be retained.
28
+ """
29
+
30
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
31
+ """Main entry point for the identity elimination pass."""
32
+ modified = False
33
+
34
+ # Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
35
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
36
+ if self._try_eliminate_identity_node(node):
37
+ modified = True
38
+
39
+ # Process nodes in functions
40
+ for function in model.functions.values():
41
+ for node in ir.traversal.RecursiveGraphIterator(function):
42
+ if self._try_eliminate_identity_node(node):
43
+ modified = True
44
+
45
+ if modified:
46
+ logger.info("Identity elimination pass modified the model")
47
+
48
+ return ir.passes.PassResult(model, modified=modified)
49
+
50
+ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
51
+ """Try to eliminate a single identity node. Returns True if modified."""
52
+ if node.op_type != "Identity" or node.domain != "":
53
+ return False
54
+
55
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
56
+ # Invalid Identity node, skip
57
+ return False
58
+
59
+ input_value = node.inputs[0]
60
+ output_value = node.outputs[0]
61
+
62
+ if input_value is None:
63
+ # Cannot eliminate if input is None
64
+ return False
65
+
66
+ # Get the graph that contains this node
67
+ graph_like = node.graph
68
+ assert graph_like is not None, "Node must be in a graph"
69
+
70
+ output_is_graph_output = output_value.is_graph_output()
71
+ input_is_graph_input = input_value.is_graph_input()
72
+
73
+ # Case 3: Both output is graph output and input is graph input - keep the node
74
+ if output_is_graph_output and input_is_graph_input:
75
+ return False
76
+
77
+ # Case 1 & 2 (merged): Eliminate the identity node
78
+ # Replace all uses of output with input
79
+ ir.convenience.replace_all_uses_with(output_value, input_value)
80
+
81
+ # If output is a graph output, we need to rename input and update graph outputs
82
+ if output_is_graph_output:
83
+ # Store the original output name
84
+ original_output_name = output_value.name
85
+
86
+ # Update the input value to have the output's name
87
+ input_value.name = original_output_name
88
+
89
+ # Update graph outputs to point to the input value
90
+ for idx, graph_output in enumerate(graph_like.outputs):
91
+ if graph_output is output_value:
92
+ graph_like.outputs[idx] = input_value
93
+
94
+ # Remove the identity node
95
+ graph_like.remove(node, safe=True)
96
+ logger.debug("Eliminated identity node: %s", node)
97
+ return True
@@ -0,0 +1,286 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Name fix pass for ensuring unique names for all values and nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "NameFixPass",
9
+ "NameGenerator",
10
+ "SimpleNameGenerator",
11
+ ]
12
+
13
+ import collections
14
+ import logging
15
+ from typing import Protocol
16
+
17
+ import onnx_ir as ir
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class NameGenerator(Protocol):
23
+ def generate_node_name(self, node: ir.Node) -> str:
24
+ """Generate a preferred name for a node."""
25
+ ...
26
+
27
+ def generate_value_name(self, value: ir.Value) -> str:
28
+ """Generate a preferred name for a value."""
29
+ ...
30
+
31
+
32
+ class SimpleNameGenerator(NameGenerator):
33
+ """Base class for name generation functions."""
34
+
35
+ def generate_node_name(self, node: ir.Node) -> str:
36
+ """Generate a preferred name for a node."""
37
+ return node.name or "node"
38
+
39
+ def generate_value_name(self, value: ir.Value) -> str:
40
+ """Generate a preferred name for a value."""
41
+ return value.name or "v"
42
+
43
+
44
+ class NameFixPass(ir.passes.InPlacePass):
45
+ """Pass for fixing names to ensure all values and nodes have unique names.
46
+
47
+ This pass ensures that:
48
+ 1. Graph inputs and outputs have unique names (take precedence)
49
+ 2. All intermediate values have unique names (assign names to unnamed values)
50
+ 3. All values in subgraphs have unique names within their graph and parent graphs
51
+ 4. All nodes have unique names within their graph
52
+
53
+ The pass maintains global uniqueness across the entire model.
54
+
55
+ You can customize the name generation functions for nodes and values by passing
56
+ a subclass of :class:`NameGenerator`.
57
+
58
+ For example, you can use a custom naming scheme like this::
59
+
60
+ class CustomNameGenerator:
61
+ def custom_node_name(node: ir.Node) -> str:
62
+ return f"custom_node_{node.op_type}"
63
+
64
+ def custom_value_name(value: ir.Value) -> str:
65
+ return f"custom_value_{value.type}"
66
+
67
+ name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())
68
+
69
+ .. versionadded:: 0.1.6
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ name_generator: NameGenerator | None = None,
75
+ ) -> None:
76
+ """Initialize the NameFixPass with custom name generation functions.
77
+
78
+ Args:
79
+ name_generator (NameGenerator, optional): An instance of a subclass of
80
+ :class:`NameGenerator` to customize name generation for nodes and values.
81
+ If not provided, defaults to a basic implementation that uses
82
+ the node's or value's existing name or a generic name like "node" or "v".
83
+ """
84
+ super().__init__()
85
+ self._name_generator = name_generator or SimpleNameGenerator()
86
+
87
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
88
+ # Process the main graph
89
+ modified = self._fix_graph_names(model.graph)
90
+
91
+ # Process functions
92
+ for function in model.functions.values():
93
+ modified = self._fix_graph_names(function) or modified
94
+
95
+ return ir.passes.PassResult(model, modified=modified)
96
+
97
+ def _fix_graph_names(self, graph_like: ir.Graph | ir.Function) -> bool:
98
+ """Fix names in a graph and return whether modifications were made."""
99
+ modified = False
100
+
101
+ # Set to track which values have been assigned names
102
+ seen_values: set[ir.Value] = set()
103
+
104
+ # The first set is a dummy placeholder so that there is always a [-1] scope for access
105
+ # (even though we don't write to it)
106
+ scoped_used_value_names: list[set[str]] = [set()]
107
+ scoped_used_node_names: list[set[str]] = [set()]
108
+
109
+ # Counters for generating unique names (using list to pass by reference)
110
+ value_counter = collections.Counter()
111
+ node_counter = collections.Counter()
112
+
113
+ def enter_graph(graph_like) -> None:
114
+ """Callback for entering a subgraph."""
115
+ # Initialize new scopes with all names from the parent scope
116
+ scoped_used_value_names.append(set(scoped_used_value_names[-1]))
117
+ scoped_used_node_names.append(set())
118
+
119
+ nonlocal modified
120
+
121
+ # Step 1: Fix graph input names first (they have precedence)
122
+ for input_value in graph_like.inputs:
123
+ if self._process_value(
124
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
125
+ ):
126
+ modified = True
127
+
128
+ # Step 2: Fix graph output names (they have precedence)
129
+ for output_value in graph_like.outputs:
130
+ if self._process_value(
131
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
132
+ ):
133
+ modified = True
134
+
135
+ if isinstance(graph_like, ir.Graph):
136
+ # For graphs, also fix initializers
137
+ for initializer in graph_like.initializers.values():
138
+ if self._process_value(
139
+ initializer, scoped_used_value_names[-1], seen_values, value_counter
140
+ ):
141
+ modified = True
142
+
143
+ def exit_graph(_) -> None:
144
+ """Callback for exiting a subgraph."""
145
+ # Pop the current scope
146
+ scoped_used_value_names.pop()
147
+ scoped_used_node_names.pop()
148
+
149
+ # Step 3: Process all nodes and their values
150
+ for node in ir.traversal.RecursiveGraphIterator(
151
+ graph_like, enter_graph=enter_graph, exit_graph=exit_graph
152
+ ):
153
+ # Fix node name
154
+ if not node.name:
155
+ if self._assign_node_name(node, scoped_used_node_names[-1], node_counter):
156
+ modified = True
157
+ else:
158
+ if self._fix_duplicate_node_name(
159
+ node, scoped_used_node_names[-1], node_counter
160
+ ):
161
+ modified = True
162
+
163
+ # Fix input value names (only if not already processed)
164
+ for input_value in node.inputs:
165
+ if input_value is not None:
166
+ if self._process_value(
167
+ input_value, scoped_used_value_names[-1], seen_values, value_counter
168
+ ):
169
+ modified = True
170
+
171
+ # Fix output value names (only if not already processed)
172
+ for output_value in node.outputs:
173
+ if self._process_value(
174
+ output_value, scoped_used_value_names[-1], seen_values, value_counter
175
+ ):
176
+ modified = True
177
+
178
+ return modified
179
+
180
+ def _process_value(
181
+ self,
182
+ value: ir.Value,
183
+ used_value_names: set[str],
184
+ seen_values: set[ir.Value],
185
+ value_counter: collections.Counter,
186
+ ) -> bool:
187
+ """Process a value only if it hasn't been processed before."""
188
+ if value in seen_values:
189
+ return False
190
+
191
+ modified = False
192
+
193
+ if not value.name:
194
+ modified = self._assign_value_name(value, used_value_names, value_counter)
195
+ else:
196
+ old_name = value.name
197
+ modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
198
+ if modified:
199
+ assert value.graph is not None
200
+ if value.is_initializer():
201
+ value.graph.initializers.pop(old_name)
202
+ # Add the initializer back with the new name
203
+ value.graph.initializers.add(value)
204
+
205
+ # Record the final name for this value
206
+ assert value.name is not None
207
+ seen_values.add(value)
208
+ return modified
209
+
210
+ def _assign_value_name(
211
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter
212
+ ) -> bool:
213
+ """Assign a name to an unnamed value. Returns True if modified."""
214
+ assert not value.name, (
215
+ "value should not have a name already if function is called correctly"
216
+ )
217
+
218
+ preferred_name = self._name_generator.generate_value_name(value)
219
+ value.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
220
+ logger.debug("Assigned name %s to unnamed value", value.name)
221
+ return True
222
+
223
+ def _assign_node_name(
224
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter
225
+ ) -> bool:
226
+ """Assign a name to an unnamed node. Returns True if modified."""
227
+ assert not node.name, (
228
+ "node should not have a name already if function is called correctly"
229
+ )
230
+
231
+ preferred_name = self._name_generator.generate_node_name(node)
232
+ node.name = _find_and_record_next_unique_name(preferred_name, used_names, counter)
233
+ logger.debug("Assigned name %s to unnamed node", node.name)
234
+ return True
235
+
236
+ def _fix_duplicate_value_name(
237
+ self, value: ir.Value, used_names: set[str], counter: collections.Counter
238
+ ) -> bool:
239
+ """Fix a value's name if it conflicts with existing names. Returns True if modified."""
240
+ original_name = value.name
241
+
242
+ assert original_name, (
243
+ "value should have a name already if function is called correctly"
244
+ )
245
+
246
+ if original_name not in used_names:
247
+ # Name is unique, just record it
248
+ used_names.add(original_name)
249
+ return False
250
+
251
+ # If name is already used, make it unique
252
+ base_name = self._name_generator.generate_value_name(value)
253
+ value.name = _find_and_record_next_unique_name(base_name, used_names, counter)
254
+ logger.debug("Renamed value from %s to %s for uniqueness", original_name, value.name)
255
+ return True
256
+
257
+ def _fix_duplicate_node_name(
258
+ self, node: ir.Node, used_names: set[str], counter: collections.Counter
259
+ ) -> bool:
260
+ """Fix a node's name if it conflicts with existing names. Returns True if modified."""
261
+ original_name = node.name
262
+
263
+ assert original_name, "node should have a name already if function is called correctly"
264
+
265
+ if original_name not in used_names:
266
+ # Name is unique, just record it
267
+ used_names.add(original_name)
268
+ return False
269
+
270
+ # If name is already used, make it unique
271
+ base_name = self._name_generator.generate_node_name(node)
272
+ node.name = _find_and_record_next_unique_name(base_name, used_names, counter)
273
+ logger.debug("Renamed node from %s to %s for uniqueness", original_name, node.name)
274
+ return True
275
+
276
+
277
+ def _find_and_record_next_unique_name(
278
+ preferred_name: str, used_names: set[str], counter: collections.Counter
279
+ ) -> str:
280
+ """Generate a unique name based on the preferred name and current counter."""
281
+ new_name = preferred_name
282
+ while new_name in used_names:
283
+ counter[preferred_name] += 1
284
+ new_name = f"{preferred_name}_{counter[preferred_name]}"
285
+ used_names.add(new_name)
286
+ return new_name
@@ -405,6 +405,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
405
405
  _enums.DataType.FLOAT8E4M3FNUZ,
406
406
  _enums.DataType.FLOAT8E5M2,
407
407
  _enums.DataType.FLOAT8E5M2FNUZ,
408
+ _enums.DataType.FLOAT8E8M0,
408
409
  _enums.DataType.INT16,
409
410
  _enums.DataType.INT32,
410
411
  _enums.DataType.INT4,
@@ -505,6 +506,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
505
506
  _enums.DataType.FLOAT8E4M3FNUZ,
506
507
  _enums.DataType.FLOAT8E5M2,
507
508
  _enums.DataType.FLOAT8E5M2FNUZ,
509
+ _enums.DataType.FLOAT8E8M0,
508
510
  _enums.DataType.INT4,
509
511
  _enums.DataType.UINT4,
510
512
  _enums.DataType.FLOAT4E2M1,
@@ -77,6 +77,10 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
77
77
  torch.uint32: ir.DataType.UINT32,
78
78
  torch.uint64: ir.DataType.UINT64,
79
79
  }
80
+ if hasattr(torch, "float8_e8m0fnu"):
81
+ # torch.float8_e8m0fnu is available in PyTorch 2.7+
82
+ _TORCH_DTYPE_TO_ONNX[torch.float8_e8m0fnu] = ir.DataType.FLOAT8E8M0
83
+
80
84
  if dtype not in _TORCH_DTYPE_TO_ONNX:
81
85
  raise TypeError(
82
86
  f"Unsupported PyTorch dtype '{dtype}'. "
@@ -113,7 +117,17 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
113
117
  ir.DataType.UINT32: torch.uint32,
114
118
  ir.DataType.UINT64: torch.uint64,
115
119
  }
120
+
121
+ if hasattr(torch, "float8_e8m0fnu"):
122
+ # torch.float8_e8m0fnu is available in PyTorch 2.7+
123
+ _ONNX_DTYPE_TO_TORCH[ir.DataType.FLOAT8E8M0] = torch.float8_e8m0fnu
124
+
116
125
  if dtype not in _ONNX_DTYPE_TO_TORCH:
126
+ if dtype == ir.DataType.FLOAT8E8M0:
127
+ raise ValueError(
128
+ "The requested DataType 'FLOAT8E8M0' is only supported in PyTorch 2.7+. "
129
+ "Please upgrade your PyTorch version to use this dtype."
130
+ )
117
131
  raise TypeError(
118
132
  f"Unsupported conversion from ONNX dtype '{dtype}' to torch. "
119
133
  "Please use a supported dtype from the list: "
@@ -142,6 +156,7 @@ class TorchTensor(_core.Tensor):
142
156
  ir.DataType.FLOAT8E4M3FNUZ,
143
157
  ir.DataType.FLOAT8E5M2,
144
158
  ir.DataType.FLOAT8E5M2FNUZ,
159
+ ir.DataType.FLOAT8E8M0,
145
160
  }:
146
161
  return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
147
162
 
@@ -25,19 +25,33 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
25
25
  *,
26
26
  recursive: Callable[[_core.Node], bool] | None = None,
27
27
  reverse: bool = False,
28
+ enter_graph: Callable[[GraphLike], None] | None = None,
29
+ exit_graph: Callable[[GraphLike], None] | None = None,
28
30
  ):
29
31
  """Iterate over the nodes in the graph, recursively visiting subgraphs.
30
32
 
33
+ This iterator allows for traversing the nodes of a graph and its subgraphs
34
+ in a depth-first manner. It supports optional callbacks for entering and exiting
35
+ subgraphs, as well as a callback `recursive` to determine whether to visit subgraphs
36
+ contained within nodes.
37
+
38
+ .. versionadded:: 0.1.6
39
+ Added the `enter_graph` and `exit_graph` callbacks.
40
+
31
41
  Args:
32
42
  graph_like: The graph to traverse.
33
43
  recursive: A callback that determines whether to recursively visit the subgraphs
34
44
  contained in a node. If not provided, all nodes in subgraphs are visited.
35
45
  reverse: Whether to iterate in reverse order.
46
+ enter_graph: An optional callback that is called when entering a subgraph.
47
+ exit_graph: An optional callback that is called when exiting a subgraph.
36
48
  """
37
49
  self._graph = graph_like
38
50
  self._recursive = recursive
39
51
  self._reverse = reverse
40
52
  self._iterator = self._recursive_node_iter(graph_like)
53
+ self._enter_graph = enter_graph
54
+ self._exit_graph = exit_graph
41
55
 
42
56
  def __iter__(self) -> Self:
43
57
  self._iterator = self._recursive_node_iter(self._graph)
@@ -50,34 +64,55 @@ class RecursiveGraphIterator(Iterator[_core.Node], Reversible[_core.Node]):
50
64
  self, graph: _core.Graph | _core.Function | _core.GraphView
51
65
  ) -> Iterator[_core.Node]:
52
66
  iterable = reversed(graph) if self._reverse else graph
67
+
68
+ if self._enter_graph is not None:
69
+ self._enter_graph(graph)
70
+
53
71
  for node in iterable: # type: ignore[union-attr]
54
72
  yield node
55
73
  if self._recursive is not None and not self._recursive(node):
56
74
  continue
57
75
  yield from self._iterate_subgraphs(node)
58
76
 
77
+ if self._exit_graph is not None:
78
+ self._exit_graph(graph)
79
+
59
80
  def _iterate_subgraphs(self, node: _core.Node):
60
81
  for attr in node.attributes.values():
61
82
  if not isinstance(attr, _core.Attr):
62
83
  continue
63
84
  if attr.type == _enums.AttributeType.GRAPH:
85
+ if self._enter_graph is not None:
86
+ self._enter_graph(attr.value)
64
87
  yield from RecursiveGraphIterator(
65
88
  attr.value,
66
89
  recursive=self._recursive,
67
90
  reverse=self._reverse,
91
+ enter_graph=self._enter_graph,
92
+ exit_graph=self._exit_graph,
68
93
  )
94
+ if self._exit_graph is not None:
95
+ self._exit_graph(attr.value)
69
96
  elif attr.type == _enums.AttributeType.GRAPHS:
70
97
  graphs = reversed(attr.value) if self._reverse else attr.value
71
98
  for graph in graphs:
99
+ if self._enter_graph is not None:
100
+ self._enter_graph(graph)
72
101
  yield from RecursiveGraphIterator(
73
102
  graph,
74
103
  recursive=self._recursive,
75
104
  reverse=self._reverse,
105
+ enter_graph=self._enter_graph,
106
+ exit_graph=self._exit_graph,
76
107
  )
108
+ if self._exit_graph is not None:
109
+ self._exit_graph(graph)
77
110
 
78
111
  def __reversed__(self) -> Iterator[_core.Node]:
79
112
  return RecursiveGraphIterator(
80
113
  self._graph,
81
114
  recursive=self._recursive,
82
115
  reverse=not self._reverse,
116
+ enter_graph=self._enter_graph,
117
+ exit_graph=self._exit_graph,
83
118
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.4
3
+ Version: 0.1.6
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -29,7 +29,7 @@ Dynamic: license-file
29
29
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
30
30
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
31
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
32
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
33
33
 
34
34
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
35
35
 
@@ -40,8 +40,10 @@ src/onnx_ir/passes/common/_c_api_utils.py
40
40
  src/onnx_ir/passes/common/clear_metadata_and_docstring.py
41
41
  src/onnx_ir/passes/common/common_subexpression_elimination.py
42
42
  src/onnx_ir/passes/common/constant_manipulation.py
43
+ src/onnx_ir/passes/common/identity_elimination.py
43
44
  src/onnx_ir/passes/common/initializer_deduplication.py
44
45
  src/onnx_ir/passes/common/inliner.py
46
+ src/onnx_ir/passes/common/naming.py
45
47
  src/onnx_ir/passes/common/onnx_checker.py
46
48
  src/onnx_ir/passes/common/shape_inference.py
47
49
  src/onnx_ir/passes/common/topological_sort.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes