onnx-ir 0.1.9__tar.gz → 0.1.11__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (52) hide show
  1. {onnx_ir-0.1.9/src/onnx_ir.egg-info → onnx_ir-0.1.11}/PKG-INFO +1 -6
  2. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/pyproject.toml +0 -5
  3. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/__init__.py +1 -1
  4. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_convenience/_constructors.py +4 -1
  5. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_core.py +207 -18
  6. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/external_data.py +10 -4
  7. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/constant_manipulation.py +5 -0
  8. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/identity_elimination.py +28 -0
  9. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/naming.py +1 -7
  10. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/serde.py +33 -7
  11. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/tensor_adapters.py +16 -8
  12. {onnx_ir-0.1.9 → onnx_ir-0.1.11/src/onnx_ir.egg-info}/PKG-INFO +1 -6
  13. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/LICENSE +0 -0
  14. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/MANIFEST.in +0 -0
  15. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/README.md +0 -0
  16. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/setup.cfg +0 -0
  17. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_convenience/__init__.py +0 -0
  18. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_display.py +0 -0
  19. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_enums.py +0 -0
  20. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_graph_comparison.py +0 -0
  21. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_graph_containers.py +0 -0
  22. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_io.py +0 -0
  23. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_linked_list.py +0 -0
  24. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_metadata.py +0 -0
  25. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_name_authority.py +0 -0
  26. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_polyfill.py +0 -0
  27. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_protocols.py +0 -0
  28. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_tape.py +0 -0
  29. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  30. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_type_casting.py +0 -0
  31. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/_version_utils.py +0 -0
  32. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/convenience.py +0 -0
  33. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/__init__.py +0 -0
  34. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/_pass_infra.py +0 -0
  35. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/__init__.py +0 -0
  36. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  37. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  38. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
  39. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
  40. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/inliner.py +0 -0
  41. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  42. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  43. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  44. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  45. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/py.typed +0 -0
  46. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/tape.py +0 -0
  47. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/testing.py +0 -0
  48. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir/traversal.py +0 -0
  49. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
  50. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  51. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir.egg-info/requires.txt +0 -0
  52. {onnx_ir-0.1.9 → onnx_ir-0.1.11}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
10
10
  Classifier: Development Status :: 4 - Beta
11
- Classifier: Programming Language :: Python :: 3.9
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
11
  Requires-Python: >=3.9
17
12
  Description-Content-Type: text/markdown
18
13
  License-File: LICENSE
@@ -15,11 +15,6 @@ license = "Apache-2.0"
15
15
  license-files = ["LICEN[CS]E*"]
16
16
  classifiers = [
17
17
  "Development Status :: 4 - Beta",
18
- "Programming Language :: Python :: 3.9",
19
- "Programming Language :: Python :: 3.10",
20
- "Programming Language :: Python :: 3.11",
21
- "Programming Language :: Python :: 3.12",
22
- "Programming Language :: Python :: 3.13",
23
18
  ]
24
19
  dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes>=0.5.0"]
25
20
 
@@ -168,4 +168,4 @@ def __set_module() -> None:
168
168
 
169
169
 
170
170
  __set_module()
171
- __version__ = "0.1.9"
171
+ __version__ = "0.1.11"
@@ -224,6 +224,7 @@ def val(
224
224
  *,
225
225
  type: ir.TypeProtocol | None = None,
226
226
  const_value: ir.TensorProtocol | None = None,
227
+ metadata_props: dict[str, str] | None = None,
227
228
  ) -> ir.Value:
228
229
  """Create a :class:`~onnx_ir.Value` with the given name and type.
229
230
 
@@ -253,6 +254,7 @@ def val(
253
254
  type: The type of the value. Only one of dtype and type can be specified.
254
255
  const_value: The constant tensor that initializes the value. Supply this argument
255
256
  when you want to create an initializer. The type and shape can be obtained from the tensor.
257
+ metadata_props: The metadata properties that will be serialized to the ONNX proto.
256
258
 
257
259
  Returns:
258
260
  A Value object.
@@ -279,10 +281,11 @@ def val(
279
281
  type=const_tensor_type,
280
282
  shape=_core.Shape(const_value.shape), # type: ignore
281
283
  const_value=const_value,
284
+ metadata_props=metadata_props,
282
285
  )
283
286
 
284
287
  if type is None and dtype is not None:
285
288
  type = _core.TensorType(dtype)
286
289
  if shape is not None and not isinstance(shape, _core.Shape):
287
290
  shape = _core.Shape(shape)
288
- return _core.Value(name=name, type=type, shape=shape)
291
+ return _core.Value(name=name, type=type, shape=shape, metadata_props=metadata_props)
@@ -165,6 +165,11 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
165
165
 
166
166
  @property
167
167
  def metadata_props(self) -> dict[str, str]:
168
+ """The metadata properties of the tensor.
169
+
170
+ The metadata properties are used to store additional information about the tensor.
171
+ Unlike ``meta``, this property is serialized to the ONNX proto.
172
+ """
168
173
  if self._metadata_props is None:
169
174
  self._metadata_props = {}
170
175
  return self._metadata_props
@@ -180,6 +185,19 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
180
185
  self._metadata = _metadata.MetadataStore()
181
186
  return self._metadata
182
187
 
188
+ def tofile(self, file) -> None:
189
+ """Write the tensor to a binary file.
190
+
191
+ This method writes the raw bytes of the tensor to a file-like object.
192
+ The file-like object must have a ``write`` method that accepts bytes.
193
+
194
+ .. versionadded:: 0.1.11
195
+
196
+ Args:
197
+ file: A file-like object with a ``write`` method that accepts bytes.
198
+ """
199
+ file.write(self.tobytes())
200
+
183
201
  def display(self, *, page: bool = False) -> None:
184
202
  rich = _display.require_rich()
185
203
 
@@ -332,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
332
350
  return array
333
351
 
334
352
 
353
+ def _supports_fileno(file: Any) -> bool:
354
+ """Check if the file-like object supports fileno()."""
355
+ if not hasattr(file, "fileno"):
356
+ return False
357
+ try:
358
+ file.fileno()
359
+ except Exception: # pylint: disable=broad-except
360
+ return False
361
+ return True
362
+
363
+
364
+ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
365
+ """Create a numpy array for the byte representation of the tensor.
366
+
367
+ This function is used for serializing the tensor to bytes. It handles the
368
+ special cases for 4-bit data types and endianness.
369
+ """
370
+ array = tensor.numpy()
371
+ if tensor.dtype in {
372
+ _enums.DataType.INT4,
373
+ _enums.DataType.UINT4,
374
+ _enums.DataType.FLOAT4E2M1,
375
+ }:
376
+ # Pack the array into int4
377
+ array = _type_casting.pack_4bitx2(array)
378
+ else:
379
+ assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
380
+ if not _IS_LITTLE_ENDIAN:
381
+ array = array.astype(array.dtype.newbyteorder("<"))
382
+ return array
383
+
384
+
335
385
  class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
336
386
  """An immutable concrete tensor.
337
387
 
@@ -504,20 +554,24 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
504
554
  value is not a numpy array.
505
555
  """
506
556
  # TODO(justinchuby): Support DLPack
507
- array = self.numpy()
508
- if self.dtype in {
509
- _enums.DataType.INT4,
510
- _enums.DataType.UINT4,
511
- _enums.DataType.FLOAT4E2M1,
512
- }:
513
- # Pack the array into int4
514
- array = _type_casting.pack_4bitx2(array)
515
- else:
516
- assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
517
- if not _IS_LITTLE_ENDIAN:
518
- array = array.view(array.dtype.newbyteorder("<"))
557
+ array = _create_np_array_for_byte_representation(self)
519
558
  return array.tobytes()
520
559
 
560
+ def tofile(self, file) -> None:
561
+ """Write the tensor to a binary file.
562
+
563
+ .. versionadded:: 0.1.11
564
+
565
+ Args:
566
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
567
+ """
568
+ if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
569
+ # This is a duplication of tobytes() for handling special cases
570
+ array = _create_np_array_for_byte_representation(self)
571
+ array.tofile(file)
572
+ else:
573
+ file.write(self.tobytes())
574
+
521
575
 
522
576
  class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
523
577
  """An immutable concrete tensor with its data store on disk.
@@ -530,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
530
584
  the tensor is recommended if IO overhead and memory usage is a concern.
531
585
 
532
586
  To obtain an array, call :meth:`numpy`. To obtain the bytes,
533
- call :meth:`tobytes`.
587
+ call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
534
588
 
535
589
  The :attr:`location` must be a relative path conforming to the ONNX
536
590
  specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
@@ -585,7 +639,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
585
639
  length: The length of the data in bytes.
586
640
  dtype: The data type of the tensor.
587
641
  shape: The shape of the tensor.
588
- name: The name of the tensor..
642
+ name: The name of the tensor.
589
643
  doc_string: The documentation string.
590
644
  metadata_props: The metadata properties.
591
645
  base_dir: The base directory for the external data. It is used to resolve relative paths.
@@ -741,6 +795,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
741
795
  length = self._length or self.nbytes
742
796
  return self.raw[offset : offset + length]
743
797
 
798
+ def tofile(self, file) -> None:
799
+ self._check_validity()
800
+ with open(self.path, "rb") as src:
801
+ if self._offset is not None:
802
+ src.seek(self._offset)
803
+ bytes_to_copy = self._length or self.nbytes
804
+ chunk_size = 1024 * 1024 # 1MB
805
+ while bytes_to_copy > 0:
806
+ chunk = src.read(min(chunk_size, bytes_to_copy))
807
+ file.write(chunk)
808
+ bytes_to_copy -= len(chunk)
809
+
744
810
  def valid(self) -> bool:
745
811
  """Check if the tensor is valid.
746
812
 
@@ -974,6 +1040,15 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
974
1040
  """Return the bytes of the tensor."""
975
1041
  return self._evaluate().tobytes()
976
1042
 
1043
+ def tofile(self, file) -> None:
1044
+ tensor = self._evaluate()
1045
+ if hasattr(tensor, "tofile"):
1046
+ # Some existing implementation of TensorProtocol
1047
+ # may not have tofile() as it was introduced in v0.1.11
1048
+ tensor.tofile(file)
1049
+ else:
1050
+ super().tofile(file)
1051
+
977
1052
 
978
1053
  class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
979
1054
  """A tensor that stores 4bit datatypes in packed format.
@@ -1105,9 +1180,26 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1105
1180
  """
1106
1181
  array = self.numpy_packed()
1107
1182
  if not _IS_LITTLE_ENDIAN:
1108
- array = array.view(array.dtype.newbyteorder("<"))
1183
+ array = array.astype(array.dtype.newbyteorder("<"))
1109
1184
  return array.tobytes()
1110
1185
 
1186
+ def tofile(self, file) -> None:
1187
+ """Write the tensor to a binary file.
1188
+
1189
+ .. versionadded:: 0.1.11
1190
+
1191
+ Args:
1192
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
1193
+ """
1194
+ if _supports_fileno(file):
1195
+ # This is a duplication of tobytes() for handling edge cases
1196
+ array = self.numpy_packed()
1197
+ if not _IS_LITTLE_ENDIAN:
1198
+ array = array.astype(array.dtype.newbyteorder("<"))
1199
+ array.tofile(file)
1200
+ else:
1201
+ file.write(self.tobytes())
1202
+
1111
1203
 
1112
1204
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1113
1205
  """Immutable symbolic dimension that can be shared across multiple shapes.
@@ -1214,6 +1306,12 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1214
1306
 
1215
1307
  Use :meth:`get_denotation` and :meth:`set_denotation` to access and modify the denotations.
1216
1308
 
1309
+ .. note::
1310
+ Two shapes can be compared for equality. Be careful when comparing shapes with
1311
+ unknown dimensions (``None``), as they may not be considered semantically equal
1312
+ even if all dimensions are the same. You can use :meth:`has_unknown_dim` to
1313
+ check if a shape has any unknown dimensions.
1314
+
1217
1315
  Example::
1218
1316
 
1219
1317
  >>> import onnx_ir as ir
@@ -1422,6 +1520,29 @@ class Shape(_protocols.ShapeProtocol, _display.PrettyPrintable):
1422
1520
  return not self.is_static()
1423
1521
  return not self.is_static(dim)
1424
1522
 
1523
+ def is_unknown_dim(self, dim: int) -> bool:
1524
+ """Return True if the dimension is unknown (None).
1525
+
1526
+ A dynamic dimension without a symbolic name is considered unknown.
1527
+
1528
+ .. versionadded:: 0.1.10
1529
+
1530
+ Args:
1531
+ dim: The index of the dimension.
1532
+ """
1533
+ dim_obj = self._dims[dim]
1534
+ return isinstance(dim_obj, SymbolicDim) and dim_obj.value is None
1535
+
1536
+ def has_unknown_dim(self) -> bool:
1537
+ """Return True if any dimension is unknown (None).
1538
+
1539
+ You can use :meth:`is_unknown_dim` to check if a specific dimension is unknown.
1540
+
1541
+ .. versionadded:: 0.1.10
1542
+ """
1543
+ # We can use "in" directly because SymbolicDim implements __eq__ with None
1544
+ return None in self._dims
1545
+
1425
1546
 
1426
1547
  def _quoted(string: str) -> str:
1427
1548
  """Return a quoted string.
@@ -2022,9 +2143,18 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2022
2143
  type: _protocols.TypeProtocol | None = None,
2023
2144
  doc_string: str | None = None,
2024
2145
  const_value: _protocols.TensorProtocol | None = None,
2146
+ metadata_props: dict[str, str] | None = None,
2025
2147
  ) -> None:
2026
2148
  """Initialize a value.
2027
2149
 
2150
+ When assigning a name to the value, the name of the backing `const_value` (Tensor)
2151
+ will also be updated. If the value is an initializer of a graph, the initializers
2152
+ dictionary of the graph will also be updated.
2153
+
2154
+ .. versionchanged:: 0.1.10
2155
+ Assigning a name to the value will also update the graph initializer entry
2156
+ if the value is an initializer of a graph.
2157
+
2028
2158
  Args:
2029
2159
  producer: The node that produces the value.
2030
2160
  It can be ``None`` when the value is initialized first than its producer.
@@ -2034,11 +2164,12 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2034
2164
  type: The type of the value.
2035
2165
  doc_string: The documentation string.
2036
2166
  const_value: The constant tensor if the value is constant.
2167
+ metadata_props: Metadata that will be serialized to the ONNX file.
2037
2168
  """
2038
2169
  self._producer: Node | None = producer
2039
2170
  self._index: int | None = index
2040
2171
  self._metadata: _metadata.MetadataStore | None = None
2041
- self._metadata_props: dict[str, str] | None = None
2172
+ self._metadata_props: dict[str, str] | None = metadata_props
2042
2173
 
2043
2174
  self._name: str | None = name
2044
2175
  self._shape: Shape | None = shape
@@ -2170,10 +2301,41 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2170
2301
 
2171
2302
  @name.setter
2172
2303
  def name(self, value: str | None) -> None:
2304
+ if self._name == value:
2305
+ return
2306
+
2307
+ # First check if renaming is valid. Do not change anything if it is invalid
2308
+ # to prevent the value from being in an inconsistent state.
2309
+ is_initializer = self.is_initializer()
2310
+ if is_initializer:
2311
+ if value is None:
2312
+ raise ValueError(
2313
+ "Initializer value cannot have name set to None. Please pop() the value from initializers first to do so."
2314
+ )
2315
+ graph = self._graph
2316
+ assert graph is not None
2317
+ if value in graph.initializers and graph.initializers[value] is not self:
2318
+ raise ValueError(
2319
+ f"Cannot rename initializer '{self}' to '{value}': an initializer with that name already exists."
2320
+ )
2321
+
2322
+ # Rename the backing constant tensor
2173
2323
  if self._const_value is not None:
2174
2324
  self._const_value.name = value
2325
+
2326
+ # Rename self
2327
+ old_name = self._name
2175
2328
  self._name = value
2176
2329
 
2330
+ if is_initializer:
2331
+ # Rename the initializer entry in the graph
2332
+ assert value is not None, "debug: Should be guarded above"
2333
+ graph = self._graph
2334
+ assert graph is not None
2335
+ assert old_name is not None
2336
+ graph.initializers.pop(old_name)
2337
+ graph.initializers[value] = self
2338
+
2177
2339
  @property
2178
2340
  def type(self) -> _protocols.TypeProtocol | None:
2179
2341
  """The type of the tensor.
@@ -2226,9 +2388,16 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2226
2388
  def const_value(
2227
2389
  self,
2228
2390
  ) -> _protocols.TensorProtocol | None:
2229
- """A concrete value.
2391
+ """The backing constant tensor for the value.
2392
+
2393
+ If the ``Value`` has a ``const_value`` and is part of a graph initializers
2394
+ dictionary, the value is an initialized value. Its ``const_value``
2395
+ will appear as an ``initializer`` in the GraphProto when serialized.
2396
+
2397
+ If the ``Value`` is not part of a graph initializers dictionary, the ``const_value``
2398
+ field will be ignored during serialization.
2230
2399
 
2231
- The value can be backed by different raw data types, such as numpy arrays.
2400
+ ``const_value`` can be backed by different raw data types, such as numpy arrays.
2232
2401
  The only guarantee is that it conforms TensorProtocol.
2233
2402
  """
2234
2403
  return self._const_value
@@ -2258,6 +2427,11 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2258
2427
 
2259
2428
  @property
2260
2429
  def metadata_props(self) -> dict[str, str]:
2430
+ """The metadata properties of the value.
2431
+
2432
+ The metadata properties are used to store additional information about the value.
2433
+ Unlike ``meta``, this property is serialized to the ONNX proto.
2434
+ """
2261
2435
  if self._metadata_props is None:
2262
2436
  self._metadata_props = {}
2263
2437
  return self._metadata_props
@@ -2805,6 +2979,11 @@ class Graph(_protocols.GraphProtocol, Sequence[Node], _display.PrettyPrintable):
2805
2979
 
2806
2980
  @property
2807
2981
  def metadata_props(self) -> dict[str, str]:
2982
+ """The metadata properties of the graph.
2983
+
2984
+ The metadata properties are used to store additional information about the graph.
2985
+ Unlike ``meta``, this property is serialized to the ONNX proto.
2986
+ """
2808
2987
  if self._metadata_props is None:
2809
2988
  self._metadata_props = {}
2810
2989
  return self._metadata_props
@@ -3057,6 +3236,11 @@ class Model(_protocols.ModelProtocol, _display.PrettyPrintable):
3057
3236
 
3058
3237
  @property
3059
3238
  def metadata_props(self) -> dict[str, str]:
3239
+ """The metadata properties of the model.
3240
+
3241
+ The metadata properties are used to store additional information about the model.
3242
+ Unlike ``meta``, this property is serialized to the ONNX proto.
3243
+ """
3060
3244
  if self._metadata_props is None:
3061
3245
  self._metadata_props = {}
3062
3246
  return self._metadata_props
@@ -3250,6 +3434,11 @@ class Function(_protocols.FunctionProtocol, Sequence[Node], _display.PrettyPrint
3250
3434
 
3251
3435
  @property
3252
3436
  def metadata_props(self) -> dict[str, str]:
3437
+ """The metadata properties of the function.
3438
+
3439
+ The metadata properties are used to store additional information about the function.
3440
+ Unlike ``meta``, this property is serialized to the ONNX proto.
3441
+ """
3253
3442
  return self._graph.metadata_props
3254
3443
 
3255
3444
  def all_nodes(self) -> Iterator[Node]:
@@ -205,14 +205,20 @@ def _write_external_data(
205
205
  )
206
206
  current_offset = tensor_info.offset
207
207
  assert tensor is not None
208
- raw_data = tensor.tobytes()
209
- if isinstance(tensor, _core.ExternalTensor):
210
- tensor.release()
211
208
  # Pad file to required offset if needed
212
209
  file_size = data_file.tell()
213
210
  if current_offset > file_size:
214
211
  data_file.write(b"\0" * (current_offset - file_size))
215
- data_file.write(raw_data)
212
+
213
+ if hasattr(tensor, "tofile"):
214
+ # Some existing implementation of TensorProtocol
215
+ # may not have tofile() as it was introduced in v0.1.11
216
+ tensor.tofile(data_file)
217
+ else:
218
+ raw_data = tensor.tobytes()
219
+ if isinstance(tensor, _core.ExternalTensor):
220
+ tensor.release()
221
+ data_file.write(raw_data)
216
222
 
217
223
 
218
224
  def _create_external_tensor(
@@ -69,7 +69,12 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
69
69
  shape=tensor.shape, # type: ignore[arg-type]
70
70
  type=ir.TensorType(tensor.dtype),
71
71
  const_value=tensor,
72
+ # Preserve metadata from Constant value into the onnx model
73
+ metadata_props=node.outputs[0].metadata_props.copy(),
72
74
  )
75
+ # Preserve value meta from the Constant output for intermediate analysis
76
+ initializer.meta.update(node.outputs[0].meta)
77
+
73
78
  assert node.graph is not None
74
79
  node.graph.register_initializer(initializer)
75
80
  # Replace the constant node with the initializer
@@ -15,6 +15,29 @@ import onnx_ir as ir
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
+ def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19
+ def merge_dims(dim1, dim2):
20
+ if dim1 == dim2:
21
+ return dim1
22
+ if not isinstance(dim1, ir.SymbolicDim):
23
+ return dim1 # Prefer int value over symbolic dim
24
+ if not isinstance(dim2, ir.SymbolicDim):
25
+ return dim2
26
+ if dim1.value is None:
27
+ return dim2
28
+ return dim1
29
+
30
+ if shape1 is None:
31
+ return shape2
32
+ if shape2 is None:
33
+ return shape1
34
+ if len(shape1) != len(shape2):
35
+ raise ValueError(
36
+ f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
37
+ )
38
+ return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
39
+
40
+
18
41
  class IdentityEliminationPass(ir.passes.InPlacePass):
19
42
  """Pass for eliminating redundant Identity nodes.
20
43
 
@@ -75,6 +98,11 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
75
98
  if output_is_graph_output and input_is_graph_input:
76
99
  return False
77
100
 
101
+ # Copy over shape/type if the output has more complete information
102
+ input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
103
+ if input_value.type is None:
104
+ input_value.type = output_value.type
105
+
78
106
  # Case 1 & 2 (merged): Eliminate the identity node
79
107
  # Replace all uses of output with input
80
108
  ir.convenience.replace_all_uses_with(output_value, input_value)
@@ -193,14 +193,8 @@ class NameFixPass(ir.passes.InPlacePass):
193
193
  if not value.name:
194
194
  modified = self._assign_value_name(value, used_value_names, value_counter)
195
195
  else:
196
- old_name = value.name
197
196
  modified = self._fix_duplicate_value_name(value, used_value_names, value_counter)
198
- if modified:
199
- assert value.graph is not None
200
- if value.is_initializer():
201
- value.graph.initializers.pop(old_name)
202
- # Add the initializer back with the new name
203
- value.graph.initializers.add(value)
197
+ # initializers dictionary is updated automatically when the Value is renamed
204
198
 
205
199
  # Record the final name for this value
206
200
  assert value.name is not None
@@ -709,8 +709,7 @@ def _deserialize_graph(
709
709
  annotation.tensor_name: annotation for annotation in proto.quantization_annotation
710
710
  }
711
711
 
712
- # Create values for initializers and inputs
713
- initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
712
+ # Create values for inputs
714
713
  inputs = [_core.Value(name=info.name) for info in proto.input]
715
714
  for info, value in zip(proto.input, inputs):
716
715
  deserialize_value_info_proto(info, value)
@@ -725,6 +724,11 @@ def _deserialize_graph(
725
724
  # Enter the graph scope by pushing the values for this scope to the stack
726
725
  scoped_values.append(values)
727
726
 
727
+ # Build the value info dictionary to allow for quick lookup for this graph scope
728
+ value_info = {info.name: info for info in proto.value_info}
729
+
730
+ # Create values for initializers
731
+ initializer_tensors = [deserialize_tensor(tensor) for tensor in proto.initializer]
728
732
  initializer_values = []
729
733
  for i, tensor in enumerate(initializer_tensors):
730
734
  initializer_name = tensor.name
@@ -750,6 +754,8 @@ def _deserialize_graph(
750
754
  shape=tensor.shape, # type: ignore[arg-type]
751
755
  const_value=tensor,
752
756
  )
757
+ if initializer_name in value_info:
758
+ deserialize_value_info_proto(value_info[initializer_name], initializer_value)
753
759
  if initializer_value.name in quantization_annotations:
754
760
  _deserialize_quantization_annotation(
755
761
  quantization_annotations[initializer_value.name], initializer_value
@@ -757,9 +763,6 @@ def _deserialize_graph(
757
763
  values[initializer_name] = initializer_value
758
764
  initializer_values.append(initializer_value)
759
765
 
760
- # Build the value info dictionary to allow for quick lookup for this graph scope
761
- value_info = {info.name: info for info in proto.value_info}
762
-
763
766
  # Declare values for all node outputs from this graph scope. This is necessary
764
767
  # to handle the case where a node in a subgraph uses a value that is declared out
765
768
  # of order in the outer graph. Declaring the values first allows us to find the
@@ -1390,7 +1393,12 @@ def _should_create_value_info_for_value(value: _protocols.ValueProtocol) -> bool
1390
1393
  True if value info should be created for the value.
1391
1394
  """
1392
1395
  # No need to serialize value info if it is not set
1393
- if value.shape is None and value.type is None:
1396
+ if (
1397
+ value.shape is None
1398
+ and value.type is None
1399
+ and not value.metadata_props
1400
+ and not value.doc_string
1401
+ ):
1394
1402
  return False
1395
1403
  if not value.name:
1396
1404
  logger.debug("Did not serialize '%s' because its name is empty", value)
@@ -1967,11 +1975,26 @@ def serialize_type(type_protocol: _protocols.TypeProtocol) -> onnx.TypeProto:
1967
1975
  @_capture_errors(lambda type_proto, from_: repr(from_))
1968
1976
  def serialize_shape_into(type_proto: onnx.TypeProto, from_: _protocols.ShapeProtocol) -> None:
1969
1977
  value_field = type_proto.WhichOneof("value")
1978
+ if value_field is None:
1979
+ # We cannot write the shape because we do not know where to write it
1980
+ logger.warning(
1981
+ # TODO(justinchuby): Show more context about the value when move everything to an object
1982
+ "The value type for shape %s is not known. Please set type for the value. Skipping serialization",
1983
+ from_,
1984
+ )
1985
+ return
1970
1986
  tensor_type = getattr(type_proto, value_field)
1971
1987
  while not isinstance(tensor_type.elem_type, int):
1972
1988
  # Find the leaf type that has the shape field
1973
1989
  type_proto = tensor_type.elem_type
1974
1990
  value_field = type_proto.WhichOneof("value")
1991
+ if value_field is None:
1992
+ logger.warning(
1993
+ # TODO(justinchuby): Show more context about the value when move everything to an object
1994
+ "The value type for shape %s is not known. Please set type for the value. Skipping serialization",
1995
+ from_,
1996
+ )
1997
+ return
1975
1998
  tensor_type = getattr(type_proto, value_field)
1976
1999
  # When from is empty, we still need to set the shape field to an empty list by touching it
1977
2000
  tensor_type.shape.ClearField("dim")
@@ -1992,5 +2015,8 @@ def serialize_dimension_into(
1992
2015
  dim_proto.dim_value = dim
1993
2016
  elif isinstance(dim, (_core.SymbolicDim, _protocols.SymbolicDimProtocol)):
1994
2017
  if dim.value is not None:
1995
- # TODO(justinchuby): None is probably not a valid value for dim_param
1996
2018
  dim_proto.dim_param = str(dim.value)
2019
+ # NOTE: None is a valid value for symbolic dimension:
2020
+ # A dimension MAY have neither dim_value nor dim_param set. Such a dimension
2021
+ # represents an unknown dimension unrelated to other unknown dimensions.
2022
+ # Here we will just leave the dim_proto empty.
@@ -168,10 +168,8 @@ class TorchTensor(_core.Tensor):
168
168
  return self.numpy()
169
169
  return self.numpy().__array__(dtype)
170
170
 
171
- def tobytes(self) -> bytes:
172
- # Implement tobytes to support native PyTorch types so we can use types like bloat16
173
- # Reading from memory directly is also more efficient because
174
- # it avoids copying to a NumPy array
171
+ def _get_cbytes(self):
172
+ """Get a ctypes byte array pointing to the tensor data."""
175
173
  import torch._subclasses.fake_tensor
176
174
 
177
175
  with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
@@ -185,8 +183,18 @@ class TorchTensor(_core.Tensor):
185
183
  "or save the model without initializers by setting include_initializers=False."
186
184
  )
187
185
 
188
- return bytes(
189
- (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
190
- tensor.data_ptr()
191
- )
186
+ # Return the tensor to ensure it is not garbage collected while the ctypes array is in use
187
+ return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
188
+ tensor.data_ptr()
192
189
  )
190
+
191
+ def tobytes(self) -> bytes:
192
+ # Implement tobytes to support native PyTorch types so we can use types like bloat16
193
+ # Reading from memory directly is also more efficient because
194
+ # it avoids copying to a NumPy array
195
+ _, data = self._get_cbytes()
196
+ return bytes(data)
197
+
198
+ def tofile(self, file) -> None:
199
+ _, data = self._get_cbytes()
200
+ return file.write(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.9
3
+ Version: 0.1.11
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
10
10
  Classifier: Development Status :: 4 - Beta
11
- Classifier: Programming Language :: Python :: 3.9
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
11
  Requires-Python: >=3.9
17
12
  Description-Content-Type: text/markdown
18
13
  License-File: LICENSE
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes