compressed-tensors 0.9.5a20250519__py3-none-any.whl → 0.9.5a20250521__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -99,6 +99,7 @@ class BaseQuantizationCompressor(BaseCompressor):
99
99
  scale = model_state.get(prefix + "weight_scale", None)
100
100
  g_idx = model_state.get(prefix + "weight_g_idx", None)
101
101
  zp = model_state.get(prefix + "weight_zero_point", None)
102
+ global_scale = model_state.get(prefix + "weight_global_scale", None)
102
103
 
103
104
  # is scale does not exist, then weight cannot be compressed
104
105
  if scale is None:
@@ -112,6 +113,7 @@ class BaseQuantizationCompressor(BaseCompressor):
112
113
  weight=value,
113
114
  scale=scale,
114
115
  zero_point=zp,
116
+ global_scale=global_scale,
115
117
  g_idx=g_idx,
116
118
  quantization_args=quant_args,
117
119
  device="cpu",
@@ -78,6 +78,7 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
78
78
  zero_point: Optional[Tensor] = None,
79
79
  g_idx: Optional[torch.Tensor] = None,
80
80
  device: Optional[torch.device] = None,
81
+ global_scale: Optional[torch.Tensor] = None,
81
82
  ) -> Dict[str, torch.Tensor]:
82
83
  """
83
84
  Compresses a single uncompressed weight
@@ -90,6 +91,11 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
90
91
  :param device: optional device to move compressed output to
91
92
  :return: dictionary of compressed weight data
92
93
  """
94
+ if global_scale is not None:
95
+ raise ValueError(
96
+ "global_scale is not supported for the NaiveQuantizationCompressor"
97
+ )
98
+
93
99
  if can_quantize(weight, quantization_args):
94
100
  quantized_weight = quantize(
95
101
  x=weight,
@@ -94,6 +94,7 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
94
94
  zero_point: Optional[Tensor] = None,
95
95
  g_idx: Optional[torch.Tensor] = None,
96
96
  device: Optional[torch.device] = None,
97
+ global_scale: Optional[torch.Tensor] = None,
97
98
  ) -> Dict[str, torch.Tensor]:
98
99
  """
99
100
  Compresses a single uncompressed weight
@@ -106,6 +107,11 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
106
107
  :param device: optional device to move compressed output to
107
108
  :return: dictionary of compressed weight data
108
109
  """
110
+ if global_scale is not None:
111
+ raise ValueError(
112
+ "global_scale is not supported for the PackQuantizationCompressor"
113
+ )
114
+
109
115
  compressed_dict = {}
110
116
  if can_quantize(weight, quantization_args):
111
117
  quantized_weight = quantize(
@@ -27,8 +27,14 @@ from compressed_tensors.quantization.lifecycle.compressed import (
27
27
  )
28
28
  from compressed_tensors.quantization.lifecycle.initialize import (
29
29
  initialize_module_for_quantization,
30
+ update_fused_layer_weight_global_scales,
31
+ )
32
+ from compressed_tensors.quantization.quant_args import (
33
+ FP4_E2M1_DATA,
34
+ FP8_E4M3_DATA,
35
+ QuantizationArgs,
36
+ QuantizationType,
30
37
  )
31
- from compressed_tensors.quantization.quant_args import QuantizationArgs
32
38
  from compressed_tensors.quantization.quant_config import (
33
39
  QuantizationConfig,
34
40
  QuantizationStatus,
@@ -266,6 +272,9 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
266
272
  )
267
273
  )
268
274
 
275
+ if status == QuantizationStatus.INITIALIZED:
276
+ update_fused_layer_weight_global_scales(model)
277
+
269
278
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
270
279
  model.apply(compress_quantized_weights)
271
280
 
@@ -20,6 +20,7 @@ import torch
20
20
  from compressed_tensors.quantization.quant_args import (
21
21
  QuantizationArgs,
22
22
  QuantizationStrategy,
23
+ QuantizationType,
23
24
  round_to_quantized_type,
24
25
  )
25
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
@@ -49,6 +50,7 @@ def quantize(
49
50
  args: QuantizationArgs,
50
51
  dtype: Optional[torch.dtype] = None,
51
52
  g_idx: Optional[torch.Tensor] = None,
53
+ global_scale: Optional[torch.Tensor] = None,
52
54
  ) -> torch.Tensor:
53
55
  """
54
56
  Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -63,6 +65,7 @@ def quantize(
63
65
  :param args: quantization args dictating how to quantize x
64
66
  :param dtype: optional dtype to cast the quantized output to
65
67
  :param g_idx: optional mapping from column index to group index
68
+ :param global_scale: optional constant to scale the quantization scale during QDQ
66
69
  :return: fake quantized tensor
67
70
  """
68
71
 
@@ -75,6 +78,7 @@ def quantize(
75
78
  do_quantize=True,
76
79
  do_dequantize=False,
77
80
  g_idx=g_idx,
81
+ global_scale=global_scale,
78
82
  )
79
83
 
80
84
 
@@ -86,6 +90,7 @@ def dequantize(
86
90
  args: Optional[QuantizationArgs] = None,
87
91
  dtype: Optional[torch.dtype] = None,
88
92
  g_idx: Optional[torch.Tensor] = None,
93
+ global_scale: Optional[torch.Tensor] = None,
89
94
  ) -> torch.Tensor:
90
95
  """
91
96
  Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -97,6 +102,7 @@ def dequantize(
97
102
  :param args: quantization args used to quantize x_q
98
103
  :param dtype: optional dtype to cast the dequantized output to
99
104
  :param g_idx: optional mapping from column index to group index
105
+ :param global_scale: optional constant to scale the quantization scale during QDQ
100
106
  :return: dequantized float tensor
101
107
  """
102
108
  if args is None:
@@ -128,6 +134,7 @@ def dequantize(
128
134
  do_dequantize=True,
129
135
  dtype=dtype,
130
136
  g_idx=g_idx,
137
+ global_scale=global_scale,
131
138
  )
132
139
 
133
140
 
@@ -138,6 +145,7 @@ def fake_quantize(
138
145
  zero_point: torch.Tensor,
139
146
  args: QuantizationArgs,
140
147
  g_idx: Optional[torch.Tensor] = None,
148
+ global_scale: Optional[torch.Tensor] = None,
141
149
  ) -> torch.Tensor:
142
150
  """
143
151
  Fake quantize the input tensor x by quantizing then dequantizing with
@@ -151,6 +159,7 @@ def fake_quantize(
151
159
  :param zero_point: zero point tensor
152
160
  :param args: quantization args dictating how to quantize x
153
161
  :param g_idx: optional mapping from column index to group index
162
+ :param global_scale: optional constant to scale the quantization scale during QDQ
154
163
  :return: fake quantized tensor
155
164
  """
156
165
  return _process_quantization(
@@ -161,6 +170,7 @@ def fake_quantize(
161
170
  do_quantize=True,
162
171
  do_dequantize=True,
163
172
  g_idx=g_idx,
173
+ global_scale=global_scale,
164
174
  )
165
175
 
166
176
 
@@ -174,6 +184,7 @@ def _process_quantization(
174
184
  dtype: Optional[torch.dtype] = None,
175
185
  do_quantize: bool = True,
176
186
  do_dequantize: bool = True,
187
+ global_scale: Optional[torch.Tensor] = None,
177
188
  ) -> torch.Tensor:
178
189
  q_min, q_max = calculate_range(args, x.device)
179
190
  group_size = args.group_size
@@ -221,18 +232,21 @@ def _process_quantization(
221
232
  end = start + group_count
222
233
  if do_quantize:
223
234
  output[:, start:end] = _quantize(
224
- x[:, start:end],
225
- sc,
226
- zp,
227
- q_min,
228
- q_max,
229
- args,
235
+ x=x[:, start:end],
236
+ scale=sc,
237
+ zero_point=zp,
238
+ q_min=q_min,
239
+ q_max=q_max,
240
+ args=args,
230
241
  dtype=dtype,
242
+ global_scale=global_scale,
231
243
  )
232
244
 
233
245
  if do_dequantize:
234
246
  input = output[:, start:end] if do_quantize else x[:, start:end]
235
- output[:, start:end] = _dequantize(input, sc, zp)
247
+ output[:, start:end] = _dequantize(
248
+ x_q=input, scale=sc, zero_point=zp, global_scale=global_scale
249
+ )
236
250
 
237
251
  if not is_column_order:
238
252
  output = safe_permute(output, torch.argsort(perm), dim=1)
@@ -240,16 +254,22 @@ def _process_quantization(
240
254
  else: # covers channel, token and tensor strategies
241
255
  if do_quantize:
242
256
  output = _quantize(
243
- x,
244
- scale,
245
- zero_point,
246
- q_min,
247
- q_max,
248
- args,
257
+ x=x,
258
+ scale=scale,
259
+ zero_point=zero_point,
260
+ q_min=q_min,
261
+ q_max=q_max,
262
+ args=args,
249
263
  dtype=dtype,
264
+ global_scale=global_scale,
250
265
  )
251
266
  if do_dequantize:
252
- output = _dequantize(output if do_quantize else x, scale, zero_point)
267
+ output = _dequantize(
268
+ output if do_quantize else x,
269
+ scale=scale,
270
+ zero_point=zero_point,
271
+ global_scale=global_scale,
272
+ )
253
273
 
254
274
  return output
255
275
 
@@ -330,6 +350,7 @@ def forward_quantize(
330
350
  return value
331
351
 
332
352
  g_idx = getattr(module, "weight_g_idx", None)
353
+ global_scale = getattr(module, f"{base_name}_global_scale", None)
333
354
 
334
355
  if args.dynamic:
335
356
  # dynamic quantization - determine the scale/zp on the fly
@@ -345,6 +366,7 @@ def forward_quantize(
345
366
  zero_point=zero_point,
346
367
  args=args,
347
368
  g_idx=g_idx,
369
+ global_scale=global_scale,
348
370
  )
349
371
 
350
372
 
@@ -357,11 +379,18 @@ def _quantize(
357
379
  q_max: torch.Tensor,
358
380
  args: QuantizationArgs,
359
381
  dtype: Optional[torch.dtype] = None,
382
+ global_scale: Optional[torch.Tensor] = None,
360
383
  ) -> torch.Tensor:
361
384
 
385
+ # if a global scale is optionally provided, use it
386
+ # to further scale the local `scale` parameter
387
+ if global_scale:
388
+ scale = scale.to(global_scale.dtype) / global_scale
389
+
362
390
  scaled = x / scale
363
391
  if zero_point is not None:
364
392
  scaled += zero_point.to(x.dtype)
393
+
365
394
  # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
366
395
  clamped_value = torch.clamp(
367
396
  scaled,
@@ -381,7 +410,14 @@ def _dequantize(
381
410
  scale: torch.Tensor,
382
411
  zero_point: torch.Tensor = None,
383
412
  dtype: Optional[torch.dtype] = None,
413
+ global_scale: Optional[torch.Tensor] = None,
384
414
  ) -> torch.Tensor:
415
+
416
+ # if a global scale is optionally provided, use it
417
+ # to further scale the local `scale` parameter
418
+ if global_scale:
419
+ scale = scale.to(global_scale.dtype) / global_scale
420
+
385
421
  dequant_value = x_q.to(scale.dtype)
386
422
 
387
423
  if zero_point is not None:
@@ -16,24 +16,33 @@
16
16
  import logging
17
17
  import math
18
18
  from enum import Enum
19
- from typing import Optional
19
+ from typing import List, Optional
20
20
 
21
21
  import torch
22
22
  from compressed_tensors.quantization.lifecycle.forward import (
23
23
  wrap_module_forward_quantized,
24
24
  )
25
25
  from compressed_tensors.quantization.quant_args import (
26
+ FP4_E2M1_DATA,
27
+ FP8_E4M3_DATA,
26
28
  ActivationOrdering,
27
29
  QuantizationArgs,
28
30
  QuantizationStrategy,
31
+ QuantizationType,
29
32
  )
30
33
  from compressed_tensors.quantization.quant_config import QuantizationStatus
31
34
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
- from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
35
+ from compressed_tensors.quantization.utils import (
36
+ generate_global_scale,
37
+ is_fp4,
38
+ is_kv_cache_quant_scheme,
39
+ iter_named_quantizable_modules,
40
+ )
33
41
  from compressed_tensors.utils import (
34
42
  disable_hf_hook,
35
43
  get_execution_device,
36
44
  register_offload_parameter,
45
+ update_parameter_data,
37
46
  )
38
47
  from torch.nn import Module, Parameter
39
48
 
@@ -42,6 +51,7 @@ __all__ = [
42
51
  "initialize_module_for_quantization",
43
52
  "is_attention_module",
44
53
  "KVCacheScaleType",
54
+ "update_fused_layer_weight_global_scales",
45
55
  ]
46
56
 
47
57
 
@@ -170,7 +180,24 @@ def _initialize_scale_zero_point(
170
180
  # TODO: consider erroring out in the future as if the dtype if not one fo these,
171
181
  # there is likely bug
172
182
 
173
- if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
183
+ if is_fp4(quantization_args=quantization_args) and base_name == "weight":
184
+ scale_dtype = FP8_E4M3_DATA.dtype
185
+ # When applying weight-only FP4 quantization, generate a global_scale
186
+ # This scale is applied during runtime to ensure that the generated
187
+ # local scale falls properly within the FP8 range (i.e max value is FP8_max)
188
+ # which is the expected dtype of NVFP4A16 scales
189
+ value = generate_global_scale(input_tensor=module.weight)
190
+ value = value.to(device)
191
+ init_global_scale = Parameter(value, requires_grad=False)
192
+ register_offload_parameter(
193
+ module, f"{base_name}_global_scale", init_global_scale
194
+ )
195
+
196
+ if scale_dtype not in [
197
+ torch.float16,
198
+ torch.bfloat16,
199
+ torch.float32,
200
+ ] and not is_fp4(quantization_args=quantization_args):
174
201
  scale_dtype = torch.float16
175
202
 
176
203
  # initializes empty scale, zero point, and g_idx parameters for the module
@@ -181,7 +208,11 @@ def _initialize_scale_zero_point(
181
208
  register_offload_parameter(module, f"{base_name}_scale", init_scale)
182
209
 
183
210
  if force_zero_point or not quantization_args.symmetric:
184
- zp_dtype = quantization_args.pytorch_dtype()
211
+ if is_fp4(quantization_args=quantization_args):
212
+ zp_dtype = FP8_E4M3_DATA.dtype
213
+ else:
214
+ zp_dtype = quantization_args.pytorch_dtype()
215
+
185
216
  init_zero_point = Parameter(
186
217
  torch.zeros(expected_shape, device=device, dtype=zp_dtype),
187
218
  requires_grad=False,
@@ -219,3 +250,91 @@ def _initialize_attn_scales(module: Module) -> None:
219
250
  requires_grad=False,
220
251
  )
221
252
  register_offload_parameter(module, KVCacheScaleType.VALUE.value, init_scale)
253
+
254
+
255
+ # TODO: Potentially introduce an argument to turn this off
256
+ # Only relevant for NVFP4A16 currently
257
+ def update_fused_layer_weight_global_scales(model: torch.nn.Module):
258
+ """
259
+ When running NVFP4A16 quantization, update the global scale
260
+ such that q,k,v layers are treated as one tensor with the same
261
+ global_scale and gate_proj/up_proj layers are treated as one tensor
262
+ with the same global scale. This is requirement currently being set
263
+ by vLLM and may be removed in the future OR potentially make it
264
+ an optional step.
265
+
266
+ :param model: model to quantize
267
+ """
268
+
269
+ def _is_attention_module(module: Module):
270
+ return "attention" in module.__class__.__name__.lower() and (
271
+ hasattr(module, "k_proj")
272
+ or hasattr(module, "v_proj")
273
+ or hasattr(module, "qkv_proj")
274
+ )
275
+
276
+ def _is_mlp_module(module: Module):
277
+ return "mlp" in module.__class__.__name__.lower() and (
278
+ hasattr(module, "gate_proj") or hasattr(module, "up_proj")
279
+ )
280
+
281
+ def _valid_fp4_quant(layer_list: List[torch.nn.Linear]):
282
+ """
283
+ Return True if all the linear layers in the layer_list are
284
+ NVFP4A16 quantized.
285
+ """
286
+ for layer in layer_list:
287
+ scheme = getattr(layer, "quantization_scheme", None)
288
+ if scheme is None:
289
+ return False
290
+
291
+ weight_quant_args = scheme.weights
292
+
293
+ if weight_quant_args is None:
294
+ return False
295
+
296
+ if not is_fp4(quantization_args=weight_quant_args):
297
+ return False
298
+ return True
299
+
300
+ for name, submodule in iter_named_quantizable_modules(
301
+ model,
302
+ include_attn=True,
303
+ include_mlp=True,
304
+ ):
305
+
306
+ if _is_attention_module(submodule):
307
+ # already fused/treated as one layer
308
+ if hasattr(submodule, "qkv_proj"):
309
+ continue
310
+
311
+ if not _valid_fp4_quant(
312
+ [submodule.q_proj, submodule.v_proj, submodule.k_proj]
313
+ ):
314
+ continue
315
+
316
+ q_weight = submodule.q_proj.weight.data
317
+ v_weight = submodule.v_proj.weight.data
318
+ k_weight = submodule.k_proj.weight.data
319
+
320
+ value = generate_global_scale(
321
+ input_tensor=torch.cat((q_weight, v_weight, k_weight), dim=0)
322
+ )
323
+
324
+ update_parameter_data(submodule.q_proj, value, "weight_global_scale")
325
+ update_parameter_data(submodule.k_proj, value, "weight_global_scale")
326
+ update_parameter_data(submodule.v_proj, value, "weight_global_scale")
327
+
328
+ if _is_mlp_module(submodule):
329
+ if not _valid_fp4_quant([submodule.gate_proj, submodule.up_proj]):
330
+ continue
331
+
332
+ gate_data = submodule.gate_proj.weight.data
333
+ up_data = submodule.up_proj.weight.data
334
+
335
+ value = generate_global_scale(
336
+ input_tensor=torch.cat((gate_data, up_data), dim=0)
337
+ )
338
+
339
+ update_parameter_data(submodule.gate_proj, value, "weight_global_scale")
340
+ update_parameter_data(submodule.up_proj, value, "weight_global_scale")
@@ -26,6 +26,7 @@ __all__ = [
26
26
  "FP8_DTYPE",
27
27
  "FP8_E4M3_DATA",
28
28
  "FP4_E2M1_DATA",
29
+ "FloatArgs",
29
30
  "QuantizationType",
30
31
  "QuantizationStrategy",
31
32
  "QuantizationArgs",
@@ -268,8 +269,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
268
269
  observer = None
269
270
 
270
271
  elif observer is None:
271
- # default to minmax for non-dynamic cases
272
- observer = "minmax"
272
+ # default to mse for non-dynamic cases
273
+ observer = "mse"
273
274
 
274
275
  # write back modified values
275
276
  model.strategy = strategy
@@ -17,7 +17,9 @@ from typing import Generator, List, Optional, Tuple
17
17
 
18
18
  import torch
19
19
  from compressed_tensors.quantization.quant_args import (
20
- FP8_DTYPE,
20
+ FP4_E2M1_DATA,
21
+ FP8_E4M3_DATA,
22
+ FloatArgs,
21
23
  QuantizationArgs,
22
24
  QuantizationStrategy,
23
25
  QuantizationType,
@@ -44,6 +46,8 @@ __all__ = [
44
46
  "compute_dynamic_scales_and_zp",
45
47
  "calculate_range",
46
48
  "calculate_qparams",
49
+ "generate_global_scale",
50
+ "is_fp4",
47
51
  ]
48
52
 
49
53
  # target the self_attn layer
@@ -53,8 +57,18 @@ KV_CACHE_TARGETS = ["re:.*self_attn$"]
53
57
  _LOGGER: logging.Logger = logging.getLogger(__name__)
54
58
 
55
59
 
60
+ def is_fp4(quantization_args: QuantizationArgs):
61
+ return (
62
+ quantization_args.num_bits == 4
63
+ and quantization_args.type == QuantizationType.FLOAT
64
+ )
65
+
66
+
56
67
  def calculate_qparams(
57
- min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
68
+ min_vals: Tensor,
69
+ max_vals: Tensor,
70
+ quantization_args: QuantizationArgs,
71
+ global_scale: Optional[Tensor] = None,
58
72
  ) -> Tuple[FloatTensor, IntTensor]:
59
73
  """
60
74
  :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
@@ -62,7 +76,11 @@ def calculate_qparams(
62
76
  :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
63
77
  from
64
78
  :param quantization_args: settings to quantization
65
- :return: tuple of the calculated scale(s) and zero point(s)
79
+ :param global_scale: additional global scale to scale the locally generated scale
80
+ currently only applied/supported for Fp4
81
+
82
+ :return: tuple of the calculated scale(s) and zero point(s). For FP4, the calculated
83
+ scale if of dtype FP8
66
84
  """
67
85
  # based on the implementations for consuming quantized values,
68
86
  # 0.0 must always be representable within the quantized range
@@ -73,14 +91,40 @@ def calculate_qparams(
73
91
 
74
92
  bit_min, bit_max = calculate_range(quantization_args, device)
75
93
  bit_range = bit_max - bit_min
76
- zp_dtype = quantization_args.pytorch_dtype()
94
+
95
+ if is_fp4(quantization_args=quantization_args):
96
+ zp_dtype = FP8_E4M3_DATA.dtype
97
+ else:
98
+ zp_dtype = quantization_args.pytorch_dtype()
77
99
 
78
100
  if quantization_args.symmetric:
79
101
  max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
80
- scales = max_val_pos / (float(bit_range) / 2)
81
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
102
+
103
+ if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104
+ # Conditionally scale the generated local scale by a global_scale
105
+ scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
106
+ scales = scales.to(FP8_E4M3_DATA.dtype)
107
+ else:
108
+ scales = max_val_pos / (float(bit_range) / 2)
109
+
110
+ if scales.dtype == FP8_E4M3_DATA.dtype:
111
+ # torch.clamp not supported for FP8
112
+ # use the next largest fp8 value from 0
113
+ scales = torch.where(
114
+ scales == 0,
115
+ torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
116
+ scales,
117
+ )
118
+ else:
119
+ scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
120
+
82
121
  zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
83
122
  else:
123
+ if is_fp4(quantization_args=quantization_args):
124
+ raise NotImplementedError(
125
+ "Asymmetric Quantization is not supported for FP4"
126
+ )
127
+
84
128
  scales = (max_vals - min_vals) / float(bit_range)
85
129
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
86
130
  zero_points = bit_min - (min_vals / scales)
@@ -144,14 +188,16 @@ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
144
188
  q_max = torch.tensor(bit_range / 2 - 1, device=device)
145
189
  q_min = torch.tensor(-bit_range / 2, device=device)
146
190
  elif quantization_args.type == QuantizationType.FLOAT:
147
- if quantization_args.num_bits != 8:
148
- raise ValueError(
149
- "Floating point quantization is only supported for 8 bits,"
150
- f"got {quantization_args.num_bits}"
191
+ if quantization_args.num_bits == 8:
192
+ q_max = torch.tensor(FP8_E4M3_DATA.max, device=device)
193
+ q_min = torch.tensor(FP8_E4M3_DATA.min, device=device)
194
+ elif quantization_args.num_bits == 4:
195
+ q_max = torch.tensor(FP4_E2M1_DATA.max, device=device)
196
+ q_min = torch.tensor(FP4_E2M1_DATA.min, device=device)
197
+ else:
198
+ raise NotImplementedError(
199
+ "Range calculation only supported for 4 and 8 bits"
151
200
  )
152
- fp_range_info = torch.finfo(FP8_DTYPE)
153
- q_max = torch.tensor(fp_range_info.max, device=device)
154
- q_min = torch.tensor(fp_range_info.min, device=device)
155
201
  else:
156
202
  raise ValueError(f"Invalid quantization type {quantization_args.type}")
157
203
 
@@ -249,7 +295,10 @@ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None
249
295
 
250
296
 
251
297
  def iter_named_quantizable_modules(
252
- model: Module, include_children: bool = True, include_attn: bool = False
298
+ model: Module,
299
+ include_children: bool = True,
300
+ include_attn: bool = False,
301
+ include_mlp: bool = False,
253
302
  ) -> Generator[Tuple[str, Module], None, None]:
254
303
  """
255
304
  Yield name and submodule of
@@ -282,6 +331,9 @@ def iter_named_quantizable_modules(
282
331
  if include_attn:
283
332
  if name.endswith("self_attn"):
284
333
  yield name, submodule
334
+ if include_mlp:
335
+ if name.endswith("mlp"):
336
+ yield name, submodule
285
337
 
286
338
 
287
339
  def get_torch_bit_depth(value: torch.Tensor) -> int:
@@ -396,3 +448,24 @@ def parse_out_kv_cache_args(
396
448
  kv_cache_args = None
397
449
 
398
450
  return kv_cache_args, quant_scheme_to_layers
451
+
452
+
453
+ def generate_global_scale(
454
+ input_tensor: torch.Tensor,
455
+ scale_data: Optional[FloatArgs] = FP8_E4M3_DATA,
456
+ quant_data: Optional[FloatArgs] = FP4_E2M1_DATA,
457
+ dtype: Optional[torch.dtype] = torch.float32,
458
+ ):
459
+ """
460
+ Generate a global scale for an entire tensor (input_tensor).
461
+ Goal of the scale is to ensure that the quantization (local) scale
462
+ falls into the approproiate dtype range.
463
+
464
+ E.g. for NVFP4, group (local) scales are in dtype FP8. The global_scale
465
+ attempts to use the entire FP8 dtype range while mapping a per-group max
466
+ to the FP4 max.
467
+ """
468
+ scale_dtype = scale_data.dtype
469
+ tensor_amax = torch.abs(input_tensor.data).max().to(dtype)
470
+ global_scale = scale_data.max * quant_data.max / tensor_amax
471
+ return global_scale.to(dtype)
@@ -19,7 +19,7 @@ of neuralmagic utilities
19
19
 
20
20
  import importlib
21
21
  from collections import defaultdict
22
- from typing import Any, Dict, List, Optional, Type, Union
22
+ from typing import Any, Dict, List, Optional, TypeVar, Union
23
23
 
24
24
 
25
25
  __all__ = [
@@ -32,8 +32,9 @@ __all__ = [
32
32
  ]
33
33
 
34
34
 
35
- _ALIAS_REGISTRY: Dict[Type, Dict[str, str]] = defaultdict(dict)
36
- _REGISTRY: Dict[Type, Dict[str, Any]] = defaultdict(dict)
35
+ _ALIAS_REGISTRY: Dict[type, Dict[str, str]] = defaultdict(dict)
36
+ _REGISTRY: Dict[type, Dict[str, Any]] = defaultdict(dict)
37
+ T = TypeVar("", bound="RegistryMixin")
37
38
 
38
39
 
39
40
  def standardize_lookup_name(name: str) -> str:
@@ -159,7 +160,7 @@ class RegistryMixin:
159
160
  )
160
161
 
161
162
  @classmethod
162
- def load_from_registry(cls, name: str, **constructor_kwargs) -> object:
163
+ def load_from_registry(cls: type[T], name: str, **constructor_kwargs) -> T:
163
164
  """
164
165
  :param name: name of registered class to load
165
166
  :param constructor_kwargs: arguments to pass to the constructor retrieved
@@ -172,7 +173,7 @@ class RegistryMixin:
172
173
  return constructor(**constructor_kwargs)
173
174
 
174
175
  @classmethod
175
- def get_value_from_registry(cls, name: str):
176
+ def get_value_from_registry(cls: type[T], name: str) -> T:
176
177
  """
177
178
  :param name: name to retrieve from the registry
178
179
  :return: value from retrieved the registry for the given name, raises
@@ -200,7 +201,7 @@ class RegistryMixin:
200
201
 
201
202
 
202
203
  def register(
203
- parent_class: Type,
204
+ parent_class: type,
204
205
  value: Any,
205
206
  name: Optional[str] = None,
206
207
  alias: Union[List[str], str, None] = None,
@@ -240,7 +241,7 @@ def register(
240
241
 
241
242
 
242
243
  def get_from_registry(
243
- parent_class: Type, name: str, require_subclass: bool = False
244
+ parent_class: type, name: str, require_subclass: bool = False
244
245
  ) -> Any:
245
246
  """
246
247
  :param parent_class: class that the name is registered under
@@ -276,7 +277,7 @@ def get_from_registry(
276
277
  return retrieved_value
277
278
 
278
279
 
279
- def registered_names(parent_class: Type) -> List[str]:
280
+ def registered_names(parent_class: type) -> List[str]:
280
281
  """
281
282
  :param parent_class: class to look up the registry of
282
283
  :return: all names registered to the given class
@@ -284,7 +285,7 @@ def registered_names(parent_class: Type) -> List[str]:
284
285
  return list(_REGISTRY[parent_class].keys())
285
286
 
286
287
 
287
- def registered_aliases(parent_class: Type) -> List[str]:
288
+ def registered_aliases(parent_class: type) -> List[str]:
288
289
  """
289
290
  :param parent_class: class to look up the registry of
290
291
  :return: all aliases registered to the given class
@@ -297,7 +298,7 @@ def registered_aliases(parent_class: Type) -> List[str]:
297
298
 
298
299
 
299
300
  def register_alias(
300
- name: str, parent_class: Type, alias: Union[str, List[str], None] = None
301
+ name: str, parent_class: type, alias: Union[str, List[str], None] = None
301
302
  ):
302
303
  """
303
304
  Updates the mapping from the alias(es) to the given name.
@@ -352,7 +353,7 @@ def _import_and_get_value_from_module(module_path: str, value_name: str) -> Any:
352
353
  return value
353
354
 
354
355
 
355
- def _validate_subclass(parent_class: Type, child_class: Type):
356
+ def _validate_subclass(parent_class: type, child_class: type):
356
357
  if not issubclass(child_class, parent_class):
357
358
  raise ValueError(
358
359
  f"class {child_class} is not a subclass of the class it is "
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.9.5.a20250519'
20
+ __version__ = version = '0.9.5.a20250521'
21
21
  __version_tuple__ = version_tuple = (0, 9, 5)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: compressed-tensors
3
- Version: 0.9.5a20250519
3
+ Version: 0.9.5a20250521
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,16 +1,16 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
2
  compressed_tensors/base.py,sha256=73HYH7HY7O2roC89yG_piPFnZwrBfn_i7HmKl90SKc0,875
3
- compressed_tensors/version.py,sha256=nXAnufttJXt-FtZQ-qInj1Xx7rNF_ERhtqkUZcqWiEc,521
3
+ compressed_tensors/version.py,sha256=FJ5OPohL511E88TFF_Jipl_3ikvZ6NgmdrYxPbi2vo8,521
4
4
  compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1tWZewqVnOb4lM,825
5
5
  compressed_tensors/compressors/base.py,sha256=nvWsv4xEw1Tkxkxth6TmHplDYXfBeP22xWxOsZERyDY,7204
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
8
  compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=BBJd3Ei6FtqVQLBkOm80G6pSJ11IMTGuTA-FL4n6_5g,32704
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=KvaFBL_Q84LxRGJOV035M8OBoCkAx8kOkfphswgkKWk,745
10
- compressed_tensors/compressors/quantized_compressors/base.py,sha256=4YWT95GIhHETI7glsk_ITrnUzzN1MhEypt-0z9eKqOI,9134
11
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fd0KlkSx6bvZ3xwIkK3jEUdPSUPs56Eua4dEDOtzKW0,5150
10
+ compressed_tensors/compressors/quantized_compressors/base.py,sha256=n_sVSzySHUBgXt-nkLggM1DtB0aEgQmiKhTzcnQU9Dc,9266
11
+ compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=0ANDcuD8aXPqTYNPY6GnX9iS6eXJw6P0TzNV_rYS2l8,5369
12
12
  compressed_tensors/compressors/quantized_compressors/nvfp4_quantized.py,sha256=Gw-lVzk5jrKUlM5UTCiJBmhM5gHzB9mn8r298MVUbDI,6395
13
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=SPIHlk8ewip2LcjgkCw02K21EkfUSFSd9qQqL0Pt5eM,11162
13
+ compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=_66tQ8bxslDUdas-ULORXblPw9kdNNn1UJJU9-ZOGPY,11380
14
14
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=Atuz-OdEgn8OCUhx7Ovd6gXdyImAI186uCR-uR0t_Nk,737
15
15
  compressed_tensors/compressors/sparse_compressors/base.py,sha256=YNZWcHjDleAlqbgRZQ6oJf44MQb_UDNvJGOqhl26uFA,8098
16
16
  compressed_tensors/compressors/sparse_compressors/dense.py,sha256=rPaxbP7P52prWNs4lGaiBbpNvsQLElFMwOrq1oBP2Yg,1733
@@ -26,19 +26,19 @@ compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5y
26
26
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
27
27
  compressed_tensors/linear/compressed_linear.py,sha256=1yo9RyjA0aQ--iuIknFfcSorJn43Mn4CoV-q4JlTJ_o,4052
28
28
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
29
- compressed_tensors/quantization/quant_args.py,sha256=CepGBAURFGxBzTyFXxHwsUs6wYEJ46_jPbEvJYMG0Tw,10491
29
+ compressed_tensors/quantization/quant_args.py,sha256=5-mq43RmbI81z9Xl9pYNv4bqIP5AIT65FgT--4ERsE8,10502
30
30
  compressed_tensors/quantization/quant_config.py,sha256=MxSUcb5dOqMN6LFyD5K2h8X0TvEtcWIAoiUJqD2dHGE,10159
31
31
  compressed_tensors/quantization/quant_scheme.py,sha256=Fx7Ma4bDlFB6OWkHKhOB6_0AOVIOPRgNE_qTwmDLSbc,6586
32
32
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
33
- compressed_tensors/quantization/lifecycle/apply.py,sha256=DOoxH4jM8r0270GGGUFOpRrgwaisiJi7TV-Q6E8qM8E,18067
33
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=-OKZ-FFFfIIoeGTrho8lXx6HVWZQp3Xkn3Q-G0hU-CM,18294
34
34
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
35
- compressed_tensors/quantization/lifecycle/forward.py,sha256=DOWouUqfaLA4Qhg-ojVVBdhhSAlgZqFC26vZARxE0ko,12961
35
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=WY-HY5kXY2Zs9HMpaq44bpolQUAQ1ELrNZC7GM5C4jw,14494
36
36
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
37
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=PaOs3WqlWZFBq9Zc2W_WImdyzSCdZIkqCP5r2jnmokw,7789
37
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=dWXxjYLemjmtrSnb8vyuvNoNTSm8ywmUswze3soKY4o,12041
38
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
39
- compressed_tensors/quantization/utils/helpers.py,sha256=-wX0H7zVysJ67jRRCGbx6BfxbMU_1sqffTf5YUIpPiU,14391
39
+ compressed_tensors/quantization/utils/helpers.py,sha256=w3Ucpdog88b0MnZdJ37VzgtYi1fqrwJafYdfWPc0hTk,16852
40
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
41
- compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
41
+ compressed_tensors/registry/registry.py,sha256=0s15BxdGgzBv8RL4kUJCYcuDOFUh_KZYvNvLEeRqWTc,11956
42
42
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
43
43
  compressed_tensors/utils/helpers.py,sha256=RrNvzD08naEjEiXdU-FdZjQVda1nQywu1hA_GCDj0vg,10415
44
44
  compressed_tensors/utils/offload.py,sha256=JNQ66_6vhSsizhlUaMgyEdBuFolYxbgUuT1mAZrCfKY,15436
@@ -46,8 +46,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
46
46
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
47
47
  compressed_tensors/utils/safetensors_load.py,sha256=DMfZBuUbA6qp_BG_zIWT3ckiEE33K9ob34s-OgzReO4,12057
48
48
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
49
- compressed_tensors-0.9.5a20250519.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
- compressed_tensors-0.9.5a20250519.dist-info/METADATA,sha256=9A6h2qW5-4_2UfY2lCyQSWJuu0RMUsGzvI8YteN27Dg,7004
51
- compressed_tensors-0.9.5a20250519.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
52
- compressed_tensors-0.9.5a20250519.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
- compressed_tensors-0.9.5a20250519.dist-info/RECORD,,
49
+ compressed_tensors-0.9.5a20250521.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
+ compressed_tensors-0.9.5a20250521.dist-info/METADATA,sha256=Xl6EbYwMlKhFyy6VXtxD2x0TsiTDG36YszGdub5wLqM,7004
51
+ compressed_tensors-0.9.5a20250521.dist-info/WHEEL,sha256=zaaOINJESkSfm_4HQVc5ssNzHCPXhJm0kEUakpsEHaU,91
52
+ compressed_tensors-0.9.5a20250521.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
+ compressed_tensors-0.9.5a20250521.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.7.1)
2
+ Generator: setuptools (80.8.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5