compressed-tensors-nightly 0.6.0.20240925__py3-none-any.whl → 0.6.0.20240928__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -16,3 +16,4 @@ SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
18
  KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
19
+ COMPRESSION_VERSION_NAME = "version"
@@ -25,6 +25,7 @@ import transformers
25
25
  import compressed_tensors
26
26
  from compressed_tensors.base import (
27
27
  COMPRESSION_CONFIG_NAME,
28
+ COMPRESSION_VERSION_NAME,
28
29
  QUANTIZATION_CONFIG_NAME,
29
30
  SPARSITY_CONFIG_NAME,
30
31
  )
@@ -200,6 +201,7 @@ class ModelCompressor:
200
201
  # SparseAutoModel format
201
202
  quantization_config = deepcopy(compression_config)
202
203
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
204
+ quantization_config.pop(COMPRESSION_VERSION_NAME, None)
203
205
  if len(quantization_config) == 0:
204
206
  quantization_config = None
205
207
  return quantization_config
@@ -214,6 +216,11 @@ class ModelCompressor:
214
216
  self.sparsity_compressor = None
215
217
  self.quantization_compressor = None
216
218
 
219
+
220
+ if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
221
+ # ignore dense sparsity config
222
+ self.sparsity_config = None
223
+
217
224
  if sparsity_config is not None:
218
225
  self.sparsity_compressor = Compressor.load_from_registry(
219
226
  sparsity_config.format, config=sparsity_config
@@ -252,62 +259,6 @@ class ModelCompressor:
252
259
  compressed_state_dict
253
260
  )
254
261
 
255
- # HACK (mgoin): Post-process step for kv cache scales to take the
256
- # k/v_proj module `output_scale` parameters, and store them in the
257
- # parent attention module as `k_scale` and `v_scale`
258
- #
259
- # Example:
260
- # Replace `model.layers.0.self_attn.k_proj.output_scale`
261
- # with `model.layers.0.self_attn.k_scale`
262
- if (
263
- self.quantization_config is not None
264
- and self.quantization_config.kv_cache_scheme is not None
265
- ):
266
- # HACK (mgoin): We assume the quantized modules in question
267
- # will be k_proj and v_proj since those are the default targets.
268
- # We check that both of these modules have output activation
269
- # quantization, and additionally check that q_proj doesn't.
270
- q_proj_has_no_quant_output = 0
271
- k_proj_has_quant_output = 0
272
- v_proj_has_quant_output = 0
273
- for name, module in model.named_modules():
274
- if not hasattr(module, "quantization_scheme"):
275
- # We still want to count non-quantized q_proj
276
- if name.endswith(".q_proj"):
277
- q_proj_has_no_quant_output += 1
278
- continue
279
- out_act = module.quantization_scheme.output_activations
280
- if name.endswith(".q_proj") and out_act is None:
281
- q_proj_has_no_quant_output += 1
282
- elif name.endswith(".k_proj") and out_act is not None:
283
- k_proj_has_quant_output += 1
284
- elif name.endswith(".v_proj") and out_act is not None:
285
- v_proj_has_quant_output += 1
286
-
287
- assert (
288
- q_proj_has_no_quant_output > 0
289
- and k_proj_has_quant_output > 0
290
- and v_proj_has_quant_output > 0
291
- )
292
- assert (
293
- q_proj_has_no_quant_output
294
- == k_proj_has_quant_output
295
- == v_proj_has_quant_output
296
- )
297
-
298
- # Move all .k/v_proj.output_scale parameters to .k/v_scale
299
- working_state_dict = {}
300
- for key in compressed_state_dict.keys():
301
- if key.endswith(".k_proj.output_scale"):
302
- new_key = key.replace(".k_proj.output_scale", ".k_scale")
303
- working_state_dict[new_key] = compressed_state_dict[key]
304
- elif key.endswith(".v_proj.output_scale"):
305
- new_key = key.replace(".v_proj.output_scale", ".v_scale")
306
- working_state_dict[new_key] = compressed_state_dict[key]
307
- else:
308
- working_state_dict[key] = compressed_state_dict[key]
309
- compressed_state_dict = working_state_dict
310
-
311
262
  # HACK: Override the dtype_byte_size function in transformers to
312
263
  # support float8 types. Fix is posted upstream
313
264
  # https://github.com/huggingface/transformers/pull/30488
@@ -360,16 +311,18 @@ class ModelCompressor:
360
311
  with open(config_file_path, "r") as config_file:
361
312
  config_data = json.load(config_file)
362
313
 
363
- config_data[COMPRESSION_CONFIG_NAME] = {}
314
+ config_data[QUANTIZATION_CONFIG_NAME] = {}
364
315
  if self.quantization_config is not None:
365
316
  quant_config_data = self.quantization_config.model_dump()
366
- config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
317
+ config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
367
318
  if self.sparsity_config is not None:
368
319
  sparsity_config_data = self.sparsity_config.model_dump()
369
- config_data[COMPRESSION_CONFIG_NAME][
320
+ config_data[QUANTIZATION_CONFIG_NAME][
370
321
  SPARSITY_CONFIG_NAME
371
322
  ] = sparsity_config_data
372
- config_data[COMPRESSION_CONFIG_NAME]["version"] = compressed_tensors.__version__
323
+ config_data[QUANTIZATION_CONFIG_NAME][
324
+ COMPRESSION_VERSION_NAME
325
+ ] = compressed_tensors.__version__
373
326
 
374
327
  with open(config_file_path, "w") as config_file:
375
328
  json.dump(config_data, config_file, indent=2, sort_keys=True)
@@ -19,3 +19,4 @@ from .quant_args import *
19
19
  from .quant_config import *
20
20
  from .quant_scheme import *
21
21
  from .lifecycle import *
22
+ from .cache import QuantizedKVParameterCache
@@ -0,0 +1,201 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from compressed_tensors.quantization.observers import Observer
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import Tensor
22
+ from transformers import DynamicCache as HFDyanmicCache
23
+
24
+
25
+ class KVCacheScaleType(Enum):
26
+ KEY = "k_scale"
27
+ VALUE = "v_scale"
28
+
29
+
30
+ class QuantizedKVParameterCache(HFDyanmicCache):
31
+
32
+ """
33
+ Quantized KV cache used in the forward call based on HF's dynamic cache.
34
+ Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
+ Singleton, so that the same cache gets reused in all forward call of self_attn.
36
+ Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
+ gets called appropriately.
38
+ The size of tensor is
39
+ `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
+
41
+
42
+ Triggered by adding kv_cache_scheme in the recipe.
43
+
44
+ Example:
45
+
46
+ ```python3
47
+ recipe = '''
48
+ quant_stage:
49
+ quant_modifiers:
50
+ QuantizationModifier:
51
+ kv_cache_scheme:
52
+ num_bits: 8
53
+ type: float
54
+ strategy: tensor
55
+ dynamic: false
56
+ symmetric: true
57
+ '''
58
+
59
+ """
60
+
61
+ _instance = None
62
+ _initialized = False
63
+
64
+ def __new__(cls, *args, **kwargs):
65
+ """Singleton"""
66
+ if cls._instance is None:
67
+ cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
+ return cls._instance
69
+
70
+ def __init__(self, quantization_args: QuantizationArgs):
71
+ if not self._initialized:
72
+ super().__init__()
73
+
74
+ self.quantization_args = quantization_args
75
+
76
+ self.k_observers: List[Observer] = []
77
+ self.v_observers: List[Observer] = []
78
+
79
+ # each index corresponds to layer_idx of the attention layer
80
+ self.k_scales: List[Tensor] = []
81
+ self.v_scales: List[Tensor] = []
82
+
83
+ self.k_zps: List[Tensor] = []
84
+ self.v_zps: List[Tensor] = []
85
+
86
+ self._initialized = True
87
+
88
+ def update(
89
+ self,
90
+ key_states: Tensor,
91
+ value_states: Tensor,
92
+ layer_idx: int,
93
+ cache_kwargs: Optional[Dict[str, Any]] = None,
94
+ ) -> Tuple[Tensor, Tensor]:
95
+ """
96
+ Get the k_scale and v_scale and output the
97
+ fakequant-ed key_states and value_states
98
+ """
99
+
100
+ if len(self.k_observers) <= layer_idx:
101
+ k_observer = self.quantization_args.get_observer()
102
+ v_observer = self.quantization_args.get_observer()
103
+
104
+ self.k_observers.append(k_observer)
105
+ self.v_observers.append(v_observer)
106
+
107
+ q_key_states = self._quantize(
108
+ key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
+ )
110
+ q_value_states = self._quantize(
111
+ value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
+ )
113
+
114
+ qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
+ qdq_value_states = self._dequantize(
116
+ q_value_states, KVCacheScaleType.VALUE, layer_idx
117
+ )
118
+
119
+ keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
+
121
+ return keys_to_return, values_to_return
122
+
123
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
+ """
125
+ Returns the sequence length of the cached states.
126
+ A layer index can be optionally passed.
127
+ """
128
+ if len(self.key_cache) <= layer_idx:
129
+ return 0
130
+ # since we cannot get the seq_length of each layer directly and
131
+ # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
+ # this is a hack to get the actual seq_length for the given layer_idx
133
+ # this part of code otherwise fails when used to
134
+ # verify attn_weight shape in some models
135
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
+
137
+ def reset_states(self):
138
+ """reset the kv states (used in calibration)"""
139
+ self.key_cache: List[Tensor] = []
140
+ self.value_cache: List[Tensor] = []
141
+ # Used in `generate` to keep tally of how many tokens the cache has seen
142
+ self._seen_tokens = 0
143
+ self._quantized_key_cache: List[Tensor] = []
144
+ self._quantized_value_cache: List[Tensor] = []
145
+
146
+ def reset(self):
147
+ """
148
+ Reset the instantiation, create new instance on init
149
+ """
150
+ QuantizedKVParameterCache._instance = None
151
+ QuantizedKVParameterCache._initialized = False
152
+
153
+ def _quantize(self, tensor, kv_type, layer_idx):
154
+ """Quantizes a key/value using a defined quantization method."""
155
+ from compressed_tensors.quantization.lifecycle.forward import quantize
156
+
157
+ if kv_type == KVCacheScaleType.KEY: # key type
158
+ observer = self.k_observers[layer_idx]
159
+ scales = self.k_scales
160
+ zps = self.k_zps
161
+ else:
162
+ assert kv_type == KVCacheScaleType.VALUE
163
+ observer = self.v_observers[layer_idx]
164
+ scales = self.v_scales
165
+ zps = self.v_zps
166
+
167
+ scale, zp = observer(tensor)
168
+ if len(scales) <= layer_idx:
169
+ scales.append(scale)
170
+ zps.append(zp)
171
+ else:
172
+ scales[layer_idx] = scale
173
+ zps[layer_idx] = scale
174
+
175
+ q_tensor = quantize(
176
+ x=tensor,
177
+ scale=scale,
178
+ zero_point=zp,
179
+ args=self.quantization_args,
180
+ )
181
+ return q_tensor
182
+
183
+ def _dequantize(self, qtensor, kv_type, layer_idx):
184
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
+ from compressed_tensors.quantization.lifecycle.forward import dequantize
186
+
187
+ if kv_type == KVCacheScaleType.KEY:
188
+ scale = self.k_scales[layer_idx]
189
+ zp = self.k_zps[layer_idx]
190
+ else:
191
+ assert kv_type == KVCacheScaleType.VALUE
192
+ scale = self.v_scales[layer_idx]
193
+ zp = self.v_zps[layer_idx]
194
+
195
+ qdq_tensor = dequantize(
196
+ x_q=qtensor,
197
+ scale=scale,
198
+ zero_point=zp,
199
+ args=self.quantization_args,
200
+ )
201
+ return qdq_tensor
@@ -43,6 +43,7 @@ from compressed_tensors.quantization.utils import (
43
43
  infer_quantization_status,
44
44
  is_kv_cache_quant_scheme,
45
45
  iter_named_leaf_modules,
46
+ iter_named_quantizable_modules,
46
47
  )
47
48
  from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
48
49
  from compressed_tensors.utils.offload import update_parameter_data
@@ -135,15 +136,23 @@ def apply_quantization_config(
135
136
  # list of submodules to ignore
136
137
  ignored_submodules = defaultdict(list)
137
138
  # mark appropriate layers for quantization by setting their quantization schemes
138
- for name, submodule in iter_named_leaf_modules(model):
139
+ for name, submodule in iter_named_quantizable_modules(
140
+ model,
141
+ include_children=True,
142
+ include_attn=True,
143
+ ): # child modules and attention modules
139
144
  # potentially fix module name to remove FSDP wrapper prefix
140
145
  name = fix_fsdp_module_name(name)
141
146
  if matches := find_name_or_class_matches(name, submodule, config.ignore):
142
147
  for match in matches:
143
148
  ignored_submodules[match].append(name)
144
149
  continue # layer matches ignore list, continue
150
+
145
151
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
152
+
146
153
  if targets:
154
+ # mark modules to be quantized by adding
155
+ # quant scheme to the matching layers
147
156
  scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
157
  if run_compressed:
149
158
  format = config.format
@@ -200,6 +209,9 @@ def process_kv_cache_config(
200
209
  :param config: the QuantizationConfig
201
210
  :return: the QuantizationConfig with additional "kv_cache" group
202
211
  """
212
+ if targets == KV_CACHE_TARGETS:
213
+ _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
214
+
203
215
  kv_cache_dict = config.kv_cache_scheme.model_dump()
204
216
  kv_cache_scheme = QuantizationScheme(
205
217
  output_activations=QuantizationArgs(**kv_cache_dict),
@@ -14,9 +14,10 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Optional
17
+ from typing import Callable, Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
20
21
  from compressed_tensors.quantization.observers.helpers import calculate_range
21
22
  from compressed_tensors.quantization.quant_args import (
22
23
  QuantizationArgs,
@@ -62,6 +63,7 @@ def quantize(
62
63
  :param g_idx: optional mapping from column index to group index
63
64
  :return: fake quantized tensor
64
65
  """
66
+
65
67
  return _process_quantization(
66
68
  x=x,
67
69
  scale=scale,
@@ -165,8 +167,8 @@ def _process_quantization(
165
167
  x: torch.Tensor,
166
168
  scale: torch.Tensor,
167
169
  zero_point: torch.Tensor,
168
- g_idx: Optional[torch.Tensor],
169
170
  args: QuantizationArgs,
171
+ g_idx: Optional[torch.Tensor] = None,
170
172
  dtype: Optional[torch.dtype] = None,
171
173
  do_quantize: bool = True,
172
174
  do_dequantize: bool = True,
@@ -266,6 +268,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
266
268
  return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
267
269
 
268
270
  input_ = args[0]
271
+
269
272
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
270
273
 
271
274
  if scheme.input_activations is not None:
@@ -285,9 +288,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
285
288
  output = forward_func_orig.__get__(module, module.__class__)(
286
289
  input_, *args[1:], **kwargs
287
290
  )
288
-
289
291
  if scheme.output_activations is not None:
292
+
290
293
  # calibrate and (fake) quantize output activations when applicable
294
+ # kv_cache scales updated on model self_attn forward call in
295
+ # wrap_module_forward_quantized_attn
291
296
  output = maybe_calibrate_or_quantize(
292
297
  module, output, "output", scheme.output_activations
293
298
  )
@@ -304,6 +309,50 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
304
309
  setattr(module, "forward", bound_wrapped_forward)
305
310
 
306
311
 
312
+ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
313
+ # expects a module already initialized and injected with the parameters in
314
+ # initialize_module_for_quantization
315
+ if hasattr(module.forward, "__func__"):
316
+ forward_func_orig = module.forward.__func__
317
+ else:
318
+ forward_func_orig = module.forward.func
319
+
320
+ @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
321
+ def wrapped_forward(self, *args, **kwargs):
322
+
323
+ # kv cache stored under weights
324
+ if module.quantization_status == QuantizationStatus.CALIBRATION:
325
+ quantization_args: QuantizationArgs = scheme.output_activations
326
+ past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
327
+ kwargs["past_key_value"] = past_key_value
328
+
329
+ # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
330
+ # does not store quantized_key_states and quantized_value_state
331
+ kwargs["use_cache"] = False
332
+
333
+ attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
334
+
335
+ past_key_value.reset_states()
336
+
337
+ rtn = attn_forward(*args, **kwargs)
338
+
339
+ update_parameter_data(
340
+ module, past_key_value.k_scales[module.layer_idx], "k_scale"
341
+ )
342
+ update_parameter_data(
343
+ module, past_key_value.v_scales[module.layer_idx], "v_scale"
344
+ )
345
+
346
+ return rtn
347
+
348
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
349
+
350
+ # bind wrapped forward to module class so reference to `self` is correct
351
+ bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
352
+ # set forward to wrapped forward
353
+ setattr(module, "forward", bound_wrapped_forward)
354
+
355
+
307
356
  def maybe_calibrate_or_quantize(
308
357
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
309
358
  ) -> torch.Tensor:
@@ -14,6 +14,7 @@
14
14
 
15
15
 
16
16
  from compressed_tensors.quantization.quant_config import QuantizationStatus
17
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
17
18
  from torch.nn import Module
18
19
 
19
20
 
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
44
45
  delattr(module, "input_observer")
45
46
  if scheme.weights and not scheme.weights.dynamic:
46
47
  delattr(module, "weight_observer")
47
- if scheme.output_activations and not scheme.output_activations.dynamic:
48
+ if (
49
+ scheme.output_activations
50
+ and not is_kv_cache_quant_scheme(scheme)
51
+ and not scheme.output_activations.dynamic
52
+ ):
48
53
  delattr(module, "output_observer")
49
54
 
50
55
  module.quantization_status = QuantizationStatus.FROZEN
@@ -17,8 +17,10 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import KVCacheScaleType
20
21
  from compressed_tensors.quantization.lifecycle.forward import (
21
22
  wrap_module_forward_quantized,
23
+ wrap_module_forward_quantized_attn,
22
24
  )
23
25
  from compressed_tensors.quantization.quant_args import (
24
26
  ActivationOrdering,
@@ -27,6 +29,7 @@ from compressed_tensors.quantization.quant_args import (
27
29
  )
28
30
  from compressed_tensors.quantization.quant_config import QuantizationStatus
29
31
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
30
33
  from compressed_tensors.utils import get_execution_device, is_module_offloaded
31
34
  from torch.nn import Module, Parameter
32
35
 
@@ -62,72 +65,85 @@ def initialize_module_for_quantization(
62
65
  # no scheme passed and layer not targeted for quantization - skip
63
66
  return
64
67
 
65
- if scheme.input_activations is not None:
66
- _initialize_scale_zero_point_observer(
67
- module, "input", scheme.input_activations, force_zero_point=force_zero_point
68
- )
69
- if scheme.weights is not None:
70
- if hasattr(module, "weight"):
71
- weight_shape = module.weight.shape
68
+ if is_attention_module(module):
69
+ # wrap forward call of module to perform
70
+ # quantized actions based on calltime status
71
+ wrap_module_forward_quantized_attn(module, scheme)
72
+ _initialize_attn_scales(module)
73
+
74
+ else:
75
+
76
+ if scheme.input_activations is not None:
72
77
  _initialize_scale_zero_point_observer(
73
78
  module,
74
- "weight",
75
- scheme.weights,
76
- weight_shape=weight_shape,
79
+ "input",
80
+ scheme.input_activations,
77
81
  force_zero_point=force_zero_point,
78
82
  )
79
- else:
80
- _LOGGER.warning(
81
- f"module type {type(module)} targeted for weight quantization but "
82
- "has no attribute weight, skipping weight quantization "
83
- f"for {type(module)}"
84
- )
85
- if scheme.output_activations is not None:
86
- _initialize_scale_zero_point_observer(
87
- module,
88
- "output",
89
- scheme.output_activations,
90
- force_zero_point=force_zero_point,
91
- )
92
-
93
- module.quantization_scheme = scheme
94
- module.quantization_status = QuantizationStatus.INITIALIZED
95
-
96
- offloaded = False
97
- if is_module_offloaded(module):
98
- try:
99
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
100
- from accelerate.utils import PrefixedDataset
101
- except ModuleNotFoundError:
102
- raise ModuleNotFoundError(
103
- "Offloaded model detected. To use CPU offloading with "
104
- "compressed-tensors the `accelerate` package must be installed, "
105
- "run `pip install compressed-tensors[accelerate]`"
106
- )
107
-
108
- offloaded = True
109
- hook = module._hf_hook
110
- prefix_dict = module._hf_hook.weights_map
111
- new_prefix = {}
112
-
113
- # recreate the prefix dict (since it is immutable)
114
- # and add quantization parameters
115
- for key, data in module.named_parameters():
116
- if key not in prefix_dict:
117
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
83
+ if scheme.weights is not None:
84
+ if hasattr(module, "weight"):
85
+ weight_shape = None
86
+ if isinstance(module, torch.nn.Linear):
87
+ weight_shape = module.weight.shape
88
+ _initialize_scale_zero_point_observer(
89
+ module,
90
+ "weight",
91
+ scheme.weights,
92
+ weight_shape=weight_shape,
93
+ force_zero_point=force_zero_point,
94
+ )
118
95
  else:
119
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
120
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
121
- remove_hook_from_module(module)
122
-
123
- # wrap forward call of module to perform quantized actions based on calltime status
124
- wrap_module_forward_quantized(module, scheme)
125
-
126
- if offloaded:
127
- # we need to re-add the hook for offloading now that we've wrapped forward
128
- add_hook_to_module(module, hook)
129
- if prefix_dict is not None:
130
- module._hf_hook.weights_map = new_prefix_dict
96
+ _LOGGER.warning(
97
+ f"module type {type(module)} targeted for weight quantization but "
98
+ "has no attribute weight, skipping weight quantization "
99
+ f"for {type(module)}"
100
+ )
101
+
102
+ if scheme.output_activations is not None:
103
+ if not is_kv_cache_quant_scheme(scheme):
104
+ _initialize_scale_zero_point_observer(
105
+ module, "output", scheme.output_activations
106
+ )
107
+
108
+ module.quantization_scheme = scheme
109
+ module.quantization_status = QuantizationStatus.INITIALIZED
110
+
111
+ offloaded = False
112
+ if is_module_offloaded(module):
113
+ try:
114
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
115
+ from accelerate.utils import PrefixedDataset
116
+ except ModuleNotFoundError:
117
+ raise ModuleNotFoundError(
118
+ "Offloaded model detected. To use CPU offloading with "
119
+ "compressed-tensors the `accelerate` package must be installed, "
120
+ "run `pip install compressed-tensors[accelerate]`"
121
+ )
122
+
123
+ offloaded = True
124
+ hook = module._hf_hook
125
+ prefix_dict = module._hf_hook.weights_map
126
+ new_prefix = {}
127
+
128
+ # recreate the prefix dict (since it is immutable)
129
+ # and add quantization parameters
130
+ for key, data in module.named_parameters():
131
+ if key not in prefix_dict:
132
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
133
+ else:
134
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
135
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
136
+ remove_hook_from_module(module)
137
+
138
+ # wrap forward call of module to perform
139
+ # quantized actions based on calltime status
140
+ wrap_module_forward_quantized(module, scheme)
141
+
142
+ if offloaded:
143
+ # we need to re-add the hook for offloading now that we've wrapped forward
144
+ add_hook_to_module(module, hook)
145
+ if prefix_dict is not None:
146
+ module._hf_hook.weights_map = new_prefix_dict
131
147
 
132
148
 
133
149
  def _initialize_scale_zero_point_observer(
@@ -189,3 +205,34 @@ def _initialize_scale_zero_point_observer(
189
205
  requires_grad=False,
190
206
  )
191
207
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
208
+
209
+
210
+ def is_attention_module(module: Module):
211
+ return "attention" in module.__class__.__name__.lower() and (
212
+ hasattr(module, "k_proj")
213
+ or hasattr(module, "v_proj")
214
+ or hasattr(module, "qkv_proj")
215
+ )
216
+
217
+
218
+ def _initialize_attn_scales(module: Module) -> None:
219
+ """Initlaize k_scale, v_scale for self_attn"""
220
+
221
+ expected_shape = 1 # per tensor
222
+
223
+ param = next(module.parameters())
224
+ scale_dtype = param.dtype
225
+ device = param.device
226
+
227
+ init_scale = Parameter(
228
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
229
+ requires_grad=False,
230
+ )
231
+
232
+ module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
233
+
234
+ init_scale = Parameter(
235
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
236
+ requires_grad=False,
237
+ )
238
+ module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
@@ -122,6 +122,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
122
122
 
123
123
  return Observer.load_from_registry(self.observer, quantization_args=self)
124
124
 
125
+ def get_kv_cache(self):
126
+ """Get the singleton KV Cache"""
127
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
128
+
129
+ return QuantizedKVParameterCache(self)
130
+
125
131
  @field_validator("type", mode="before")
126
132
  def validate_type(cls, value) -> QuantizationType:
127
133
  if isinstance(value, str):
@@ -24,7 +24,7 @@ from compressed_tensors.quantization.quant_scheme import (
24
24
  from compressed_tensors.quantization.utils import (
25
25
  calculate_compression_ratio,
26
26
  is_module_quantized,
27
- iter_named_leaf_modules,
27
+ iter_named_quantizable_modules,
28
28
  module_type,
29
29
  parse_out_kv_cache_args,
30
30
  )
@@ -177,7 +177,9 @@ class QuantizationConfig(BaseModel):
177
177
  quantization_status = None
178
178
  ignore = {}
179
179
  quantization_type_names = set()
180
- for name, submodule in iter_named_leaf_modules(model):
180
+ for name, submodule in iter_named_quantizable_modules(
181
+ model, include_children=True, include_attn=True
182
+ ):
181
183
  layer_type = module_type(submodule)
182
184
  if not is_module_quantized(submodule):
183
185
  if layer_type not in ignore:
@@ -241,6 +243,9 @@ class QuantizationConfig(BaseModel):
241
243
  )
242
244
 
243
245
  def requires_calibration_data(self):
246
+ if self.kv_cache_scheme is not None:
247
+ return True
248
+
244
249
  for _, scheme in self.config_groups.items():
245
250
  if scheme.input_activations is not None:
246
251
  if not scheme.input_activations.dynamic:
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- import re
17
- from typing import List, Optional, Tuple
16
+ from typing import Generator, List, Optional, Tuple
18
17
 
19
18
  import torch
20
19
  from compressed_tensors.quantization.observers.base import Observer
@@ -28,7 +27,6 @@ __all__ = [
28
27
  "infer_quantization_status",
29
28
  "is_module_quantized",
30
29
  "is_model_quantized",
31
- "iter_named_leaf_modules",
32
30
  "module_type",
33
31
  "calculate_compression_ratio",
34
32
  "get_torch_bit_depth",
@@ -36,9 +34,14 @@ __all__ = [
36
34
  "parse_out_kv_cache_args",
37
35
  "KV_CACHE_TARGETS",
38
36
  "is_kv_cache_quant_scheme",
37
+ "iter_named_leaf_modules",
38
+ "iter_named_quantizable_modules",
39
39
  ]
40
40
 
41
- KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
41
+ # target the self_attn layer
42
+ # QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
43
+ KV_CACHE_TARGETS = ["re:.*self_attn$"]
44
+
42
45
  _LOGGER: logging.Logger = logging.getLogger(__name__)
43
46
 
44
47
 
@@ -106,11 +109,10 @@ def module_type(module: Module) -> str:
106
109
  return type(module).__name__
107
110
 
108
111
 
109
- def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
112
+ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
110
113
  """
111
114
  Yields modules that do not have any submodules except observers. The observers
112
115
  themselves are not yielded
113
-
114
116
  :param model: model to get leaf modules of
115
117
  :returns: generator tuple of (name, leaf_submodule)
116
118
  """
@@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
128
130
  yield name, submodule
129
131
 
130
132
 
133
+ def iter_named_quantizable_modules(
134
+ model: Module, include_children: bool = True, include_attn: bool = False
135
+ ) -> Generator[Tuple[str, Module], None, None]:
136
+ """
137
+ Yield name and submodule of
138
+ - leaf modules, set by include_children
139
+ - attention modyles, set by include_attn
140
+
141
+ :param model: model to get leaf modules of
142
+ :param include_children: flag to get the leaf modules
143
+ :param inlcude_attn: flag to get the attention modules
144
+ :returns: generator tuple of (name, submodule)
145
+ """
146
+ for name, submodule in model.named_modules():
147
+ if include_children:
148
+ children = list(submodule.children())
149
+ if len(children) == 0 and not isinstance(submodule, Observer):
150
+ yield name, submodule
151
+ else:
152
+ has_non_observer_children = False
153
+ for child in children:
154
+ if not isinstance(child, Observer):
155
+ has_non_observer_children = True
156
+
157
+ if not has_non_observer_children:
158
+ yield name, submodule
159
+ if include_attn:
160
+ if name.endswith("self_attn"):
161
+ yield name, submodule
162
+
163
+
131
164
  def get_torch_bit_depth(value: torch.Tensor) -> int:
132
165
  """
133
166
  Determine the number of bits used to represent the dtype of a tensor
@@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
204
237
  :param scheme: The QuantizationScheme to investigate
205
238
  :return: boolean flag
206
239
  """
207
- if len(scheme.targets) == 1:
208
- # match on the KV_CACHE_TARGETS regex pattern
209
- # if there is only one target
210
- is_match_targets = any(
211
- [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
- )
213
- else:
214
- # match on the exact KV_CACHE_TARGETS
215
- # if there are multiple targets
216
- is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
240
+ for target in scheme.targets:
241
+ if target in KV_CACHE_TARGETS:
242
+ return True
217
243
 
218
- is_match_output_activations = scheme.output_activations is not None
219
- return is_match_targets and is_match_output_activations
244
+ return False
220
245
 
221
246
 
222
247
  def parse_out_kv_cache_args(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.6.0.20240925
3
+ Version: 0.6.0.20240928
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,12 +1,12 @@
1
1
  compressed_tensors/__init__.py,sha256=UtKmifNeBCSE2TZSAfduVNNzHY-3V7bLjZ7n7RuXLOE,812
2
- compressed_tensors/base.py,sha256=Mq4mfVQcJhNpha-BXzpOfpmFIdl01o09BJE7D2oQ_00,796
2
+ compressed_tensors/base.py,sha256=7fdFGo8lxjLvrsbBEn0KqceGzcI4RdMSTh8mR6J1Hws,833
3
3
  compressed_tensors/version.py,sha256=83tBdwNu2sUhiLPvv6tRNh4Y7u70sZ1TFy3ydWctVL8,1586
4
4
  compressed_tensors/compressors/__init__.py,sha256=wmX4VnkUTS63xBwK5-6w8FP78bNZpcdcqvf2KOEC5E4,1133
5
5
  compressed_tensors/compressors/base.py,sha256=NfVkhq6PRiq2cvAXaUXLoqC_nVYWdSrkE12c9AXYSMo,9956
6
6
  compressed_tensors/compressors/dense.py,sha256=xcWECjcRY4INN6jC7vHx5wvUX3NmnKlxA9SVE1A6m2Q,1267
7
7
  compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
8
8
  compressed_tensors/compressors/marlin_24.py,sha256=e7fGUyZbjUpA5VUMCPxqcYPGNiwoDKupHJaXWCoVKRw,9410
9
- compressed_tensors/compressors/model_compressor.py,sha256=Wq-NbjtaVOEElDpcjEYun6QFvAIZee8ZAw_wbifuTDA,16793
9
+ compressed_tensors/compressors/model_compressor.py,sha256=3pMfGTTb8bN8PRNCFuH5k0RbP38r8GS_-cPgCkzL9vk,14355
10
10
  compressed_tensors/compressors/naive_quantized.py,sha256=z3h3ca5xKCN69mahutxcbzdv-OysiaxaM8P-Qum6zUQ,4823
11
11
  compressed_tensors/compressors/pack_quantized.py,sha256=27RVmJ2wg2dvCoawj407HSmKT3VPGJ6ujAMHlT26WlI,7571
12
12
  compressed_tensors/compressors/sparse_bitmask.py,sha256=kiDwBlFV0sJGLcIdDYxIiuF64ccgwDfqq1hWRQThYDc,8647
@@ -16,18 +16,19 @@ compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74j
16
16
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
17
17
  compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
18
18
  compressed_tensors/linear/compressed_linear.py,sha256=G0gEFfxLAUsgRcnfSV-PKz1ZBNTVokOauOoup7SE1mw,3210
19
- compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
20
- compressed_tensors/quantization/quant_args.py,sha256=CmyVtjJeHlqCW-7R5Z7tIw6lXUrzCX6Y9bwgmMxEudY,8069
21
- compressed_tensors/quantization/quant_config.py,sha256=NpVu8YJ4Xw2pIQW_PGaNaml8kx1bUnxkvb0jBYWbKdE,9971
19
+ compressed_tensors/quantization/__init__.py,sha256=nWP_fsl6Nn0ksEgZPzerGiETdvF-ZfNwPnwGlRiR5pY,805
20
+ compressed_tensors/quantization/cache.py,sha256=vnBB5zasO_XpHomZvzUPVVbzyCz2VgebsHePm0kANzY,6831
21
+ compressed_tensors/quantization/quant_args.py,sha256=73KevZXHyrkMCT_3CxbYHz70fI3i-wcF8NvN0wsBPK4,8271
22
+ compressed_tensors/quantization/quant_config.py,sha256=xcCLkPomAOfjB1X8PmQTw1Bmqs8_JF52dSQ9W07VQZc,10119
22
23
  compressed_tensors/quantization/quant_scheme.py,sha256=2ITawuNf76E1CDYBWrfpMP8tyZFykzwU99-eD-WggsM,5930
23
24
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=MXE2E7GfIfRRfhrdGy2Og3AZOz5N59B0ZGFcsD89y6c,821
24
- compressed_tensors/quantization/lifecycle/apply.py,sha256=uftWFunr_CpCZM_qWfo2O1USXKB2qSYD1pBJsO8BuCU,15285
25
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=_rd56GZZkhbu0HWiq6iYzgcnkMsX3GCs-e8DvtmWmbQ,15668
25
26
  compressed_tensors/quantization/lifecycle/calibration.py,sha256=PlS_EqCOPqJD3QKuLPXO9AOtDzXtQWvEBTynFv-FFVw,2698
26
27
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
27
- compressed_tensors/quantization/lifecycle/forward.py,sha256=PljD9pzATILEOiC3ZdHUTsfSbZdAa6iSIxWmvAHLG9I,13688
28
- compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
28
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=eLup6QDRUUp_Ozcas7RDRLIXBWjFbxn5gWbcAIJEGlw,15715
29
+ compressed_tensors/quantization/lifecycle/frozen.py,sha256=NiJw7NP7pcT6idWFa8vksgiLoT8oQ975e57S4QfD2QQ,1874
29
30
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=TmLY_G5VP_Fg2Ywio_dxoHRTxOKZdT7_aG5S9WtD4zI,2424
30
- compressed_tensors/quantization/lifecycle/initialize.py,sha256=S5Kwy16Da8WUIIpa1xVKc72MijJ5C_rqM6JjanZ7MGk,7133
31
+ compressed_tensors/quantization/lifecycle/initialize.py,sha256=HAtSm7vKOZ3kGZuWe2B8LsmfC5B5vIKlc0V8C4rAF4Y,8819
31
32
  compressed_tensors/quantization/observers/__init__.py,sha256=4Sa7rqi5RB_S5bPO8KmncETiqDsoMBhwP37arlQym8s,764
32
33
  compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
33
34
  compressed_tensors/quantization/observers/helpers.py,sha256=s_A23Qa_BLfOdHJCN5bm-qPWkhjjj_RIVrhSp1Y9Dtk,4211
@@ -35,7 +36,7 @@ compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ
35
36
  compressed_tensors/quantization/observers/min_max.py,sha256=sQXqU3z-voxIDfR_9mQzwQUflZj2sASm_G8CYaXntFw,3865
36
37
  compressed_tensors/quantization/observers/mse.py,sha256=Aeh-253Vbab1F8cYuBiGNn4OXWJ67wXQ_JVfl3mu2a8,6034
37
38
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
38
- compressed_tensors/quantization/utils/helpers.py,sha256=pwvU613XRvMDtI5b39II5jukBl5OUCqoX0ofVRpOFRY,8633
39
+ compressed_tensors/quantization/utils/helpers.py,sha256=y4LEyC2oUd876ZMdALWKGH3Ct5EgBJZV4id_NUjTGH8,9531
39
40
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
40
41
  compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
41
42
  compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
@@ -45,8 +46,8 @@ compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVy
45
46
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
46
47
  compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
47
48
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
48
- compressed_tensors_nightly-0.6.0.20240925.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
49
- compressed_tensors_nightly-0.6.0.20240925.dist-info/METADATA,sha256=AHeC-ko08CtK8_xQUnuNlWNQIhmDcKzDpihAiMBHjR8,6799
50
- compressed_tensors_nightly-0.6.0.20240925.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
51
- compressed_tensors_nightly-0.6.0.20240925.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
52
- compressed_tensors_nightly-0.6.0.20240925.dist-info/RECORD,,
49
+ compressed_tensors_nightly-0.6.0.20240928.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
50
+ compressed_tensors_nightly-0.6.0.20240928.dist-info/METADATA,sha256=vndAZXPsHUGFnoR1oLqalmP1tnMaAUx7QgXHPVrwarE,6799
51
+ compressed_tensors_nightly-0.6.0.20240928.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
52
+ compressed_tensors_nightly-0.6.0.20240928.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
53
+ compressed_tensors_nightly-0.6.0.20240928.dist-info/RECORD,,