compressed-tensors 0.7.1__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/PKG-INFO +1 -1
  2. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  3. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/base.py +60 -2
  4. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/__init__.py +0 -1
  5. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  6. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/apply.py +1 -16
  7. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/forward.py +24 -87
  8. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +21 -24
  9. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_args.py +11 -22
  10. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/utils/helpers.py +125 -8
  11. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/registry/registry.py +1 -1
  12. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/version.py +1 -1
  13. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
  14. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/SOURCES.txt +0 -8
  15. compressed-tensors-0.7.1/src/compressed_tensors/quantization/cache.py +0 -201
  16. compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  17. compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  18. compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/__init__.py +0 -21
  19. compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/base.py +0 -213
  20. compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/helpers.py +0 -149
  21. compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/min_max.py +0 -104
  22. compressed-tensors-0.7.1/src/compressed_tensors/quantization/observers/mse.py +0 -162
  23. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/LICENSE +0 -0
  24. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/README.md +0 -0
  25. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/pyproject.toml +0 -0
  26. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/setup.cfg +0 -0
  27. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/setup.py +0 -0
  28. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/__init__.py +0 -0
  29. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/base.py +0 -0
  30. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/__init__.py +0 -0
  31. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/base.py +0 -0
  32. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/helpers.py +0 -0
  33. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  34. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  35. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  36. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  37. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  38. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  39. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  40. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  41. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  42. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  43. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  44. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/__init__.py +0 -0
  45. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/dense.py +0 -0
  46. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  47. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/linear/__init__.py +0 -0
  48. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  49. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  50. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  51. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_config.py +0 -0
  52. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  53. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  54. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/registry/__init__.py +0 -0
  55. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/__init__.py +0 -0
  56. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/helpers.py +0 -0
  57. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/offload.py +0 -0
  58. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/permutations_24.py +0 -0
  59. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/permute.py +0 -0
  60. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  61. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  62. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  63. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/requires.txt +0 -0
  64. {compressed-tensors-0.7.1 → compressed-tensors-0.8.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.7.1
3
+ Version: 0.8.0
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -238,7 +238,7 @@ def pack_scales_24(scales, quantization_args, w_shape):
238
238
  _, scale_perm_2_4, scale_perm_single_2_4 = get_permutations_24(num_bits)
239
239
 
240
240
  if (
241
- quantization_args.strategy is QuantizationStrategy.GROUP
241
+ quantization_args.strategy == QuantizationStrategy.GROUP
242
242
  and quantization_args.group_size < size_k
243
243
  ):
244
244
  scales = scales.reshape((-1, len(scale_perm_2_4)))[:, scale_perm_2_4]
@@ -12,16 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from enum import Enum
15
+ from enum import Enum, unique
16
16
  from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
20
20
 
21
21
 
22
- __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"]
23
23
 
24
24
 
25
+ @unique
25
26
  class CompressionFormat(Enum):
26
27
  dense = "dense"
27
28
  sparse_bitmask = "sparse-bitmask"
@@ -32,6 +33,63 @@ class CompressionFormat(Enum):
32
33
  marlin_24 = "marlin-24"
33
34
 
34
35
 
36
+ @unique
37
+ class SparsityStructure(Enum):
38
+ """
39
+ An enumeration to represent different sparsity structures.
40
+
41
+ Attributes
42
+ ----------
43
+ TWO_FOUR : str
44
+ Represents a 2:4 sparsity structure.
45
+ ZERO_ZERO : str
46
+ Represents a 0:0 sparsity structure.
47
+ UNSTRUCTURED : str
48
+ Represents an unstructured sparsity structure.
49
+
50
+ Examples
51
+ --------
52
+ >>> SparsityStructure('2:4')
53
+ <SparsityStructure.TWO_FOUR: '2:4'>
54
+
55
+ >>> SparsityStructure('unstructured')
56
+ <SparsityStructure.UNSTRUCTURED: 'unstructured'>
57
+
58
+ >>> SparsityStructure('2:4') == SparsityStructure.TWO_FOUR
59
+ True
60
+
61
+ >>> SparsityStructure('UNSTRUCTURED') == SparsityStructure.UNSTRUCTURED
62
+ True
63
+
64
+ >>> SparsityStructure(None) == SparsityStructure.UNSTRUCTURED
65
+ True
66
+
67
+ >>> SparsityStructure('invalid')
68
+ Traceback (most recent call last):
69
+ ...
70
+ ValueError: invalid is not a valid SparsityStructure
71
+ """
72
+
73
+ TWO_FOUR = "2:4"
74
+ UNSTRUCTURED = "unstructured"
75
+ ZERO_ZERO = "0:0"
76
+
77
+ def __new__(cls, value):
78
+ obj = object.__new__(cls)
79
+ obj._value_ = value.lower() if value is not None else value
80
+ return obj
81
+
82
+ @classmethod
83
+ def _missing_(cls, value):
84
+ # Handle None and case-insensitive values
85
+ if value is None:
86
+ return cls.UNSTRUCTURED
87
+ for member in cls:
88
+ if member.value == value.lower():
89
+ return member
90
+ raise ValueError(f"{value} is not a valid {cls.__name__}")
91
+
92
+
35
93
  class SparsityCompressionConfig(RegistryMixin, BaseModel):
36
94
  """
37
95
  Base data class for storing sparsity compression parameters
@@ -19,4 +19,3 @@ from .quant_args import *
19
19
  from .quant_config import *
20
20
  from .quant_scheme import *
21
21
  from .lifecycle import *
22
- from .cache import QuantizedKVParameterCache
@@ -15,9 +15,7 @@
15
15
  # flake8: noqa
16
16
  # isort: skip_file
17
17
 
18
- from .calibration import *
19
18
  from .forward import *
20
- from .frozen import *
21
19
  from .initialize import *
22
20
  from .compressed import *
23
21
  from .apply import *
@@ -22,13 +22,9 @@ from typing import Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
25
- from compressed_tensors.quantization.lifecycle.calibration import (
26
- set_module_for_calibration,
27
- )
28
25
  from compressed_tensors.quantization.lifecycle.compressed import (
29
26
  compress_quantized_weights,
30
27
  )
31
- from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
32
28
  from compressed_tensors.quantization.lifecycle.initialize import (
33
29
  initialize_module_for_quantization,
34
30
  )
@@ -233,6 +229,7 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
233
229
  :param model: model to apply quantization to
234
230
  :param status: status to update the module to
235
231
  """
232
+
236
233
  current_status = infer_quantization_status(model)
237
234
 
238
235
  if status >= QuantizationStatus.INITIALIZED > current_status:
@@ -243,18 +240,6 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
243
240
  )
244
241
  )
245
242
 
246
- if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
247
- # only quantize weights up front when our end goal state is calibration,
248
- # weight quantization parameters are already loaded for frozen/compressed
249
- quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
250
- model.apply(
251
- lambda module: set_module_for_calibration(
252
- module, quantize_weights_upfront=quantize_weights_upfront
253
- )
254
- )
255
- if current_status < status >= QuantizationStatus.FROZEN > current_status:
256
- model.apply(freeze_module_quantization)
257
-
258
243
  if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
259
244
  model.apply(compress_quantized_weights)
260
245
 
@@ -14,14 +14,9 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
- from typing import Callable, Optional
17
+ from typing import Optional
18
18
 
19
19
  import torch
20
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
21
- from compressed_tensors.quantization.observers.helpers import (
22
- calculate_range,
23
- compute_dynamic_scales_and_zp,
24
- )
25
20
  from compressed_tensors.quantization.quant_args import (
26
21
  QuantizationArgs,
27
22
  QuantizationStrategy,
@@ -29,7 +24,11 @@ from compressed_tensors.quantization.quant_args import (
29
24
  )
30
25
  from compressed_tensors.quantization.quant_config import QuantizationStatus
31
26
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
- from compressed_tensors.utils import safe_permute, update_parameter_data
27
+ from compressed_tensors.quantization.utils import (
28
+ calculate_range,
29
+ compute_dynamic_scales_and_zp,
30
+ )
31
+ from compressed_tensors.utils import safe_permute
33
32
  from torch.nn import Module
34
33
 
35
34
 
@@ -38,7 +37,7 @@ __all__ = [
38
37
  "dequantize",
39
38
  "fake_quantize",
40
39
  "wrap_module_forward_quantized",
41
- "maybe_calibrate_or_quantize",
40
+ "forward_quantize",
42
41
  ]
43
42
 
44
43
 
@@ -275,15 +274,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
275
274
  compressed = module.quantization_status == QuantizationStatus.COMPRESSED
276
275
 
277
276
  if scheme.input_activations is not None:
278
- # calibrate and (fake) quantize input activations when applicable
279
- input_ = maybe_calibrate_or_quantize(
280
- module, input_, "input", scheme.input_activations
281
- )
277
+ # prehook should calibrate activations before forward call
278
+ input_ = forward_quantize(module, input_, "input", scheme.input_activations)
282
279
 
283
280
  if scheme.weights is not None and not compressed:
284
281
  # calibrate and (fake) quantize weights when applicable
285
282
  unquantized_weight = self.weight.data.clone()
286
- self.weight.data = maybe_calibrate_or_quantize(
283
+ self.weight.data = forward_quantize(
287
284
  module, self.weight, "weight", scheme.weights
288
285
  )
289
286
 
@@ -291,64 +288,23 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
291
288
  output = forward_func_orig.__get__(module, module.__class__)(
292
289
  input_, *args[1:], **kwargs
293
290
  )
294
- if scheme.output_activations is not None:
295
-
296
- # calibrate and (fake) quantize output activations when applicable
297
- # kv_cache scales updated on model self_attn forward call in
298
- # wrap_module_forward_quantized_attn
299
- output = maybe_calibrate_or_quantize(
300
- module, output, "output", scheme.output_activations
301
- )
302
291
 
303
292
  # restore back to unquantized_value
304
293
  if scheme.weights is not None and not compressed:
305
294
  self.weight.data = unquantized_weight
306
295
 
307
- return output
308
-
309
- # bind wrapped forward to module class so reference to `self` is correct
310
- bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
311
- # set forward to wrapped forward
312
- setattr(module, "forward", bound_wrapped_forward)
313
-
314
-
315
- def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
316
- # expects a module already initialized and injected with the parameters in
317
- # initialize_module_for_quantization
318
- if hasattr(module.forward, "__func__"):
319
- forward_func_orig = module.forward.__func__
320
- else:
321
- forward_func_orig = module.forward.func
322
-
323
- @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
324
- def wrapped_forward(self, *args, **kwargs):
325
-
326
- # kv cache stored under weights
327
- if module.quantization_status == QuantizationStatus.CALIBRATION:
328
- quantization_args: QuantizationArgs = scheme.output_activations
329
- past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
330
- kwargs["past_key_value"] = past_key_value
331
-
332
- # QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
333
- # does not store quantized_key_states and quantized_value_state
334
- kwargs["use_cache"] = False
335
-
336
- attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
337
-
338
- past_key_value.reset_states()
339
-
340
- rtn = attn_forward(*args, **kwargs)
341
-
342
- update_parameter_data(
343
- module, past_key_value.k_scales[module.layer_idx], "k_scale"
344
- )
345
- update_parameter_data(
346
- module, past_key_value.v_scales[module.layer_idx], "v_scale"
296
+ if scheme.output_activations is not None:
297
+ # forward-hook should calibrate/forward_quantize
298
+ if (
299
+ module.quantization_status == QuantizationStatus.CALIBRATION
300
+ and not scheme.output_activations.dynamic
301
+ ):
302
+ return output
303
+
304
+ output = forward_quantize(
305
+ module, output, "output", scheme.output_activations
347
306
  )
348
-
349
- return rtn
350
-
351
- return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
307
+ return output
352
308
 
353
309
  # bind wrapped forward to module class so reference to `self` is correct
354
310
  bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
@@ -356,12 +312,9 @@ def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationSchem
356
312
  setattr(module, "forward", bound_wrapped_forward)
357
313
 
358
314
 
359
- def maybe_calibrate_or_quantize(
315
+ def forward_quantize(
360
316
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
361
317
  ) -> torch.Tensor:
362
- # don't run quantization if we haven't entered calibration mode
363
- if module.quantization_status == QuantizationStatus.INITIALIZED:
364
- return value
365
318
 
366
319
  # in compressed mode, the weight is already compressed and quantized so we don't
367
320
  # need to run fake quantization
@@ -379,29 +332,13 @@ def maybe_calibrate_or_quantize(
379
332
  g_idx = getattr(module, "weight_g_idx", None)
380
333
 
381
334
  if args.dynamic:
382
- # dynamic quantization - no need to invoke observer
335
+ # dynamic quantization - determine the scale/zp on the fly
383
336
  scale, zero_point = compute_dynamic_scales_and_zp(value=value, args=args)
384
337
  else:
385
- # static quantization - get previous scale and zero point from layer
338
+ # static quantization - get scale and zero point from layer
386
339
  scale = getattr(module, f"{base_name}_scale")
387
340
  zero_point = getattr(module, f"{base_name}_zero_point", None)
388
341
 
389
- if (
390
- module.quantization_status == QuantizationStatus.CALIBRATION
391
- and base_name != "weight"
392
- ):
393
- # calibration mode - get new quant params from observer
394
- observer = getattr(module, f"{base_name}_observer")
395
-
396
- updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
397
-
398
- # update scale and zero point
399
- update_parameter_data(module, updated_scale, f"{base_name}_scale")
400
- update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
401
-
402
- scale = updated_scale
403
- zero_point = updated_zero_point
404
-
405
342
  return fake_quantize(
406
343
  x=value,
407
344
  scale=scale,
@@ -14,13 +14,12 @@
14
14
 
15
15
 
16
16
  import logging
17
+ from enum import Enum
17
18
  from typing import Optional
18
19
 
19
20
  import torch
20
- from compressed_tensors.quantization.cache import KVCacheScaleType
21
21
  from compressed_tensors.quantization.lifecycle.forward import (
22
22
  wrap_module_forward_quantized,
23
- wrap_module_forward_quantized_attn,
24
23
  )
25
24
  from compressed_tensors.quantization.quant_args import (
26
25
  ActivationOrdering,
@@ -36,12 +35,19 @@ from torch.nn import Module, Parameter
36
35
 
37
36
  __all__ = [
38
37
  "initialize_module_for_quantization",
38
+ "is_attention_module",
39
+ "KVCacheScaleType",
39
40
  ]
40
41
 
41
42
 
42
43
  _LOGGER = logging.getLogger(__name__)
43
44
 
44
45
 
46
+ class KVCacheScaleType(Enum):
47
+ KEY = "k_scale"
48
+ VALUE = "v_scale"
49
+
50
+
45
51
  def initialize_module_for_quantization(
46
52
  module: Module,
47
53
  scheme: Optional[QuantizationScheme] = None,
@@ -66,15 +72,13 @@ def initialize_module_for_quantization(
66
72
  return
67
73
 
68
74
  if is_attention_module(module):
69
- # wrap forward call of module to perform
70
75
  # quantized actions based on calltime status
71
- wrap_module_forward_quantized_attn(module, scheme)
72
76
  _initialize_attn_scales(module)
73
77
 
74
78
  else:
75
79
 
76
80
  if scheme.input_activations is not None:
77
- _initialize_scale_zero_point_observer(
81
+ _initialize_scale_zero_point(
78
82
  module,
79
83
  "input",
80
84
  scheme.input_activations,
@@ -85,7 +89,7 @@ def initialize_module_for_quantization(
85
89
  weight_shape = None
86
90
  if isinstance(module, torch.nn.Linear):
87
91
  weight_shape = module.weight.shape
88
- _initialize_scale_zero_point_observer(
92
+ _initialize_scale_zero_point(
89
93
  module,
90
94
  "weight",
91
95
  scheme.weights,
@@ -101,7 +105,7 @@ def initialize_module_for_quantization(
101
105
 
102
106
  if scheme.output_activations is not None:
103
107
  if not is_kv_cache_quant_scheme(scheme):
104
- _initialize_scale_zero_point_observer(
108
+ _initialize_scale_zero_point(
105
109
  module, "output", scheme.output_activations
106
110
  )
107
111
 
@@ -109,6 +113,7 @@ def initialize_module_for_quantization(
109
113
  module.quantization_status = QuantizationStatus.INITIALIZED
110
114
 
111
115
  offloaded = False
116
+ # What is this doing/why isn't this in the attn case?
112
117
  if is_module_offloaded(module):
113
118
  try:
114
119
  from accelerate.hooks import add_hook_to_module, remove_hook_from_module
@@ -146,21 +151,21 @@ def initialize_module_for_quantization(
146
151
  module._hf_hook.weights_map = new_prefix_dict
147
152
 
148
153
 
149
- def _initialize_scale_zero_point_observer(
154
+ def is_attention_module(module: Module):
155
+ return "attention" in module.__class__.__name__.lower() and (
156
+ hasattr(module, "k_proj")
157
+ or hasattr(module, "v_proj")
158
+ or hasattr(module, "qkv_proj")
159
+ )
160
+
161
+
162
+ def _initialize_scale_zero_point(
150
163
  module: Module,
151
164
  base_name: str,
152
165
  quantization_args: QuantizationArgs,
153
166
  weight_shape: Optional[torch.Size] = None,
154
167
  force_zero_point: bool = True,
155
168
  ):
156
-
157
- # initialize observer module and attach as submodule
158
- observer = quantization_args.get_observer()
159
- # no need to register an observer for dynamic quantization
160
- if observer:
161
- module.register_module(f"{base_name}_observer", observer)
162
-
163
- # no need to register a scale and zero point for a dynamic quantization
164
169
  if quantization_args.dynamic:
165
170
  return
166
171
 
@@ -209,14 +214,6 @@ def _initialize_scale_zero_point_observer(
209
214
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
210
215
 
211
216
 
212
- def is_attention_module(module: Module):
213
- return "attention" in module.__class__.__name__.lower() and (
214
- hasattr(module, "k_proj")
215
- or hasattr(module, "v_proj")
216
- or hasattr(module, "qkv_proj")
217
- )
218
-
219
-
220
217
  def _initialize_attn_scales(module: Module) -> None:
221
218
  """Initlaize k_scale, v_scale for self_attn"""
222
219
 
@@ -114,20 +114,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
114
114
  """
115
115
  :return: torch quantization FakeQuantize built based on these QuantizationArgs
116
116
  """
117
- from compressed_tensors.quantization.observers.base import Observer
118
-
119
- # No observer required for the dynamic case
120
- if self.dynamic:
121
- self.observer = None
122
- return self.observer
123
-
124
- return Observer.load_from_registry(self.observer, quantization_args=self)
125
-
126
- def get_kv_cache(self):
127
- """Get the singleton KV Cache"""
128
- from compressed_tensors.quantization.cache import QuantizedKVParameterCache
129
-
130
- return QuantizedKVParameterCache(self)
117
+ return self.observer
131
118
 
132
119
  @field_validator("type", mode="before")
133
120
  def validate_type(cls, value) -> QuantizationType:
@@ -210,6 +197,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
210
197
  "activation ordering"
211
198
  )
212
199
 
200
+ # infer observer w.r.t. dynamic
213
201
  if dynamic:
214
202
  if strategy not in (
215
203
  QuantizationStrategy.TOKEN,
@@ -221,18 +209,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
221
209
  "quantization",
222
210
  )
223
211
  if observer is not None:
224
- warnings.warn(
225
- "No observer is used for dynamic quantization, setting to None"
226
- )
227
- model.observer = None
212
+ if observer != "memoryless": # avoid annoying users with old configs
213
+ warnings.warn(
214
+ "No observer is used for dynamic quantization, setting to None"
215
+ )
216
+ observer = None
228
217
 
229
- # if we have not set an observer and we
230
- # are running static quantization, use minmax
231
- if not observer and not dynamic:
232
- model.observer = "minmax"
218
+ elif observer is None:
219
+ # default to minmax for non-dynamic cases
220
+ observer = "minmax"
233
221
 
234
222
  # write back modified values
235
223
  model.strategy = strategy
224
+ model.observer = observer
236
225
  return model
237
226
 
238
227
  def pytorch_dtype(self) -> torch.dtype: