compressed-tensors-nightly 0.5.0.20240829__py3-none-any.whl → 0.5.0.20240831__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/base.py +198 -8
- compressed_tensors/compressors/model_compressor.py +12 -1
- compressed_tensors/compressors/naive_quantized.py +65 -78
- compressed_tensors/compressors/pack_quantized.py +82 -98
- compressed_tensors/linear/__init__.py +13 -0
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/apply.py +31 -4
- compressed_tensors/quantization/lifecycle/compressed.py +1 -1
- compressed_tensors/quantization/lifecycle/forward.py +19 -19
- compressed_tensors/quantization/lifecycle/initialize.py +29 -14
- compressed_tensors/utils/helpers.py +13 -0
- compressed_tensors/utils/offload.py +3 -0
- {compressed_tensors_nightly-0.5.0.20240829.dist-info → compressed_tensors_nightly-0.5.0.20240831.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.5.0.20240829.dist-info → compressed_tensors_nightly-0.5.0.20240831.dist-info}/RECORD +17 -15
- {compressed_tensors_nightly-0.5.0.20240829.dist-info → compressed_tensors_nightly-0.5.0.20240831.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.5.0.20240829.dist-info → compressed_tensors_nightly-0.5.0.20240831.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.5.0.20240829.dist-info → compressed_tensors_nightly-0.5.0.20240831.dist-info}/top_level.txt +0 -0
@@ -12,20 +12,53 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import logging
|
16
|
+
from typing import Dict, Generator, Optional, Tuple, Union
|
16
17
|
|
18
|
+
import torch
|
17
19
|
from compressed_tensors.config import SparsityCompressionConfig
|
18
|
-
from compressed_tensors.quantization import QuantizationConfig
|
20
|
+
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
|
19
21
|
from compressed_tensors.registry import RegistryMixin
|
22
|
+
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
23
|
+
from safetensors import safe_open
|
20
24
|
from torch import Tensor
|
25
|
+
from torch.nn.modules import Module
|
26
|
+
from tqdm import tqdm
|
21
27
|
|
22
28
|
|
29
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
30
|
+
|
23
31
|
__all__ = ["Compressor"]
|
24
32
|
|
25
33
|
|
26
34
|
class Compressor(RegistryMixin):
|
27
35
|
"""
|
28
|
-
Base class representing a model compression algorithm
|
36
|
+
Base class representing a model compression algorithm. Each child class should
|
37
|
+
implement compression_param_info, compress_weight and decompress_weight.
|
38
|
+
|
39
|
+
Compressors support compressing/decompressing a full module state dict or a single
|
40
|
+
quantized PyTorch leaf module.
|
41
|
+
|
42
|
+
Model Load Lifecycle (run_compressed=False):
|
43
|
+
- ModelCompressor.decompress()
|
44
|
+
- apply_quantization_config()
|
45
|
+
- Compressor.decompress()
|
46
|
+
- Compressor.decompress_weight()
|
47
|
+
|
48
|
+
Model Save Lifecycle:
|
49
|
+
- ModelCompressor.compress()
|
50
|
+
- Compressor.compress()
|
51
|
+
- Compressor.compress_weight()
|
52
|
+
|
53
|
+
Module Lifecycle (run_compressed=True):
|
54
|
+
- apply_quantization_config()
|
55
|
+
- compressed_module = CompressedLinear(module)
|
56
|
+
- initialize_module_for_quantization()
|
57
|
+
- Compressor.compression_param_info()
|
58
|
+
- register_parameters()
|
59
|
+
- compressed_module.forward()
|
60
|
+
-compressed_module.decompress()
|
61
|
+
|
29
62
|
|
30
63
|
:param config: config specifying compression parameters
|
31
64
|
"""
|
@@ -35,26 +68,183 @@ class Compressor(RegistryMixin):
|
|
35
68
|
):
|
36
69
|
self.config = config
|
37
70
|
|
38
|
-
def
|
71
|
+
def compression_param_info(
|
72
|
+
self,
|
73
|
+
weight_shape: torch.Size,
|
74
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
75
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
76
|
+
"""
|
77
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
78
|
+
parameter used by the compressor
|
79
|
+
|
80
|
+
:param weight_shape: uncompressed weight shape
|
81
|
+
:param quantization_args: quantization parameters for the weight
|
82
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
83
|
+
"""
|
84
|
+
raise NotImplementedError()
|
85
|
+
|
86
|
+
def compress(
|
87
|
+
self,
|
88
|
+
model_state: Dict[str, Tensor],
|
89
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
90
|
+
**kwargs,
|
91
|
+
) -> Dict[str, Tensor]:
|
39
92
|
"""
|
40
93
|
Compresses a dense state dict
|
41
94
|
|
42
95
|
:param model_state: state dict of uncompressed model
|
96
|
+
:param names_to_scheme: quantization args for each quantized weight, needed for
|
97
|
+
quantize function to calculate bit depth
|
43
98
|
:return: compressed state dict
|
44
99
|
"""
|
45
|
-
|
100
|
+
compressed_dict = {}
|
101
|
+
weight_suffix = ".weight"
|
102
|
+
_LOGGER.debug(
|
103
|
+
f"Compressing model with {len(model_state)} parameterized layers..."
|
104
|
+
)
|
105
|
+
|
106
|
+
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
107
|
+
if name.endswith(weight_suffix):
|
108
|
+
prefix = name[: -(len(weight_suffix))]
|
109
|
+
scale = model_state.get(merge_names(prefix, "weight_scale"), None)
|
110
|
+
zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
|
111
|
+
if scale is not None:
|
112
|
+
# weight is quantized, compress it
|
113
|
+
quant_args = names_to_scheme[prefix]
|
114
|
+
compressed_data = self.compress_weight(
|
115
|
+
weight=value,
|
116
|
+
scale=scale,
|
117
|
+
zero_point=zp,
|
118
|
+
quantization_args=quant_args,
|
119
|
+
device="cpu",
|
120
|
+
)
|
121
|
+
for key, value in compressed_data.items():
|
122
|
+
compressed_dict[merge_names(prefix, key)] = value
|
123
|
+
else:
|
124
|
+
compressed_dict[name] = value.to("cpu")
|
125
|
+
elif name.endswith("zero_point") and torch.all(value == 0):
|
126
|
+
# all zero_points are 0, no need to include in
|
127
|
+
# compressed state_dict
|
128
|
+
continue
|
129
|
+
else:
|
130
|
+
compressed_dict[name] = value.to("cpu")
|
131
|
+
|
132
|
+
return compressed_dict
|
46
133
|
|
47
134
|
def decompress(
|
48
|
-
self,
|
135
|
+
self,
|
136
|
+
path_to_model_or_tensors: str,
|
137
|
+
names_to_scheme: Dict[str, QuantizationArgs],
|
138
|
+
device: str = "cpu",
|
49
139
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
50
140
|
"""
|
51
141
|
Reads a compressed state dict located at path_to_model_or_tensors
|
52
142
|
and returns a generator for sequentially decompressing back to a
|
53
143
|
dense state dict
|
54
144
|
|
55
|
-
:param
|
56
|
-
one or more safetensors files) or compressed tensors file
|
145
|
+
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
146
|
+
with one or more safetensors files) or compressed tensors file
|
147
|
+
:param names_to_scheme: quantization args for each quantized weight
|
57
148
|
:param device: optional device to load intermediate weights into
|
58
149
|
:return: compressed state dict
|
59
150
|
"""
|
151
|
+
weight_mappings = get_nested_weight_mappings(
|
152
|
+
path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
|
153
|
+
)
|
154
|
+
for weight_name in weight_mappings.keys():
|
155
|
+
weight_data = {}
|
156
|
+
for param_name, safe_path in weight_mappings[weight_name].items():
|
157
|
+
full_name = merge_names(weight_name, param_name)
|
158
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
159
|
+
weight_data[param_name] = f.get_tensor(full_name)
|
160
|
+
|
161
|
+
if "weight_scale" in weight_data:
|
162
|
+
quant_args = names_to_scheme[weight_name]
|
163
|
+
decompressed = self.decompress_weight(
|
164
|
+
compressed_data=weight_data, quantization_args=quant_args
|
165
|
+
)
|
166
|
+
yield merge_names(weight_name, "weight"), decompressed
|
167
|
+
|
168
|
+
def compress_weight(
|
169
|
+
self,
|
170
|
+
weight: Tensor,
|
171
|
+
scale: Tensor,
|
172
|
+
zero_point: Optional[Tensor] = None,
|
173
|
+
g_idx: Optional[torch.Tensor] = None,
|
174
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
175
|
+
) -> Dict[str, torch.Tensor]:
|
176
|
+
"""
|
177
|
+
Compresses a single uncompressed weight
|
178
|
+
|
179
|
+
:param weight: uncompressed weight tensor
|
180
|
+
:param scale: quantization scale for weight
|
181
|
+
:param zero_point: quantization zero point for weight
|
182
|
+
:param g_idx: optional mapping from column index to group index
|
183
|
+
:param quantization_args: quantization parameters for weight
|
184
|
+
:return: dictionary of compressed weight data
|
185
|
+
"""
|
60
186
|
raise NotImplementedError()
|
187
|
+
|
188
|
+
def decompress_weight(
|
189
|
+
self,
|
190
|
+
compressed_data: Dict[str, Tensor],
|
191
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
192
|
+
) -> torch.Tensor:
|
193
|
+
"""
|
194
|
+
Decompresses a single compressed weight
|
195
|
+
|
196
|
+
:param compressed_data: dictionary of data needed for decompression
|
197
|
+
:param quantization_args: quantization parameters for the weight
|
198
|
+
:return: tensor of the decompressed weight
|
199
|
+
"""
|
200
|
+
raise NotImplementedError()
|
201
|
+
|
202
|
+
def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
|
203
|
+
"""
|
204
|
+
Compresses a single quantized leaf PyTorch module. If the module is not
|
205
|
+
quantized, this function has no effect.
|
206
|
+
|
207
|
+
:param module: PyTorch module to compress
|
208
|
+
:return: dictionary of compressed weight data, or None if module is not
|
209
|
+
quantized
|
210
|
+
"""
|
211
|
+
if not hasattr(module, "quantization_scheme"):
|
212
|
+
return None # module is not quantized
|
213
|
+
quantization_scheme = module.quantization_scheme
|
214
|
+
if not hasattr(quantization_scheme, "weights"):
|
215
|
+
return None # weights are not quantized
|
216
|
+
|
217
|
+
quantization_args = quantization_scheme.weights
|
218
|
+
weight = getattr(module, "weight", None)
|
219
|
+
weight_scale = getattr(module, "weight_scale", None)
|
220
|
+
weight_zero_point = getattr(module, "weight_zero_point", None)
|
221
|
+
|
222
|
+
return self.compress_weight(
|
223
|
+
weight=weight,
|
224
|
+
scale=weight_scale,
|
225
|
+
zero_point=weight_zero_point,
|
226
|
+
quantization_args=quantization_args,
|
227
|
+
)
|
228
|
+
|
229
|
+
def decompress_module(self, module: Module):
|
230
|
+
"""
|
231
|
+
Decompresses a single compressed leaf PyTorch module. If the module is not
|
232
|
+
quantized, this function has no effect.
|
233
|
+
|
234
|
+
:param module: PyTorch module to decompress
|
235
|
+
:return: tensor of the decompressed weight, or None if module is not quantized
|
236
|
+
"""
|
237
|
+
if not hasattr(module, "quantization_scheme"):
|
238
|
+
return None # module is not quantized
|
239
|
+
quantization_scheme = module.quantization_scheme
|
240
|
+
if not hasattr(quantization_scheme, "weights"):
|
241
|
+
return None # weights are not quantized
|
242
|
+
|
243
|
+
quantization_args = quantization_scheme.weights
|
244
|
+
compressed_data = {}
|
245
|
+
for name, parameter in module.named_parameters():
|
246
|
+
compressed_data[name] = parameter
|
247
|
+
|
248
|
+
return self.decompress_weight(
|
249
|
+
compressed_data=compressed_data, quantization_args=quantization_args
|
250
|
+
)
|
@@ -28,7 +28,7 @@ from compressed_tensors.base import (
|
|
28
28
|
SPARSITY_CONFIG_NAME,
|
29
29
|
)
|
30
30
|
from compressed_tensors.compressors import Compressor
|
31
|
-
from compressed_tensors.config import SparsityCompressionConfig
|
31
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
32
32
|
from compressed_tensors.quantization import (
|
33
33
|
QuantizationConfig,
|
34
34
|
QuantizationStatus,
|
@@ -176,6 +176,9 @@ class ModelCompressor:
|
|
176
176
|
if hasattr(compression_config, SPARSITY_CONFIG_NAME):
|
177
177
|
# for loaded HFQuantizer config
|
178
178
|
return getattr(compression_config, SPARSITY_CONFIG_NAME)
|
179
|
+
if SPARSITY_CONFIG_NAME in compression_config:
|
180
|
+
# for loaded HFQuantizer config from dict
|
181
|
+
return compression_config[SPARSITY_CONFIG_NAME]
|
179
182
|
|
180
183
|
# SparseAutoModel format
|
181
184
|
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
@@ -189,6 +192,10 @@ class ModelCompressor:
|
|
189
192
|
# for loaded HFQuantizer config
|
190
193
|
return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
|
191
194
|
|
195
|
+
if QUANTIZATION_CONFIG_NAME in compression_config:
|
196
|
+
# for loaded HFQuantizer config from dict
|
197
|
+
return compression_config[QUANTIZATION_CONFIG_NAME]
|
198
|
+
|
192
199
|
# SparseAutoModel format
|
193
200
|
quantization_config = deepcopy(compression_config)
|
194
201
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
@@ -234,6 +241,10 @@ class ModelCompressor:
|
|
234
241
|
compressed_state_dict = self.quantization_compressor.compress(
|
235
242
|
state_dict, names_to_scheme=quantized_modules_to_args
|
236
243
|
)
|
244
|
+
if self.quantization_config.format != CompressionFormat.dense.value:
|
245
|
+
self.quantization_config.quantization_status = (
|
246
|
+
QuantizationStatus.COMPRESSED
|
247
|
+
)
|
237
248
|
|
238
249
|
if self.sparsity_compressor is not None:
|
239
250
|
compressed_state_dict = self.sparsity_compressor.compress(
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from typing import Dict,
|
16
|
+
from typing import Dict, Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from compressed_tensors.compressors import Compressor
|
@@ -21,10 +21,7 @@ from compressed_tensors.config import CompressionFormat
|
|
21
21
|
from compressed_tensors.quantization import QuantizationArgs
|
22
22
|
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
23
23
|
from compressed_tensors.quantization.utils import can_quantize
|
24
|
-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
25
|
-
from safetensors import safe_open
|
26
24
|
from torch import Tensor
|
27
|
-
from tqdm import tqdm
|
28
25
|
|
29
26
|
|
30
27
|
__all__ = [
|
@@ -51,88 +48,78 @@ class QuantizationCompressor(Compressor):
|
|
51
48
|
"weight_g_idx",
|
52
49
|
]
|
53
50
|
|
54
|
-
def
|
51
|
+
def compression_param_info(
|
55
52
|
self,
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
) -> Dict[str, Tensor]:
|
53
|
+
weight_shape: torch.Size,
|
54
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
55
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
60
56
|
"""
|
61
|
-
|
57
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
58
|
+
parameter used by the compressor
|
62
59
|
|
63
|
-
:param
|
64
|
-
:param
|
65
|
-
|
66
|
-
:return: compressed state dict
|
60
|
+
:param weight_shape: uncompressed weight shape
|
61
|
+
:param quantization_args: quantization parameters for the weight
|
62
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
67
63
|
"""
|
68
|
-
|
69
|
-
|
70
|
-
_LOGGER.debug(
|
71
|
-
f"Compressing model with {len(model_state)} parameterized layers..."
|
72
|
-
)
|
64
|
+
dtype = quantization_args.pytorch_dtype()
|
65
|
+
return {"weight": (weight_shape, dtype)}
|
73
66
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
if can_quantize(value, quant_args):
|
84
|
-
# only quantize if not already quantized
|
85
|
-
value = quantize(
|
86
|
-
x=value,
|
87
|
-
scale=scale,
|
88
|
-
zero_point=zp,
|
89
|
-
args=quant_args,
|
90
|
-
dtype=quant_args.pytorch_dtype(),
|
91
|
-
g_idx=g_idx,
|
92
|
-
)
|
93
|
-
elif name.endswith("zero_point"):
|
94
|
-
if torch.all(value == 0):
|
95
|
-
# all zero_points are 0, no need to include in
|
96
|
-
# compressed state_dict
|
97
|
-
continue
|
98
|
-
compressed_dict[name] = value.to("cpu")
|
99
|
-
|
100
|
-
return compressed_dict
|
101
|
-
|
102
|
-
def decompress(
|
103
|
-
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
104
|
-
) -> Generator[Tuple[str, Tensor], None, None]:
|
67
|
+
def compress_weight(
|
68
|
+
self,
|
69
|
+
weight: Tensor,
|
70
|
+
scale: Tensor,
|
71
|
+
zero_point: Optional[Tensor] = None,
|
72
|
+
g_idx: Optional[torch.Tensor] = None,
|
73
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
74
|
+
device: Optional[torch.device] = None,
|
75
|
+
) -> Dict[str, torch.Tensor]:
|
105
76
|
"""
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
:param
|
111
|
-
|
112
|
-
:param
|
113
|
-
:
|
77
|
+
Compresses a single uncompressed weight
|
78
|
+
|
79
|
+
:param weight: uncompressed weight tensor
|
80
|
+
:param scale: quantization scale for weight
|
81
|
+
:param zero_point: quantization zero point for weight
|
82
|
+
:param g_idx: optional mapping from column index to group index
|
83
|
+
:param quantization_args: quantization parameters for weight
|
84
|
+
:param device: optional device to move compressed output to
|
85
|
+
:return: dictionary of compressed weight data
|
86
|
+
"""
|
87
|
+
if can_quantize(weight, quantization_args):
|
88
|
+
quantized_weight = quantize(
|
89
|
+
x=weight,
|
90
|
+
scale=scale,
|
91
|
+
zero_point=zero_point,
|
92
|
+
g_idx=g_idx,
|
93
|
+
args=quantization_args,
|
94
|
+
dtype=quantization_args.pytorch_dtype(),
|
95
|
+
)
|
96
|
+
|
97
|
+
if device is not None:
|
98
|
+
quantized_weight = quantized_weight.to(device)
|
99
|
+
|
100
|
+
return {"weight": quantized_weight}
|
101
|
+
|
102
|
+
def decompress_weight(
|
103
|
+
self,
|
104
|
+
compressed_data: Dict[str, Tensor],
|
105
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
106
|
+
) -> torch.Tensor:
|
114
107
|
"""
|
115
|
-
|
116
|
-
|
108
|
+
Decompresses a single compressed weight
|
109
|
+
|
110
|
+
:param compressed_data: dictionary of data needed for decompression
|
111
|
+
:param quantization_args: quantization parameters for the weight
|
112
|
+
:return: tensor of the decompressed weight
|
113
|
+
"""
|
114
|
+
weight = compressed_data["weight"]
|
115
|
+
scale = compressed_data["weight_scale"]
|
116
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
117
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
118
|
+
decompressed_weight = dequantize(
|
119
|
+
x_q=weight, scale=scale, zero_point=zero_point, g_idx=g_idx
|
117
120
|
)
|
118
|
-
|
119
|
-
|
120
|
-
for param_name, safe_path in weight_mappings[weight_name].items():
|
121
|
-
full_name = merge_names(weight_name, param_name)
|
122
|
-
with safe_open(safe_path, framework="pt", device=device) as f:
|
123
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
124
|
-
|
125
|
-
if "weight_scale" in weight_data:
|
126
|
-
scale = weight_data["weight_scale"]
|
127
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
128
|
-
g_idx = weight_data.get("weight_g_idx", None)
|
129
|
-
decompressed = dequantize(
|
130
|
-
x_q=weight_data["weight"],
|
131
|
-
scale=scale,
|
132
|
-
zero_point=zero_point,
|
133
|
-
g_idx=g_idx,
|
134
|
-
)
|
135
|
-
yield merge_names(weight_name, "weight"), decompressed
|
121
|
+
|
122
|
+
return decompressed_weight
|
136
123
|
|
137
124
|
|
138
125
|
@Compressor.register(name=CompressionFormat.int_quantized.value)
|
@@ -11,10 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
import logging
|
16
14
|
import math
|
17
|
-
from typing import Dict,
|
15
|
+
from typing import Dict, Optional, Tuple
|
18
16
|
|
19
17
|
import numpy as np
|
20
18
|
import torch
|
@@ -23,16 +21,11 @@ from compressed_tensors.config import CompressionFormat
|
|
23
21
|
from compressed_tensors.quantization import QuantizationArgs
|
24
22
|
from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
|
25
23
|
from compressed_tensors.quantization.utils import can_quantize
|
26
|
-
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
27
|
-
from safetensors import safe_open
|
28
24
|
from torch import Tensor
|
29
|
-
from tqdm import tqdm
|
30
25
|
|
31
26
|
|
32
27
|
__all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
|
33
28
|
|
34
|
-
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
35
|
-
|
36
29
|
|
37
30
|
@Compressor.register(name=CompressionFormat.pack_quantized.value)
|
38
31
|
class PackedQuantizationCompressor(Compressor):
|
@@ -48,103 +41,92 @@ class PackedQuantizationCompressor(Compressor):
|
|
48
41
|
"weight_shape",
|
49
42
|
]
|
50
43
|
|
51
|
-
def
|
44
|
+
def compression_param_info(
|
52
45
|
self,
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
) -> Dict[str, Tensor]:
|
46
|
+
weight_shape: torch.Size,
|
47
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
48
|
+
) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
|
57
49
|
"""
|
58
|
-
|
50
|
+
Creates a dictionary of expected shapes and dtypes for each compression
|
51
|
+
parameter used by the compressor
|
59
52
|
|
60
|
-
:param
|
61
|
-
:param
|
62
|
-
|
63
|
-
|
53
|
+
:param weight_shape: uncompressed weight shape
|
54
|
+
:param quantization_args: quantization parameters for the weight
|
55
|
+
:return: dictionary mapping compressed parameter names to shape and dtype
|
56
|
+
"""
|
57
|
+
pack_factor = 32 // quantization_args.num_bits
|
58
|
+
packed_size = math.ceil(weight_shape[1] / pack_factor)
|
59
|
+
return {
|
60
|
+
"weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
|
61
|
+
"weight_shape": (torch.Size((2,)), torch.int32),
|
62
|
+
}
|
63
|
+
|
64
|
+
def compress_weight(
|
65
|
+
self,
|
66
|
+
weight: Tensor,
|
67
|
+
scale: Tensor,
|
68
|
+
zero_point: Optional[Tensor] = None,
|
69
|
+
g_idx: Optional[torch.Tensor] = None,
|
70
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
71
|
+
device: Optional[torch.device] = None,
|
72
|
+
) -> Dict[str, torch.Tensor]:
|
73
|
+
"""
|
74
|
+
Compresses a single uncompressed weight
|
75
|
+
|
76
|
+
:param weight: uncompressed weight tensor
|
77
|
+
:param scale: quantization scale for weight
|
78
|
+
:param zero_point: quantization zero point for weight
|
79
|
+
:param g_idx: optional mapping from column index to group index
|
80
|
+
:param quantization_args: quantization parameters for weight
|
81
|
+
:param device: optional device to move compressed output to
|
82
|
+
:return: dictionary of compressed weight data
|
64
83
|
"""
|
65
84
|
compressed_dict = {}
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
x=value,
|
85
|
-
scale=scale,
|
86
|
-
zero_point=zp,
|
87
|
-
g_idx=g_idx,
|
88
|
-
args=quant_args,
|
89
|
-
dtype=torch.int8,
|
90
|
-
)
|
91
|
-
value = pack_to_int32(value.cpu(), quant_args.num_bits)
|
92
|
-
compressed_dict[merge_names(prefix, "weight_shape")] = shape
|
93
|
-
compressed_dict[merge_names(prefix, "weight_packed")] = value
|
94
|
-
continue
|
95
|
-
|
96
|
-
elif name.endswith("zero_point"):
|
97
|
-
if torch.all(value == 0):
|
98
|
-
# all zero_points are 0, no need to include in
|
99
|
-
# compressed state_dict
|
100
|
-
continue
|
101
|
-
|
102
|
-
compressed_dict[name] = value.to("cpu")
|
85
|
+
if can_quantize(weight, quantization_args):
|
86
|
+
quantized_weight = quantize(
|
87
|
+
x=weight,
|
88
|
+
scale=scale,
|
89
|
+
zero_point=zero_point,
|
90
|
+
g_idx=g_idx,
|
91
|
+
args=quantization_args,
|
92
|
+
dtype=torch.int8,
|
93
|
+
)
|
94
|
+
|
95
|
+
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
96
|
+
weight_shape = torch.tensor(weight.shape)
|
97
|
+
if device is not None:
|
98
|
+
packed_weight = packed_weight.to(device)
|
99
|
+
weight_shape = weight_shape.to(device)
|
100
|
+
|
101
|
+
compressed_dict["weight_shape"] = weight_shape
|
102
|
+
compressed_dict["weight_packed"] = packed_weight
|
103
103
|
|
104
104
|
return compressed_dict
|
105
105
|
|
106
|
-
def
|
106
|
+
def decompress_weight(
|
107
107
|
self,
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
) -> Generator[Tuple[str, Tensor], None, None]:
|
108
|
+
compressed_data: Dict[str, Tensor],
|
109
|
+
quantization_args: Optional[QuantizationArgs] = None,
|
110
|
+
) -> torch.Tensor:
|
112
111
|
"""
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
:
|
118
|
-
one or more safetensors files) or compressed tensors file
|
119
|
-
:param device: optional device to load intermediate weights into
|
120
|
-
:return: compressed state dict
|
112
|
+
Decompresses a single compressed weight
|
113
|
+
|
114
|
+
:param compressed_data: dictionary of data needed for decompression
|
115
|
+
:param quantization_args: quantization parameters for the weight
|
116
|
+
:return: tensor of the decompressed weight
|
121
117
|
"""
|
122
|
-
|
123
|
-
|
118
|
+
weight = compressed_data["weight_packed"]
|
119
|
+
scale = compressed_data["weight_scale"]
|
120
|
+
zero_point = compressed_data.get("weight_zero_point", None)
|
121
|
+
g_idx = compressed_data.get("weight_g_idx", None)
|
122
|
+
original_shape = torch.Size(compressed_data["weight_shape"])
|
123
|
+
num_bits = quantization_args.num_bits
|
124
|
+
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
125
|
+
decompressed_weight = dequantize(
|
126
|
+
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
|
124
127
|
)
|
125
|
-
|
126
|
-
|
127
|
-
for param_name, safe_path in weight_mappings[weight_name].items():
|
128
|
-
weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
|
129
|
-
full_name = merge_names(weight_name, param_name)
|
130
|
-
with safe_open(safe_path, framework="pt", device=device) as f:
|
131
|
-
weight_data[param_name] = f.get_tensor(full_name)
|
132
|
-
|
133
|
-
if "weight_scale" in weight_data:
|
134
|
-
weight = weight_data["weight_packed"]
|
135
|
-
scale = weight_data["weight_scale"]
|
136
|
-
zero_point = weight_data.get("weight_zero_point", None)
|
137
|
-
g_idx = weight_data.get("weight_g_idx", None)
|
138
|
-
num_bits = weight_data["num_bits"]
|
139
|
-
original_shape = torch.Size(weight_data["weight_shape"])
|
140
|
-
unpacked = unpack_from_int32(weight, num_bits, original_shape)
|
141
|
-
decompressed = dequantize(
|
142
|
-
x_q=unpacked,
|
143
|
-
scale=scale,
|
144
|
-
zero_point=zero_point,
|
145
|
-
g_idx=g_idx,
|
146
|
-
)
|
147
|
-
yield merge_names(weight_name, "weight"), decompressed
|
128
|
+
|
129
|
+
return decompressed_weight
|
148
130
|
|
149
131
|
|
150
132
|
def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
|
@@ -202,13 +184,15 @@ def unpack_from_int32(
|
|
202
184
|
if num_bits > 8:
|
203
185
|
raise ValueError("Unpacking is only supported for less than 8 bits")
|
204
186
|
|
205
|
-
# convert packed input to unsigned numpy
|
206
|
-
value = value.numpy().view(np.uint32)
|
207
187
|
pack_factor = 32 // num_bits
|
208
188
|
|
209
189
|
# unpack
|
210
190
|
mask = pow(2, num_bits) - 1
|
211
|
-
unpacked =
|
191
|
+
unpacked = torch.zeros(
|
192
|
+
(value.shape[0], value.shape[1] * pack_factor),
|
193
|
+
device=value.device,
|
194
|
+
dtype=torch.int32,
|
195
|
+
)
|
212
196
|
for i in range(pack_factor):
|
213
197
|
unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
|
214
198
|
|
@@ -219,6 +203,6 @@ def unpack_from_int32(
|
|
219
203
|
# bits are packed in unsigned format, reformat to signed
|
220
204
|
# update the value range from unsigned to signed
|
221
205
|
offset = pow(2, num_bits) // 2
|
222
|
-
unpacked = (unpacked
|
206
|
+
unpacked = (unpacked - offset).to(torch.int8)
|
223
207
|
|
224
|
-
return
|
208
|
+
return unpacked
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import torch
|
16
|
+
from compressed_tensors.compressors.base import Compressor
|
17
|
+
from compressed_tensors.quantization import (
|
18
|
+
QuantizationScheme,
|
19
|
+
QuantizationStatus,
|
20
|
+
initialize_module_for_quantization,
|
21
|
+
)
|
22
|
+
from torch import Tensor
|
23
|
+
from torch.nn import Parameter
|
24
|
+
from torch.nn.functional import linear
|
25
|
+
from torch.nn.modules import Linear
|
26
|
+
|
27
|
+
|
28
|
+
class CompressedLinear(Linear):
|
29
|
+
"""
|
30
|
+
Wrapper module for running a compressed forward pass of a quantized Linear module.
|
31
|
+
The wrapped layer will decompressed on each forward call.
|
32
|
+
|
33
|
+
:param module: dense linear module to replace
|
34
|
+
:param quantization_scheme: quantization config for the module to wrap
|
35
|
+
:param quantization_format: compression format module is stored as
|
36
|
+
"""
|
37
|
+
|
38
|
+
@classmethod
|
39
|
+
@torch.no_grad()
|
40
|
+
def from_linear(
|
41
|
+
cls,
|
42
|
+
module: Linear,
|
43
|
+
quantization_scheme: QuantizationScheme,
|
44
|
+
quantization_format: str,
|
45
|
+
):
|
46
|
+
module.__class__ = CompressedLinear
|
47
|
+
module.compressor = Compressor.load_from_registry(quantization_format)
|
48
|
+
device = next(module.parameters()).device
|
49
|
+
|
50
|
+
# this will initialize all the scales and zero points
|
51
|
+
initialize_module_for_quantization(
|
52
|
+
module, quantization_scheme, force_zero_point=False
|
53
|
+
)
|
54
|
+
|
55
|
+
# get the shape and dtype of compressed parameters
|
56
|
+
compression_params = module.compressor.compression_param_info(
|
57
|
+
module.weight.shape, quantization_scheme.weights
|
58
|
+
)
|
59
|
+
|
60
|
+
# no need for this once quantization is initialized, will be replaced
|
61
|
+
# with the compressed parameter
|
62
|
+
delattr(module, "weight")
|
63
|
+
|
64
|
+
# populate compressed weights and quantization parameters
|
65
|
+
for name, (shape, dtype) in compression_params.items():
|
66
|
+
param = Parameter(
|
67
|
+
torch.empty(shape, device=device, dtype=dtype), requires_grad=False
|
68
|
+
)
|
69
|
+
module.register_parameter(name, param)
|
70
|
+
|
71
|
+
# mark module as compressed
|
72
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
73
|
+
|
74
|
+
# handles case where forward is wrapped in new_forward by accelerate hooks
|
75
|
+
if hasattr(module, "_old_forward"):
|
76
|
+
module._old_forward = CompressedLinear.forward.__get__(
|
77
|
+
module, CompressedLinear
|
78
|
+
)
|
79
|
+
|
80
|
+
return module
|
81
|
+
|
82
|
+
def forward(self, input: Tensor) -> Tensor:
|
83
|
+
"""
|
84
|
+
Decompresses the weight, then runs the wrapped forward pass
|
85
|
+
"""
|
86
|
+
uncompressed_weight = self.compressor.decompress_module(self)
|
87
|
+
return linear(input, uncompressed_weight, self.bias)
|
@@ -21,6 +21,7 @@ from typing import OrderedDict as OrderedDictType
|
|
21
21
|
from typing import Union
|
22
22
|
|
23
23
|
import torch
|
24
|
+
from compressed_tensors.config import CompressionFormat
|
24
25
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
25
26
|
set_module_for_calibration,
|
26
27
|
)
|
@@ -43,7 +44,7 @@ from compressed_tensors.quantization.utils import (
|
|
43
44
|
is_kv_cache_quant_scheme,
|
44
45
|
iter_named_leaf_modules,
|
45
46
|
)
|
46
|
-
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
47
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
47
48
|
from compressed_tensors.utils.offload import update_parameter_data
|
48
49
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
49
50
|
from torch.nn import Module
|
@@ -104,12 +105,16 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
104
105
|
)
|
105
106
|
|
106
107
|
|
107
|
-
def apply_quantization_config(
|
108
|
+
def apply_quantization_config(
|
109
|
+
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
+
) -> Dict:
|
108
111
|
"""
|
109
112
|
Initializes the model for quantization in-place based on the given config
|
110
113
|
|
111
114
|
:param model: model to apply quantization config to
|
112
115
|
:param config: quantization config
|
116
|
+
:param run_compressed: Whether the model will be run in compressed mode or
|
117
|
+
decompressed fully on load
|
113
118
|
"""
|
114
119
|
# remove reference to the original `config`
|
115
120
|
# argument. This function can mutate it, and we'd
|
@@ -124,6 +129,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
124
129
|
for target in scheme.targets:
|
125
130
|
target_to_scheme[target] = scheme
|
126
131
|
|
132
|
+
if run_compressed:
|
133
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
134
|
+
|
127
135
|
# list of submodules to ignore
|
128
136
|
ignored_submodules = defaultdict(list)
|
129
137
|
# mark appropriate layers for quantization by setting their quantization schemes
|
@@ -136,10 +144,24 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
136
144
|
continue # layer matches ignore list, continue
|
137
145
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
138
146
|
if targets:
|
147
|
+
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
|
+
if run_compressed:
|
149
|
+
format = config.format
|
150
|
+
if format != CompressionFormat.dense.value:
|
151
|
+
if isinstance(submodule, torch.nn.Linear):
|
152
|
+
# TODO: expand to more module types
|
153
|
+
compressed_linear = CompressedLinear.from_linear(
|
154
|
+
submodule,
|
155
|
+
quantization_scheme=scheme,
|
156
|
+
quantization_format=format,
|
157
|
+
)
|
158
|
+
replace_module(model, name, compressed_linear)
|
159
|
+
|
139
160
|
# target matched - add layer and scheme to target list
|
140
161
|
submodule.quantization_scheme = _scheme_from_targets(
|
141
162
|
target_to_scheme, targets, name
|
142
163
|
)
|
164
|
+
|
143
165
|
names_to_scheme[name] = submodule.quantization_scheme.weights
|
144
166
|
|
145
167
|
if config.ignore is not None and ignored_submodules is not None:
|
@@ -149,8 +171,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
|
|
149
171
|
"not found in the model: "
|
150
172
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
151
173
|
)
|
152
|
-
# apply current quantization status across all targeted layers
|
153
174
|
|
175
|
+
# apply current quantization status across all targeted layers
|
154
176
|
apply_quantization_status(model, config.quantization_status)
|
155
177
|
return names_to_scheme
|
156
178
|
|
@@ -198,7 +220,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
198
220
|
current_status = infer_quantization_status(model)
|
199
221
|
|
200
222
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
201
|
-
|
223
|
+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
|
224
|
+
model.apply(
|
225
|
+
lambda module: initialize_module_for_quantization(
|
226
|
+
module, force_zero_point=force_zero_point_init
|
227
|
+
)
|
228
|
+
)
|
202
229
|
|
203
230
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
204
231
|
# only quantize weights up front when our end goal state is calibration,
|
@@ -50,7 +50,7 @@ def compress_quantized_weights(module: Module):
|
|
50
50
|
scale = getattr(module, "weight_scale", None)
|
51
51
|
zero_point = getattr(module, "weight_zero_point", None)
|
52
52
|
|
53
|
-
if weight is None or scale is None
|
53
|
+
if weight is None or scale is None:
|
54
54
|
# no weight, scale, or ZP, nothing to do
|
55
55
|
|
56
56
|
# mark as compressed here to maintain consistent status throughout the model
|
@@ -62,14 +62,6 @@ def quantize(
|
|
62
62
|
:param g_idx: optional mapping from column index to group index
|
63
63
|
:return: fake quantized tensor
|
64
64
|
"""
|
65
|
-
# ensure all tensors are on the same device
|
66
|
-
# assumes that the target device is the input
|
67
|
-
# tensor's device
|
68
|
-
if x.device != scale.device:
|
69
|
-
scale = scale.to(x.device)
|
70
|
-
if x.device != zero_point.device:
|
71
|
-
zero_point = zero_point.to(x.device)
|
72
|
-
|
73
65
|
return _process_quantization(
|
74
66
|
x=x,
|
75
67
|
scale=scale,
|
@@ -274,6 +266,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
274
266
|
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
275
267
|
|
276
268
|
input_ = args[0]
|
269
|
+
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
277
270
|
|
278
271
|
if scheme.input_activations is not None:
|
279
272
|
# calibrate and (fake) quantize input activations when applicable
|
@@ -281,7 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
281
274
|
module, input_, "input", scheme.input_activations
|
282
275
|
)
|
283
276
|
|
284
|
-
if scheme.weights is not None:
|
277
|
+
if scheme.weights is not None and not compressed:
|
285
278
|
# calibrate and (fake) quantize weights when applicable
|
286
279
|
unquantized_weight = self.weight.data.clone()
|
287
280
|
self.weight.data = maybe_calibrate_or_quantize(
|
@@ -300,7 +293,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
300
293
|
)
|
301
294
|
|
302
295
|
# restore back to unquantized_value
|
303
|
-
if scheme.weights is not None:
|
296
|
+
if scheme.weights is not None and not compressed:
|
304
297
|
self.weight.data = unquantized_weight
|
305
298
|
|
306
299
|
return output
|
@@ -314,11 +307,16 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
314
307
|
def maybe_calibrate_or_quantize(
|
315
308
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
316
309
|
) -> torch.Tensor:
|
317
|
-
#
|
318
|
-
if module.quantization_status
|
319
|
-
|
320
|
-
|
321
|
-
|
310
|
+
# don't run quantization if we haven't entered calibration mode
|
311
|
+
if module.quantization_status == QuantizationStatus.INITIALIZED:
|
312
|
+
return value
|
313
|
+
|
314
|
+
# in compressed mode, the weight is already compressed and quantized so we don't
|
315
|
+
# need to run fake quantization
|
316
|
+
if (
|
317
|
+
module.quantization_status == QuantizationStatus.COMPRESSED
|
318
|
+
and base_name == "weight"
|
319
|
+
):
|
322
320
|
return value
|
323
321
|
|
324
322
|
if value.numel() == 0:
|
@@ -335,7 +333,7 @@ def maybe_calibrate_or_quantize(
|
|
335
333
|
else:
|
336
334
|
# static quantization - get previous scale and zero point from layer
|
337
335
|
scale = getattr(module, f"{base_name}_scale")
|
338
|
-
zero_point = getattr(module, f"{base_name}_zero_point")
|
336
|
+
zero_point = getattr(module, f"{base_name}_zero_point", None)
|
339
337
|
|
340
338
|
if (
|
341
339
|
module.quantization_status == QuantizationStatus.CALIBRATION
|
@@ -364,7 +362,9 @@ def _quantize(
|
|
364
362
|
dtype: Optional[torch.dtype] = None,
|
365
363
|
) -> torch.Tensor:
|
366
364
|
|
367
|
-
scaled = x / scale
|
365
|
+
scaled = x / scale
|
366
|
+
if zero_point is not None:
|
367
|
+
scaled += zero_point.to(x.dtype)
|
368
368
|
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
369
369
|
clamped_value = torch.clamp(
|
370
370
|
scaled,
|
@@ -385,11 +385,11 @@ def _dequantize(
|
|
385
385
|
zero_point: torch.Tensor = None,
|
386
386
|
dtype: Optional[torch.dtype] = None,
|
387
387
|
) -> torch.Tensor:
|
388
|
+
dequant_value = x_q.to(scale.dtype)
|
388
389
|
|
389
|
-
dequant_value = x_q
|
390
390
|
if zero_point is not None:
|
391
391
|
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
392
|
-
dequant_value = dequant_value
|
392
|
+
dequant_value = dequant_value * scale
|
393
393
|
|
394
394
|
if dtype is not None:
|
395
395
|
dequant_value = dequant_value.to(dtype)
|
@@ -41,6 +41,7 @@ _LOGGER = logging.getLogger(__name__)
|
|
41
41
|
def initialize_module_for_quantization(
|
42
42
|
module: Module,
|
43
43
|
scheme: Optional[QuantizationScheme] = None,
|
44
|
+
force_zero_point: bool = True,
|
44
45
|
):
|
45
46
|
"""
|
46
47
|
attaches appropriate scales, zero points, and observers to a layer
|
@@ -52,6 +53,8 @@ def initialize_module_for_quantization(
|
|
52
53
|
:param scheme: scheme to use for quantization. if None is provided,
|
53
54
|
will attempt to use scheme stored in the module under `quantization_scheme`,
|
54
55
|
if not provided, the layer will be skipped
|
56
|
+
:param force_zero_point: whether to force initialization of a zero point for
|
57
|
+
symmetric quantization
|
55
58
|
"""
|
56
59
|
scheme = scheme or getattr(module, "quantization_scheme", None)
|
57
60
|
if scheme is None:
|
@@ -59,14 +62,18 @@ def initialize_module_for_quantization(
|
|
59
62
|
return
|
60
63
|
|
61
64
|
if scheme.input_activations is not None:
|
62
|
-
_initialize_scale_zero_point_observer(
|
65
|
+
_initialize_scale_zero_point_observer(
|
66
|
+
module, "input", scheme.input_activations, force_zero_point=force_zero_point
|
67
|
+
)
|
63
68
|
if scheme.weights is not None:
|
64
69
|
if hasattr(module, "weight"):
|
65
|
-
weight_shape =
|
66
|
-
if isinstance(module, torch.nn.Linear):
|
67
|
-
weight_shape = module.weight.shape
|
70
|
+
weight_shape = module.weight.shape
|
68
71
|
_initialize_scale_zero_point_observer(
|
69
|
-
module,
|
72
|
+
module,
|
73
|
+
"weight",
|
74
|
+
scheme.weights,
|
75
|
+
weight_shape=weight_shape,
|
76
|
+
force_zero_point=force_zero_point,
|
70
77
|
)
|
71
78
|
else:
|
72
79
|
_LOGGER.warning(
|
@@ -76,7 +83,10 @@ def initialize_module_for_quantization(
|
|
76
83
|
)
|
77
84
|
if scheme.output_activations is not None:
|
78
85
|
_initialize_scale_zero_point_observer(
|
79
|
-
module,
|
86
|
+
module,
|
87
|
+
"output",
|
88
|
+
scheme.output_activations,
|
89
|
+
force_zero_point=force_zero_point,
|
80
90
|
)
|
81
91
|
|
82
92
|
module.quantization_scheme = scheme
|
@@ -124,6 +134,7 @@ def _initialize_scale_zero_point_observer(
|
|
124
134
|
base_name: str,
|
125
135
|
quantization_args: QuantizationArgs,
|
126
136
|
weight_shape: Optional[torch.Size] = None,
|
137
|
+
force_zero_point: bool = True,
|
127
138
|
):
|
128
139
|
# initialize observer module and attach as submodule
|
129
140
|
observer = quantization_args.get_observer()
|
@@ -149,20 +160,24 @@ def _initialize_scale_zero_point_observer(
|
|
149
160
|
weight_shape[1] // quantization_args.group_size,
|
150
161
|
)
|
151
162
|
|
163
|
+
scale_dtype = module.weight.dtype
|
164
|
+
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
165
|
+
scale_dtype = torch.float16
|
166
|
+
|
152
167
|
# initializes empty scale, zero point, and g_idx parameters for the module
|
153
168
|
init_scale = Parameter(
|
154
|
-
torch.empty(expected_shape, dtype=
|
169
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
155
170
|
requires_grad=False,
|
156
171
|
)
|
157
172
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
158
173
|
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
174
|
+
if force_zero_point or not quantization_args.symmetric:
|
175
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
176
|
+
init_zero_point = Parameter(
|
177
|
+
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
178
|
+
requires_grad=False,
|
179
|
+
)
|
180
|
+
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|
166
181
|
|
167
182
|
# initialize with empty for actorder, to be populated by GPTQ or state_dict
|
168
183
|
if quantization_args.actorder:
|
@@ -22,6 +22,7 @@ __all__ = [
|
|
22
22
|
"infer_compressor_from_model_config",
|
23
23
|
"fix_fsdp_module_name",
|
24
24
|
"tensor_follows_mask_structure",
|
25
|
+
"replace_module",
|
25
26
|
]
|
26
27
|
|
27
28
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -90,3 +91,15 @@ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
|
90
91
|
raise ValueError()
|
91
92
|
|
92
93
|
return True
|
94
|
+
|
95
|
+
|
96
|
+
def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Module):
|
97
|
+
if "." in name:
|
98
|
+
parent_name = name.rsplit(".", 1)[0]
|
99
|
+
child_name = name[len(parent_name) + 1 :]
|
100
|
+
parent = model.get_submodule(parent_name)
|
101
|
+
else:
|
102
|
+
parent_name = ""
|
103
|
+
parent = model
|
104
|
+
child_name = name
|
105
|
+
setattr(parent, child_name, new_module)
|
@@ -91,6 +91,9 @@ def update_parameter_data(
|
|
91
91
|
:param new_param_data: tensor to update parameter with
|
92
92
|
:param param_name: name of layer parameter to update
|
93
93
|
"""
|
94
|
+
if not hasattr(module, param_name):
|
95
|
+
return
|
96
|
+
|
94
97
|
device = next(module.parameters()).device
|
95
98
|
|
96
99
|
offloaded = False
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.5.0.
|
3
|
+
Version: 0.5.0.20240831
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -2,30 +2,32 @@ compressed_tensors/__init__.py,sha256=SV1csvHUVCd8kHXz6UDZim1HZ_fAVG3vfk-j_4Bb6h
|
|
2
2
|
compressed_tensors/base.py,sha256=Mq4mfVQcJhNpha-BXzpOfpmFIdl01o09BJE7D2oQ_00,796
|
3
3
|
compressed_tensors/version.py,sha256=DdMT4o5D6_t26gTuvhF1Q9HPeXY6vV5g7XMprWuHLdI,1586
|
4
4
|
compressed_tensors/compressors/__init__.py,sha256=wmX4VnkUTS63xBwK5-6w8FP78bNZpcdcqvf2KOEC5E4,1133
|
5
|
-
compressed_tensors/compressors/base.py,sha256
|
5
|
+
compressed_tensors/compressors/base.py,sha256=4BO07h28Epbl2ED43lORnPGmBZ3pMdaoLYym_LJTpPQ,9846
|
6
6
|
compressed_tensors/compressors/dense.py,sha256=xcWECjcRY4INN6jC7vHx5wvUX3NmnKlxA9SVE1A6m2Q,1267
|
7
7
|
compressed_tensors/compressors/helpers.py,sha256=k9avlkmeYj6vkOAvl-MgcixtP7ib24SCfhzZ-RusXfw,5403
|
8
8
|
compressed_tensors/compressors/marlin_24.py,sha256=e7fGUyZbjUpA5VUMCPxqcYPGNiwoDKupHJaXWCoVKRw,9410
|
9
|
-
compressed_tensors/compressors/model_compressor.py,sha256=
|
10
|
-
compressed_tensors/compressors/naive_quantized.py,sha256=
|
11
|
-
compressed_tensors/compressors/pack_quantized.py,sha256=
|
9
|
+
compressed_tensors/compressors/model_compressor.py,sha256=Yv2V8Ey6AFDg2Tmvwc7-E_AnMFkeIy_HVu62ct650AI,16507
|
10
|
+
compressed_tensors/compressors/naive_quantized.py,sha256=z3h3ca5xKCN69mahutxcbzdv-OysiaxaM8P-Qum6zUQ,4823
|
11
|
+
compressed_tensors/compressors/pack_quantized.py,sha256=27RVmJ2wg2dvCoawj407HSmKT3VPGJ6ujAMHlT26WlI,7571
|
12
12
|
compressed_tensors/compressors/sparse_bitmask.py,sha256=kiDwBlFV0sJGLcIdDYxIiuF64ccgwDfqq1hWRQThYDc,8647
|
13
13
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
14
14
|
compressed_tensors/config/base.py,sha256=caSZ7xZ_kgcHRMXZ5hM1i6TKbgY__CkiSjZ93imHZQ0,1562
|
15
15
|
compressed_tensors/config/dense.py,sha256=NgSxnFCnckU9-iunxEaqiFwqgdO7YYxlWKR74jNbjks,1317
|
16
16
|
compressed_tensors/config/sparse_bitmask.py,sha256=pZUboRNZTu6NajGOQEFExoPknak5ynVAUeiiYpS1Gt8,1308
|
17
|
+
compressed_tensors/linear/__init__.py,sha256=fH6rjBYAxuwrTzBTlTjTgCYNyh6TCvCqajCz4Im4YrA,617
|
18
|
+
compressed_tensors/linear/compressed_linear.py,sha256=G0gEFfxLAUsgRcnfSV-PKz1ZBNTVokOauOoup7SE1mw,3210
|
17
19
|
compressed_tensors/quantization/__init__.py,sha256=83J5bPB7PavN2TfCoW7_vEDhfYpm4TDrqYO9vdSQ5bk,760
|
18
20
|
compressed_tensors/quantization/quant_args.py,sha256=wSC2ve1P-XRwZUpqEaqvQpj1Xe0EGgmmPEjPk9YEnyg,6797
|
19
21
|
compressed_tensors/quantization/quant_config.py,sha256=NpVu8YJ4Xw2pIQW_PGaNaml8kx1bUnxkvb0jBYWbKdE,9971
|
20
22
|
compressed_tensors/quantization/quant_scheme.py,sha256=_RKOFJI0T5xJVBLX63UeYkSY4EFAecsBnqzUIVBjeU0,6014
|
21
23
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=MXE2E7GfIfRRfhrdGy2Og3AZOz5N59B0ZGFcsD89y6c,821
|
22
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
24
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=uftWFunr_CpCZM_qWfo2O1USXKB2qSYD1pBJsO8BuCU,15285
|
23
25
|
compressed_tensors/quantization/lifecycle/calibration.py,sha256=PlS_EqCOPqJD3QKuLPXO9AOtDzXtQWvEBTynFv-FFVw,2698
|
24
|
-
compressed_tensors/quantization/lifecycle/compressed.py,sha256=
|
25
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
26
|
+
compressed_tensors/quantization/lifecycle/compressed.py,sha256=laNDwvhk4S925qWTPHCufo4uDdMo24NDV1qhsAkf5Iw,2225
|
27
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=fZMSrUXX2NnkQiappEpT5SO-6JxbX5wiw9hyjfKNIZo,13538
|
26
28
|
compressed_tensors/quantization/lifecycle/frozen.py,sha256=h1XYt89MouBTf3jTYLG_6OdFxIu5q2N8tPjsy6J4E6Y,1726
|
27
29
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=TmLY_G5VP_Fg2Ywio_dxoHRTxOKZdT7_aG5S9WtD4zI,2424
|
28
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
30
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=r8GNYIUYVHJ-539mHKnhhGysCluaOG6VieH6CQD4eeo,7112
|
29
31
|
compressed_tensors/quantization/observers/__init__.py,sha256=4Sa7rqi5RB_S5bPO8KmncETiqDsoMBhwP37arlQym8s,764
|
30
32
|
compressed_tensors/quantization/observers/base.py,sha256=5ovQicWPYHjIxr6-EkQ4lgOX0PpI9g23iSzKpxjM1Zg,8420
|
31
33
|
compressed_tensors/quantization/observers/helpers.py,sha256=s_A23Qa_BLfOdHJCN5bm-qPWkhjjj_RIVrhSp1Y9Dtk,4211
|
@@ -37,14 +39,14 @@ compressed_tensors/quantization/utils/helpers.py,sha256=YjXABJQUnelof-z7qcwck6fn
|
|
37
39
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
38
40
|
compressed_tensors/registry/registry.py,sha256=fxjOjh2wklCvJhQxwofdy-zV8q7MkQ85SLG77nml2iA,11890
|
39
41
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
40
|
-
compressed_tensors/utils/helpers.py,sha256=
|
41
|
-
compressed_tensors/utils/offload.py,sha256=
|
42
|
+
compressed_tensors/utils/helpers.py,sha256=bh4G8mj_YCRf8Bo2FQ9FkIIZXY8xqqPjckNnVYB0gBA,3557
|
43
|
+
compressed_tensors/utils/offload.py,sha256=d9q8LNe8HyF8tOjgjA7QGLD3HRysmNp0d8eBbdqBgIM,4089
|
42
44
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
43
45
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
44
46
|
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
45
47
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
46
|
-
compressed_tensors_nightly-0.5.0.
|
47
|
-
compressed_tensors_nightly-0.5.0.
|
48
|
-
compressed_tensors_nightly-0.5.0.
|
49
|
-
compressed_tensors_nightly-0.5.0.
|
50
|
-
compressed_tensors_nightly-0.5.0.
|
48
|
+
compressed_tensors_nightly-0.5.0.20240831.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
49
|
+
compressed_tensors_nightly-0.5.0.20240831.dist-info/METADATA,sha256=EC7NXHFCAhZV33MtR51mgvmE9VItsDlDub4cDiwo3ag,6799
|
50
|
+
compressed_tensors_nightly-0.5.0.20240831.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
51
|
+
compressed_tensors_nightly-0.5.0.20240831.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
52
|
+
compressed_tensors_nightly-0.5.0.20240831.dist-info/RECORD,,
|
File without changes
|
File without changes
|