onnx-ir 0.1.10__tar.gz → 0.1.12__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

Files changed (52) hide show
  1. {onnx_ir-0.1.10/src/onnx_ir.egg-info → onnx_ir-0.1.12}/PKG-INFO +1 -7
  2. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/README.md +0 -1
  3. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/pyproject.toml +0 -5
  4. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/__init__.py +1 -1
  5. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_convenience/__init__.py +25 -14
  6. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_core.py +175 -24
  7. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_protocols.py +11 -0
  8. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/external_data.py +10 -4
  9. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/common_subexpression_elimination.py +3 -2
  10. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/constant_manipulation.py +1 -1
  11. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/identity_elimination.py +32 -10
  12. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/initializer_deduplication.py +2 -2
  13. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/tensor_adapters.py +16 -8
  14. {onnx_ir-0.1.10 → onnx_ir-0.1.12/src/onnx_ir.egg-info}/PKG-INFO +1 -7
  15. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/LICENSE +0 -0
  16. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/MANIFEST.in +0 -0
  17. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/setup.cfg +0 -0
  18. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_convenience/_constructors.py +0 -0
  19. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_display.py +0 -0
  20. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_enums.py +0 -0
  21. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_graph_comparison.py +0 -0
  22. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_graph_containers.py +0 -0
  23. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_io.py +0 -0
  24. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_linked_list.py +0 -0
  25. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_metadata.py +0 -0
  26. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_name_authority.py +0 -0
  27. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_polyfill.py +0 -0
  28. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_tape.py +0 -0
  29. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
  30. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_type_casting.py +0 -0
  31. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_version_utils.py +0 -0
  32. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/convenience.py +0 -0
  33. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/__init__.py +0 -0
  34. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/_pass_infra.py +0 -0
  35. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/__init__.py +0 -0
  36. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
  37. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
  38. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/inliner.py +0 -0
  39. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/naming.py +0 -0
  40. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
  41. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/shape_inference.py +0 -0
  42. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/topological_sort.py +0 -0
  43. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/unused_removal.py +0 -0
  44. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/py.typed +0 -0
  45. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/serde.py +0 -0
  46. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/tape.py +0 -0
  47. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/testing.py +0 -0
  48. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/traversal.py +0 -0
  49. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
  50. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
  51. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/requires.txt +0 -0
  52. {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
10
10
  Classifier: Development Status :: 4 - Beta
11
- Classifier: Programming Language :: Python :: 3.9
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
11
  Requires-Python: >=3.9
17
12
  Description-Content-Type: text/markdown
18
13
  License-File: LICENSE
@@ -25,7 +20,6 @@ Dynamic: license-file
25
20
  # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
26
21
 
27
22
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
28
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
29
23
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
30
24
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
31
25
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
@@ -1,7 +1,6 @@
1
1
  # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
2
2
 
3
3
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
4
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
5
4
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
6
5
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
7
6
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
@@ -15,11 +15,6 @@ license = "Apache-2.0"
15
15
  license-files = ["LICEN[CS]E*"]
16
16
  classifiers = [
17
17
  "Development Status :: 4 - Beta",
18
- "Programming Language :: Python :: 3.9",
19
- "Programming Language :: Python :: 3.10",
20
- "Programming Language :: Python :: 3.11",
21
- "Programming Language :: Python :: 3.12",
22
- "Programming Language :: Python :: 3.13",
23
18
  ]
24
19
  dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes>=0.5.0"]
25
20
 
@@ -168,4 +168,4 @@ def __set_module() -> None:
168
168
 
169
169
 
170
170
  __set_module()
171
- __version__ = "0.1.10"
171
+ __version__ = "0.1.12"
@@ -280,6 +280,7 @@ def convert_attributes(
280
280
  def replace_all_uses_with(
281
281
  values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
282
282
  replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
283
+ replace_graph_outputs: bool = False,
283
284
  ) -> None:
284
285
  """Replace all uses of the given values with the replacements.
285
286
 
@@ -318,9 +319,22 @@ def replace_all_uses_with(
318
319
  replaced are part of the graph outputs. Be sure to remove the old nodes
319
320
  from the graph using ``graph.remove()`` if they are no longer needed.
320
321
 
322
+ .. versionadded:: 0.1.12
323
+ The ``replace_graph_outputs`` parameter is added.
324
+
325
+ .. versionadded:: 0.1.12
326
+ ValueError is raised when ``replace_graph_outputs`` is False && when the value to
327
+ replace is a graph output.
328
+
321
329
  Args:
322
330
  values: The value or values to be replaced.
323
331
  replacements: The new value or values to use as inputs.
332
+ replace_graph_outputs: If True, graph outputs that reference the values
333
+ being replaced will also be updated to reference the replacements.
334
+
335
+ Raises:
336
+ ValueError: When ``replace_graph_outputs`` is False && when the value to
337
+ replace is a graph output.
324
338
  """
325
339
  if not isinstance(values, Sequence):
326
340
  values = (values,)
@@ -329,8 +343,7 @@ def replace_all_uses_with(
329
343
  if len(values) != len(replacements):
330
344
  raise ValueError("The number of values and replacements must match.")
331
345
  for value, replacement in zip(values, replacements):
332
- for user_node, index in tuple(value.uses()):
333
- user_node.replace_input_with(index, replacement)
346
+ value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)
334
347
 
335
348
 
336
349
  def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
@@ -397,20 +410,18 @@ def replace_nodes_and_values(
397
410
  """
398
411
  for old_value, new_value in zip(old_values, new_values):
399
412
  # Propagate relevant info from old value to new value
400
- # TODO(Rama): Perhaps this should be a separate utility function. Also, consider
401
- # merging old and new type/shape info.
402
- new_value.type = old_value.type
403
- new_value.shape = old_value.shape
404
- new_value.const_value = old_value.const_value
405
- new_value.name = old_value.name
413
+ # TODO(Rama): Perhaps this should be a separate utility function.
414
+ new_value.type = old_value.type if old_value.type is not None else new_value.type
415
+ new_value.shape = old_value.shape if old_value.shape is not None else new_value.shape
416
+ new_value.const_value = (
417
+ old_value.const_value
418
+ if old_value.const_value is not None
419
+ else new_value.const_value
420
+ )
421
+ new_value.name = old_value.name if old_value.name is not None else new_value.name
406
422
 
407
423
  # Reconnect the users of the deleted values to use the new values
408
- replace_all_uses_with(old_values, new_values)
409
- # Update graph/function outputs if the node generates output
410
- replacement_mapping = dict(zip(old_values, new_values))
411
- for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
412
- if graph_or_function_output in replacement_mapping:
413
- graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
424
+ replace_all_uses_with(old_values, new_values, replace_graph_outputs=True)
414
425
 
415
426
  # insert new nodes after the index node
416
427
  graph_or_function.insert_after(insertion_point, new_nodes)
@@ -185,6 +185,19 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
185
185
  self._metadata = _metadata.MetadataStore()
186
186
  return self._metadata
187
187
 
188
+ def tofile(self, file) -> None:
189
+ """Write the tensor to a binary file.
190
+
191
+ This method writes the raw bytes of the tensor to a file-like object.
192
+ The file-like object must have a ``write`` method that accepts bytes.
193
+
194
+ .. versionadded:: 0.1.11
195
+
196
+ Args:
197
+ file: A file-like object with a ``write`` method that accepts bytes.
198
+ """
199
+ file.write(self.tobytes())
200
+
188
201
  def display(self, *, page: bool = False) -> None:
189
202
  rich = _display.require_rich()
190
203
 
@@ -337,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
337
350
  return array
338
351
 
339
352
 
353
+ def _supports_fileno(file: Any) -> bool:
354
+ """Check if the file-like object supports fileno()."""
355
+ if not hasattr(file, "fileno"):
356
+ return False
357
+ try:
358
+ file.fileno()
359
+ except Exception: # pylint: disable=broad-except
360
+ return False
361
+ return True
362
+
363
+
364
+ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
365
+ """Create a numpy array for the byte representation of the tensor.
366
+
367
+ This function is used for serializing the tensor to bytes. It handles the
368
+ special cases for 4-bit data types and endianness.
369
+ """
370
+ array = tensor.numpy()
371
+ if tensor.dtype in {
372
+ _enums.DataType.INT4,
373
+ _enums.DataType.UINT4,
374
+ _enums.DataType.FLOAT4E2M1,
375
+ }:
376
+ # Pack the array into int4
377
+ array = _type_casting.pack_4bitx2(array)
378
+ else:
379
+ assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
380
+ if not _IS_LITTLE_ENDIAN:
381
+ array = array.astype(array.dtype.newbyteorder("<"))
382
+ return array
383
+
384
+
340
385
  class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
341
386
  """An immutable concrete tensor.
342
387
 
@@ -509,20 +554,24 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
509
554
  value is not a numpy array.
510
555
  """
511
556
  # TODO(justinchuby): Support DLPack
512
- array = self.numpy()
513
- if self.dtype in {
514
- _enums.DataType.INT4,
515
- _enums.DataType.UINT4,
516
- _enums.DataType.FLOAT4E2M1,
517
- }:
518
- # Pack the array into int4
519
- array = _type_casting.pack_4bitx2(array)
520
- else:
521
- assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
522
- if not _IS_LITTLE_ENDIAN:
523
- array = array.view(array.dtype.newbyteorder("<"))
557
+ array = _create_np_array_for_byte_representation(self)
524
558
  return array.tobytes()
525
559
 
560
+ def tofile(self, file) -> None:
561
+ """Write the tensor to a binary file.
562
+
563
+ .. versionadded:: 0.1.11
564
+
565
+ Args:
566
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
567
+ """
568
+ if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
569
+ # This is a duplication of tobytes() for handling special cases
570
+ array = _create_np_array_for_byte_representation(self)
571
+ array.tofile(file)
572
+ else:
573
+ file.write(self.tobytes())
574
+
526
575
 
527
576
  class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
528
577
  """An immutable concrete tensor with its data store on disk.
@@ -535,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
535
584
  the tensor is recommended if IO overhead and memory usage is a concern.
536
585
 
537
586
  To obtain an array, call :meth:`numpy`. To obtain the bytes,
538
- call :meth:`tobytes`.
587
+ call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
539
588
 
540
589
  The :attr:`location` must be a relative path conforming to the ONNX
541
590
  specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
@@ -590,7 +639,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
590
639
  length: The length of the data in bytes.
591
640
  dtype: The data type of the tensor.
592
641
  shape: The shape of the tensor.
593
- name: The name of the tensor..
642
+ name: The name of the tensor.
594
643
  doc_string: The documentation string.
595
644
  metadata_props: The metadata properties.
596
645
  base_dir: The base directory for the external data. It is used to resolve relative paths.
@@ -746,6 +795,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
746
795
  length = self._length or self.nbytes
747
796
  return self.raw[offset : offset + length]
748
797
 
798
+ def tofile(self, file) -> None:
799
+ self._check_validity()
800
+ with open(self.path, "rb") as src:
801
+ if self._offset is not None:
802
+ src.seek(self._offset)
803
+ bytes_to_copy = self._length or self.nbytes
804
+ chunk_size = 1024 * 1024 # 1MB
805
+ while bytes_to_copy > 0:
806
+ chunk = src.read(min(chunk_size, bytes_to_copy))
807
+ file.write(chunk)
808
+ bytes_to_copy -= len(chunk)
809
+
749
810
  def valid(self) -> bool:
750
811
  """Check if the tensor is valid.
751
812
 
@@ -979,6 +1040,15 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
979
1040
  """Return the bytes of the tensor."""
980
1041
  return self._evaluate().tobytes()
981
1042
 
1043
+ def tofile(self, file) -> None:
1044
+ tensor = self._evaluate()
1045
+ if hasattr(tensor, "tofile"):
1046
+ # Some existing implementation of TensorProtocol
1047
+ # may not have tofile() as it was introduced in v0.1.11
1048
+ tensor.tofile(file)
1049
+ else:
1050
+ super().tofile(file)
1051
+
982
1052
 
983
1053
  class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
984
1054
  """A tensor that stores 4bit datatypes in packed format.
@@ -1110,9 +1180,26 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1110
1180
  """
1111
1181
  array = self.numpy_packed()
1112
1182
  if not _IS_LITTLE_ENDIAN:
1113
- array = array.view(array.dtype.newbyteorder("<"))
1183
+ array = array.astype(array.dtype.newbyteorder("<"))
1114
1184
  return array.tobytes()
1115
1185
 
1186
+ def tofile(self, file) -> None:
1187
+ """Write the tensor to a binary file.
1188
+
1189
+ .. versionadded:: 0.1.11
1190
+
1191
+ Args:
1192
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
1193
+ """
1194
+ if _supports_fileno(file):
1195
+ # This is a duplication of tobytes() for handling edge cases
1196
+ array = self.numpy_packed()
1197
+ if not _IS_LITTLE_ENDIAN:
1198
+ array = array.astype(array.dtype.newbyteorder("<"))
1199
+ array.tofile(file)
1200
+ else:
1201
+ file.write(self.tobytes())
1202
+
1116
1203
 
1117
1204
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1118
1205
  """Immutable symbolic dimension that can be shared across multiple shapes.
@@ -2214,23 +2301,38 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2214
2301
 
2215
2302
  @name.setter
2216
2303
  def name(self, value: str | None) -> None:
2217
- if self._const_value is not None:
2218
- self._const_value.name = value
2219
- old_name = self._name
2220
- self._name = value
2221
- if self.is_initializer():
2304
+ if self._name == value:
2305
+ return
2306
+
2307
+ # First check if renaming is valid. Do not change anything if it is invalid
2308
+ # to prevent the value from being in an inconsistent state.
2309
+ is_initializer = self.is_initializer()
2310
+ if is_initializer:
2222
2311
  if value is None:
2223
2312
  raise ValueError(
2224
- "Initializer value cannot have name set to None. Please pop() the value from initializers first"
2313
+ "Initializer value cannot have name set to None. Please pop() the value from initializers first to do so."
2225
2314
  )
2226
- # Rename the initializer entry in the graph
2227
2315
  graph = self._graph
2228
2316
  assert graph is not None
2229
- assert old_name is not None
2230
2317
  if value in graph.initializers and graph.initializers[value] is not self:
2231
2318
  raise ValueError(
2232
- f"Cannot rename initializer to '{value}': an initializer with that name already exists."
2319
+ f"Cannot rename initializer '{self}' to '{value}': an initializer with that name already exists."
2233
2320
  )
2321
+
2322
+ # Rename the backing constant tensor
2323
+ if self._const_value is not None:
2324
+ self._const_value.name = value
2325
+
2326
+ # Rename self
2327
+ old_name = self._name
2328
+ self._name = value
2329
+
2330
+ if is_initializer:
2331
+ # Rename the initializer entry in the graph
2332
+ assert value is not None, "debug: Should be guarded above"
2333
+ graph = self._graph
2334
+ assert graph is not None
2335
+ assert old_name is not None
2234
2336
  graph.initializers.pop(old_name)
2235
2337
  graph.initializers[value] = self
2236
2338
 
@@ -2346,6 +2448,55 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2346
2448
  """Whether the value is an initializer of a graph."""
2347
2449
  return self._is_initializer
2348
2450
 
2451
+ def replace_all_uses_with(
2452
+ self, replacement: Value, /, replace_graph_outputs: bool = False
2453
+ ) -> None:
2454
+ """Replace all uses of this value with another value.
2455
+
2456
+ If the value is an output of a graph and ``replace_graph_outputs`` is ``True``,
2457
+ the graph output will also be replaced. Be careful when a value appears multiple times
2458
+ in the graph outputs - this is invalid. An identity node will need to be added on each
2459
+ duplicated outputs to ensure a valid ONNX graph.
2460
+
2461
+ You may also want to assign the name of this value to the replacement value
2462
+ to maintain the name when it is a graph output.
2463
+
2464
+ To replace usage of a sequence of values with another sequence of values, consider using
2465
+ :func:`onnx_ir.convenience.replace_all_uses_with`.
2466
+
2467
+ .. versionadded:: 0.1.12
2468
+
2469
+ Args:
2470
+ replacement: The value to replace all uses with.
2471
+ replace_graph_outputs: If True, graph outputs that reference this value
2472
+ will also be updated to reference the replacement.
2473
+
2474
+ Raises:
2475
+ ValueError: When ``replace_graph_outputs`` is False && when the value to
2476
+ replace is a graph output.
2477
+ """
2478
+ # NOTE: Why we don't replace the value name when the value is an output:
2479
+ # When the replacement value is already an output of the graph, renaming it
2480
+ # to the name of this value will cause name conflicts. It is better to let
2481
+ # the user handle the renaming explicitly and insert identity nodes if needed.
2482
+ if self.is_graph_output():
2483
+ graph = self.graph
2484
+ assert graph is not None
2485
+
2486
+ if not replace_graph_outputs:
2487
+ raise ValueError(
2488
+ f"{self!r} is an output of graph {graph.name!r}. "
2489
+ "Set replace_graph_outputs=True or replace the graph output frist before "
2490
+ "calling replace_all_uses_with."
2491
+ )
2492
+
2493
+ for i, output in enumerate(graph.outputs):
2494
+ if output is self:
2495
+ graph.outputs[i] = replacement
2496
+
2497
+ for user_node, index in self.uses():
2498
+ user_node.replace_input_with(index, replacement)
2499
+
2349
2500
 
2350
2501
  @deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
2351
2502
  def Input(
@@ -203,6 +203,17 @@ class ValueProtocol(Protocol):
203
203
  """Whether this value is an output of a graph."""
204
204
  ...
205
205
 
206
+ def replace_all_uses_with(
207
+ self, new_value: ValueProtocol | None, replace_graph_outputs: bool = False
208
+ ) -> None:
209
+ """Replace all uses of this value with the given new value.
210
+
211
+ Args:
212
+ new_value: The new value to replace this value with.
213
+ replace_graph_outputs: Whether to replace graph outputs that use this value.
214
+ """
215
+ ...
216
+
206
217
 
207
218
  @typing.runtime_checkable
208
219
  class NodeProtocol(Protocol):
@@ -205,14 +205,20 @@ def _write_external_data(
205
205
  )
206
206
  current_offset = tensor_info.offset
207
207
  assert tensor is not None
208
- raw_data = tensor.tobytes()
209
- if isinstance(tensor, _core.ExternalTensor):
210
- tensor.release()
211
208
  # Pad file to required offset if needed
212
209
  file_size = data_file.tell()
213
210
  if current_offset > file_size:
214
211
  data_file.write(b"\0" * (current_offset - file_size))
215
- data_file.write(raw_data)
212
+
213
+ if hasattr(tensor, "tofile"):
214
+ # Some existing implementation of TensorProtocol
215
+ # may not have tofile() as it was introduced in v0.1.11
216
+ tensor.tofile(data_file)
217
+ else:
218
+ raw_data = tensor.tobytes()
219
+ if isinstance(tensor, _core.ExternalTensor):
220
+ tensor.release()
221
+ data_file.write(raw_data)
216
222
 
217
223
 
218
224
  def _create_external_tensor(
@@ -150,8 +150,6 @@ def _remove_node_and_replace_values(
150
150
  remove_values: The values to replace.
151
151
  new_values: The values to replace with.
152
152
  """
153
- # Reconnect the users of the deleted values to use the new values
154
- ir.convenience.replace_all_uses_with(remove_values, new_values)
155
153
  # Update graph/function outputs if the node generates output
156
154
  if any(remove_value.is_graph_output() for remove_value in remove_values):
157
155
  replacement_mapping = dict(zip(remove_values, new_values))
@@ -185,6 +183,9 @@ def _remove_node_and_replace_values(
185
183
  new_value.name = graph_output.name
186
184
  graph.outputs[idx] = new_value
187
185
 
186
+ # Reconnect the users of the deleted values to use the new values
187
+ ir.convenience.replace_all_uses_with(remove_values, new_values)
188
+
188
189
  graph.remove(remove_node, safe=True)
189
190
 
190
191
 
@@ -78,7 +78,7 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
78
78
  assert node.graph is not None
79
79
  node.graph.register_initializer(initializer)
80
80
  # Replace the constant node with the initializer
81
- ir.convenience.replace_all_uses_with(node.outputs[0], initializer)
81
+ node.outputs[0].replace_all_uses_with(initializer)
82
82
  node.graph.remove(node, safe=True)
83
83
  count += 1
84
84
  logger.debug(
@@ -15,6 +15,29 @@ import onnx_ir as ir
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
+ def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19
+ def merge_dims(dim1, dim2):
20
+ if dim1 == dim2:
21
+ return dim1
22
+ if not isinstance(dim1, ir.SymbolicDim):
23
+ return dim1 # Prefer int value over symbolic dim
24
+ if not isinstance(dim2, ir.SymbolicDim):
25
+ return dim2
26
+ if dim1.value is None:
27
+ return dim2
28
+ return dim1
29
+
30
+ if shape1 is None:
31
+ return shape2
32
+ if shape2 is None:
33
+ return shape1
34
+ if len(shape1) != len(shape2):
35
+ raise ValueError(
36
+ f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
37
+ )
38
+ return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
39
+
40
+
18
41
  class IdentityEliminationPass(ir.passes.InPlacePass):
19
42
  """Pass for eliminating redundant Identity nodes.
20
43
 
@@ -75,22 +98,21 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
75
98
  if output_is_graph_output and input_is_graph_input:
76
99
  return False
77
100
 
101
+ # Copy over shape/type if the output has more complete information
102
+ input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
103
+ if input_value.type is None:
104
+ input_value.type = output_value.type
105
+
78
106
  # Case 1 & 2 (merged): Eliminate the identity node
79
107
  # Replace all uses of output with input
80
- ir.convenience.replace_all_uses_with(output_value, input_value)
108
+ ir.convenience.replace_all_uses_with(
109
+ output_value, input_value, replace_graph_outputs=True
110
+ )
81
111
 
82
112
  # If output is a graph output, we need to rename input and update graph outputs
83
113
  if output_is_graph_output:
84
- # Store the original output name
85
- original_output_name = output_value.name
86
-
87
114
  # Update the input value to have the output's name
88
- input_value.name = original_output_name
89
-
90
- # Update graph outputs to point to the input value
91
- for idx, graph_output in enumerate(graph_like.outputs):
92
- if graph_output is output_value:
93
- graph_like.outputs[idx] = input_value
115
+ input_value.name = output_value.name
94
116
 
95
117
  # Remove the identity node
96
118
  graph_like.remove(node, safe=True)
@@ -100,7 +100,7 @@ class DeduplicateInitializersPass(ir.passes.InPlacePass):
100
100
  if key in initializers:
101
101
  modified = True
102
102
  initializer_to_keep = initializers[key] # type: ignore[index]
103
- ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
103
+ initializer.replace_all_uses_with(initializer_to_keep)
104
104
  assert initializer.name is not None
105
105
  graph.initializers.pop(initializer.name)
106
106
  logger.info(
@@ -165,7 +165,7 @@ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
165
165
  continue
166
166
  modified = True
167
167
  initializer_to_keep = initializers[key] # type: ignore[index]
168
- ir.convenience.replace_all_uses_with(initializer, initializer_to_keep)
168
+ initializer.replace_all_uses_with(initializer_to_keep)
169
169
  assert initializer.name is not None
170
170
  graph.initializers.pop(initializer.name)
171
171
  logger.info(
@@ -168,10 +168,8 @@ class TorchTensor(_core.Tensor):
168
168
  return self.numpy()
169
169
  return self.numpy().__array__(dtype)
170
170
 
171
- def tobytes(self) -> bytes:
172
- # Implement tobytes to support native PyTorch types so we can use types like bloat16
173
- # Reading from memory directly is also more efficient because
174
- # it avoids copying to a NumPy array
171
+ def _get_cbytes(self):
172
+ """Get a ctypes byte array pointing to the tensor data."""
175
173
  import torch._subclasses.fake_tensor
176
174
 
177
175
  with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
@@ -185,8 +183,18 @@ class TorchTensor(_core.Tensor):
185
183
  "or save the model without initializers by setting include_initializers=False."
186
184
  )
187
185
 
188
- return bytes(
189
- (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
190
- tensor.data_ptr()
191
- )
186
+ # Return the tensor to ensure it is not garbage collected while the ctypes array is in use
187
+ return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
188
+ tensor.data_ptr()
192
189
  )
190
+
191
+ def tobytes(self) -> bytes:
192
+ # Implement tobytes to support native PyTorch types so we can use types like bloat16
193
+ # Reading from memory directly is also more efficient because
194
+ # it avoids copying to a NumPy array
195
+ _, data = self._get_cbytes()
196
+ return bytes(data)
197
+
198
+ def tofile(self, file) -> None:
199
+ _, data = self._get_cbytes()
200
+ return file.write(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
10
10
  Classifier: Development Status :: 4 - Beta
11
- Classifier: Programming Language :: Python :: 3.9
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
11
  Requires-Python: >=3.9
17
12
  Description-Content-Type: text/markdown
18
13
  License-File: LICENSE
@@ -25,7 +20,6 @@ Dynamic: license-file
25
20
  # <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
26
21
 
27
22
  [![PyPI - Version](https://img.shields.io/pypi/v/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
28
- [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/onnx-ir.svg)](https://pypi.org/project/onnx-ir)
29
23
  [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
30
24
  [![codecov](https://codecov.io/gh/onnx/ir-py/graph/badge.svg?token=SPQ3G9T78Z)](https://codecov.io/gh/onnx/ir-py)
31
25
  [![PyPI Downloads](https://static.pepy.tech/badge/onnx-ir/month)](https://pepy.tech/projects/onnx-ir)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes