compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,336 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import json
|
16
|
+
import logging
|
17
|
+
import operator
|
18
|
+
import os
|
19
|
+
import re
|
20
|
+
from copy import deepcopy
|
21
|
+
from typing import Any, Dict, Optional, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
import transformers
|
25
|
+
from compressed_tensors.base import (
|
26
|
+
COMPRESSION_CONFIG_NAME,
|
27
|
+
QUANTIZATION_CONFIG_NAME,
|
28
|
+
SPARSITY_CONFIG_NAME,
|
29
|
+
)
|
30
|
+
from compressed_tensors.compressors import Compressor
|
31
|
+
from compressed_tensors.config import SparsityCompressionConfig
|
32
|
+
from compressed_tensors.quantization import (
|
33
|
+
QuantizationConfig,
|
34
|
+
QuantizationStatus,
|
35
|
+
apply_quantization_config,
|
36
|
+
load_pretrained_quantization,
|
37
|
+
)
|
38
|
+
from compressed_tensors.quantization.utils import (
|
39
|
+
is_module_quantized,
|
40
|
+
iter_named_leaf_modules,
|
41
|
+
)
|
42
|
+
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
|
43
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
44
|
+
from torch import Tensor
|
45
|
+
from torch.nn import Module
|
46
|
+
from tqdm import tqdm
|
47
|
+
from transformers import AutoConfig
|
48
|
+
from transformers.file_utils import CONFIG_NAME
|
49
|
+
|
50
|
+
|
51
|
+
__all__ = ["ModelCompressor", "map_modules_to_quant_args"]
|
52
|
+
|
53
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
54
|
+
|
55
|
+
|
56
|
+
class ModelCompressor:
|
57
|
+
"""
|
58
|
+
Handles compression and decompression of a model with a sparsity config and/or
|
59
|
+
quantization config.
|
60
|
+
|
61
|
+
Compression LifeCycle
|
62
|
+
- compressor = ModelCompressor.from_pretrained_model(model)
|
63
|
+
- compressed_state_dict = compressor.compress(model, state_dict)
|
64
|
+
- compressor.quantization_compressor.compress(model, state_dict)
|
65
|
+
- compressor.sparsity_compressor.compress(model, state_dict)
|
66
|
+
- model.save_pretrained(output_dir, state_dict=compressed_state_dict)
|
67
|
+
- compressor.update_config(output_dir)
|
68
|
+
|
69
|
+
Decompression LifeCycle
|
70
|
+
- compressor = ModelCompressor.from_pretrained(comp_model_path)
|
71
|
+
- model = AutoModel.from_pretrained(comp_model_path)
|
72
|
+
- compressor.decompress(comp_model_path, model)
|
73
|
+
- compressor.sparsity_compressor.decompress(comp_model_path, model)
|
74
|
+
- compressor.quantization_compressor.decompress(comp_model_path, model)
|
75
|
+
|
76
|
+
:param sparsity_config: config specifying sparsity compression parameters
|
77
|
+
:param quantization_config: config specifying quantization compression parameters
|
78
|
+
"""
|
79
|
+
|
80
|
+
@classmethod
|
81
|
+
def from_pretrained(
|
82
|
+
cls,
|
83
|
+
pretrained_model_name_or_path: str,
|
84
|
+
**kwargs,
|
85
|
+
) -> Optional["ModelCompressor"]:
|
86
|
+
"""
|
87
|
+
Given a path to a model config, extract the sparsity and/or quantization
|
88
|
+
configs and load a ModelCompressor
|
89
|
+
|
90
|
+
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
91
|
+
:return: compressor for the extracted configs
|
92
|
+
"""
|
93
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
94
|
+
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
|
95
|
+
return cls.from_compression_config(compression_config)
|
96
|
+
|
97
|
+
@classmethod
|
98
|
+
def from_compression_config(cls, compression_config: Dict[str, Any]):
|
99
|
+
"""
|
100
|
+
:param compression_config: compression/quantization config dictionary
|
101
|
+
found under key "quantization_config" in HF model config
|
102
|
+
:return: compressor for the extracted configs
|
103
|
+
"""
|
104
|
+
if compression_config is None:
|
105
|
+
return None
|
106
|
+
|
107
|
+
try:
|
108
|
+
from transformers.utils.quantization_config import CompressedTensorsConfig
|
109
|
+
|
110
|
+
if isinstance(compression_config, CompressedTensorsConfig):
|
111
|
+
compression_config = compression_config.to_dict()
|
112
|
+
except ImportError:
|
113
|
+
pass
|
114
|
+
|
115
|
+
sparsity_config = cls.parse_sparsity_config(compression_config)
|
116
|
+
quantization_config = cls.parse_quantization_config(compression_config)
|
117
|
+
if sparsity_config is None and quantization_config is None:
|
118
|
+
return None
|
119
|
+
|
120
|
+
if sparsity_config is not None and not isinstance(
|
121
|
+
sparsity_config, SparsityCompressionConfig
|
122
|
+
):
|
123
|
+
format = sparsity_config.get("format")
|
124
|
+
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
125
|
+
format, **sparsity_config
|
126
|
+
)
|
127
|
+
if quantization_config is not None and not isinstance(
|
128
|
+
quantization_config, QuantizationConfig
|
129
|
+
):
|
130
|
+
quantization_config = QuantizationConfig.parse_obj(quantization_config)
|
131
|
+
|
132
|
+
return cls(
|
133
|
+
sparsity_config=sparsity_config, quantization_config=quantization_config
|
134
|
+
)
|
135
|
+
|
136
|
+
@classmethod
|
137
|
+
def from_pretrained_model(
|
138
|
+
cls,
|
139
|
+
model: Module,
|
140
|
+
sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
|
141
|
+
quantization_format: Optional[str] = None,
|
142
|
+
) -> Optional["ModelCompressor"]:
|
143
|
+
"""
|
144
|
+
Given a pytorch model and optional sparsity and/or quantization configs,
|
145
|
+
load the appropriate compressors
|
146
|
+
|
147
|
+
:param model: pytorch model to target for compression
|
148
|
+
:param sparsity_config: a filled in sparsity config or string corresponding
|
149
|
+
to a sparsity compression algorithm
|
150
|
+
:param quantization_format: string corresponding to a quantization compression
|
151
|
+
algorithm
|
152
|
+
:return: compressor for the extracted configs
|
153
|
+
"""
|
154
|
+
quantization_config = QuantizationConfig.from_pretrained(
|
155
|
+
model, format=quantization_format
|
156
|
+
)
|
157
|
+
|
158
|
+
if isinstance(sparsity_config, str): # we passed in a sparsity format
|
159
|
+
sparsity_config = SparsityCompressionConfig.load_from_registry(
|
160
|
+
sparsity_config
|
161
|
+
)
|
162
|
+
|
163
|
+
if sparsity_config is None and quantization_config is None:
|
164
|
+
return None
|
165
|
+
|
166
|
+
return cls(
|
167
|
+
sparsity_config=sparsity_config, quantization_config=quantization_config
|
168
|
+
)
|
169
|
+
|
170
|
+
@staticmethod
|
171
|
+
def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
|
172
|
+
if compression_config is None:
|
173
|
+
return None
|
174
|
+
if SPARSITY_CONFIG_NAME not in compression_config:
|
175
|
+
return None
|
176
|
+
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
|
177
|
+
# for loaded HFQuantizer config
|
178
|
+
return getattr(compression_config, SPARSITY_CONFIG_NAME)
|
179
|
+
|
180
|
+
# SparseAutoModel format
|
181
|
+
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
182
|
+
|
183
|
+
@staticmethod
|
184
|
+
def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
|
185
|
+
if compression_config is None:
|
186
|
+
return None
|
187
|
+
|
188
|
+
if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
|
189
|
+
# for loaded HFQuantizer config
|
190
|
+
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
|
191
|
+
|
192
|
+
# SparseAutoModel format
|
193
|
+
quantization_config = deepcopy(compression_config)
|
194
|
+
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
195
|
+
if len(quantization_config) == 0:
|
196
|
+
quantization_config = None
|
197
|
+
return quantization_config
|
198
|
+
|
199
|
+
def __init__(
|
200
|
+
self,
|
201
|
+
sparsity_config: Optional[SparsityCompressionConfig] = None,
|
202
|
+
quantization_config: Optional[QuantizationConfig] = None,
|
203
|
+
):
|
204
|
+
self.sparsity_config = sparsity_config
|
205
|
+
self.quantization_config = quantization_config
|
206
|
+
self.sparsity_compressor = None
|
207
|
+
self.quantization_compressor = None
|
208
|
+
|
209
|
+
if sparsity_config is not None:
|
210
|
+
self.sparsity_compressor = Compressor.load_from_registry(
|
211
|
+
sparsity_config.format, config=sparsity_config
|
212
|
+
)
|
213
|
+
if quantization_config is not None:
|
214
|
+
self.quantization_compressor = Compressor.load_from_registry(
|
215
|
+
quantization_config.format, config=quantization_config
|
216
|
+
)
|
217
|
+
|
218
|
+
def compress(
|
219
|
+
self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
|
220
|
+
) -> Dict[str, Tensor]:
|
221
|
+
"""
|
222
|
+
Compresses a dense state dict or model with sparsity and/or quantization
|
223
|
+
|
224
|
+
:param model: uncompressed model to compress
|
225
|
+
:param model_state: optional uncompressed state_dict to insert into model
|
226
|
+
:return: compressed state dict
|
227
|
+
"""
|
228
|
+
if state_dict is None:
|
229
|
+
state_dict = model.state_dict()
|
230
|
+
|
231
|
+
compressed_state_dict = state_dict
|
232
|
+
quantized_modules_to_args = map_modules_to_quant_args(model)
|
233
|
+
if self.quantization_compressor is not None:
|
234
|
+
compressed_state_dict = self.quantization_compressor.compress(
|
235
|
+
state_dict, names_to_scheme=quantized_modules_to_args
|
236
|
+
)
|
237
|
+
|
238
|
+
if self.sparsity_compressor is not None:
|
239
|
+
compressed_state_dict = self.sparsity_compressor.compress(
|
240
|
+
compressed_state_dict
|
241
|
+
)
|
242
|
+
|
243
|
+
# HACK: Override the dtype_byte_size function in transformers to
|
244
|
+
# support float8 types. Fix is posted upstream
|
245
|
+
# https://github.com/huggingface/transformers/pull/30488
|
246
|
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
247
|
+
|
248
|
+
return compressed_state_dict
|
249
|
+
|
250
|
+
def decompress(self, model_path: str, model: Module):
|
251
|
+
"""
|
252
|
+
Overwrites the weights in model with weights decompressed from model_path
|
253
|
+
|
254
|
+
:param model_path: path to compressed weights
|
255
|
+
:param model: pytorch model to load decompressed weights into
|
256
|
+
"""
|
257
|
+
model_path = get_safetensors_folder(model_path)
|
258
|
+
if self.sparsity_compressor is not None:
|
259
|
+
dense_gen = self.sparsity_compressor.decompress(model_path)
|
260
|
+
self._replace_weights(dense_gen, model)
|
261
|
+
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
262
|
+
|
263
|
+
if self.quantization_compressor is not None:
|
264
|
+
names_to_scheme = apply_quantization_config(model, self.quantization_config)
|
265
|
+
load_pretrained_quantization(model, model_path)
|
266
|
+
dense_gen = self.quantization_compressor.decompress(
|
267
|
+
model_path, names_to_scheme=names_to_scheme
|
268
|
+
)
|
269
|
+
self._replace_weights(dense_gen, model)
|
270
|
+
|
271
|
+
def update_status(module):
|
272
|
+
module.quantization_status = QuantizationStatus.FROZEN
|
273
|
+
|
274
|
+
model.apply(update_status)
|
275
|
+
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
|
276
|
+
|
277
|
+
def update_config(self, save_directory: str):
|
278
|
+
"""
|
279
|
+
Update the model config located at save_directory with compression configs
|
280
|
+
for sparsity and/or quantization
|
281
|
+
|
282
|
+
:param save_directory: path to a folder containing a HF model config
|
283
|
+
"""
|
284
|
+
config_file_path = os.path.join(save_directory, CONFIG_NAME)
|
285
|
+
if not os.path.exists(config_file_path):
|
286
|
+
_LOGGER.warning(
|
287
|
+
f"Could not find a valid model config file in "
|
288
|
+
f"{save_directory}. Compression config will not be saved."
|
289
|
+
)
|
290
|
+
return
|
291
|
+
|
292
|
+
with open(config_file_path, "r") as config_file:
|
293
|
+
config_data = json.load(config_file)
|
294
|
+
|
295
|
+
config_data[COMPRESSION_CONFIG_NAME] = {}
|
296
|
+
if self.quantization_config is not None:
|
297
|
+
quant_config_data = self.quantization_config.model_dump()
|
298
|
+
config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
|
299
|
+
if self.sparsity_config is not None:
|
300
|
+
sparsity_config_data = self.sparsity_config.model_dump()
|
301
|
+
config_data[COMPRESSION_CONFIG_NAME][
|
302
|
+
SPARSITY_CONFIG_NAME
|
303
|
+
] = sparsity_config_data
|
304
|
+
|
305
|
+
with open(config_file_path, "w") as config_file:
|
306
|
+
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
307
|
+
|
308
|
+
def _replace_weights(self, dense_weight_generator, model):
|
309
|
+
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
310
|
+
split_name = name.split(".")
|
311
|
+
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
312
|
+
module = operator.attrgetter(prefix)(model)
|
313
|
+
update_parameter_data(module, data, param_name)
|
314
|
+
|
315
|
+
|
316
|
+
def map_modules_to_quant_args(model: Module) -> Dict:
|
317
|
+
quantized_modules_to_args = {}
|
318
|
+
for name, submodule in iter_named_leaf_modules(model):
|
319
|
+
if is_module_quantized(submodule):
|
320
|
+
if submodule.quantization_scheme.weights is not None:
|
321
|
+
name = fix_fsdp_module_name(name)
|
322
|
+
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
323
|
+
|
324
|
+
return quantized_modules_to_args
|
325
|
+
|
326
|
+
|
327
|
+
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
328
|
+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
329
|
+
def new_dtype_byte_size(dtype):
|
330
|
+
if dtype == torch.bool:
|
331
|
+
return 1 / 8
|
332
|
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
333
|
+
if bit_search is None:
|
334
|
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
335
|
+
bit_size = int(bit_search.groups()[0])
|
336
|
+
return bit_size // 8
|
@@ -0,0 +1,144 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Generator, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.compressors import Compressor
|
20
|
+
from compressed_tensors.config import CompressionFormat
|
21
|
+
from compressed_tensors.quantization import QuantizationArgs
|
22
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
23
|
+
from compressed_tensors.quantization.utils import can_quantize
|
24
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
25
|
+
from safetensors import safe_open
|
26
|
+
from torch import Tensor
|
27
|
+
from tqdm import tqdm
|
28
|
+
|
29
|
+
|
30
|
+
__all__ = [
|
31
|
+
"QuantizationCompressor",
|
32
|
+
"IntQuantizationCompressor",
|
33
|
+
"FloatQuantizationCompressor",
|
34
|
+
]
|
35
|
+
|
36
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
37
|
+
|
38
|
+
|
39
|
+
@Compressor.register(name=CompressionFormat.naive_quantized.value)
|
40
|
+
class QuantizationCompressor(Compressor):
|
41
|
+
"""
|
42
|
+
Implements naive compression for quantized models. Weight of each
|
43
|
+
quantized layer is converted from its original float type to the closest Pytorch
|
44
|
+
type to the type specified by the layer's QuantizationArgs.
|
45
|
+
"""
|
46
|
+
|
47
|
+
COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
|
48
|
+
|
49
|
+
def compress(
|
50
|
+
self,
|
51
|
+
model_state: Dict[str, Tensor],
|
52
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
53
|
+
**kwargs,
|
54
|
+
) -> Dict[str, Tensor]:
|
55
|
+
"""
|
56
|
+
Compresses a dense state dict
|
57
|
+
|
58
|
+
:param model_state: state dict of uncompressed model
|
59
|
+
:param names_to_scheme: quantization args for each quantized weight, needed for
|
60
|
+
quantize function to calculate bit depth
|
61
|
+
:return: compressed state dict
|
62
|
+
"""
|
63
|
+
compressed_dict = {}
|
64
|
+
weight_suffix = ".weight"
|
65
|
+
_LOGGER.debug(
|
66
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
67
|
+
)
|
68
|
+
|
69
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
70
|
+
if name.endswith(weight_suffix):
|
71
|
+
prefix = name[: -(len(weight_suffix))]
|
72
|
+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
73
|
+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
74
|
+
if scale is not None and zp is not None:
|
75
|
+
# weight is quantized, compress it
|
76
|
+
quant_args = names_to_scheme[prefix]
|
77
|
+
if can_quantize(value, quant_args):
|
78
|
+
# only quantize if not already quantized
|
79
|
+
value = quantize(
|
80
|
+
x=value,
|
81
|
+
scale=scale,
|
82
|
+
zero_point=zp,
|
83
|
+
args=quant_args,
|
84
|
+
dtype=quant_args.pytorch_dtype(),
|
85
|
+
)
|
86
|
+
elif name.endswith("zero_point"):
|
87
|
+
if torch.all(value == 0):
|
88
|
+
# all zero_points are 0, no need to include in
|
89
|
+
# compressed state_dict
|
90
|
+
continue
|
91
|
+
compressed_dict[name] = value.to("cpu")
|
92
|
+
|
93
|
+
return compressed_dict
|
94
|
+
|
95
|
+
def decompress(
|
96
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
97
|
+
) -> Generator[Tuple[str, Tensor], None, None]:
|
98
|
+
"""
|
99
|
+
Reads a compressed state dict located at path_to_model_or_tensors
|
100
|
+
and returns a generator for sequentially decompressing back to a
|
101
|
+
dense state dict
|
102
|
+
|
103
|
+
:param model_path: path to compressed safetensors model (directory with
|
104
|
+
one or more safetensors files) or compressed tensors file
|
105
|
+
:param device: optional device to load intermediate weights into
|
106
|
+
:return: compressed state dict
|
107
|
+
"""
|
108
|
+
weight_mappings = get_nested_weight_mappings(
|
109
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
110
|
+
)
|
111
|
+
for weight_name in weight_mappings.keys():
|
112
|
+
weight_data = {}
|
113
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
114
|
+
full_name = merge_names(weight_name, param_name)
|
115
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
116
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
117
|
+
|
118
|
+
if "weight_scale" in weight_data:
|
119
|
+
zero_point = weight_data.get("weight_zero_point", None)
|
120
|
+
scale = weight_data["weight_scale"]
|
121
|
+
decompressed = dequantize(
|
122
|
+
x_q=weight_data["weight"],
|
123
|
+
scale=scale,
|
124
|
+
zero_point=zero_point,
|
125
|
+
)
|
126
|
+
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
|
129
|
+
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
130
|
+
class IntQuantizationCompressor(QuantizationCompressor):
|
131
|
+
"""
|
132
|
+
Alias for integer quantized models
|
133
|
+
"""
|
134
|
+
|
135
|
+
pass
|
136
|
+
|
137
|
+
|
138
|
+
@Compressor.register(name=CompressionFormat.float_quantized.value)
|
139
|
+
class FloatQuantizationCompressor(QuantizationCompressor):
|
140
|
+
"""
|
141
|
+
Alias for fp quantized models
|
142
|
+
"""
|
143
|
+
|
144
|
+
pass
|
@@ -0,0 +1,219 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
import math
|
17
|
+
from typing import Dict, Generator, Tuple
|
18
|
+
|
19
|
+
import numpy as np
|
20
|
+
import torch
|
21
|
+
from compressed_tensors.compressors import Compressor
|
22
|
+
from compressed_tensors.config import CompressionFormat
|
23
|
+
from compressed_tensors.quantization import QuantizationArgs
|
24
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
|
+
from compressed_tensors.quantization.utils import can_quantize
|
26
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
27
|
+
from safetensors import safe_open
|
28
|
+
from torch import Tensor
|
29
|
+
from tqdm import tqdm
|
30
|
+
|
31
|
+
|
32
|
+
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
33
|
+
|
34
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
@Compressor.register(name=CompressionFormat.pack_quantized.value)
|
38
|
+
class PackedQuantizationCompressor(Compressor):
|
39
|
+
"""
|
40
|
+
Compresses a quantized model by packing every eight 4-bit weights into an int32
|
41
|
+
"""
|
42
|
+
|
43
|
+
COMPRESSION_PARAM_NAMES = [
|
44
|
+
"weight_packed",
|
45
|
+
"weight_scale",
|
46
|
+
"weight_zero_point",
|
47
|
+
"weight_shape",
|
48
|
+
]
|
49
|
+
|
50
|
+
def compress(
|
51
|
+
self,
|
52
|
+
model_state: Dict[str, Tensor],
|
53
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
54
|
+
**kwargs,
|
55
|
+
) -> Dict[str, Tensor]:
|
56
|
+
"""
|
57
|
+
Compresses a dense state dict
|
58
|
+
|
59
|
+
:param model_state: state dict of uncompressed model
|
60
|
+
:param names_to_scheme: quantization args for each quantized weight, needed for
|
61
|
+
quantize function to calculate bit depth
|
62
|
+
:return: compressed state dict
|
63
|
+
"""
|
64
|
+
compressed_dict = {}
|
65
|
+
weight_suffix = ".weight"
|
66
|
+
_LOGGER.debug(
|
67
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
68
|
+
)
|
69
|
+
|
70
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
71
|
+
if name.endswith(weight_suffix):
|
72
|
+
prefix = name[: -(len(weight_suffix))]
|
73
|
+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
74
|
+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
75
|
+
shape = torch.tensor(value.shape)
|
76
|
+
if scale is not None and zp is not None:
|
77
|
+
# weight is quantized, compress it
|
78
|
+
quant_args = names_to_scheme[prefix]
|
79
|
+
if can_quantize(value, quant_args):
|
80
|
+
# convert weight to an int if not already compressed
|
81
|
+
value = quantize(
|
82
|
+
x=value,
|
83
|
+
scale=scale,
|
84
|
+
zero_point=zp,
|
85
|
+
args=quant_args,
|
86
|
+
dtype=torch.int8,
|
87
|
+
)
|
88
|
+
value = pack_to_int32(value.cpu(), quant_args.num_bits)
|
89
|
+
compressed_dict[merge_names(prefix, "weight_shape")] = shape
|
90
|
+
compressed_dict[merge_names(prefix, "weight_packed")] = value
|
91
|
+
continue
|
92
|
+
|
93
|
+
elif name.endswith("zero_point"):
|
94
|
+
if torch.all(value == 0):
|
95
|
+
# all zero_points are 0, no need to include in
|
96
|
+
# compressed state_dict
|
97
|
+
continue
|
98
|
+
|
99
|
+
compressed_dict[name] = value.to("cpu")
|
100
|
+
|
101
|
+
return compressed_dict
|
102
|
+
|
103
|
+
def decompress(
|
104
|
+
self,
|
105
|
+
path_to_model_or_tensors: str,
|
106
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
107
|
+
device: str = "cpu",
|
108
|
+
) -> Generator[Tuple[str, Tensor], None, None]:
|
109
|
+
"""
|
110
|
+
Reads a compressed state dict located at path_to_model_or_tensors
|
111
|
+
and returns a generator for sequentially decompressing back to a
|
112
|
+
dense state dict
|
113
|
+
|
114
|
+
:param model_path: path to compressed safetensors model (directory with
|
115
|
+
one or more safetensors files) or compressed tensors file
|
116
|
+
:param device: optional device to load intermediate weights into
|
117
|
+
:return: compressed state dict
|
118
|
+
"""
|
119
|
+
weight_mappings = get_nested_weight_mappings(
|
120
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
121
|
+
)
|
122
|
+
for weight_name in weight_mappings.keys():
|
123
|
+
weight_data = {}
|
124
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
125
|
+
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
|
126
|
+
full_name = merge_names(weight_name, param_name)
|
127
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
128
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
129
|
+
|
130
|
+
if "weight_scale" in weight_data:
|
131
|
+
zero_point = weight_data.get("weight_zero_point", None)
|
132
|
+
scale = weight_data["weight_scale"]
|
133
|
+
weight = weight_data["weight_packed"]
|
134
|
+
num_bits = weight_data["num_bits"]
|
135
|
+
original_shape = torch.Size(weight_data["weight_shape"])
|
136
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
137
|
+
decompressed = dequantize(
|
138
|
+
x_q=unpacked,
|
139
|
+
scale=scale,
|
140
|
+
zero_point=zero_point,
|
141
|
+
)
|
142
|
+
yield merge_names(weight_name, "weight"), decompressed
|
143
|
+
|
144
|
+
|
145
|
+
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
146
|
+
"""
|
147
|
+
Packs a tensor of quantized weights stored in int8 into int32s with padding
|
148
|
+
|
149
|
+
:param value: tensor to pack
|
150
|
+
:param num_bits: number of bits used to store underlying data
|
151
|
+
:returns: packed int32 tensor
|
152
|
+
"""
|
153
|
+
if value.dtype is not torch.int8:
|
154
|
+
raise ValueError("Tensor must be quantized to torch.int8 before packing")
|
155
|
+
|
156
|
+
if num_bits > 8:
|
157
|
+
raise ValueError("Packing is only supported for less than 8 bits")
|
158
|
+
|
159
|
+
# convert to unsigned for packing
|
160
|
+
offset = pow(2, num_bits) // 2
|
161
|
+
value = (value + offset).to(torch.uint8)
|
162
|
+
value = value.cpu().numpy().astype(np.uint32)
|
163
|
+
pack_factor = 32 // num_bits
|
164
|
+
|
165
|
+
# pad input tensor and initialize packed output
|
166
|
+
packed_size = math.ceil(value.shape[1] / pack_factor)
|
167
|
+
packed = np.zeros((value.shape[0], packed_size), dtype=np.uint32)
|
168
|
+
padding = packed.shape[1] * pack_factor - value.shape[1]
|
169
|
+
value = np.pad(value, pad_width=[(0, 0), (0, padding)], constant_values=0)
|
170
|
+
|
171
|
+
# pack values
|
172
|
+
for i in range(pack_factor):
|
173
|
+
packed |= value[:, i::pack_factor] << num_bits * i
|
174
|
+
|
175
|
+
# convert back to signed and torch
|
176
|
+
packed = np.ascontiguousarray(packed).view(np.int32)
|
177
|
+
return torch.from_numpy(packed)
|
178
|
+
|
179
|
+
|
180
|
+
def unpack_from_int32(
|
181
|
+
value: torch.Tensor, num_bits: int, shape: torch.Size
|
182
|
+
) -> torch.Tensor:
|
183
|
+
"""
|
184
|
+
Unpacks a tensor of packed int32 weights into individual int8s, maintaining the
|
185
|
+
original their bit range
|
186
|
+
|
187
|
+
:param value: tensor to upack
|
188
|
+
:param num_bits: number of bits to unpack each data point into
|
189
|
+
:param shape: shape to unpack into, used to remove padding
|
190
|
+
:returns: unpacked int8 tensor
|
191
|
+
"""
|
192
|
+
if value.dtype is not torch.int32:
|
193
|
+
raise ValueError(
|
194
|
+
f"Expected {torch.int32} but got {value.dtype}, Aborting unpack."
|
195
|
+
)
|
196
|
+
|
197
|
+
if num_bits > 8:
|
198
|
+
raise ValueError("Unpacking is only supported for less than 8 bits")
|
199
|
+
|
200
|
+
# convert packed input to unsigned numpy
|
201
|
+
value = value.numpy().view(np.uint32)
|
202
|
+
pack_factor = 32 // num_bits
|
203
|
+
|
204
|
+
# unpack
|
205
|
+
mask = pow(2, num_bits) - 1
|
206
|
+
unpacked = np.zeros((value.shape[0], value.shape[1] * pack_factor))
|
207
|
+
for i in range(pack_factor):
|
208
|
+
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
209
|
+
|
210
|
+
# remove padding
|
211
|
+
original_row_size = int(shape[1])
|
212
|
+
unpacked = unpacked[:, :original_row_size]
|
213
|
+
|
214
|
+
# bits are packed in unsigned format, reformat to signed
|
215
|
+
# update the value range from unsigned to signed
|
216
|
+
offset = pow(2, num_bits) // 2
|
217
|
+
unpacked = (unpacked.astype(np.int16) - offset).astype(np.int8)
|
218
|
+
|
219
|
+
return torch.from_numpy(unpacked)
|