onnx-ir 0.1.6__tar.gz → 0.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (53) hide show
  1. {onnx_ir-0.1.6/src/onnx_ir.egg-info → onnx_ir-0.1.8}/PKG-INFO +7 -4
  2. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/README.md +5 -1
  3. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/pyproject.toml +3 -3
  4. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/__init__.py +1 -1
  5. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_convenience/__init__.py +49 -32
  6. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_core.py +65 -16
  7. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_enums.py +146 -1
  8. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/__init__.py +2 -0
  9. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/constant_manipulation.py +15 -7
  10. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/identity_elimination.py +1 -0
  11. onnx_ir-0.1.8/src/onnx_ir/passes/common/initializer_deduplication.py +179 -0
  12. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/naming.py +1 -1
  13. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/serde.py +97 -36
  14. {onnx_ir-0.1.6 → onnx_ir-0.1.8/src/onnx_ir.egg-info}/PKG-INFO +7 -4
  15. onnx_ir-0.1.6/src/onnx_ir/passes/common/initializer_deduplication.py +0 -56
  16. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/LICENSE +0 -0
  17. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/MANIFEST.in +0 -0
  18. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/setup.cfg +0 -0
  19. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_convenience/_constructors.py +0 -0
  20. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_display.py +0 -0
  21. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_graph_comparison.py +0 -0
  22. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_graph_containers.py +0 -0
  23. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_io.py +0 -0
  24. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_linked_list.py +0 -0
  25. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_metadata.py +0 -0
  26. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_name_authority.py +0 -0
  27. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_polyfill.py +0 -0
  28. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_protocols.py +0 -0
  29. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_tape.py +0 -0
  30. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  31. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_type_casting.py +0 -0
  32. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/_version_utils.py +0 -0
  33. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/convenience.py +0 -0
  34. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/external_data.py +0 -0
  35. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/__init__.py +0 -0
  36. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/_pass_infra.py +0 -0
  37. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  38. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  39. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  40. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/inliner.py +0 -0
  41. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  42. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  43. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  44. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  45. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/py.typed +0 -0
  46. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/tape.py +0 -0
  47. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/tensor_adapters.py +0 -0
  48. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/testing.py +0 -0
  49. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir/traversal.py +0 -0
  50. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
  51. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  52. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/requires.txt +0 -0
  53. {onnx_ir-0.1.6 → onnx_ir-0.1.8}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
- License: Apache License v2.0
6
+ License-Expression: Apache-2.0
7
7
  Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
- Classifier: License :: OSI Approved :: Apache Software License
17
16
  Requires-Python: >=3.9
18
17
  Description-Content-Type: text/markdown
19
18
  License-File: LICENSE
@@ -23,7 +22,7 @@ Requires-Dist: typing_extensions>=4.10
23
22
  Requires-Dist: ml_dtypes
24
23
  Dynamic: license-file
25
24
 
26
- # ONNX IR
25
+ # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
27
26
 
28
27
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
29
28
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
@@ -61,6 +60,10 @@ pip install git+https://github.com/onnx/ir-py.git
61
60
  - Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
62
61
  - No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
63
62
 
63
+ ## Concept Diagram
64
+
65
+ ![Concept Diagram](docs/resource/onnx-ir-entities.svg)
66
+
64
67
  ## Code Organization 🗺️
65
68
 
66
69
  - [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
@@ -1,4 +1,4 @@
1
- # ONNX IR
1
+ # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
2
2
 
3
3
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
4
4
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
@@ -36,6 +36,10 @@ pip install git+https://github.com/onnx/ir-py.git
36
36
  - Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
37
37
  - No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
38
38
 
39
+ ## Concept Diagram
40
+
41
+ ![Concept Diagram](docs/resource/onnx-ir-entities.svg)
42
+
39
43
  ## Code Organization 🗺️
40
44
 
41
45
  - [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
@@ -1,5 +1,5 @@
1
1
  [build-system]
2
- requires = ["setuptools>=70"]
2
+ requires = ["setuptools>=77"]
3
3
  build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
@@ -11,7 +11,8 @@ authors = [
11
11
  ]
12
12
  readme = "README.md"
13
13
  requires-python = ">=3.9"
14
- license = {text = "Apache License v2.0"}
14
+ license = "Apache-2.0"
15
+ license-files = ["LICEN[CS]E*"]
15
16
  classifiers = [
16
17
  "Development Status :: 4 - Beta",
17
18
  "Programming Language :: Python :: 3.9",
@@ -19,7 +20,6 @@ classifiers = [
19
20
  "Programming Language :: Python :: 3.11",
20
21
  "Programming Language :: Python :: 3.12",
21
22
  "Programming Language :: Python :: 3.13",
22
- "License :: OSI Approved :: Apache Software License",
23
23
  ]
24
24
  dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes"]
25
25
 
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.6"
170
+ __version__ = "0.1.8"
@@ -58,44 +58,52 @@ def _infer_attribute_type(attr: SupportedAttrTypes) -> _enums.AttributeType:
58
58
  return _enums.AttributeType.STRING
59
59
  if isinstance(attr, _core.Attr):
60
60
  return attr.type
61
- if isinstance(attr, Sequence) and all(isinstance(x, int) for x in attr):
62
- return _enums.AttributeType.INTS
63
- if isinstance(attr, Sequence) and all(isinstance(x, float) for x in attr):
64
- return _enums.AttributeType.FLOATS
65
- if isinstance(attr, Sequence) and all(isinstance(x, str) for x in attr):
66
- return _enums.AttributeType.STRINGS
61
+ if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
62
+ return _enums.AttributeType.GRAPH
67
63
  if isinstance(attr, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol)):
68
64
  # Be sure to check TensorProtocol last because isinstance checking on Protocols can be slower
69
65
  return _enums.AttributeType.TENSOR
70
- if isinstance(attr, Sequence) and all(
71
- isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
72
- for x in attr
73
- ):
74
- return _enums.AttributeType.TENSORS
75
- if isinstance(attr, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)):
76
- return _enums.AttributeType.GRAPH
77
- if isinstance(attr, Sequence) and all(
78
- isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol)) for x in attr
79
- ):
80
- return _enums.AttributeType.GRAPHS
81
66
  if isinstance(
82
67
  attr,
83
68
  (_core.TensorType, _core.SequenceType, _core.OptionalType, _protocols.TypeProtocol),
84
69
  ):
85
70
  return _enums.AttributeType.TYPE_PROTO
86
- if isinstance(attr, Sequence) and all(
87
- isinstance(
88
- x,
89
- (
90
- _core.TensorType,
91
- _core.SequenceType,
92
- _core.OptionalType,
93
- _protocols.TypeProtocol,
94
- ),
95
- )
96
- for x in attr
97
- ):
98
- return _enums.AttributeType.TYPE_PROTOS
71
+ if isinstance(attr, Sequence):
72
+ if not attr:
73
+ logger.warning(
74
+ "Attribute type is ambiguous because it is an empty sequence. "
75
+ "Please create an Attr with an explicit type. Defaulted to INTS"
76
+ )
77
+ return _enums.AttributeType.INTS
78
+ if all(isinstance(x, int) for x in attr):
79
+ return _enums.AttributeType.INTS
80
+ if all(isinstance(x, float) for x in attr):
81
+ return _enums.AttributeType.FLOATS
82
+ if all(isinstance(x, str) for x in attr):
83
+ return _enums.AttributeType.STRINGS
84
+ if all(
85
+ isinstance(x, (_core.TensorBase, onnx.TensorProto, _protocols.TensorProtocol))
86
+ for x in attr
87
+ ):
88
+ return _enums.AttributeType.TENSORS
89
+ if all(
90
+ isinstance(x, (_core.Graph, onnx.GraphProto, _protocols.GraphProtocol))
91
+ for x in attr
92
+ ):
93
+ return _enums.AttributeType.GRAPHS
94
+ if all(
95
+ isinstance(
96
+ x,
97
+ (
98
+ _core.TensorType,
99
+ _core.SequenceType,
100
+ _core.OptionalType,
101
+ _protocols.TypeProtocol,
102
+ ),
103
+ )
104
+ for x in attr
105
+ ):
106
+ return _enums.AttributeType.TYPE_PROTOS
99
107
  raise TypeError(f"Unsupported attribute type: '{type(attr)}'")
100
108
 
101
109
 
@@ -218,7 +226,7 @@ def convert_attributes(
218
226
  ... "type_protos": [ir.TensorType(ir.DataType.FLOAT), ir.TensorType(ir.DataType.FLOAT)],
219
227
  ... }
220
228
  >>> convert_attributes(attrs)
221
- [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', INTS, Graph(
229
+ [Attr('int', INT, 1), Attr('float', FLOAT, 1.0), Attr('str', STRING, 'hello'), Attr('ints', INTS, [1, 2, 3]), Attr('floats', FLOATS, [1.0, 2.0, 3.0]), Attr('strings', STRINGS, ['hello', 'world']), Attr('tensor', TENSOR, Tensor<DOUBLE,[3]>(array([1., 2., 3.]), name=None)), Attr('tensor_proto', TENSOR, TensorProtoTensor<FLOAT,[3]>(array([1., 2., 3.], dtype=float32), name='proto')), Attr('graph', GRAPH, Graph(
222
230
  name='graph0',
223
231
  inputs=(
224
232
  <BLANKLINE>
@@ -247,11 +255,20 @@ def convert_attributes(
247
255
  len()=0
248
256
  )]), Attr('type_proto', TYPE_PROTO, Tensor(FLOAT)), Attr('type_protos', TYPE_PROTOS, [Tensor(FLOAT), Tensor(FLOAT)])]
249
257
 
258
+ .. important::
259
+ An empty sequence should be created with an explicit type by initializing
260
+ an Attr object with an attribute type to avoid type ambiguity. For example::
261
+
262
+ ir.Attr("empty", [], type=ir.AttributeType.INTS)
263
+
250
264
  Args:
251
265
  attrs: A dictionary of {<attribute name>: <python objects>} to convert.
252
266
 
253
267
  Returns:
254
- A list of _core.Attr objects.
268
+ A list of :class:`_core.Attr` objects.
269
+
270
+ Raises:
271
+ TypeError: If an attribute type is not supported.
255
272
  """
256
273
  attributes: list[_core.Attr] = []
257
274
  for name, attr in attrs.items():
@@ -836,6 +836,11 @@ class StringTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=to
836
836
  """The shape of the tensor. Immutable."""
837
837
  return self._shape
838
838
 
839
+ @property
840
+ def nbytes(self) -> int:
841
+ """The number of bytes in the tensor."""
842
+ return sum(len(string) for string in self.string_data())
843
+
839
844
  @property
840
845
  def raw(self) -> Sequence[bytes] | npt.NDArray[np.bytes_]:
841
846
  """Backing data of the tensor. Immutable."""
@@ -2564,14 +2569,23 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2564
2569
 
2565
2570
  .. versionadded:: 0.1.2
2566
2571
  """
2567
- seen_graphs: set[Graph] = set()
2568
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
2569
- graph = node.graph
2572
+ # Use a dict to preserve order
2573
+ seen_graphs: dict[Graph, None] = {}
2574
+
2575
+ # Need to use the enter_graph callback so that empty subgraphs are collected
2576
+ def enter_subgraph(graph) -> None:
2570
2577
  if graph is self:
2571
- continue
2572
- if graph is not None and graph not in seen_graphs:
2573
- seen_graphs.add(graph)
2574
- yield graph
2578
+ return
2579
+ if not isinstance(graph, Graph):
2580
+ raise TypeError(
2581
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
2582
+ )
2583
+ if graph not in seen_graphs:
2584
+ seen_graphs[graph] = None
2585
+
2586
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
2587
+ pass
2588
+ yield from seen_graphs.keys()
2575
2589
 
2576
2590
  # Mutation methods
2577
2591
  def append(self, node: Node, /) -> None:
@@ -3180,6 +3194,21 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3180
3194
  def attributes(self) -> _graph_containers.Attributes:
3181
3195
  return self._attributes
3182
3196
 
3197
+ @property
3198
+ def graph(self) -> Graph:
3199
+ """The underlying Graph object that contains the nodes of this function.
3200
+
3201
+ Only use this graph for identity comparison::
3202
+
3203
+ if value.graph is function.graph:
3204
+ # Do something with the value that belongs to this function
3205
+
3206
+ Otherwise use the Function object directly to access the nodes and other properties.
3207
+
3208
+ .. versionadded:: 0.1.7
3209
+ """
3210
+ return self._graph
3211
+
3183
3212
  @typing.overload
3184
3213
  def __getitem__(self, index: int) -> Node: ...
3185
3214
  @typing.overload
@@ -3240,14 +3269,22 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3240
3269
 
3241
3270
  .. versionadded:: 0.1.2
3242
3271
  """
3243
- seen_graphs: set[Graph] = set()
3244
- for node in onnx_ir.traversal.RecursiveGraphIterator(self):
3245
- graph = node.graph
3246
- if graph is self._graph:
3247
- continue
3248
- if graph is not None and graph not in seen_graphs:
3249
- seen_graphs.add(graph)
3250
- yield graph
3272
+ seen_graphs: dict[Graph, None] = {}
3273
+
3274
+ # Need to use the enter_graph callback so that empty subgraphs are collected
3275
+ def enter_subgraph(graph) -> None:
3276
+ if graph is self:
3277
+ return
3278
+ if not isinstance(graph, Graph):
3279
+ raise TypeError(
3280
+ f"Expected a Graph, got {type(graph)}. The model may be invalid"
3281
+ )
3282
+ if graph not in seen_graphs:
3283
+ seen_graphs[graph] = None
3284
+
3285
+ for _ in onnx_ir.traversal.RecursiveGraphIterator(self, enter_graph=enter_subgraph):
3286
+ pass
3287
+ yield from seen_graphs.keys()
3251
3288
 
3252
3289
  # Mutation methods
3253
3290
  def append(self, node: Node, /) -> None:
@@ -3349,7 +3386,7 @@ class Attr(
3349
3386
  ):
3350
3387
  """Base class for ONNX attributes or references."""
3351
3388
 
3352
- __slots__ = ("_name", "_ref_attr_name", "_type", "_value", "doc_string")
3389
+ __slots__ = ("_metadata", "_name", "_ref_attr_name", "_type", "_value", "doc_string")
3353
3390
 
3354
3391
  def __init__(
3355
3392
  self,
@@ -3365,6 +3402,7 @@ class Attr(
3365
3402
  self._value = value
3366
3403
  self._ref_attr_name = ref_attr_name
3367
3404
  self.doc_string = doc_string
3405
+ self._metadata: _metadata.MetadataStore | None = None
3368
3406
 
3369
3407
  @property
3370
3408
  def name(self) -> str:
@@ -3386,6 +3424,17 @@ class Attr(
3386
3424
  def ref_attr_name(self) -> str | None:
3387
3425
  return self._ref_attr_name
3388
3426
 
3427
+ @property
3428
+ def meta(self) -> _metadata.MetadataStore:
3429
+ """The metadata store for intermediate analysis.
3430
+
3431
+ Write to the :attr:`metadata_props` if you would like the metadata to be serialized
3432
+ to the ONNX proto.
3433
+ """
3434
+ if self._metadata is None:
3435
+ self._metadata = _metadata.MetadataStore()
3436
+ return self._metadata
3437
+
3389
3438
  def is_ref(self) -> bool:
3390
3439
  """Check if this attribute is a reference attribute."""
3391
3440
  return self.ref_attr_name is not None
@@ -5,6 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import enum
8
+ from typing import Any
8
9
 
9
10
  import ml_dtypes
10
11
  import numpy as np
@@ -77,7 +78,7 @@ class DataType(enum.IntEnum):
77
78
  if dtype in _NP_TYPE_TO_DATA_TYPE:
78
79
  return cls(_NP_TYPE_TO_DATA_TYPE[dtype])
79
80
 
80
- if np.issubdtype(dtype, np.str_):
81
+ if np.issubdtype(dtype, np.str_) or np.issubdtype(dtype, np.bytes_):
81
82
  return DataType.STRING
82
83
 
83
84
  # Special cases for handling custom dtypes defined in ONNX (as of onnx 1.18)
@@ -131,6 +132,146 @@ class DataType(enum.IntEnum):
131
132
  raise TypeError(f"Bitwidth not available for ONNX data type: {self}")
132
133
  return _BITWIDTH_MAP[self]
133
134
 
135
+ @property
136
+ def exponent_bitwidth(self) -> int:
137
+ """Returns the bit width of the exponent for floating-point types.
138
+
139
+ .. versionadded:: 0.1.8
140
+
141
+ Raises:
142
+ TypeError: If the data type is not supported.
143
+ """
144
+ if self.is_floating_point():
145
+ return ml_dtypes.finfo(self.numpy()).nexp
146
+
147
+ raise TypeError(f"Exponent not available for ONNX data type: {self}")
148
+
149
+ @property
150
+ def mantissa_bitwidth(self) -> int:
151
+ """Returns the bit width of the mantissa for floating-point types.
152
+
153
+ .. versionadded:: 0.1.8
154
+
155
+ Raises:
156
+ TypeError: If the data type is not supported.
157
+ """
158
+ if self.is_floating_point():
159
+ return ml_dtypes.finfo(self.numpy()).nmant
160
+
161
+ raise TypeError(f"Mantissa not available for ONNX data type: {self}")
162
+
163
+ @property
164
+ def eps(self) -> int | np.floating[Any]:
165
+ """Returns the difference between 1.0 and the next smallest representable float larger than 1.0 for the ONNX data type.
166
+
167
+ Returns 1 for integers.
168
+
169
+ .. versionadded:: 0.1.8
170
+
171
+ Raises:
172
+ TypeError: If the data type is not a numeric data type.
173
+ """
174
+ if self.is_integer():
175
+ return 1
176
+
177
+ if self.is_floating_point():
178
+ return ml_dtypes.finfo(self.numpy()).eps
179
+
180
+ raise TypeError(f"Eps not available for ONNX data type: {self}")
181
+
182
+ @property
183
+ def tiny(self) -> int | np.floating[Any]:
184
+ """Returns the smallest positive non-zero value for the ONNX data type.
185
+
186
+ Returns 1 for integers.
187
+
188
+ .. versionadded:: 0.1.8
189
+
190
+ Raises:
191
+ TypeError: If the data type is not a numeric data type.
192
+ """
193
+ if self.is_integer():
194
+ return 1
195
+
196
+ if self.is_floating_point():
197
+ return ml_dtypes.finfo(self.numpy()).tiny
198
+
199
+ raise TypeError(f"Tiny not available for ONNX data type: {self}")
200
+
201
+ @property
202
+ def min(self) -> int | np.floating[Any]:
203
+ """Returns the minimum representable value for the ONNX data type.
204
+
205
+ .. versionadded:: 0.1.8
206
+
207
+ Raises:
208
+ TypeError: If the data type is not a numeric data type.
209
+ """
210
+ if self.is_integer():
211
+ return ml_dtypes.iinfo(self.numpy()).min
212
+
213
+ if self.is_floating_point():
214
+ return ml_dtypes.finfo(self.numpy()).min
215
+
216
+ raise TypeError(f"Minimum not available for ONNX data type: {self}")
217
+
218
+ @property
219
+ def max(self) -> int | np.floating[Any]:
220
+ """Returns the maximum representable value for the ONNX data type.
221
+
222
+ .. versionadded:: 0.1.8
223
+
224
+ Raises:
225
+ TypeError: If the data type is not a numeric data type.
226
+ """
227
+ if self.is_integer():
228
+ return ml_dtypes.iinfo(self.numpy()).max
229
+
230
+ if self.is_floating_point():
231
+ return ml_dtypes.finfo(self.numpy()).max
232
+
233
+ raise TypeError(f"Maximum not available for ONNX data type: {self}")
234
+
235
+ @property
236
+ def precision(self) -> int:
237
+ """Returns the precision for the ONNX dtype if supported.
238
+
239
+ For floats returns the approximate number of decimal digits to which
240
+ this kind of float is precise. Returns 0 for integers.
241
+
242
+ .. versionadded:: 0.1.8
243
+
244
+ Raises:
245
+ TypeError: If the data type is not a numeric data type.
246
+ """
247
+ if self.is_integer():
248
+ return 0
249
+
250
+ if self.is_floating_point():
251
+ return ml_dtypes.finfo(self.numpy()).precision
252
+
253
+ raise TypeError(f"Precision not available for ONNX data type: {self}")
254
+
255
+ @property
256
+ def resolution(self) -> int | np.floating[Any]:
257
+ """Returns the resolution for the ONNX dtype if supported.
258
+
259
+ Returns the approximate decimal resolution of this type, i.e.,
260
+ 10**-precision. Returns 1 for integers.
261
+
262
+ .. versionadded:: 0.1.8
263
+
264
+ Raises:
265
+ TypeError: If the data type is not a numeric data type.
266
+ """
267
+ if self.is_integer():
268
+ return 1
269
+
270
+ if self.is_floating_point():
271
+ return ml_dtypes.finfo(self.numpy()).resolution
272
+
273
+ raise TypeError(f"Resolution not available for ONNX data type: {self}")
274
+
134
275
  def numpy(self) -> np.dtype:
135
276
  """Returns the numpy dtype for the ONNX data type.
136
277
 
@@ -215,6 +356,10 @@ class DataType(enum.IntEnum):
215
356
  DataType.FLOAT8E8M0,
216
357
  }
217
358
 
359
+ def is_string(self) -> bool:
360
+ """Returns True if the data type is a string type."""
361
+ return self == DataType.STRING
362
+
218
363
  def __repr__(self) -> str:
219
364
  return self.name
220
365
 
@@ -6,6 +6,7 @@ __all__ = [
6
6
  "CheckerPass",
7
7
  "ClearMetadataAndDocStringPass",
8
8
  "CommonSubexpressionEliminationPass",
9
+ "DeduplicateHashedInitializersPass",
9
10
  "DeduplicateInitializersPass",
10
11
  "IdentityEliminationPass",
11
12
  "InlinePass",
@@ -36,6 +37,7 @@ from onnx_ir.passes.common.identity_elimination import (
36
37
  IdentityEliminationPass,
37
38
  )
38
39
  from onnx_ir.passes.common.initializer_deduplication import (
40
+ DeduplicateHashedInitializersPass,
39
41
  DeduplicateInitializersPass,
40
42
  )
41
43
  from onnx_ir.passes.common.inliner import InlinePass
@@ -148,6 +148,7 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
148
148
  if graph is model.graph:
149
149
  continue
150
150
  for name in tuple(graph.initializers):
151
+ assert name is not None
151
152
  initializer = graph.initializers[name]
152
153
  if initializer.is_graph_input():
153
154
  # Skip the ones that are also graph inputs
@@ -156,17 +157,24 @@ class LiftSubgraphInitializersToMainGraphPass(ir.passes.InPlacePass):
156
157
  initializer.name,
157
158
  )
158
159
  continue
160
+ if initializer.is_graph_output():
161
+ logger.debug(
162
+ "Initializer '%s' is used as output, so it can't be lifted",
163
+ initializer.name,
164
+ )
165
+ continue
159
166
  # Remove the initializer from the subgraph
160
167
  graph.initializers.pop(name)
161
168
  # To avoid name conflicts, we need to rename the initializer
162
169
  # to a unique name in the main graph
163
- if name in registered_initializer_names:
164
- name_count = registered_initializer_names[name]
165
- initializer.name = f"{name}_{name_count}"
166
- registered_initializer_names[name] = name_count + 1
167
- else:
168
- assert initializer.name is not None
169
- registered_initializer_names[initializer.name] = 1
170
+ new_name = name
171
+ while new_name in model.graph.initializers:
172
+ if name in registered_initializer_names:
173
+ registered_initializer_names[name] += 1
174
+ else:
175
+ registered_initializer_names[name] = 1
176
+ new_name = f"{name}_{registered_initializer_names[name]}"
177
+ initializer.name = new_name
170
178
  model.graph.register_initializer(initializer)
171
179
  count += 1
172
180
  logger.debug(
@@ -19,6 +19,7 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
19
19
  """Pass for eliminating redundant Identity nodes.
20
20
 
21
21
  This pass removes Identity nodes according to the following rules:
22
+
22
23
  1. For any node of the form `y = Identity(x)`, where `y` is not an output
23
24
  of any graph, replace all uses of `y` with a use of `x`, and remove the node.
24
25
  2. If `y` is an output of a graph, and `x` is not an input of any graph,
@@ -0,0 +1,179 @@
1
+ # Copyright (c) ONNX Project Contributors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Pass for removing duplicated initializer tensors from a graph."""
4
+
5
+ from __future__ import annotations
6
+
7
+ __all__ = ["DeduplicateInitializersPass", "DeduplicateHashedInitializersPass"]
8
+
9
+
10
+ import hashlib
11
+ import logging
12
+
13
+ import numpy as np
14
+
15
+ import onnx_ir as ir
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def _should_skip_initializer(initializer: ir.Value, size_limit: int) -> bool:
21
+ """Check if the initializer should be skipped for deduplication."""
22
+ if initializer.is_graph_input() or initializer.is_graph_output():
23
+ # Skip graph inputs and outputs
24
+ logger.warning(
25
+ "Skipped deduplication of initializer '%s' as it is a graph input or output",
26
+ initializer.name,
27
+ )
28
+ return True
29
+
30
+ const_val = initializer.const_value
31
+ if const_val is None:
32
+ # Skip if initializer has no constant value
33
+ logger.warning(
34
+ "Skipped deduplication of initializer '%s' as it has no constant value. The model may contain invalid initializers",
35
+ initializer.name,
36
+ )
37
+ return True
38
+
39
+ if const_val.size > size_limit:
40
+ # Skip if the initializer is larger than the size limit
41
+ logger.debug(
42
+ "Skipped initializer '%s' as it exceeds the size limit of %d elements",
43
+ initializer.name,
44
+ size_limit,
45
+ )
46
+ return True
47
+ return False
48
+
49
+
50
+ def _tobytes(val):
51
+ """StringTensor does not support tobytes. Use 'string_data' instead.
52
+
53
+ However, 'string_data' yields a list of bytes which cannot be hashed, i.e.,
54
+ cannot be used to index into a dict. To generate keys for identifying
55
+ tensors in initializer deduplication the following converts the list of
56
+ bytes to an array of fixed-length strings which can be flattened into a
57
+ bytes-string. This, together with the tensor shape, is sufficient for
58
+ identifying tensors for deduplication, but it differs from the
59
+ representation used for serializing tensors (that is string_data) by adding
60
+ padding bytes so that each string occupies the same number of consecutive
61
+ bytes in the flattened .tobytes representation.
62
+ """
63
+ if val.dtype.is_string():
64
+ return np.array(val.string_data()).tobytes()
65
+ return val.tobytes()
66
+
67
+
68
+ class DeduplicateInitializersPass(ir.passes.InPlacePass):
69
+ """Remove duplicated initializer tensors from the main graph and all subgraphs.
70
+
71
+ This pass detects initializers with identical shape, dtype, and content,
72
+ and replaces all duplicate references with a canonical one.
73
+
74
+ Initializers are deduplicated within each graph. To deduplicate initializers
75
+ in the model globally (across graphs), use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
76
+ to lift the initializers to the main graph first before running pass.
77
+
78
+ .. versionadded:: 0.1.3
79
+ .. versionchanged:: 0.1.7
80
+ This pass now deduplicates initializers in subgraphs as well.
81
+ """
82
+
83
+ def __init__(self, size_limit: int = 1024):
84
+ super().__init__()
85
+ self.size_limit = size_limit
86
+
87
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
88
+ modified = False
89
+
90
+ for graph in model.graphs():
91
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
92
+ for initializer in tuple(graph.initializers.values()):
93
+ if _should_skip_initializer(initializer, self.size_limit):
94
+ continue
95
+
96
+ const_val = initializer.const_value
97
+ assert const_val is not None
98
+
99
+ key = (const_val.dtype, tuple(const_val.shape), _tobytes(const_val))
100
+ if key in initializers:
101
+ modified = True
102
+ initializer_to_keep = initializers[key] # type: ignore[index]
103
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
104
+ assert initializer.name is not None
105
+ graph.initializers.pop(initializer.name)
106
+ logger.info(
107
+ "Replaced initializer '%s' with existing initializer '%s'",
108
+ initializer.name,
109
+ initializer_to_keep.name,
110
+ )
111
+ else:
112
+ initializers[key] = initializer # type: ignore[index]
113
+
114
+ return ir.passes.PassResult(model=model, modified=modified)
115
+
116
+
117
+ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
118
+ """Remove duplicated initializer tensors (using a hashed method) from the graph.
119
+
120
+ This pass detects initializers with identical shape, dtype, and hashed content,
121
+ and replaces all duplicate references with a canonical one.
122
+
123
+ This pass should have a lower peak memory usage than :class:`DeduplicateInitializersPass`
124
+ as it does not store the full tensor data in memory, but instead uses a hash of the tensor data.
125
+
126
+ .. versionadded:: 0.1.7
127
+ """
128
+
129
+ def __init__(self, size_limit: int = 4 * 1024 * 1024 * 1024):
130
+ super().__init__()
131
+ # 4 GB default size limit for deduplication
132
+ self.size_limit = size_limit
133
+
134
+ def call(self, model: ir.Model) -> ir.passes.PassResult:
135
+ modified = False
136
+
137
+ for graph in model.graphs():
138
+ initializers: dict[tuple[ir.DataType, tuple[int, ...], str], ir.Value] = {}
139
+
140
+ for initializer in tuple(graph.initializers.values()):
141
+ if _should_skip_initializer(initializer, self.size_limit):
142
+ continue
143
+
144
+ const_val = initializer.const_value
145
+ assert const_val is not None
146
+
147
+ # Hash tensor data to avoid storing large amounts of data in memory
148
+ hashed = hashlib.sha512()
149
+ tensor_data = const_val.numpy()
150
+ hashed.update(tensor_data)
151
+ tensor_digest = hashed.hexdigest()
152
+
153
+ tensor_dims = tuple(const_val.shape.numpy())
154
+
155
+ key = (const_val.dtype, tensor_dims, tensor_digest)
156
+
157
+ if key in initializers:
158
+ if _tobytes(initializers[key].const_value) != _tobytes(const_val):
159
+ logger.warning(
160
+ "Initializer deduplication failed: "
161
+ "hashes match but values differ with values %s and %s",
162
+ initializers[key],
163
+ initializer,
164
+ )
165
+ continue
166
+ modified = True
167
+ initializer_to_keep = initializers[key] # type: ignore[index]
168
+ ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
169
+ assert initializer.name is not None
170
+ graph.initializers.pop(initializer.name)
171
+ logger.info(
172
+ "Replaced initializer '%s' with existing initializer '%s'",
173
+ initializer.name,
174
+ initializer_to_keep.name,
175
+ )
176
+ else:
177
+ initializers[key] = initializer # type: ignore[index]
178
+
179
+ return ir.passes.PassResult(model=model, modified=modified)
@@ -64,7 +64,7 @@ class NameFixPass(ir.passes.InPlacePass):
64
64
  def custom_value_name(value: ir.Value) -> str:
65
65
  return f"custom_value_{value.type}"
66
66
 
67
- name_fix_pass = NameFixPass(nameGenerator=CustomNameGenerator())
67
+ name_fix_pass = NameFixPass(name_generator=CustomNameGenerator())
68
68
 
69
69
  .. versionadded:: 0.1.6
70
70
  """
@@ -682,8 +682,8 @@ def deserialize_graph(proto: onnx.GraphProto) -> _core.Graph:
682
682
  Returns:
683
683
  IR Graph.
684
684
 
685
- .. versionadded:: 0.3
686
- Support for *quantization_annotation* is added.
685
+ .. versionadded:: 0.1.3
686
+ Support for `quantization_annotation` is added.
687
687
  """
688
688
  return _deserialize_graph(proto, [])
689
689
 
@@ -760,6 +760,18 @@ def _deserialize_graph(
760
760
  # Build the value info dictionary to allow for quick lookup for this graph scope
761
761
  value_info = {info.name: info for info in proto.value_info}
762
762
 
763
+ # Declare values for all node outputs from this graph scope. This is necessary
764
+ # to handle the case where a node in a subgraph uses a value that is declared out
765
+ # of order in the outer graph. Declaring the values first allows us to find the
766
+ # values later when deserializing the nodes in subgraphs.
767
+ for node in proto.node:
768
+ _declare_node_outputs(
769
+ node,
770
+ values,
771
+ value_info=value_info,
772
+ quantization_annotations=quantization_annotations,
773
+ )
774
+
763
775
  # Deserialize nodes with all known values
764
776
  nodes = [
765
777
  _deserialize_node(node, scoped_values, value_info, quantization_annotations)
@@ -798,6 +810,55 @@ def _deserialize_graph(
798
810
  )
799
811
 
800
812
 
813
+ def _declare_node_outputs(
814
+ proto: onnx.NodeProto,
815
+ current_value_scope: dict[str, _core.Value],
816
+ value_info: dict[str, onnx.ValueInfoProto],
817
+ quantization_annotations: dict[str, onnx.TensorAnnotation],
818
+ ) -> None:
819
+ """Declare outputs for a node in the current graph scope.
820
+
821
+ This is necessary to handle the case where a node in a subgraph uses a value that is declared
822
+ out of order in the outer graph. Declaring the values first allows us to find the values later
823
+ when deserializing the nodes in subgraphs.
824
+
825
+ Args:
826
+ proto: The ONNX NodeProto to declare outputs for.
827
+ current_value_scope: The current scope of values, mapping value names to their corresponding Value objects.
828
+ value_info: A dictionary mapping value names to their corresponding ValueInfoProto.
829
+ quantization_annotations: A dictionary mapping tensor names to their corresponding TensorAnnotation.
830
+
831
+ Raises:
832
+ ValueError: If an output name is redeclared in the current graph scope.
833
+ """
834
+ for output_name in proto.output:
835
+ if output_name == "":
836
+ continue
837
+ if output_name in current_value_scope:
838
+ raise ValueError(
839
+ f"Output '{output_name}' is redeclared in the current graph scope. "
840
+ f"Original declaration {current_value_scope[output_name]}. "
841
+ f"New declaration: by operator '{proto.op_type}' of node '{proto.name}'. "
842
+ "The model is invalid"
843
+ )
844
+
845
+ # Create the value and add it to the current scope.
846
+ value = _core.Value(name=output_name)
847
+ current_value_scope[output_name] = value
848
+ # Fill in shape/type information if they exist
849
+ if output_name in value_info:
850
+ deserialize_value_info_proto(value_info[output_name], value)
851
+ else:
852
+ logger.debug(
853
+ "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
854
+ output_name,
855
+ proto.name,
856
+ proto.op_type,
857
+ )
858
+ if output_name in quantization_annotations:
859
+ _deserialize_quantization_annotation(quantization_annotations[output_name], value)
860
+
861
+
801
862
  @_capture_errors(lambda proto: proto.name)
802
863
  def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
803
864
  """Deserialize an ONNX FunctionProto into an IR Function.
@@ -812,7 +873,14 @@ def deserialize_function(proto: onnx.FunctionProto) -> _core.Function:
812
873
  values: dict[str, _core.Value] = {v.name: v for v in inputs} # type: ignore[misc]
813
874
  value_info = {info.name: info for info in getattr(proto, "value_info", [])}
814
875
 
815
- # TODO(justinchuby): Handle unsorted nodes
876
+ for node in proto.node:
877
+ _declare_node_outputs(
878
+ node,
879
+ values,
880
+ value_info=value_info,
881
+ quantization_annotations={},
882
+ )
883
+
816
884
  nodes = [
817
885
  _deserialize_node(node, [values], value_info=value_info, quantization_annotations={})
818
886
  for node in proto.node
@@ -1137,8 +1205,15 @@ def deserialize_node(proto: onnx.NodeProto) -> _core.Node:
1137
1205
  Returns:
1138
1206
  An IR Node object representing the ONNX node.
1139
1207
  """
1208
+ value_scope: dict[str, _core.Value] = {}
1209
+ _declare_node_outputs(
1210
+ proto,
1211
+ value_scope,
1212
+ value_info={},
1213
+ quantization_annotations={},
1214
+ )
1140
1215
  return _deserialize_node(
1141
- proto, scoped_values=[{}], value_info={}, quantization_annotations={}
1216
+ proto, scoped_values=[value_scope], value_info={}, quantization_annotations={}
1142
1217
  )
1143
1218
 
1144
1219
 
@@ -1161,18 +1236,18 @@ def _deserialize_node(
1161
1236
  for values in reversed(scoped_values):
1162
1237
  if input_name not in values:
1163
1238
  continue
1239
+
1164
1240
  node_inputs.append(values[input_name])
1165
1241
  found = True
1166
1242
  del values # Remove the reference so it is not used by mistake
1167
1243
  break
1168
1244
  if not found:
1169
- # If the input is not found, we know the graph may be unsorted and
1170
- # the input may be a supposed-to-be initializer or an output of a node that comes later.
1171
- # Here we create the value with the name and add it to the current scope.
1172
- # Nodes need to check the value pool for potentially initialized outputs
1245
+ # If the input is not found, we know the graph is invalid because the value
1246
+ # is not declared. We will still create a new input for the node so that
1247
+ # it can be fixed later.
1173
1248
  logger.warning(
1174
- "Input '%s' of node '%s(%s::%s:%s)' not found in any scope. "
1175
- "The graph may be unsorted. Creating a new input (current depth: %s) .",
1249
+ "Input '%s' of node '%s' (%s::%s:%s) cannot be found in any scope. "
1250
+ "The model is invalid but we will still create a new input for the node (current depth: %s)",
1176
1251
  input_name,
1177
1252
  proto.name,
1178
1253
  proto.domain,
@@ -1208,35 +1283,22 @@ def _deserialize_node(
1208
1283
  node_outputs.append(_core.Value(name=""))
1209
1284
  continue
1210
1285
 
1211
- # 1. When the graph is unsorted, we may be able to find the output already created
1286
+ # The outputs should already be declared in the current scope by _declare_node_outputs.
1287
+ #
1288
+ # When the graph is unsorted, we may be able to find the output already created
1212
1289
  # as an input to some other nodes in the current scope.
1213
1290
  # Note that a value is always owned by the producing node. Even though a value
1214
1291
  # can be created when parsing inputs of other nodes, the new node created here
1215
1292
  # that produces the value will assume ownership. It is then impossible to transfer
1216
1293
  # the ownership to any other node.
1217
-
1294
+ #
1218
1295
  # The output can only be found in the current scope. It is impossible for
1219
1296
  # a node to produce an output that is not in its own scope.
1220
1297
  current_scope = scoped_values[-1]
1221
- if output_name in current_scope:
1222
- value = current_scope[output_name]
1223
- else:
1224
- # 2. Common scenario: the graph is sorted and this is the first time we see the output.
1225
- # Create the value and add it to the current scope.
1226
- value = _core.Value(name=output_name)
1227
- current_scope[output_name] = value
1228
- # Fill in shape/type information if they exist
1229
- if output_name in value_info:
1230
- deserialize_value_info_proto(value_info[output_name], value)
1231
- else:
1232
- logger.debug(
1233
- "ValueInfoProto not found for output '%s' in node '%s' of type '%s'",
1234
- output_name,
1235
- proto.name,
1236
- proto.op_type,
1237
- )
1238
- if output_name in quantization_annotations:
1239
- _deserialize_quantization_annotation(quantization_annotations[output_name], value)
1298
+ assert output_name in current_scope, (
1299
+ f"Output '{output_name}' not found in the current scope. This is unexpected"
1300
+ )
1301
+ value = current_scope[output_name]
1240
1302
  node_outputs.append(value)
1241
1303
  return _core.Node(
1242
1304
  proto.domain,
@@ -1469,8 +1531,6 @@ def serialize_graph_into(
1469
1531
  serialize_value_into(graph_proto.input.add(), input_)
1470
1532
  if input_.name not in from_.initializers:
1471
1533
  # Annotations for initializers will be added below to avoid double adding
1472
- # TODO(justinchuby): We should add a method is_initializer() on Value when
1473
- # the initializer list is tracked
1474
1534
  _maybe_add_quantization_annotation(graph_proto, input_)
1475
1535
  input_names = {input_.name for input_ in from_.inputs}
1476
1536
  # TODO(justinchuby): Support sparse_initializer
@@ -1724,11 +1784,12 @@ def _fill_in_value_for_attribute(
1724
1784
  ) -> None:
1725
1785
  if type_ == _enums.AttributeType.INT:
1726
1786
  # value: int
1727
- attribute_proto.i = value
1787
+ # Cast bool to int, for example
1788
+ attribute_proto.i = int(value)
1728
1789
  attribute_proto.type = onnx.AttributeProto.INT
1729
1790
  elif type_ == _enums.AttributeType.FLOAT:
1730
1791
  # value: float
1731
- attribute_proto.f = value
1792
+ attribute_proto.f = float(value)
1732
1793
  attribute_proto.type = onnx.AttributeProto.FLOAT
1733
1794
  elif type_ == _enums.AttributeType.STRING:
1734
1795
  # value: str
@@ -1818,7 +1879,7 @@ def serialize_value(value: _protocols.ValueProtocol, *, name: str = "") -> onnx.
1818
1879
  return value_info_proto
1819
1880
 
1820
1881
 
1821
- @_capture_errors(lambda value_info_proto, from_: repr(from_))
1882
+ @_capture_errors(lambda value_info_proto, from_, name="": repr(from_))
1822
1883
  def serialize_value_into(
1823
1884
  value_info_proto: onnx.ValueInfoProto,
1824
1885
  from_: _protocols.ValueProtocol,
@@ -1,9 +1,9 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.6
3
+ Version: 0.1.8
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
- License: Apache License v2.0
6
+ License-Expression: Apache-2.0
7
7
  Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
@@ -13,7 +13,6 @@ Classifier: Programming Language :: Python :: 3.10
13
13
  Classifier: Programming Language :: Python :: 3.11
14
14
  Classifier: Programming Language :: Python :: 3.12
15
15
  Classifier: Programming Language :: Python :: 3.13
16
- Classifier: License :: OSI Approved :: Apache Software License
17
16
  Requires-Python: >=3.9
18
17
  Description-Content-Type: text/markdown
19
18
  License-File: LICENSE
@@ -23,7 +22,7 @@ Requires-Dist: typing_extensions>=4.10
23
22
  Requires-Dist: ml_dtypes
24
23
  Dynamic: license-file
25
24
 
26
- # ONNX IR
25
+ # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
27
26
 
28
27
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
29
28
  [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
@@ -61,6 +60,10 @@ pip install git+https://github.com/onnx/ir-py.git
61
60
  - Pythonic and familiar APIs: Classes define Pythonic apis and still map to ONNX protobuf concepts in an intuitive way.
62
61
  - No protobuf dependency: The IR does not require protobuf once the model is converted to the IR representation, decoupling from the serialization format.
63
62
 
63
+ ## Concept Diagram
64
+
65
+ ![Concept Diagram](docs/resource/onnx-ir-entities.svg)
66
+
64
67
  ## Code Organization 🗺️
65
68
 
66
69
  - [`_protocols.py`](src/onnx_ir/_protocols.py): Interfaces defined for all entities in the IR.
@@ -1,56 +0,0 @@
1
- # Copyright (c) ONNX Project Contributors
2
- # SPDX-License-Identifier: Apache-2.0
3
- """Pass for removing duplicated initializer tensors from a graph."""
4
-
5
- from __future__ import annotations
6
-
7
- __all__ = [
8
- "DeduplicateInitializersPass",
9
- ]
10
-
11
-
12
- import onnx_ir as ir
13
-
14
-
15
- class DeduplicateInitializersPass(ir.passes.InPlacePass):
16
- """Remove duplicated initializer tensors from the graph.
17
-
18
- This pass detects initializers with identical shape, dtype, and content,
19
- and replaces all duplicate references with a canonical one.
20
-
21
- To deduplicate initializers from subgraphs, use :class:`~onnx_ir.passes.common.LiftSubgraphInitializersToMainGraphPass`
22
- to lift the initializers to the main graph first before running pass.
23
-
24
- .. versionadded:: 0.1.3
25
- """
26
-
27
- def __init__(self, size_limit: int = 1024):
28
- super().__init__()
29
- self.size_limit = size_limit
30
-
31
- def call(self, model: ir.Model) -> ir.passes.PassResult:
32
- graph = model.graph
33
- initializers: dict[tuple[ir.DataType, tuple[int, ...], bytes], ir.Value] = {}
34
- modified = False
35
-
36
- for initializer in tuple(graph.initializers.values()):
37
- # TODO(justinchuby): Handle subgraphs as well. For now users can lift initializers
38
- # out from the main graph before running this pass.
39
- const_val = initializer.const_value
40
- if const_val is None:
41
- # Skip if initializer has no constant value
42
- continue
43
-
44
- if const_val.size > self.size_limit:
45
- continue
46
-
47
- key = (const_val.dtype, tuple(const_val.shape), const_val.tobytes())
48
- if key in initializers:
49
- modified = True
50
- ir.convenience.replace_all_uses_with(initializer, initializers[key]) # type: ignore[index]
51
- assert initializer.name is not None
52
- graph.initializers.pop(initializer.name)
53
- else:
54
- initializers[key] = initializer # type: ignore[index]
55
-
56
- return ir.passes.PassResult(model=model, modified=modified)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes