onnx-ir 0.1.10__py3-none-any.whl → 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -168,4 +168,4 @@ def __set_module() -> None:
168
168
 
169
169
 
170
170
  __set_module()
171
- __version__ = "0.1.10"
171
+ __version__ = "0.1.11"
onnx_ir/_core.py CHANGED
@@ -185,6 +185,19 @@ class TensorBase(abc.ABC, _protocols.TensorProtocol, _display.PrettyPrintable):
185
185
  self._metadata = _metadata.MetadataStore()
186
186
  return self._metadata
187
187
 
188
+ def tofile(self, file) -> None:
189
+ """Write the tensor to a binary file.
190
+
191
+ This method writes the raw bytes of the tensor to a file-like object.
192
+ The file-like object must have a ``write`` method that accepts bytes.
193
+
194
+ .. versionadded:: 0.1.11
195
+
196
+ Args:
197
+ file: A file-like object with a ``write`` method that accepts bytes.
198
+ """
199
+ file.write(self.tobytes())
200
+
188
201
  def display(self, *, page: bool = False) -> None:
189
202
  rich = _display.require_rich()
190
203
 
@@ -337,6 +350,38 @@ def _maybe_view_np_array_with_ml_dtypes(
337
350
  return array
338
351
 
339
352
 
353
+ def _supports_fileno(file: Any) -> bool:
354
+ """Check if the file-like object supports fileno()."""
355
+ if not hasattr(file, "fileno"):
356
+ return False
357
+ try:
358
+ file.fileno()
359
+ except Exception: # pylint: disable=broad-except
360
+ return False
361
+ return True
362
+
363
+
364
+ def _create_np_array_for_byte_representation(tensor: Tensor) -> np.ndarray:
365
+ """Create a numpy array for the byte representation of the tensor.
366
+
367
+ This function is used for serializing the tensor to bytes. It handles the
368
+ special cases for 4-bit data types and endianness.
369
+ """
370
+ array = tensor.numpy()
371
+ if tensor.dtype in {
372
+ _enums.DataType.INT4,
373
+ _enums.DataType.UINT4,
374
+ _enums.DataType.FLOAT4E2M1,
375
+ }:
376
+ # Pack the array into int4
377
+ array = _type_casting.pack_4bitx2(array)
378
+ else:
379
+ assert tensor.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
380
+ if not _IS_LITTLE_ENDIAN:
381
+ array = array.astype(array.dtype.newbyteorder("<"))
382
+ return array
383
+
384
+
340
385
  class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
341
386
  """An immutable concrete tensor.
342
387
 
@@ -509,20 +554,24 @@ class Tensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]):
509
554
  value is not a numpy array.
510
555
  """
511
556
  # TODO(justinchuby): Support DLPack
512
- array = self.numpy()
513
- if self.dtype in {
514
- _enums.DataType.INT4,
515
- _enums.DataType.UINT4,
516
- _enums.DataType.FLOAT4E2M1,
517
- }:
518
- # Pack the array into int4
519
- array = _type_casting.pack_4bitx2(array)
520
- else:
521
- assert self.dtype.itemsize == array.itemsize, "Bug: The itemsize should match"
522
- if not _IS_LITTLE_ENDIAN:
523
- array = array.view(array.dtype.newbyteorder("<"))
557
+ array = _create_np_array_for_byte_representation(self)
524
558
  return array.tobytes()
525
559
 
560
+ def tofile(self, file) -> None:
561
+ """Write the tensor to a binary file.
562
+
563
+ .. versionadded:: 0.1.11
564
+
565
+ Args:
566
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
567
+ """
568
+ if isinstance(self._raw, np.ndarray) and _supports_fileno(file):
569
+ # This is a duplication of tobytes() for handling special cases
570
+ array = _create_np_array_for_byte_representation(self)
571
+ array.tofile(file)
572
+ else:
573
+ file.write(self.tobytes())
574
+
526
575
 
527
576
  class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-many-ancestors
528
577
  """An immutable concrete tensor with its data store on disk.
@@ -535,7 +584,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
535
584
  the tensor is recommended if IO overhead and memory usage is a concern.
536
585
 
537
586
  To obtain an array, call :meth:`numpy`. To obtain the bytes,
538
- call :meth:`tobytes`.
587
+ call :meth:`tobytes`. To write the data to a file, call :meth:`tofile`.
539
588
 
540
589
  The :attr:`location` must be a relative path conforming to the ONNX
541
590
  specification. Given the correct :attr:`base_dir`, the :attr:`path` is computed
@@ -590,7 +639,7 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
590
639
  length: The length of the data in bytes.
591
640
  dtype: The data type of the tensor.
592
641
  shape: The shape of the tensor.
593
- name: The name of the tensor..
642
+ name: The name of the tensor.
594
643
  doc_string: The documentation string.
595
644
  metadata_props: The metadata properties.
596
645
  base_dir: The base directory for the external data. It is used to resolve relative paths.
@@ -746,6 +795,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
746
795
  length = self._length or self.nbytes
747
796
  return self.raw[offset : offset + length]
748
797
 
798
+ def tofile(self, file) -> None:
799
+ self._check_validity()
800
+ with open(self.path, "rb") as src:
801
+ if self._offset is not None:
802
+ src.seek(self._offset)
803
+ bytes_to_copy = self._length or self.nbytes
804
+ chunk_size = 1024 * 1024 # 1MB
805
+ while bytes_to_copy > 0:
806
+ chunk = src.read(min(chunk_size, bytes_to_copy))
807
+ file.write(chunk)
808
+ bytes_to_copy -= len(chunk)
809
+
749
810
  def valid(self) -> bool:
750
811
  """Check if the tensor is valid.
751
812
 
@@ -979,6 +1040,15 @@ class LazyTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=too-
979
1040
  """Return the bytes of the tensor."""
980
1041
  return self._evaluate().tobytes()
981
1042
 
1043
+ def tofile(self, file) -> None:
1044
+ tensor = self._evaluate()
1045
+ if hasattr(tensor, "tofile"):
1046
+ # Some existing implementation of TensorProtocol
1047
+ # may not have tofile() as it was introduced in v0.1.11
1048
+ tensor.tofile(file)
1049
+ else:
1050
+ super().tofile(file)
1051
+
982
1052
 
983
1053
  class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatible]): # pylint: disable=too-many-ancestors
984
1054
  """A tensor that stores 4bit datatypes in packed format.
@@ -1110,9 +1180,26 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1110
1180
  """
1111
1181
  array = self.numpy_packed()
1112
1182
  if not _IS_LITTLE_ENDIAN:
1113
- array = array.view(array.dtype.newbyteorder("<"))
1183
+ array = array.astype(array.dtype.newbyteorder("<"))
1114
1184
  return array.tobytes()
1115
1185
 
1186
+ def tofile(self, file) -> None:
1187
+ """Write the tensor to a binary file.
1188
+
1189
+ .. versionadded:: 0.1.11
1190
+
1191
+ Args:
1192
+ file: A file-like object with a ``write`` method that accepts bytes, or has an ``fileno()`` method.
1193
+ """
1194
+ if _supports_fileno(file):
1195
+ # This is a duplication of tobytes() for handling edge cases
1196
+ array = self.numpy_packed()
1197
+ if not _IS_LITTLE_ENDIAN:
1198
+ array = array.astype(array.dtype.newbyteorder("<"))
1199
+ array.tofile(file)
1200
+ else:
1201
+ file.write(self.tobytes())
1202
+
1116
1203
 
1117
1204
  class SymbolicDim(_protocols.SymbolicDimProtocol, _display.PrettyPrintable):
1118
1205
  """Immutable symbolic dimension that can be shared across multiple shapes.
@@ -2214,23 +2301,38 @@ class Value(_protocols.ValueProtocol, _display.PrettyPrintable):
2214
2301
 
2215
2302
  @name.setter
2216
2303
  def name(self, value: str | None) -> None:
2217
- if self._const_value is not None:
2218
- self._const_value.name = value
2219
- old_name = self._name
2220
- self._name = value
2221
- if self.is_initializer():
2304
+ if self._name == value:
2305
+ return
2306
+
2307
+ # First check if renaming is valid. Do not change anything if it is invalid
2308
+ # to prevent the value from being in an inconsistent state.
2309
+ is_initializer = self.is_initializer()
2310
+ if is_initializer:
2222
2311
  if value is None:
2223
2312
  raise ValueError(
2224
- "Initializer value cannot have name set to None. Please pop() the value from initializers first"
2313
+ "Initializer value cannot have name set to None. Please pop() the value from initializers first to do so."
2225
2314
  )
2226
- # Rename the initializer entry in the graph
2227
2315
  graph = self._graph
2228
2316
  assert graph is not None
2229
- assert old_name is not None
2230
2317
  if value in graph.initializers and graph.initializers[value] is not self:
2231
2318
  raise ValueError(
2232
- f"Cannot rename initializer to '{value}': an initializer with that name already exists."
2319
+ f"Cannot rename initializer '{self}' to '{value}': an initializer with that name already exists."
2233
2320
  )
2321
+
2322
+ # Rename the backing constant tensor
2323
+ if self._const_value is not None:
2324
+ self._const_value.name = value
2325
+
2326
+ # Rename self
2327
+ old_name = self._name
2328
+ self._name = value
2329
+
2330
+ if is_initializer:
2331
+ # Rename the initializer entry in the graph
2332
+ assert value is not None, "debug: Should be guarded above"
2333
+ graph = self._graph
2334
+ assert graph is not None
2335
+ assert old_name is not None
2234
2336
  graph.initializers.pop(old_name)
2235
2337
  graph.initializers[value] = self
2236
2338
 
onnx_ir/external_data.py CHANGED
@@ -205,14 +205,20 @@ def _write_external_data(
205
205
  )
206
206
  current_offset = tensor_info.offset
207
207
  assert tensor is not None
208
- raw_data = tensor.tobytes()
209
- if isinstance(tensor, _core.ExternalTensor):
210
- tensor.release()
211
208
  # Pad file to required offset if needed
212
209
  file_size = data_file.tell()
213
210
  if current_offset > file_size:
214
211
  data_file.write(b"\0" * (current_offset - file_size))
215
- data_file.write(raw_data)
212
+
213
+ if hasattr(tensor, "tofile"):
214
+ # Some existing implementation of TensorProtocol
215
+ # may not have tofile() as it was introduced in v0.1.11
216
+ tensor.tofile(data_file)
217
+ else:
218
+ raw_data = tensor.tobytes()
219
+ if isinstance(tensor, _core.ExternalTensor):
220
+ tensor.release()
221
+ data_file.write(raw_data)
216
222
 
217
223
 
218
224
  def _create_external_tensor(
@@ -15,6 +15,29 @@ import onnx_ir as ir
15
15
  logger = logging.getLogger(__name__)
16
16
 
17
17
 
18
+ def _merge_shapes(shape1: ir.Shape | None, shape2: ir.Shape | None) -> ir.Shape | None:
19
+ def merge_dims(dim1, dim2):
20
+ if dim1 == dim2:
21
+ return dim1
22
+ if not isinstance(dim1, ir.SymbolicDim):
23
+ return dim1 # Prefer int value over symbolic dim
24
+ if not isinstance(dim2, ir.SymbolicDim):
25
+ return dim2
26
+ if dim1.value is None:
27
+ return dim2
28
+ return dim1
29
+
30
+ if shape1 is None:
31
+ return shape2
32
+ if shape2 is None:
33
+ return shape1
34
+ if len(shape1) != len(shape2):
35
+ raise ValueError(
36
+ f"Shapes must have the same rank, got {len(shape1)} and {len(shape2)}."
37
+ )
38
+ return ir.Shape([merge_dims(dim1, dim2) for dim1, dim2 in zip(shape1, shape2)])
39
+
40
+
18
41
  class IdentityEliminationPass(ir.passes.InPlacePass):
19
42
  """Pass for eliminating redundant Identity nodes.
20
43
 
@@ -75,6 +98,11 @@ class IdentityEliminationPass(ir.passes.InPlacePass):
75
98
  if output_is_graph_output and input_is_graph_input:
76
99
  return False
77
100
 
101
+ # Copy over shape/type if the output has more complete information
102
+ input_value.shape = _merge_shapes(input_value.shape, output_value.shape)
103
+ if input_value.type is None:
104
+ input_value.type = output_value.type
105
+
78
106
  # Case 1 & 2 (merged): Eliminate the identity node
79
107
  # Replace all uses of output with input
80
108
  ir.convenience.replace_all_uses_with(output_value, input_value)
@@ -168,10 +168,8 @@ class TorchTensor(_core.Tensor):
168
168
  return self.numpy()
169
169
  return self.numpy().__array__(dtype)
170
170
 
171
- def tobytes(self) -> bytes:
172
- # Implement tobytes to support native PyTorch types so we can use types like bloat16
173
- # Reading from memory directly is also more efficient because
174
- # it avoids copying to a NumPy array
171
+ def _get_cbytes(self):
172
+ """Get a ctypes byte array pointing to the tensor data."""
175
173
  import torch._subclasses.fake_tensor
176
174
 
177
175
  with torch._subclasses.fake_tensor.unset_fake_temporarily(): # pylint: disable=protected-access
@@ -185,8 +183,18 @@ class TorchTensor(_core.Tensor):
185
183
  "or save the model without initializers by setting include_initializers=False."
186
184
  )
187
185
 
188
- return bytes(
189
- (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
190
- tensor.data_ptr()
191
- )
186
+ # Return the tensor to ensure it is not garbage collected while the ctypes array is in use
187
+ return tensor, (ctypes.c_ubyte * tensor.element_size() * tensor.numel()).from_address(
188
+ tensor.data_ptr()
192
189
  )
190
+
191
+ def tobytes(self) -> bytes:
192
+ # Implement tobytes to support native PyTorch types so we can use types like bloat16
193
+ # Reading from memory directly is also more efficient because
194
+ # it avoids copying to a NumPy array
195
+ _, data = self._get_cbytes()
196
+ return bytes(data)
197
+
198
+ def tofile(self, file) -> None:
199
+ _, data = self._get_cbytes()
200
+ return file.write(data)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.10
3
+ Version: 0.1.11
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License-Expression: Apache-2.0
@@ -8,11 +8,6 @@ Project-URL: Homepage, https://onnx.ai/ir-py
8
8
  Project-URL: Issues, https://github.com/onnx/ir-py/issues
9
9
  Project-URL: Repository, https://github.com/onnx/ir-py
10
10
  Classifier: Development Status :: 4 - Beta
11
- Classifier: Programming Language :: Python :: 3.9
12
- Classifier: Programming Language :: Python :: 3.10
13
- Classifier: Programming Language :: Python :: 3.11
14
- Classifier: Programming Language :: Python :: 3.12
15
- Classifier: Programming Language :: Python :: 3.13
16
11
  Requires-Python: >=3.9
17
12
  Description-Content-Type: text/markdown
18
13
  License-File: LICENSE
@@ -1,5 +1,5 @@
1
- onnx_ir/__init__.py,sha256=N-6BRNeRjjU17-iaezMOq1ErVq8irJltkubTpAbZ2BY,3441
2
- onnx_ir/_core.py,sha256=9AnsXuoiFzypp9U_al7yU46OU0iqrGq0rJzLw4FHmcI,143517
1
+ onnx_ir/__init__.py,sha256=152XzM3zz66eHZYCpERRhPZhEExuXkOhdS_zYi_0wQE,3441
2
+ onnx_ir/_core.py,sha256=32tCB7gZkRjI2kcrp5cLUGz-KH7-BAbaxVeuoWR2TZs,147035
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
4
  onnx_ir/_enums.py,sha256=E7WQ7yQzulBeimamc9q_k4fEUoyH_2PWtaOMpwck_W0,13915
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
@@ -14,11 +14,11 @@ onnx_ir/_tape.py,sha256=nEGY6VZVKuB8FDyXeYr0MTq8j7E4HKOE2yN8qpz4ia0,7007
14
14
  onnx_ir/_type_casting.py,sha256=hbikTmgFEu0SEfnbgv2R1LbpuPQ2MCfqto3-oLWhcBc,1645
15
15
  onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,1333
16
16
  onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
17
- onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
17
+ onnx_ir/external_data.py,sha256=9XRXQ8blxM7KGMPrAXWiF_QIdv2O84VpHu5Dp_fAX6A,18334
18
18
  onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
19
  onnx_ir/serde.py,sha256=S0zCZnfePs1UV927HDEr3VnXuue_B5PD1dVPqwuwrak,80636
20
20
  onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
21
- onnx_ir/tensor_adapters.py,sha256=YffUeZDZi8thxm-4nF2cL6cNSJSVmLm4A3IbEzwY8QQ,7233
21
+ onnx_ir/tensor_adapters.py,sha256=1YxOgYro9QV-Dw4XsiD9WQgs1_6h13nu6tfmA1mG7IA,7568
22
22
  onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
23
23
  onnx_ir/traversal.py,sha256=Wy4XphwuapAvm94-5iaz6G8LjIoMFpY7qfPfXzYViEE,4488
24
24
  onnx_ir/_convenience/__init__.py,sha256=SO7kc8RXVKEUODGh0q2Y7WgmbUsOjYSixmKFx_A0DAQ,19752
@@ -31,7 +31,7 @@ onnx_ir/passes/common/_c_api_utils.py,sha256=g6riA6xNGVWaO5YjVHZ0krrfslWHmRlryRk
31
31
  onnx_ir/passes/common/clear_metadata_and_docstring.py,sha256=YwouLfsNFSaTuGd7uMOGjdvVwG9yHQTkSphUgDlM0ME,2365
32
32
  onnx_ir/passes/common/common_subexpression_elimination.py,sha256=wZ1zEPdCshYB_ifP9fCAVfzQkesE6uhCfzCuL2qO5fA,7948
33
33
  onnx_ir/passes/common/constant_manipulation.py,sha256=Ja_2uO59ni8YX600iF3wvCmOk4EvM5crl7csYU1s7rQ,9391
34
- onnx_ir/passes/common/identity_elimination.py,sha256=wN8g8uPGn6IIQ6Jf1lo6nGTXvpWyiSQtT_CfmtvZpwA,3664
34
+ onnx_ir/passes/common/identity_elimination.py,sha256=yLoy0DEIIfhP1smlCx-2Bxsv89lij1uvFdFVh9XUfTY,4667
35
35
  onnx_ir/passes/common/initializer_deduplication.py,sha256=gKrXTMFAtCkMmiIm8zWzwPnwSbRdZxunJeAt_jFU-vY,7253
36
36
  onnx_ir/passes/common/inliner.py,sha256=wBoO6yXt6F1AObQjYZHMQ0wn3YH681N4HQQVyaMAYd4,13702
37
37
  onnx_ir/passes/common/naming.py,sha256=l_LrUiI3gAoSEVs8YeQ5kRNWp7aMOlBK-SfFUsKobZI,10687
@@ -39,8 +39,8 @@ onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6sv
39
39
  onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
40
40
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
41
41
  onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
42
- onnx_ir-0.1.10.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
43
- onnx_ir-0.1.10.dist-info/METADATA,sha256=3ou4NcDwiEwn172tRdD57Tb7y7rCGUyfhFDqan7qM3s,3605
44
- onnx_ir-0.1.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
- onnx_ir-0.1.10.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
46
- onnx_ir-0.1.10.dist-info/RECORD,,
42
+ onnx_ir-0.1.11.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
43
+ onnx_ir-0.1.11.dist-info/METADATA,sha256=Q5F7zhOGZspqkN7nI0NnAjSdmRgZqFK6zgWM-2azglA,3351
44
+ onnx_ir-0.1.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ onnx_ir-0.1.11.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
46
+ onnx_ir-0.1.11.dist-info/RECORD,,