compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +200 -8
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +101 -13
  7. compressed_tensors/compressors/naive_quantized.py +140 -0
  8. compressed_tensors/compressors/pack_quantized.py +128 -132
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +8 -1
  11. compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
  12. compressed_tensors/linear/compressed_linear.py +87 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +204 -44
  15. compressed_tensors/quantization/lifecycle/calibration.py +22 -2
  16. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  17. compressed_tensors/quantization/lifecycle/forward.py +139 -61
  18. compressed_tensors/quantization/lifecycle/helpers.py +80 -0
  19. compressed_tensors/quantization/lifecycle/initialize.py +77 -13
  20. compressed_tensors/quantization/observers/__init__.py +1 -0
  21. compressed_tensors/quantization/observers/base.py +93 -14
  22. compressed_tensors/quantization/observers/helpers.py +64 -11
  23. compressed_tensors/quantization/observers/min_max.py +8 -0
  24. compressed_tensors/quantization/observers/mse.py +162 -0
  25. compressed_tensors/quantization/quant_args.py +139 -23
  26. compressed_tensors/quantization/quant_config.py +35 -2
  27. compressed_tensors/quantization/quant_scheme.py +112 -13
  28. compressed_tensors/quantization/utils/helpers.py +68 -2
  29. compressed_tensors/utils/__init__.py +5 -0
  30. compressed_tensors/utils/helpers.py +44 -2
  31. compressed_tensors/utils/offload.py +116 -0
  32. compressed_tensors/utils/permute.py +70 -0
  33. compressed_tensors/utils/safetensors_load.py +2 -0
  34. compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
  35. compressed_tensors/version.py +1 -1
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
  37. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  38. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  39. compressed_tensors/compressors/int_quantized.py +0 -126
  40. compressed_tensors/compressors/utils/helpers.py +0 -43
  41. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  42. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  43. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  44. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,14 @@
14
14
 
15
15
  import logging
16
16
  import re
17
- from collections import OrderedDict
18
- from typing import Dict, Iterable, Optional
17
+ from collections import OrderedDict, defaultdict
18
+ from copy import deepcopy
19
+ from typing import Dict, Iterable, List, Optional
20
+ from typing import OrderedDict as OrderedDictType
21
+ from typing import Union
19
22
 
20
23
  import torch
24
+ from compressed_tensors.config import CompressionFormat
21
25
  from compressed_tensors.quantization.lifecycle.calibration import (
22
26
  set_module_for_calibration,
23
27
  )
@@ -28,15 +32,20 @@ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quant
28
32
  from compressed_tensors.quantization.lifecycle.initialize import (
29
33
  initialize_module_for_quantization,
30
34
  )
35
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
31
36
  from compressed_tensors.quantization.quant_config import (
32
37
  QuantizationConfig,
33
38
  QuantizationStatus,
34
39
  )
40
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35
41
  from compressed_tensors.quantization.utils import (
42
+ KV_CACHE_TARGETS,
36
43
  infer_quantization_status,
44
+ is_kv_cache_quant_scheme,
37
45
  iter_named_leaf_modules,
38
46
  )
39
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
48
+ from compressed_tensors.utils.offload import update_parameter_data
40
49
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
41
50
  from torch.nn import Module
42
51
 
@@ -45,7 +54,7 @@ __all__ = [
45
54
  "load_pretrained_quantization",
46
55
  "apply_quantization_config",
47
56
  "apply_quantization_status",
48
- "find_first_name_or_class_match",
57
+ "find_name_or_class_matches",
49
58
  ]
50
59
 
51
60
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -96,33 +105,64 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
96
105
  )
97
106
 
98
107
 
99
- def apply_quantization_config(model: Module, config: QuantizationConfig):
108
+ def apply_quantization_config(
109
+ model: Module, config: QuantizationConfig, run_compressed: bool = False
110
+ ) -> Dict:
100
111
  """
101
112
  Initializes the model for quantization in-place based on the given config
102
113
 
103
114
  :param model: model to apply quantization config to
104
115
  :param config: quantization config
116
+ :param run_compressed: Whether the model will be run in compressed mode or
117
+ decompressed fully on load
105
118
  """
119
+ # remove reference to the original `config`
120
+ # argument. This function can mutate it, and we'd
121
+ # like to keep the original `config` as it is.
122
+ config = deepcopy(config)
106
123
  # build mapping of targets to schemes for easier matching
107
124
  # use ordered dict to preserve target ordering in config
108
125
  target_to_scheme = OrderedDict()
126
+ config = process_quantization_config(config)
127
+ names_to_scheme = OrderedDict()
109
128
  for scheme in config.config_groups.values():
110
129
  for target in scheme.targets:
111
130
  target_to_scheme[target] = scheme
112
131
 
132
+ if run_compressed:
133
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
134
+
113
135
  # list of submodules to ignore
114
- ignored_submodules = []
136
+ ignored_submodules = defaultdict(list)
115
137
  # mark appropriate layers for quantization by setting their quantization schemes
116
138
  for name, submodule in iter_named_leaf_modules(model):
117
139
  # potentially fix module name to remove FSDP wrapper prefix
118
140
  name = fix_fsdp_module_name(name)
119
- if find_first_name_or_class_match(name, submodule, config.ignore):
120
- ignored_submodules.append(name)
141
+ if matches := find_name_or_class_matches(name, submodule, config.ignore):
142
+ for match in matches:
143
+ ignored_submodules[match].append(name)
121
144
  continue # layer matches ignore list, continue
122
- target = find_first_name_or_class_match(name, submodule, target_to_scheme)
123
- if target is not None:
145
+ targets = find_name_or_class_matches(name, submodule, target_to_scheme)
146
+ if targets:
147
+ scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
+ if run_compressed:
149
+ format = config.format
150
+ if format != CompressionFormat.dense.value:
151
+ if isinstance(submodule, torch.nn.Linear):
152
+ # TODO: expand to more module types
153
+ compressed_linear = CompressedLinear.from_linear(
154
+ submodule,
155
+ quantization_scheme=scheme,
156
+ quantization_format=format,
157
+ )
158
+ replace_module(model, name, compressed_linear)
159
+
124
160
  # target matched - add layer and scheme to target list
125
- submodule.quantization_scheme = target_to_scheme[target]
161
+ submodule.quantization_scheme = _scheme_from_targets(
162
+ target_to_scheme, targets, name
163
+ )
164
+
165
+ names_to_scheme[name] = submodule.quantization_scheme.weights
126
166
 
127
167
  if config.ignore is not None and ignored_submodules is not None:
128
168
  if set(config.ignore) - set(ignored_submodules):
@@ -131,8 +171,43 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
131
171
  "not found in the model: "
132
172
  f"{set(config.ignore) - set(ignored_submodules)}"
133
173
  )
174
+
134
175
  # apply current quantization status across all targeted layers
135
176
  apply_quantization_status(model, config.quantization_status)
177
+ return names_to_scheme
178
+
179
+
180
+ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
181
+ """
182
+ Preprocess the raw QuantizationConfig
183
+
184
+ :param config: the raw QuantizationConfig
185
+ :return: the processed QuantizationConfig
186
+ """
187
+ if config.kv_cache_scheme is not None:
188
+ config = process_kv_cache_config(config)
189
+
190
+ return config
191
+
192
+
193
+ def process_kv_cache_config(
194
+ config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
195
+ ) -> QuantizationConfig:
196
+ """
197
+ Reformulate the `config.kv_cache` as a `config_group`
198
+ and add it to the set of existing `config.groups`
199
+
200
+ :param config: the QuantizationConfig
201
+ :return: the QuantizationConfig with additional "kv_cache" group
202
+ """
203
+ kv_cache_dict = config.kv_cache_scheme.model_dump()
204
+ kv_cache_scheme = QuantizationScheme(
205
+ output_activations=QuantizationArgs(**kv_cache_dict),
206
+ targets=targets,
207
+ )
208
+ kv_cache_group = dict(kv_cache=kv_cache_scheme)
209
+ config.config_groups.update(kv_cache_group)
210
+ return config
136
211
 
137
212
 
138
213
  def apply_quantization_status(model: Module, status: QuantizationStatus):
@@ -145,10 +220,22 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
145
220
  current_status = infer_quantization_status(model)
146
221
 
147
222
  if status >= QuantizationStatus.INITIALIZED > current_status:
148
- model.apply(initialize_module_for_quantization)
223
+ force_zero_point_init = status != QuantizationStatus.COMPRESSED
224
+ model.apply(
225
+ lambda module: initialize_module_for_quantization(
226
+ module, force_zero_point=force_zero_point_init
227
+ )
228
+ )
149
229
 
150
230
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
151
- model.apply(set_module_for_calibration)
231
+ # only quantize weights up front when our end goal state is calibration,
232
+ # weight quantization parameters are already loaded for frozen/compressed
233
+ quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
234
+ model.apply(
235
+ lambda module: set_module_for_calibration(
236
+ module, quantize_weights_upfront=quantize_weights_upfront
237
+ )
238
+ )
152
239
  if current_status < status >= QuantizationStatus.FROZEN > current_status:
153
240
  model.apply(freeze_module_quantization)
154
241
 
@@ -156,36 +243,45 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
156
243
  model.apply(compress_quantized_weights)
157
244
 
158
245
 
159
- def find_first_name_or_class_match(
246
+ def find_name_or_class_matches(
160
247
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
161
- ) -> Optional[str]:
162
- # first element of targets that matches the given name
163
- # if no name matches returns first target that matches the class name
164
- # returns None otherwise
248
+ ) -> List[str]:
249
+ """
250
+ Returns all targets that match the given name or the class name.
251
+ Returns empty list otherwise.
252
+ The order of the output `matches` list matters.
253
+ The entries are sorted in the following order:
254
+ 1. matches on exact strings
255
+ 2. matches on regex patterns
256
+ 3. matches on module names
257
+ """
258
+ targets = sorted(targets, key=lambda x: ("re:" in x, x))
165
259
  if isinstance(targets, Iterable):
166
- return _find_first_match(name, targets) or _find_first_match(
260
+ matches = _find_matches(name, targets) + _find_matches(
167
261
  module.__class__.__name__, targets, check_contains
168
262
  )
263
+ matches = [match for match in matches if match is not None]
264
+ return matches
169
265
 
170
266
 
171
- def _find_first_match(
267
+ def _find_matches(
172
268
  value: str, targets: Iterable[str], check_contains: bool = False
173
- ) -> Optional[str]:
174
- # returns first element of target that matches value either
269
+ ) -> List[str]:
270
+ # returns all the targets that match value either
175
271
  # exactly or as a regex after 're:'. if check_contains is set to True,
176
272
  # additionally checks if the target string is contained with value.
177
-
273
+ matches = []
178
274
  for target in targets:
179
275
  if target.startswith("re:"):
180
276
  pattern = target[3:]
181
277
  if re.match(pattern, value):
182
- return target
278
+ matches.append(target)
183
279
  elif check_contains:
184
280
  if target.lower() in value.lower():
185
- return target
281
+ matches.append(target)
186
282
  elif target == value:
187
- return target
188
- return None
283
+ matches.append(target)
284
+ return matches
189
285
 
190
286
 
191
287
  def _infer_status(model: Module) -> Optional[QuantizationStatus]:
@@ -210,20 +306,84 @@ def _load_quant_args_from_state_dict(
210
306
  """
211
307
  scale_name = f"{base_name}_scale"
212
308
  zp_name = f"{base_name}_zero_point"
213
- device = next(module.parameters()).device
214
-
215
- scale = getattr(module, scale_name, None)
216
- zp = getattr(module, zp_name, None)
217
- if scale is not None:
218
- state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
219
- if state_dict_scale is not None:
220
- scale.data = state_dict_scale.to(device).to(scale.dtype)
221
- else:
222
- scale.data = scale.data.to(device)
223
-
224
- if zp is not None:
225
- zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
226
- if zp_from_state is not None: # load the non-zero zero points
227
- zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
228
- else: # fill with zeros matching scale shape
229
- zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device)
309
+ g_idx_name = f"{base_name}_g_idx"
310
+
311
+ state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
312
+ state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
313
+ state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
314
+
315
+ if state_dict_scale is not None:
316
+ # module is quantized
317
+ update_parameter_data(module, state_dict_scale, scale_name)
318
+ if state_dict_zp is None:
319
+ # fill in zero point for symmetric quantization
320
+ state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
321
+ update_parameter_data(module, state_dict_zp, zp_name)
322
+
323
+ if state_dict_g_idx is not None:
324
+ update_parameter_data(module, state_dict_g_idx, g_idx_name)
325
+
326
+
327
+ def _scheme_from_targets(
328
+ target_to_scheme: OrderedDictType[str, QuantizationScheme],
329
+ targets: List[str],
330
+ name: str,
331
+ ) -> QuantizationScheme:
332
+ if len(targets) == 1:
333
+ # if `targets` iterable contains a single element
334
+ # use it as the key
335
+ return target_to_scheme[targets[0]]
336
+
337
+ # otherwise, we need to merge QuantizationSchemes corresponding
338
+ # to multiple targets. This is most likely because `name` module
339
+ # is being target both as an ordinary quantization target, as well
340
+ # as kv cache quantization target
341
+ schemes_to_merge = [target_to_scheme[target] for target in targets]
342
+ return _merge_schemes(schemes_to_merge, name)
343
+
344
+
345
+ def _merge_schemes(
346
+ schemes_to_merge: List[QuantizationScheme], name: str
347
+ ) -> QuantizationScheme:
348
+
349
+ kv_cache_quantization_scheme = [
350
+ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
351
+ ]
352
+ if not kv_cache_quantization_scheme:
353
+ # if the schemes_to_merge do not contain any
354
+ # kv cache QuantizationScheme
355
+ # return the first scheme (the prioritized one,
356
+ # since the order of schemes_to_merge matters)
357
+ return schemes_to_merge[0]
358
+ else:
359
+ # fetch the kv cache QuantizationScheme and the highest
360
+ # priority non-kv cache QuantizationScheme and merge them
361
+ kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
362
+ quantization_scheme = [
363
+ scheme
364
+ for scheme in schemes_to_merge
365
+ if not is_kv_cache_quant_scheme(scheme)
366
+ ][0]
367
+ schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
368
+ merged_scheme = {}
369
+ for scheme in schemes_to_merge:
370
+ scheme_dict = {
371
+ k: v for k, v in scheme.model_dump().items() if v is not None
372
+ }
373
+ # when merging multiple schemes, the final target will be
374
+ # the `name` argument - hence erase the original targets
375
+ del scheme_dict["targets"]
376
+ # make sure that schemes do not "clash" with each other
377
+ overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
378
+ if overlapping_keys:
379
+ raise ValueError(
380
+ f"The module: {name} is being modified by two clashing "
381
+ f"quantization schemes, that jointly try to override "
382
+ f"properties: {overlapping_keys}. Fix the quantization config "
383
+ "so that it is not ambiguous."
384
+ )
385
+ merged_scheme.update(scheme_dict)
386
+
387
+ merged_scheme.update(targets=[name])
388
+
389
+ return QuantizationScheme(**merged_scheme)
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from compressed_tensors.quantization.quant_config import QuantizationStatus
19
+ from compressed_tensors.utils import is_module_offloaded, update_parameter_data
19
20
  from torch.nn import Module
20
21
 
21
22
 
@@ -27,7 +28,7 @@ __all__ = [
27
28
  _LOGGER = logging.getLogger(__name__)
28
29
 
29
30
 
30
- def set_module_for_calibration(module: Module):
31
+ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
31
32
  """
32
33
  marks a layer as ready for calibration which activates observers
33
34
  to update scales and zero points on each forward pass
@@ -35,17 +36,36 @@ def set_module_for_calibration(module: Module):
35
36
  apply to full model with `model.apply(set_module_for_calibration)`
36
37
 
37
38
  :param module: module to set for calibration
39
+ :param quantize_weights_upfront: whether to automatically
40
+ run weight quantization at the start of calibration
38
41
  """
39
42
  if not getattr(module, "quantization_scheme", None):
40
43
  # no quantization scheme nothing to do
41
44
  return
42
45
  status = getattr(module, "quantization_status", None)
43
46
  if not status or status != QuantizationStatus.INITIALIZED:
44
- raise _LOGGER.warning(
47
+ _LOGGER.warning(
45
48
  f"Attempting set module with status {status} to calibration mode. "
46
49
  f"but status is not {QuantizationStatus.INITIALIZED} - you may "
47
50
  "be calibrating an uninitialized module which may fail or attempting "
48
51
  "to re-calibrate a frozen module"
49
52
  )
50
53
 
54
+ if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
+ # set weight scale and zero_point up front, calibration data doesn't affect it
56
+ observer = module.weight_observer
57
+ g_idx = getattr(module, "weight_g_idx", None)
58
+
59
+ offloaded = False
60
+ if is_module_offloaded(module):
61
+ module._hf_hook.pre_forward(module)
62
+ offloaded = True
63
+
64
+ scale, zero_point = observer(module.weight, g_idx=g_idx)
65
+ update_parameter_data(module, scale, "weight_scale")
66
+ update_parameter_data(module, zero_point, "weight_zero_point")
67
+
68
+ if offloaded:
69
+ module._hf_hook.post_forward(module, None)
70
+
51
71
  module.quantization_status = QuantizationStatus.CALIBRATION
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
49
49
  weight = getattr(module, "weight", None)
50
50
  scale = getattr(module, "weight_scale", None)
51
51
  zero_point = getattr(module, "weight_zero_point", None)
52
+ g_idx = getattr(module, "weight_g_idx", None)
52
53
 
53
- if weight is None or scale is None or zero_point is None:
54
+ if weight is None or scale is None:
54
55
  # no weight, scale, or ZP, nothing to do
55
56
 
56
57
  # mark as compressed here to maintain consistent status throughout the model
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
62
63
  x=weight,
63
64
  scale=scale,
64
65
  zero_point=zero_point,
66
+ g_idx=g_idx,
65
67
  args=scheme.weights,
66
68
  dtype=torch.int8,
67
69
  )