compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from typing import Dict, Generator, List, Tuple, Union
17
17
 
18
18
  import numpy
19
19
  import torch
20
- from compressed_tensors.compressors import ModelCompressor
20
+ from compressed_tensors.compressors import Compressor
21
21
  from compressed_tensors.config import CompressionFormat
22
22
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
23
  from safetensors import safe_open
@@ -37,8 +37,8 @@ __all__ = [
37
37
  _LOGGER: logging.Logger = logging.getLogger(__name__)
38
38
 
39
39
 
40
- @ModelCompressor.register(name=CompressionFormat.sparse_bitmask.value)
41
- class BitmaskCompressor(ModelCompressor):
40
+ @Compressor.register(name=CompressionFormat.sparse_bitmask.value)
41
+ class BitmaskCompressor(Compressor):
42
42
  """
43
43
  Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
44
44
  values tensor, with their locations stored in a 2d bitmask
@@ -72,7 +72,7 @@ class BitmaskCompressor(ModelCompressor):
72
72
  return compressed_dict
73
73
 
74
74
  def decompress(
75
- self, path_to_model_or_tensors: str, device: str = "cpu"
75
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
76
76
  ) -> Generator[Tuple[str, Tensor], None, None]:
77
77
  """
78
78
  Reads a bitmask compressed state dict located
@@ -19,17 +19,22 @@ from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
20
20
 
21
21
 
22
- __all__ = ["CompressionConfig", "CompressionFormat"]
22
+ __all__ = ["SparsityCompressionConfig", "CompressionFormat"]
23
23
 
24
24
 
25
25
  class CompressionFormat(Enum):
26
- dense_sparsity = "dense-sparsity"
26
+ dense = "dense"
27
27
  sparse_bitmask = "sparse-bitmask"
28
+ int_quantized = "int-quantized"
29
+ float_quantized = "float-quantized"
30
+ naive_quantized = "naive-quantized"
31
+ pack_quantized = "pack-quantized"
32
+ marlin_24 = "marlin-24"
28
33
 
29
34
 
30
- class CompressionConfig(RegistryMixin, BaseModel):
35
+ class SparsityCompressionConfig(RegistryMixin, BaseModel):
31
36
  """
32
- Base data class for storing compression parameters
37
+ Base data class for storing sparsity compression parameters
33
38
 
34
39
  :param format: name of compression format
35
40
  :param global_sparsity: average sparsity of the entire model
@@ -14,14 +14,14 @@
14
14
 
15
15
  from typing import Optional
16
16
 
17
- from compressed_tensors.config import CompressionConfig, CompressionFormat
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
18
 
19
19
 
20
20
  __all__ = ["DenseSparsityConfig"]
21
21
 
22
22
 
23
- @CompressionConfig.register(name=CompressionFormat.dense_sparsity.value)
24
- class DenseSparsityConfig(CompressionConfig):
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
24
+ class DenseSparsityConfig(SparsityCompressionConfig):
25
25
  """
26
26
  Identity configuration for storing a sparse model in
27
27
  an uncompressed dense format
@@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig):
31
31
  "unstructured", "2:4", "8:16" etc
32
32
  """
33
33
 
34
- format: str = CompressionFormat.dense_sparsity.value
34
+ format: str = CompressionFormat.dense.value
35
35
  global_sparsity: Optional[float] = 0.0
36
36
  sparsity_structure: Optional[str] = "unstructured"
@@ -14,14 +14,14 @@
14
14
 
15
15
  from typing import Optional
16
16
 
17
- from compressed_tensors.config import CompressionConfig, CompressionFormat
17
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
18
18
 
19
19
 
20
20
  __all__ = ["BitmaskConfig"]
21
21
 
22
22
 
23
- @CompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
- class BitmaskConfig(CompressionConfig):
23
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
24
+ class BitmaskConfig(SparsityCompressionConfig):
25
25
  """
26
26
  Configuration for storing a sparse model using
27
27
  bitmask compression
@@ -19,4 +19,6 @@ from .calibration import *
19
19
  from .forward import *
20
20
  from .frozen import *
21
21
  from .initialize import *
22
+ from .compressed import *
22
23
  from .apply import *
24
+ from .helpers import *
@@ -12,22 +12,38 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import logging
15
16
  import re
16
17
  from collections import OrderedDict
17
- from typing import Dict, Iterable, Optional
18
+ from typing import Dict, Iterable, List, Optional
19
+ from typing import OrderedDict as OrderedDictType
20
+ from typing import Union
18
21
 
22
+ import torch
19
23
  from compressed_tensors.quantization.lifecycle.calibration import (
20
24
  set_module_for_calibration,
21
25
  )
26
+ from compressed_tensors.quantization.lifecycle.compressed import (
27
+ compress_quantized_weights,
28
+ )
22
29
  from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
23
30
  from compressed_tensors.quantization.lifecycle.initialize import (
24
31
  initialize_module_for_quantization,
25
32
  )
33
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
26
34
  from compressed_tensors.quantization.quant_config import (
27
35
  QuantizationConfig,
28
36
  QuantizationStatus,
29
37
  )
30
- from compressed_tensors.quantization.utils import iter_named_leaf_modules
38
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
39
+ from compressed_tensors.quantization.utils import (
40
+ KV_CACHE_TARGETS,
41
+ infer_quantization_status,
42
+ is_kv_cache_quant_scheme,
43
+ iter_named_leaf_modules,
44
+ )
45
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name
46
+ from compressed_tensors.utils.offload import update_parameter_data
31
47
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
32
48
  from torch.nn import Module
33
49
 
@@ -36,13 +52,16 @@ __all__ = [
36
52
  "load_pretrained_quantization",
37
53
  "apply_quantization_config",
38
54
  "apply_quantization_status",
39
- "find_first_name_or_class_match",
55
+ "find_name_or_class_matches",
40
56
  ]
41
57
 
42
58
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
43
59
  from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
44
60
 
45
61
 
62
+ _LOGGER = logging.getLogger(__name__)
63
+
64
+
46
65
  def load_pretrained_quantization(model: Module, model_name_or_path: str):
47
66
  """
48
67
  Loads the quantization parameters (scale and zero point) from model_name_or_path to
@@ -84,7 +103,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
84
103
  )
85
104
 
86
105
 
87
- def apply_quantization_config(model: Module, config: QuantizationConfig):
106
+ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
88
107
  """
89
108
  Initializes the model for quantization in-place based on the given config
90
109
 
@@ -94,21 +113,73 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
94
113
  # build mapping of targets to schemes for easier matching
95
114
  # use ordered dict to preserve target ordering in config
96
115
  target_to_scheme = OrderedDict()
116
+ config = process_quantization_config(config)
117
+ names_to_scheme = OrderedDict()
97
118
  for scheme in config.config_groups.values():
98
119
  for target in scheme.targets:
99
120
  target_to_scheme[target] = scheme
100
121
 
122
+ # list of submodules to ignore
123
+ ignored_submodules = []
101
124
  # mark appropriate layers for quantization by setting their quantization schemes
102
125
  for name, submodule in iter_named_leaf_modules(model):
103
- if find_first_name_or_class_match(name, submodule, config.ignore):
126
+ # potentially fix module name to remove FSDP wrapper prefix
127
+ name = fix_fsdp_module_name(name)
128
+ if find_name_or_class_matches(name, submodule, config.ignore):
129
+ ignored_submodules.append(name)
104
130
  continue # layer matches ignore list, continue
105
- target = find_first_name_or_class_match(name, submodule, target_to_scheme)
106
- if target is not None:
131
+ targets = find_name_or_class_matches(name, submodule, target_to_scheme)
132
+ if targets:
107
133
  # target matched - add layer and scheme to target list
108
- submodule.quantization_scheme = target_to_scheme[target]
134
+ submodule.quantization_scheme = _scheme_from_targets(
135
+ target_to_scheme, targets, name
136
+ )
137
+ names_to_scheme[name] = submodule.quantization_scheme.weights
109
138
 
139
+ if config.ignore is not None and ignored_submodules is not None:
140
+ if set(config.ignore) - set(ignored_submodules):
141
+ _LOGGER.warning(
142
+ "Some layers that were to be ignored were "
143
+ "not found in the model: "
144
+ f"{set(config.ignore) - set(ignored_submodules)}"
145
+ )
110
146
  # apply current quantization status across all targeted layers
147
+
111
148
  apply_quantization_status(model, config.quantization_status)
149
+ return names_to_scheme
150
+
151
+
152
+ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
153
+ """
154
+ Preprocess the raw QuantizationConfig
155
+
156
+ :param config: the raw QuantizationConfig
157
+ :return: the processed QuantizationConfig
158
+ """
159
+ if config.kv_cache_scheme is not None:
160
+ config = process_kv_cache_config(config)
161
+
162
+ return config
163
+
164
+
165
+ def process_kv_cache_config(
166
+ config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
167
+ ) -> QuantizationConfig:
168
+ """
169
+ Reformulate the `config.kv_cache` as a `config_group`
170
+ and add it to the set of existing `config.groups`
171
+
172
+ :param config: the QuantizationConfig
173
+ :return: the QuantizationConfig with additional "kv_cache" group
174
+ """
175
+ kv_cache_dict = config.kv_cache_scheme.model_dump()
176
+ kv_cache_scheme = QuantizationScheme(
177
+ output_activations=QuantizationArgs(**kv_cache_dict),
178
+ targets=targets,
179
+ )
180
+ kv_cache_group = dict(kv_cache=kv_cache_scheme)
181
+ config.config_groups.update(kv_cache_group)
182
+ return config
112
183
 
113
184
 
114
185
  def apply_quantization_status(model: Module, status: QuantizationStatus):
@@ -118,41 +189,73 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
118
189
  :param model: model to apply quantization to
119
190
  :param status: status to update the module to
120
191
  """
121
- if status >= QuantizationStatus.INITIALIZED:
192
+ current_status = infer_quantization_status(model)
193
+
194
+ if status >= QuantizationStatus.INITIALIZED > current_status:
122
195
  model.apply(initialize_module_for_quantization)
123
- if status >= QuantizationStatus.CALIBRATION:
124
- model.apply(set_module_for_calibration)
125
- if status >= QuantizationStatus.FROZEN:
196
+
197
+ if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
198
+ # only quantize weights up front when our end goal state is calibration,
199
+ # weight quantization parameters are already loaded for frozen/compressed
200
+ quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
201
+ model.apply(
202
+ lambda module: set_module_for_calibration(
203
+ module, quantize_weights_upfront=quantize_weights_upfront
204
+ )
205
+ )
206
+ if current_status < status >= QuantizationStatus.FROZEN > current_status:
126
207
  model.apply(freeze_module_quantization)
127
208
 
209
+ if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
210
+ model.apply(compress_quantized_weights)
128
211
 
129
- def find_first_name_or_class_match(
212
+
213
+ def find_name_or_class_matches(
130
214
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
131
- ) -> Optional[str]:
132
- # first element of targets that matches the given name
133
- # if no name matches returns first target that matches the class name
134
- # returns None otherwise
135
- return _find_first_match(name, targets) or _find_first_match(
136
- module.__class__.__name__, targets, check_contains
137
- )
215
+ ) -> List[str]:
216
+ """
217
+ Returns all targets that match the given name or the class name.
218
+ Returns empty list otherwise.
219
+ The order of the output `matches` list matters.
220
+ The entries are sorted in the following order:
221
+ 1. matches on exact strings
222
+ 2. matches on regex patterns
223
+ 3. matches on module names
224
+ """
225
+ targets = sorted(targets, key=lambda x: ("re:" in x, x))
226
+ if isinstance(targets, Iterable):
227
+ matches = _find_matches(name, targets) + _find_matches(
228
+ module.__class__.__name__, targets, check_contains
229
+ )
230
+ matches = [match for match in matches if match is not None]
231
+ return matches
138
232
 
139
233
 
140
- def _find_first_match(
234
+ def _find_matches(
141
235
  value: str, targets: Iterable[str], check_contains: bool = False
142
- ) -> Optional[str]:
143
- # returns first element of target that matches value either
236
+ ) -> List[str]:
237
+ # returns all the targets that match value either
144
238
  # exactly or as a regex after 're:'. if check_contains is set to True,
145
239
  # additionally checks if the target string is contained with value.
240
+ matches = []
146
241
  for target in targets:
147
242
  if target.startswith("re:"):
148
243
  pattern = target[3:]
149
244
  if re.match(pattern, value):
150
- return target
245
+ matches.append(target)
151
246
  elif check_contains:
152
247
  if target.lower() in value.lower():
153
- return target
248
+ matches.append(target)
154
249
  elif target == value:
155
- return target
250
+ matches.append(target)
251
+ return matches
252
+
253
+
254
+ def _infer_status(model: Module) -> Optional[QuantizationStatus]:
255
+ for module in model.modules():
256
+ status = getattr(module, "quantization_status", None)
257
+ if status is not None:
258
+ return status
156
259
  return None
157
260
 
158
261
 
@@ -170,9 +273,79 @@ def _load_quant_args_from_state_dict(
170
273
  """
171
274
  scale_name = f"{base_name}_scale"
172
275
  zp_name = f"{base_name}_zero_point"
173
- device = next(module.parameters()).device
174
276
 
175
- scale = getattr(module, scale_name)
176
- zp = getattr(module, zp_name)
177
- scale.data = state_dict[f"{module_name}.{scale_name}"].to(device)
178
- zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
277
+ state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
278
+ state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
279
+
280
+ if state_dict_scale is not None:
281
+ # module is quantized
282
+ update_parameter_data(module, state_dict_scale, scale_name)
283
+ if state_dict_zp is None:
284
+ # fill in zero point for symmetric quantization
285
+ state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
286
+ update_parameter_data(module, state_dict_zp, zp_name)
287
+
288
+
289
+ def _scheme_from_targets(
290
+ target_to_scheme: OrderedDictType[str, QuantizationScheme],
291
+ targets: List[str],
292
+ name: str,
293
+ ) -> QuantizationScheme:
294
+ if len(targets) == 1:
295
+ # if `targets` iterable contains a single element
296
+ # use it as the key
297
+ return target_to_scheme[targets[0]]
298
+
299
+ # otherwise, we need to merge QuantizationSchemes corresponding
300
+ # to multiple targets. This is most likely because `name` module
301
+ # is being target both as an ordinary quantization target, as well
302
+ # as kv cache quantization target
303
+ schemes_to_merge = [target_to_scheme[target] for target in targets]
304
+ return _merge_schemes(schemes_to_merge, name)
305
+
306
+
307
+ def _merge_schemes(
308
+ schemes_to_merge: List[QuantizationScheme], name: str
309
+ ) -> QuantizationScheme:
310
+
311
+ kv_cache_quantization_scheme = [
312
+ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
313
+ ]
314
+ if not kv_cache_quantization_scheme:
315
+ # if the schemes_to_merge do not contain any
316
+ # kv cache QuantizationScheme
317
+ # return the first scheme (the prioritized one,
318
+ # since the order of schemes_to_merge matters)
319
+ return schemes_to_merge[0]
320
+ else:
321
+ # fetch the kv cache QuantizationScheme and the highest
322
+ # priority non-kv cache QuantizationScheme and merge them
323
+ kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
324
+ quantization_scheme = [
325
+ scheme
326
+ for scheme in schemes_to_merge
327
+ if not is_kv_cache_quant_scheme(scheme)
328
+ ][0]
329
+ schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
330
+ merged_scheme = {}
331
+ for scheme in schemes_to_merge:
332
+ scheme_dict = {
333
+ k: v for k, v in scheme.model_dump().items() if v is not None
334
+ }
335
+ # when merging multiple schemes, the final target will be
336
+ # the `name` argument - hence erase the original targets
337
+ del scheme_dict["targets"]
338
+ # make sure that schemes do not "clash" with each other
339
+ overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
340
+ if overlapping_keys:
341
+ raise ValueError(
342
+ f"The module: {name} is being modified by two clashing "
343
+ f"quantization schemes, that jointly try to override "
344
+ f"properties: {overlapping_keys}. Fix the quantization config "
345
+ "so that it is not ambiguous."
346
+ )
347
+ merged_scheme.update(scheme_dict)
348
+
349
+ merged_scheme.update(targets=[name])
350
+
351
+ return QuantizationScheme(**merged_scheme)
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from compressed_tensors.quantization.quant_config import QuantizationStatus
19
+ from compressed_tensors.utils import is_module_offloaded, update_parameter_data
19
20
  from torch.nn import Module
20
21
 
21
22
 
@@ -27,7 +28,7 @@ __all__ = [
27
28
  _LOGGER = logging.getLogger(__name__)
28
29
 
29
30
 
30
- def set_module_for_calibration(module: Module):
31
+ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
31
32
  """
32
33
  marks a layer as ready for calibration which activates observers
33
34
  to update scales and zero points on each forward pass
@@ -35,6 +36,8 @@ def set_module_for_calibration(module: Module):
35
36
  apply to full model with `model.apply(set_module_for_calibration)`
36
37
 
37
38
  :param module: module to set for calibration
39
+ :param quantize_weights_upfront: whether to automatically run weight quantization at the
40
+ start of calibration
38
41
  """
39
42
  if not getattr(module, "quantization_scheme", None):
40
43
  # no quantization scheme nothing to do
@@ -48,4 +51,20 @@ def set_module_for_calibration(module: Module):
48
51
  "to re-calibrate a frozen module"
49
52
  )
50
53
 
54
+ if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
+ # set weight scale and zero_point up front, calibration data doesn't affect it
56
+ observer = module.weight_observer
57
+
58
+ offloaded = False
59
+ if is_module_offloaded(module):
60
+ module._hf_hook.pre_forward(module)
61
+ offloaded = True
62
+
63
+ scale, zero_point = observer(module.weight)
64
+ update_parameter_data(module, scale, "weight_scale")
65
+ update_parameter_data(module, zero_point, "weight_zero_point")
66
+
67
+ if offloaded:
68
+ module._hf_hook.post_forward(module, None)
69
+
51
70
  module.quantization_status = QuantizationStatus.CALIBRATION
@@ -0,0 +1,69 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import logging
17
+
18
+ import torch
19
+ from compressed_tensors.quantization.lifecycle.forward import quantize
20
+ from compressed_tensors.quantization.quant_config import QuantizationStatus
21
+ from torch.nn import Module
22
+
23
+
24
+ __all__ = [
25
+ "compress_quantized_weights",
26
+ ]
27
+
28
+
29
+ _LOGGER = logging.getLogger(__name__)
30
+
31
+
32
+ def compress_quantized_weights(module: Module):
33
+ """
34
+ Quantizes the module weight representation to use fewer bits in memory
35
+
36
+ apply to full model with `model.apply(compress_quantized_weights)`
37
+
38
+ :param module: module to compress to quantized representation
39
+ """
40
+ scheme = getattr(module, "quantization_scheme", None)
41
+ if not scheme or not scheme.weights:
42
+ # no quantization scheme or weights not quantized, nothing to do
43
+ return
44
+
45
+ if scheme is QuantizationStatus.COMPRESSED:
46
+ # module is already compressed, nothing to do
47
+ return
48
+
49
+ weight = getattr(module, "weight", None)
50
+ scale = getattr(module, "weight_scale", None)
51
+ zero_point = getattr(module, "weight_zero_point", None)
52
+
53
+ if weight is None or scale is None or zero_point is None:
54
+ # no weight, scale, or ZP, nothing to do
55
+
56
+ # mark as compressed here to maintain consistent status throughout the model
57
+ module.quantization_status = QuantizationStatus.COMPRESSED
58
+ return
59
+
60
+ module.weight.requires_grad = False # cannot use auto grad after compression
61
+ module.weight.data = quantize(
62
+ x=weight,
63
+ scale=scale,
64
+ zero_point=zero_point,
65
+ args=scheme.weights,
66
+ dtype=torch.int8,
67
+ )
68
+
69
+ module.quantization_status = QuantizationStatus.COMPRESSED