compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -12,17 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
16
- from typing import Dict, Generator, List, Tuple, Union
15
+ from typing import Dict, List, Tuple, Union
17
16
 
18
17
  import numpy
19
18
  import torch
20
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
- from safetensors import safe_open
22
+ from compressed_tensors.utils import merge_names
24
23
  from torch import Tensor
25
- from tqdm import tqdm
26
24
 
27
25
 
28
26
  __all__ = [
@@ -34,11 +32,9 @@ __all__ = [
34
32
  "unpack_bitmasks",
35
33
  ]
36
34
 
37
- _LOGGER: logging.Logger = logging.getLogger(__name__)
38
35
 
39
-
40
- @Compressor.register(name=CompressionFormat.sparse_bitmask.value)
41
- class BitmaskCompressor(Compressor):
36
+ @BaseCompressor.register(name=CompressionFormat.sparse_bitmask.value)
37
+ class BitmaskCompressor(BaseSparseCompressor):
42
38
  """
43
39
  Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
44
40
  values tensor, with their locations stored in a 2d bitmask
@@ -46,56 +42,15 @@ class BitmaskCompressor(Compressor):
46
42
 
47
43
  COMPRESSION_PARAM_NAMES = ["shape", "compressed", "bitmask", "row_offsets"]
48
44
 
49
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
50
- """
51
- Compresses a dense state dict using bitmask compression
45
+ def compress_weight(self, name, value):
46
+ bitmask_tensor = BitmaskTensor.from_dense(value)
47
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
48
+ return bitmask_dict
52
49
 
53
- :param model_state: state dict of uncompressed model
54
- :return: compressed state dict
55
- """
56
- compressed_dict = {}
57
- _LOGGER.debug(
58
- f"Compressing model with {len(model_state)} parameterized layers..."
59
- )
60
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
61
- bitmask_tensor = BitmaskTensor.from_dense(value)
62
- bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
63
- for key in bitmask_dict.keys():
64
- if key in compressed_dict:
65
- _LOGGER.warn(
66
- f"Expected all compressed state_dict keys to be unique, but "
67
- f"found an existing entry for {key}. The existing entry will "
68
- "be replaced."
69
- )
70
- compressed_dict.update(bitmask_dict)
71
-
72
- return compressed_dict
73
-
74
- def decompress(
75
- self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
76
- ) -> Generator[Tuple[str, Tensor], None, None]:
77
- """
78
- Reads a bitmask compressed state dict located
79
- at path_to_model_or_tensors and returns a generator
80
- for sequentially decompressing back to a dense state dict
81
-
82
- :param model_path: path to compressed safetensors model (directory with
83
- one or more safetensors files) or compressed tensors file
84
- :param device: device to load decompressed weights onto
85
- :return: iterator for generating decompressed weights
86
- """
87
- weight_mappings = get_nested_weight_mappings(
88
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
89
- )
90
- for weight_name in weight_mappings.keys():
91
- weight_data = {}
92
- for param_name, safe_path in weight_mappings[weight_name].items():
93
- full_name = merge_names(weight_name, param_name)
94
- with safe_open(safe_path, framework="pt", device=device) as f:
95
- weight_data[param_name] = f.get_tensor(full_name)
96
- data = BitmaskTensor(**weight_data)
97
- decompressed = data.decompress()
98
- yield weight_name, decompressed
50
+ def decompress_weight(self, weight_data):
51
+ data = BitmaskTensor(**weight_data)
52
+ decompressed = data.decompress()
53
+ return decompressed
99
54
 
100
55
 
101
56
  class BitmaskTensor:
@@ -0,0 +1,16 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .marlin_24 import Marlin24Compressor
@@ -17,7 +17,7 @@ from typing import Dict, Generator, Tuple
17
17
 
18
18
  import numpy as np
19
19
  import torch
20
- from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.compressors.base import BaseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
22
  from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23
23
  from compressed_tensors.quantization.lifecycle.forward import quantize
@@ -35,8 +35,8 @@ from tqdm import tqdm
35
35
  _LOGGER: logging.Logger = logging.getLogger(__name__)
36
36
 
37
37
 
38
- @Compressor.register(name=CompressionFormat.marlin_24.value)
39
- class Marlin24Compressor(Compressor):
38
+ @BaseCompressor.register(name=CompressionFormat.marlin_24.value)
39
+ class Marlin24Compressor(BaseCompressor):
40
40
  """
41
41
  Compresses a quantized model with 2:4 sparsity structure for inference with the
42
42
  Marlin24 kernel. Decompression is not implemented for this compressor.
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Optional
16
+ from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
@@ -37,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel):
37
37
  Base data class for storing sparsity compression parameters
38
38
 
39
39
  :param format: name of compression format
40
+ :param targets: List of layer names or layer types that aren't sparse and should
41
+ be ignored during compression. By default, assume all layers are targeted
42
+ :param ignore: List of layer names (unique) to ignore from targets. Defaults to None
40
43
  :param global_sparsity: average sparsity of the entire model
41
44
  :param sparsity_structure: structure of the sparsity, such as
42
45
  "unstructured", "2:4", "8:16" etc
43
46
  """
44
47
 
45
48
  format: str
49
+ targets: Optional[List[str]] = None
50
+ ignore: Optional[List[str]] = None
46
51
  global_sparsity: Optional[float] = 0.0
47
52
  sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,87 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.compressors.base import BaseCompressor
17
+ from compressed_tensors.quantization import (
18
+ QuantizationScheme,
19
+ QuantizationStatus,
20
+ initialize_module_for_quantization,
21
+ )
22
+ from torch import Tensor
23
+ from torch.nn import Parameter
24
+ from torch.nn.functional import linear
25
+ from torch.nn.modules import Linear
26
+
27
+
28
+ class CompressedLinear(Linear):
29
+ """
30
+ Wrapper module for running a compressed forward pass of a quantized Linear module.
31
+ The wrapped layer will decompressed on each forward call.
32
+
33
+ :param module: dense linear module to replace
34
+ :param quantization_scheme: quantization config for the module to wrap
35
+ :param quantization_format: compression format module is stored as
36
+ """
37
+
38
+ @classmethod
39
+ @torch.no_grad()
40
+ def from_linear(
41
+ cls,
42
+ module: Linear,
43
+ quantization_scheme: QuantizationScheme,
44
+ quantization_format: str,
45
+ ):
46
+ module.__class__ = CompressedLinear
47
+ module.compressor = BaseCompressor.load_from_registry(quantization_format)
48
+ device = next(module.parameters()).device
49
+
50
+ # this will initialize all the scales and zero points
51
+ initialize_module_for_quantization(
52
+ module, quantization_scheme, force_zero_point=False
53
+ )
54
+
55
+ # get the shape and dtype of compressed parameters
56
+ compression_params = module.compressor.compression_param_info(
57
+ module.weight.shape, quantization_scheme.weights
58
+ )
59
+
60
+ # no need for this once quantization is initialized, will be replaced
61
+ # with the compressed parameter
62
+ delattr(module, "weight")
63
+
64
+ # populate compressed weights and quantization parameters
65
+ for name, (shape, dtype) in compression_params.items():
66
+ param = Parameter(
67
+ torch.empty(shape, device=device, dtype=dtype), requires_grad=False
68
+ )
69
+ module.register_parameter(name, param)
70
+
71
+ # mark module as compressed
72
+ module.quantization_status = QuantizationStatus.COMPRESSED
73
+
74
+ # handles case where forward is wrapped in new_forward by accelerate hooks
75
+ if hasattr(module, "_old_forward"):
76
+ module._old_forward = CompressedLinear.forward.__get__(
77
+ module, CompressedLinear
78
+ )
79
+
80
+ return module
81
+
82
+ def forward(self, input: Tensor) -> Tensor:
83
+ """
84
+ Decompresses the weight, then runs the wrapped forward pass
85
+ """
86
+ uncompressed_weight = self.compressor.decompress_module(self)
87
+ return linear(input, uncompressed_weight, self.bias)
@@ -19,3 +19,4 @@ from .quant_args import *
19
19
  from .quant_config import *
20
20
  from .quant_scheme import *
21
21
  from .lifecycle import *
22
+ from .cache import QuantizedKVParameterCache
@@ -0,0 +1,201 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from enum import Enum
17
+ from typing import Any, Dict, List, Optional, Tuple
18
+
19
+ from compressed_tensors.quantization.observers import Observer
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import Tensor
22
+ from transformers import DynamicCache as HFDyanmicCache
23
+
24
+
25
+ class KVCacheScaleType(Enum):
26
+ KEY = "k_scale"
27
+ VALUE = "v_scale"
28
+
29
+
30
+ class QuantizedKVParameterCache(HFDyanmicCache):
31
+
32
+ """
33
+ Quantized KV cache used in the forward call based on HF's dynamic cache.
34
+ Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
35
+ Singleton, so that the same cache gets reused in all forward call of self_attn.
36
+ Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
37
+ gets called appropriately.
38
+ The size of tensor is
39
+ `[batch_size, num_heads, seq_len - residual_length, head_dim]`.
40
+
41
+
42
+ Triggered by adding kv_cache_scheme in the recipe.
43
+
44
+ Example:
45
+
46
+ ```python3
47
+ recipe = '''
48
+ quant_stage:
49
+ quant_modifiers:
50
+ QuantizationModifier:
51
+ kv_cache_scheme:
52
+ num_bits: 8
53
+ type: float
54
+ strategy: tensor
55
+ dynamic: false
56
+ symmetric: true
57
+ '''
58
+
59
+ """
60
+
61
+ _instance = None
62
+ _initialized = False
63
+
64
+ def __new__(cls, *args, **kwargs):
65
+ """Singleton"""
66
+ if cls._instance is None:
67
+ cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
68
+ return cls._instance
69
+
70
+ def __init__(self, quantization_args: QuantizationArgs):
71
+ if not self._initialized:
72
+ super().__init__()
73
+
74
+ self.quantization_args = quantization_args
75
+
76
+ self.k_observers: List[Observer] = []
77
+ self.v_observers: List[Observer] = []
78
+
79
+ # each index corresponds to layer_idx of the attention layer
80
+ self.k_scales: List[Tensor] = []
81
+ self.v_scales: List[Tensor] = []
82
+
83
+ self.k_zps: List[Tensor] = []
84
+ self.v_zps: List[Tensor] = []
85
+
86
+ self._initialized = True
87
+
88
+ def update(
89
+ self,
90
+ key_states: Tensor,
91
+ value_states: Tensor,
92
+ layer_idx: int,
93
+ cache_kwargs: Optional[Dict[str, Any]] = None,
94
+ ) -> Tuple[Tensor, Tensor]:
95
+ """
96
+ Get the k_scale and v_scale and output the
97
+ fakequant-ed key_states and value_states
98
+ """
99
+
100
+ if len(self.k_observers) <= layer_idx:
101
+ k_observer = self.quantization_args.get_observer()
102
+ v_observer = self.quantization_args.get_observer()
103
+
104
+ self.k_observers.append(k_observer)
105
+ self.v_observers.append(v_observer)
106
+
107
+ q_key_states = self._quantize(
108
+ key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
109
+ )
110
+ q_value_states = self._quantize(
111
+ value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
112
+ )
113
+
114
+ qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
115
+ qdq_value_states = self._dequantize(
116
+ q_value_states, KVCacheScaleType.VALUE, layer_idx
117
+ )
118
+
119
+ keys_to_return, values_to_return = qdq_key_states, qdq_value_states
120
+
121
+ return keys_to_return, values_to_return
122
+
123
+ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
124
+ """
125
+ Returns the sequence length of the cached states.
126
+ A layer index can be optionally passed.
127
+ """
128
+ if len(self.key_cache) <= layer_idx:
129
+ return 0
130
+ # since we cannot get the seq_length of each layer directly and
131
+ # rely on `_seen_tokens` which is updated every "layer_idx" == 0,
132
+ # this is a hack to get the actual seq_length for the given layer_idx
133
+ # this part of code otherwise fails when used to
134
+ # verify attn_weight shape in some models
135
+ return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
136
+
137
+ def reset_states(self):
138
+ """reset the kv states (used in calibration)"""
139
+ self.key_cache: List[Tensor] = []
140
+ self.value_cache: List[Tensor] = []
141
+ # Used in `generate` to keep tally of how many tokens the cache has seen
142
+ self._seen_tokens = 0
143
+ self._quantized_key_cache: List[Tensor] = []
144
+ self._quantized_value_cache: List[Tensor] = []
145
+
146
+ def reset(self):
147
+ """
148
+ Reset the instantiation, create new instance on init
149
+ """
150
+ QuantizedKVParameterCache._instance = None
151
+ QuantizedKVParameterCache._initialized = False
152
+
153
+ def _quantize(self, tensor, kv_type, layer_idx):
154
+ """Quantizes a key/value using a defined quantization method."""
155
+ from compressed_tensors.quantization.lifecycle.forward import quantize
156
+
157
+ if kv_type == KVCacheScaleType.KEY: # key type
158
+ observer = self.k_observers[layer_idx]
159
+ scales = self.k_scales
160
+ zps = self.k_zps
161
+ else:
162
+ assert kv_type == KVCacheScaleType.VALUE
163
+ observer = self.v_observers[layer_idx]
164
+ scales = self.v_scales
165
+ zps = self.v_zps
166
+
167
+ scale, zp = observer(tensor)
168
+ if len(scales) <= layer_idx:
169
+ scales.append(scale)
170
+ zps.append(zp)
171
+ else:
172
+ scales[layer_idx] = scale
173
+ zps[layer_idx] = scale
174
+
175
+ q_tensor = quantize(
176
+ x=tensor,
177
+ scale=scale,
178
+ zero_point=zp,
179
+ args=self.quantization_args,
180
+ )
181
+ return q_tensor
182
+
183
+ def _dequantize(self, qtensor, kv_type, layer_idx):
184
+ """Dequantizes back the tensor that was quantized by `self._quantize()`"""
185
+ from compressed_tensors.quantization.lifecycle.forward import dequantize
186
+
187
+ if kv_type == KVCacheScaleType.KEY:
188
+ scale = self.k_scales[layer_idx]
189
+ zp = self.k_zps[layer_idx]
190
+ else:
191
+ assert kv_type == KVCacheScaleType.VALUE
192
+ scale = self.v_scales[layer_idx]
193
+ zp = self.v_zps[layer_idx]
194
+
195
+ qdq_tensor = dequantize(
196
+ x_q=qtensor,
197
+ scale=scale,
198
+ zero_point=zp,
199
+ args=self.quantization_args,
200
+ )
201
+ return qdq_tensor
@@ -14,12 +14,14 @@
14
14
 
15
15
  import logging
16
16
  import re
17
- from collections import OrderedDict
17
+ from collections import OrderedDict, defaultdict
18
+ from copy import deepcopy
18
19
  from typing import Dict, Iterable, List, Optional
19
20
  from typing import OrderedDict as OrderedDictType
20
21
  from typing import Union
21
22
 
22
23
  import torch
24
+ from compressed_tensors.config import CompressionFormat
23
25
  from compressed_tensors.quantization.lifecycle.calibration import (
24
26
  set_module_for_calibration,
25
27
  )
@@ -41,8 +43,9 @@ from compressed_tensors.quantization.utils import (
41
43
  infer_quantization_status,
42
44
  is_kv_cache_quant_scheme,
43
45
  iter_named_leaf_modules,
46
+ iter_named_quantizable_modules,
44
47
  )
45
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
48
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
46
49
  from compressed_tensors.utils.offload import update_parameter_data
47
50
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
48
51
  from torch.nn import Module
@@ -103,13 +106,25 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
103
106
  )
104
107
 
105
108
 
106
- def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
109
+ def apply_quantization_config(
110
+ model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
111
+ ) -> OrderedDict:
107
112
  """
108
113
  Initializes the model for quantization in-place based on the given config
109
114
 
110
115
  :param model: model to apply quantization config to
111
116
  :param config: quantization config
117
+ :param run_compressed: Whether the model will be run in compressed mode or
118
+ decompressed fully on load
112
119
  """
120
+ # Workaround for when HF Quantizer passes None, see PR #180
121
+ if config is None:
122
+ return OrderedDict()
123
+
124
+ # remove reference to the original `config`
125
+ # argument. This function can mutate it, and we'd
126
+ # like to keep the original `config` as it is.
127
+ config = deepcopy(config)
113
128
  # build mapping of targets to schemes for easier matching
114
129
  # use ordered dict to preserve target ordering in config
115
130
  target_to_scheme = OrderedDict()
@@ -119,21 +134,47 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
119
134
  for target in scheme.targets:
120
135
  target_to_scheme[target] = scheme
121
136
 
137
+ if run_compressed:
138
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
139
+
122
140
  # list of submodules to ignore
123
- ignored_submodules = []
141
+ ignored_submodules = defaultdict(list)
124
142
  # mark appropriate layers for quantization by setting their quantization schemes
125
- for name, submodule in iter_named_leaf_modules(model):
143
+ for name, submodule in iter_named_quantizable_modules(
144
+ model,
145
+ include_children=True,
146
+ include_attn=True,
147
+ ): # child modules and attention modules
126
148
  # potentially fix module name to remove FSDP wrapper prefix
127
149
  name = fix_fsdp_module_name(name)
128
- if find_name_or_class_matches(name, submodule, config.ignore):
129
- ignored_submodules.append(name)
150
+ if matches := find_name_or_class_matches(name, submodule, config.ignore):
151
+ for match in matches:
152
+ ignored_submodules[match].append(name)
130
153
  continue # layer matches ignore list, continue
154
+
131
155
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
156
+
132
157
  if targets:
158
+ # mark modules to be quantized by adding
159
+ # quant scheme to the matching layers
160
+ scheme = _scheme_from_targets(target_to_scheme, targets, name)
161
+ if run_compressed:
162
+ format = config.format
163
+ if format != CompressionFormat.dense.value:
164
+ if isinstance(submodule, torch.nn.Linear):
165
+ # TODO: expand to more module types
166
+ compressed_linear = CompressedLinear.from_linear(
167
+ submodule,
168
+ quantization_scheme=scheme,
169
+ quantization_format=format,
170
+ )
171
+ replace_module(model, name, compressed_linear)
172
+
133
173
  # target matched - add layer and scheme to target list
134
174
  submodule.quantization_scheme = _scheme_from_targets(
135
175
  target_to_scheme, targets, name
136
176
  )
177
+
137
178
  names_to_scheme[name] = submodule.quantization_scheme.weights
138
179
 
139
180
  if config.ignore is not None and ignored_submodules is not None:
@@ -143,8 +184,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
143
184
  "not found in the model: "
144
185
  f"{set(config.ignore) - set(ignored_submodules)}"
145
186
  )
146
- # apply current quantization status across all targeted layers
147
187
 
188
+ # apply current quantization status across all targeted layers
148
189
  apply_quantization_status(model, config.quantization_status)
149
190
  return names_to_scheme
150
191
 
@@ -172,6 +213,9 @@ def process_kv_cache_config(
172
213
  :param config: the QuantizationConfig
173
214
  :return: the QuantizationConfig with additional "kv_cache" group
174
215
  """
216
+ if targets == KV_CACHE_TARGETS:
217
+ _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
218
+
175
219
  kv_cache_dict = config.kv_cache_scheme.model_dump()
176
220
  kv_cache_scheme = QuantizationScheme(
177
221
  output_activations=QuantizationArgs(**kv_cache_dict),
@@ -192,7 +236,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
192
236
  current_status = infer_quantization_status(model)
193
237
 
194
238
  if status >= QuantizationStatus.INITIALIZED > current_status:
195
- model.apply(initialize_module_for_quantization)
239
+ force_zero_point_init = status != QuantizationStatus.COMPRESSED
240
+ model.apply(
241
+ lambda module: initialize_module_for_quantization(
242
+ module, force_zero_point=force_zero_point_init
243
+ )
244
+ )
196
245
 
197
246
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
198
247
  # only quantize weights up front when our end goal state is calibration,
@@ -273,9 +322,11 @@ def _load_quant_args_from_state_dict(
273
322
  """
274
323
  scale_name = f"{base_name}_scale"
275
324
  zp_name = f"{base_name}_zero_point"
325
+ g_idx_name = f"{base_name}_g_idx"
276
326
 
277
327
  state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
278
328
  state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
329
+ state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
279
330
 
280
331
  if state_dict_scale is not None:
281
332
  # module is quantized
@@ -285,6 +336,9 @@ def _load_quant_args_from_state_dict(
285
336
  state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
286
337
  update_parameter_data(module, state_dict_zp, zp_name)
287
338
 
339
+ if state_dict_g_idx is not None:
340
+ update_parameter_data(module, state_dict_g_idx, g_idx_name)
341
+
288
342
 
289
343
  def _scheme_from_targets(
290
344
  target_to_scheme: OrderedDictType[str, QuantizationScheme],
@@ -36,15 +36,15 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
36
36
  apply to full model with `model.apply(set_module_for_calibration)`
37
37
 
38
38
  :param module: module to set for calibration
39
- :param quantize_weights_upfront: whether to automatically run weight quantization at the
40
- start of calibration
39
+ :param quantize_weights_upfront: whether to automatically
40
+ run weight quantization at the start of calibration
41
41
  """
42
42
  if not getattr(module, "quantization_scheme", None):
43
43
  # no quantization scheme nothing to do
44
44
  return
45
45
  status = getattr(module, "quantization_status", None)
46
46
  if not status or status != QuantizationStatus.INITIALIZED:
47
- raise _LOGGER.warning(
47
+ _LOGGER.warning(
48
48
  f"Attempting set module with status {status} to calibration mode. "
49
49
  f"but status is not {QuantizationStatus.INITIALIZED} - you may "
50
50
  "be calibrating an uninitialized module which may fail or attempting "
@@ -54,13 +54,13 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
54
54
  if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
55
  # set weight scale and zero_point up front, calibration data doesn't affect it
56
56
  observer = module.weight_observer
57
+ g_idx = getattr(module, "weight_g_idx", None)
57
58
 
58
- offloaded = False
59
- if is_module_offloaded(module):
59
+ offloaded = is_module_offloaded(module)
60
+ if offloaded:
60
61
  module._hf_hook.pre_forward(module)
61
- offloaded = True
62
62
 
63
- scale, zero_point = observer(module.weight)
63
+ scale, zero_point = observer(module.weight, g_idx=g_idx)
64
64
  update_parameter_data(module, scale, "weight_scale")
65
65
  update_parameter_data(module, zero_point, "weight_zero_point")
66
66