onnx-ir 0.1.10__tar.gz → 0.1.12__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.10/src/onnx_ir.egg-info → onnx_ir-0.1.12}/PKG-INFO +1 -7
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/README.md +0 -1
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/pyproject.toml +0 -5
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_convenience/__init__.py +25 -14
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_core.py +175 -24
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_protocols.py +11 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/external_data.py +10 -4
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/common_subexpression_elimination.py +3 -2
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/constant_manipulation.py +1 -1
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/identity_elimination.py +32 -10
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/initializer_deduplication.py +2 -2
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/tensor_adapters.py +16 -8
- {onnx_ir-0.1.10 → onnx_ir-0.1.12/src/onnx_ir.egg-info}/PKG-INFO +1 -7
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/LICENSE +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/MANIFEST.in +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/setup.cfg +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_enums.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_type_casting.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/__init__.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/naming.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/py.typed +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/serde.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/SOURCES.txt +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.12
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
|
|
|
8
8
|
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
9
|
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.13
|
|
16
11
|
Requires-Python: >=3.9
|
|
17
12
|
Description-Content-Type: text/markdown
|
|
18
13
|
License-File: LICENSE
|
|
@@ -25,7 +20,6 @@ Dynamic: license-file
|
|
|
25
20
|
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
26
21
|
|
|
27
22
|
[](https://pypi.org/project/onnx-ir)
|
|
28
|
-
[](https://pypi.org/project/onnx-ir)
|
|
29
23
|
[](https://github.com/astral-sh/ruff)
|
|
30
24
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
31
25
|
[](https://pepy.tech/projects/onnx-ir)
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
2
2
|
|
|
3
3
|
[](https://pypi.org/project/onnx-ir)
|
|
4
|
-
[](https://pypi.org/project/onnx-ir)
|
|
5
4
|
[](https://github.com/astral-sh/ruff)
|
|
6
5
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
7
6
|
[](https://pepy.tech/projects/onnx-ir)
|
|
@@ -15,11 +15,6 @@ license = "Apache-2.0"
|
|
|
15
15
|
license-files = ["LICEN[CS]E*"]
|
|
16
16
|
classifiers = [
|
|
17
17
|
"Development Status :: 4 - Beta",
|
|
18
|
-
"Programming Language :: Python :: 3.9",
|
|
19
|
-
"Programming Language :: Python :: 3.10",
|
|
20
|
-
"Programming Language :: Python :: 3.11",
|
|
21
|
-
"Programming Language :: Python :: 3.12",
|
|
22
|
-
"Programming Language :: Python :: 3.13",
|
|
23
18
|
]
|
|
24
19
|
dependencies = ["numpy", "onnx>=1.16", "typing_extensions>=4.10", "ml_dtypes>=0.5.0"]
|
|
25
20
|
|
|
@@ -280,6 +280,7 @@ def convert_attributes(
|
|
|
280
280
|
def replace_all_uses_with(
|
|
281
281
|
values: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
282
282
|
replacements: _protocols.ValueProtocol | Sequence[_protocols.ValueProtocol],
|
|
283
|
+
replace_graph_outputs: bool = False,
|
|
283
284
|
) -> None:
|
|
284
285
|
"""Replace all uses of the given values with the replacements.
|
|
285
286
|
|
|
@@ -318,9 +319,22 @@ def replace_all_uses_with(
|
|
|
318
319
|
replaced are part of the graph outputs. Be sure to remove the old nodes
|
|
319
320
|
from the graph using ``graph.remove()`` if they are no longer needed.
|
|
320
321
|
|
|
322
|
+
.. versionadded:: 0.1.12
|
|
323
|
+
The ``replace_graph_outputs`` parameter is added.
|
|
324
|
+
|
|
325
|
+
.. versionadded:: 0.1.12
|
|
326
|
+
ValueError is raised when ``replace_graph_outputs`` is False && when the value to
|
|
327
|
+
replace is a graph output.
|
|
328
|
+
|
|
321
329
|
Args:
|
|
322
330
|
values: The value or values to be replaced.
|
|
323
331
|
replacements: The new value or values to use as inputs.
|
|
332
|
+
replace_graph_outputs: If True, graph outputs that reference the values
|
|
333
|
+
being replaced will also be updated to reference the replacements.
|
|
334
|
+
|
|
335
|
+
Raises:
|
|
336
|
+
ValueError: When ``replace_graph_outputs`` is False && when the value to
|
|
337
|
+
replace is a graph output.
|
|
324
338
|
"""
|
|
325
339
|
if not isinstance(values, Sequence):
|
|
326
340
|
values = (values,)
|
|
@@ -329,8 +343,7 @@ def replace_all_uses_with(
|
|
|
329
343
|
if len(values) != len(replacements):
|
|
330
344
|
raise ValueError("The number of values and replacements must match.")
|
|
331
345
|
for value, replacement in zip(values, replacements):
|
|
332
|
-
|
|
333
|
-
user_node.replace_input_with(index, replacement)
|
|
346
|
+
value.replace_all_uses_with(replacement, replace_graph_outputs=replace_graph_outputs)
|
|
334
347
|
|
|
335
348
|
|
|
336
349
|
def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]:
|
|
@@ -397,20 +410,18 @@ def replace_nodes_and_values(
|
|
|
397
410
|
"""
|
|
398
411
|
for old_value, new_value in zip(old_values, new_values):
|
|
399
412
|
# Propagate relevant info from old value to new value
|
|
400
|
-
# TODO(Rama): Perhaps this should be a separate utility function.
|
|
401
|
-
|
|
402
|
-
new_value.
|
|
403
|
-
new_value.
|
|
404
|
-
|
|
405
|
-
|
|
413
|
+
# TODO(Rama): Perhaps this should be a separate utility function.
|
|
414
|
+
new_value.type = old_value.type if old_value.type is not None else new_value.type
|
|
415
|
+
new_value.shape = old_value.shape if old_value.shape is not None else new_value.shape
|
|
416
|
+
new_value.const_value = (
|
|
417
|
+
old_value.const_value
|
|
418
|
+
if old_value.const_value is not None
|
|
419
|
+
else new_value.const_value
|
|
420
|
+
)
|
|
421
|
+
new_value.name = old_value.name if old_value.name is not None else new_value.name
|
|
406
422
|
|
|
407
423
|
# Reconnect the users of the deleted values to use the new values
|
|
408
|
-
replace_all_uses_with(old_values, new_values)
|
|
409
|
-
# Update graph/function outputs if the node generates output
|
|
410
|
-
replacement_mapping = dict(zip(old_values, new_values))
|
|
411
|
-
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
|
|
412
|
-
if graph_or_function_output in replacement_mapping:
|
|
413
|
-
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
|
|
424
|
+
replace_all_uses_with(old_values, new_values, replace_graph_outputs=True)
|
|
414
425
|
|
|
415
426
|
# insert new nodes after the index node
|
|
416
427
|
graph_or_function.insert_after(insertion_point, new_nodes)
|
|
@@ -185,6 +185,19 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
|
|
|
185
185
|
self._metadata = _metadata.MetadataStore()
|
|
186
186
|
return self._metadata
|
|
187
187
|
|
|
188
|
+
def tofile(self, file) -> None:
|
|
189
|
+
"""Write the tensor to a binary file.
|
|
190
|
+
|
|
191
|
+
This method writes the raw bytes of the tensor to a file-like object.
|
|
192
|
+
The file-like object must have a ``write`` method that accepts bytes.
|
|
193
|
+
|
|
194
|
+
.. versionadded:: 0.1.11
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
file: A file-like object with a ``write`` method that accepts bytes.
|
|
198
|
+
"""
|
|
199
|
+
file.write(self.tobytes())
|
|
200
|
+
|
|
188
201
|
def display(self, *, page: bool = False) -> None:
|
|
189
202
|
rich = _display.require_rich()
|
|
190
203
|
|
|
@@ -337,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
|
|
|
337
350
|
return array
|
|
338
351
|
|
|
339
352
|
|
|
353
|
+
def _supports_fileno(file: Any) -> bool:
|
|
354
|
+
"""Check if the file-like object supports fileno()."""
|
|
355
|
+
if not hasattr(file, "fileno"):
|
|
356
|
+
return False
|
|
357
|
+
try:
|
|
358
|
+
file.fileno()
|
|
359
|
+
except Exception: # pylint: disable=broad-except
|
|
360
|
+
return False
|
|
361
|
+
return True
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
|
|
365
|
+
"""Create a numpy array for the byte representation of the tensor.
|
|
366
|
+
|
|
367
|
+
This function is used for serializing the tensor to bytes. It handles the
|
|
368
|
+
special cases for 4-bit data types and endianness.
|
|
369
|
+
"""
|
|
370
|
+
array = tensor.numpy()
|
|
371
|
+
if tensor.dtype in {
|
|
372
|
+
_enums.DataType.INT4,
|
|
373
|
+
_enums.DataType.UINT4,
|
|
374
|
+
_enums.DataType.FLOAT4E2M1,
|
|
375
|
+
}:
|
|
376
|
+
# Pack the array into int4
|
|
377
|
+
array = _type_casting.pack_4bitx2(array)
|
|
378
|
+
else:
|
|
379
|
+
assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
|
|
380
|
+
if not _IS_LITTLE_ENDIAN:
|
|
381
|
+
array = array.astype(array.dtype.newbyteorder("<"))
|
|
382
|
+
return array
|
|
383
|
+
|
|
384
|
+
|
|
340
385
|
class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
341
386
|
"""An immutable concrete tensor.
|
|
342
387
|
|
|
@@ -509,20 +554,24 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
|
|
|
509
554
|
value is not a numpy array.
|
|
510
555
|
"""
|
|
511
556
|
# TODO(justinchuby): Support DLPack
|
|
512
|
-
array = self
|
|
513
|
-
if self.dtype in {
|
|
514
|
-
_enums.DataType.INT4,
|
|
515
|
-
_enums.DataType.UINT4,
|
|
516
|
-
_enums.DataType.FLOAT4E2M1,
|
|
517
|
-
}:
|
|
518
|
-
# Pack the array into int4
|
|
519
|
-
array = _type_casting.pack_4bitx2(array)
|
|
520
|
-
else:
|
|
521
|
-
assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
|
|
522
|
-
if not _IS_LITTLE_ENDIAN:
|
|
523
|
-
array = array.view(array.dtype.newbyteorder("<"))
|
|
557
|
+
array = _create_np_array_for_byte_representation(self)
|
|
524
558
|
return array.tobytes()
|
|
525
559
|
|
|
560
|
+
def tofile(self, file) -> None:
|
|
561
|
+
"""Write the tensor to a binary file.
|
|
562
|
+
|
|
563
|
+
.. versionadded:: 0.1.11
|
|
564
|
+
|
|
565
|
+
Args:
|
|
566
|
+
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
|
|
567
|
+
"""
|
|
568
|
+
if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
|
|
569
|
+
# This is a duplication of tobytes() for handling special cases
|
|
570
|
+
array = _create_np_array_for_byte_representation(self)
|
|
571
|
+
array.tofile(file)
|
|
572
|
+
else:
|
|
573
|
+
file.write(self.tobytes())
|
|
574
|
+
|
|
526
575
|
|
|
527
576
|
class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
|
|
528
577
|
"""An immutable concrete tensor with its data store on disk.
|
|
@@ -535,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
535
584
|
the tensor is recommended if IO overhead and memory usage is a concern.
|
|
536
585
|
|
|
537
586
|
To obtain an array, call :meth:`numpy`. To obtain the bytes,
|
|
538
|
-
call :meth:`tobytes`.
|
|
587
|
+
call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
|
|
539
588
|
|
|
540
589
|
The :attr:`location` must be a relative path conforming to the ONNX
|
|
541
590
|
specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
|
|
@@ -590,7 +639,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
590
639
|
length: The length of the data in bytes.
|
|
591
640
|
dtype: The data type of the tensor.
|
|
592
641
|
shape: The shape of the tensor.
|
|
593
|
-
name: The name of the tensor
|
|
642
|
+
name: The name of the tensor.
|
|
594
643
|
doc_string: The documentation string.
|
|
595
644
|
metadata_props: The metadata properties.
|
|
596
645
|
base_dir: The base directory for the external data. It is used to resolve relative paths.
|
|
@@ -746,6 +795,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
746
795
|
length = self._length or self.nbytes
|
|
747
796
|
return self.raw[offset : offset + length]
|
|
748
797
|
|
|
798
|
+
def tofile(self, file) -> None:
|
|
799
|
+
self._check_validity()
|
|
800
|
+
with open(self.path, "rb") as src:
|
|
801
|
+
if self._offset is not None:
|
|
802
|
+
src.seek(self._offset)
|
|
803
|
+
bytes_to_copy = self._length or self.nbytes
|
|
804
|
+
chunk_size = 1024 * 1024 # 1MB
|
|
805
|
+
while bytes_to_copy > 0:
|
|
806
|
+
chunk = src.read(min(chunk_size, bytes_to_copy))
|
|
807
|
+
file.write(chunk)
|
|
808
|
+
bytes_to_copy -= len(chunk)
|
|
809
|
+
|
|
749
810
|
def valid(self) -> bool:
|
|
750
811
|
"""Check if the tensor is valid.
|
|
751
812
|
|
|
@@ -979,6 +1040,15 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
|
|
|
979
1040
|
"""Return the bytes of the tensor."""
|
|
980
1041
|
return self._evaluate().tobytes()
|
|
981
1042
|
|
|
1043
|
+
def tofile(self, file) -> None:
|
|
1044
|
+
tensor = self._evaluate()
|
|
1045
|
+
if hasattr(tensor, "tofile"):
|
|
1046
|
+
# Some existing implementation of TensorProtocol
|
|
1047
|
+
# may not have tofile() as it was introduced in v0.1.11
|
|
1048
|
+
tensor.tofile(file)
|
|
1049
|
+
else:
|
|
1050
|
+
super().tofile(file)
|
|
1051
|
+
|
|
982
1052
|
|
|
983
1053
|
class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
|
|
984
1054
|
"""A tensor that stores 4bit datatypes in packed format.
|
|
@@ -1110,9 +1180,26 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
|
|
|
1110
1180
|
"""
|
|
1111
1181
|
array = self.numpy_packed()
|
|
1112
1182
|
if not _IS_LITTLE_ENDIAN:
|
|
1113
|
-
array = array.
|
|
1183
|
+
array = array.astype(array.dtype.newbyteorder("<"))
|
|
1114
1184
|
return array.tobytes()
|
|
1115
1185
|
|
|
1186
|
+
def tofile(self, file) -> None:
|
|
1187
|
+
"""Write the tensor to a binary file.
|
|
1188
|
+
|
|
1189
|
+
.. versionadded:: 0.1.11
|
|
1190
|
+
|
|
1191
|
+
Args:
|
|
1192
|
+
file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
|
|
1193
|
+
"""
|
|
1194
|
+
if _supports_fileno(file):
|
|
1195
|
+
# This is a duplication of tobytes() for handling edge cases
|
|
1196
|
+
array = self.numpy_packed()
|
|
1197
|
+
if not _IS_LITTLE_ENDIAN:
|
|
1198
|
+
array = array.astype(array.dtype.newbyteorder("<"))
|
|
1199
|
+
array.tofile(file)
|
|
1200
|
+
else:
|
|
1201
|
+
file.write(self.tobytes())
|
|
1202
|
+
|
|
1116
1203
|
|
|
1117
1204
|
class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
|
|
1118
1205
|
"""Immutable symbolic dimension that can be shared across multiple shapes.
|
|
@@ -2214,23 +2301,38 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
2214
2301
|
|
|
2215
2302
|
@name.setter
|
|
2216
2303
|
def name(self, value: str | None) -> None:
|
|
2217
|
-
if self.
|
|
2218
|
-
|
|
2219
|
-
|
|
2220
|
-
|
|
2221
|
-
|
|
2304
|
+
if self._name == value:
|
|
2305
|
+
return
|
|
2306
|
+
|
|
2307
|
+
# First check if renaming is valid. Do not change anything if it is invalid
|
|
2308
|
+
# to prevent the value from being in an inconsistent state.
|
|
2309
|
+
is_initializer = self.is_initializer()
|
|
2310
|
+
if is_initializer:
|
|
2222
2311
|
if value is None:
|
|
2223
2312
|
raise ValueError(
|
|
2224
|
-
"Initializer value cannot have name set to None. Please pop() the value from initializers first"
|
|
2313
|
+
"Initializer value cannot have name set to None. Please pop() the value from initializers first to do so."
|
|
2225
2314
|
)
|
|
2226
|
-
# Rename the initializer entry in the graph
|
|
2227
2315
|
graph = self._graph
|
|
2228
2316
|
assert graph is not None
|
|
2229
|
-
assert old_name is not None
|
|
2230
2317
|
if value in graph.initializers and graph.initializers[value] is not self:
|
|
2231
2318
|
raise ValueError(
|
|
2232
|
-
f"Cannot rename initializer to '{value}': an initializer with that name already exists."
|
|
2319
|
+
f"Cannot rename initializer '{self}' to '{value}': an initializer with that name already exists."
|
|
2233
2320
|
)
|
|
2321
|
+
|
|
2322
|
+
# Rename the backing constant tensor
|
|
2323
|
+
if self._const_value is not None:
|
|
2324
|
+
self._const_value.name = value
|
|
2325
|
+
|
|
2326
|
+
# Rename self
|
|
2327
|
+
old_name = self._name
|
|
2328
|
+
self._name = value
|
|
2329
|
+
|
|
2330
|
+
if is_initializer:
|
|
2331
|
+
# Rename the initializer entry in the graph
|
|
2332
|
+
assert value is not None, "debug: Should be guarded above"
|
|
2333
|
+
graph = self._graph
|
|
2334
|
+
assert graph is not None
|
|
2335
|
+
assert old_name is not None
|
|
2234
2336
|
graph.initializers.pop(old_name)
|
|
2235
2337
|
graph.initializers[value] = self
|
|
2236
2338
|
|
|
@@ -2346,6 +2448,55 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
|
|
|
2346
2448
|
"""Whether the value is an initializer of a graph."""
|
|
2347
2449
|
return self._is_initializer
|
|
2348
2450
|
|
|
2451
|
+
def replace_all_uses_with(
|
|
2452
|
+
self, replacement: Value, /, replace_graph_outputs: bool = False
|
|
2453
|
+
) -> None:
|
|
2454
|
+
"""Replace all uses of this value with another value.
|
|
2455
|
+
|
|
2456
|
+
If the value is an output of a graph and ``replace_graph_outputs`` is ``True``,
|
|
2457
|
+
the graph output will also be replaced. Be careful when a value appears multiple times
|
|
2458
|
+
in the graph outputs - this is invalid. An identity node will need to be added on each
|
|
2459
|
+
duplicated outputs to ensure a valid ONNX graph.
|
|
2460
|
+
|
|
2461
|
+
You may also want to assign the name of this value to the replacement value
|
|
2462
|
+
to maintain the name when it is a graph output.
|
|
2463
|
+
|
|
2464
|
+
To replace usage of a sequence of values with another sequence of values, consider using
|
|
2465
|
+
:func:`onnx_ir.convenience.replace_all_uses_with`.
|
|
2466
|
+
|
|
2467
|
+
.. versionadded:: 0.1.12
|
|
2468
|
+
|
|
2469
|
+
Args:
|
|
2470
|
+
replacement: The value to replace all uses with.
|
|
2471
|
+
replace_graph_outputs: If True, graph outputs that reference this value
|
|
2472
|
+
will also be updated to reference the replacement.
|
|
2473
|
+
|
|
2474
|
+
Raises:
|
|
2475
|
+
ValueError: When ``replace_graph_outputs`` is False && when the value to
|
|
2476
|
+
replace is a graph output.
|
|
2477
|
+
"""
|
|
2478
|
+
# NOTE: Why we don't replace the value name when the value is an output:
|
|
2479
|
+
# When the replacement value is already an output of the graph, renaming it
|
|
2480
|
+
# to the name of this value will cause name conflicts. It is better to let
|
|
2481
|
+
# the user handle the renaming explicitly and insert identity nodes if needed.
|
|
2482
|
+
if self.is_graph_output():
|
|
2483
|
+
graph = self.graph
|
|
2484
|
+
assert graph is not None
|
|
2485
|
+
|
|
2486
|
+
if not replace_graph_outputs:
|
|
2487
|
+
raise ValueError(
|
|
2488
|
+
f"{self!r} is an output of graph {graph.name!r}. "
|
|
2489
|
+
"Set replace_graph_outputs=True or replace the graph output frist before "
|
|
2490
|
+
"calling replace_all_uses_with."
|
|
2491
|
+
)
|
|
2492
|
+
|
|
2493
|
+
for i, output in enumerate(graph.outputs):
|
|
2494
|
+
if output is self:
|
|
2495
|
+
graph.outputs[i] = replacement
|
|
2496
|
+
|
|
2497
|
+
for user_node, index in self.uses():
|
|
2498
|
+
user_node.replace_input_with(index, replacement)
|
|
2499
|
+
|
|
2349
2500
|
|
|
2350
2501
|
@deprecated("Input is deprecated since 0.1.9. Use ir.val(...) instead.")
|
|
2351
2502
|
def Input(
|
|
@@ -203,6 +203,17 @@ class ValueProtocol(Protocol):
|
|
|
203
203
|
"""Whether this value is an output of a graph."""
|
|
204
204
|
...
|
|
205
205
|
|
|
206
|
+
def replace_all_uses_with(
|
|
207
|
+
self, new_value: ValueProtocol | None, replace_graph_outputs: bool = False
|
|
208
|
+
) -> None:
|
|
209
|
+
"""Replace all uses of this value with the given new value.
|
|
210
|
+
|
|
211
|
+
Args:
|
|
212
|
+
new_value: The new value to replace this value with.
|
|
213
|
+
replace_graph_outputs: Whether to replace graph outputs that use this value.
|
|
214
|
+
"""
|
|
215
|
+
...
|
|
216
|
+
|
|
206
217
|
|
|
207
218
|
@typing.runtime_checkable
|
|
208
219
|
class NodeProtocol(Protocol):
|
|
@@ -205,14 +205,20 @@ def _write_external_data(
|
|
|
205
205
|
)
|
|
206
206
|
current_offset = tensor_info.offset
|
|
207
207
|
assert tensor is not None
|
|
208
|
-
raw_data = tensor.tobytes()
|
|
209
|
-
if isinstance(tensor, _core.ExternalTensor):
|
|
210
|
-
tensor.release()
|
|
211
208
|
# Pad file to required offset if needed
|
|
212
209
|
file_size = data_file.tell()
|
|
213
210
|
if current_offset > file_size:
|
|
214
211
|
data_file.write(b"\0" * (current_offset - file_size))
|
|
215
|
-
|
|
212
|
+
|
|
213
|
+
if hasattr(tensor, "tofile"):
|
|
214
|
+
# Some existing implementation of TensorProtocol
|
|
215
|
+
# may not have tofile() as it was introduced in v0.1.11
|
|
216
|
+
tensor.tofile(data_file)
|
|
217
|
+
else:
|
|
218
|
+
raw_data = tensor.tobytes()
|
|
219
|
+
if isinstance(tensor, _core.ExternalTensor):
|
|
220
|
+
tensor.release()
|
|
221
|
+
data_file.write(raw_data)
|
|
216
222
|
|
|
217
223
|
|
|
218
224
|
def _create_external_tensor(
|
{onnx_ir-0.1.10 → onnx_ir-0.1.12}/src/onnx_ir/passes/common/common_subexpression_elimination.py
RENAMED
|
@@ -150,8 +150,6 @@ def _remove_node_and_replace_values(
|
|
|
150
150
|
remove_values: The values to replace.
|
|
151
151
|
new_values: The values to replace with.
|
|
152
152
|
"""
|
|
153
|
-
# Reconnect the users of the deleted values to use the new values
|
|
154
|
-
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
155
153
|
# Update graph/function outputs if the node generates output
|
|
156
154
|
if any(remove_value.is_graph_output() for remove_value in remove_values):
|
|
157
155
|
replacement_mapping = dict(zip(remove_values, new_values))
|
|
@@ -185,6 +183,9 @@ def _remove_node_and_replace_values(
|
|
|
185
183
|
new_value.name = graph_output.name
|
|
186
184
|
graph.outputs[idx] = new_value
|
|
187
185
|
|
|
186
|
+
# Reconnect the users of the deleted values to use the new values
|
|
187
|
+
ir.convenience.replace_all_uses_with(remove_values, new_values)
|
|
188
|
+
|
|
188
189
|
graph.remove(remove_node, safe=True)
|
|
189
190
|
|
|
190
191
|
|
|
@@ -78,7 +78,7 @@ class LiftConstantsToInitializersPass(ir.passes.InPlacePass):
|
|
|
78
78
|
assert node.graph is not None
|
|
79
79
|
node.graph.register_initializer(initializer)
|
|
80
80
|
# Replace the constant node with the initializer
|
|
81
|
-
|
|
81
|
+
node.outputs[0].replace_all_uses_with(initializer)
|
|
82
82
|
node.graph.remove(node, safe=True)
|
|
83
83
|
count += 1
|
|
84
84
|
logger.debug(
|
|
@@ -15,6 +15,29 @@ import onnx_ir as ir
|
|
|
15
15
|
logger = logging.getLogger(__name__)
|
|
16
16
|
|
|
17
17
|
|
|
18
|
+
def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
|
|
19
|
+
def merge_dims(dim1, dim2):
|
|
20
|
+
if dim1 == dim2:
|
|
21
|
+
return dim1
|
|
22
|
+
if not isinstance(dim1, ir.SymbolicDim):
|
|
23
|
+
return dim1 # Prefer int value over symbolic dim
|
|
24
|
+
if not isinstance(dim2, ir.SymbolicDim):
|
|
25
|
+
return dim2
|
|
26
|
+
if dim1.value is None:
|
|
27
|
+
return dim2
|
|
28
|
+
return dim1
|
|
29
|
+
|
|
30
|
+
if shape1 is None:
|
|
31
|
+
return shape2
|
|
32
|
+
if shape2 is None:
|
|
33
|
+
return shape1
|
|
34
|
+
if len(shape1) != len(shape2):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
|
|
37
|
+
)
|
|
38
|
+
return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
|
|
39
|
+
|
|
40
|
+
|
|
18
41
|
class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
19
42
|
"""Pass for eliminating redundant Identity nodes.
|
|
20
43
|
|
|
@@ -75,22 +98,21 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
|
|
|
75
98
|
if output_is_graph_output and input_is_graph_input:
|
|
76
99
|
return False
|
|
77
100
|
|
|
101
|
+
# Copy over shape/type if the output has more complete information
|
|
102
|
+
input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
|
|
103
|
+
if input_value.type is None:
|
|
104
|
+
input_value.type = output_value.type
|
|
105
|
+
|
|
78
106
|
# Case 1 & 2 (merged): Eliminate the identity node
|
|
79
107
|
# Replace all uses of output with input
|
|
80
|
-
ir.convenience.replace_all_uses_with(
|
|
108
|
+
ir.convenience.replace_all_uses_with(
|
|
109
|
+
output_value, input_value, replace_graph_outputs=True
|
|
110
|
+
)
|
|
81
111
|
|
|
82
112
|
# If output is a graph output, we need to rename input and update graph outputs
|
|
83
113
|
if output_is_graph_output:
|
|
84
|
-
# Store the original output name
|
|
85
|
-
original_output_name = output_value.name
|
|
86
|
-
|
|
87
114
|
# Update the input value to have the output's name
|
|
88
|
-
input_value.name =
|
|
89
|
-
|
|
90
|
-
# Update graph outputs to point to the input value
|
|
91
|
-
for idx, graph_output in enumerate(graph_like.outputs):
|
|
92
|
-
if graph_output is output_value:
|
|
93
|
-
graph_like.outputs[idx] = input_value
|
|
115
|
+
input_value.name = output_value.name
|
|
94
116
|
|
|
95
117
|
# Remove the identity node
|
|
96
118
|
graph_like.remove(node, safe=True)
|
|
@@ -100,7 +100,7 @@ class DeduplicateInitializersPass(ir.passes.InPlacePass):
|
|
|
100
100
|
if key in initializers:
|
|
101
101
|
modified = True
|
|
102
102
|
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
103
|
-
|
|
103
|
+
initializer.replace_all_uses_with(initializer_to_keep)
|
|
104
104
|
assert initializer.name is not None
|
|
105
105
|
graph.initializers.pop(initializer.name)
|
|
106
106
|
logger.info(
|
|
@@ -165,7 +165,7 @@ class DeduplicateHashedInitializersPass(ir.passes.InPlacePass):
|
|
|
165
165
|
continue
|
|
166
166
|
modified = True
|
|
167
167
|
initializer_to_keep = initializers[key] # type: ignore[index]
|
|
168
|
-
|
|
168
|
+
initializer.replace_all_uses_with(initializer_to_keep)
|
|
169
169
|
assert initializer.name is not None
|
|
170
170
|
graph.initializers.pop(initializer.name)
|
|
171
171
|
logger.info(
|
|
@@ -168,10 +168,8 @@ class TorchTensor(_core.Tensor):
|
|
|
168
168
|
return self.numpy()
|
|
169
169
|
return self.numpy().__array__(dtype)
|
|
170
170
|
|
|
171
|
-
def
|
|
172
|
-
|
|
173
|
-
# Reading from memory directly is also more efficient because
|
|
174
|
-
# it avoids copying to a NumPy array
|
|
171
|
+
def _get_cbytes(self):
|
|
172
|
+
"""Get a ctypes byte array pointing to the tensor data."""
|
|
175
173
|
import torch._subclasses.fake_tensor
|
|
176
174
|
|
|
177
175
|
with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
|
|
@@ -185,8 +183,18 @@ class TorchTensor(_core.Tensor):
|
|
|
185
183
|
"or save the model without initializers by setting include_initializers=False."
|
|
186
184
|
)
|
|
187
185
|
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
)
|
|
186
|
+
# Return the tensor to ensure it is not garbage collected while the ctypes array is in use
|
|
187
|
+
return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
|
|
188
|
+
tensor.data_ptr()
|
|
192
189
|
)
|
|
190
|
+
|
|
191
|
+
def tobytes(self) -> bytes:
|
|
192
|
+
# Implement tobytes to support native PyTorch types so we can use types like bloat16
|
|
193
|
+
# Reading from memory directly is also more efficient because
|
|
194
|
+
# it avoids copying to a NumPy array
|
|
195
|
+
_, data = self._get_cbytes()
|
|
196
|
+
return bytes(data)
|
|
197
|
+
|
|
198
|
+
def tofile(self, file) -> None:
|
|
199
|
+
_, data = self._get_cbytes()
|
|
200
|
+
return file.write(data)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: onnx-ir
|
|
3
|
-
Version: 0.1.
|
|
3
|
+
Version: 0.1.12
|
|
4
4
|
Summary: Efficient in-memory representation for ONNX
|
|
5
5
|
Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
|
|
|
8
8
|
Project-URL: Issues, https://github.com/onnx/ir-py/issues
|
|
9
9
|
Project-URL: Repository, https://github.com/onnx/ir-py
|
|
10
10
|
Classifier: Development Status :: 4 - Beta
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.9
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
14
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
15
|
-
Classifier: Programming Language :: Python :: 3.13
|
|
16
11
|
Requires-Python: >=3.9
|
|
17
12
|
Description-Content-Type: text/markdown
|
|
18
13
|
License-File: LICENSE
|
|
@@ -25,7 +20,6 @@ Dynamic: license-file
|
|
|
25
20
|
# <img src="docs/_static/logo-light.png" alt="ONNX IR" width="250"/>
|
|
26
21
|
|
|
27
22
|
[](https://pypi.org/project/onnx-ir)
|
|
28
|
-
[](https://pypi.org/project/onnx-ir)
|
|
29
23
|
[](https://github.com/astral-sh/ruff)
|
|
30
24
|
[](https://codecov.io/gh/onnx/ir-py)
|
|
31
25
|
[](https://pepy.tech/projects/onnx-ir)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|