compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import operator
18
+ import os
19
+ import re
20
+ from copy import deepcopy
21
+ from typing import Any, Dict, Optional, Union
22
+
23
+ import torch
24
+ import transformers
25
+ from compressed_tensors.base import (
26
+ COMPRESSION_CONFIG_NAME,
27
+ QUANTIZATION_CONFIG_NAME,
28
+ SPARSITY_CONFIG_NAME,
29
+ )
30
+ from compressed_tensors.compressors import Compressor
31
+ from compressed_tensors.config import SparsityCompressionConfig
32
+ from compressed_tensors.quantization import (
33
+ QuantizationConfig,
34
+ QuantizationStatus,
35
+ apply_quantization_config,
36
+ load_pretrained_quantization,
37
+ )
38
+ from compressed_tensors.quantization.utils import (
39
+ is_module_quantized,
40
+ iter_named_leaf_modules,
41
+ )
42
+ from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
43
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name
44
+ from torch import Tensor
45
+ from torch.nn import Module
46
+ from tqdm import tqdm
47
+ from transformers import AutoConfig
48
+ from transformers.file_utils import CONFIG_NAME
49
+
50
+
51
+ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
52
+
53
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
54
+
55
+
56
+ class ModelCompressor:
57
+ """
58
+ Handles compression and decompression of a model with a sparsity config and/or
59
+ quantization config.
60
+
61
+ Compression LifeCycle
62
+ - compressor = ModelCompressor.from_pretrained_model(model)
63
+ - compressed_state_dict = compressor.compress(model, state_dict)
64
+ - compressor.quantization_compressor.compress(model, state_dict)
65
+ - compressor.sparsity_compressor.compress(model, state_dict)
66
+ - model.save_pretrained(output_dir, state_dict=compressed_state_dict)
67
+ - compressor.update_config(output_dir)
68
+
69
+ Decompression LifeCycle
70
+ - compressor = ModelCompressor.from_pretrained(comp_model_path)
71
+ - model = AutoModel.from_pretrained(comp_model_path)
72
+ - compressor.decompress(comp_model_path, model)
73
+ - compressor.sparsity_compressor.decompress(comp_model_path, model)
74
+ - compressor.quantization_compressor.decompress(comp_model_path, model)
75
+
76
+ :param sparsity_config: config specifying sparsity compression parameters
77
+ :param quantization_config: config specifying quantization compression parameters
78
+ """
79
+
80
+ @classmethod
81
+ def from_pretrained(
82
+ cls,
83
+ pretrained_model_name_or_path: str,
84
+ **kwargs,
85
+ ) -> Optional["ModelCompressor"]:
86
+ """
87
+ Given a path to a model config, extract the sparsity and/or quantization
88
+ configs and load a ModelCompressor
89
+
90
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
91
+ :return: compressor for the extracted configs
92
+ """
93
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
94
+ compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
95
+ return cls.from_compression_config(compression_config)
96
+
97
+ @classmethod
98
+ def from_compression_config(cls, compression_config: Dict[str, Any]):
99
+ """
100
+ :param compression_config: compression/quantization config dictionary
101
+ found under key "quantization_config" in HF model config
102
+ :return: compressor for the extracted configs
103
+ """
104
+ if compression_config is None:
105
+ return None
106
+
107
+ try:
108
+ from transformers.utils.quantization_config import CompressedTensorsConfig
109
+
110
+ if isinstance(compression_config, CompressedTensorsConfig):
111
+ compression_config = compression_config.to_dict()
112
+ except ImportError:
113
+ pass
114
+
115
+ sparsity_config = cls.parse_sparsity_config(compression_config)
116
+ quantization_config = cls.parse_quantization_config(compression_config)
117
+ if sparsity_config is None and quantization_config is None:
118
+ return None
119
+
120
+ if sparsity_config is not None and not isinstance(
121
+ sparsity_config, SparsityCompressionConfig
122
+ ):
123
+ format = sparsity_config.get("format")
124
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
125
+ format, **sparsity_config
126
+ )
127
+ if quantization_config is not None and not isinstance(
128
+ quantization_config, QuantizationConfig
129
+ ):
130
+ quantization_config = QuantizationConfig.parse_obj(quantization_config)
131
+
132
+ return cls(
133
+ sparsity_config=sparsity_config, quantization_config=quantization_config
134
+ )
135
+
136
+ @classmethod
137
+ def from_pretrained_model(
138
+ cls,
139
+ model: Module,
140
+ sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
141
+ quantization_format: Optional[str] = None,
142
+ ) -> Optional["ModelCompressor"]:
143
+ """
144
+ Given a pytorch model and optional sparsity and/or quantization configs,
145
+ load the appropriate compressors
146
+
147
+ :param model: pytorch model to target for compression
148
+ :param sparsity_config: a filled in sparsity config or string corresponding
149
+ to a sparsity compression algorithm
150
+ :param quantization_format: string corresponding to a quantization compression
151
+ algorithm
152
+ :return: compressor for the extracted configs
153
+ """
154
+ quantization_config = QuantizationConfig.from_pretrained(
155
+ model, format=quantization_format
156
+ )
157
+
158
+ if isinstance(sparsity_config, str): # we passed in a sparsity format
159
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
160
+ sparsity_config
161
+ )
162
+
163
+ if sparsity_config is None and quantization_config is None:
164
+ return None
165
+
166
+ return cls(
167
+ sparsity_config=sparsity_config, quantization_config=quantization_config
168
+ )
169
+
170
+ @staticmethod
171
+ def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
172
+ if compression_config is None:
173
+ return None
174
+ if SPARSITY_CONFIG_NAME not in compression_config:
175
+ return None
176
+ if hasattr(compression_config, SPARSITY_CONFIG_NAME):
177
+ # for loaded HFQuantizer config
178
+ return getattr(compression_config, SPARSITY_CONFIG_NAME)
179
+
180
+ # SparseAutoModel format
181
+ return compression_config.get(SPARSITY_CONFIG_NAME, None)
182
+
183
+ @staticmethod
184
+ def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
185
+ if compression_config is None:
186
+ return None
187
+
188
+ if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
189
+ # for loaded HFQuantizer config
190
+ return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
191
+
192
+ # SparseAutoModel format
193
+ quantization_config = deepcopy(compression_config)
194
+ quantization_config.pop(SPARSITY_CONFIG_NAME, None)
195
+ if len(quantization_config) == 0:
196
+ quantization_config = None
197
+ return quantization_config
198
+
199
+ def __init__(
200
+ self,
201
+ sparsity_config: Optional[SparsityCompressionConfig] = None,
202
+ quantization_config: Optional[QuantizationConfig] = None,
203
+ ):
204
+ self.sparsity_config = sparsity_config
205
+ self.quantization_config = quantization_config
206
+ self.sparsity_compressor = None
207
+ self.quantization_compressor = None
208
+
209
+ if sparsity_config is not None:
210
+ self.sparsity_compressor = Compressor.load_from_registry(
211
+ sparsity_config.format, config=sparsity_config
212
+ )
213
+ if quantization_config is not None:
214
+ self.quantization_compressor = Compressor.load_from_registry(
215
+ quantization_config.format, config=quantization_config
216
+ )
217
+
218
+ def compress(
219
+ self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
220
+ ) -> Dict[str, Tensor]:
221
+ """
222
+ Compresses a dense state dict or model with sparsity and/or quantization
223
+
224
+ :param model: uncompressed model to compress
225
+ :param model_state: optional uncompressed state_dict to insert into model
226
+ :return: compressed state dict
227
+ """
228
+ if state_dict is None:
229
+ state_dict = model.state_dict()
230
+
231
+ compressed_state_dict = state_dict
232
+ quantized_modules_to_args = map_modules_to_quant_args(model)
233
+ if self.quantization_compressor is not None:
234
+ compressed_state_dict = self.quantization_compressor.compress(
235
+ state_dict, names_to_scheme=quantized_modules_to_args
236
+ )
237
+
238
+ if self.sparsity_compressor is not None:
239
+ compressed_state_dict = self.sparsity_compressor.compress(
240
+ compressed_state_dict
241
+ )
242
+
243
+ # HACK: Override the dtype_byte_size function in transformers to
244
+ # support float8 types. Fix is posted upstream
245
+ # https://github.com/huggingface/transformers/pull/30488
246
+ transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
247
+
248
+ return compressed_state_dict
249
+
250
+ def decompress(self, model_path: str, model: Module):
251
+ """
252
+ Overwrites the weights in model with weights decompressed from model_path
253
+
254
+ :param model_path: path to compressed weights
255
+ :param model: pytorch model to load decompressed weights into
256
+ """
257
+ model_path = get_safetensors_folder(model_path)
258
+ if self.sparsity_compressor is not None:
259
+ dense_gen = self.sparsity_compressor.decompress(model_path)
260
+ self._replace_weights(dense_gen, model)
261
+ setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
262
+
263
+ if self.quantization_compressor is not None:
264
+ names_to_scheme = apply_quantization_config(model, self.quantization_config)
265
+ load_pretrained_quantization(model, model_path)
266
+ dense_gen = self.quantization_compressor.decompress(
267
+ model_path, names_to_scheme=names_to_scheme
268
+ )
269
+ self._replace_weights(dense_gen, model)
270
+
271
+ def update_status(module):
272
+ module.quantization_status = QuantizationStatus.FROZEN
273
+
274
+ model.apply(update_status)
275
+ setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
276
+
277
+ def update_config(self, save_directory: str):
278
+ """
279
+ Update the model config located at save_directory with compression configs
280
+ for sparsity and/or quantization
281
+
282
+ :param save_directory: path to a folder containing a HF model config
283
+ """
284
+ config_file_path = os.path.join(save_directory, CONFIG_NAME)
285
+ if not os.path.exists(config_file_path):
286
+ _LOGGER.warning(
287
+ f"Could not find a valid model config file in "
288
+ f"{save_directory}. Compression config will not be saved."
289
+ )
290
+ return
291
+
292
+ with open(config_file_path, "r") as config_file:
293
+ config_data = json.load(config_file)
294
+
295
+ config_data[COMPRESSION_CONFIG_NAME] = {}
296
+ if self.quantization_config is not None:
297
+ quant_config_data = self.quantization_config.model_dump()
298
+ config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
299
+ if self.sparsity_config is not None:
300
+ sparsity_config_data = self.sparsity_config.model_dump()
301
+ config_data[COMPRESSION_CONFIG_NAME][
302
+ SPARSITY_CONFIG_NAME
303
+ ] = sparsity_config_data
304
+
305
+ with open(config_file_path, "w") as config_file:
306
+ json.dump(config_data, config_file, indent=2, sort_keys=True)
307
+
308
+ def _replace_weights(self, dense_weight_generator, model):
309
+ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
310
+ split_name = name.split(".")
311
+ prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
312
+ module = operator.attrgetter(prefix)(model)
313
+ update_parameter_data(module, data, param_name)
314
+
315
+
316
+ def map_modules_to_quant_args(model: Module) -> Dict:
317
+ quantized_modules_to_args = {}
318
+ for name, submodule in iter_named_leaf_modules(model):
319
+ if is_module_quantized(submodule):
320
+ if submodule.quantization_scheme.weights is not None:
321
+ name = fix_fsdp_module_name(name)
322
+ quantized_modules_to_args[name] = submodule.quantization_scheme.weights
323
+
324
+ return quantized_modules_to_args
325
+
326
+
327
+ # HACK: Override the dtype_byte_size function in transformers to support float8 types
328
+ # Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
329
+ def new_dtype_byte_size(dtype):
330
+ if dtype == torch.bool:
331
+ return 1 / 8
332
+ bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
333
+ if bit_search is None:
334
+ raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
335
+ bit_size = int(bit_search.groups()[0])
336
+ return bit_size // 8
@@ -0,0 +1,144 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ import torch
19
+ from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.config import CompressionFormat
21
+ from compressed_tensors.quantization import QuantizationArgs
22
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
23
+ from compressed_tensors.quantization.utils import can_quantize
24
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
25
+ from safetensors import safe_open
26
+ from torch import Tensor
27
+ from tqdm import tqdm
28
+
29
+
30
+ __all__ = [
31
+ "QuantizationCompressor",
32
+ "IntQuantizationCompressor",
33
+ "FloatQuantizationCompressor",
34
+ ]
35
+
36
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
37
+
38
+
39
+ @Compressor.register(name=CompressionFormat.naive_quantized.value)
40
+ class QuantizationCompressor(Compressor):
41
+ """
42
+ Implements naive compression for quantized models. Weight of each
43
+ quantized layer is converted from its original float type to the closest Pytorch
44
+ type to the type specified by the layer's QuantizationArgs.
45
+ """
46
+
47
+ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
48
+
49
+ def compress(
50
+ self,
51
+ model_state: Dict[str, Tensor],
52
+ names_to_scheme: Dict[str, QuantizationArgs],
53
+ **kwargs,
54
+ ) -> Dict[str, Tensor]:
55
+ """
56
+ Compresses a dense state dict
57
+
58
+ :param model_state: state dict of uncompressed model
59
+ :param names_to_scheme: quantization args for each quantized weight, needed for
60
+ quantize function to calculate bit depth
61
+ :return: compressed state dict
62
+ """
63
+ compressed_dict = {}
64
+ weight_suffix = ".weight"
65
+ _LOGGER.debug(
66
+ f"Compressing model with {len(model_state)} parameterized layers..."
67
+ )
68
+
69
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
70
+ if name.endswith(weight_suffix):
71
+ prefix = name[: -(len(weight_suffix))]
72
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
73
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
74
+ if scale is not None and zp is not None:
75
+ # weight is quantized, compress it
76
+ quant_args = names_to_scheme[prefix]
77
+ if can_quantize(value, quant_args):
78
+ # only quantize if not already quantized
79
+ value = quantize(
80
+ x=value,
81
+ scale=scale,
82
+ zero_point=zp,
83
+ args=quant_args,
84
+ dtype=quant_args.pytorch_dtype(),
85
+ )
86
+ elif name.endswith("zero_point"):
87
+ if torch.all(value == 0):
88
+ # all zero_points are 0, no need to include in
89
+ # compressed state_dict
90
+ continue
91
+ compressed_dict[name] = value.to("cpu")
92
+
93
+ return compressed_dict
94
+
95
+ def decompress(
96
+ self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
97
+ ) -> Generator[Tuple[str, Tensor], None, None]:
98
+ """
99
+ Reads a compressed state dict located at path_to_model_or_tensors
100
+ and returns a generator for sequentially decompressing back to a
101
+ dense state dict
102
+
103
+ :param model_path: path to compressed safetensors model (directory with
104
+ one or more safetensors files) or compressed tensors file
105
+ :param device: optional device to load intermediate weights into
106
+ :return: compressed state dict
107
+ """
108
+ weight_mappings = get_nested_weight_mappings(
109
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
110
+ )
111
+ for weight_name in weight_mappings.keys():
112
+ weight_data = {}
113
+ for param_name, safe_path in weight_mappings[weight_name].items():
114
+ full_name = merge_names(weight_name, param_name)
115
+ with safe_open(safe_path, framework="pt", device=device) as f:
116
+ weight_data[param_name] = f.get_tensor(full_name)
117
+
118
+ if "weight_scale" in weight_data:
119
+ zero_point = weight_data.get("weight_zero_point", None)
120
+ scale = weight_data["weight_scale"]
121
+ decompressed = dequantize(
122
+ x_q=weight_data["weight"],
123
+ scale=scale,
124
+ zero_point=zero_point,
125
+ )
126
+ yield merge_names(weight_name, "weight"), decompressed
127
+
128
+
129
+ @Compressor.register(name=CompressionFormat.int_quantized.value)
130
+ class IntQuantizationCompressor(QuantizationCompressor):
131
+ """
132
+ Alias for integer quantized models
133
+ """
134
+
135
+ pass
136
+
137
+
138
+ @Compressor.register(name=CompressionFormat.float_quantized.value)
139
+ class FloatQuantizationCompressor(QuantizationCompressor):
140
+ """
141
+ Alias for fp quantized models
142
+ """
143
+
144
+ pass
@@ -0,0 +1,219 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import math
17
+ from typing import Dict, Generator, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ from compressed_tensors.compressors import Compressor
22
+ from compressed_tensors.config import CompressionFormat
23
+ from compressed_tensors.quantization import QuantizationArgs
24
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25
+ from compressed_tensors.quantization.utils import can_quantize
26
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
27
+ from safetensors import safe_open
28
+ from torch import Tensor
29
+ from tqdm import tqdm
30
+
31
+
32
+ __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
33
+
34
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
35
+
36
+
37
+ @Compressor.register(name=CompressionFormat.pack_quantized.value)
38
+ class PackedQuantizationCompressor(Compressor):
39
+ """
40
+ Compresses a quantized model by packing every eight 4-bit weights into an int32
41
+ """
42
+
43
+ COMPRESSION_PARAM_NAMES = [
44
+ "weight_packed",
45
+ "weight_scale",
46
+ "weight_zero_point",
47
+ "weight_shape",
48
+ ]
49
+
50
+ def compress(
51
+ self,
52
+ model_state: Dict[str, Tensor],
53
+ names_to_scheme: Dict[str, QuantizationArgs],
54
+ **kwargs,
55
+ ) -> Dict[str, Tensor]:
56
+ """
57
+ Compresses a dense state dict
58
+
59
+ :param model_state: state dict of uncompressed model
60
+ :param names_to_scheme: quantization args for each quantized weight, needed for
61
+ quantize function to calculate bit depth
62
+ :return: compressed state dict
63
+ """
64
+ compressed_dict = {}
65
+ weight_suffix = ".weight"
66
+ _LOGGER.debug(
67
+ f"Compressing model with {len(model_state)} parameterized layers..."
68
+ )
69
+
70
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
71
+ if name.endswith(weight_suffix):
72
+ prefix = name[: -(len(weight_suffix))]
73
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
74
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
75
+ shape = torch.tensor(value.shape)
76
+ if scale is not None and zp is not None:
77
+ # weight is quantized, compress it
78
+ quant_args = names_to_scheme[prefix]
79
+ if can_quantize(value, quant_args):
80
+ # convert weight to an int if not already compressed
81
+ value = quantize(
82
+ x=value,
83
+ scale=scale,
84
+ zero_point=zp,
85
+ args=quant_args,
86
+ dtype=torch.int8,
87
+ )
88
+ value = pack_to_int32(value.cpu(), quant_args.num_bits)
89
+ compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
+ compressed_dict[merge_names(prefix, "weight_packed")] = value
91
+ continue
92
+
93
+ elif name.endswith("zero_point"):
94
+ if torch.all(value == 0):
95
+ # all zero_points are 0, no need to include in
96
+ # compressed state_dict
97
+ continue
98
+
99
+ compressed_dict[name] = value.to("cpu")
100
+
101
+ return compressed_dict
102
+
103
+ def decompress(
104
+ self,
105
+ path_to_model_or_tensors: str,
106
+ names_to_scheme: Dict[str, QuantizationArgs],
107
+ device: str = "cpu",
108
+ ) -> Generator[Tuple[str, Tensor], None, None]:
109
+ """
110
+ Reads a compressed state dict located at path_to_model_or_tensors
111
+ and returns a generator for sequentially decompressing back to a
112
+ dense state dict
113
+
114
+ :param model_path: path to compressed safetensors model (directory with
115
+ one or more safetensors files) or compressed tensors file
116
+ :param device: optional device to load intermediate weights into
117
+ :return: compressed state dict
118
+ """
119
+ weight_mappings = get_nested_weight_mappings(
120
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
121
+ )
122
+ for weight_name in weight_mappings.keys():
123
+ weight_data = {}
124
+ for param_name, safe_path in weight_mappings[weight_name].items():
125
+ weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
126
+ full_name = merge_names(weight_name, param_name)
127
+ with safe_open(safe_path, framework="pt", device=device) as f:
128
+ weight_data[param_name] = f.get_tensor(full_name)
129
+
130
+ if "weight_scale" in weight_data:
131
+ zero_point = weight_data.get("weight_zero_point", None)
132
+ scale = weight_data["weight_scale"]
133
+ weight = weight_data["weight_packed"]
134
+ num_bits = weight_data["num_bits"]
135
+ original_shape = torch.Size(weight_data["weight_shape"])
136
+ unpacked = unpack_from_int32(weight, num_bits, original_shape)
137
+ decompressed = dequantize(
138
+ x_q=unpacked,
139
+ scale=scale,
140
+ zero_point=zero_point,
141
+ )
142
+ yield merge_names(weight_name, "weight"), decompressed
143
+
144
+
145
+ def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
146
+ """
147
+ Packs a tensor of quantized weights stored in int8 into int32s with padding
148
+
149
+ :param value: tensor to pack
150
+ :param num_bits: number of bits used to store underlying data
151
+ :returns: packed int32 tensor
152
+ """
153
+ if value.dtype is not torch.int8:
154
+ raise ValueError("Tensor must be quantized to torch.int8 before packing")
155
+
156
+ if num_bits > 8:
157
+ raise ValueError("Packing is only supported for less than 8 bits")
158
+
159
+ # convert to unsigned for packing
160
+ offset = pow(2, num_bits) // 2
161
+ value = (value + offset).to(torch.uint8)
162
+ value = value.cpu().numpy().astype(np.uint32)
163
+ pack_factor = 32 // num_bits
164
+
165
+ # pad input tensor and initialize packed output
166
+ packed_size = math.ceil(value.shape[1] / pack_factor)
167
+ packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
168
+ padding = packed.shape[1] * pack_factor - value.shape[1]
169
+ value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
170
+
171
+ # pack values
172
+ for i in range(pack_factor):
173
+ packed |= value[:, i::pack_factor] << num_bits * i
174
+
175
+ # convert back to signed and torch
176
+ packed = np.ascontiguousarray(packed).view(np.int32)
177
+ return torch.from_numpy(packed)
178
+
179
+
180
+ def unpack_from_int32(
181
+ value: torch.Tensor, num_bits: int, shape: torch.Size
182
+ ) -> torch.Tensor:
183
+ """
184
+ Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
185
+ original their bit range
186
+
187
+ :param value: tensor to upack
188
+ :param num_bits: number of bits to unpack each data point into
189
+ :param shape: shape to unpack into, used to remove padding
190
+ :returns: unpacked int8 tensor
191
+ """
192
+ if value.dtype is not torch.int32:
193
+ raise ValueError(
194
+ f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
195
+ )
196
+
197
+ if num_bits > 8:
198
+ raise ValueError("Unpacking is only supported for less than 8 bits")
199
+
200
+ # convert packed input to unsigned numpy
201
+ value = value.numpy().view(np.uint32)
202
+ pack_factor = 32 // num_bits
203
+
204
+ # unpack
205
+ mask = pow(2, num_bits) - 1
206
+ unpacked = np.zeros((value.shape[0], value.shape[1] * pack_factor))
207
+ for i in range(pack_factor):
208
+ unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
209
+
210
+ # remove padding
211
+ original_row_size = int(shape[1])
212
+ unpacked = unpacked[:, :original_row_size]
213
+
214
+ # bits are packed in unsigned format, reformat to signed
215
+ # update the value range from unsigned to signed
216
+ offset = pow(2, num_bits) // 2
217
+ unpacked = (unpacked.astype(np.int16) - offset).astype(np.int8)
218
+
219
+ return torch.from_numpy(unpacked)