compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +200 -8
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +101 -13
  7. compressed_tensors/compressors/naive_quantized.py +140 -0
  8. compressed_tensors/compressors/pack_quantized.py +128 -132
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +8 -1
  11. compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
  12. compressed_tensors/linear/compressed_linear.py +87 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +204 -44
  15. compressed_tensors/quantization/lifecycle/calibration.py +22 -2
  16. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  17. compressed_tensors/quantization/lifecycle/forward.py +139 -61
  18. compressed_tensors/quantization/lifecycle/helpers.py +80 -0
  19. compressed_tensors/quantization/lifecycle/initialize.py +77 -13
  20. compressed_tensors/quantization/observers/__init__.py +1 -0
  21. compressed_tensors/quantization/observers/base.py +93 -14
  22. compressed_tensors/quantization/observers/helpers.py +64 -11
  23. compressed_tensors/quantization/observers/min_max.py +8 -0
  24. compressed_tensors/quantization/observers/mse.py +162 -0
  25. compressed_tensors/quantization/quant_args.py +139 -23
  26. compressed_tensors/quantization/quant_config.py +35 -2
  27. compressed_tensors/quantization/quant_scheme.py +112 -13
  28. compressed_tensors/quantization/utils/helpers.py +68 -2
  29. compressed_tensors/utils/__init__.py +5 -0
  30. compressed_tensors/utils/helpers.py +44 -2
  31. compressed_tensors/utils/offload.py +116 -0
  32. compressed_tensors/utils/permute.py +70 -0
  33. compressed_tensors/utils/safetensors_load.py +2 -0
  34. compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
  35. compressed_tensors/version.py +1 -1
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
  37. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  38. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  39. compressed_tensors/compressors/int_quantized.py +0 -126
  40. compressed_tensors/compressors/utils/helpers.py +0 -43
  41. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  42. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  43. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  44. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -15,3 +15,4 @@
15
15
  SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
+ KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
@@ -17,8 +17,12 @@
17
17
  from .base import Compressor
18
18
  from .dense import DenseCompressor
19
19
  from .helpers import load_compressed, save_compressed, save_compressed_model
20
- from .int_quantized import IntQuantizationCompressor
21
20
  from .marlin_24 import Marlin24Compressor
22
21
  from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
+ from .naive_quantized import (
23
+ FloatQuantizationCompressor,
24
+ IntQuantizationCompressor,
25
+ QuantizationCompressor,
26
+ )
23
27
  from .pack_quantized import PackedQuantizationCompressor
24
28
  from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -12,20 +12,53 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, Generator, Tuple, Union
15
+ import logging
16
+ from typing import Dict, Generator, Optional, Tuple, Union
16
17
 
18
+ import torch
17
19
  from compressed_tensors.config import SparsityCompressionConfig
18
- from compressed_tensors.quantization import QuantizationConfig
20
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
19
21
  from compressed_tensors.registry import RegistryMixin
22
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
+ from safetensors import safe_open
20
24
  from torch import Tensor
25
+ from torch.nn.modules import Module
26
+ from tqdm import tqdm
21
27
 
22
28
 
29
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
30
+
23
31
  __all__ = ["Compressor"]
24
32
 
25
33
 
26
34
  class Compressor(RegistryMixin):
27
35
  """
28
- Base class representing a model compression algorithm
36
+ Base class representing a model compression algorithm. Each child class should
37
+ implement compression_param_info, compress_weight and decompress_weight.
38
+
39
+ Compressors support compressing/decompressing a full module state dict or a single
40
+ quantized PyTorch leaf module.
41
+
42
+ Model Load Lifecycle (run_compressed=False):
43
+ - ModelCompressor.decompress()
44
+ - apply_quantization_config()
45
+ - Compressor.decompress()
46
+ - Compressor.decompress_weight()
47
+
48
+ Model Save Lifecycle:
49
+ - ModelCompressor.compress()
50
+ - Compressor.compress()
51
+ - Compressor.compress_weight()
52
+
53
+ Module Lifecycle (run_compressed=True):
54
+ - apply_quantization_config()
55
+ - compressed_module = CompressedLinear(module)
56
+ - initialize_module_for_quantization()
57
+ - Compressor.compression_param_info()
58
+ - register_parameters()
59
+ - compressed_module.forward()
60
+ -compressed_module.decompress()
61
+
29
62
 
30
63
  :param config: config specifying compression parameters
31
64
  """
@@ -35,26 +68,185 @@ class Compressor(RegistryMixin):
35
68
  ):
36
69
  self.config = config
37
70
 
38
- def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
71
+ def compression_param_info(
72
+ self,
73
+ weight_shape: torch.Size,
74
+ quantization_args: Optional[QuantizationArgs] = None,
75
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
76
+ """
77
+ Creates a dictionary of expected shapes and dtypes for each compression
78
+ parameter used by the compressor
79
+
80
+ :param weight_shape: uncompressed weight shape
81
+ :param quantization_args: quantization parameters for the weight
82
+ :return: dictionary mapping compressed parameter names to shape and dtype
83
+ """
84
+ raise NotImplementedError()
85
+
86
+ def compress(
87
+ self,
88
+ model_state: Dict[str, Tensor],
89
+ names_to_scheme: Dict[str, QuantizationArgs],
90
+ **kwargs,
91
+ ) -> Dict[str, Tensor]:
39
92
  """
40
93
  Compresses a dense state dict
41
94
 
42
95
  :param model_state: state dict of uncompressed model
96
+ :param names_to_scheme: quantization args for each quantized weight, needed for
97
+ quantize function to calculate bit depth
43
98
  :return: compressed state dict
44
99
  """
45
- raise NotImplementedError()
100
+ compressed_dict = {}
101
+ weight_suffix = ".weight"
102
+ _LOGGER.debug(
103
+ f"Compressing model with {len(model_state)} parameterized layers..."
104
+ )
105
+
106
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
107
+ if name.endswith(weight_suffix):
108
+ prefix = name[: -(len(weight_suffix))]
109
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
110
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
111
+ g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
112
+ if scale is not None:
113
+ # weight is quantized, compress it
114
+ quant_args = names_to_scheme[prefix]
115
+ compressed_data = self.compress_weight(
116
+ weight=value,
117
+ scale=scale,
118
+ zero_point=zp,
119
+ g_idx=g_idx,
120
+ quantization_args=quant_args,
121
+ device="cpu",
122
+ )
123
+ for key, value in compressed_data.items():
124
+ compressed_dict[merge_names(prefix, key)] = value
125
+ else:
126
+ compressed_dict[name] = value.to("cpu")
127
+ elif name.endswith("zero_point") and torch.all(value == 0):
128
+ continue
129
+ elif name.endswith("g_idx") and torch.any(value <= -1):
130
+ continue
131
+ else:
132
+ compressed_dict[name] = value.to("cpu")
133
+
134
+ return compressed_dict
46
135
 
47
136
  def decompress(
48
- self, path_to_model_or_tensors: str, device: str = "cpu"
137
+ self,
138
+ path_to_model_or_tensors: str,
139
+ names_to_scheme: Dict[str, QuantizationArgs],
140
+ device: str = "cpu",
49
141
  ) -> Generator[Tuple[str, Tensor], None, None]:
50
142
  """
51
143
  Reads a compressed state dict located at path_to_model_or_tensors
52
144
  and returns a generator for sequentially decompressing back to a
53
145
  dense state dict
54
146
 
55
- :param model_path: path to compressed safetensors model (directory with
56
- one or more safetensors files) or compressed tensors file
147
+ :param path_to_model_or_tensors: path to compressed safetensors model (directory
148
+ with one or more safetensors files) or compressed tensors file
149
+ :param names_to_scheme: quantization args for each quantized weight
57
150
  :param device: optional device to load intermediate weights into
58
151
  :return: compressed state dict
59
152
  """
153
+ weight_mappings = get_nested_weight_mappings(
154
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
155
+ )
156
+ for weight_name in weight_mappings.keys():
157
+ weight_data = {}
158
+ for param_name, safe_path in weight_mappings[weight_name].items():
159
+ full_name = merge_names(weight_name, param_name)
160
+ with safe_open(safe_path, framework="pt", device=device) as f:
161
+ weight_data[param_name] = f.get_tensor(full_name)
162
+
163
+ if "weight_scale" in weight_data:
164
+ quant_args = names_to_scheme[weight_name]
165
+ decompressed = self.decompress_weight(
166
+ compressed_data=weight_data, quantization_args=quant_args
167
+ )
168
+ yield merge_names(weight_name, "weight"), decompressed
169
+
170
+ def compress_weight(
171
+ self,
172
+ weight: Tensor,
173
+ scale: Tensor,
174
+ zero_point: Optional[Tensor] = None,
175
+ g_idx: Optional[torch.Tensor] = None,
176
+ quantization_args: Optional[QuantizationArgs] = None,
177
+ ) -> Dict[str, torch.Tensor]:
178
+ """
179
+ Compresses a single uncompressed weight
180
+
181
+ :param weight: uncompressed weight tensor
182
+ :param scale: quantization scale for weight
183
+ :param zero_point: quantization zero point for weight
184
+ :param g_idx: optional mapping from column index to group index
185
+ :param quantization_args: quantization parameters for weight
186
+ :return: dictionary of compressed weight data
187
+ """
60
188
  raise NotImplementedError()
189
+
190
+ def decompress_weight(
191
+ self,
192
+ compressed_data: Dict[str, Tensor],
193
+ quantization_args: Optional[QuantizationArgs] = None,
194
+ ) -> torch.Tensor:
195
+ """
196
+ Decompresses a single compressed weight
197
+
198
+ :param compressed_data: dictionary of data needed for decompression
199
+ :param quantization_args: quantization parameters for the weight
200
+ :return: tensor of the decompressed weight
201
+ """
202
+ raise NotImplementedError()
203
+
204
+ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
205
+ """
206
+ Compresses a single quantized leaf PyTorch module. If the module is not
207
+ quantized, this function has no effect.
208
+
209
+ :param module: PyTorch module to compress
210
+ :return: dictionary of compressed weight data, or None if module is not
211
+ quantized
212
+ """
213
+ if not hasattr(module, "quantization_scheme"):
214
+ return None # module is not quantized
215
+ quantization_scheme = module.quantization_scheme
216
+ if not hasattr(quantization_scheme, "weights"):
217
+ return None # weights are not quantized
218
+
219
+ quantization_args = quantization_scheme.weights
220
+ weight = getattr(module, "weight", None)
221
+ weight_scale = getattr(module, "weight_scale", None)
222
+ weight_zero_point = getattr(module, "weight_zero_point", None)
223
+
224
+ return self.compress_weight(
225
+ weight=weight,
226
+ scale=weight_scale,
227
+ zero_point=weight_zero_point,
228
+ quantization_args=quantization_args,
229
+ )
230
+
231
+ def decompress_module(self, module: Module):
232
+ """
233
+ Decompresses a single compressed leaf PyTorch module. If the module is not
234
+ quantized, this function has no effect.
235
+
236
+ :param module: PyTorch module to decompress
237
+ :return: tensor of the decompressed weight, or None if module is not quantized
238
+ """
239
+ if not hasattr(module, "quantization_scheme"):
240
+ return None # module is not quantized
241
+ quantization_scheme = module.quantization_scheme
242
+ if not hasattr(quantization_scheme, "weights"):
243
+ return None # weights are not quantized
244
+
245
+ quantization_args = quantization_scheme.weights
246
+ compressed_data = {}
247
+ for name, parameter in module.named_parameters():
248
+ compressed_data[name] = parameter
249
+
250
+ return self.decompress_weight(
251
+ compressed_data=compressed_data, quantization_args=quantization_args
252
+ )
@@ -29,6 +29,6 @@ class DenseCompressor(Compressor):
29
29
  return model_state
30
30
 
31
31
  def decompress(
32
- self, path_to_model_or_tensors: str, device: str = "cpu"
32
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
33
33
  ) -> Generator[Tuple[str, Tensor], None, None]:
34
34
  return iter([])
@@ -18,15 +18,16 @@ from typing import Dict, Generator, Tuple
18
18
  import numpy as np
19
19
  import torch
20
20
  from compressed_tensors.compressors import Compressor
21
- from compressed_tensors.compressors.utils import (
21
+ from compressed_tensors.config import CompressionFormat
22
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
23
+ from compressed_tensors.quantization.lifecycle.forward import quantize
24
+ from compressed_tensors.utils import (
22
25
  get_permutations_24,
26
+ is_quantization_param,
27
+ merge_names,
23
28
  sparse_semi_structured_from_dense_cutlass,
24
29
  tensor_follows_mask_structure,
25
30
  )
26
- from compressed_tensors.config import CompressionFormat
27
- from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
28
- from compressed_tensors.quantization.lifecycle.forward import quantize
29
- from compressed_tensors.utils import is_quantization_param, merge_names
30
31
  from torch import Tensor
31
32
  from tqdm import tqdm
32
33
 
@@ -107,7 +108,7 @@ class Marlin24Compressor(Compressor):
107
108
  def compress(
108
109
  self,
109
110
  model_state: Dict[str, Tensor],
110
- model_quant_args: Dict[str, QuantizationArgs],
111
+ names_to_scheme: Dict[str, QuantizationArgs],
111
112
  **kwargs,
112
113
  ) -> Dict[str, Tensor]:
113
114
  """
@@ -115,11 +116,11 @@ class Marlin24Compressor(Compressor):
115
116
  with the Marlin24 kernel
116
117
 
117
118
  :param model_state: state dict of uncompressed model
118
- :param model_quant_args: quantization args for each quantized weight, needed for
119
+ :param names_to_scheme: quantization args for each quantized weight, needed for
119
120
  quantize function to calculate bit depth
120
121
  :return: compressed state dict
121
122
  """
122
- self.validate_quant_compatability(model_quant_args)
123
+ self.validate_quant_compatability(names_to_scheme)
123
124
 
124
125
  compressed_dict = {}
125
126
  weight_suffix = ".weight"
@@ -139,7 +140,7 @@ class Marlin24Compressor(Compressor):
139
140
  value = value.to(torch.float16)
140
141
 
141
142
  # quantize weight, keeping it as a float16 for now
142
- quant_args = model_quant_args[prefix]
143
+ quant_args = names_to_scheme[prefix]
143
144
  value = quantize(
144
145
  x=value, scale=scale, zero_point=zp, args=quant_args
145
146
  )
@@ -175,7 +176,7 @@ class Marlin24Compressor(Compressor):
175
176
  return compressed_dict
176
177
 
177
178
  def decompress(
178
- self, path_to_model_or_tensors: str, device: str = "cpu"
179
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
179
180
  ) -> Generator[Tuple[str, Tensor], None, None]:
180
181
  raise NotImplementedError(
181
182
  "Decompression is not implemented for the Marlin24 Compressor."
@@ -16,16 +16,19 @@ import json
16
16
  import logging
17
17
  import operator
18
18
  import os
19
+ import re
19
20
  from copy import deepcopy
20
21
  from typing import Any, Dict, Optional, Union
21
22
 
23
+ import torch
24
+ import transformers
22
25
  from compressed_tensors.base import (
23
26
  COMPRESSION_CONFIG_NAME,
24
27
  QUANTIZATION_CONFIG_NAME,
25
28
  SPARSITY_CONFIG_NAME,
26
29
  )
27
30
  from compressed_tensors.compressors import Compressor
28
- from compressed_tensors.config import SparsityCompressionConfig
31
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
29
32
  from compressed_tensors.quantization import (
30
33
  QuantizationConfig,
31
34
  QuantizationStatus,
@@ -36,10 +39,10 @@ from compressed_tensors.quantization.utils import (
36
39
  is_module_quantized,
37
40
  iter_named_leaf_modules,
38
41
  )
39
- from compressed_tensors.utils import get_safetensors_folder
42
+ from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
40
43
  from compressed_tensors.utils.helpers import fix_fsdp_module_name
41
44
  from torch import Tensor
42
- from torch.nn import Module, Parameter
45
+ from torch.nn import Module
43
46
  from tqdm import tqdm
44
47
  from transformers import AutoConfig
45
48
  from transformers.file_utils import CONFIG_NAME
@@ -78,6 +81,7 @@ class ModelCompressor:
78
81
  def from_pretrained(
79
82
  cls,
80
83
  pretrained_model_name_or_path: str,
84
+ **kwargs,
81
85
  ) -> Optional["ModelCompressor"]:
82
86
  """
83
87
  Given a path to a model config, extract the sparsity and/or quantization
@@ -86,7 +90,7 @@ class ModelCompressor:
86
90
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
87
91
  :return: compressor for the extracted configs
88
92
  """
89
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
93
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
90
94
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
91
95
  return cls.from_compression_config(compression_config)
92
96
 
@@ -172,6 +176,9 @@ class ModelCompressor:
172
176
  if hasattr(compression_config, SPARSITY_CONFIG_NAME):
173
177
  # for loaded HFQuantizer config
174
178
  return getattr(compression_config, SPARSITY_CONFIG_NAME)
179
+ if SPARSITY_CONFIG_NAME in compression_config:
180
+ # for loaded HFQuantizer config from dict
181
+ return compression_config[SPARSITY_CONFIG_NAME]
175
182
 
176
183
  # SparseAutoModel format
177
184
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
@@ -185,6 +192,10 @@ class ModelCompressor:
185
192
  # for loaded HFQuantizer config
186
193
  return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
187
194
 
195
+ if QUANTIZATION_CONFIG_NAME in compression_config:
196
+ # for loaded HFQuantizer config from dict
197
+ return compression_config[QUANTIZATION_CONFIG_NAME]
198
+
188
199
  # SparseAutoModel format
189
200
  quantization_config = deepcopy(compression_config)
190
201
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
@@ -228,14 +239,79 @@ class ModelCompressor:
228
239
  quantized_modules_to_args = map_modules_to_quant_args(model)
229
240
  if self.quantization_compressor is not None:
230
241
  compressed_state_dict = self.quantization_compressor.compress(
231
- state_dict, model_quant_args=quantized_modules_to_args
242
+ state_dict, names_to_scheme=quantized_modules_to_args
232
243
  )
244
+ if self.quantization_config.format != CompressionFormat.dense.value:
245
+ self.quantization_config.quantization_status = (
246
+ QuantizationStatus.COMPRESSED
247
+ )
233
248
 
234
249
  if self.sparsity_compressor is not None:
235
250
  compressed_state_dict = self.sparsity_compressor.compress(
236
251
  compressed_state_dict
237
252
  )
238
253
 
254
+ # HACK (mgoin): Post-process step for kv cache scales to take the
255
+ # k/v_proj module `output_scale` parameters, and store them in the
256
+ # parent attention module as `k_scale` and `v_scale`
257
+ #
258
+ # Example:
259
+ # Replace `model.layers.0.self_attn.k_proj.output_scale`
260
+ # with `model.layers.0.self_attn.k_scale`
261
+ if (
262
+ self.quantization_config is not None
263
+ and self.quantization_config.kv_cache_scheme is not None
264
+ ):
265
+ # HACK (mgoin): We assume the quantized modules in question
266
+ # will be k_proj and v_proj since those are the default targets.
267
+ # We check that both of these modules have output activation
268
+ # quantization, and additionally check that q_proj doesn't.
269
+ q_proj_has_no_quant_output = 0
270
+ k_proj_has_quant_output = 0
271
+ v_proj_has_quant_output = 0
272
+ for name, module in model.named_modules():
273
+ if not hasattr(module, "quantization_scheme"):
274
+ # We still want to count non-quantized q_proj
275
+ if name.endswith(".q_proj"):
276
+ q_proj_has_no_quant_output += 1
277
+ continue
278
+ out_act = module.quantization_scheme.output_activations
279
+ if name.endswith(".q_proj") and out_act is None:
280
+ q_proj_has_no_quant_output += 1
281
+ elif name.endswith(".k_proj") and out_act is not None:
282
+ k_proj_has_quant_output += 1
283
+ elif name.endswith(".v_proj") and out_act is not None:
284
+ v_proj_has_quant_output += 1
285
+
286
+ assert (
287
+ q_proj_has_no_quant_output > 0
288
+ and k_proj_has_quant_output > 0
289
+ and v_proj_has_quant_output > 0
290
+ )
291
+ assert (
292
+ q_proj_has_no_quant_output
293
+ == k_proj_has_quant_output
294
+ == v_proj_has_quant_output
295
+ )
296
+
297
+ # Move all .k/v_proj.output_scale parameters to .k/v_scale
298
+ working_state_dict = {}
299
+ for key in compressed_state_dict.keys():
300
+ if key.endswith(".k_proj.output_scale"):
301
+ new_key = key.replace(".k_proj.output_scale", ".k_scale")
302
+ working_state_dict[new_key] = compressed_state_dict[key]
303
+ elif key.endswith(".v_proj.output_scale"):
304
+ new_key = key.replace(".v_proj.output_scale", ".v_scale")
305
+ working_state_dict[new_key] = compressed_state_dict[key]
306
+ else:
307
+ working_state_dict[key] = compressed_state_dict[key]
308
+ compressed_state_dict = working_state_dict
309
+
310
+ # HACK: Override the dtype_byte_size function in transformers to
311
+ # support float8 types. Fix is posted upstream
312
+ # https://github.com/huggingface/transformers/pull/30488
313
+ transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
314
+
239
315
  return compressed_state_dict
240
316
 
241
317
  def decompress(self, model_path: str, model: Module):
@@ -252,9 +328,11 @@ class ModelCompressor:
252
328
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
253
329
 
254
330
  if self.quantization_compressor is not None:
255
- apply_quantization_config(model, self.quantization_config)
331
+ names_to_scheme = apply_quantization_config(model, self.quantization_config)
256
332
  load_pretrained_quantization(model, model_path)
257
- dense_gen = self.quantization_compressor.decompress(model_path)
333
+ dense_gen = self.quantization_compressor.decompress(
334
+ model_path, names_to_scheme=names_to_scheme
335
+ )
258
336
  self._replace_weights(dense_gen, model)
259
337
 
260
338
  def update_status(module):
@@ -296,12 +374,10 @@ class ModelCompressor:
296
374
 
297
375
  def _replace_weights(self, dense_weight_generator, model):
298
376
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
299
- # loading the decompressed weights into the model
300
- model_device = operator.attrgetter(name)(model).device
301
- data_old = operator.attrgetter(name)(model)
302
- data_dtype = data_old.dtype
303
- data_new = Parameter(data.to(model_device).to(data_dtype))
304
- data_old.data = data_new.data
377
+ split_name = name.split(".")
378
+ prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
379
+ module = operator.attrgetter(prefix)(model)
380
+ update_parameter_data(module, data, param_name)
305
381
 
306
382
 
307
383
  def map_modules_to_quant_args(model: Module) -> Dict:
@@ -313,3 +389,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
313
389
  quantized_modules_to_args[name] = submodule.quantization_scheme.weights
314
390
 
315
391
  return quantized_modules_to_args
392
+
393
+
394
+ # HACK: Override the dtype_byte_size function in transformers to support float8 types
395
+ # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
396
+ def new_dtype_byte_size(dtype):
397
+ if dtype == torch.bool:
398
+ return 1 / 8
399
+ bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
400
+ if bit_search is None:
401
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
402
+ bit_size = int(bit_search.groups()[0])
403
+ return bit_size // 8
@@ -0,0 +1,140 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Optional, Tuple
17
+
18
+ import torch
19
+ from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.config import CompressionFormat
21
+ from compressed_tensors.quantization import QuantizationArgs
22
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
23
+ from compressed_tensors.quantization.utils import can_quantize
24
+ from torch import Tensor
25
+
26
+
27
+ __all__ = [
28
+ "QuantizationCompressor",
29
+ "IntQuantizationCompressor",
30
+ "FloatQuantizationCompressor",
31
+ ]
32
+
33
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
34
+
35
+
36
+ @Compressor.register(name=CompressionFormat.naive_quantized.value)
37
+ class QuantizationCompressor(Compressor):
38
+ """
39
+ Implements naive compression for quantized models. Weight of each
40
+ quantized layer is converted from its original float type to the closest Pytorch
41
+ type to the type specified by the layer's QuantizationArgs.
42
+ """
43
+
44
+ COMPRESSION_PARAM_NAMES = [
45
+ "weight",
46
+ "weight_scale",
47
+ "weight_zero_point",
48
+ "weight_g_idx",
49
+ ]
50
+
51
+ def compression_param_info(
52
+ self,
53
+ weight_shape: torch.Size,
54
+ quantization_args: Optional[QuantizationArgs] = None,
55
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
56
+ """
57
+ Creates a dictionary of expected shapes and dtypes for each compression
58
+ parameter used by the compressor
59
+
60
+ :param weight_shape: uncompressed weight shape
61
+ :param quantization_args: quantization parameters for the weight
62
+ :return: dictionary mapping compressed parameter names to shape and dtype
63
+ """
64
+ dtype = quantization_args.pytorch_dtype()
65
+ return {"weight": (weight_shape, dtype)}
66
+
67
+ def compress_weight(
68
+ self,
69
+ weight: Tensor,
70
+ scale: Tensor,
71
+ zero_point: Optional[Tensor] = None,
72
+ g_idx: Optional[torch.Tensor] = None,
73
+ quantization_args: Optional[QuantizationArgs] = None,
74
+ device: Optional[torch.device] = None,
75
+ ) -> Dict[str, torch.Tensor]:
76
+ """
77
+ Compresses a single uncompressed weight
78
+
79
+ :param weight: uncompressed weight tensor
80
+ :param scale: quantization scale for weight
81
+ :param zero_point: quantization zero point for weight
82
+ :param g_idx: optional mapping from column index to group index
83
+ :param quantization_args: quantization parameters for weight
84
+ :param device: optional device to move compressed output to
85
+ :return: dictionary of compressed weight data
86
+ """
87
+ if can_quantize(weight, quantization_args):
88
+ quantized_weight = quantize(
89
+ x=weight,
90
+ scale=scale,
91
+ zero_point=zero_point,
92
+ g_idx=g_idx,
93
+ args=quantization_args,
94
+ dtype=quantization_args.pytorch_dtype(),
95
+ )
96
+
97
+ if device is not None:
98
+ quantized_weight = quantized_weight.to(device)
99
+
100
+ return {"weight": quantized_weight}
101
+
102
+ def decompress_weight(
103
+ self,
104
+ compressed_data: Dict[str, Tensor],
105
+ quantization_args: Optional[QuantizationArgs] = None,
106
+ ) -> torch.Tensor:
107
+ """
108
+ Decompresses a single compressed weight
109
+
110
+ :param compressed_data: dictionary of data needed for decompression
111
+ :param quantization_args: quantization parameters for the weight
112
+ :return: tensor of the decompressed weight
113
+ """
114
+ weight = compressed_data["weight"]
115
+ scale = compressed_data["weight_scale"]
116
+ zero_point = compressed_data.get("weight_zero_point", None)
117
+ g_idx = compressed_data.get("weight_g_idx", None)
118
+ decompressed_weight = dequantize(
119
+ x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx
120
+ )
121
+
122
+ return decompressed_weight
123
+
124
+
125
+ @Compressor.register(name=CompressionFormat.int_quantized.value)
126
+ class IntQuantizationCompressor(QuantizationCompressor):
127
+ """
128
+ Alias for integer quantized models
129
+ """
130
+
131
+ pass
132
+
133
+
134
+ @Compressor.register(name=CompressionFormat.float_quantized.value)
135
+ class FloatQuantizationCompressor(QuantizationCompressor):
136
+ """
137
+ Alias for fp quantized models
138
+ """
139
+
140
+ pass