compressed-tensors 0.9.5a20250507__py3-none-any.whl → 0.9.5a20250512__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,4 +15,5 @@
15
15
 
16
16
  from .base import *
17
17
  from .naive_quantized import *
18
+ from .nvfp4_quantized import *
18
19
  from .pack_quantized import *
@@ -0,0 +1,190 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import numpy
19
+ import torch
20
+ from compressed_tensors.compressors.base import BaseCompressor
21
+ from compressed_tensors.compressors.quantized_compressors.base import (
22
+ BaseQuantizationCompressor,
23
+ )
24
+ from compressed_tensors.config import CompressionFormat
25
+ from compressed_tensors.quantization import QuantizationArgs
26
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
27
+ from torch import Tensor
28
+
29
+
30
+ __all__ = ["pack_fp4_to_uint8", "unpack_fp4_from_uint8"]
31
+
32
+ FLOAT_TO_E2M1 = [
33
+ 0.0,
34
+ 0.5,
35
+ 1.0,
36
+ 1.5,
37
+ 2.0,
38
+ 3.0,
39
+ 4.0,
40
+ 6.0,
41
+ ]
42
+
43
+
44
+ @BaseCompressor.register(name=CompressionFormat.nvfp4_pack_quantized.value)
45
+ class NVFP4PackedCompressor(BaseQuantizationCompressor):
46
+ """
47
+ Implements compression of FP4 values. Weights of each quantized layer
48
+ are packed into uint8. Only supports symmetric weight compression for now.
49
+ """
50
+
51
+ @property
52
+ def compression_param_names(self) -> Tuple[str]:
53
+ """
54
+ Returns a tuple of compression parameter names introduced by
55
+ the compressor during compression
56
+ """
57
+ return (
58
+ "weight_packed",
59
+ "weight_scale",
60
+ "weight_zero_point",
61
+ "weight_global_scale",
62
+ )
63
+
64
+ def compress_weight(
65
+ self,
66
+ weight: Tensor,
67
+ scale: Tensor,
68
+ global_scale: Tensor,
69
+ quantization_args: QuantizationArgs,
70
+ device: Optional[torch.device] = None,
71
+ zero_point: Optional[torch.Tensor] = None,
72
+ g_idx: Optional[torch.Tensor] = None,
73
+ ) -> Dict[str, torch.Tensor]:
74
+
75
+ quantized_weight = quantize(
76
+ x=weight,
77
+ scale=scale,
78
+ global_scale=global_scale,
79
+ zero_point=zero_point,
80
+ args=quantization_args,
81
+ )
82
+ compressed_dict = {}
83
+ weight_packed = pack_fp4_to_uint8(quantized_weight)
84
+ if device is not None:
85
+ weight_packed = weight_packed.to(device)
86
+ compressed_dict["weight_packed"] = weight_packed
87
+ return compressed_dict
88
+
89
+ def decompress_weight(
90
+ self,
91
+ compressed_data: Dict[str, Tensor],
92
+ quantization_args: Optional[QuantizationArgs] = None,
93
+ ) -> torch.Tensor:
94
+
95
+ weight = compressed_data["weight_packed"]
96
+ scale = compressed_data["weight_scale"]
97
+ global_scale = compressed_data["weight_global_scale"]
98
+ m, n = weight.shape
99
+ # TODO: use a user provided dequant dtype
100
+ unpacked = unpack_fp4_from_uint8(weight, m, n * 2)
101
+ decompressed_weight = dequantize(
102
+ x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
103
+ )
104
+
105
+ return decompressed_weight
106
+
107
+
108
+ def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Packs a tensor with values in the fp4 range into uint8.
111
+ As there are 16 valid fp4 values, two fp4 values can be
112
+ packed into one uint8. Each fp4 value is mapped to its
113
+ particular index (e.g. 0.5 is mapped to index 1, 6.0 is mapped
114
+ to index 7) which is then represented using 4 bits. Consecutive
115
+ pairs of 4 bits are then packed into an uint8.
116
+
117
+ :param x: tensor to pack
118
+ returns: a packed tensor in uint8
119
+ """
120
+
121
+ m, n = x.shape
122
+ device = x.device
123
+
124
+ # Create lookup table for FP4 values to indices
125
+ # Map the absolute values to 0-7 indices
126
+ kE2M1 = torch.tensor(FLOAT_TO_E2M1, device=device, dtype=x.dtype)
127
+
128
+ # Find closest valid FP4 value index for each element
129
+ abs_x = torch.abs(x)
130
+ abs_indices = torch.zeros_like(abs_x, dtype=torch.long)
131
+ for i, val in enumerate(kE2M1):
132
+ abs_indices = torch.where(torch.isclose(abs_x, val), i, abs_indices)
133
+
134
+ # Apply sign bit (bit 3) to get final 4-bit representation
135
+ indices = abs_indices + (torch.signbit(x) << 3).to(torch.long)
136
+
137
+ # Reshape to prepare for packing pairs of values
138
+ indices = indices.reshape(-1)
139
+
140
+ # Handle odd length by padding if necessary
141
+ if indices.numel() % 2 != 0:
142
+ indices = torch.cat([indices, torch.zeros(1, dtype=torch.long, device=device)])
143
+
144
+ # Reshape to pair consecutive elements
145
+ indices = indices.reshape(-1, 2)
146
+
147
+ # Pack pairs of 4-bit values into 8-bit values
148
+ packed = (indices[:, 0] | (indices[:, 1] << 4)).to(torch.uint8)
149
+
150
+ return packed.reshape(m, n // 2)
151
+
152
+
153
+ kE2M1ToFloat = torch.tensor(
154
+ [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
155
+ )
156
+
157
+ # reference: : https://github.com/vllm-project/vllm/pull/16362
158
+ def unpack_fp4_from_uint8(
159
+ a: torch.Tensor, m: int, n: int, dtype: Optional[torch.dtype] = torch.bfloat16
160
+ ) -> torch.Tensor:
161
+ """
162
+ Unpacks uint8 values into fp4. Each uint8 consists of two fp4 values
163
+ (i.e. first four bits correspond to one fp4 value, last four corresond to a consecutive
164
+ fp4 value). The bits represent an index, which are mapped to an fp4 value.
165
+
166
+ :param a: tensor to unpack
167
+ :param m: original dim 0 size of the unpacked tensor
168
+ :param n: original dim 1 size of the unpacked tensor
169
+ :param dtype: dense dtype to cast the unpacked tensor to
170
+ """
171
+ assert a.dtype == torch.uint8
172
+
173
+ # Vectorized nibble processing
174
+ a_flat = a.flatten()
175
+ high = (a_flat & 0xF0) >> 4 # Upper nibbles
176
+ low = a_flat & 0x0F # Lower nibbles
177
+
178
+ # Combine nibbles for batch processing
179
+ combined = torch.stack((low, high), dim=1).flatten()
180
+
181
+ # Vectorized sign and magnitude extraction
182
+ signs = (combined & 0x08).to(torch.bool) # Sign bits
183
+ abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
184
+
185
+ # Device-aware lookup and sign application
186
+ kE2M1 = kE2M1ToFloat.to(device=a.device)
187
+ values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
188
+
189
+ # Reshape to final form
190
+ return values.reshape(m, n).to(dtype=dtype)
@@ -32,6 +32,7 @@ class CompressionFormat(Enum):
32
32
  naive_quantized = "naive-quantized"
33
33
  pack_quantized = "pack-quantized"
34
34
  marlin_24 = "marlin-24"
35
+ nvfp4_pack_quantized = "nvfp4-pack-quantized"
35
36
 
36
37
 
37
38
  @unique
@@ -24,6 +24,8 @@ from pydantic import BaseModel, Field, field_validator, model_validator
24
24
 
25
25
  __all__ = [
26
26
  "FP8_DTYPE",
27
+ "FP8_E4M3_DATA",
28
+ "FP4_E2M1_DATA",
27
29
  "QuantizationType",
28
30
  "QuantizationStrategy",
29
31
  "QuantizationArgs",
@@ -31,6 +33,48 @@ __all__ = [
31
33
  "ActivationOrdering",
32
34
  ]
33
35
 
36
+
37
+ class FloatArgs:
38
+ exponent: int
39
+ mantissa: int
40
+ bits: int
41
+ max: float
42
+ min: float
43
+ dtype: Optional[torch.dtype] = None
44
+
45
+
46
+ class FP4_E2M1_DATA(FloatArgs):
47
+ exponent = 2
48
+ mantissa = 1
49
+ bits = 4
50
+ max = 6.0
51
+ min = -6.0
52
+
53
+ @staticmethod
54
+ def cast_to_fp4(x):
55
+ sign = torch.sign(x)
56
+ x = torch.abs(x)
57
+ x[(x >= 0.0) & (x <= 0.25)] = 0.0
58
+ x[(x > 0.25) & (x < 0.75)] = 0.5
59
+ x[(x >= 0.75) & (x <= 1.25)] = 1.0
60
+ x[(x > 1.25) & (x < 1.75)] = 1.5
61
+ x[(x >= 1.75) & (x <= 2.5)] = 2.0
62
+ x[(x > 2.5) & (x < 3.5)] = 3.0
63
+ x[(x >= 3.5) & (x <= 5.0)] = 4.0
64
+ x[x > 5.0] = 6.0
65
+ return x * sign
66
+
67
+
68
+ class FP8_E4M3_DATA(FloatArgs):
69
+ exponent = 4
70
+ mantissa = 3
71
+ bits = 8
72
+ max = torch.finfo(torch.float8_e4m3fn).max
73
+ min = torch.finfo(torch.float8_e4m3fn).min
74
+ dtype = torch.float8_e4m3fn
75
+
76
+
77
+ # TODO: Remove soon in favour of a more descriptive FloatArgs
34
78
  FP8_DTYPE = torch.float8_e4m3fn
35
79
 
36
80
 
@@ -234,7 +278,10 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
234
278
 
235
279
  def pytorch_dtype(self) -> torch.dtype:
236
280
  if self.type == QuantizationType.FLOAT:
237
- return FP8_DTYPE
281
+ if self.num_bits == 8:
282
+ return FP8_E4M3_DATA.dtype
283
+ else:
284
+ raise NotImplementedError("Only num_bits in (8) are supported")
238
285
  elif self.type == QuantizationType.INT:
239
286
  if self.num_bits <= 8:
240
287
  return torch.int8
@@ -263,7 +310,12 @@ def round_to_quantized_type(
263
310
  """
264
311
  original_dtype = tensor.dtype
265
312
  if args.type == QuantizationType.FLOAT:
266
- rounded = tensor.to(FP8_DTYPE)
313
+ if args.num_bits == 8:
314
+ rounded = tensor.to(FP8_E4M3_DATA.dtype)
315
+ elif args.num_bits == 4:
316
+ rounded = FP4_E2M1_DATA.cast_to_fp4(tensor)
317
+ else:
318
+ raise NotImplementedError("Only num_bits in (4, 8) are supported")
267
319
  elif args.type == QuantizationType.INT:
268
320
  rounded = torch.round(tensor)
269
321
  else:
@@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool:
100
100
 
101
101
  UNQUANTIZED = dict()
102
102
 
103
+ NVFP4A16 = dict(
104
+ weights=QuantizationArgs(
105
+ num_bits=4,
106
+ type=QuantizationType.FLOAT,
107
+ strategy=QuantizationStrategy.GROUP,
108
+ symmetric=True,
109
+ dynamic=False,
110
+ group_size=16,
111
+ )
112
+ )
113
+
103
114
  # 8 bit integer weights and 8 bit activations quantization
104
115
  INT8_W8A8 = dict(
105
116
  weights=QuantizationArgs(
@@ -225,4 +236,5 @@ PRESET_SCHEMES = {
225
236
  # Float weight and activation schemes
226
237
  "FP8": FP8,
227
238
  "FP8_DYNAMIC": FP8_DYNAMIC,
239
+ "NVFP4A16": NVFP4A16,
228
240
  }
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250507'
20
+ __version__ = version = '0.9.5.a20250512'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250507
3
+ Version: 0.9.5a20250512
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,14 +1,15 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=9uyHdJcjtC13iNJagGPGRejL7_ESUhf48c_dyUk-TGw,521
3
+ compressed_tensors/version.py,sha256=Yd5MPtXGvm9XWCPM_O99KCK_9DwKmcl2OenuJ4MxlUI,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
8
  compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=uh3Rbyqhjvt8o8On6ioOn6utBKv2siRRmAvgM1lDrxU,26555
9
- compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
9
+ compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
10
  compressed_tensors/compressors/quantized_compressors/base.py,sha256=n0L2QH2_Y1vWtLeQ0uV78y2lV4bviFEAtUKODl8L_nw,8828
11
11
  compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fd0KlkSx6bvZ3xwIkK3jEUdPSUPs56Eua4dEDOtzKW0,5150
12
+ compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
12
13
  compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=SPIHlk8ewip2LcjgkCw02K21EkfUSFSd9qQqL0Pt5eM,11162
13
14
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
14
15
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=PMiWIaW2XSF_esYJlQ12RVW7opeAzavdbkRFtelMFX0,6655
@@ -18,16 +19,16 @@ compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=S8vW0
18
19
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
19
20
  compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=7VRLmtUTg1iJl3mXiOzLPi1RgIOhMISPAwzVi8v2QF0,9951
20
21
  compressed_tensors/config/__init__.py,sha256=8sOoZ6xvYSC79mBvEtO8l6xk4PC80d29AnnJiGMrY2M,737
21
- compressed_tensors/config/base.py,sha256=R3iUmFf1MslEjin5LgwQbmfJHIsS7Uw0UIxfn780uqY,3479
22
+ compressed_tensors/config/base.py,sha256=p3glQHvC2fjodf_SvlelVrTWSIjGXgGC86t8oVOlMng,3529
22
23
  compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
23
24
  compressed_tensors/config/sparse_24_bitmask.py,sha256=Lhj39zT2V1hxftprvxvneyhv45ShlXOKd75DBbDTyTE,1401
24
25
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
25
26
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
26
27
  compressed_tensors/linear/compressed_linear.py,sha256=_m6XpNcI53eeSHO8VdiuAM6UBTdpDhn5Ivd8iRMwEKc,3980
27
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
28
- compressed_tensors/quantization/quant_args.py,sha256=sKpb8DcNObidjXjNol1Tn_Iih3ZXBycSp-fyz68TGhY,9117
29
+ compressed_tensors/quantization/quant_args.py,sha256=2m4WJWBnNjkU-3rVR_f2a6p_BZMvGfMrPyOui8JUwWk,10487
29
30
  compressed_tensors/quantization/quant_config.py,sha256=MxSUcb5dOqMN6LFyD5K2h8X0TvEtcWIAoiUJqD2dHGE,10159
30
- compressed_tensors/quantization/quant_scheme.py,sha256=yz0oMbbwp7QZXXd2k5KIJu-Q6aTqg2929VdUzZ7vysM,6324
31
+ compressed_tensors/quantization/quant_scheme.py,sha256=0FpN3R7bVn8rQ18Vp0NuDVpoilTZ7X8vk9zp_8AndwY,6578
31
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
32
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=DOoxH4jM8r0270GGGUFOpRrgwaisiJi7TV-Q6E8qM8E,18067
33
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
@@ -45,8 +46,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
45
46
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
46
47
  compressed_tensors/utils/safetensors_load.py,sha256=kkkUDmS1H40MFy6FDP-DFGiAYbtqke6bKE7YrAtORtA,11499
47
48
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
48
- compressed_tensors-0.9.5a20250507.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
49
- compressed_tensors-0.9.5a20250507.dist-info/METADATA,sha256=DUcVYkCy5Fa5ayrcaz_7mJ1XjvIMOlAFkVkOQMaClrE,7004
50
- compressed_tensors-0.9.5a20250507.dist-info/WHEEL,sha256=0CuiUZ_p9E4cD6NyLD6UG80LBXYyiSYZOKDm5lp32xk,91
51
- compressed_tensors-0.9.5a20250507.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
52
- compressed_tensors-0.9.5a20250507.dist-info/RECORD,,
49
+ compressed_tensors-0.9.5a20250512.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
+ compressed_tensors-0.9.5a20250512.dist-info/METADATA,sha256=gArKa7gy0jdBGF5PbuNLVh_ZmnXq4CdEp1-7grxpjnw,7004
51
+ compressed_tensors-0.9.5a20250512.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
52
+ compressed_tensors-0.9.5a20250512.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
+ compressed_tensors-0.9.5a20250512.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.3.1)
2
+ Generator: setuptools (80.4.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5