onnx-ir 0.1.5__tar.gz → 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (53) hide show
  1. {onnx_ir-0.1.5/src/onnx_ir.egg-info → onnx_ir-0.1.7}/PKG-INFO +2 -4
  2. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/README.md +0 -1
  3. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/pyproject.toml +3 -3
  4. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/__init__.py +1 -1
  5. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_convenience/__init__.py +49 -32
  6. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_core.py +60 -16
  7. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/__init__.py +4 -0
  8. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/constant_manipulation.py +15 -7
  9. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/identity_elimination.py +1 -0
  10. onnx_ir-0.1.7/src/onnx_ir/passes/common/initializer_deduplication.py +167 -0
  11. onnx_ir-0.1.7/src/onnx_ir/passes/common/naming.py +286 -0
  12. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/serde.py +94 -34
  13. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/tensor_adapters.py +14 -2
  14. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/traversal.py +35 -0
  15. {onnx_ir-0.1.5 → onnx_ir-0.1.7/src/onnx_ir.egg-info}/PKG-INFO +2 -4
  16. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
  17. onnx_ir-0.1.5/src/onnx_ir/passes/common/initializer_deduplication.py +0 -56
  18. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/LICENSE +0 -0
  19. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/MANIFEST.in +0 -0
  20. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/setup.cfg +0 -0
  21. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_convenience/_constructors.py +0 -0
  22. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_display.py +0 -0
  23. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_enums.py +0 -0
  24. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_graph_comparison.py +0 -0
  25. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_graph_containers.py +0 -0
  26. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_io.py +0 -0
  27. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_linked_list.py +0 -0
  28. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_metadata.py +0 -0
  29. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_name_authority.py +0 -0
  30. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_polyfill.py +0 -0
  31. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_protocols.py +0 -0
  32. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_tape.py +0 -0
  33. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  34. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_type_casting.py +0 -0
  35. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/_version_utils.py +0 -0
  36. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/convenience.py +0 -0
  37. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/external_data.py +0 -0
  38. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/__init__.py +0 -0
  39. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/_pass_infra.py +0 -0
  40. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  41. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  42. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  43. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/inliner.py +0 -0
  44. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  45. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  46. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  47. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  48. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/py.typed +0 -0
  49. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/tape.py +0 -0
  50. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir/testing.py +0 -0
  51. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  52. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/requires.txt +0 -0
  53. {onnx_ir-0.1.5 → onnx_ir-0.1.7}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.5
3
+ Version: 0.1.7
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
- License: Apache License v2.0
6
+ License-Expression: Apache-2.0
7
7
  Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
- Classifier: License :: OSI Approved :: Apache Software License
17
16
  Requires-Python: >=3.9
18
17
  Description-Content-Type: text/markdown
19
18
  License-File: LICENSE
@@ -29,7 +28,6 @@ Dynamic: license-file
29
28
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
30
29
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
31
30
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
32
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
33
31
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
34
32
 
35
33
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
@@ -4,7 +4,6 @@
4
4
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
5
5
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
6
6
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
7
- [![DeepWiki](https://img.shields.io/badge/DeepWiki-onnx%2Fir--py-blue.svg?logo=)](https://deepwiki.com/onnx/ir-py)
8
7
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
9
8
 
10
9
  An in-memory IR that supports the full ONNX spec, designed for graph construction, analysis and transformation.
@@ -1,5 +1,5 @@
1
1
  [build-system]
2
- requires = ["setuptools>=70"]
2
+ requires = ["setuptools>=77"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
@@ -11,7 +11,8 @@ authors = [
11
11
  ]
12
12
  readme = "README.md"
13
13
  requires-python = ">=3.9"
14
- license = {text = "Apache License v2.0"}
14
+ license = "Apache-2.0"
15
+ license-files = ["LICEN[CS]E*"]
15
16
  classifiers = [
16
17
  "Development Status :: 4 - Beta",
17
18
  "Programming Language :: Python :: 3.9",
@@ -19,7 +20,6 @@ classifiers = [
19
20
  "Programming Language :: Python :: 3.11",
20
21
  "Programming Language :: Python :: 3.12",
21
22
  "Programming Language :: Python :: 3.13",
22
- "License :: OSI Approved :: Apache Software License",
23
23
  ]
24
24
  dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
25
25
 
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.5"
170
+ __version__ = "0.1.7"
@@ -58,44 +58,52 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
58
58
  return _enums.AttributeType.STRING
59
59
  if isinstance(attr, _core.Attr):
60
60
  return attr.type
61
- if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
62
- return _enums.AttributeType.INTS
63
- if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
64
- return _enums.AttributeType.FLOATS
65
- if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
66
- return _enums.AttributeType.STRINGS
61
+ if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
62
+ return _enums.AttributeType.GRAPH
67
63
  if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
68
64
  # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
69
65
  return _enums.AttributeType.TENSOR
70
- if isinstance(attr, Sequence) and all(
71
- isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
72
- for x in attr
73
- ):
74
- return _enums.AttributeType.TENSORS
75
- if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
76
- return _enums.AttributeType.GRAPH
77
- if isinstance(attr, Sequence) and all(
78
- isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
79
- ):
80
- return _enums.AttributeType.GRAPHS
81
66
  if isinstance(
82
67
  attr,
83
68
  (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
84
69
  ):
85
70
  return _enums.AttributeType.TYPE_PROTO
86
- if isinstance(attr, Sequence) and all(
87
- isinstance(
88
- x,
89
- (
90
- _core.TensorType,
91
- _core.SequenceType,
92
- _core.OptionalType,
93
- _protocols.TypeProtocol,
94
- ),
95
- )
96
- for x in attr
97
- ):
98
- return _enums.AttributeType.TYPE_PROTOS
71
+ if isinstance(attr, Sequence):
72
+ if not attr:
73
+ logger.warning(
74
+ "Attribute type is ambiguous because it is an empty sequence. "
75
+ "Please create an Attr with an explicit type. Defaulted to INTS"
76
+ )
77
+ return _enums.AttributeType.INTS
78
+ if all(isinstance(x, int) for x in attr):
79
+ return _enums.AttributeType.INTS
80
+ if all(isinstance(x, float) for x in attr):
81
+ return _enums.AttributeType.FLOATS
82
+ if all(isinstance(x, str) for x in attr):
83
+ return _enums.AttributeType.STRINGS
84
+ if all(
85
+ isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
86
+ for x in attr
87
+ ):
88
+ return _enums.AttributeType.TENSORS
89
+ if all(
90
+ isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
91
+ for x in attr
92
+ ):
93
+ return _enums.AttributeType.GRAPHS
94
+ if all(
95
+ isinstance(
96
+ x,
97
+ (
98
+ _core.TensorType,
99
+ _core.SequenceType,
100
+ _core.OptionalType,
101
+ _protocols.TypeProtocol,
102
+ ),
103
+ )
104
+ for x in attr
105
+ ):
106
+ return _enums.AttributeType.TYPE_PROTOS
99
107
  raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
100
108
 
101
109
 
@@ -218,7 +226,7 @@ def convert_attributes(
218
226
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
219
227
  ... }
220
228
  >>> convert_attributes(attrs)
221
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
222
230
  name='graph0',
223
231
  inputs=(
224
232
  <BLANKLINE>
@@ -247,11 +255,20 @@ def convert_attributes(
247
255
  len()=0
248
256
  )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
249
257
 
258
+ .. important::
259
+ An empty sequence should be created with an explicit type by initializing
260
+ an Attr object with an attribute type to avoid type ambiguity. For example::
261
+
262
+ ir.Attr("empty", [], type=ir.AttributeType.INTS)
263
+
250
264
  Args:
251
265
  attrs: A dictionary of {<attribute name>: <python objects>} to convert.
252
266
 
253
267
  Returns:
254
- A list of _core.Attr objects.
268
+ A list of :class:`_core.Attr` objects.
269
+
270
+ Raises:
271
+ TypeError: If an attribute type is not supported.
255
272
  """
256
273
  attributes: list[_core.Attr] = []
257
274
  for name, attr in attrs.items():
@@ -2564,14 +2564,23 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2564
2564
 
2565
2565
  .. versionadded:: 0.1.2
2566
2566
  """
2567
- seen_graphs: set[Graph] = set()
2568
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2569
- graph = node.graph
2567
+ # Use a dict to preserve order
2568
+ seen_graphs: dict[Graph, None] = {}
2569
+
2570
+ # Need to use the enter_graph callback so that empty subgraphs are collected
2571
+ def enter_subgraph(graph) -> None:
2570
2572
  if graph is self:
2571
- continue
2572
- if graph is not None and graph not in seen_graphs:
2573
- seen_graphs.add(graph)
2574
- yield graph
2573
+ return
2574
+ if not isinstance(graph, Graph):
2575
+ raise TypeError(
2576
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
2577
+ )
2578
+ if graph not in seen_graphs:
2579
+ seen_graphs[graph] = None
2580
+
2581
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
2582
+ pass
2583
+ yield from seen_graphs.keys()
2575
2584
 
2576
2585
  # Mutation methods
2577
2586
  def append(self, node: Node, /) -> None:
@@ -3180,6 +3189,21 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3180
3189
  def attributes(self) -> _graph_containers.Attributes:
3181
3190
  return self._attributes
3182
3191
 
3192
+ @property
3193
+ def graph(self) -> Graph:
3194
+ """The underlying Graph object that contains the nodes of this function.
3195
+
3196
+ Only use this graph for identity comparison::
3197
+
3198
+ if value.graph is function.graph:
3199
+ # Do something with the value that belongs to this function
3200
+
3201
+ Otherwise use the Function object directly to access the nodes and other properties.
3202
+
3203
+ .. versionadded:: 0.1.7
3204
+ """
3205
+ return self._graph
3206
+
3183
3207
  @typing.overload
3184
3208
  def __getitem__(self, index: int) -> Node: ...
3185
3209
  @typing.overload
@@ -3240,14 +3264,22 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3240
3264
 
3241
3265
  .. versionadded:: 0.1.2
3242
3266
  """
3243
- seen_graphs: set[Graph] = set()
3244
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3245
- graph = node.graph
3246
- if graph is self._graph:
3247
- continue
3248
- if graph is not None and graph not in seen_graphs:
3249
- seen_graphs.add(graph)
3250
- yield graph
3267
+ seen_graphs: dict[Graph, None] = {}
3268
+
3269
+ # Need to use the enter_graph callback so that empty subgraphs are collected
3270
+ def enter_subgraph(graph) -> None:
3271
+ if graph is self:
3272
+ return
3273
+ if not isinstance(graph, Graph):
3274
+ raise TypeError(
3275
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
3276
+ )
3277
+ if graph not in seen_graphs:
3278
+ seen_graphs[graph] = None
3279
+
3280
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
3281
+ pass
3282
+ yield from seen_graphs.keys()
3251
3283
 
3252
3284
  # Mutation methods
3253
3285
  def append(self, node: Node, /) -> None:
@@ -3349,7 +3381,7 @@ class Attr(
3349
3381
  ):
3350
3382
  """Base class for ONNX attributes or references."""
3351
3383
 
3352
- __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
3384
+ __slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
3353
3385
 
3354
3386
  def __init__(
3355
3387
  self,
@@ -3365,6 +3397,7 @@ class Attr(
3365
3397
  self._value = value
3366
3398
  self._ref_attr_name = ref_attr_name
3367
3399
  self.doc_string = doc_string
3400
+ self._metadata: _metadata.MetadataStore | None = None
3368
3401
 
3369
3402
  @property
3370
3403
  def name(self) -> str:
@@ -3386,6 +3419,17 @@ class Attr(
3386
3419
  def ref_attr_name(self) -> str | None:
3387
3420
  return self._ref_attr_name
3388
3421
 
3422
+ @property
3423
+ def meta(self) -> _metadata.MetadataStore:
3424
+ """The metadata store for intermediate analysis.
3425
+
3426
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3427
+ to the ONNX proto.
3428
+ """
3429
+ if self._metadata is None:
3430
+ self._metadata = _metadata.MetadataStore()
3431
+ return self._metadata
3432
+
3389
3433
  def is_ref(self) -> bool:
3390
3434
  """Check if this attribute is a reference attribute."""
3391
3435
  return self.ref_attr_name is not None
@@ -6,11 +6,13 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateHashedInitializersPass",
9
10
  "DeduplicateInitializersPass",
10
11
  "IdentityEliminationPass",
11
12
  "InlinePass",
12
13
  "LiftConstantsToInitializersPass",
13
14
  "LiftSubgraphInitializersToMainGraphPass",
15
+ "NameFixPass",
14
16
  "RemoveInitializersFromInputsPass",
15
17
  "RemoveUnusedFunctionsPass",
16
18
  "RemoveUnusedNodesPass",
@@ -35,9 +37,11 @@ from onnx_ir.passes.common.identity_elimination import (
35
37
  IdentityEliminationPass,
36
38
  )
37
39
  from onnx_ir.passes.common.initializer_deduplication import (
40
+ DeduplicateHashedInitializersPass,
38
41
  DeduplicateInitializersPass,
39
42
  )
40
43
  from onnx_ir.passes.common.inliner import InlinePass
44
+ from onnx_ir.passes.common.naming import NameFixPass
41
45
  from onnx_ir.passes.common.onnx_checker import CheckerPass
42
46
  from onnx_ir.passes.common.shape_inference import ShapeInferencePass
43
47
  from onnx_ir.passes.common.topological_sort import TopologicalSortPass
@@ -148,6 +148,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
148
148
  if graph is model.graph:
149
149
  continue
150
150
  for name in tuple(graph.initializers):
151
+ assert name is not None
151
152
  initializer = graph.initializers[name]
152
153
  if initializer.is_graph_input():
153
154
  # Skip the ones that are also graph inputs
@@ -156,17 +157,24 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
156
157
  initializer.name,
157
158
  )
158
159
  continue
160
+ if initializer.is_graph_output():
161
+ logger.debug(
162
+ "Initializer '%s' is used as output, so it can't be lifted",
163
+ initializer.name,
164
+ )
165
+ continue
159
166
  # Remove the initializer from the subgraph
160
167
  graph.initializers.pop(name)
161
168
  # To avoid name conflicts, we need to rename the initializer
162
169
  # to a unique name in the main graph
163
- if name in registered_initializer_names:
164
- name_count = registered_initializer_names[name]
165
- initializer.name = f"{name}_{name_count}"
166
- registered_initializer_names[name] = name_count + 1
167
- else:
168
- assert initializer.name is not None
169
- registered_initializer_names[initializer.name] = 1
170
+ new_name = name
171
+ while new_name in model.graph.initializers:
172
+ if name in registered_initializer_names:
173
+ registered_initializer_names[name] += 1
174
+ else:
175
+ registered_initializer_names[name] = 1
176
+ new_name = f"{name}_{registered_initializer_names[name]}"
177
+ initializer.name = new_name
170
178
  model.graph.register_initializer(initializer)
171
179
  count += 1
172
180
  logger.debug(
@@ -19,6 +19,7 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
19
19
  """Pass for eliminating redundant Identity nodes.
20
20
 
21
21
  This pass removes Identity nodes according to the following rules:
22
+
22
23
  1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
24
  of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
25
  2. If `y` is an output of a graph, and `x` is not an input of any graph,
@@ -0,0 +1,167 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
8
+
9
+
10
+ import hashlib
11
+ import logging
12
+
13
+ import onnx_ir as ir
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
19
+ """Check if the initializer should be skipped for deduplication."""
20
+ if initializer.is_graph_input() or initializer.is_graph_output():
21
+ # Skip graph inputs and outputs
22
+ logger.warning(
23
+ "Skipped deduplication of initializer '%s' as it is a graph input or output",
24
+ initializer.name,
25
+ )
26
+ return True
27
+
28
+ const_val = initializer.const_value
29
+ if const_val is None:
30
+ # Skip if initializer has no constant value
31
+ logger.warning(
32
+ "Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
33
+ initializer.name,
34
+ )
35
+ return True
36
+
37
+ if const_val.size > size_limit:
38
+ # Skip if the initializer is larger than the size limit
39
+ logger.debug(
40
+ "Skipped initializer '%s' as it exceeds the size limit of %d elements",
41
+ initializer.name,
42
+ size_limit,
43
+ )
44
+ return True
45
+
46
+ if const_val.dtype == ir.DataType.STRING:
47
+ # Skip string initializers as they don't have a bytes representation
48
+ logger.warning(
49
+ "Skipped deduplication of string initializer '%s' (unsupported yet)",
50
+ initializer.name,
51
+ )
52
+ return True
53
+ return False
54
+
55
+
56
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
57
+ """Remove duplicated initializer tensors from the main graph and all subgraphs.
58
+
59
+ This pass detects initializers with identical shape, dtype, and content,
60
+ and replaces all duplicate references with a canonical one.
61
+
62
+ Initializers are deduplicated within each graph. To deduplicate initializers
63
+ in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
64
+ to lift the initializers to the main graph first before running pass.
65
+
66
+ .. versionadded:: 0.1.3
67
+ .. versionchanged:: 0.1.7
68
+ This pass now deduplicates initializers in subgraphs as well.
69
+ """
70
+
71
+ def __init__(self, size_limit: int = 1024):
72
+ super().__init__()
73
+ self.size_limit = size_limit
74
+
75
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
76
+ modified = False
77
+
78
+ for graph in model.graphs():
79
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
80
+ for initializer in tuple(graph.initializers.values()):
81
+ if _should_skip_initializer(initializer, self.size_limit):
82
+ continue
83
+
84
+ const_val = initializer.const_value
85
+ assert const_val is not None
86
+
87
+ key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
88
+ if key in initializers:
89
+ modified = True
90
+ initializer_to_keep = initializers[key] # type: ignore[index]
91
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
92
+ assert initializer.name is not None
93
+ graph.initializers.pop(initializer.name)
94
+ logger.info(
95
+ "Replaced initializer '%s' with existing initializer '%s'",
96
+ initializer.name,
97
+ initializer_to_keep.name,
98
+ )
99
+ else:
100
+ initializers[key] = initializer # type: ignore[index]
101
+
102
+ return ir.passes.PassResult(model=model, modified=modified)
103
+
104
+
105
+ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
106
+ """Remove duplicated initializer tensors (using a hashed method) from the graph.
107
+
108
+ This pass detects initializers with identical shape, dtype, and hashed content,
109
+ and replaces all duplicate references with a canonical one.
110
+
111
+ This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
112
+ as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
113
+
114
+ .. versionadded:: 0.1.7
115
+ """
116
+
117
+ def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
118
+ super().__init__()
119
+ # 4 GB default size limit for deduplication
120
+ self.size_limit = size_limit
121
+
122
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
123
+ modified = False
124
+
125
+ for graph in model.graphs():
126
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
127
+
128
+ for initializer in tuple(graph.initializers.values()):
129
+ if _should_skip_initializer(initializer, self.size_limit):
130
+ continue
131
+
132
+ const_val = initializer.const_value
133
+ assert const_val is not None
134
+
135
+ # Hash tensor data to avoid storing large amounts of data in memory
136
+ hashed = hashlib.sha512()
137
+ tensor_data = const_val.numpy()
138
+ hashed.update(tensor_data)
139
+ tensor_digest = hashed.hexdigest()
140
+
141
+ tensor_dims = tuple(const_val.shape.numpy())
142
+
143
+ key = (const_val.dtype, tensor_dims, tensor_digest)
144
+
145
+ if key in initializers:
146
+ if initializers[key].const_value.tobytes() != const_val.tobytes():
147
+ logger.warning(
148
+ "Initializer deduplication failed: "
149
+ "hashes match but values differ with values %s and %s",
150
+ initializers[key],
151
+ initializer,
152
+ )
153
+ continue
154
+ modified = True
155
+ initializer_to_keep = initializers[key] # type: ignore[index]
156
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
157
+ assert initializer.name is not None
158
+ graph.initializers.pop(initializer.name)
159
+ logger.info(
160
+ "Replaced initializer '%s' with existing initializer '%s'",
161
+ initializer.name,
162
+ initializer_to_keep.name,
163
+ )
164
+ else:
165
+ initializers[key] = initializer # type: ignore[index]
166
+
167
+ return ir.passes.PassResult(model=model, modified=modified)