onnx-ir 0.1.3__tar.gz → 0.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of onnx-ir might be problematic. Click here for more details.
- {onnx_ir-0.1.3/src/onnx_ir.egg-info → onnx_ir-0.1.4}/PKG-INFO +1 -1
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/__init__.py +1 -1
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_core.py +11 -19
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_enums.py +42 -0
- onnx_ir-0.1.4/src/onnx_ir/_type_casting.py +50 -0
- onnx_ir-0.1.4/src/onnx_ir/py.typed +1 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/serde.py +72 -46
- {onnx_ir-0.1.3 → onnx_ir-0.1.4/src/onnx_ir.egg-info}/PKG-INFO +1 -1
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/SOURCES.txt +1 -0
- onnx_ir-0.1.3/src/onnx_ir/_type_casting.py +0 -107
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/LICENSE +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/MANIFEST.in +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/README.md +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/pyproject.toml +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/setup.cfg +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/__init__.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_convenience/_constructors.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_display.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_graph_comparison.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_graph_containers.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_io.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_linked_list.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_metadata.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_name_authority.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_polyfill.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_protocols.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_tape.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_thirdparty/asciichartpy.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/_version_utils.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/convenience.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/external_data.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/__init__.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/_pass_infra.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/__init__.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/_c_api_utils.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/clear_metadata_and_docstring.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/common_subexpression_elimination.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/constant_manipulation.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/initializer_deduplication.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/inliner.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/onnx_checker.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/shape_inference.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/topological_sort.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/unused_removal.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/tape.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/tensor_adapters.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/testing.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/traversal.py +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/dependency_links.txt +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/requires.txt +0 -0
- {onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir.egg-info/top_level.txt +0 -0
|
@@ -657,15 +657,13 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
657
657
|
self._array = np.empty(self.shape.numpy(), dtype=self.dtype.numpy())
|
|
658
658
|
return
|
|
659
659
|
# Map the whole file into the memory
|
|
660
|
-
# TODO(justinchuby): Verify if this would exhaust the memory address space
|
|
661
660
|
with open(self.path, "rb") as f:
|
|
662
661
|
self.raw = mmap.mmap(
|
|
663
662
|
f.fileno(),
|
|
664
663
|
0,
|
|
665
664
|
access=mmap.ACCESS_READ,
|
|
666
665
|
)
|
|
667
|
-
|
|
668
|
-
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
666
|
+
|
|
669
667
|
if self.dtype in {
|
|
670
668
|
_enums.DataType.INT4,
|
|
671
669
|
_enums.DataType.UINT4,
|
|
@@ -675,16 +673,18 @@ class ExternalTensor(TensorBase, _protocols.TensorProtocol): # pylint: disable=
|
|
|
675
673
|
dt = np.dtype(np.uint8).newbyteorder("<")
|
|
676
674
|
count = self.size // 2 + self.size % 2
|
|
677
675
|
else:
|
|
676
|
+
# Handle the byte order correctly by always using little endian
|
|
677
|
+
dt = np.dtype(self.dtype.numpy()).newbyteorder("<")
|
|
678
678
|
count = self.size
|
|
679
|
+
|
|
679
680
|
self._array = np.frombuffer(self.raw, dtype=dt, offset=self.offset or 0, count=count)
|
|
680
681
|
shape = self.shape.numpy()
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
self._array = _type_casting.unpack_float4e2m1(self._array, shape)
|
|
682
|
+
|
|
683
|
+
if self.dtype.bitwidth == 4:
|
|
684
|
+
# Unpack the 4bit arrays
|
|
685
|
+
self._array = _type_casting.unpack_4bitx2(self._array, shape).view(
|
|
686
|
+
self.dtype.numpy()
|
|
687
|
+
)
|
|
688
688
|
else:
|
|
689
689
|
self._array = self._array.reshape(shape)
|
|
690
690
|
|
|
@@ -1071,15 +1071,7 @@ class PackedTensor(TensorBase, _protocols.TensorProtocol, Generic[TArrayCompatib
|
|
|
1071
1071
|
"""
|
|
1072
1072
|
array = self.numpy_packed()
|
|
1073
1073
|
# ONNX IR returns the unpacked arrays
|
|
1074
|
-
|
|
1075
|
-
return _type_casting.unpack_int4(array, self.shape.numpy())
|
|
1076
|
-
if self.dtype == _enums.DataType.UINT4:
|
|
1077
|
-
return _type_casting.unpack_uint4(array, self.shape.numpy())
|
|
1078
|
-
if self.dtype == _enums.DataType.FLOAT4E2M1:
|
|
1079
|
-
return _type_casting.unpack_float4e2m1(array, self.shape.numpy())
|
|
1080
|
-
raise TypeError(
|
|
1081
|
-
f"PackedTensor only supports INT4, UINT4, FLOAT4E2M1, but got {self.dtype}"
|
|
1082
|
-
)
|
|
1074
|
+
return _type_casting.unpack_4bitx2(array, self.shape.numpy()).view(self.dtype.numpy())
|
|
1083
1075
|
|
|
1084
1076
|
def numpy_packed(self) -> npt.NDArray[np.uint8]:
|
|
1085
1077
|
"""Return the tensor as a packed array."""
|
|
@@ -169,6 +169,48 @@ class DataType(enum.IntEnum):
|
|
|
169
169
|
DataType.FLOAT4E2M1,
|
|
170
170
|
}
|
|
171
171
|
|
|
172
|
+
def is_integer(self) -> bool:
|
|
173
|
+
"""Returns True if the data type is an integer.
|
|
174
|
+
|
|
175
|
+
.. versionadded:: 0.1.4
|
|
176
|
+
"""
|
|
177
|
+
return self in {
|
|
178
|
+
DataType.UINT8,
|
|
179
|
+
DataType.INT8,
|
|
180
|
+
DataType.UINT16,
|
|
181
|
+
DataType.INT16,
|
|
182
|
+
DataType.INT32,
|
|
183
|
+
DataType.INT64,
|
|
184
|
+
DataType.UINT32,
|
|
185
|
+
DataType.UINT64,
|
|
186
|
+
DataType.UINT4,
|
|
187
|
+
DataType.INT4,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
def is_signed(self) -> bool:
|
|
191
|
+
"""Returns True if the data type is a signed type.
|
|
192
|
+
|
|
193
|
+
.. versionadded:: 0.1.4
|
|
194
|
+
"""
|
|
195
|
+
return self in {
|
|
196
|
+
DataType.FLOAT,
|
|
197
|
+
DataType.INT8,
|
|
198
|
+
DataType.INT16,
|
|
199
|
+
DataType.INT32,
|
|
200
|
+
DataType.INT64,
|
|
201
|
+
DataType.FLOAT16,
|
|
202
|
+
DataType.DOUBLE,
|
|
203
|
+
DataType.COMPLEX64,
|
|
204
|
+
DataType.COMPLEX128,
|
|
205
|
+
DataType.BFLOAT16,
|
|
206
|
+
DataType.FLOAT8E4M3FN,
|
|
207
|
+
DataType.FLOAT8E4M3FNUZ,
|
|
208
|
+
DataType.FLOAT8E5M2,
|
|
209
|
+
DataType.FLOAT8E5M2FNUZ,
|
|
210
|
+
DataType.INT4,
|
|
211
|
+
DataType.FLOAT4E2M1,
|
|
212
|
+
}
|
|
213
|
+
|
|
172
214
|
def __repr__(self) -> str:
|
|
173
215
|
return self.name
|
|
174
216
|
|
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
# Copyright (c) ONNX Project Contributors
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
"""Numpy utilities for non-native type operation."""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import typing
|
|
8
|
+
from collections.abc import Sequence
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
if typing.TYPE_CHECKING:
|
|
13
|
+
import numpy.typing as npt
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
17
|
+
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
18
|
+
# Create a 1D copy
|
|
19
|
+
array_flat = array.ravel().view(np.uint8).copy()
|
|
20
|
+
size = array.size
|
|
21
|
+
odd_sized = size % 2 == 1
|
|
22
|
+
if odd_sized:
|
|
23
|
+
array_flat.resize([size + 1], refcheck=False)
|
|
24
|
+
array_flat &= 0x0F
|
|
25
|
+
array_flat[1::2] <<= 4
|
|
26
|
+
return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def unpack_4bitx2(data: npt.NDArray[np.uint8], dims: Sequence[int]) -> npt.NDArray[np.uint8]:
|
|
30
|
+
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
data: A numpy array.
|
|
34
|
+
dims: The dimensions are used to reshape the unpacked buffer.
|
|
35
|
+
|
|
36
|
+
Returns:
|
|
37
|
+
A numpy array of int8/uint8 reshaped to dims.
|
|
38
|
+
"""
|
|
39
|
+
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
40
|
+
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
41
|
+
array_low = data & np.uint8(0x0F)
|
|
42
|
+
array_high = data & np.uint8(0xF0)
|
|
43
|
+
array_high >>= np.uint8(4)
|
|
44
|
+
result[0::2] = array_low
|
|
45
|
+
result[1::2] = array_high
|
|
46
|
+
if result.size == np.prod(dims) + 1:
|
|
47
|
+
# handle single-element padding due to odd number of elements
|
|
48
|
+
result = result[:-1]
|
|
49
|
+
result.resize(dims, refcheck=False)
|
|
50
|
+
return result
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -74,7 +74,6 @@ from onnx_ir import _convenience, _core, _enums, _protocols, _type_casting
|
|
|
74
74
|
|
|
75
75
|
if typing.TYPE_CHECKING:
|
|
76
76
|
import google.protobuf.internal.containers as proto_containers
|
|
77
|
-
import numpy.typing as npt
|
|
78
77
|
|
|
79
78
|
logger = logging.getLogger(__name__)
|
|
80
79
|
|
|
@@ -117,13 +116,6 @@ def _little_endian_dtype(dtype) -> np.dtype:
|
|
|
117
116
|
return np.dtype(dtype).newbyteorder("<")
|
|
118
117
|
|
|
119
118
|
|
|
120
|
-
def _unflatten_complex(
|
|
121
|
-
array: npt.NDArray[np.float32 | np.float64],
|
|
122
|
-
) -> npt.NDArray[np.complex64 | np.complex128]:
|
|
123
|
-
"""Convert the real representation of a complex dtype to the complex dtype."""
|
|
124
|
-
return array[::2] + 1j * array[1::2]
|
|
125
|
-
|
|
126
|
-
|
|
127
119
|
@typing.overload
|
|
128
120
|
def from_proto(proto: onnx.ModelProto) -> _core.Model: ... # type: ignore[overload-overlap]
|
|
129
121
|
@typing.overload
|
|
@@ -391,54 +383,88 @@ class TensorProtoTensor(_core.TensorBase): # pylint: disable=too-many-ancestors
|
|
|
391
383
|
"Cannot convert external tensor to numpy array. Use ir.ExternalTensor instead."
|
|
392
384
|
)
|
|
393
385
|
|
|
386
|
+
shape = self._proto.dims
|
|
387
|
+
|
|
394
388
|
if self._proto.HasField("raw_data"):
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
389
|
+
if dtype.bitwidth == 4:
|
|
390
|
+
return _type_casting.unpack_4bitx2(
|
|
391
|
+
np.frombuffer(self._proto.raw_data, dtype=np.uint8), shape
|
|
392
|
+
).view(dtype.numpy())
|
|
393
|
+
return np.frombuffer(
|
|
394
|
+
self._proto.raw_data, dtype=dtype.numpy().newbyteorder("<")
|
|
395
|
+
).reshape(shape)
|
|
396
|
+
if dtype == _enums.DataType.STRING:
|
|
397
|
+
return np.array(self._proto.string_data).reshape(shape)
|
|
398
|
+
if self._proto.int32_data:
|
|
399
|
+
assert dtype in {
|
|
400
|
+
_enums.DataType.BFLOAT16,
|
|
401
|
+
_enums.DataType.BOOL,
|
|
402
|
+
_enums.DataType.FLOAT16,
|
|
403
|
+
_enums.DataType.FLOAT4E2M1,
|
|
405
404
|
_enums.DataType.FLOAT8E4M3FN,
|
|
406
405
|
_enums.DataType.FLOAT8E4M3FNUZ,
|
|
407
406
|
_enums.DataType.FLOAT8E5M2,
|
|
408
407
|
_enums.DataType.FLOAT8E5M2FNUZ,
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
408
|
+
_enums.DataType.INT16,
|
|
409
|
+
_enums.DataType.INT32,
|
|
410
|
+
_enums.DataType.INT4,
|
|
411
|
+
_enums.DataType.INT8,
|
|
412
|
+
_enums.DataType.UINT16,
|
|
413
|
+
_enums.DataType.UINT4,
|
|
414
|
+
_enums.DataType.UINT8,
|
|
415
|
+
}, f"Unsupported dtype {dtype} for int32_data"
|
|
416
|
+
array = np.array(self._proto.int32_data, dtype=_little_endian_dtype(np.int32))
|
|
417
|
+
if dtype.bitwidth == 32:
|
|
418
|
+
return array.reshape(shape)
|
|
419
|
+
if dtype.bitwidth == 16:
|
|
420
|
+
# Reinterpret the int32 as float16 or bfloat16
|
|
421
|
+
return array.astype(np.uint16).view(dtype.numpy()).reshape(shape)
|
|
422
|
+
if dtype.bitwidth == 8:
|
|
423
|
+
return array.astype(np.uint8).view(dtype.numpy()).reshape(shape)
|
|
424
|
+
if dtype.bitwidth == 4:
|
|
425
|
+
return _type_casting.unpack_4bitx2(array.astype(np.uint8), shape).view(
|
|
426
|
+
dtype.numpy()
|
|
427
|
+
)
|
|
428
|
+
raise ValueError(
|
|
429
|
+
f"Unsupported dtype {dtype} for int32_data with bitwidth {dtype.bitwidth}"
|
|
430
|
+
)
|
|
431
|
+
if self._proto.int64_data:
|
|
432
|
+
assert dtype in {
|
|
433
|
+
_enums.DataType.INT64,
|
|
434
|
+
}, f"Unsupported dtype {dtype} for int64_data"
|
|
435
|
+
return np.array(
|
|
436
|
+
self._proto.int64_data, dtype=_little_endian_dtype(np.int64)
|
|
437
|
+
).reshape(shape)
|
|
438
|
+
if self._proto.uint64_data:
|
|
439
|
+
assert dtype in {
|
|
440
|
+
_enums.DataType.UINT64,
|
|
441
|
+
_enums.DataType.UINT32,
|
|
442
|
+
}, f"Unsupported dtype {dtype} for uint64_data"
|
|
414
443
|
array = np.array(self._proto.uint64_data, dtype=_little_endian_dtype(np.uint64))
|
|
415
|
-
|
|
444
|
+
if dtype == _enums.DataType.UINT32:
|
|
445
|
+
return array.astype(np.uint32).reshape(shape)
|
|
446
|
+
return array.reshape(shape)
|
|
447
|
+
if self._proto.float_data:
|
|
448
|
+
assert dtype in {
|
|
449
|
+
_enums.DataType.FLOAT,
|
|
450
|
+
_enums.DataType.COMPLEX64,
|
|
451
|
+
}, f"Unsupported dtype {dtype} for float_data"
|
|
416
452
|
array = np.array(self._proto.float_data, dtype=_little_endian_dtype(np.float32))
|
|
417
453
|
if dtype == _enums.DataType.COMPLEX64:
|
|
418
|
-
array
|
|
419
|
-
|
|
454
|
+
return array.view(np.complex64).reshape(shape)
|
|
455
|
+
return array.reshape(shape)
|
|
456
|
+
if self._proto.double_data:
|
|
457
|
+
assert dtype in {
|
|
458
|
+
_enums.DataType.DOUBLE,
|
|
459
|
+
_enums.DataType.COMPLEX128,
|
|
460
|
+
}, f"Unsupported dtype {dtype} for double_data"
|
|
420
461
|
array = np.array(self._proto.double_data, dtype=_little_endian_dtype(np.float64))
|
|
421
462
|
if dtype == _enums.DataType.COMPLEX128:
|
|
422
|
-
array
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
return np.array([], dtype=dtype.numpy())
|
|
428
|
-
else:
|
|
429
|
-
# Otherwise we return a size 0 array with the correct shape
|
|
430
|
-
return np.zeros(self._proto.dims, dtype=dtype.numpy())
|
|
431
|
-
|
|
432
|
-
if dtype == _enums.DataType.INT4:
|
|
433
|
-
return _type_casting.unpack_int4(array.astype(np.uint8), self._proto.dims)
|
|
434
|
-
elif dtype == _enums.DataType.UINT4:
|
|
435
|
-
return _type_casting.unpack_uint4(array.astype(np.uint8), self._proto.dims)
|
|
436
|
-
elif dtype == _enums.DataType.FLOAT4E2M1:
|
|
437
|
-
return _type_casting.unpack_float4e2m1(array.astype(np.uint8), self._proto.dims)
|
|
438
|
-
else:
|
|
439
|
-
# Otherwise convert to the correct dtype and reshape
|
|
440
|
-
# Note we cannot use view() here because the storage dtype may not be the same size as the target
|
|
441
|
-
return array.astype(dtype.numpy()).reshape(self._proto.dims)
|
|
463
|
+
return array.view(np.complex128).reshape(shape)
|
|
464
|
+
return array.reshape(shape)
|
|
465
|
+
|
|
466
|
+
# Empty tensor. We return a size 0 array with the correct shape
|
|
467
|
+
return np.zeros(shape, dtype=dtype.numpy())
|
|
442
468
|
|
|
443
469
|
def tobytes(self) -> bytes:
|
|
444
470
|
"""Return the tensor as a byte string conformed to the ONNX specification, in little endian.
|
|
@@ -1,107 +0,0 @@
|
|
|
1
|
-
# Copyright (c) ONNX Project Contributors
|
|
2
|
-
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
-
"""Numpy utilities for non-native type operation."""
|
|
4
|
-
# TODO(justinchuby): Upstream the logic to onnx
|
|
5
|
-
|
|
6
|
-
from __future__ import annotations
|
|
7
|
-
|
|
8
|
-
import typing
|
|
9
|
-
from collections.abc import Sequence
|
|
10
|
-
|
|
11
|
-
import ml_dtypes
|
|
12
|
-
import numpy as np
|
|
13
|
-
|
|
14
|
-
if typing.TYPE_CHECKING:
|
|
15
|
-
import numpy.typing as npt
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def pack_4bitx2(array: np.ndarray) -> npt.NDArray[np.uint8]:
|
|
19
|
-
"""Convert a numpy array to flatten, packed int4/uint4. Elements must be in the correct range."""
|
|
20
|
-
# Create a 1D copy
|
|
21
|
-
array_flat = array.ravel().view(np.uint8).copy()
|
|
22
|
-
size = array.size
|
|
23
|
-
odd_sized = size % 2 == 1
|
|
24
|
-
if odd_sized:
|
|
25
|
-
array_flat.resize([size + 1], refcheck=False)
|
|
26
|
-
array_flat &= 0x0F
|
|
27
|
-
array_flat[1::2] <<= 4
|
|
28
|
-
return array_flat[0::2] | array_flat[1::2] # type: ignore[return-type]
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
def _unpack_uint4_as_uint8(
|
|
32
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
33
|
-
) -> npt.NDArray[np.uint8]:
|
|
34
|
-
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
data: A numpy array.
|
|
38
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
A numpy array of int8/uint8 reshaped to dims.
|
|
42
|
-
"""
|
|
43
|
-
assert data.dtype == np.uint8, "Input data must be of type uint8"
|
|
44
|
-
result = np.empty([data.size * 2], dtype=data.dtype)
|
|
45
|
-
array_low = data & np.uint8(0x0F)
|
|
46
|
-
array_high = data & np.uint8(0xF0)
|
|
47
|
-
array_high >>= np.uint8(4)
|
|
48
|
-
result[0::2] = array_low
|
|
49
|
-
result[1::2] = array_high
|
|
50
|
-
if result.size == np.prod(dims) + 1:
|
|
51
|
-
# handle single-element padding due to odd number of elements
|
|
52
|
-
result = result[:-1]
|
|
53
|
-
result.resize(dims, refcheck=False)
|
|
54
|
-
return result
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
def unpack_uint4(
|
|
58
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
59
|
-
) -> npt.NDArray[ml_dtypes.uint4]:
|
|
60
|
-
"""Convert a packed uint4 array to unpacked uint4 array represented as uint8.
|
|
61
|
-
|
|
62
|
-
Args:
|
|
63
|
-
data: A numpy array.
|
|
64
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
65
|
-
|
|
66
|
-
Returns:
|
|
67
|
-
A numpy array of int8/uint8 reshaped to dims.
|
|
68
|
-
"""
|
|
69
|
-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.uint4)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
def _extend_int4_sign_bits(x: npt.NDArray[np.uint8]) -> npt.NDArray[np.int8]:
|
|
73
|
-
"""Extend 4-bit signed integer to 8-bit signed integer."""
|
|
74
|
-
return np.where((x >> 3) == 0, x, x | 0xF0).astype(np.int8)
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def unpack_int4(
|
|
78
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
79
|
-
) -> npt.NDArray[ml_dtypes.int4]:
|
|
80
|
-
"""Convert a packed (signed) int4 array to unpacked int4 array represented as int8.
|
|
81
|
-
|
|
82
|
-
The sign bit is extended to the most significant bit of the int8.
|
|
83
|
-
|
|
84
|
-
Args:
|
|
85
|
-
data: A numpy array.
|
|
86
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
87
|
-
|
|
88
|
-
Returns:
|
|
89
|
-
A numpy array of int8 reshaped to dims.
|
|
90
|
-
"""
|
|
91
|
-
unpacked = _unpack_uint4_as_uint8(data, dims)
|
|
92
|
-
return _extend_int4_sign_bits(unpacked).view(ml_dtypes.int4)
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
def unpack_float4e2m1(
|
|
96
|
-
data: npt.NDArray[np.uint8], dims: Sequence[int]
|
|
97
|
-
) -> npt.NDArray[ml_dtypes.float4_e2m1fn]:
|
|
98
|
-
"""Convert a packed float4e2m1 array to unpacked float4e2m1 array.
|
|
99
|
-
|
|
100
|
-
Args:
|
|
101
|
-
data: A numpy array.
|
|
102
|
-
dims: The dimensions are used to reshape the unpacked buffer.
|
|
103
|
-
|
|
104
|
-
Returns:
|
|
105
|
-
A numpy array of float32 reshaped to dims.
|
|
106
|
-
"""
|
|
107
|
-
return _unpack_uint4_as_uint8(data, dims).view(ml_dtypes.float4_e2m1fn)
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{onnx_ir-0.1.3 → onnx_ir-0.1.4}/src/onnx_ir/passes/common/common_subexpression_elimination.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|