compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
49
49
  weight = getattr(module, "weight", None)
50
50
  scale = getattr(module, "weight_scale", None)
51
51
  zero_point = getattr(module, "weight_zero_point", None)
52
+ g_idx = getattr(module, "weight_g_idx", None)
52
53
 
53
- if weight is None or scale is None or zero_point is None:
54
+ if weight is None or scale is None:
54
55
  # no weight, scale, or ZP, nothing to do
55
56
 
56
57
  # mark as compressed here to maintain consistent status throughout the model
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
62
63
  x=weight,
63
64
  scale=scale,
64
65
  zero_point=zero_point,
66
+ g_idx=g_idx,
65
67
  args=scheme.weights,
66
68
  dtype=torch.int8,
67
69
  )
@@ -14,9 +14,10 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Optional
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
20
21
  from compressed_tensors.quantization.observers.helpers import calculate_range
21
22
  from compressed_tensors.quantization.quant_args import (
22
23
  QuantizationArgs,
@@ -25,7 +26,7 @@ from compressed_tensors.quantization.quant_args import (
25
26
  )
26
27
  from compressed_tensors.quantization.quant_config import QuantizationStatus
27
28
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28
- from compressed_tensors.utils import update_parameter_data
29
+ from compressed_tensors.utils import safe_permute, update_parameter_data
29
30
  from torch.nn import Module
30
31
 
31
32
 
@@ -45,6 +46,7 @@ def quantize(
45
46
  zero_point: torch.Tensor,
46
47
  args: QuantizationArgs,
47
48
  dtype: Optional[torch.dtype] = None,
49
+ g_idx: Optional[torch.Tensor] = None,
48
50
  ) -> torch.Tensor:
49
51
  """
50
52
  Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -58,15 +60,9 @@ def quantize(
58
60
  :param zero_point: zero point tensor
59
61
  :param args: quantization args dictating how to quantize x
60
62
  :param dtype: optional dtype to cast the quantized output to
63
+ :param g_idx: optional mapping from column index to group index
61
64
  :return: fake quantized tensor
62
65
  """
63
- # ensure all tensors are on the same device
64
- # assumes that the target device is the input
65
- # tensor's device
66
- if x.device != scale.device:
67
- scale = scale.to(x.device)
68
- if x.device != zero_point.device:
69
- zero_point = zero_point.to(x.device)
70
66
 
71
67
  return _process_quantization(
72
68
  x=x,
@@ -76,6 +72,7 @@ def quantize(
76
72
  dtype=dtype,
77
73
  do_quantize=True,
78
74
  do_dequantize=False,
75
+ g_idx=g_idx,
79
76
  )
80
77
 
81
78
 
@@ -86,6 +83,7 @@ def dequantize(
86
83
  zero_point: torch.Tensor = None,
87
84
  args: QuantizationArgs = None,
88
85
  dtype: Optional[torch.dtype] = None,
86
+ g_idx: Optional[torch.Tensor] = None,
89
87
  ) -> torch.Tensor:
90
88
  """
91
89
  Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -96,6 +94,7 @@ def dequantize(
96
94
  :param zero_point: zero point tensor
97
95
  :param args: quantization args used to quantize x_q
98
96
  :param dtype: optional dtype to cast the dequantized output to
97
+ :param g_idx: optional mapping from column index to group index
99
98
  :return: dequantized float tensor
100
99
  """
101
100
  if args is None:
@@ -126,6 +125,7 @@ def dequantize(
126
125
  do_quantize=False,
127
126
  do_dequantize=True,
128
127
  dtype=dtype,
128
+ g_idx=g_idx,
129
129
  )
130
130
 
131
131
 
@@ -135,6 +135,7 @@ def fake_quantize(
135
135
  scale: torch.Tensor,
136
136
  zero_point: torch.Tensor,
137
137
  args: QuantizationArgs,
138
+ g_idx: Optional[torch.Tensor] = None,
138
139
  ) -> torch.Tensor:
139
140
  """
140
141
  Fake quantize the input tensor x by quantizing then dequantizing with
@@ -147,6 +148,7 @@ def fake_quantize(
147
148
  :param scale: scale tensor
148
149
  :param zero_point: zero point tensor
149
150
  :param args: quantization args dictating how to quantize x
151
+ :param g_idx: optional mapping from column index to group index
150
152
  :return: fake quantized tensor
151
153
  """
152
154
  return _process_quantization(
@@ -156,6 +158,7 @@ def fake_quantize(
156
158
  args=args,
157
159
  do_quantize=True,
158
160
  do_dequantize=True,
161
+ g_idx=g_idx,
159
162
  )
160
163
 
161
164
 
@@ -165,20 +168,18 @@ def _process_quantization(
165
168
  scale: torch.Tensor,
166
169
  zero_point: torch.Tensor,
167
170
  args: QuantizationArgs,
171
+ g_idx: Optional[torch.Tensor] = None,
168
172
  dtype: Optional[torch.dtype] = None,
169
173
  do_quantize: bool = True,
170
174
  do_dequantize: bool = True,
171
175
  ) -> torch.Tensor:
172
-
173
176
  q_min, q_max = calculate_range(args, x.device)
174
177
  group_size = args.group_size
175
178
 
176
179
  if args.strategy == QuantizationStrategy.GROUP:
177
180
  output_dtype = dtype if dtype is not None else x.dtype
178
181
  output = torch.zeros_like(x).to(output_dtype)
179
-
180
- # TODO: vectorize the for loop
181
- # TODO: fix genetric assumption about the tensor size for computing group
182
+ columns = output.shape[1]
182
183
 
183
184
  # TODO: make validation step for inputs
184
185
 
@@ -187,23 +188,38 @@ def _process_quantization(
187
188
  scale = scale.unsqueeze(1)
188
189
  zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
189
190
 
190
- columns = x.shape[1]
191
191
  if columns >= group_size:
192
192
  if columns % group_size != 0:
193
193
  raise ValueError(
194
- "tesnor column shape must be divisble "
194
+ "tensor column shape must be divisble "
195
195
  f"by the given group_size {group_size}"
196
196
  )
197
- for i in range(ceil(columns / group_size)):
198
- # scale.shape should be [nchan, ndim]
199
- # sc.shape should be [nchan, 1] after unsqueeze
200
- sc = scale[:, i].view(-1, 1)
201
- zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
202
197
 
203
- idx = i * group_size
198
+ # support column-order (default) quantization as well as other orderings
199
+ # such as activation ordering. Below checks if g_idx has been initialized
200
+ is_column_order = g_idx is None or -1 in g_idx
201
+ if is_column_order:
202
+ num_groups = int(ceil(columns / group_size))
203
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
204
+
205
+ else:
206
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
207
+ group_sizes = group_sizes[torch.argsort(group_indices)]
208
+
209
+ perm = torch.argsort(g_idx)
210
+ x = safe_permute(x, perm, dim=1)
211
+
212
+ # TODO: experiment with vectorizing for loop for performance
213
+ end = 0
214
+ for index, group_count in enumerate(group_sizes):
215
+ sc = scale[:, index].view(-1, 1)
216
+ zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
217
+
218
+ start = end
219
+ end = start + group_count
204
220
  if do_quantize:
205
- output[:, idx : (idx + group_size)] = _quantize(
206
- x[:, idx : (idx + group_size)],
221
+ output[:, start:end] = _quantize(
222
+ x[:, start:end],
207
223
  sc,
208
224
  zp,
209
225
  q_min,
@@ -211,13 +227,13 @@ def _process_quantization(
211
227
  args,
212
228
  dtype=dtype,
213
229
  )
230
+
214
231
  if do_dequantize:
215
- input = (
216
- output[:, idx : (idx + group_size)]
217
- if do_quantize
218
- else x[:, idx : (idx + group_size)]
219
- )
220
- output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
232
+ input = output[:, start:end] if do_quantize else x[:, start:end]
233
+ output[:, start:end] = _dequantize(input, sc, zp)
234
+
235
+ if not is_column_order:
236
+ output = safe_permute(output, torch.argsort(perm), dim=1)
221
237
 
222
238
  else: # covers channel, token and tensor strategies
223
239
  if do_quantize:
@@ -253,13 +269,15 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
253
269
 
254
270
  input_ = args[0]
255
271
 
272
+ compressed = module.quantization_status == QuantizationStatus.COMPRESSED
273
+
256
274
  if scheme.input_activations is not None:
257
275
  # calibrate and (fake) quantize input activations when applicable
258
276
  input_ = maybe_calibrate_or_quantize(
259
277
  module, input_, "input", scheme.input_activations
260
278
  )
261
279
 
262
- if scheme.weights is not None:
280
+ if scheme.weights is not None and not compressed:
263
281
  # calibrate and (fake) quantize weights when applicable
264
282
  unquantized_weight = self.weight.data.clone()
265
283
  self.weight.data = maybe_calibrate_or_quantize(
@@ -270,15 +288,17 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
270
288
  output = forward_func_orig.__get__(module, module.__class__)(
271
289
  input_, *args[1:], **kwargs
272
290
  )
273
-
274
291
  if scheme.output_activations is not None:
292
+
275
293
  # calibrate and (fake) quantize output activations when applicable
294
+ # kv_cache scales updated on model self_attn forward call in
295
+ # wrap_module_forward_quantized_attn
276
296
  output = maybe_calibrate_or_quantize(
277
297
  module, output, "output", scheme.output_activations
278
298
  )
279
299
 
280
300
  # restore back to unquantized_value
281
- if scheme.weights is not None:
301
+ if scheme.weights is not None and not compressed:
282
302
  self.weight.data = unquantized_weight
283
303
 
284
304
  return output
@@ -289,14 +309,63 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
289
309
  setattr(module, "forward", bound_wrapped_forward)
290
310
 
291
311
 
312
+ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
313
+ # expects a module already initialized and injected with the parameters in
314
+ # initialize_module_for_quantization
315
+ if hasattr(module.forward, "__func__"):
316
+ forward_func_orig = module.forward.__func__
317
+ else:
318
+ forward_func_orig = module.forward.func
319
+
320
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
321
+ def wrapped_forward(self, *args, **kwargs):
322
+
323
+ # kv cache stored under weights
324
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
325
+ quantization_args: QuantizationArgs = scheme.output_activations
326
+ past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
327
+ kwargs["past_key_value"] = past_key_value
328
+
329
+ # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
330
+ # does not store quantized_key_states and quantized_value_state
331
+ kwargs["use_cache"] = False
332
+
333
+ attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
334
+
335
+ past_key_value.reset_states()
336
+
337
+ rtn = attn_forward(*args, **kwargs)
338
+
339
+ update_parameter_data(
340
+ module, past_key_value.k_scales[module.layer_idx], "k_scale"
341
+ )
342
+ update_parameter_data(
343
+ module, past_key_value.v_scales[module.layer_idx], "v_scale"
344
+ )
345
+
346
+ return rtn
347
+
348
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
349
+
350
+ # bind wrapped forward to module class so reference to `self` is correct
351
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
352
+ # set forward to wrapped forward
353
+ setattr(module, "forward", bound_wrapped_forward)
354
+
355
+
292
356
  def maybe_calibrate_or_quantize(
293
357
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
294
358
  ) -> torch.Tensor:
295
- # only run quantized for the included stages
296
- if module.quantization_status not in {
297
- QuantizationStatus.CALIBRATION,
298
- QuantizationStatus.FROZEN,
299
- }:
359
+ # don't run quantization if we haven't entered calibration mode
360
+ if module.quantization_status == QuantizationStatus.INITIALIZED:
361
+ return value
362
+
363
+ # in compressed mode, the weight is already compressed and quantized so we don't
364
+ # need to run fake quantization
365
+ if (
366
+ module.quantization_status == QuantizationStatus.COMPRESSED
367
+ and base_name == "weight"
368
+ ):
300
369
  return value
301
370
 
302
371
  if value.numel() == 0:
@@ -304,14 +373,16 @@ def maybe_calibrate_or_quantize(
304
373
  # skip quantization
305
374
  return value
306
375
 
376
+ g_idx = getattr(module, "weight_g_idx", None)
377
+
307
378
  if args.dynamic:
308
379
  # dynamic quantization - get scale and zero point directly from observer
309
380
  observer = getattr(module, f"{base_name}_observer")
310
- scale, zero_point = observer(value)
381
+ scale, zero_point = observer(value, g_idx=g_idx)
311
382
  else:
312
383
  # static quantization - get previous scale and zero point from layer
313
384
  scale = getattr(module, f"{base_name}_scale")
314
- zero_point = getattr(module, f"{base_name}_zero_point")
385
+ zero_point = getattr(module, f"{base_name}_zero_point", None)
315
386
 
316
387
  if (
317
388
  module.quantization_status == QuantizationStatus.CALIBRATION
@@ -320,13 +391,22 @@ def maybe_calibrate_or_quantize(
320
391
  # calibration mode - get new quant params from observer
321
392
  observer = getattr(module, f"{base_name}_observer")
322
393
 
323
- updated_scale, updated_zero_point = observer(value)
394
+ updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
324
395
 
325
396
  # update scale and zero point
326
397
  update_parameter_data(module, updated_scale, f"{base_name}_scale")
327
398
  update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328
399
 
329
- return fake_quantize(value, scale, zero_point, args)
400
+ scale = updated_scale
401
+ zero_point = updated_zero_point
402
+
403
+ return fake_quantize(
404
+ x=value,
405
+ scale=scale,
406
+ zero_point=zero_point,
407
+ args=args,
408
+ g_idx=g_idx,
409
+ )
330
410
 
331
411
 
332
412
  @torch.no_grad()
@@ -340,7 +420,9 @@ def _quantize(
340
420
  dtype: Optional[torch.dtype] = None,
341
421
  ) -> torch.Tensor:
342
422
 
343
- scaled = x / scale + zero_point.to(x.dtype)
423
+ scaled = x / scale
424
+ if zero_point is not None:
425
+ scaled += zero_point.to(x.dtype)
344
426
  # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
345
427
  clamped_value = torch.clamp(
346
428
  scaled,
@@ -361,11 +443,11 @@ def _dequantize(
361
443
  zero_point: torch.Tensor = None,
362
444
  dtype: Optional[torch.dtype] = None,
363
445
  ) -> torch.Tensor:
446
+ dequant_value = x_q.to(scale.dtype)
364
447
 
365
- dequant_value = x_q
366
448
  if zero_point is not None:
367
449
  dequant_value = dequant_value - zero_point.to(scale.dtype)
368
- dequant_value = dequant_value.to(scale.dtype) * scale
450
+ dequant_value = dequant_value * scale
369
451
 
370
452
  if dtype is not None:
371
453
  dequant_value = dequant_value.to(dtype)
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
17
18
  from torch.nn import Module
18
19
 
19
20
 
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
44
45
  delattr(module, "input_observer")
45
46
  if scheme.weights and not scheme.weights.dynamic:
46
47
  delattr(module, "weight_observer")
47
- if scheme.output_activations and not scheme.output_activations.dynamic:
48
+ if (
49
+ scheme.output_activations
50
+ and not is_kv_cache_quant_scheme(scheme)
51
+ and not scheme.output_activations.dynamic
52
+ ):
48
53
  delattr(module, "output_observer")
49
54
 
50
55
  module.quantization_status = QuantizationStatus.FROZEN
@@ -16,35 +16,15 @@
16
16
  Miscelaneous helpers for the quantization lifecycle
17
17
  """
18
18
 
19
-
20
19
  from torch.nn import Module
21
20
 
22
21
 
23
22
  __all__ = [
24
- "update_layer_weight_quant_params",
25
23
  "enable_quantization",
26
24
  "disable_quantization",
27
25
  ]
28
26
 
29
27
 
30
- def update_layer_weight_quant_params(layer: Module):
31
- weight = getattr(layer, "weight", None)
32
- scale = getattr(layer, "weight_scale", None)
33
- zero_point = getattr(layer, "weight_zero_point", None)
34
- observer = getattr(layer, "weight_observer", None)
35
-
36
- if weight is None or observer is None or scale is None or zero_point is None:
37
- # scale, zp, or observer not calibratable or weight not available
38
- return
39
-
40
- updated_scale, updated_zero_point = observer(weight)
41
-
42
- # update scale and zero point
43
- device = next(layer.parameters()).device
44
- scale.data = updated_scale.to(device)
45
- zero_point.data = updated_zero_point.to(device)
46
-
47
-
48
28
  def enable_quantization(module: Module):
49
29
  module.quantization_enabled = True
50
30