compressed-tensors 0.7.1__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +17 -5
  2. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +4 -2
  3. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -0
  4. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  5. compressed_tensors/config/base.py +60 -2
  6. compressed_tensors/linear/compressed_linear.py +3 -1
  7. compressed_tensors/quantization/__init__.py +0 -1
  8. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  9. compressed_tensors/quantization/lifecycle/apply.py +3 -17
  10. compressed_tensors/quantization/lifecycle/forward.py +24 -87
  11. compressed_tensors/quantization/lifecycle/initialize.py +21 -24
  12. compressed_tensors/quantization/quant_args.py +27 -25
  13. compressed_tensors/quantization/quant_config.py +2 -2
  14. compressed_tensors/quantization/quant_scheme.py +17 -24
  15. compressed_tensors/quantization/utils/helpers.py +125 -8
  16. compressed_tensors/registry/registry.py +1 -1
  17. compressed_tensors/utils/helpers.py +33 -1
  18. compressed_tensors/version.py +1 -1
  19. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/METADATA +1 -1
  20. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/RECORD +23 -31
  21. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/WHEEL +1 -1
  22. compressed_tensors/quantization/cache.py +0 -201
  23. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  24. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  25. compressed_tensors/quantization/observers/__init__.py +0 -21
  26. compressed_tensors/quantization/observers/base.py +0 -213
  27. compressed_tensors/quantization/observers/helpers.py +0 -149
  28. compressed_tensors/quantization/observers/min_max.py +0 -104
  29. compressed_tensors/quantization/observers/mse.py +0 -162
  30. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/LICENSE +0 -0
  31. {compressed_tensors-0.7.1.dist-info → compressed_tensors-0.8.1.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,6 @@ import compressed_tensors
24
24
  import torch
25
25
  import transformers
26
26
  from compressed_tensors.base import (
27
- COMPRESSION_CONFIG_NAME,
28
27
  COMPRESSION_VERSION_NAME,
29
28
  QUANTIZATION_CONFIG_NAME,
30
29
  QUANTIZATION_METHOD_NAME,
@@ -39,6 +38,7 @@ from compressed_tensors.quantization import (
39
38
  apply_quantization_config,
40
39
  load_pretrained_quantization,
41
40
  )
41
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
42
42
  from compressed_tensors.quantization.utils import (
43
43
  is_module_quantized,
44
44
  iter_named_leaf_modules,
@@ -103,12 +103,14 @@ class ModelCompressor:
103
103
  :return: compressor for the configs, or None if model is not compressed
104
104
  """
105
105
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106
- compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
106
+ compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107
+
107
108
  return cls.from_compression_config(compression_config)
108
109
 
109
110
  @classmethod
110
111
  def from_compression_config(
111
- cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
112
+ cls,
113
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
112
114
  ):
113
115
  """
114
116
  :param compression_config:
@@ -265,7 +267,11 @@ class ModelCompressor:
265
267
  state_dict = model.state_dict()
266
268
 
267
269
  compressed_state_dict = state_dict
268
- quantized_modules_to_args = map_modules_to_quant_args(model)
270
+
271
+ quantized_modules_to_args: Dict[
272
+ str, QuantizationArgs
273
+ ] = map_modules_to_quant_args(model)
274
+
269
275
  if self.quantization_compressor is not None:
270
276
  compressed_state_dict = self.quantization_compressor.compress(
271
277
  state_dict, names_to_scheme=quantized_modules_to_args
@@ -369,7 +375,13 @@ class ModelCompressor:
369
375
  update_parameter_data(module, data, param_name)
370
376
 
371
377
 
372
- def map_modules_to_quant_args(model: Module) -> Dict:
378
+ def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
379
+ """
380
+ Given a pytorch model, map out the submodule name (usually linear layers)
381
+ to the QuantizationArgs
382
+
383
+ :param model: pytorch model
384
+ """
373
385
  quantized_modules_to_args = {}
374
386
  for name, submodule in iter_named_leaf_modules(model):
375
387
  if is_module_quantized(submodule):
@@ -93,9 +93,11 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
93
93
  args=quantization_args,
94
94
  dtype=quantization_args.pytorch_dtype(),
95
95
  )
96
+ else:
97
+ quantized_weight = weight
96
98
 
97
- if device is not None:
98
- quantized_weight = quantized_weight.to(device)
99
+ if device is not None:
100
+ quantized_weight = quantized_weight.to(device)
99
101
 
100
102
  return {"weight": quantized_weight}
101
103
 
@@ -94,6 +94,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
94
94
  args=quantization_args,
95
95
  dtype=torch.int8,
96
96
  )
97
+ else:
98
+ quantized_weight = weight
97
99
 
98
100
  packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
99
101
  weight_shape = torch.tensor(weight.shape)
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
238
238
  _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
239
239
 
240
240
  if (
241
- quantization_args.strategy is QuantizationStrategy.GROUP
241
+ quantization_args.strategy == QuantizationStrategy.GROUP
242
242
  and quantization_args.group_size < size_k
243
243
  ):
244
244
  scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
@@ -12,16 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from enum import Enum
15
+ from enum import Enum, unique
16
16
  from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
20
20
 
21
21
 
22
- __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
23
23
 
24
24
 
25
+ @unique
25
26
  class CompressionFormat(Enum):
26
27
  dense = "dense"
27
28
  sparse_bitmask = "sparse-bitmask"
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
32
33
  marlin_24 = "marlin-24"
33
34
 
34
35
 
36
+ @unique
37
+ class SparsityStructure(Enum):
38
+ """
39
+ An enumeration to represent different sparsity structures.
40
+
41
+ Attributes
42
+ ----------
43
+ TWO_FOUR : str
44
+ Represents a 2:4 sparsity structure.
45
+ ZERO_ZERO : str
46
+ Represents a 0:0 sparsity structure.
47
+ UNSTRUCTURED : str
48
+ Represents an unstructured sparsity structure.
49
+
50
+ Examples
51
+ --------
52
+ >>> SparsityStructure('2:4')
53
+ <SparsityStructure.TWO_FOUR: '2:4'>
54
+
55
+ >>> SparsityStructure('unstructured')
56
+ <SparsityStructure.UNSTRUCTURED: 'unstructured'>
57
+
58
+ >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
59
+ True
60
+
61
+ >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
62
+ True
63
+
64
+ >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
65
+ True
66
+
67
+ >>> SparsityStructure('invalid')
68
+ Traceback (most recent call last):
69
+ ...
70
+ ValueError: invalid is not a valid SparsityStructure
71
+ """
72
+
73
+ TWO_FOUR = "2:4"
74
+ UNSTRUCTURED = "unstructured"
75
+ ZERO_ZERO = "0:0"
76
+
77
+ def __new__(cls, value):
78
+ obj = object.__new__(cls)
79
+ obj._value_ = value.lower() if value is not None else value
80
+ return obj
81
+
82
+ @classmethod
83
+ def _missing_(cls, value):
84
+ # Handle None and case-insensitive values
85
+ if value is None:
86
+ return cls.UNSTRUCTURED
87
+ for member in cls:
88
+ if member.value == value.lower():
89
+ return member
90
+ raise ValueError(f"{value} is not a valid {cls.__name__}")
91
+
92
+
35
93
  class SparsityCompressionConfig(RegistryMixin, BaseModel):
36
94
  """
37
95
  Base data class for storing sparsity compression parameters
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Dict, Tuple
16
+
15
17
  import torch
16
18
  from compressed_tensors.compressors.base import BaseCompressor
17
19
  from compressed_tensors.quantization import (
@@ -53,7 +55,7 @@ class CompressedLinear(Linear):
53
55
  )
54
56
 
55
57
  # get the shape and dtype of compressed parameters
56
- compression_params = module.compressor.compression_param_info(
58
+ compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
57
59
  module.weight.shape, quantization_scheme.weights
58
60
  )
59
61
 
@@ -19,4 +19,3 @@ from .quant_args import *
19
19
  from .quant_config import *
20
20
  from .quant_scheme import *
21
21
  from .lifecycle import *
22
- from .cache import QuantizedKVParameterCache
@@ -15,9 +15,7 @@
15
15
  # flake8: noqa
16
16
  # isort: skip_file
17
17
 
18
- from .calibration import *
19
18
  from .forward import *
20
- from .frozen import *
21
19
  from .initialize import *
22
20
  from .compressed import *
23
21
  from .apply import *
@@ -22,13 +22,9 @@ from typing import Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
25
- from compressed_tensors.quantization.lifecycle.calibration import (
26
- set_module_for_calibration,
27
- )
28
25
  from compressed_tensors.quantization.lifecycle.compressed import (
29
26
  compress_quantized_weights,
30
27
  )
31
- from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
32
28
  from compressed_tensors.quantization.lifecycle.initialize import (
33
29
  initialize_module_for_quantization,
34
30
  )
@@ -110,7 +106,8 @@ def apply_quantization_config(
110
106
  model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
111
107
  ) -> OrderedDict:
112
108
  """
113
- Initializes the model for quantization in-place based on the given config
109
+ Initializes the model for quantization in-place based on the given config.
110
+ Optionally coverts quantizable modules to compressed_linear modules
114
111
 
115
112
  :param model: model to apply quantization config to
116
113
  :param config: quantization config
@@ -233,6 +230,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
233
230
  :param model: model to apply quantization to
234
231
  :param status: status to update the module to
235
232
  """
233
+
236
234
  current_status = infer_quantization_status(model)
237
235
 
238
236
  if status >= QuantizationStatus.INITIALIZED > current_status:
@@ -243,18 +241,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
243
241
  )
244
242
  )
245
243
 
246
- if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
247
- # only quantize weights up front when our end goal state is calibration,
248
- # weight quantization parameters are already loaded for frozen/compressed
249
- quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
250
- model.apply(
251
- lambda module: set_module_for_calibration(
252
- module, quantize_weights_upfront=quantize_weights_upfront
253
- )
254
- )
255
- if current_status < status >= QuantizationStatus.FROZEN > current_status:
256
- model.apply(freeze_module_quantization)
257
-
258
244
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
259
245
  model.apply(compress_quantized_weights)
260
246
 
@@ -14,14 +14,9 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Callable, Optional
17
+ from typing import Optional
18
18
 
19
19
  import torch
20
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
21
- from compressed_tensors.quantization.observers.helpers import (
22
- calculate_range,
23
- compute_dynamic_scales_and_zp,
24
- )
25
20
  from compressed_tensors.quantization.quant_args import (
26
21
  QuantizationArgs,
27
22
  QuantizationStrategy,
@@ -29,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
29
24
  )
30
25
  from compressed_tensors.quantization.quant_config import QuantizationStatus
31
26
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
- from compressed_tensors.utils import safe_permute, update_parameter_data
27
+ from compressed_tensors.quantization.utils import (
28
+ calculate_range,
29
+ compute_dynamic_scales_and_zp,
30
+ )
31
+ from compressed_tensors.utils import safe_permute
33
32
  from torch.nn import Module
34
33
 
35
34
 
@@ -38,7 +37,7 @@ __all__ = [
38
37
  "dequantize",
39
38
  "fake_quantize",
40
39
  "wrap_module_forward_quantized",
41
- "maybe_calibrate_or_quantize",
40
+ "forward_quantize",
42
41
  ]
43
42
 
44
43
 
@@ -275,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
275
274
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
276
275
 
277
276
  if scheme.input_activations is not None:
278
- # calibrate and (fake) quantize input activations when applicable
279
- input_ = maybe_calibrate_or_quantize(
280
- module, input_, "input", scheme.input_activations
281
- )
277
+ # prehook should calibrate activations before forward call
278
+ input_ = forward_quantize(module, input_, "input", scheme.input_activations)
282
279
 
283
280
  if scheme.weights is not None and not compressed:
284
281
  # calibrate and (fake) quantize weights when applicable
285
282
  unquantized_weight = self.weight.data.clone()
286
- self.weight.data = maybe_calibrate_or_quantize(
283
+ self.weight.data = forward_quantize(
287
284
  module, self.weight, "weight", scheme.weights
288
285
  )
289
286
 
@@ -291,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
291
288
  output = forward_func_orig.__get__(module, module.__class__)(
292
289
  input_, *args[1:], **kwargs
293
290
  )
294
- if scheme.output_activations is not None:
295
-
296
- # calibrate and (fake) quantize output activations when applicable
297
- # kv_cache scales updated on model self_attn forward call in
298
- # wrap_module_forward_quantized_attn
299
- output = maybe_calibrate_or_quantize(
300
- module, output, "output", scheme.output_activations
301
- )
302
291
 
303
292
  # restore back to unquantized_value
304
293
  if scheme.weights is not None and not compressed:
305
294
  self.weight.data = unquantized_weight
306
295
 
307
- return output
308
-
309
- # bind wrapped forward to module class so reference to `self` is correct
310
- bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
311
- # set forward to wrapped forward
312
- setattr(module, "forward", bound_wrapped_forward)
313
-
314
-
315
- def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
316
- # expects a module already initialized and injected with the parameters in
317
- # initialize_module_for_quantization
318
- if hasattr(module.forward, "__func__"):
319
- forward_func_orig = module.forward.__func__
320
- else:
321
- forward_func_orig = module.forward.func
322
-
323
- @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
324
- def wrapped_forward(self, *args, **kwargs):
325
-
326
- # kv cache stored under weights
327
- if module.quantization_status == QuantizationStatus.CALIBRATION:
328
- quantization_args: QuantizationArgs = scheme.output_activations
329
- past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
330
- kwargs["past_key_value"] = past_key_value
331
-
332
- # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
333
- # does not store quantized_key_states and quantized_value_state
334
- kwargs["use_cache"] = False
335
-
336
- attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
337
-
338
- past_key_value.reset_states()
339
-
340
- rtn = attn_forward(*args, **kwargs)
341
-
342
- update_parameter_data(
343
- module, past_key_value.k_scales[module.layer_idx], "k_scale"
344
- )
345
- update_parameter_data(
346
- module, past_key_value.v_scales[module.layer_idx], "v_scale"
296
+ if scheme.output_activations is not None:
297
+ # forward-hook should calibrate/forward_quantize
298
+ if (
299
+ module.quantization_status == QuantizationStatus.CALIBRATION
300
+ and not scheme.output_activations.dynamic
301
+ ):
302
+ return output
303
+
304
+ output = forward_quantize(
305
+ module, output, "output", scheme.output_activations
347
306
  )
348
-
349
- return rtn
350
-
351
- return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
307
+ return output
352
308
 
353
309
  # bind wrapped forward to module class so reference to `self` is correct
354
310
  bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
@@ -356,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
356
312
  setattr(module, "forward", bound_wrapped_forward)
357
313
 
358
314
 
359
- def maybe_calibrate_or_quantize(
315
+ def forward_quantize(
360
316
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
361
317
  ) -> torch.Tensor:
362
- # don't run quantization if we haven't entered calibration mode
363
- if module.quantization_status == QuantizationStatus.INITIALIZED:
364
- return value
365
318
 
366
319
  # in compressed mode, the weight is already compressed and quantized so we don't
367
320
  # need to run fake quantization
@@ -379,29 +332,13 @@ def maybe_calibrate_or_quantize(
379
332
  g_idx = getattr(module, "weight_g_idx", None)
380
333
 
381
334
  if args.dynamic:
382
- # dynamic quantization - no need to invoke observer
335
+ # dynamic quantization - determine the scale/zp on the fly
383
336
  scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
384
337
  else:
385
- # static quantization - get previous scale and zero point from layer
338
+ # static quantization - get scale and zero point from layer
386
339
  scale = getattr(module, f"{base_name}_scale")
387
340
  zero_point = getattr(module, f"{base_name}_zero_point", None)
388
341
 
389
- if (
390
- module.quantization_status == QuantizationStatus.CALIBRATION
391
- and base_name != "weight"
392
- ):
393
- # calibration mode - get new quant params from observer
394
- observer = getattr(module, f"{base_name}_observer")
395
-
396
- updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
397
-
398
- # update scale and zero point
399
- update_parameter_data(module, updated_scale, f"{base_name}_scale")
400
- update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
401
-
402
- scale = updated_scale
403
- zero_point = updated_zero_point
404
-
405
342
  return fake_quantize(
406
343
  x=value,
407
344
  scale=scale,
@@ -14,13 +14,12 @@
14
14
 
15
15
 
16
16
  import logging
17
+ from enum import Enum
17
18
  from typing import Optional
18
19
 
19
20
  import torch
20
- from compressed_tensors.quantization.cache import KVCacheScaleType
21
21
  from compressed_tensors.quantization.lifecycle.forward import (
22
22
  wrap_module_forward_quantized,
23
- wrap_module_forward_quantized_attn,
24
23
  )
25
24
  from compressed_tensors.quantization.quant_args import (
26
25
  ActivationOrdering,
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
36
35
 
37
36
  __all__ = [
38
37
  "initialize_module_for_quantization",
38
+ "is_attention_module",
39
+ "KVCacheScaleType",
39
40
  ]
40
41
 
41
42
 
42
43
  _LOGGER = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class KVCacheScaleType(Enum):
47
+ KEY = "k_scale"
48
+ VALUE = "v_scale"
49
+
50
+
45
51
  def initialize_module_for_quantization(
46
52
  module: Module,
47
53
  scheme: Optional[QuantizationScheme] = None,
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
66
72
  return
67
73
 
68
74
  if is_attention_module(module):
69
- # wrap forward call of module to perform
70
75
  # quantized actions based on calltime status
71
- wrap_module_forward_quantized_attn(module, scheme)
72
76
  _initialize_attn_scales(module)
73
77
 
74
78
  else:
75
79
 
76
80
  if scheme.input_activations is not None:
77
- _initialize_scale_zero_point_observer(
81
+ _initialize_scale_zero_point(
78
82
  module,
79
83
  "input",
80
84
  scheme.input_activations,
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
85
89
  weight_shape = None
86
90
  if isinstance(module, torch.nn.Linear):
87
91
  weight_shape = module.weight.shape
88
- _initialize_scale_zero_point_observer(
92
+ _initialize_scale_zero_point(
89
93
  module,
90
94
  "weight",
91
95
  scheme.weights,
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
101
105
 
102
106
  if scheme.output_activations is not None:
103
107
  if not is_kv_cache_quant_scheme(scheme):
104
- _initialize_scale_zero_point_observer(
108
+ _initialize_scale_zero_point(
105
109
  module, "output", scheme.output_activations
106
110
  )
107
111
 
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
109
113
  module.quantization_status = QuantizationStatus.INITIALIZED
110
114
 
111
115
  offloaded = False
116
+ # What is this doing/why isn't this in the attn case?
112
117
  if is_module_offloaded(module):
113
118
  try:
114
119
  from accelerate.hooks import add_hook_to_module, remove_hook_from_module
@@ -146,21 +151,21 @@ def initialize_module_for_quantization(
146
151
  module._hf_hook.weights_map = new_prefix_dict
147
152
 
148
153
 
149
- def _initialize_scale_zero_point_observer(
154
+ def is_attention_module(module: Module):
155
+ return "attention" in module.__class__.__name__.lower() and (
156
+ hasattr(module, "k_proj")
157
+ or hasattr(module, "v_proj")
158
+ or hasattr(module, "qkv_proj")
159
+ )
160
+
161
+
162
+ def _initialize_scale_zero_point(
150
163
  module: Module,
151
164
  base_name: str,
152
165
  quantization_args: QuantizationArgs,
153
166
  weight_shape: Optional[torch.Size] = None,
154
167
  force_zero_point: bool = True,
155
168
  ):
156
-
157
- # initialize observer module and attach as submodule
158
- observer = quantization_args.get_observer()
159
- # no need to register an observer for dynamic quantization
160
- if observer:
161
- module.register_module(f"{base_name}_observer", observer)
162
-
163
- # no need to register a scale and zero point for a dynamic quantization
164
169
  if quantization_args.dynamic:
165
170
  return
166
171
 
@@ -209,14 +214,6 @@ def _initialize_scale_zero_point_observer(
209
214
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
210
215
 
211
216
 
212
- def is_attention_module(module: Module):
213
- return "attention" in module.__class__.__name__.lower() and (
214
- hasattr(module, "k_proj")
215
- or hasattr(module, "v_proj")
216
- or hasattr(module, "qkv_proj")
217
- )
218
-
219
-
220
217
  def _initialize_attn_scales(module: Module) -> None:
221
218
  """Initlaize k_scale, v_scale for self_attn"""
222
219
 
@@ -17,6 +17,7 @@ from enum import Enum
17
17
  from typing import Any, Dict, Optional, Union
18
18
 
19
19
  import torch
20
+ from compressed_tensors.utils import Aliasable
20
21
  from pydantic import BaseModel, Field, field_validator, model_validator
21
22
 
22
23
 
@@ -53,17 +54,29 @@ class QuantizationStrategy(str, Enum):
53
54
  TOKEN = "token"
54
55
 
55
56
 
56
- class ActivationOrdering(str, Enum):
57
+ class ActivationOrdering(Aliasable, str, Enum):
57
58
  """
58
59
  Enum storing strategies for activation ordering
59
60
 
60
61
  Group: reorder groups and weight\n
61
- Weight: only reorder weight, not groups. Slightly lower latency and
62
- accuracy compared to group actorder\n
62
+ Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
63
+ latency when compared to group actorder\n
64
+ Dynamic: alias for Group\n
65
+ Static: alias for Weight\n
63
66
  """
64
67
 
65
68
  GROUP = "group"
66
69
  WEIGHT = "weight"
70
+ # aliases
71
+ DYNAMIC = "dynamic"
72
+ STATIC = "static"
73
+
74
+ @staticmethod
75
+ def get_aliases() -> Dict[str, str]:
76
+ return {
77
+ "dynamic": "group",
78
+ "static": "weight",
79
+ }
67
80
 
68
81
 
69
82
  class QuantizationArgs(BaseModel, use_enum_values=True):
@@ -114,20 +127,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
114
127
  """
115
128
  :return: torch quantization FakeQuantize built based on these QuantizationArgs
116
129
  """
117
- from compressed_tensors.quantization.observers.base import Observer
118
-
119
- # No observer required for the dynamic case
120
- if self.dynamic:
121
- self.observer = None
122
- return self.observer
123
-
124
- return Observer.load_from_registry(self.observer, quantization_args=self)
125
-
126
- def get_kv_cache(self):
127
- """Get the singleton KV Cache"""
128
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
129
-
130
- return QuantizedKVParameterCache(self)
130
+ return self.observer
131
131
 
132
132
  @field_validator("type", mode="before")
133
133
  def validate_type(cls, value) -> QuantizationType:
@@ -210,6 +210,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
210
210
  "activation ordering"
211
211
  )
212
212
 
213
+ # infer observer w.r.t. dynamic
213
214
  if dynamic:
214
215
  if strategy not in (
215
216
  QuantizationStrategy.TOKEN,
@@ -221,18 +222,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
221
222
  "quantization",
222
223
  )
223
224
  if observer is not None:
224
- warnings.warn(
225
- "No observer is used for dynamic quantization, setting to None"
226
- )
227
- model.observer = None
225
+ if observer != "memoryless": # avoid annoying users with old configs
226
+ warnings.warn(
227
+ "No observer is used for dynamic quantization, setting to None"
228
+ )
229
+ observer = None
228
230
 
229
- # if we have not set an observer and we
230
- # are running static quantization, use minmax
231
- if not observer and not dynamic:
232
- model.observer = "minmax"
231
+ elif observer is None:
232
+ # default to minmax for non-dynamic cases
233
+ observer = "minmax"
233
234
 
234
235
  # write back modified values
235
236
  model.strategy = strategy
237
+ model.observer = observer
236
238
  return model
237
239
 
238
240
  def pytorch_dtype(self) -> torch.dtype:
@@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel):
132
132
  `k_proj` and `v_proj` in their names. If this is not the case
133
133
  and kv_cache_scheme != None, the quantization of kv cache will fail
134
134
  :global_compression_ratio: optional informational config to report the model
135
- compression ratio acheived by the quantization config
135
+ compression ratio acheived by the quantization config
136
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
137
- are not quantized even if they match up with a target in config_groups
137
+ are not quantized even if they match up with a target in config_groups
138
138
  """
139
139
 
140
140
  config_groups: Dict[str, Union[QuantizationScheme, List[str]]]