compressed-tensors 0.9.5a20250509__py3-none-any.whl → 0.9.5a20250512__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -24,6 +24,8 @@ from pydantic import BaseModel, Field, field_validator, model_validator
24
24
 
25
25
  __all__ = [
26
26
  "FP8_DTYPE",
27
+ "FP8_E4M3_DATA",
28
+ "FP4_E2M1_DATA",
27
29
  "QuantizationType",
28
30
  "QuantizationStrategy",
29
31
  "QuantizationArgs",
@@ -31,6 +33,48 @@ __all__ = [
31
33
  "ActivationOrdering",
32
34
  ]
33
35
 
36
+
37
+ class FloatArgs:
38
+ exponent: int
39
+ mantissa: int
40
+ bits: int
41
+ max: float
42
+ min: float
43
+ dtype: Optional[torch.dtype] = None
44
+
45
+
46
+ class FP4_E2M1_DATA(FloatArgs):
47
+ exponent = 2
48
+ mantissa = 1
49
+ bits = 4
50
+ max = 6.0
51
+ min = -6.0
52
+
53
+ @staticmethod
54
+ def cast_to_fp4(x):
55
+ sign = torch.sign(x)
56
+ x = torch.abs(x)
57
+ x[(x >= 0.0) & (x <= 0.25)] = 0.0
58
+ x[(x > 0.25) & (x < 0.75)] = 0.5
59
+ x[(x >= 0.75) & (x <= 1.25)] = 1.0
60
+ x[(x > 1.25) & (x < 1.75)] = 1.5
61
+ x[(x >= 1.75) & (x <= 2.5)] = 2.0
62
+ x[(x > 2.5) & (x < 3.5)] = 3.0
63
+ x[(x >= 3.5) & (x <= 5.0)] = 4.0
64
+ x[x > 5.0] = 6.0
65
+ return x * sign
66
+
67
+
68
+ class FP8_E4M3_DATA(FloatArgs):
69
+ exponent = 4
70
+ mantissa = 3
71
+ bits = 8
72
+ max = torch.finfo(torch.float8_e4m3fn).max
73
+ min = torch.finfo(torch.float8_e4m3fn).min
74
+ dtype = torch.float8_e4m3fn
75
+
76
+
77
+ # TODO: Remove soon in favour of a more descriptive FloatArgs
34
78
  FP8_DTYPE = torch.float8_e4m3fn
35
79
 
36
80
 
@@ -234,7 +278,10 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
234
278
 
235
279
  def pytorch_dtype(self) -> torch.dtype:
236
280
  if self.type == QuantizationType.FLOAT:
237
- return FP8_DTYPE
281
+ if self.num_bits == 8:
282
+ return FP8_E4M3_DATA.dtype
283
+ else:
284
+ raise NotImplementedError("Only num_bits in (8) are supported")
238
285
  elif self.type == QuantizationType.INT:
239
286
  if self.num_bits <= 8:
240
287
  return torch.int8
@@ -263,7 +310,12 @@ def round_to_quantized_type(
263
310
  """
264
311
  original_dtype = tensor.dtype
265
312
  if args.type == QuantizationType.FLOAT:
266
- rounded = tensor.to(FP8_DTYPE)
313
+ if args.num_bits == 8:
314
+ rounded = tensor.to(FP8_E4M3_DATA.dtype)
315
+ elif args.num_bits == 4:
316
+ rounded = FP4_E2M1_DATA.cast_to_fp4(tensor)
317
+ else:
318
+ raise NotImplementedError("Only num_bits in (4, 8) are supported")
267
319
  elif args.type == QuantizationType.INT:
268
320
  rounded = torch.round(tensor)
269
321
  else:
@@ -100,6 +100,17 @@ def is_preset_scheme(name: str) -> bool:
100
100
 
101
101
  UNQUANTIZED = dict()
102
102
 
103
+ NVFP4A16 = dict(
104
+ weights=QuantizationArgs(
105
+ num_bits=4,
106
+ type=QuantizationType.FLOAT,
107
+ strategy=QuantizationStrategy.GROUP,
108
+ symmetric=True,
109
+ dynamic=False,
110
+ group_size=16,
111
+ )
112
+ )
113
+
103
114
  # 8 bit integer weights and 8 bit activations quantization
104
115
  INT8_W8A8 = dict(
105
116
  weights=QuantizationArgs(
@@ -225,4 +236,5 @@ PRESET_SCHEMES = {
225
236
  # Float weight and activation schemes
226
237
  "FP8": FP8,
227
238
  "FP8_DYNAMIC": FP8_DYNAMIC,
239
+ "NVFP4A16": NVFP4A16,
228
240
  }
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250509'
20
+ __version__ = version = '0.9.5.a20250512'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250509
3
+ Version: 0.9.5a20250512
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,6 +1,6 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=MJNvDWAetuDB0OYncH2iIZtFhxnO8XXV5IHnTcj9e6k,521
3
+ compressed_tensors/version.py,sha256=Yd5MPtXGvm9XWCPM_O99KCK_9DwKmcl2OenuJ4MxlUI,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
@@ -26,9 +26,9 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
26
26
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
27
27
  compressed_tensors/linear/compressed_linear.py,sha256=_m6XpNcI53eeSHO8VdiuAM6UBTdpDhn5Ivd8iRMwEKc,3980
28
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
29
- compressed_tensors/quantization/quant_args.py,sha256=sKpb8DcNObidjXjNol1Tn_Iih3ZXBycSp-fyz68TGhY,9117
29
+ compressed_tensors/quantization/quant_args.py,sha256=2m4WJWBnNjkU-3rVR_f2a6p_BZMvGfMrPyOui8JUwWk,10487
30
30
  compressed_tensors/quantization/quant_config.py,sha256=MxSUcb5dOqMN6LFyD5K2h8X0TvEtcWIAoiUJqD2dHGE,10159
31
- compressed_tensors/quantization/quant_scheme.py,sha256=yz0oMbbwp7QZXXd2k5KIJu-Q6aTqg2929VdUzZ7vysM,6324
31
+ compressed_tensors/quantization/quant_scheme.py,sha256=0FpN3R7bVn8rQ18Vp0NuDVpoilTZ7X8vk9zp_8AndwY,6578
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
33
  compressed_tensors/quantization/lifecycle/apply.py,sha256=DOoxH4jM8r0270GGGUFOpRrgwaisiJi7TV-Q6E8qM8E,18067
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
@@ -46,8 +46,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
46
46
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
47
47
  compressed_tensors/utils/safetensors_load.py,sha256=kkkUDmS1H40MFy6FDP-DFGiAYbtqke6bKE7YrAtORtA,11499
48
48
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
49
- compressed_tensors-0.9.5a20250509.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
- compressed_tensors-0.9.5a20250509.dist-info/METADATA,sha256=2BVor0VJtcZHud4QBjYS5OhLXob0nF9dmeP08RUSA5k,7004
51
- compressed_tensors-0.9.5a20250509.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
52
- compressed_tensors-0.9.5a20250509.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
- compressed_tensors-0.9.5a20250509.dist-info/RECORD,,
49
+ compressed_tensors-0.9.5a20250512.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
+ compressed_tensors-0.9.5a20250512.dist-info/METADATA,sha256=gArKa7gy0jdBGF5PbuNLVh_ZmnXq4CdEp1-7grxpjnw,7004
51
+ compressed_tensors-0.9.5a20250512.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
52
+ compressed_tensors-0.9.5a20250512.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
+ compressed_tensors-0.9.5a20250512.dist-info/RECORD,,