compressed-tensors 0.9.5a20250507__py3-none-any.whl → 0.9.5a20250509__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,4 +15,5 @@
15
15
 
16
16
  from .base import *
17
17
  from .naive_quantized import *
18
+ from .nvfp4_quantized import *
18
19
  from .pack_quantized import *
@@ -0,0 +1,190 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import numpy
19
+ import torch
20
+ from compressed_tensors.compressors.base import BaseCompressor
21
+ from compressed_tensors.compressors.quantized_compressors.base import (
22
+ BaseQuantizationCompressor,
23
+ )
24
+ from compressed_tensors.config import CompressionFormat
25
+ from compressed_tensors.quantization import QuantizationArgs
26
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
27
+ from torch import Tensor
28
+
29
+
30
+ __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
31
+
32
+ FLOAT_TO_E2M1 = [
33
+ 0.0,
34
+ 0.5,
35
+ 1.0,
36
+ 1.5,
37
+ 2.0,
38
+ 3.0,
39
+ 4.0,
40
+ 6.0,
41
+ ]
42
+
43
+
44
+ @BaseCompressor.register(name=CompressionFormat.nvfp4_pack_quantized.value)
45
+ class NVFP4PackedCompressor(BaseQuantizationCompressor):
46
+ """
47
+ Implements compression of FP4 values. Weights of each quantized layer
48
+ are packed into uint8. Only supports symmetric weight compression for now.
49
+ """
50
+
51
+ @property
52
+ def compression_param_names(self) -> Tuple[str]:
53
+ """
54
+ Returns a tuple of compression parameter names introduced by
55
+ the compressor during compression
56
+ """
57
+ return (
58
+ "weight_packed",
59
+ "weight_scale",
60
+ "weight_zero_point",
61
+ "weight_global_scale",
62
+ )
63
+
64
+ def compress_weight(
65
+ self,
66
+ weight: Tensor,
67
+ scale: Tensor,
68
+ global_scale: Tensor,
69
+ quantization_args: QuantizationArgs,
70
+ device: Optional[torch.device] = None,
71
+ zero_point: Optional[torch.Tensor] = None,
72
+ g_idx: Optional[torch.Tensor] = None,
73
+ ) -> Dict[str, torch.Tensor]:
74
+
75
+ quantized_weight = quantize(
76
+ x=weight,
77
+ scale=scale,
78
+ global_scale=global_scale,
79
+ zero_point=zero_point,
80
+ args=quantization_args,
81
+ )
82
+ compressed_dict = {}
83
+ weight_packed = pack_fp4_to_uint8(quantized_weight)
84
+ if device is not None:
85
+ weight_packed = weight_packed.to(device)
86
+ compressed_dict["weight_packed"] = weight_packed
87
+ return compressed_dict
88
+
89
+ def decompress_weight(
90
+ self,
91
+ compressed_data: Dict[str, Tensor],
92
+ quantization_args: Optional[QuantizationArgs] = None,
93
+ ) -> torch.Tensor:
94
+
95
+ weight = compressed_data["weight_packed"]
96
+ scale = compressed_data["weight_scale"]
97
+ global_scale = compressed_data["weight_global_scale"]
98
+ m, n = weight.shape
99
+ # TODO: use a user provided dequant dtype
100
+ unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
101
+ decompressed_weight = dequantize(
102
+ x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
103
+ )
104
+
105
+ return decompressed_weight
106
+
107
+
108
+ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Packs a tensor with values in the fp4 range into uint8.
111
+ As there are 16 valid fp4 values, two fp4 values can be
112
+ packed into one uint8. Each fp4 value is mapped to its
113
+ particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
114
+ to index 7) which is then represented using 4 bits. Consecutive
115
+ pairs of 4 bits are then packed into an uint8.
116
+
117
+ :param x: tensor to pack
118
+ returns: a packed tensor in uint8
119
+ """
120
+
121
+ m, n = x.shape
122
+ device = x.device
123
+
124
+ # Create lookup table for FP4 values to indices
125
+ # Map the absolute values to 0-7 indices
126
+ kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
127
+
128
+ # Find closest valid FP4 value index for each element
129
+ abs_x = torch.abs(x)
130
+ abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
131
+ for i, val in enumerate(kE2M1):
132
+ abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
133
+
134
+ # Apply sign bit (bit 3) to get final 4-bit representation
135
+ indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
136
+
137
+ # Reshape to prepare for packing pairs of values
138
+ indices = indices.reshape(-1)
139
+
140
+ # Handle odd length by padding if necessary
141
+ if indices.numel() % 2 != 0:
142
+ indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
143
+
144
+ # Reshape to pair consecutive elements
145
+ indices = indices.reshape(-1, 2)
146
+
147
+ # Pack pairs of 4-bit values into 8-bit values
148
+ packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
149
+
150
+ return packed.reshape(m, n // 2)
151
+
152
+
153
+ kE2M1ToFloat = torch.tensor(
154
+ [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
155
+ )
156
+
157
+ # reference: : https://github.com/vllm-project/vllm/pull/16362
158
+ def unpack_fp4_from_uint8(
159
+ a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
160
+ ) -> torch.Tensor:
161
+ """
162
+ Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
163
+ (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
164
+ fp4 value). The bits represent an index, which are mapped to an fp4 value.
165
+
166
+ :param a: tensor to unpack
167
+ :param m: original dim 0 size of the unpacked tensor
168
+ :param n: original dim 1 size of the unpacked tensor
169
+ :param dtype: dense dtype to cast the unpacked tensor to
170
+ """
171
+ assert a.dtype == torch.uint8
172
+
173
+ # Vectorized nibble processing
174
+ a_flat = a.flatten()
175
+ high = (a_flat & 0xF0) >> 4 # Upper nibbles
176
+ low = a_flat & 0x0F # Lower nibbles
177
+
178
+ # Combine nibbles for batch processing
179
+ combined = torch.stack((low, high), dim=1).flatten()
180
+
181
+ # Vectorized sign and magnitude extraction
182
+ signs = (combined & 0x08).to(torch.bool) # Sign bits
183
+ abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
184
+
185
+ # Device-aware lookup and sign application
186
+ kE2M1 = kE2M1ToFloat.to(device=a.device)
187
+ values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
188
+
189
+ # Reshape to final form
190
+ return values.reshape(m, n).to(dtype=dtype)
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
32
32
  naive_quantized = "naive-quantized"
33
33
  pack_quantized = "pack-quantized"
34
34
  marlin_24 = "marlin-24"
35
+ nvfp4_pack_quantized = "nvfp4-pack-quantized"
35
36
 
36
37
 
37
38
  @unique
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250507'
20
+ __version__ = version = '0.9.5.a20250509'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250507
3
+ Version: 0.9.5a20250509
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,14 +1,15 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=9uyHdJcjtC13iNJagGPGRejL7_ESUhf48c_dyUk-TGw,521
3
+ compressed_tensors/version.py,sha256=MJNvDWAetuDB0OYncH2iIZtFhxnO8XXV5IHnTcj9e6k,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
8
  compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=uh3Rbyqhjvt8o8On6ioOn6utBKv2siRRmAvgM1lDrxU,26555
9
- compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
9
+ compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=n0L2QH2_Y1vWtLeQ0uV78y2lV4bviFEAtUKODl8L_nw,8828
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fd0KlkSx6bvZ3xwIkK3jEUdPSUPs56Eua4dEDOtzKW0,5150
12
+ compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
12
13
  compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=SPIHlk8ewip2LcjgkCw02K21EkfUSFSd9qQqL0Pt5eM,11162
13
14
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
14
15
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=PMiWIaW2XSF_esYJlQ12RVW7opeAzavdbkRFtelMFX0,6655
@@ -18,7 +19,7 @@ compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0
18
19
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
19
20
  compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=7VRLmtUTg1iJl3mXiOzLPi1RgIOhMISPAwzVi8v2QF0,9951
20
21
  compressed_tensors/config/__init__.py,sha256=8sOoZ6xvYSC79mBvEtO8l6xk4PC80d29AnnJiGMrY2M,737
21
- compressed_tensors/config/base.py,sha256=R3iUmFf1MslEjin5LgwQbmfJHIsS7Uw0UIxfn780uqY,3479
22
+ compressed_tensors/config/base.py,sha256=p3glQHvC2fjodf_SvlelVrTWSIjGXgGC86t8oVOlMng,3529
22
23
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
23
24
  compressed_tensors/config/sparse_24_bitmask.py,sha256=Lhj39zT2V1hxftprvxvneyhv45ShlXOKd75DBbDTyTE,1401
24
25
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
@@ -45,8 +46,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
45
46
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
46
47
  compressed_tensors/utils/safetensors_load.py,sha256=kkkUDmS1H40MFy6FDP-DFGiAYbtqke6bKE7YrAtORtA,11499
47
48
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
48
- compressed_tensors-0.9.5a20250507.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
49
- compressed_tensors-0.9.5a20250507.dist-info/METADATA,sha256=DUcVYkCy5Fa5ayrcaz_7mJ1XjvIMOlAFkVkOQMaClrE,7004
50
- compressed_tensors-0.9.5a20250507.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
51
- compressed_tensors-0.9.5a20250507.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
52
- compressed_tensors-0.9.5a20250507.dist-info/RECORD,,
49
+ compressed_tensors-0.9.5a20250509.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
+ compressed_tensors-0.9.5a20250509.dist-info/METADATA,sha256=2BVor0VJtcZHud4QBjYS5OhLXob0nF9dmeP08RUSA5k,7004
51
+ compressed_tensors-0.9.5a20250509.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
52
+ compressed_tensors-0.9.5a20250509.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
+ compressed_tensors-0.9.5a20250509.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5