compressed-tensors-nightly 0.4.0.20240623__py3-none-any.whl → 0.4.0.20240627__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,3 +15,4 @@
15
15
  SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -45,7 +45,7 @@ class Compressor(RegistryMixin):
45
45
  raise NotImplementedError()
46
46
 
47
47
  def decompress(
48
- self, path_to_model_or_tensors: str, device: str = "cpu"
48
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
49
49
  ) -> Generator[Tuple[str, Tensor], None, None]:
50
50
  """
51
51
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -29,6 +29,6 @@ class DenseCompressor(Compressor):
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str = "cpu"
32
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -107,7 +107,7 @@ class Marlin24Compressor(Compressor):
107
107
  def compress(
108
108
  self,
109
109
  model_state: Dict[str, Tensor],
110
- model_quant_args: Dict[str, QuantizationArgs],
110
+ names_to_scheme: Dict[str, QuantizationArgs],
111
111
  **kwargs,
112
112
  ) -> Dict[str, Tensor]:
113
113
  """
@@ -115,11 +115,11 @@ class Marlin24Compressor(Compressor):
115
115
  with the Marlin24 kernel
116
116
 
117
117
  :param model_state: state dict of uncompressed model
118
- :param model_quant_args: quantization args for each quantized weight, needed for
118
+ :param names_to_scheme: quantization args for each quantized weight, needed for
119
119
  quantize function to calculate bit depth
120
120
  :return: compressed state dict
121
121
  """
122
- self.validate_quant_compatability(model_quant_args)
122
+ self.validate_quant_compatability(names_to_scheme)
123
123
 
124
124
  compressed_dict = {}
125
125
  weight_suffix = ".weight"
@@ -139,7 +139,7 @@ class Marlin24Compressor(Compressor):
139
139
  value = value.to(torch.float16)
140
140
 
141
141
  # quantize weight, keeping it as a float16 for now
142
- quant_args = model_quant_args[prefix]
142
+ quant_args = names_to_scheme[prefix]
143
143
  value = quantize(
144
144
  x=value, scale=scale, zero_point=zp, args=quant_args
145
145
  )
@@ -175,7 +175,7 @@ class Marlin24Compressor(Compressor):
175
175
  return compressed_dict
176
176
 
177
177
  def decompress(
178
- self, path_to_model_or_tensors: str, device: str = "cpu"
178
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
179
179
  ) -> Generator[Tuple[str, Tensor], None, None]:
180
180
  raise NotImplementedError(
181
181
  "Decompression is not implemented for the Marlin24 Compressor."
@@ -231,7 +231,7 @@ class ModelCompressor:
231
231
  quantized_modules_to_args = map_modules_to_quant_args(model)
232
232
  if self.quantization_compressor is not None:
233
233
  compressed_state_dict = self.quantization_compressor.compress(
234
- state_dict, model_quant_args=quantized_modules_to_args
234
+ state_dict, names_to_scheme=quantized_modules_to_args
235
235
  )
236
236
 
237
237
  if self.sparsity_compressor is not None:
@@ -260,9 +260,11 @@ class ModelCompressor:
260
260
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
261
261
 
262
262
  if self.quantization_compressor is not None:
263
- apply_quantization_config(model, self.quantization_config)
263
+ names_to_scheme = apply_quantization_config(model, self.quantization_config)
264
264
  load_pretrained_quantization(model, model_path)
265
- dense_gen = self.quantization_compressor.decompress(model_path)
265
+ dense_gen = self.quantization_compressor.decompress(
266
+ model_path, names_to_scheme=names_to_scheme
267
+ )
266
268
  self._replace_weights(dense_gen, model)
267
269
 
268
270
  def update_status(module):
@@ -49,14 +49,14 @@ class QuantizationCompressor(Compressor):
49
49
  def compress(
50
50
  self,
51
51
  model_state: Dict[str, Tensor],
52
- model_quant_args: Dict[str, QuantizationArgs],
52
+ names_to_scheme: Dict[str, QuantizationArgs],
53
53
  **kwargs,
54
54
  ) -> Dict[str, Tensor]:
55
55
  """
56
56
  Compresses a dense state dict
57
57
 
58
58
  :param model_state: state dict of uncompressed model
59
- :param model_quant_args: quantization args for each quantized weight, needed for
59
+ :param names_to_scheme: quantization args for each quantized weight, needed for
60
60
  quantize function to calculate bit depth
61
61
  :return: compressed state dict
62
62
  """
@@ -73,7 +73,7 @@ class QuantizationCompressor(Compressor):
73
73
  zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
74
74
  if scale is not None and zp is not None:
75
75
  # weight is quantized, compress it
76
- quant_args = model_quant_args[prefix]
76
+ quant_args = names_to_scheme[prefix]
77
77
  if can_quantize(value, quant_args):
78
78
  # only quantize if not already quantized
79
79
  value = quantize(
@@ -93,7 +93,7 @@ class QuantizationCompressor(Compressor):
93
93
  return compressed_dict
94
94
 
95
95
  def decompress(
96
- self, path_to_model_or_tensors: str, device: str = "cpu"
96
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
97
97
  ) -> Generator[Tuple[str, Tensor], None, None]:
98
98
  """
99
99
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -29,7 +29,13 @@ from torch import Tensor
29
29
  from tqdm import tqdm
30
30
 
31
31
 
32
- __all__ = ["PackedQuantizationCompressor", "pack_4bit_ints", "unpack_4bit_ints"]
32
+ __all__ = [
33
+ "PackedQuantizationCompressor",
34
+ "pack_4bit_ints",
35
+ "pack_8bit_ints",
36
+ "unpack_4bit_ints",
37
+ "unpack_8bit_ints",
38
+ ]
33
39
 
34
40
  _LOGGER: logging.Logger = logging.getLogger(__name__)
35
41
 
@@ -50,14 +56,14 @@ class PackedQuantizationCompressor(Compressor):
50
56
  def compress(
51
57
  self,
52
58
  model_state: Dict[str, Tensor],
53
- model_quant_args: Dict[str, QuantizationArgs],
59
+ names_to_scheme: Dict[str, QuantizationArgs],
54
60
  **kwargs,
55
61
  ) -> Dict[str, Tensor]:
56
62
  """
57
63
  Compresses a dense state dict
58
64
 
59
65
  :param model_state: state dict of uncompressed model
60
- :param model_quant_args: quantization args for each quantized weight, needed for
66
+ :param names_to_scheme: quantization args for each quantized weight, needed for
61
67
  quantize function to calculate bit depth
62
68
  :return: compressed state dict
63
69
  """
@@ -75,7 +81,7 @@ class PackedQuantizationCompressor(Compressor):
75
81
  shape = torch.tensor(value.shape)
76
82
  if scale is not None and zp is not None:
77
83
  # weight is quantized, compress it
78
- quant_args = model_quant_args[prefix]
84
+ quant_args = names_to_scheme[prefix]
79
85
  if can_quantize(value, quant_args):
80
86
  # convert weight to an int if not already compressed
81
87
  value = quantize(
@@ -85,7 +91,11 @@ class PackedQuantizationCompressor(Compressor):
85
91
  args=quant_args,
86
92
  dtype=torch.int8,
87
93
  )
88
- value = pack_4bit_ints(value.cpu())
94
+
95
+ if quant_args.num_bits == 8:
96
+ value = pack_8bit_ints(value.cpu())
97
+ else:
98
+ value = pack_4bit_ints(value.cpu())
89
99
  compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
100
  compressed_dict[merge_names(prefix, "weight_packed")] = value
91
101
  continue
@@ -101,7 +111,10 @@ class PackedQuantizationCompressor(Compressor):
101
111
  return compressed_dict
102
112
 
103
113
  def decompress(
104
- self, path_to_model_or_tensors: str, device: str = "cpu"
114
+ self,
115
+ path_to_model_or_tensors: str,
116
+ names_to_scheme: Dict[str, QuantizationArgs],
117
+ device: str = "cpu",
105
118
  ) -> Generator[Tuple[str, Tensor], None, None]:
106
119
  """
107
120
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -119,6 +132,7 @@ class PackedQuantizationCompressor(Compressor):
119
132
  for weight_name in weight_mappings.keys():
120
133
  weight_data = {}
121
134
  for param_name, safe_path in weight_mappings[weight_name].items():
135
+ weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
122
136
  full_name = merge_names(weight_name, param_name)
123
137
  with safe_open(safe_path, framework="pt", device=device) as f:
124
138
  weight_data[param_name] = f.get_tensor(full_name)
@@ -127,8 +141,12 @@ class PackedQuantizationCompressor(Compressor):
127
141
  zero_point = weight_data.get("weight_zero_point", None)
128
142
  scale = weight_data["weight_scale"]
129
143
  weight = weight_data["weight_packed"]
144
+ num_bits = weight_data["num_bits"]
130
145
  original_shape = torch.Size(weight_data["weight_shape"])
131
- unpacked = unpack_4bit_ints(weight, original_shape)
146
+ if num_bits == 4:
147
+ unpacked = unpack_4bit_ints(weight, original_shape)
148
+ else:
149
+ unpacked = unpack_8bit_ints(weight, original_shape)
132
150
  decompressed = dequantize(
133
151
  x_q=unpacked,
134
152
  scale=scale,
@@ -137,6 +155,19 @@ class PackedQuantizationCompressor(Compressor):
137
155
  yield merge_names(weight_name, "weight"), decompressed
138
156
 
139
157
 
158
+ def pack_8bit_ints(value: torch.Tensor) -> torch.Tensor:
159
+ """
160
+ Packs a tensor of int8 into int32s with padding
161
+
162
+ :param value: tensor to pack
163
+ :returns: packed int32 tensor
164
+ """
165
+ # need to convert to unsigned 8bit to use numpy's pack/unpack
166
+ value_uint = (value - 128).to(torch.uint8)
167
+ bits = np.unpackbits(value_uint, axis=-1, bitorder="little")
168
+ return _pack_bits(bits_to_pack=bits)
169
+
170
+
140
171
  def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
141
172
  """
142
173
  Packs a tensor of int4 weights stored in int8 into int32s with padding
@@ -152,22 +183,31 @@ def pack_4bit_ints(value: torch.Tensor) -> torch.Tensor:
152
183
  bits = np.unpackbits(temp.numpy(), axis=-1, bitorder="little")
153
184
  ranges = np.array([range(x, x + 4) for x in range(0, bits.shape[1], 8)]).flatten()
154
185
  only_4_bits = bits[:, ranges] # top 4 bits are 0 because we're really uint4
186
+ return _pack_bits(bits_to_pack=only_4_bits)
155
187
 
156
- # pad each row to fill a full 32bit int
157
- pack_depth = 32
158
- padding = (
159
- math.ceil(only_4_bits.shape[1] / pack_depth) * pack_depth - only_4_bits.shape[1]
160
- )
161
- padded_bits = np.pad(
162
- only_4_bits, pad_width=[(0, 0), (0, padding)], constant_values=0
163
- )
164
188
 
165
- # after packbits each uint8 is two packed uint4s
166
- # then we keep the bit pattern the same but convert to int32
167
- compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
168
- compressed = np.ascontiguousarray(compressed).view(np.int32)
189
+ def unpack_8bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
190
+ """
191
+ Unpacks a tensor packed int8 weights in int32
169
192
 
170
- return torch.from_numpy(compressed)
193
+ :param value: tensor to upack
194
+ :param shape: shape to unpack into, used to remove padding
195
+ :returns: unpacked int8 tensor
196
+ """
197
+ if value.dtype is not torch.int32:
198
+ raise ValueError(
199
+ f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
200
+ )
201
+
202
+ # unpack bits and undo padding to nearest int32 bits
203
+ individual_depth = 8
204
+ as_uint8 = value.numpy().view(np.uint8)
205
+ bits = np.unpackbits(as_uint8, axis=-1, bitorder="little")
206
+ original_row_size = int(shape[1] * individual_depth)
207
+ bits = bits[:, :original_row_size]
208
+ bits = np.packbits(bits, axis=-1, bitorder="little")
209
+ final = (bits - 128).astype(np.int8)
210
+ return torch.from_numpy(final)
171
211
 
172
212
 
173
213
  def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
@@ -206,3 +246,27 @@ def unpack_4bit_ints(value: torch.Tensor, shape: torch.Size) -> torch.Tensor:
206
246
  final = repacked.astype(np.int8) - 8
207
247
 
208
248
  return torch.from_numpy(final)
249
+
250
+
251
+ def _pack_bits(bits_to_pack: torch.Tensor) -> torch.Tensor:
252
+ """
253
+ Pack a tensor of bits to int32.
254
+
255
+ :param bits_to_pack: tensor of bits to pack
256
+ """
257
+ # pad each row to fill a full 32bit int
258
+ pack_depth = 32
259
+ padding = (
260
+ math.ceil(bits_to_pack.shape[1] / pack_depth) * pack_depth
261
+ - bits_to_pack.shape[1]
262
+ )
263
+ padded_bits = np.pad(
264
+ bits_to_pack, pad_width=[(0, 0), (0, padding)], constant_values=0
265
+ )
266
+
267
+ # after packbits each uint8 is two packed uint4s
268
+ # then we keep the bit pattern the same but convert to int32
269
+ compressed = np.packbits(padded_bits, axis=-1, bitorder="little")
270
+ compressed = np.ascontiguousarray(compressed).view(np.int32)
271
+
272
+ return torch.from_numpy(compressed)
@@ -72,7 +72,7 @@ class BitmaskCompressor(Compressor):
72
72
  return compressed_dict
73
73
 
74
74
  def decompress(
75
- self, path_to_model_or_tensors: str, device: str = "cpu"
75
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
76
76
  ) -> Generator[Tuple[str, Tensor], None, None]:
77
77
  """
78
78
  Reads a bitmask compressed state dict located
@@ -15,7 +15,9 @@
15
15
  import logging
16
16
  import re
17
17
  from collections import OrderedDict
18
- from typing import Dict, Iterable, Optional
18
+ from typing import Dict, Iterable, List, Optional
19
+ from typing import OrderedDict as OrderedDictType
20
+ from typing import Union
19
21
 
20
22
  import torch
21
23
  from compressed_tensors.quantization.lifecycle.calibration import (
@@ -28,12 +30,16 @@ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quant
28
30
  from compressed_tensors.quantization.lifecycle.initialize import (
29
31
  initialize_module_for_quantization,
30
32
  )
33
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
31
34
  from compressed_tensors.quantization.quant_config import (
32
35
  QuantizationConfig,
33
36
  QuantizationStatus,
34
37
  )
38
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35
39
  from compressed_tensors.quantization.utils import (
40
+ KV_CACHE_TARGETS,
36
41
  infer_quantization_status,
42
+ is_kv_cache_quant_scheme,
37
43
  iter_named_leaf_modules,
38
44
  )
39
45
  from compressed_tensors.utils.helpers import fix_fsdp_module_name
@@ -45,7 +51,7 @@ __all__ = [
45
51
  "load_pretrained_quantization",
46
52
  "apply_quantization_config",
47
53
  "apply_quantization_status",
48
- "find_first_name_or_class_match",
54
+ "find_name_or_class_matches",
49
55
  ]
50
56
 
51
57
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -96,7 +102,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
96
102
  )
97
103
 
98
104
 
99
- def apply_quantization_config(model: Module, config: QuantizationConfig):
105
+ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
100
106
  """
101
107
  Initializes the model for quantization in-place based on the given config
102
108
 
@@ -106,6 +112,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
106
112
  # build mapping of targets to schemes for easier matching
107
113
  # use ordered dict to preserve target ordering in config
108
114
  target_to_scheme = OrderedDict()
115
+ config = process_quantization_config(config)
116
+ names_to_scheme = OrderedDict()
109
117
  for scheme in config.config_groups.values():
110
118
  for target in scheme.targets:
111
119
  target_to_scheme[target] = scheme
@@ -116,13 +124,16 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
116
124
  for name, submodule in iter_named_leaf_modules(model):
117
125
  # potentially fix module name to remove FSDP wrapper prefix
118
126
  name = fix_fsdp_module_name(name)
119
- if find_first_name_or_class_match(name, submodule, config.ignore):
127
+ if find_name_or_class_matches(name, submodule, config.ignore):
120
128
  ignored_submodules.append(name)
121
129
  continue # layer matches ignore list, continue
122
- target = find_first_name_or_class_match(name, submodule, target_to_scheme)
123
- if target is not None:
130
+ targets = find_name_or_class_matches(name, submodule, target_to_scheme)
131
+ if targets:
124
132
  # target matched - add layer and scheme to target list
125
- submodule.quantization_scheme = target_to_scheme[target]
133
+ submodule.quantization_scheme = _scheme_from_targets(
134
+ target_to_scheme, targets, name
135
+ )
136
+ names_to_scheme[name] = submodule.quantization_scheme.weights
126
137
 
127
138
  if config.ignore is not None and ignored_submodules is not None:
128
139
  if set(config.ignore) - set(ignored_submodules):
@@ -132,7 +143,42 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
132
143
  f"{set(config.ignore) - set(ignored_submodules)}"
133
144
  )
134
145
  # apply current quantization status across all targeted layers
146
+
135
147
  apply_quantization_status(model, config.quantization_status)
148
+ return names_to_scheme
149
+
150
+
151
+ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
152
+ """
153
+ Preprocess the raw QuantizationConfig
154
+
155
+ :param config: the raw QuantizationConfig
156
+ :return: the processed QuantizationConfig
157
+ """
158
+ if config.kv_cache_scheme is not None:
159
+ config = process_kv_cache_config(config)
160
+
161
+ return config
162
+
163
+
164
+ def process_kv_cache_config(
165
+ config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
166
+ ) -> QuantizationConfig:
167
+ """
168
+ Reformulate the `config.kv_cache` as a `config_group`
169
+ and add it to the set of existing `config.groups`
170
+
171
+ :param config: the QuantizationConfig
172
+ :return: the QuantizationConfig with additional "kv_cache" group
173
+ """
174
+ kv_cache_dict = config.kv_cache_scheme.model_dump()
175
+ kv_cache_scheme = QuantizationScheme(
176
+ output_activations=QuantizationArgs(**kv_cache_dict),
177
+ targets=targets,
178
+ )
179
+ kv_cache_group = dict(kv_cache=kv_cache_scheme)
180
+ config.config_groups.update(kv_cache_group)
181
+ return config
136
182
 
137
183
 
138
184
  def apply_quantization_status(model: Module, status: QuantizationStatus):
@@ -156,36 +202,45 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
156
202
  model.apply(compress_quantized_weights)
157
203
 
158
204
 
159
- def find_first_name_or_class_match(
205
+ def find_name_or_class_matches(
160
206
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
161
- ) -> Optional[str]:
162
- # first element of targets that matches the given name
163
- # if no name matches returns first target that matches the class name
164
- # returns None otherwise
207
+ ) -> List[str]:
208
+ """
209
+ Returns all targets that match the given name or the class name.
210
+ Returns empty list otherwise.
211
+ The order of the output `matches` list matters.
212
+ The entries are sorted in the following order:
213
+ 1. matches on exact strings
214
+ 2. matches on regex patterns
215
+ 3. matches on module names
216
+ """
217
+ targets = sorted(targets, key=lambda x: ("re:" in x, x))
165
218
  if isinstance(targets, Iterable):
166
- return _find_first_match(name, targets) or _find_first_match(
219
+ matches = _find_matches(name, targets) + _find_matches(
167
220
  module.__class__.__name__, targets, check_contains
168
221
  )
222
+ matches = [match for match in matches if match is not None]
223
+ return matches
169
224
 
170
225
 
171
- def _find_first_match(
226
+ def _find_matches(
172
227
  value: str, targets: Iterable[str], check_contains: bool = False
173
- ) -> Optional[str]:
174
- # returns first element of target that matches value either
228
+ ) -> List[str]:
229
+ # returns all the targets that match value either
175
230
  # exactly or as a regex after 're:'. if check_contains is set to True,
176
231
  # additionally checks if the target string is contained with value.
177
-
232
+ matches = []
178
233
  for target in targets:
179
234
  if target.startswith("re:"):
180
235
  pattern = target[3:]
181
236
  if re.match(pattern, value):
182
- return target
237
+ matches.append(target)
183
238
  elif check_contains:
184
239
  if target.lower() in value.lower():
185
- return target
240
+ matches.append(target)
186
241
  elif target == value:
187
- return target
188
- return None
242
+ matches.append(target)
243
+ return matches
189
244
 
190
245
 
191
246
  def _infer_status(model: Module) -> Optional[QuantizationStatus]:
@@ -223,3 +278,68 @@ def _load_quant_args_from_state_dict(
223
278
  zp.data = zp_from_state.to(device).to(zp.dtype)
224
279
  else: # fill with zeros matching scale shape
225
280
  zp.data = torch.zeros_like(scale, dtype=zp.dtype).to(device)
281
+
282
+
283
+ def _scheme_from_targets(
284
+ target_to_scheme: OrderedDictType[str, QuantizationScheme],
285
+ targets: List[str],
286
+ name: str,
287
+ ) -> QuantizationScheme:
288
+ if len(targets) == 1:
289
+ # if `targets` iterable contains a single element
290
+ # use it as the key
291
+ return target_to_scheme[targets[0]]
292
+
293
+ # otherwise, we need to merge QuantizationSchemes corresponding
294
+ # to multiple targets. This is most likely because `name` module
295
+ # is being target both as an ordinary quantization target, as well
296
+ # as kv cache quantization target
297
+ schemes_to_merge = [target_to_scheme[target] for target in targets]
298
+ return _merge_schemes(schemes_to_merge, name)
299
+
300
+
301
+ def _merge_schemes(
302
+ schemes_to_merge: List[QuantizationScheme], name: str
303
+ ) -> QuantizationScheme:
304
+
305
+ kv_cache_quantization_scheme = [
306
+ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
307
+ ]
308
+ if not kv_cache_quantization_scheme:
309
+ # if the schemes_to_merge do not contain any
310
+ # kv cache QuantizationScheme
311
+ # return the first scheme (the prioritized one,
312
+ # since the order of schemes_to_merge matters)
313
+ return schemes_to_merge[0]
314
+ else:
315
+ # fetch the kv cache QuantizationScheme and the highest
316
+ # priority non-kv cache QuantizationScheme and merge them
317
+ kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
318
+ quantization_scheme = [
319
+ scheme
320
+ for scheme in schemes_to_merge
321
+ if not is_kv_cache_quant_scheme(scheme)
322
+ ][0]
323
+ schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
324
+ merged_scheme = {}
325
+ for scheme in schemes_to_merge:
326
+ scheme_dict = {
327
+ k: v for k, v in scheme.model_dump().items() if v is not None
328
+ }
329
+ # when merging multiple schemes, the final target will be
330
+ # the `name` argument - hence erase the original targets
331
+ del scheme_dict["targets"]
332
+ # make sure that schemes do not "clash" with each other
333
+ overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
334
+ if overlapping_keys:
335
+ raise ValueError(
336
+ f"The module: {name} is being modified by two clashing "
337
+ f"quantization schemes, that jointly try to override "
338
+ f"properties: {overlapping_keys}. Fix the quantization config "
339
+ "so that it is not ambiguous."
340
+ )
341
+ merged_scheme.update(scheme_dict)
342
+
343
+ merged_scheme.update(targets=[name])
344
+
345
+ return QuantizationScheme(**merged_scheme)
@@ -16,6 +16,7 @@ from enum import Enum
16
16
  from typing import Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
20
  from compressed_tensors.quantization.quant_scheme import (
20
21
  QuantizationScheme,
21
22
  preset_name_to_scheme,
@@ -25,6 +26,7 @@ from compressed_tensors.quantization.utils import (
25
26
  is_module_quantized,
26
27
  iter_named_leaf_modules,
27
28
  module_type,
29
+ parse_out_kv_cache_args,
28
30
  )
29
31
  from pydantic import BaseModel, Field
30
32
  from torch.nn import Module
@@ -117,7 +119,18 @@ class QuantizationConfig(BaseModel):
117
119
  other quantization configs
118
120
  :param format: specifies how the quantized model is stored on disk
119
121
  :quantization_status: specifies the current status of all quantized layers. It is
120
- assumed all layers are in the same state.
122
+ assumed all layers are in the same state.
123
+ :param kv_cache_scheme: optional QuantizationArgs, that specify the
124
+ quantization of the kv cache. If None, kv cache is not quantized.
125
+ When applying kv cache quantization to transformer AutoModelForCausalLM,
126
+ the kv_cache_scheme gets converted into a QuantizationScheme that:
127
+ - targets the `q_proj` and `k_proj` modules of the model. The outputs
128
+ of those modules are the keys and values that might be cached
129
+ - quantizes the outputs of the aformentioned layers, so that
130
+ keys and values are compressed before storing them in the cache
131
+ There is an explicit assumption that the model contains modules with
132
+ `k_proj` and `v_proj` in their names. If this is not the case
133
+ and kv_cache_scheme != None, the quantization of kv cache will fail
121
134
  :global_compression_ratio: optional informational config to report the model
122
135
  compression ratio acheived by the quantization config
123
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
@@ -126,6 +139,7 @@ class QuantizationConfig(BaseModel):
126
139
 
127
140
  config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
128
141
  quant_method: str = DEFAULT_QUANTIZATION_METHOD
142
+ kv_cache_scheme: Optional[QuantizationArgs] = None
129
143
  format: str = DEFAULT_QUANTIZATION_FORMAT
130
144
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
131
145
  global_compression_ratio: Optional[float] = None
@@ -154,7 +168,7 @@ class QuantizationConfig(BaseModel):
154
168
  ) -> Optional["QuantizationConfig"]:
155
169
  """
156
170
  Converts a model into its associated QuantizationConfig based on the
157
- QuantizationScheme attached to each quanitzed module
171
+ QuantizationScheme attached to each quantized module
158
172
 
159
173
  :param model: model to calculate quantization scheme of
160
174
  :return: filled out QuantizationScheme for the input model
@@ -195,6 +209,13 @@ class QuantizationConfig(BaseModel):
195
209
  # else we leave it off the ignore list, doesn't fall under any of the
196
210
  # existing quantization schemes so it won't be quantized
197
211
 
212
+ kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
213
+ quant_scheme_to_layers
214
+ )
215
+ kv_cache_scheme = (
216
+ kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
217
+ )
218
+
198
219
  config_groups = {}
199
220
  for idx, scheme in enumerate(quant_scheme_to_layers):
200
221
  group_name = "group_" + str(idx)
@@ -213,6 +234,7 @@ class QuantizationConfig(BaseModel):
213
234
  return QuantizationConfig(
214
235
  config_groups=config_groups,
215
236
  quantization_status=quantization_status,
237
+ kv_cache_scheme=kv_cache_scheme,
216
238
  global_compression_ratio=compression_ratio,
217
239
  format=format,
218
240
  ignore=consolidated_ignore,
@@ -17,7 +17,6 @@ from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.quantization.quant_args import (
19
19
  QuantizationArgs,
20
- QuantizationStrategy,
21
20
  QuantizationType,
22
21
  )
23
22
  from pydantic import BaseModel
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Optional, Tuple
16
+ import re
17
+ from typing import List, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.quantization.observers.base import Observer
21
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
22
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
20
23
  from torch.nn import Module
21
24
  from tqdm import tqdm
22
25
 
@@ -30,8 +33,12 @@ __all__ = [
30
33
  "calculate_compression_ratio",
31
34
  "get_torch_bit_depth",
32
35
  "can_quantize",
36
+ "parse_out_kv_cache_args",
37
+ "KV_CACHE_TARGETS",
38
+ "is_kv_cache_quant_scheme",
33
39
  ]
34
40
 
41
+ KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
35
42
  _LOGGER: logging.Logger = logging.getLogger(__name__)
36
43
 
37
44
 
@@ -182,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
182
189
  total_uncompressed += uncompressed_bits * num_weights
183
190
 
184
191
  return total_uncompressed / total_compressed
192
+
193
+
194
+ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
195
+ """
196
+ Check whether the QuantizationScheme targets the kv cache.
197
+ It does if all the following criteria are met:
198
+ - the scheme targets either exactly match the KV_CACHE_TARGETS
199
+ or the match KV_CACHE_TARGETS regex pattern
200
+ - the scheme quantizes output_activations (we want to quantize the
201
+ outputs from the KV_CACHE_TARGETS, as their correspond to the
202
+ keys and values that are to be saved in the cache)
203
+
204
+ :param scheme: The QuantizationScheme to investigate
205
+ :return: boolean flag
206
+ """
207
+ if len(scheme.targets) == 1:
208
+ # match on the KV_CACHE_TARGETS regex pattern
209
+ # if there is only one target
210
+ is_match_targets = any(
211
+ [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
+ )
213
+ else:
214
+ # match on the exact KV_CACHE_TARGETS
215
+ # if there are multiple targets
216
+ is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
217
+
218
+ is_match_output_activations = scheme.output_activations is not None
219
+ return is_match_targets and is_match_output_activations
220
+
221
+
222
+ def parse_out_kv_cache_args(
223
+ quant_scheme_to_layers: List[QuantizationScheme],
224
+ ) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
225
+ """
226
+ If possible, parse out the kv cache specific QuantizationArgs
227
+ from the list of the QuantizationSchemes. If no kv cache
228
+ specific QuantizationArgs available, this function acts
229
+ as an identity function
230
+
231
+ :param quant_scheme_to_layers: list of QuantizationSchemes
232
+ :return: kv_cache_args (optional) and the (remaining or original)
233
+ list of the QuantizationSchemes
234
+ """
235
+ kv_cache_quant_scheme_to_layers = [
236
+ scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
237
+ ]
238
+ quant_scheme_to_layers = [
239
+ scheme
240
+ for scheme in quant_scheme_to_layers
241
+ if not is_kv_cache_quant_scheme(scheme)
242
+ ]
243
+
244
+ if kv_cache_quant_scheme_to_layers:
245
+ kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
246
+ kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
247
+ else:
248
+ kv_cache_args = None
249
+
250
+ return kv_cache_args, quant_scheme_to_layers
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from typing import Optional
17
16
 
18
17
  from transformers import AutoConfig
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.4.0.20240623
3
+ Version: 0.4.0.20240627
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,15 +1,15 @@
1
1
  compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6hY,789
2
- compressed_tensors/base.py,sha256=OA2TOLP1gP3LSH7gp508eqr2ZtDQ-pqRHElCp-aB0vs,755
2
+ compressed_tensors/base.py,sha256=Mq4mfVQcJhNpha-BXzpOfpmFIdl01o09BJE7D2oQ_00,796
3
3
  compressed_tensors/version.py,sha256=cJJf0y0NnXErTtQtVQjOvrq9hMIkhXIfBwuu4Tuxl24,1586
4
4
  compressed_tensors/compressors/__init__.py,sha256=wmX4VnkUTS63xBwK5-6w8FP78bNZpcdcqvf2KOEC5E4,1133
5
- compressed_tensors/compressors/base.py,sha256=LWEgbpgTxzmoqQ7Xhq2OQszUgWoDtFuGCiV1Y8nlBGw,2134
6
- compressed_tensors/compressors/dense.py,sha256=G_XHbvuENyupIKlXSITOQgvPkNkcMEOLcLWQr70V9EE,1257
5
+ compressed_tensors/compressors/base.py,sha256=-rqT2h9G2iwDkwrVj0d0jxxn9h0dccJA1mqOzVEkwGM,2144
6
+ compressed_tensors/compressors/dense.py,sha256=xcWECjcRY4INN6jC7vHx5wvUX3NmnKlxA9SVE1A6m2Q,1267
7
7
  compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
8
- compressed_tensors/compressors/marlin_24.py,sha256=X_BjtFB3Mn0hqiLz56UM3jGX2eNmGLnvEIPfbg7di6U,9444
9
- compressed_tensors/compressors/model_compressor.py,sha256=83AWAhlrR3QTNelfMGCh_10G-VfMIRXRTvV0ZZinCU8,13338
10
- compressed_tensors/compressors/naive_quantized.py,sha256=N3y5LxsCaTUJHT30sqEhnviZsyoz1v2eUaayE7-f8Xs,5562
11
- compressed_tensors/compressors/pack_quantized.py,sha256=ODb03_WaBQ1l99Gmp49olAUZ2TB_67z9qNZbc56X7NU,8275
12
- compressed_tensors/compressors/sparse_bitmask.py,sha256=H9oZSTYI1oRCzAMbd4zThUnZd1h2rfs8DmA3tPcvuNE,8637
8
+ compressed_tensors/compressors/marlin_24.py,sha256=PULMP1fp1sNWz-xOxvM0JXhOrUbq6sPwOTscYSifgDw,9450
9
+ compressed_tensors/compressors/model_compressor.py,sha256=t4dH7Yh637JV53VPyys-gkoMPJHGf_tlWWufLRyIdUM,13418
10
+ compressed_tensors/compressors/naive_quantized.py,sha256=6_1wuTF96-lw-UzzrsiEX_ipciKiQQJoZ8uotVwtbyQ,5569
11
+ compressed_tensors/compressors/pack_quantized.py,sha256=ZRqqBVPB6B-nZQOSdu7WhKrKWIm2-ZVrUQHATxO2Boc,10297
12
+ compressed_tensors/compressors/sparse_bitmask.py,sha256=kiDwBlFV0sJGLcIdDYxIiuF64ccgwDfqq1hWRQThYDc,8647
13
13
  compressed_tensors/compressors/utils/__init__.py,sha256=-mbGDZh1hd9T6u62Ht_iBIK255UmMg0f5bLkSs1f9Cc,731
14
14
  compressed_tensors/compressors/utils/helpers.py,sha256=4fq7KclSIK__jemCG9pwYlgWLrQjsaAMxhIrhjdw0BQ,1506
15
15
  compressed_tensors/compressors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
@@ -20,10 +20,10 @@ compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74j
20
20
  compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
21
21
  compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
22
22
  compressed_tensors/quantization/quant_args.py,sha256=Vc_tWSTcbZZsMJlACpLq4JEPvGx87izc8VEx-mcXjoM,5621
23
- compressed_tensors/quantization/quant_config.py,sha256=hL42sXp1wAZxyrkHarw7tAMRcwSVEr0MT3wmrmL3NhE,8285
24
- compressed_tensors/quantization/quant_scheme.py,sha256=Yhaj3QJn4lifGMoQ8mlXXOdLDZA6iGMthb_0hlAzvVk,3811
23
+ compressed_tensors/quantization/quant_config.py,sha256=PU3BchHm09ks6_yAderrHoIZI07zBlU9ejC87v3A-54,9568
24
+ compressed_tensors/quantization/quant_scheme.py,sha256=TU9W3bOWCY2l5Vrha0ufRtW1ac4gew1uwW8N3JGbZvg,3785
25
25
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=ggRGWRqhCxCaTTDWRcgTVX3axnS2xV6rc5YvdzK7fSg,798
26
- compressed_tensors/quantization/lifecycle/apply.py,sha256=eQfuIGcX6KBKeMta1svviXXRpKO3og2CRrxhKlGcE_k,8756
26
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=fyv5ujZC0__oG1ESOTmMyMsKK7DGAxG7uQI7_sxT7Mw,13308
27
27
  compressed_tensors/quantization/lifecycle/calibration.py,sha256=mLns4jlaWmBwOW8Jtlm5bMX-JET1AiZYUBO7qa-XuxI,1776
28
28
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=VreB10xPwgSLQQlTu20UCrFpRS--cA7-lx5s7nrPPrg,2247
29
29
  compressed_tensors/quantization/lifecycle/forward.py,sha256=tcjL_qyE3ODourNprt2bndF7_ALlUEGY2_Yag4exLoE,11908
@@ -35,14 +35,14 @@ compressed_tensors/quantization/observers/helpers.py,sha256=DSNGNJpZyT2Lyu0c82dH
35
35
  compressed_tensors/quantization/observers/memoryless.py,sha256=jH_c6K3gxf4W3VNXQ7tbnP-J_86QTrEfjBn6Kh1C-H8,2165
36
36
  compressed_tensors/quantization/observers/min_max.py,sha256=UK7zCMzxv9GGn6BflBxdajV20RiWaCY2RHcvZodCP1w,3669
37
37
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
38
- compressed_tensors/quantization/utils/helpers.py,sha256=NzAH18Cn_-mTAR87y6IlcQU5gC393XSjgNKC9CRkr78,6017
38
+ compressed_tensors/quantization/utils/helpers.py,sha256=YjXABJQUnelof-z7qcwck6fnrFLh4uMSrOmPiqNp_RY,8591
39
39
  compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
40
40
  compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
41
41
  compressed_tensors/utils/__init__.py,sha256=5DrYjoZbaEvSkJcC-GRSbM_RBHVF4tG9gMd3zsJnjLw,665
42
- compressed_tensors/utils/helpers.py,sha256=5ull5yFT31M2zVxKeFvpvvlvX5f1Sk1LGuj_wrfZWCY,2267
42
+ compressed_tensors/utils/helpers.py,sha256=dt4uxSIeqvqDmeJBJ6UUVHEOnMI7EtMSzEDv6PRUu14,2266
43
43
  compressed_tensors/utils/safetensors_load.py,sha256=0MheXwx1jeY12PeISppiSIZHs6rmN2YddwPpFb9V67I,8527
44
- compressed_tensors_nightly-0.4.0.20240623.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
- compressed_tensors_nightly-0.4.0.20240623.dist-info/METADATA,sha256=TKdmWA3qynRUK6FyOoPODvDpc8DB0sKjjiX2hN3uU7A,5668
46
- compressed_tensors_nightly-0.4.0.20240623.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
47
- compressed_tensors_nightly-0.4.0.20240623.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
48
- compressed_tensors_nightly-0.4.0.20240623.dist-info/RECORD,,
44
+ compressed_tensors_nightly-0.4.0.20240627.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
45
+ compressed_tensors_nightly-0.4.0.20240627.dist-info/METADATA,sha256=pRkLnBBttymxaUP8mHpKe_NQ4Mfa6gV3TMoBj6o3NCU,5668
46
+ compressed_tensors_nightly-0.4.0.20240627.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
47
+ compressed_tensors_nightly-0.4.0.20240627.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
48
+ compressed_tensors_nightly-0.4.0.20240627.dist-info/RECORD,,