onnx-ir 0.1.3__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of onnx-ir might be problematic. Click here for more details.

onnx_ir/__init__.py CHANGED
@@ -167,4 +167,4 @@ def __set_module() -> None:
167
167
 
168
168
 
169
169
  __set_module()
170
- __version__ = "0.1.3"
170
+ __version__ = "0.1.4"
onnx_ir/_core.py CHANGED
@@ -657,15 +657,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
657
657
  self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
658
658
  return
659
659
  # Map the whole file into the memory
660
- # TODO(justinchuby): Verify if this would exhaust the memory address space
661
660
  with open(self.path, "rb") as f:
662
661
  self.raw = mmap.mmap(
663
662
  f.fileno(),
664
663
  0,
665
664
  access=mmap.ACCESS_READ,
666
665
  )
667
- # Handle the byte order correctly by always using little endian
668
- dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
666
+
669
667
  if self.dtype in {
670
668
  _enums.DataType.INT4,
671
669
  _enums.DataType.UINT4,
@@ -675,16 +673,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
675
673
  dt = np.dtype(np.uint8).newbyteorder("<")
676
674
  count = self.size // 2 + self.size % 2
677
675
  else:
676
+ # Handle the byte order correctly by always using little endian
677
+ dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
678
678
  count = self.size
679
+
679
680
  self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
680
681
  shape = self.shape.numpy()
681
- if self.dtype == _enums.DataType.INT4:
682
- # Unpack the int4 arrays
683
- self._array = _type_casting.unpack_int4(self._array, shape)
684
- elif self.dtype == _enums.DataType.UINT4:
685
- self._array = _type_casting.unpack_uint4(self._array, shape)
686
- elif self.dtype == _enums.DataType.FLOAT4E2M1:
687
- self._array = _type_casting.unpack_float4e2m1(self._array, shape)
682
+
683
+ if self.dtype.bitwidth == 4:
684
+ # Unpack the 4bit arrays
685
+ self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
686
+ self.dtype.numpy()
687
+ )
688
688
  else:
689
689
  self._array = self._array.reshape(shape)
690
690
 
@@ -1071,15 +1071,7 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
1071
1071
  """
1072
1072
  array = self.numpy_packed()
1073
1073
  # ONNX IR returns the unpacked arrays
1074
- if self.dtype == _enums.DataType.INT4:
1075
- return _type_casting.unpack_int4(array, self.shape.numpy())
1076
- if self.dtype == _enums.DataType.UINT4:
1077
- return _type_casting.unpack_uint4(array, self.shape.numpy())
1078
- if self.dtype == _enums.DataType.FLOAT4E2M1:
1079
- return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
1080
- raise TypeError(
1081
- f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
1082
- )
1074
+ return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
1083
1075
 
1084
1076
  def numpy_packed(self) -> npt.NDArray[np.uint8]:
1085
1077
  """Return the tensor as a packed array."""
onnx_ir/_enums.py CHANGED
@@ -169,6 +169,48 @@ class DataType(enum.IntEnum):
169
169
  DataType.FLOAT4E2M1,
170
170
  }
171
171
 
172
+ def is_integer(self) -> bool:
173
+ """Returns True if the data type is an integer.
174
+
175
+ .. versionadded:: 0.1.4
176
+ """
177
+ return self in {
178
+ DataType.UINT8,
179
+ DataType.INT8,
180
+ DataType.UINT16,
181
+ DataType.INT16,
182
+ DataType.INT32,
183
+ DataType.INT64,
184
+ DataType.UINT32,
185
+ DataType.UINT64,
186
+ DataType.UINT4,
187
+ DataType.INT4,
188
+ }
189
+
190
+ def is_signed(self) -> bool:
191
+ """Returns True if the data type is a signed type.
192
+
193
+ .. versionadded:: 0.1.4
194
+ """
195
+ return self in {
196
+ DataType.FLOAT,
197
+ DataType.INT8,
198
+ DataType.INT16,
199
+ DataType.INT32,
200
+ DataType.INT64,
201
+ DataType.FLOAT16,
202
+ DataType.DOUBLE,
203
+ DataType.COMPLEX64,
204
+ DataType.COMPLEX128,
205
+ DataType.BFLOAT16,
206
+ DataType.FLOAT8E4M3FN,
207
+ DataType.FLOAT8E4M3FNUZ,
208
+ DataType.FLOAT8E5M2,
209
+ DataType.FLOAT8E5M2FNUZ,
210
+ DataType.INT4,
211
+ DataType.FLOAT4E2M1,
212
+ }
213
+
172
214
  def __repr__(self) -> str:
173
215
  return self.name
174
216
 
onnx_ir/_type_casting.py CHANGED
@@ -1,14 +1,12 @@
1
1
  # Copyright (c) ONNX Project Contributors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  """Numpy utilities for non-native type operation."""
4
- # TODO(justinchuby): Upstream the logic to onnx
5
4
 
6
5
  from __future__ import annotations
7
6
 
8
7
  import typing
9
8
  from collections.abc import Sequence
10
9
 
11
- import ml_dtypes
12
10
  import numpy as np
13
11
 
14
12
  if typing.TYPE_CHECKING:
@@ -28,9 +26,7 @@ def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
28
26
  return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
29
27
 
30
28
 
31
- def _unpack_uint4_as_uint8(
32
- data: npt.NDArray[np.uint8], dims: Sequence[int]
33
- ) -> npt.NDArray[np.uint8]:
29
+ def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
34
30
  """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
35
31
 
36
32
  Args:
@@ -52,56 +48,3 @@ def _unpack_uint4_as_uint8(
52
48
  result = result[:-1]
53
49
  result.resize(dims, refcheck=False)
54
50
  return result
55
-
56
-
57
- def unpack_uint4(
58
- data: npt.NDArray[np.uint8], dims: Sequence[int]
59
- ) -> npt.NDArray[ml_dtypes.uint4]:
60
- """Convert a packed uint4 array to unpacked uint4 array represented as uint8.
61
-
62
- Args:
63
- data: A numpy array.
64
- dims: The dimensions are used to reshape the unpacked buffer.
65
-
66
- Returns:
67
- A numpy array of int8/uint8 reshaped to dims.
68
- """
69
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
70
-
71
-
72
- def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
73
- """Extend 4-bit signed integer to 8-bit signed integer."""
74
- return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
75
-
76
-
77
- def unpack_int4(
78
- data: npt.NDArray[np.uint8], dims: Sequence[int]
79
- ) -> npt.NDArray[ml_dtypes.int4]:
80
- """Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
81
-
82
- The sign bit is extended to the most significant bit of the int8.
83
-
84
- Args:
85
- data: A numpy array.
86
- dims: The dimensions are used to reshape the unpacked buffer.
87
-
88
- Returns:
89
- A numpy array of int8 reshaped to dims.
90
- """
91
- unpacked = _unpack_uint4_as_uint8(data, dims)
92
- return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
93
-
94
-
95
- def unpack_float4e2m1(
96
- data: npt.NDArray[np.uint8], dims: Sequence[int]
97
- ) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
98
- """Convert a packed float4e2m1 array to unpacked float4e2m1 array.
99
-
100
- Args:
101
- data: A numpy array.
102
- dims: The dimensions are used to reshape the unpacked buffer.
103
-
104
- Returns:
105
- A numpy array of float32 reshaped to dims.
106
- """
107
- return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
onnx_ir/py.typed ADDED
@@ -0,0 +1 @@
1
+
onnx_ir/serde.py CHANGED
@@ -74,7 +74,6 @@ from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
74
74
 
75
75
  if typing.TYPE_CHECKING:
76
76
  import google.protobuf.internal.containers as proto_containers
77
- import numpy.typing as npt
78
77
 
79
78
  logger = logging.getLogger(__name__)
80
79
 
@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
117
116
  return np.dtype(dtype).newbyteorder("<")
118
117
 
119
118
 
120
- def _unflatten_complex(
121
- array: npt.NDArray[np.float32 | np.float64],
122
- ) -> npt.NDArray[np.complex64 | np.complex128]:
123
- """Convert the real representation of a complex dtype to the complex dtype."""
124
- return array[::2] + 1j * array[1::2]
125
-
126
-
127
119
  @typing.overload
128
120
  def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
129
121
  @typing.overload
@@ -391,54 +383,88 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
391
383
  "Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
392
384
  )
393
385
 
386
+ shape = self._proto.dims
387
+
394
388
  if self._proto.HasField("raw_data"):
395
- array = np.frombuffer(self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<"))
396
- # Cannot return now, because we may need to unpack 4bit tensors
397
- elif dtype == _enums.DataType.STRING:
398
- return np.array(self._proto.string_data).reshape(self._proto.dims)
399
- elif self._proto.int32_data:
400
- array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
401
- if dtype in {_enums.DataType.FLOAT16, _enums.DataType.BFLOAT16}:
402
- # Reinterpret the int32 as float16 or bfloat16
403
- array = array.astype(np.uint16).view(dtype.numpy())
404
- elif dtype in {
389
+ if dtype.bitwidth == 4:
390
+ return _type_casting.unpack_4bitx2(
391
+ np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
392
+ ).view(dtype.numpy())
393
+ return np.frombuffer(
394
+ self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
395
+ ).reshape(shape)
396
+ if dtype == _enums.DataType.STRING:
397
+ return np.array(self._proto.string_data).reshape(shape)
398
+ if self._proto.int32_data:
399
+ assert dtype in {
400
+ _enums.DataType.BFLOAT16,
401
+ _enums.DataType.BOOL,
402
+ _enums.DataType.FLOAT16,
403
+ _enums.DataType.FLOAT4E2M1,
405
404
  _enums.DataType.FLOAT8E4M3FN,
406
405
  _enums.DataType.FLOAT8E4M3FNUZ,
407
406
  _enums.DataType.FLOAT8E5M2,
408
407
  _enums.DataType.FLOAT8E5M2FNUZ,
409
- }:
410
- array = array.astype(np.uint8).view(dtype.numpy())
411
- elif self._proto.int64_data:
412
- array = np.array(self._proto.int64_data, dtype=_little_endian_dtype(np.int64))
413
- elif self._proto.uint64_data:
408
+ _enums.DataType.INT16,
409
+ _enums.DataType.INT32,
410
+ _enums.DataType.INT4,
411
+ _enums.DataType.INT8,
412
+ _enums.DataType.UINT16,
413
+ _enums.DataType.UINT4,
414
+ _enums.DataType.UINT8,
415
+ }, f"Unsupported dtype {dtype} for int32_data"
416
+ array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
417
+ if dtype.bitwidth == 32:
418
+ return array.reshape(shape)
419
+ if dtype.bitwidth == 16:
420
+ # Reinterpret the int32 as float16 or bfloat16
421
+ return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
422
+ if dtype.bitwidth == 8:
423
+ return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
424
+ if dtype.bitwidth == 4:
425
+ return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
426
+ dtype.numpy()
427
+ )
428
+ raise ValueError(
429
+ f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
430
+ )
431
+ if self._proto.int64_data:
432
+ assert dtype in {
433
+ _enums.DataType.INT64,
434
+ }, f"Unsupported dtype {dtype} for int64_data"
435
+ return np.array(
436
+ self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
437
+ ).reshape(shape)
438
+ if self._proto.uint64_data:
439
+ assert dtype in {
440
+ _enums.DataType.UINT64,
441
+ _enums.DataType.UINT32,
442
+ }, f"Unsupported dtype {dtype} for uint64_data"
414
443
  array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
415
- elif self._proto.float_data:
444
+ if dtype == _enums.DataType.UINT32:
445
+ return array.astype(np.uint32).reshape(shape)
446
+ return array.reshape(shape)
447
+ if self._proto.float_data:
448
+ assert dtype in {
449
+ _enums.DataType.FLOAT,
450
+ _enums.DataType.COMPLEX64,
451
+ }, f"Unsupported dtype {dtype} for float_data"
416
452
  array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
417
453
  if dtype == _enums.DataType.COMPLEX64:
418
- array = _unflatten_complex(array)
419
- elif self._proto.double_data:
454
+ return array.view(np.complex64).reshape(shape)
455
+ return array.reshape(shape)
456
+ if self._proto.double_data:
457
+ assert dtype in {
458
+ _enums.DataType.DOUBLE,
459
+ _enums.DataType.COMPLEX128,
460
+ }, f"Unsupported dtype {dtype} for double_data"
420
461
  array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
421
462
  if dtype == _enums.DataType.COMPLEX128:
422
- array = _unflatten_complex(array)
423
- else:
424
- # Empty tensor
425
- if not self._proto.dims:
426
- # When dims not precent and there is no data, we return an empty array
427
- return np.array([], dtype=dtype.numpy())
428
- else:
429
- # Otherwise we return a size 0 array with the correct shape
430
- return np.zeros(self._proto.dims, dtype=dtype.numpy())
431
-
432
- if dtype == _enums.DataType.INT4:
433
- return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
434
- elif dtype == _enums.DataType.UINT4:
435
- return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
436
- elif dtype == _enums.DataType.FLOAT4E2M1:
437
- return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
438
- else:
439
- # Otherwise convert to the correct dtype and reshape
440
- # Note we cannot use view() here because the storage dtype may not be the same size as the target
441
- return array.astype(dtype.numpy()).reshape(self._proto.dims)
463
+ return array.view(np.complex128).reshape(shape)
464
+ return array.reshape(shape)
465
+
466
+ # Empty tensor. We return a size 0 array with the correct shape
467
+ return np.zeros(shape, dtype=dtype.numpy())
442
468
 
443
469
  def tobytes(self) -> bytes:
444
470
  """Return the tensor as a byte string conformed to the ONNX specification, in little endian.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-ir
3
- Version: 0.1.3
3
+ Version: 0.1.4
4
4
  Summary: Efficient in-memory representation for ONNX
5
5
  Author-email: ONNX Contributors <onnx-technical-discuss@lists.lfaidata.foundation>
6
6
  License: Apache License v2.0
@@ -1,7 +1,7 @@
1
- onnx_ir/__init__.py,sha256=5KP1Ngl2qyWiqb5S0Ol5owYsbU0geo4LFwGwN8EXTIk,3424
2
- onnx_ir/_core.py,sha256=-9BpVTZHuHQ9jsms33wqu4NjMEaDF_M57sIuVxYcM1I,137964
1
+ onnx_ir/__init__.py,sha256=2w65_FPhzimLAzacA-s_IZWCwOWtt6-AMH_Lop7WOcI,3424
2
+ onnx_ir/_core.py,sha256=Y-RJSIgwxKWRTRHc_fWOEH_vjTxtMz2qbZ37hLxzdTI,137284
3
3
  onnx_ir/_display.py,sha256=230bMN_hVy47Ug3HkA4o5Tf5Hr21AnBEoq5w0fxjyTs,1300
4
- onnx_ir/_enums.py,sha256=4lmm_DFKEtz6PqNw6gt6GcqrBYHisctgKMsUbQCm5N8,8252
4
+ onnx_ir/_enums.py,sha256=oWJywcMCMczo1xYz25_eXFOhhJYciDNJTvzgYaKy1E0,9373
5
5
  onnx_ir/_graph_comparison.py,sha256=8_D1gu547eCDotEUqxfIJhUGU_Ufhfji7sfsSraOj3g,727
6
6
  onnx_ir/_graph_containers.py,sha256=PRKrshRZ5rzWCgRs1TefzJq9n8wyo7OqeKy3XxMhyys,14265
7
7
  onnx_ir/_io.py,sha256=GWwA4XOZ-ZX1cgibgaYD0K0O5d9LX21ZwcBN02Wrh04,5205
@@ -11,11 +11,12 @@ onnx_ir/_name_authority.py,sha256=PnoV9TRgMLussZNufWavJXosDWx5avPfldVjMWEEz18,30
11
11
  onnx_ir/_polyfill.py,sha256=LzAGBKQbVDlURC0tgQgaxgkYU4rESgCYnqVs-u-Vsx8,887
12
12
  onnx_ir/_protocols.py,sha256=M29sIOAvtdlis3QtBvCQPH4pnvSwhJCQNCvs3IrN9FY,21276
13
13
  onnx_ir/_tape.py,sha256=nEGY6VZVKuB8FDyXeYr0MTq8j7E4HKOE2yN8qpz4ia0,7007
14
- onnx_ir/_type_casting.py,sha256=8iZDVrNAx_FwRVt48G4tkzIOFu3I6AsETpH3fdxcyEI,3387
14
+ onnx_ir/_type_casting.py,sha256=hbikTmgFEu0SEfnbgv2R1LbpuPQ2MCfqto3-oLWhcBc,1645
15
15
  onnx_ir/_version_utils.py,sha256=bZThuE7meVHFOY1DLsmss9WshVIp9iig7udGfDbVaK4,1333
16
16
  onnx_ir/convenience.py,sha256=0B1epuXZCSmY4FbW2vaYfR-t5ubxBZ1UruiytHs-zFw,917
17
17
  onnx_ir/external_data.py,sha256=rXHtRU-9tjAt10Iervhr5lsI6Dtv-EhR7J4brxppImA,18079
18
- onnx_ir/serde.py,sha256=YkbYfQMwn0YAzTd3tVDSWJ-NBiSVsG-74T6xk3e5iTU,75073
18
+ onnx_ir/py.typed,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
19
+ onnx_ir/serde.py,sha256=8D9eBVFcCvwRHyW7Y4CJNuAU0iBI3Mjk0A-w2QidHK4,75892
19
20
  onnx_ir/tape.py,sha256=4FyfAHmVhQoMsfHMYnBwP2azi6UF6b6pj--ercObqZs,350
20
21
  onnx_ir/tensor_adapters.py,sha256=dXuapwfFcpLhjKC6AOqCXbtY3WvDaEHoCNPwjnUK7_o,6565
21
22
  onnx_ir/testing.py,sha256=WTrjf2joWizDWaYMJlV1KjZMQw7YmZ8NvuBTVn1uY6s,8803
@@ -36,8 +37,8 @@ onnx_ir/passes/common/onnx_checker.py,sha256=_sPmJ2ff9pDB1g9q7082BL6fyubomRaj6sv
36
37
  onnx_ir/passes/common/shape_inference.py,sha256=LVdvxjeKtcIEbPcb6mKisxoPJOOawzsm3tzk5j9xqeM,3992
37
38
  onnx_ir/passes/common/topological_sort.py,sha256=Vcu1YhBdfRX4LROr0NScjB1Pwz2DjBFD0Z_GxqaxPF8,999
38
39
  onnx_ir/passes/common/unused_removal.py,sha256=cBNqaqGnUVyCWxsD7hBzYk4qSglVPo3SmHAvkUo5-Oc,7613
39
- onnx_ir-0.1.3.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
40
- onnx_ir-0.1.3.dist-info/METADATA,sha256=vKG8o_nAUJfjM05rahv0g-FCeHkHXIwCAcuYzSY6PH8,4782
41
- onnx_ir-0.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
42
- onnx_ir-0.1.3.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
43
- onnx_ir-0.1.3.dist-info/RECORD,,
40
+ onnx_ir-0.1.4.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
41
+ onnx_ir-0.1.4.dist-info/METADATA,sha256=Oay3Vxf4jfSY50vyCfTYaH0Pbxifv47jd3yimr8CDW8,4782
42
+ onnx_ir-0.1.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
43
+ onnx_ir-0.1.4.dist-info/top_level.txt,sha256=W5tROO93YjO0XRxIdjMy4wocp-5st5GiI2ukvW7UhDo,8
44
+ onnx_ir-0.1.4.dist-info/RECORD,,