onnx-ir 0.1.4__tar.gz → 0.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (51) hide show
  1. {onnx_ir-0.1.4/src/onnx_ir.egg-info → onnx_ir-0.1.5}/PKG-INFO +2 -1
  2. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/README.md +1 -0
  3. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/__init__.py +1 -1
  4. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_core.py +4 -0
  5. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_enums.py +7 -0
  6. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/__init__.py +4 -0
  7. onnx_ir-0.1.5/src/onnx_ir/passes/common/identity_elimination.py +97 -0
  8. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/serde.py +2 -0
  9. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/tensor_adapters.py +3 -0
  10. {onnx_ir-0.1.4 → onnx_ir-0.1.5/src/onnx_ir.egg-info}/PKG-INFO +2 -1
  11. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
  12. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/LICENSE +0 -0
  13. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/MANIFEST.in +0 -0
  14. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/pyproject.toml +0 -0
  15. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/setup.cfg +0 -0
  16. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_convenience/__init__.py +0 -0
  17. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_convenience/_constructors.py +0 -0
  18. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_display.py +0 -0
  19. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_graph_comparison.py +0 -0
  20. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_graph_containers.py +0 -0
  21. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_io.py +0 -0
  22. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_linked_list.py +0 -0
  23. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_metadata.py +0 -0
  24. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_name_authority.py +0 -0
  25. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_polyfill.py +0 -0
  26. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_protocols.py +0 -0
  27. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_tape.py +0 -0
  28. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  29. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_type_casting.py +0 -0
  30. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/_version_utils.py +0 -0
  31. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/convenience.py +0 -0
  32. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/external_data.py +0 -0
  33. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/__init__.py +0 -0
  34. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/_pass_infra.py +0 -0
  35. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  36. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  37. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  38. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
  39. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
  40. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/inliner.py +0 -0
  41. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  42. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  43. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  44. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  45. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/py.typed +0 -0
  46. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/tape.py +0 -0
  47. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/testing.py +0 -0
  48. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir/traversal.py +0 -0
  49. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  50. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir.egg-info/requires.txt +0 -0
  51. {onnx_ir-0.1.4 → onnx_ir-0.1.5}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -30,6 +30,7 @@ Dynamic: license-file
30
30
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
31
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
32
  [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
33
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
33
34
 
34
35
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
35
36
 
@@ -5,6 +5,7 @@
5
5
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
6
6
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
7
7
  [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
8
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
8
9
 
9
10
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
10
11
 
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.4"
170
+ __version__ = "0.1.5"
@@ -78,6 +78,7 @@ _NON_NUMPY_NATIVE_TYPES = frozenset(
78
78
  _enums.DataType.FLOAT8E4M3FNUZ,
79
79
  _enums.DataType.FLOAT8E5M2,
80
80
  _enums.DataType.FLOAT8E5M2FNUZ,
81
+ _enums.DataType.FLOAT8E8M0,
81
82
  _enums.DataType.INT4,
82
83
  _enums.DataType.UINT4,
83
84
  _enums.DataType.FLOAT4E2M1,
@@ -261,6 +262,7 @@ def _check_numpy_representation_type(array: np.ndarray, dtype: _enums.DataType)
261
262
  ml_dtypes.float8_e4m3fn,
262
263
  ml_dtypes.float8_e5m2fnuz,
263
264
  ml_dtypes.float8_e5m2,
265
+ ml_dtypes.float8_e8m0fnu,
264
266
  ):
265
267
  raise TypeError(
266
268
  f"The numpy array dtype must be uint8 or ml_dtypes.float8* (not {array.dtype}) for IR data type {dtype}."
@@ -319,6 +321,8 @@ def _maybe_view_np_array_with_ml_dtypes(
319
321
  return array.view(ml_dtypes.float8_e5m2)
320
322
  if dtype == _enums.DataType.FLOAT8E5M2FNUZ:
321
323
  return array.view(ml_dtypes.float8_e5m2fnuz)
324
+ if dtype == _enums.DataType.FLOAT8E8M0:
325
+ return array.view(ml_dtypes.float8_e8m0fnu)
322
326
  if dtype == _enums.DataType.INT4:
323
327
  return array.view(ml_dtypes.int4)
324
328
  if dtype == _enums.DataType.UINT4:
@@ -65,6 +65,7 @@ class DataType(enum.IntEnum):
65
65
  UINT4 = 21
66
66
  INT4 = 22
67
67
  FLOAT4E2M1 = 23
68
+ FLOAT8E8M0 = 24
68
69
 
69
70
  @classmethod
70
71
  def from_numpy(cls, dtype: np.dtype) -> DataType:
@@ -81,6 +82,7 @@ class DataType(enum.IntEnum):
81
82
 
82
83
  # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
83
84
  # Ref: https://github.com/onnx/onnx/blob/2d42b6a60a52e925e57c422593e88cc51890f58a/onnx/_custom_element_types.py
85
+ # TODO(#137): Remove this when ONNX 1.19 is the minimum requirement
84
86
  if hasattr(dtype, "names"):
85
87
  if dtype.names == ("bfloat16",):
86
88
  return DataType.BFLOAT16
@@ -167,6 +169,7 @@ class DataType(enum.IntEnum):
167
169
  DataType.FLOAT8E5M2,
168
170
  DataType.FLOAT8E5M2FNUZ,
169
171
  DataType.FLOAT4E2M1,
172
+ DataType.FLOAT8E8M0,
170
173
  }
171
174
 
172
175
  def is_integer(self) -> bool:
@@ -209,6 +212,7 @@ class DataType(enum.IntEnum):
209
212
  DataType.FLOAT8E5M2FNUZ,
210
213
  DataType.INT4,
211
214
  DataType.FLOAT4E2M1,
215
+ DataType.FLOAT8E8M0,
212
216
  }
213
217
 
214
218
  def __repr__(self) -> str:
@@ -241,6 +245,7 @@ _BITWIDTH_MAP = {
241
245
  DataType.UINT4: 4,
242
246
  DataType.INT4: 4,
243
247
  DataType.FLOAT4E2M1: 4,
248
+ DataType.FLOAT8E8M0: 8,
244
249
  }
245
250
 
246
251
 
@@ -266,6 +271,7 @@ _NP_TYPE_TO_DATA_TYPE = {
266
271
  np.dtype(ml_dtypes.float8_e4m3fnuz): DataType.FLOAT8E4M3FNUZ,
267
272
  np.dtype(ml_dtypes.float8_e5m2): DataType.FLOAT8E5M2,
268
273
  np.dtype(ml_dtypes.float8_e5m2fnuz): DataType.FLOAT8E5M2FNUZ,
274
+ np.dtype(ml_dtypes.float8_e8m0fnu): DataType.FLOAT8E8M0,
269
275
  np.dtype(ml_dtypes.int4): DataType.INT4,
270
276
  np.dtype(ml_dtypes.uint4): DataType.UINT4,
271
277
  }
@@ -290,6 +296,7 @@ _DATA_TYPE_TO_SHORT_NAME = {
290
296
  DataType.FLOAT8E5M2: "f8e5m2",
291
297
  DataType.FLOAT8E4M3FNUZ: "f8e4m3fnuz",
292
298
  DataType.FLOAT8E5M2FNUZ: "f8e5m2fnuz",
299
+ DataType.FLOAT8E8M0: "f8e8m0",
293
300
  DataType.FLOAT4E2M1: "f4e2m1",
294
301
  DataType.COMPLEX64: "c64",
295
302
  DataType.COMPLEX128: "c128",
@@ -7,6 +7,7 @@ __all__ = [
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
9
  "DeduplicateInitializersPass",
10
+ "IdentityEliminationPass",
10
11
  "InlinePass",
11
12
  "LiftConstantsToInitializersPass",
12
13
  "LiftSubgraphInitializersToMainGraphPass",
@@ -30,6 +31,9 @@ from onnx_ir.passes.common.constant_manipulation import (
30
31
  LiftSubgraphInitializersToMainGraphPass,
31
32
  RemoveInitializersFromInputsPass,
32
33
  )
34
+ from onnx_ir.passes.common.identity_elimination import (
35
+ IdentityEliminationPass,
36
+ )
33
37
  from onnx_ir.passes.common.initializer_deduplication import (
34
38
  DeduplicateInitializersPass,
35
39
  )
@@ -0,0 +1,97 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Identity elimination pass for removing redundant Identity nodes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = [
8
+ "IdentityEliminationPass",
9
+ ]
10
+
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class IdentityEliminationPass(ir.passes.InPlacePass):
19
+ """Pass for eliminating redundant Identity nodes.
20
+
21
+ This pass removes Identity nodes according to the following rules:
22
+ 1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
+ of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
+ 2. If `y` is an output of a graph, and `x` is not an input of any graph,
25
+ we can still do the elimination, but the value `x` should be renamed to be `y`.
26
+ 3. If `y` is a graph-output and `x` is a graph-input, we cannot eliminate
27
+ the node. It should be retained.
28
+ """
29
+
30
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
31
+ """Main entry point for the identity elimination pass."""
32
+ modified = False
33
+
34
+ # Use RecursiveGraphIterator to process all nodes in the model graph and subgraphs
35
+ for node in ir.traversal.RecursiveGraphIterator(model.graph):
36
+ if self._try_eliminate_identity_node(node):
37
+ modified = True
38
+
39
+ # Process nodes in functions
40
+ for function in model.functions.values():
41
+ for node in ir.traversal.RecursiveGraphIterator(function):
42
+ if self._try_eliminate_identity_node(node):
43
+ modified = True
44
+
45
+ if modified:
46
+ logger.info("Identity elimination pass modified the model")
47
+
48
+ return ir.passes.PassResult(model, modified=modified)
49
+
50
+ def _try_eliminate_identity_node(self, node: ir.Node) -> bool:
51
+ """Try to eliminate a single identity node. Returns True if modified."""
52
+ if node.op_type != "Identity" or node.domain != "":
53
+ return False
54
+
55
+ if len(node.inputs) != 1 or len(node.outputs) != 1:
56
+ # Invalid Identity node, skip
57
+ return False
58
+
59
+ input_value = node.inputs[0]
60
+ output_value = node.outputs[0]
61
+
62
+ if input_value is None:
63
+ # Cannot eliminate if input is None
64
+ return False
65
+
66
+ # Get the graph that contains this node
67
+ graph_like = node.graph
68
+ assert graph_like is not None, "Node must be in a graph"
69
+
70
+ output_is_graph_output = output_value.is_graph_output()
71
+ input_is_graph_input = input_value.is_graph_input()
72
+
73
+ # Case 3: Both output is graph output and input is graph input - keep the node
74
+ if output_is_graph_output and input_is_graph_input:
75
+ return False
76
+
77
+ # Case 1 & 2 (merged): Eliminate the identity node
78
+ # Replace all uses of output with input
79
+ ir.convenience.replace_all_uses_with(output_value, input_value)
80
+
81
+ # If output is a graph output, we need to rename input and update graph outputs
82
+ if output_is_graph_output:
83
+ # Store the original output name
84
+ original_output_name = output_value.name
85
+
86
+ # Update the input value to have the output's name
87
+ input_value.name = original_output_name
88
+
89
+ # Update graph outputs to point to the input value
90
+ for idx, graph_output in enumerate(graph_like.outputs):
91
+ if graph_output is output_value:
92
+ graph_like.outputs[idx] = input_value
93
+
94
+ # Remove the identity node
95
+ graph_like.remove(node, safe=True)
96
+ logger.debug("Eliminated identity node: %s", node)
97
+ return True
@@ -405,6 +405,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
405
405
  _enums.DataType.FLOAT8E4M3FNUZ,
406
406
  _enums.DataType.FLOAT8E5M2,
407
407
  _enums.DataType.FLOAT8E5M2FNUZ,
408
+ _enums.DataType.FLOAT8E8M0,
408
409
  _enums.DataType.INT16,
409
410
  _enums.DataType.INT32,
410
411
  _enums.DataType.INT4,
@@ -505,6 +506,7 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
505
506
  _enums.DataType.FLOAT8E4M3FNUZ,
506
507
  _enums.DataType.FLOAT8E5M2,
507
508
  _enums.DataType.FLOAT8E5M2FNUZ,
509
+ _enums.DataType.FLOAT8E8M0,
508
510
  _enums.DataType.INT4,
509
511
  _enums.DataType.UINT4,
510
512
  _enums.DataType.FLOAT4E2M1,
@@ -68,6 +68,7 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType:
68
68
  torch.float8_e4m3fnuz: ir.DataType.FLOAT8E4M3FNUZ,
69
69
  torch.float8_e5m2: ir.DataType.FLOAT8E5M2,
70
70
  torch.float8_e5m2fnuz: ir.DataType.FLOAT8E5M2FNUZ,
71
+ torch.float8_e8m0fnu: ir.DataType.FLOAT8E8M0,
71
72
  torch.int16: ir.DataType.INT16,
72
73
  torch.int32: ir.DataType.INT32,
73
74
  torch.int64: ir.DataType.INT64,
@@ -104,6 +105,7 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype:
104
105
  ir.DataType.FLOAT8E4M3FNUZ: torch.float8_e4m3fnuz,
105
106
  ir.DataType.FLOAT8E5M2: torch.float8_e5m2,
106
107
  ir.DataType.FLOAT8E5M2FNUZ: torch.float8_e5m2fnuz,
108
+ ir.DataType.FLOAT8E8M0: torch.float8_e8m0fnu,
107
109
  ir.DataType.INT16: torch.int16,
108
110
  ir.DataType.INT32: torch.int32,
109
111
  ir.DataType.INT64: torch.int64,
@@ -142,6 +144,7 @@ class TorchTensor(_core.Tensor):
142
144
  ir.DataType.FLOAT8E4M3FNUZ,
143
145
  ir.DataType.FLOAT8E5M2,
144
146
  ir.DataType.FLOAT8E5M2FNUZ,
147
+ ir.DataType.FLOAT8E8M0,
145
148
  }:
146
149
  return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
147
150
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.4
3
+ Version: 0.1.5
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -30,6 +30,7 @@ Dynamic: license-file
30
30
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
31
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
32
  [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
33
+ [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
33
34
 
34
35
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
35
36
 
@@ -40,6 +40,7 @@ src/onnx_ir/passes/common/_c_api_utils.py
40
40
  src/onnx_ir/passes/common/clear_metadata_and_docstring.py
41
41
  src/onnx_ir/passes/common/common_subexpression_elimination.py
42
42
  src/onnx_ir/passes/common/constant_manipulation.py
43
+ src/onnx_ir/passes/common/identity_elimination.py
43
44
  src/onnx_ir/passes/common/initializer_deduplication.py
44
45
  src/onnx_ir/passes/common/inliner.py
45
46
  src/onnx_ir/passes/common/onnx_checker.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes