compressed-tensors 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +38 -102
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  10. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
  11. compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/linear/compressed_linear.py +2 -2
  19. compressed_tensors/quantization/__init__.py +1 -0
  20. compressed_tensors/quantization/cache.py +201 -0
  21. compressed_tensors/quantization/lifecycle/apply.py +19 -3
  22. compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  23. compressed_tensors/quantization/lifecycle/forward.py +52 -3
  24. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  25. compressed_tensors/quantization/lifecycle/helpers.py +0 -47
  26. compressed_tensors/quantization/lifecycle/initialize.py +110 -62
  27. compressed_tensors/quantization/quant_args.py +6 -0
  28. compressed_tensors/quantization/quant_config.py +14 -2
  29. compressed_tensors/quantization/quant_scheme.py +5 -4
  30. compressed_tensors/quantization/utils/helpers.py +43 -18
  31. compressed_tensors/utils/helpers.py +17 -1
  32. compressed_tensors/version.py +1 -1
  33. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +1 -1
  34. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  35. compressed_tensors-0.6.0.dist-info/RECORD +0 -52
  36. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  37. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +0 -0
  38. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from compressed_tensors.quantization.observers import Observer
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import Tensor
22
+ from transformers import DynamicCache as HFDyanmicCache
23
+
24
+
25
+ class KVCacheScaleType(Enum):
26
+ KEY = "k_scale"
27
+ VALUE = "v_scale"
28
+
29
+
30
+ class QuantizedKVParameterCache(HFDyanmicCache):
31
+
32
+ """
33
+ Quantized KV cache used in the forward call based on HF's dynamic cache.
34
+ Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
+ Singleton, so that the same cache gets reused in all forward call of self_attn.
36
+ Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
+ gets called appropriately.
38
+ The size of tensor is
39
+ `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
+
41
+
42
+ Triggered by adding kv_cache_scheme in the recipe.
43
+
44
+ Example:
45
+
46
+ ```python3
47
+ recipe = '''
48
+ quant_stage:
49
+ quant_modifiers:
50
+ QuantizationModifier:
51
+ kv_cache_scheme:
52
+ num_bits: 8
53
+ type: float
54
+ strategy: tensor
55
+ dynamic: false
56
+ symmetric: true
57
+ '''
58
+
59
+ """
60
+
61
+ _instance = None
62
+ _initialized = False
63
+
64
+ def __new__(cls, *args, **kwargs):
65
+ """Singleton"""
66
+ if cls._instance is None:
67
+ cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
+ return cls._instance
69
+
70
+ def __init__(self, quantization_args: QuantizationArgs):
71
+ if not self._initialized:
72
+ super().__init__()
73
+
74
+ self.quantization_args = quantization_args
75
+
76
+ self.k_observers: List[Observer] = []
77
+ self.v_observers: List[Observer] = []
78
+
79
+ # each index corresponds to layer_idx of the attention layer
80
+ self.k_scales: List[Tensor] = []
81
+ self.v_scales: List[Tensor] = []
82
+
83
+ self.k_zps: List[Tensor] = []
84
+ self.v_zps: List[Tensor] = []
85
+
86
+ self._initialized = True
87
+
88
+ def update(
89
+ self,
90
+ key_states: Tensor,
91
+ value_states: Tensor,
92
+ layer_idx: int,
93
+ cache_kwargs: Optional[Dict[str, Any]] = None,
94
+ ) -> Tuple[Tensor, Tensor]:
95
+ """
96
+ Get the k_scale and v_scale and output the
97
+ fakequant-ed key_states and value_states
98
+ """
99
+
100
+ if len(self.k_observers) <= layer_idx:
101
+ k_observer = self.quantization_args.get_observer()
102
+ v_observer = self.quantization_args.get_observer()
103
+
104
+ self.k_observers.append(k_observer)
105
+ self.v_observers.append(v_observer)
106
+
107
+ q_key_states = self._quantize(
108
+ key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
+ )
110
+ q_value_states = self._quantize(
111
+ value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
+ )
113
+
114
+ qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
+ qdq_value_states = self._dequantize(
116
+ q_value_states, KVCacheScaleType.VALUE, layer_idx
117
+ )
118
+
119
+ keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
+
121
+ return keys_to_return, values_to_return
122
+
123
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
+ """
125
+ Returns the sequence length of the cached states.
126
+ A layer index can be optionally passed.
127
+ """
128
+ if len(self.key_cache) <= layer_idx:
129
+ return 0
130
+ # since we cannot get the seq_length of each layer directly and
131
+ # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
+ # this is a hack to get the actual seq_length for the given layer_idx
133
+ # this part of code otherwise fails when used to
134
+ # verify attn_weight shape in some models
135
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
+
137
+ def reset_states(self):
138
+ """reset the kv states (used in calibration)"""
139
+ self.key_cache: List[Tensor] = []
140
+ self.value_cache: List[Tensor] = []
141
+ # Used in `generate` to keep tally of how many tokens the cache has seen
142
+ self._seen_tokens = 0
143
+ self._quantized_key_cache: List[Tensor] = []
144
+ self._quantized_value_cache: List[Tensor] = []
145
+
146
+ def reset(self):
147
+ """
148
+ Reset the instantiation, create new instance on init
149
+ """
150
+ QuantizedKVParameterCache._instance = None
151
+ QuantizedKVParameterCache._initialized = False
152
+
153
+ def _quantize(self, tensor, kv_type, layer_idx):
154
+ """Quantizes a key/value using a defined quantization method."""
155
+ from compressed_tensors.quantization.lifecycle.forward import quantize
156
+
157
+ if kv_type == KVCacheScaleType.KEY: # key type
158
+ observer = self.k_observers[layer_idx]
159
+ scales = self.k_scales
160
+ zps = self.k_zps
161
+ else:
162
+ assert kv_type == KVCacheScaleType.VALUE
163
+ observer = self.v_observers[layer_idx]
164
+ scales = self.v_scales
165
+ zps = self.v_zps
166
+
167
+ scale, zp = observer(tensor)
168
+ if len(scales) <= layer_idx:
169
+ scales.append(scale)
170
+ zps.append(zp)
171
+ else:
172
+ scales[layer_idx] = scale
173
+ zps[layer_idx] = scale
174
+
175
+ q_tensor = quantize(
176
+ x=tensor,
177
+ scale=scale,
178
+ zero_point=zp,
179
+ args=self.quantization_args,
180
+ )
181
+ return q_tensor
182
+
183
+ def _dequantize(self, qtensor, kv_type, layer_idx):
184
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
+ from compressed_tensors.quantization.lifecycle.forward import dequantize
186
+
187
+ if kv_type == KVCacheScaleType.KEY:
188
+ scale = self.k_scales[layer_idx]
189
+ zp = self.k_zps[layer_idx]
190
+ else:
191
+ assert kv_type == KVCacheScaleType.VALUE
192
+ scale = self.v_scales[layer_idx]
193
+ zp = self.v_zps[layer_idx]
194
+
195
+ qdq_tensor = dequantize(
196
+ x_q=qtensor,
197
+ scale=scale,
198
+ zero_point=zp,
199
+ args=self.quantization_args,
200
+ )
201
+ return qdq_tensor
@@ -43,6 +43,7 @@ from compressed_tensors.quantization.utils import (
43
43
  infer_quantization_status,
44
44
  is_kv_cache_quant_scheme,
45
45
  iter_named_leaf_modules,
46
+ iter_named_quantizable_modules,
46
47
  )
47
48
  from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
48
49
  from compressed_tensors.utils.offload import update_parameter_data
@@ -106,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
106
107
 
107
108
 
108
109
  def apply_quantization_config(
109
- model: Module, config: QuantizationConfig, run_compressed: bool = False
110
- ) -> Dict:
110
+ model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
111
+ ) -> OrderedDict:
111
112
  """
112
113
  Initializes the model for quantization in-place based on the given config
113
114
 
@@ -116,6 +117,10 @@ def apply_quantization_config(
116
117
  :param run_compressed: Whether the model will be run in compressed mode or
117
118
  decompressed fully on load
118
119
  """
120
+ # Workaround for when HF Quantizer passes None, see PR #180
121
+ if config is None:
122
+ return OrderedDict()
123
+
119
124
  # remove reference to the original `config`
120
125
  # argument. This function can mutate it, and we'd
121
126
  # like to keep the original `config` as it is.
@@ -135,15 +140,23 @@ def apply_quantization_config(
135
140
  # list of submodules to ignore
136
141
  ignored_submodules = defaultdict(list)
137
142
  # mark appropriate layers for quantization by setting their quantization schemes
138
- for name, submodule in iter_named_leaf_modules(model):
143
+ for name, submodule in iter_named_quantizable_modules(
144
+ model,
145
+ include_children=True,
146
+ include_attn=True,
147
+ ): # child modules and attention modules
139
148
  # potentially fix module name to remove FSDP wrapper prefix
140
149
  name = fix_fsdp_module_name(name)
141
150
  if matches := find_name_or_class_matches(name, submodule, config.ignore):
142
151
  for match in matches:
143
152
  ignored_submodules[match].append(name)
144
153
  continue # layer matches ignore list, continue
154
+
145
155
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
156
+
146
157
  if targets:
158
+ # mark modules to be quantized by adding
159
+ # quant scheme to the matching layers
147
160
  scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
161
  if run_compressed:
149
162
  format = config.format
@@ -200,6 +213,9 @@ def process_kv_cache_config(
200
213
  :param config: the QuantizationConfig
201
214
  :return: the QuantizationConfig with additional "kv_cache" group
202
215
  """
216
+ if targets == KV_CACHE_TARGETS:
217
+ _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
218
+
203
219
  kv_cache_dict = config.kv_cache_scheme.model_dump()
204
220
  kv_cache_scheme = QuantizationScheme(
205
221
  output_activations=QuantizationArgs(**kv_cache_dict),
@@ -56,10 +56,9 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
56
56
  observer = module.weight_observer
57
57
  g_idx = getattr(module, "weight_g_idx", None)
58
58
 
59
- offloaded = False
60
- if is_module_offloaded(module):
59
+ offloaded = is_module_offloaded(module)
60
+ if offloaded:
61
61
  module._hf_hook.pre_forward(module)
62
- offloaded = True
63
62
 
64
63
  scale, zero_point = observer(module.weight, g_idx=g_idx)
65
64
  update_parameter_data(module, scale, "weight_scale")
@@ -14,9 +14,10 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Optional
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
20
21
  from compressed_tensors.quantization.observers.helpers import calculate_range
21
22
  from compressed_tensors.quantization.quant_args import (
22
23
  QuantizationArgs,
@@ -62,6 +63,7 @@ def quantize(
62
63
  :param g_idx: optional mapping from column index to group index
63
64
  :return: fake quantized tensor
64
65
  """
66
+
65
67
  return _process_quantization(
66
68
  x=x,
67
69
  scale=scale,
@@ -165,8 +167,8 @@ def _process_quantization(
165
167
  x: torch.Tensor,
166
168
  scale: torch.Tensor,
167
169
  zero_point: torch.Tensor,
168
- g_idx: Optional[torch.Tensor],
169
170
  args: QuantizationArgs,
171
+ g_idx: Optional[torch.Tensor] = None,
170
172
  dtype: Optional[torch.dtype] = None,
171
173
  do_quantize: bool = True,
172
174
  do_dequantize: bool = True,
@@ -266,6 +268,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
266
268
  return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
267
269
 
268
270
  input_ = args[0]
271
+
269
272
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
270
273
 
271
274
  if scheme.input_activations is not None:
@@ -285,9 +288,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
285
288
  output = forward_func_orig.__get__(module, module.__class__)(
286
289
  input_, *args[1:], **kwargs
287
290
  )
288
-
289
291
  if scheme.output_activations is not None:
292
+
290
293
  # calibrate and (fake) quantize output activations when applicable
294
+ # kv_cache scales updated on model self_attn forward call in
295
+ # wrap_module_forward_quantized_attn
291
296
  output = maybe_calibrate_or_quantize(
292
297
  module, output, "output", scheme.output_activations
293
298
  )
@@ -304,6 +309,50 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
304
309
  setattr(module, "forward", bound_wrapped_forward)
305
310
 
306
311
 
312
+ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
313
+ # expects a module already initialized and injected with the parameters in
314
+ # initialize_module_for_quantization
315
+ if hasattr(module.forward, "__func__"):
316
+ forward_func_orig = module.forward.__func__
317
+ else:
318
+ forward_func_orig = module.forward.func
319
+
320
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
321
+ def wrapped_forward(self, *args, **kwargs):
322
+
323
+ # kv cache stored under weights
324
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
325
+ quantization_args: QuantizationArgs = scheme.output_activations
326
+ past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
327
+ kwargs["past_key_value"] = past_key_value
328
+
329
+ # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
330
+ # does not store quantized_key_states and quantized_value_state
331
+ kwargs["use_cache"] = False
332
+
333
+ attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
334
+
335
+ past_key_value.reset_states()
336
+
337
+ rtn = attn_forward(*args, **kwargs)
338
+
339
+ update_parameter_data(
340
+ module, past_key_value.k_scales[module.layer_idx], "k_scale"
341
+ )
342
+ update_parameter_data(
343
+ module, past_key_value.v_scales[module.layer_idx], "v_scale"
344
+ )
345
+
346
+ return rtn
347
+
348
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
349
+
350
+ # bind wrapped forward to module class so reference to `self` is correct
351
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
352
+ # set forward to wrapped forward
353
+ setattr(module, "forward", bound_wrapped_forward)
354
+
355
+
307
356
  def maybe_calibrate_or_quantize(
308
357
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
309
358
  ) -> torch.Tensor:
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
17
18
  from torch.nn import Module
18
19
 
19
20
 
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
44
45
  delattr(module, "input_observer")
45
46
  if scheme.weights and not scheme.weights.dynamic:
46
47
  delattr(module, "weight_observer")
47
- if scheme.output_activations and not scheme.output_activations.dynamic:
48
+ if (
49
+ scheme.output_activations
50
+ and not is_kv_cache_quant_scheme(scheme)
51
+ and not scheme.output_activations.dynamic
52
+ ):
48
53
  delattr(module, "output_observer")
49
54
 
50
55
  module.quantization_status = QuantizationStatus.FROZEN
@@ -16,62 +16,15 @@
16
16
  Miscelaneous helpers for the quantization lifecycle
17
17
  """
18
18
 
19
- from typing import Optional
20
-
21
- import torch
22
19
  from torch.nn import Module
23
20
 
24
21
 
25
22
  __all__ = [
26
- "update_layer_weight_quant_params",
27
23
  "enable_quantization",
28
24
  "disable_quantization",
29
25
  ]
30
26
 
31
27
 
32
- def update_layer_weight_quant_params(
33
- layer: Module,
34
- weight: Optional[torch.Tensor] = None,
35
- g_idx: Optional[torch.Tensor] = None,
36
- reset_obs: bool = False,
37
- ):
38
- """
39
- Update quantization parameters on layer
40
-
41
- :param layer: input layer
42
- :param weight: weight to update quant params with, defaults to layer weight
43
- :param g_idx: optional mapping from column index to group index
44
- :param reset_obs: reset the observer before calculating quant params,
45
- defaults to False
46
- """
47
- attached_weight = getattr(layer, "weight", None)
48
-
49
- if weight is None:
50
- weight = attached_weight
51
- scale = getattr(layer, "weight_scale", None)
52
- zero_point = getattr(layer, "weight_zero_point", None)
53
- if g_idx is None:
54
- g_idx = getattr(layer, "weight_g_idx", None)
55
- observer = getattr(layer, "weight_observer", None)
56
-
57
- if weight is None or observer is None or scale is None or zero_point is None:
58
- # scale, zp, or observer not calibratable or weight not available
59
- return
60
-
61
- if reset_obs:
62
- observer.reset()
63
-
64
- if attached_weight is not None:
65
- weight = weight.to(attached_weight.dtype)
66
-
67
- updated_scale, updated_zero_point = observer(weight)
68
-
69
- # update scale and zero point
70
- device = next(layer.parameters()).device
71
- scale.data = updated_scale.to(device)
72
- zero_point.data = updated_zero_point.to(device)
73
-
74
-
75
28
  def enable_quantization(module: Module):
76
29
  module.quantization_enabled = True
77
30
 
@@ -17,8 +17,10 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import KVCacheScaleType
20
21
  from compressed_tensors.quantization.lifecycle.forward import (
21
22
  wrap_module_forward_quantized,
23
+ wrap_module_forward_quantized_attn,
22
24
  )
23
25
  from compressed_tensors.quantization.quant_args import (
24
26
  ActivationOrdering,
@@ -27,6 +29,7 @@ from compressed_tensors.quantization.quant_args import (
27
29
  )
28
30
  from compressed_tensors.quantization.quant_config import QuantizationStatus
29
31
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
30
33
  from compressed_tensors.utils import get_execution_device, is_module_offloaded
31
34
  from torch.nn import Module, Parameter
32
35
 
@@ -62,72 +65,85 @@ def initialize_module_for_quantization(
62
65
  # no scheme passed and layer not targeted for quantization - skip
63
66
  return
64
67
 
65
- if scheme.input_activations is not None:
66
- _initialize_scale_zero_point_observer(
67
- module, "input", scheme.input_activations, force_zero_point=force_zero_point
68
- )
69
- if scheme.weights is not None:
70
- if hasattr(module, "weight"):
71
- weight_shape = module.weight.shape
68
+ if is_attention_module(module):
69
+ # wrap forward call of module to perform
70
+ # quantized actions based on calltime status
71
+ wrap_module_forward_quantized_attn(module, scheme)
72
+ _initialize_attn_scales(module)
73
+
74
+ else:
75
+
76
+ if scheme.input_activations is not None:
72
77
  _initialize_scale_zero_point_observer(
73
78
  module,
74
- "weight",
75
- scheme.weights,
76
- weight_shape=weight_shape,
79
+ "input",
80
+ scheme.input_activations,
77
81
  force_zero_point=force_zero_point,
78
82
  )
79
- else:
80
- _LOGGER.warning(
81
- f"module type {type(module)} targeted for weight quantization but "
82
- "has no attribute weight, skipping weight quantization "
83
- f"for {type(module)}"
84
- )
85
- if scheme.output_activations is not None:
86
- _initialize_scale_zero_point_observer(
87
- module,
88
- "output",
89
- scheme.output_activations,
90
- force_zero_point=force_zero_point,
91
- )
92
-
93
- module.quantization_scheme = scheme
94
- module.quantization_status = QuantizationStatus.INITIALIZED
95
-
96
- offloaded = False
97
- if is_module_offloaded(module):
98
- try:
99
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
100
- from accelerate.utils import PrefixedDataset
101
- except ModuleNotFoundError:
102
- raise ModuleNotFoundError(
103
- "Offloaded model detected. To use CPU offloading with "
104
- "compressed-tensors the `accelerate` package must be installed, "
105
- "run `pip install compressed-tensors[accelerate]`"
106
- )
107
-
108
- offloaded = True
109
- hook = module._hf_hook
110
- prefix_dict = module._hf_hook.weights_map
111
- new_prefix = {}
112
-
113
- # recreate the prefix dict (since it is immutable)
114
- # and add quantization parameters
115
- for key, data in module.named_parameters():
116
- if key not in prefix_dict:
117
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
83
+ if scheme.weights is not None:
84
+ if hasattr(module, "weight"):
85
+ weight_shape = None
86
+ if isinstance(module, torch.nn.Linear):
87
+ weight_shape = module.weight.shape
88
+ _initialize_scale_zero_point_observer(
89
+ module,
90
+ "weight",
91
+ scheme.weights,
92
+ weight_shape=weight_shape,
93
+ force_zero_point=force_zero_point,
94
+ )
118
95
  else:
119
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
120
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
121
- remove_hook_from_module(module)
122
-
123
- # wrap forward call of module to perform quantized actions based on calltime status
124
- wrap_module_forward_quantized(module, scheme)
125
-
126
- if offloaded:
127
- # we need to re-add the hook for offloading now that we've wrapped forward
128
- add_hook_to_module(module, hook)
129
- if prefix_dict is not None:
130
- module._hf_hook.weights_map = new_prefix_dict
96
+ _LOGGER.warning(
97
+ f"module type {type(module)} targeted for weight quantization but "
98
+ "has no attribute weight, skipping weight quantization "
99
+ f"for {type(module)}"
100
+ )
101
+
102
+ if scheme.output_activations is not None:
103
+ if not is_kv_cache_quant_scheme(scheme):
104
+ _initialize_scale_zero_point_observer(
105
+ module, "output", scheme.output_activations
106
+ )
107
+
108
+ module.quantization_scheme = scheme
109
+ module.quantization_status = QuantizationStatus.INITIALIZED
110
+
111
+ offloaded = False
112
+ if is_module_offloaded(module):
113
+ try:
114
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
115
+ from accelerate.utils import PrefixedDataset
116
+ except ModuleNotFoundError:
117
+ raise ModuleNotFoundError(
118
+ "Offloaded model detected. To use CPU offloading with "
119
+ "compressed-tensors the `accelerate` package must be installed, "
120
+ "run `pip install compressed-tensors[accelerate]`"
121
+ )
122
+
123
+ offloaded = True
124
+ hook = module._hf_hook
125
+ prefix_dict = module._hf_hook.weights_map
126
+ new_prefix = {}
127
+
128
+ # recreate the prefix dict (since it is immutable)
129
+ # and add quantization parameters
130
+ for key, data in module.named_parameters():
131
+ if key not in prefix_dict:
132
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
133
+ else:
134
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
135
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
136
+ remove_hook_from_module(module)
137
+
138
+ # wrap forward call of module to perform
139
+ # quantized actions based on calltime status
140
+ wrap_module_forward_quantized(module, scheme)
141
+
142
+ if offloaded:
143
+ # we need to re-add the hook for offloading now that we've wrapped forward
144
+ add_hook_to_module(module, hook)
145
+ if prefix_dict is not None:
146
+ module._hf_hook.weights_map = new_prefix_dict
131
147
 
132
148
 
133
149
  def _initialize_scale_zero_point_observer(
@@ -156,9 +172,10 @@ def _initialize_scale_zero_point_observer(
156
172
  # (output_channels, 1)
157
173
  expected_shape = (weight_shape[0], 1)
158
174
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
175
+ num_groups = weight_shape[1] // quantization_args.group_size
159
176
  expected_shape = (
160
177
  weight_shape[0],
161
- weight_shape[1] // quantization_args.group_size,
178
+ max(num_groups, 1)
162
179
  )
163
180
 
164
181
  scale_dtype = module.weight.dtype
@@ -189,3 +206,34 @@ def _initialize_scale_zero_point_observer(
189
206
  requires_grad=False,
190
207
  )
191
208
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
209
+
210
+
211
+ def is_attention_module(module: Module):
212
+ return "attention" in module.__class__.__name__.lower() and (
213
+ hasattr(module, "k_proj")
214
+ or hasattr(module, "v_proj")
215
+ or hasattr(module, "qkv_proj")
216
+ )
217
+
218
+
219
+ def _initialize_attn_scales(module: Module) -> None:
220
+ """Initlaize k_scale, v_scale for self_attn"""
221
+
222
+ expected_shape = 1 # per tensor
223
+
224
+ param = next(module.parameters())
225
+ scale_dtype = param.dtype
226
+ device = param.device
227
+
228
+ init_scale = Parameter(
229
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
230
+ requires_grad=False,
231
+ )
232
+
233
+ module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
234
+
235
+ init_scale = Parameter(
236
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
237
+ requires_grad=False,
238
+ )
239
+ module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
@@ -122,6 +122,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
122
122
 
123
123
  return Observer.load_from_registry(self.observer, quantization_args=self)
124
124
 
125
+ def get_kv_cache(self):
126
+ """Get the singleton KV Cache"""
127
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
128
+
129
+ return QuantizedKVParameterCache(self)
130
+
125
131
  @field_validator("type", mode="before")
126
132
  def validate_type(cls, value) -> QuantizationType:
127
133
  if isinstance(value, str):