compressed-tensors 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +38 -102
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  10. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
  11. compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/linear/compressed_linear.py +2 -2
  19. compressed_tensors/quantization/__init__.py +1 -0
  20. compressed_tensors/quantization/cache.py +201 -0
  21. compressed_tensors/quantization/lifecycle/apply.py +19 -3
  22. compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  23. compressed_tensors/quantization/lifecycle/forward.py +58 -7
  24. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  25. compressed_tensors/quantization/lifecycle/helpers.py +0 -47
  26. compressed_tensors/quantization/lifecycle/initialize.py +116 -67
  27. compressed_tensors/quantization/observers/__init__.py +0 -1
  28. compressed_tensors/quantization/observers/helpers.py +40 -2
  29. compressed_tensors/quantization/quant_args.py +34 -4
  30. compressed_tensors/quantization/quant_config.py +14 -2
  31. compressed_tensors/quantization/quant_scheme.py +8 -4
  32. compressed_tensors/quantization/utils/helpers.py +43 -18
  33. compressed_tensors/utils/helpers.py +17 -1
  34. compressed_tensors/version.py +1 -1
  35. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/METADATA +1 -1
  36. compressed_tensors-0.7.1.dist-info/RECORD +58 -0
  37. compressed_tensors/quantization/observers/memoryless.py +0 -56
  38. compressed_tensors-0.6.0.dist-info/RECORD +0 -52
  39. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/LICENSE +0 -0
  40. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/WHEEL +0 -0
  41. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from compressed_tensors.quantization.observers import Observer
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import Tensor
22
+ from transformers import DynamicCache as HFDyanmicCache
23
+
24
+
25
+ class KVCacheScaleType(Enum):
26
+ KEY = "k_scale"
27
+ VALUE = "v_scale"
28
+
29
+
30
+ class QuantizedKVParameterCache(HFDyanmicCache):
31
+
32
+ """
33
+ Quantized KV cache used in the forward call based on HF's dynamic cache.
34
+ Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
+ Singleton, so that the same cache gets reused in all forward call of self_attn.
36
+ Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
+ gets called appropriately.
38
+ The size of tensor is
39
+ `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
+
41
+
42
+ Triggered by adding kv_cache_scheme in the recipe.
43
+
44
+ Example:
45
+
46
+ ```python3
47
+ recipe = '''
48
+ quant_stage:
49
+ quant_modifiers:
50
+ QuantizationModifier:
51
+ kv_cache_scheme:
52
+ num_bits: 8
53
+ type: float
54
+ strategy: tensor
55
+ dynamic: false
56
+ symmetric: true
57
+ '''
58
+
59
+ """
60
+
61
+ _instance = None
62
+ _initialized = False
63
+
64
+ def __new__(cls, *args, **kwargs):
65
+ """Singleton"""
66
+ if cls._instance is None:
67
+ cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
+ return cls._instance
69
+
70
+ def __init__(self, quantization_args: QuantizationArgs):
71
+ if not self._initialized:
72
+ super().__init__()
73
+
74
+ self.quantization_args = quantization_args
75
+
76
+ self.k_observers: List[Observer] = []
77
+ self.v_observers: List[Observer] = []
78
+
79
+ # each index corresponds to layer_idx of the attention layer
80
+ self.k_scales: List[Tensor] = []
81
+ self.v_scales: List[Tensor] = []
82
+
83
+ self.k_zps: List[Tensor] = []
84
+ self.v_zps: List[Tensor] = []
85
+
86
+ self._initialized = True
87
+
88
+ def update(
89
+ self,
90
+ key_states: Tensor,
91
+ value_states: Tensor,
92
+ layer_idx: int,
93
+ cache_kwargs: Optional[Dict[str, Any]] = None,
94
+ ) -> Tuple[Tensor, Tensor]:
95
+ """
96
+ Get the k_scale and v_scale and output the
97
+ fakequant-ed key_states and value_states
98
+ """
99
+
100
+ if len(self.k_observers) <= layer_idx:
101
+ k_observer = self.quantization_args.get_observer()
102
+ v_observer = self.quantization_args.get_observer()
103
+
104
+ self.k_observers.append(k_observer)
105
+ self.v_observers.append(v_observer)
106
+
107
+ q_key_states = self._quantize(
108
+ key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
+ )
110
+ q_value_states = self._quantize(
111
+ value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
+ )
113
+
114
+ qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
+ qdq_value_states = self._dequantize(
116
+ q_value_states, KVCacheScaleType.VALUE, layer_idx
117
+ )
118
+
119
+ keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
+
121
+ return keys_to_return, values_to_return
122
+
123
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
+ """
125
+ Returns the sequence length of the cached states.
126
+ A layer index can be optionally passed.
127
+ """
128
+ if len(self.key_cache) <= layer_idx:
129
+ return 0
130
+ # since we cannot get the seq_length of each layer directly and
131
+ # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
+ # this is a hack to get the actual seq_length for the given layer_idx
133
+ # this part of code otherwise fails when used to
134
+ # verify attn_weight shape in some models
135
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
+
137
+ def reset_states(self):
138
+ """reset the kv states (used in calibration)"""
139
+ self.key_cache: List[Tensor] = []
140
+ self.value_cache: List[Tensor] = []
141
+ # Used in `generate` to keep tally of how many tokens the cache has seen
142
+ self._seen_tokens = 0
143
+ self._quantized_key_cache: List[Tensor] = []
144
+ self._quantized_value_cache: List[Tensor] = []
145
+
146
+ def reset(self):
147
+ """
148
+ Reset the instantiation, create new instance on init
149
+ """
150
+ QuantizedKVParameterCache._instance = None
151
+ QuantizedKVParameterCache._initialized = False
152
+
153
+ def _quantize(self, tensor, kv_type, layer_idx):
154
+ """Quantizes a key/value using a defined quantization method."""
155
+ from compressed_tensors.quantization.lifecycle.forward import quantize
156
+
157
+ if kv_type == KVCacheScaleType.KEY: # key type
158
+ observer = self.k_observers[layer_idx]
159
+ scales = self.k_scales
160
+ zps = self.k_zps
161
+ else:
162
+ assert kv_type == KVCacheScaleType.VALUE
163
+ observer = self.v_observers[layer_idx]
164
+ scales = self.v_scales
165
+ zps = self.v_zps
166
+
167
+ scale, zp = observer(tensor)
168
+ if len(scales) <= layer_idx:
169
+ scales.append(scale)
170
+ zps.append(zp)
171
+ else:
172
+ scales[layer_idx] = scale
173
+ zps[layer_idx] = scale
174
+
175
+ q_tensor = quantize(
176
+ x=tensor,
177
+ scale=scale,
178
+ zero_point=zp,
179
+ args=self.quantization_args,
180
+ )
181
+ return q_tensor
182
+
183
+ def _dequantize(self, qtensor, kv_type, layer_idx):
184
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
+ from compressed_tensors.quantization.lifecycle.forward import dequantize
186
+
187
+ if kv_type == KVCacheScaleType.KEY:
188
+ scale = self.k_scales[layer_idx]
189
+ zp = self.k_zps[layer_idx]
190
+ else:
191
+ assert kv_type == KVCacheScaleType.VALUE
192
+ scale = self.v_scales[layer_idx]
193
+ zp = self.v_zps[layer_idx]
194
+
195
+ qdq_tensor = dequantize(
196
+ x_q=qtensor,
197
+ scale=scale,
198
+ zero_point=zp,
199
+ args=self.quantization_args,
200
+ )
201
+ return qdq_tensor
@@ -43,6 +43,7 @@ from compressed_tensors.quantization.utils import (
43
43
  infer_quantization_status,
44
44
  is_kv_cache_quant_scheme,
45
45
  iter_named_leaf_modules,
46
+ iter_named_quantizable_modules,
46
47
  )
47
48
  from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
48
49
  from compressed_tensors.utils.offload import update_parameter_data
@@ -106,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
106
107
 
107
108
 
108
109
  def apply_quantization_config(
109
- model: Module, config: QuantizationConfig, run_compressed: bool = False
110
- ) -> Dict:
110
+ model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
111
+ ) -> OrderedDict:
111
112
  """
112
113
  Initializes the model for quantization in-place based on the given config
113
114
 
@@ -116,6 +117,10 @@ def apply_quantization_config(
116
117
  :param run_compressed: Whether the model will be run in compressed mode or
117
118
  decompressed fully on load
118
119
  """
120
+ # Workaround for when HF Quantizer passes None, see PR #180
121
+ if config is None:
122
+ return OrderedDict()
123
+
119
124
  # remove reference to the original `config`
120
125
  # argument. This function can mutate it, and we'd
121
126
  # like to keep the original `config` as it is.
@@ -135,15 +140,23 @@ def apply_quantization_config(
135
140
  # list of submodules to ignore
136
141
  ignored_submodules = defaultdict(list)
137
142
  # mark appropriate layers for quantization by setting their quantization schemes
138
- for name, submodule in iter_named_leaf_modules(model):
143
+ for name, submodule in iter_named_quantizable_modules(
144
+ model,
145
+ include_children=True,
146
+ include_attn=True,
147
+ ): # child modules and attention modules
139
148
  # potentially fix module name to remove FSDP wrapper prefix
140
149
  name = fix_fsdp_module_name(name)
141
150
  if matches := find_name_or_class_matches(name, submodule, config.ignore):
142
151
  for match in matches:
143
152
  ignored_submodules[match].append(name)
144
153
  continue # layer matches ignore list, continue
154
+
145
155
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
156
+
146
157
  if targets:
158
+ # mark modules to be quantized by adding
159
+ # quant scheme to the matching layers
147
160
  scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
161
  if run_compressed:
149
162
  format = config.format
@@ -200,6 +213,9 @@ def process_kv_cache_config(
200
213
  :param config: the QuantizationConfig
201
214
  :return: the QuantizationConfig with additional "kv_cache" group
202
215
  """
216
+ if targets == KV_CACHE_TARGETS:
217
+ _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
218
+
203
219
  kv_cache_dict = config.kv_cache_scheme.model_dump()
204
220
  kv_cache_scheme = QuantizationScheme(
205
221
  output_activations=QuantizationArgs(**kv_cache_dict),
@@ -56,10 +56,9 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
56
56
  observer = module.weight_observer
57
57
  g_idx = getattr(module, "weight_g_idx", None)
58
58
 
59
- offloaded = False
60
- if is_module_offloaded(module):
59
+ offloaded = is_module_offloaded(module)
60
+ if offloaded:
61
61
  module._hf_hook.pre_forward(module)
62
- offloaded = True
63
62
 
64
63
  scale, zero_point = observer(module.weight, g_idx=g_idx)
65
64
  update_parameter_data(module, scale, "weight_scale")
@@ -14,10 +14,14 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Optional
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
- from compressed_tensors.quantization.observers.helpers import calculate_range
20
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
21
+ from compressed_tensors.quantization.observers.helpers import (
22
+ calculate_range,
23
+ compute_dynamic_scales_and_zp,
24
+ )
21
25
  from compressed_tensors.quantization.quant_args import (
22
26
  QuantizationArgs,
23
27
  QuantizationStrategy,
@@ -62,6 +66,7 @@ def quantize(
62
66
  :param g_idx: optional mapping from column index to group index
63
67
  :return: fake quantized tensor
64
68
  """
69
+
65
70
  return _process_quantization(
66
71
  x=x,
67
72
  scale=scale,
@@ -165,8 +170,8 @@ def _process_quantization(
165
170
  x: torch.Tensor,
166
171
  scale: torch.Tensor,
167
172
  zero_point: torch.Tensor,
168
- g_idx: Optional[torch.Tensor],
169
173
  args: QuantizationArgs,
174
+ g_idx: Optional[torch.Tensor] = None,
170
175
  dtype: Optional[torch.dtype] = None,
171
176
  do_quantize: bool = True,
172
177
  do_dequantize: bool = True,
@@ -266,6 +271,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
266
271
  return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
267
272
 
268
273
  input_ = args[0]
274
+
269
275
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
270
276
 
271
277
  if scheme.input_activations is not None:
@@ -285,9 +291,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
285
291
  output = forward_func_orig.__get__(module, module.__class__)(
286
292
  input_, *args[1:], **kwargs
287
293
  )
288
-
289
294
  if scheme.output_activations is not None:
295
+
290
296
  # calibrate and (fake) quantize output activations when applicable
297
+ # kv_cache scales updated on model self_attn forward call in
298
+ # wrap_module_forward_quantized_attn
291
299
  output = maybe_calibrate_or_quantize(
292
300
  module, output, "output", scheme.output_activations
293
301
  )
@@ -304,6 +312,50 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
304
312
  setattr(module, "forward", bound_wrapped_forward)
305
313
 
306
314
 
315
+ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
316
+ # expects a module already initialized and injected with the parameters in
317
+ # initialize_module_for_quantization
318
+ if hasattr(module.forward, "__func__"):
319
+ forward_func_orig = module.forward.__func__
320
+ else:
321
+ forward_func_orig = module.forward.func
322
+
323
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
324
+ def wrapped_forward(self, *args, **kwargs):
325
+
326
+ # kv cache stored under weights
327
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
328
+ quantization_args: QuantizationArgs = scheme.output_activations
329
+ past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
330
+ kwargs["past_key_value"] = past_key_value
331
+
332
+ # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
333
+ # does not store quantized_key_states and quantized_value_state
334
+ kwargs["use_cache"] = False
335
+
336
+ attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
337
+
338
+ past_key_value.reset_states()
339
+
340
+ rtn = attn_forward(*args, **kwargs)
341
+
342
+ update_parameter_data(
343
+ module, past_key_value.k_scales[module.layer_idx], "k_scale"
344
+ )
345
+ update_parameter_data(
346
+ module, past_key_value.v_scales[module.layer_idx], "v_scale"
347
+ )
348
+
349
+ return rtn
350
+
351
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
352
+
353
+ # bind wrapped forward to module class so reference to `self` is correct
354
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
355
+ # set forward to wrapped forward
356
+ setattr(module, "forward", bound_wrapped_forward)
357
+
358
+
307
359
  def maybe_calibrate_or_quantize(
308
360
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
309
361
  ) -> torch.Tensor:
@@ -327,9 +379,8 @@ def maybe_calibrate_or_quantize(
327
379
  g_idx = getattr(module, "weight_g_idx", None)
328
380
 
329
381
  if args.dynamic:
330
- # dynamic quantization - get scale and zero point directly from observer
331
- observer = getattr(module, f"{base_name}_observer")
332
- scale, zero_point = observer(value, g_idx=g_idx)
382
+ # dynamic quantization - no need to invoke observer
383
+ scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
333
384
  else:
334
385
  # static quantization - get previous scale and zero point from layer
335
386
  scale = getattr(module, f"{base_name}_scale")
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
17
18
  from torch.nn import Module
18
19
 
19
20
 
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
44
45
  delattr(module, "input_observer")
45
46
  if scheme.weights and not scheme.weights.dynamic:
46
47
  delattr(module, "weight_observer")
47
- if scheme.output_activations and not scheme.output_activations.dynamic:
48
+ if (
49
+ scheme.output_activations
50
+ and not is_kv_cache_quant_scheme(scheme)
51
+ and not scheme.output_activations.dynamic
52
+ ):
48
53
  delattr(module, "output_observer")
49
54
 
50
55
  module.quantization_status = QuantizationStatus.FROZEN
@@ -16,62 +16,15 @@
16
16
  Miscelaneous helpers for the quantization lifecycle
17
17
  """
18
18
 
19
- from typing import Optional
20
-
21
- import torch
22
19
  from torch.nn import Module
23
20
 
24
21
 
25
22
  __all__ = [
26
- "update_layer_weight_quant_params",
27
23
  "enable_quantization",
28
24
  "disable_quantization",
29
25
  ]
30
26
 
31
27
 
32
- def update_layer_weight_quant_params(
33
- layer: Module,
34
- weight: Optional[torch.Tensor] = None,
35
- g_idx: Optional[torch.Tensor] = None,
36
- reset_obs: bool = False,
37
- ):
38
- """
39
- Update quantization parameters on layer
40
-
41
- :param layer: input layer
42
- :param weight: weight to update quant params with, defaults to layer weight
43
- :param g_idx: optional mapping from column index to group index
44
- :param reset_obs: reset the observer before calculating quant params,
45
- defaults to False
46
- """
47
- attached_weight = getattr(layer, "weight", None)
48
-
49
- if weight is None:
50
- weight = attached_weight
51
- scale = getattr(layer, "weight_scale", None)
52
- zero_point = getattr(layer, "weight_zero_point", None)
53
- if g_idx is None:
54
- g_idx = getattr(layer, "weight_g_idx", None)
55
- observer = getattr(layer, "weight_observer", None)
56
-
57
- if weight is None or observer is None or scale is None or zero_point is None:
58
- # scale, zp, or observer not calibratable or weight not available
59
- return
60
-
61
- if reset_obs:
62
- observer.reset()
63
-
64
- if attached_weight is not None:
65
- weight = weight.to(attached_weight.dtype)
66
-
67
- updated_scale, updated_zero_point = observer(weight)
68
-
69
- # update scale and zero point
70
- device = next(layer.parameters()).device
71
- scale.data = updated_scale.to(device)
72
- zero_point.data = updated_zero_point.to(device)
73
-
74
-
75
28
  def enable_quantization(module: Module):
76
29
  module.quantization_enabled = True
77
30