compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
compressed_tensors/base.py
CHANGED
@@ -17,8 +17,12 @@
|
|
17
17
|
from .base import Compressor
|
18
18
|
from .dense import DenseCompressor
|
19
19
|
from .helpers import load_compressed, save_compressed, save_compressed_model
|
20
|
-
from .int_quantized import IntQuantizationCompressor
|
21
20
|
from .marlin_24 import Marlin24Compressor
|
22
21
|
from .model_compressor import ModelCompressor, map_modules_to_quant_args
|
22
|
+
from .naive_quantized import (
|
23
|
+
FloatQuantizationCompressor,
|
24
|
+
IntQuantizationCompressor,
|
25
|
+
QuantizationCompressor,
|
26
|
+
)
|
23
27
|
from .pack_quantized import PackedQuantizationCompressor
|
24
28
|
from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
|
@@ -12,20 +12,53 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Generator, Optional, Tuple, Union
|
16
17
|
|
18
|
+
import torch
|
17
19
|
from compressed_tensors.config import SparsityCompressionConfig
|
18
|
-
from compressed_tensors.quantization import QuantizationConfig
|
20
|
+
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
|
19
21
|
from compressed_tensors.registry import RegistryMixin
|
22
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
23
|
+
from safetensors import safe_open
|
20
24
|
from torch import Tensor
|
25
|
+
from torch.nn.modules import Module
|
26
|
+
from tqdm import tqdm
|
21
27
|
|
22
28
|
|
29
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
30
|
+
|
23
31
|
__all__ = ["Compressor"]
|
24
32
|
|
25
33
|
|
26
34
|
class Compressor(RegistryMixin):
|
27
35
|
"""
|
28
|
-
Base class representing a model compression algorithm
|
36
|
+
Base class representing a model compression algorithm. Each child class should
|
37
|
+
implement compression_param_info, compress_weight and decompress_weight.
|
38
|
+
|
39
|
+
Compressors support compressing/decompressing a full module state dict or a single
|
40
|
+
quantized PyTorch leaf module.
|
41
|
+
|
42
|
+
Model Load Lifecycle (run_compressed=False):
|
43
|
+
- ModelCompressor.decompress()
|
44
|
+
- apply_quantization_config()
|
45
|
+
- Compressor.decompress()
|
46
|
+
- Compressor.decompress_weight()
|
47
|
+
|
48
|
+
Model Save Lifecycle:
|
49
|
+
- ModelCompressor.compress()
|
50
|
+
- Compressor.compress()
|
51
|
+
- Compressor.compress_weight()
|
52
|
+
|
53
|
+
Module Lifecycle (run_compressed=True):
|
54
|
+
- apply_quantization_config()
|
55
|
+
- compressed_module = CompressedLinear(module)
|
56
|
+
- initialize_module_for_quantization()
|
57
|
+
- Compressor.compression_param_info()
|
58
|
+
- register_parameters()
|
59
|
+
- compressed_module.forward()
|
60
|
+
-compressed_module.decompress()
|
61
|
+
|
29
62
|
|
30
63
|
:param config: config specifying compression parameters
|
31
64
|
"""
|
@@ -35,26 +68,185 @@ class Compressor(RegistryMixin):
|
|
35
68
|
):
|
36
69
|
self.config = config
|
37
70
|
|
38
|
-
def
|
71
|
+
def compression_param_info(
|
72
|
+
self,
|
73
|
+
weight_shape: torch.Size,
|
74
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
75
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
76
|
+
"""
|
77
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
78
|
+
parameter used by the compressor
|
79
|
+
|
80
|
+
:param weight_shape: uncompressed weight shape
|
81
|
+
:param quantization_args: quantization parameters for the weight
|
82
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
83
|
+
"""
|
84
|
+
raise NotImplementedError()
|
85
|
+
|
86
|
+
def compress(
|
87
|
+
self,
|
88
|
+
model_state: Dict[str, Tensor],
|
89
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
90
|
+
**kwargs,
|
91
|
+
) -> Dict[str, Tensor]:
|
39
92
|
"""
|
40
93
|
Compresses a dense state dict
|
41
94
|
|
42
95
|
:param model_state: state dict of uncompressed model
|
96
|
+
:param names_to_scheme: quantization args for each quantized weight, needed for
|
97
|
+
quantize function to calculate bit depth
|
43
98
|
:return: compressed state dict
|
44
99
|
"""
|
45
|
-
|
100
|
+
compressed_dict = {}
|
101
|
+
weight_suffix = ".weight"
|
102
|
+
_LOGGER.debug(
|
103
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
104
|
+
)
|
105
|
+
|
106
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
107
|
+
if name.endswith(weight_suffix):
|
108
|
+
prefix = name[: -(len(weight_suffix))]
|
109
|
+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
110
|
+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
111
|
+
g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
|
112
|
+
if scale is not None:
|
113
|
+
# weight is quantized, compress it
|
114
|
+
quant_args = names_to_scheme[prefix]
|
115
|
+
compressed_data = self.compress_weight(
|
116
|
+
weight=value,
|
117
|
+
scale=scale,
|
118
|
+
zero_point=zp,
|
119
|
+
g_idx=g_idx,
|
120
|
+
quantization_args=quant_args,
|
121
|
+
device="cpu",
|
122
|
+
)
|
123
|
+
for key, value in compressed_data.items():
|
124
|
+
compressed_dict[merge_names(prefix, key)] = value
|
125
|
+
else:
|
126
|
+
compressed_dict[name] = value.to("cpu")
|
127
|
+
elif name.endswith("zero_point") and torch.all(value == 0):
|
128
|
+
continue
|
129
|
+
elif name.endswith("g_idx") and torch.any(value <= -1):
|
130
|
+
continue
|
131
|
+
else:
|
132
|
+
compressed_dict[name] = value.to("cpu")
|
133
|
+
|
134
|
+
return compressed_dict
|
46
135
|
|
47
136
|
def decompress(
|
48
|
-
self,
|
137
|
+
self,
|
138
|
+
path_to_model_or_tensors: str,
|
139
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
140
|
+
device: str = "cpu",
|
49
141
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
50
142
|
"""
|
51
143
|
Reads a compressed state dict located at path_to_model_or_tensors
|
52
144
|
and returns a generator for sequentially decompressing back to a
|
53
145
|
dense state dict
|
54
146
|
|
55
|
-
:param
|
56
|
-
one or more safetensors files) or compressed tensors file
|
147
|
+
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
148
|
+
with one or more safetensors files) or compressed tensors file
|
149
|
+
:param names_to_scheme: quantization args for each quantized weight
|
57
150
|
:param device: optional device to load intermediate weights into
|
58
151
|
:return: compressed state dict
|
59
152
|
"""
|
153
|
+
weight_mappings = get_nested_weight_mappings(
|
154
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
155
|
+
)
|
156
|
+
for weight_name in weight_mappings.keys():
|
157
|
+
weight_data = {}
|
158
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
159
|
+
full_name = merge_names(weight_name, param_name)
|
160
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
161
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
162
|
+
|
163
|
+
if "weight_scale" in weight_data:
|
164
|
+
quant_args = names_to_scheme[weight_name]
|
165
|
+
decompressed = self.decompress_weight(
|
166
|
+
compressed_data=weight_data, quantization_args=quant_args
|
167
|
+
)
|
168
|
+
yield merge_names(weight_name, "weight"), decompressed
|
169
|
+
|
170
|
+
def compress_weight(
|
171
|
+
self,
|
172
|
+
weight: Tensor,
|
173
|
+
scale: Tensor,
|
174
|
+
zero_point: Optional[Tensor] = None,
|
175
|
+
g_idx: Optional[torch.Tensor] = None,
|
176
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
177
|
+
) -> Dict[str, torch.Tensor]:
|
178
|
+
"""
|
179
|
+
Compresses a single uncompressed weight
|
180
|
+
|
181
|
+
:param weight: uncompressed weight tensor
|
182
|
+
:param scale: quantization scale for weight
|
183
|
+
:param zero_point: quantization zero point for weight
|
184
|
+
:param g_idx: optional mapping from column index to group index
|
185
|
+
:param quantization_args: quantization parameters for weight
|
186
|
+
:return: dictionary of compressed weight data
|
187
|
+
"""
|
60
188
|
raise NotImplementedError()
|
189
|
+
|
190
|
+
def decompress_weight(
|
191
|
+
self,
|
192
|
+
compressed_data: Dict[str, Tensor],
|
193
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
194
|
+
) -> torch.Tensor:
|
195
|
+
"""
|
196
|
+
Decompresses a single compressed weight
|
197
|
+
|
198
|
+
:param compressed_data: dictionary of data needed for decompression
|
199
|
+
:param quantization_args: quantization parameters for the weight
|
200
|
+
:return: tensor of the decompressed weight
|
201
|
+
"""
|
202
|
+
raise NotImplementedError()
|
203
|
+
|
204
|
+
def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
|
205
|
+
"""
|
206
|
+
Compresses a single quantized leaf PyTorch module. If the module is not
|
207
|
+
quantized, this function has no effect.
|
208
|
+
|
209
|
+
:param module: PyTorch module to compress
|
210
|
+
:return: dictionary of compressed weight data, or None if module is not
|
211
|
+
quantized
|
212
|
+
"""
|
213
|
+
if not hasattr(module, "quantization_scheme"):
|
214
|
+
return None # module is not quantized
|
215
|
+
quantization_scheme = module.quantization_scheme
|
216
|
+
if not hasattr(quantization_scheme, "weights"):
|
217
|
+
return None # weights are not quantized
|
218
|
+
|
219
|
+
quantization_args = quantization_scheme.weights
|
220
|
+
weight = getattr(module, "weight", None)
|
221
|
+
weight_scale = getattr(module, "weight_scale", None)
|
222
|
+
weight_zero_point = getattr(module, "weight_zero_point", None)
|
223
|
+
|
224
|
+
return self.compress_weight(
|
225
|
+
weight=weight,
|
226
|
+
scale=weight_scale,
|
227
|
+
zero_point=weight_zero_point,
|
228
|
+
quantization_args=quantization_args,
|
229
|
+
)
|
230
|
+
|
231
|
+
def decompress_module(self, module: Module):
|
232
|
+
"""
|
233
|
+
Decompresses a single compressed leaf PyTorch module. If the module is not
|
234
|
+
quantized, this function has no effect.
|
235
|
+
|
236
|
+
:param module: PyTorch module to decompress
|
237
|
+
:return: tensor of the decompressed weight, or None if module is not quantized
|
238
|
+
"""
|
239
|
+
if not hasattr(module, "quantization_scheme"):
|
240
|
+
return None # module is not quantized
|
241
|
+
quantization_scheme = module.quantization_scheme
|
242
|
+
if not hasattr(quantization_scheme, "weights"):
|
243
|
+
return None # weights are not quantized
|
244
|
+
|
245
|
+
quantization_args = quantization_scheme.weights
|
246
|
+
compressed_data = {}
|
247
|
+
for name, parameter in module.named_parameters():
|
248
|
+
compressed_data[name] = parameter
|
249
|
+
|
250
|
+
return self.decompress_weight(
|
251
|
+
compressed_data=compressed_data, quantization_args=quantization_args
|
252
|
+
)
|
@@ -29,6 +29,6 @@ class DenseCompressor(Compressor):
|
|
29
29
|
return model_state
|
30
30
|
|
31
31
|
def decompress(
|
32
|
-
self, path_to_model_or_tensors: str, device: str = "cpu"
|
32
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
33
33
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
34
34
|
return iter([])
|
@@ -18,15 +18,16 @@ from typing import Dict, Generator, Tuple
|
|
18
18
|
import numpy as np
|
19
19
|
import torch
|
20
20
|
from compressed_tensors.compressors import Compressor
|
21
|
-
from compressed_tensors.
|
21
|
+
from compressed_tensors.config import CompressionFormat
|
22
|
+
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
23
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
24
|
+
from compressed_tensors.utils import (
|
22
25
|
get_permutations_24,
|
26
|
+
is_quantization_param,
|
27
|
+
merge_names,
|
23
28
|
sparse_semi_structured_from_dense_cutlass,
|
24
29
|
tensor_follows_mask_structure,
|
25
30
|
)
|
26
|
-
from compressed_tensors.config import CompressionFormat
|
27
|
-
from compressed_tensors.quantization import QuantizationArgs, QuantizationStrategy
|
28
|
-
from compressed_tensors.quantization.lifecycle.forward import quantize
|
29
|
-
from compressed_tensors.utils import is_quantization_param, merge_names
|
30
31
|
from torch import Tensor
|
31
32
|
from tqdm import tqdm
|
32
33
|
|
@@ -107,7 +108,7 @@ class Marlin24Compressor(Compressor):
|
|
107
108
|
def compress(
|
108
109
|
self,
|
109
110
|
model_state: Dict[str, Tensor],
|
110
|
-
|
111
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
111
112
|
**kwargs,
|
112
113
|
) -> Dict[str, Tensor]:
|
113
114
|
"""
|
@@ -115,11 +116,11 @@ class Marlin24Compressor(Compressor):
|
|
115
116
|
with the Marlin24 kernel
|
116
117
|
|
117
118
|
:param model_state: state dict of uncompressed model
|
118
|
-
:param
|
119
|
+
:param names_to_scheme: quantization args for each quantized weight, needed for
|
119
120
|
quantize function to calculate bit depth
|
120
121
|
:return: compressed state dict
|
121
122
|
"""
|
122
|
-
self.validate_quant_compatability(
|
123
|
+
self.validate_quant_compatability(names_to_scheme)
|
123
124
|
|
124
125
|
compressed_dict = {}
|
125
126
|
weight_suffix = ".weight"
|
@@ -139,7 +140,7 @@ class Marlin24Compressor(Compressor):
|
|
139
140
|
value = value.to(torch.float16)
|
140
141
|
|
141
142
|
# quantize weight, keeping it as a float16 for now
|
142
|
-
quant_args =
|
143
|
+
quant_args = names_to_scheme[prefix]
|
143
144
|
value = quantize(
|
144
145
|
x=value, scale=scale, zero_point=zp, args=quant_args
|
145
146
|
)
|
@@ -175,7 +176,7 @@ class Marlin24Compressor(Compressor):
|
|
175
176
|
return compressed_dict
|
176
177
|
|
177
178
|
def decompress(
|
178
|
-
self, path_to_model_or_tensors: str, device: str = "cpu"
|
179
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
179
180
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
180
181
|
raise NotImplementedError(
|
181
182
|
"Decompression is not implemented for the Marlin24 Compressor."
|
@@ -16,16 +16,19 @@ import json
|
|
16
16
|
import logging
|
17
17
|
import operator
|
18
18
|
import os
|
19
|
+
import re
|
19
20
|
from copy import deepcopy
|
20
21
|
from typing import Any, Dict, Optional, Union
|
21
22
|
|
23
|
+
import torch
|
24
|
+
import transformers
|
22
25
|
from compressed_tensors.base import (
|
23
26
|
COMPRESSION_CONFIG_NAME,
|
24
27
|
QUANTIZATION_CONFIG_NAME,
|
25
28
|
SPARSITY_CONFIG_NAME,
|
26
29
|
)
|
27
30
|
from compressed_tensors.compressors import Compressor
|
28
|
-
from compressed_tensors.config import SparsityCompressionConfig
|
31
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
29
32
|
from compressed_tensors.quantization import (
|
30
33
|
QuantizationConfig,
|
31
34
|
QuantizationStatus,
|
@@ -36,10 +39,10 @@ from compressed_tensors.quantization.utils import (
|
|
36
39
|
is_module_quantized,
|
37
40
|
iter_named_leaf_modules,
|
38
41
|
)
|
39
|
-
from compressed_tensors.utils import get_safetensors_folder
|
42
|
+
from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
|
40
43
|
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
41
44
|
from torch import Tensor
|
42
|
-
from torch.nn import Module
|
45
|
+
from torch.nn import Module
|
43
46
|
from tqdm import tqdm
|
44
47
|
from transformers import AutoConfig
|
45
48
|
from transformers.file_utils import CONFIG_NAME
|
@@ -78,6 +81,7 @@ class ModelCompressor:
|
|
78
81
|
def from_pretrained(
|
79
82
|
cls,
|
80
83
|
pretrained_model_name_or_path: str,
|
84
|
+
**kwargs,
|
81
85
|
) -> Optional["ModelCompressor"]:
|
82
86
|
"""
|
83
87
|
Given a path to a model config, extract the sparsity and/or quantization
|
@@ -86,7 +90,7 @@ class ModelCompressor:
|
|
86
90
|
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
87
91
|
:return: compressor for the extracted configs
|
88
92
|
"""
|
89
|
-
config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
|
93
|
+
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
90
94
|
compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
|
91
95
|
return cls.from_compression_config(compression_config)
|
92
96
|
|
@@ -172,6 +176,9 @@ class ModelCompressor:
|
|
172
176
|
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
|
173
177
|
# for loaded HFQuantizer config
|
174
178
|
return getattr(compression_config, SPARSITY_CONFIG_NAME)
|
179
|
+
if SPARSITY_CONFIG_NAME in compression_config:
|
180
|
+
# for loaded HFQuantizer config from dict
|
181
|
+
return compression_config[SPARSITY_CONFIG_NAME]
|
175
182
|
|
176
183
|
# SparseAutoModel format
|
177
184
|
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
@@ -185,6 +192,10 @@ class ModelCompressor:
|
|
185
192
|
# for loaded HFQuantizer config
|
186
193
|
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
|
187
194
|
|
195
|
+
if QUANTIZATION_CONFIG_NAME in compression_config:
|
196
|
+
# for loaded HFQuantizer config from dict
|
197
|
+
return compression_config[QUANTIZATION_CONFIG_NAME]
|
198
|
+
|
188
199
|
# SparseAutoModel format
|
189
200
|
quantization_config = deepcopy(compression_config)
|
190
201
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
@@ -228,14 +239,79 @@ class ModelCompressor:
|
|
228
239
|
quantized_modules_to_args = map_modules_to_quant_args(model)
|
229
240
|
if self.quantization_compressor is not None:
|
230
241
|
compressed_state_dict = self.quantization_compressor.compress(
|
231
|
-
state_dict,
|
242
|
+
state_dict, names_to_scheme=quantized_modules_to_args
|
232
243
|
)
|
244
|
+
if self.quantization_config.format != CompressionFormat.dense.value:
|
245
|
+
self.quantization_config.quantization_status = (
|
246
|
+
QuantizationStatus.COMPRESSED
|
247
|
+
)
|
233
248
|
|
234
249
|
if self.sparsity_compressor is not None:
|
235
250
|
compressed_state_dict = self.sparsity_compressor.compress(
|
236
251
|
compressed_state_dict
|
237
252
|
)
|
238
253
|
|
254
|
+
# HACK (mgoin): Post-process step for kv cache scales to take the
|
255
|
+
# k/v_proj module `output_scale` parameters, and store them in the
|
256
|
+
# parent attention module as `k_scale` and `v_scale`
|
257
|
+
#
|
258
|
+
# Example:
|
259
|
+
# Replace `model.layers.0.self_attn.k_proj.output_scale`
|
260
|
+
# with `model.layers.0.self_attn.k_scale`
|
261
|
+
if (
|
262
|
+
self.quantization_config is not None
|
263
|
+
and self.quantization_config.kv_cache_scheme is not None
|
264
|
+
):
|
265
|
+
# HACK (mgoin): We assume the quantized modules in question
|
266
|
+
# will be k_proj and v_proj since those are the default targets.
|
267
|
+
# We check that both of these modules have output activation
|
268
|
+
# quantization, and additionally check that q_proj doesn't.
|
269
|
+
q_proj_has_no_quant_output = 0
|
270
|
+
k_proj_has_quant_output = 0
|
271
|
+
v_proj_has_quant_output = 0
|
272
|
+
for name, module in model.named_modules():
|
273
|
+
if not hasattr(module, "quantization_scheme"):
|
274
|
+
# We still want to count non-quantized q_proj
|
275
|
+
if name.endswith(".q_proj"):
|
276
|
+
q_proj_has_no_quant_output += 1
|
277
|
+
continue
|
278
|
+
out_act = module.quantization_scheme.output_activations
|
279
|
+
if name.endswith(".q_proj") and out_act is None:
|
280
|
+
q_proj_has_no_quant_output += 1
|
281
|
+
elif name.endswith(".k_proj") and out_act is not None:
|
282
|
+
k_proj_has_quant_output += 1
|
283
|
+
elif name.endswith(".v_proj") and out_act is not None:
|
284
|
+
v_proj_has_quant_output += 1
|
285
|
+
|
286
|
+
assert (
|
287
|
+
q_proj_has_no_quant_output > 0
|
288
|
+
and k_proj_has_quant_output > 0
|
289
|
+
and v_proj_has_quant_output > 0
|
290
|
+
)
|
291
|
+
assert (
|
292
|
+
q_proj_has_no_quant_output
|
293
|
+
== k_proj_has_quant_output
|
294
|
+
== v_proj_has_quant_output
|
295
|
+
)
|
296
|
+
|
297
|
+
# Move all .k/v_proj.output_scale parameters to .k/v_scale
|
298
|
+
working_state_dict = {}
|
299
|
+
for key in compressed_state_dict.keys():
|
300
|
+
if key.endswith(".k_proj.output_scale"):
|
301
|
+
new_key = key.replace(".k_proj.output_scale", ".k_scale")
|
302
|
+
working_state_dict[new_key] = compressed_state_dict[key]
|
303
|
+
elif key.endswith(".v_proj.output_scale"):
|
304
|
+
new_key = key.replace(".v_proj.output_scale", ".v_scale")
|
305
|
+
working_state_dict[new_key] = compressed_state_dict[key]
|
306
|
+
else:
|
307
|
+
working_state_dict[key] = compressed_state_dict[key]
|
308
|
+
compressed_state_dict = working_state_dict
|
309
|
+
|
310
|
+
# HACK: Override the dtype_byte_size function in transformers to
|
311
|
+
# support float8 types. Fix is posted upstream
|
312
|
+
# https://github.com/huggingface/transformers/pull/30488
|
313
|
+
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
|
314
|
+
|
239
315
|
return compressed_state_dict
|
240
316
|
|
241
317
|
def decompress(self, model_path: str, model: Module):
|
@@ -252,9 +328,11 @@ class ModelCompressor:
|
|
252
328
|
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
253
329
|
|
254
330
|
if self.quantization_compressor is not None:
|
255
|
-
apply_quantization_config(model, self.quantization_config)
|
331
|
+
names_to_scheme = apply_quantization_config(model, self.quantization_config)
|
256
332
|
load_pretrained_quantization(model, model_path)
|
257
|
-
dense_gen = self.quantization_compressor.decompress(
|
333
|
+
dense_gen = self.quantization_compressor.decompress(
|
334
|
+
model_path, names_to_scheme=names_to_scheme
|
335
|
+
)
|
258
336
|
self._replace_weights(dense_gen, model)
|
259
337
|
|
260
338
|
def update_status(module):
|
@@ -296,12 +374,10 @@ class ModelCompressor:
|
|
296
374
|
|
297
375
|
def _replace_weights(self, dense_weight_generator, model):
|
298
376
|
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
data_new = Parameter(data.to(model_device).to(data_dtype))
|
304
|
-
data_old.data = data_new.data
|
377
|
+
split_name = name.split(".")
|
378
|
+
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
379
|
+
module = operator.attrgetter(prefix)(model)
|
380
|
+
update_parameter_data(module, data, param_name)
|
305
381
|
|
306
382
|
|
307
383
|
def map_modules_to_quant_args(model: Module) -> Dict:
|
@@ -313,3 +389,15 @@ def map_modules_to_quant_args(model: Module) -> Dict:
|
|
313
389
|
quantized_modules_to_args[name] = submodule.quantization_scheme.weights
|
314
390
|
|
315
391
|
return quantized_modules_to_args
|
392
|
+
|
393
|
+
|
394
|
+
# HACK: Override the dtype_byte_size function in transformers to support float8 types
|
395
|
+
# Fix is posted upstream https://github.com/huggingface/transformers/pull/30488
|
396
|
+
def new_dtype_byte_size(dtype):
|
397
|
+
if dtype == torch.bool:
|
398
|
+
return 1 / 8
|
399
|
+
bit_search = re.search(r"[^\d](\d+)_?", str(dtype))
|
400
|
+
if bit_search is None:
|
401
|
+
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
402
|
+
bit_size = int(bit_search.groups()[0])
|
403
|
+
return bit_size // 8
|
@@ -0,0 +1,140 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Optional, Tuple
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.compressors import Compressor
|
20
|
+
from compressed_tensors.config import CompressionFormat
|
21
|
+
from compressed_tensors.quantization import QuantizationArgs
|
22
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
23
|
+
from compressed_tensors.quantization.utils import can_quantize
|
24
|
+
from torch import Tensor
|
25
|
+
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
"QuantizationCompressor",
|
29
|
+
"IntQuantizationCompressor",
|
30
|
+
"FloatQuantizationCompressor",
|
31
|
+
]
|
32
|
+
|
33
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
@Compressor.register(name=CompressionFormat.naive_quantized.value)
|
37
|
+
class QuantizationCompressor(Compressor):
|
38
|
+
"""
|
39
|
+
Implements naive compression for quantized models. Weight of each
|
40
|
+
quantized layer is converted from its original float type to the closest Pytorch
|
41
|
+
type to the type specified by the layer's QuantizationArgs.
|
42
|
+
"""
|
43
|
+
|
44
|
+
COMPRESSION_PARAM_NAMES = [
|
45
|
+
"weight",
|
46
|
+
"weight_scale",
|
47
|
+
"weight_zero_point",
|
48
|
+
"weight_g_idx",
|
49
|
+
]
|
50
|
+
|
51
|
+
def compression_param_info(
|
52
|
+
self,
|
53
|
+
weight_shape: torch.Size,
|
54
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
55
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
56
|
+
"""
|
57
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
58
|
+
parameter used by the compressor
|
59
|
+
|
60
|
+
:param weight_shape: uncompressed weight shape
|
61
|
+
:param quantization_args: quantization parameters for the weight
|
62
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
63
|
+
"""
|
64
|
+
dtype = quantization_args.pytorch_dtype()
|
65
|
+
return {"weight": (weight_shape, dtype)}
|
66
|
+
|
67
|
+
def compress_weight(
|
68
|
+
self,
|
69
|
+
weight: Tensor,
|
70
|
+
scale: Tensor,
|
71
|
+
zero_point: Optional[Tensor] = None,
|
72
|
+
g_idx: Optional[torch.Tensor] = None,
|
73
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
74
|
+
device: Optional[torch.device] = None,
|
75
|
+
) -> Dict[str, torch.Tensor]:
|
76
|
+
"""
|
77
|
+
Compresses a single uncompressed weight
|
78
|
+
|
79
|
+
:param weight: uncompressed weight tensor
|
80
|
+
:param scale: quantization scale for weight
|
81
|
+
:param zero_point: quantization zero point for weight
|
82
|
+
:param g_idx: optional mapping from column index to group index
|
83
|
+
:param quantization_args: quantization parameters for weight
|
84
|
+
:param device: optional device to move compressed output to
|
85
|
+
:return: dictionary of compressed weight data
|
86
|
+
"""
|
87
|
+
if can_quantize(weight, quantization_args):
|
88
|
+
quantized_weight = quantize(
|
89
|
+
x=weight,
|
90
|
+
scale=scale,
|
91
|
+
zero_point=zero_point,
|
92
|
+
g_idx=g_idx,
|
93
|
+
args=quantization_args,
|
94
|
+
dtype=quantization_args.pytorch_dtype(),
|
95
|
+
)
|
96
|
+
|
97
|
+
if device is not None:
|
98
|
+
quantized_weight = quantized_weight.to(device)
|
99
|
+
|
100
|
+
return {"weight": quantized_weight}
|
101
|
+
|
102
|
+
def decompress_weight(
|
103
|
+
self,
|
104
|
+
compressed_data: Dict[str, Tensor],
|
105
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
106
|
+
) -> torch.Tensor:
|
107
|
+
"""
|
108
|
+
Decompresses a single compressed weight
|
109
|
+
|
110
|
+
:param compressed_data: dictionary of data needed for decompression
|
111
|
+
:param quantization_args: quantization parameters for the weight
|
112
|
+
:return: tensor of the decompressed weight
|
113
|
+
"""
|
114
|
+
weight = compressed_data["weight"]
|
115
|
+
scale = compressed_data["weight_scale"]
|
116
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
117
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
118
|
+
decompressed_weight = dequantize(
|
119
|
+
x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx
|
120
|
+
)
|
121
|
+
|
122
|
+
return decompressed_weight
|
123
|
+
|
124
|
+
|
125
|
+
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
126
|
+
class IntQuantizationCompressor(QuantizationCompressor):
|
127
|
+
"""
|
128
|
+
Alias for integer quantized models
|
129
|
+
"""
|
130
|
+
|
131
|
+
pass
|
132
|
+
|
133
|
+
|
134
|
+
@Compressor.register(name=CompressionFormat.float_quantized.value)
|
135
|
+
class FloatQuantizationCompressor(QuantizationCompressor):
|
136
|
+
"""
|
137
|
+
Alias for fp quantized models
|
138
|
+
"""
|
139
|
+
|
140
|
+
pass
|