compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +200 -8
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +101 -13
- compressed_tensors/compressors/naive_quantized.py +140 -0
- compressed_tensors/compressors/pack_quantized.py +128 -132
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +8 -1
- compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
- compressed_tensors/linear/compressed_linear.py +87 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -44
- compressed_tensors/quantization/lifecycle/calibration.py +22 -2
- compressed_tensors/quantization/lifecycle/compressed.py +3 -1
- compressed_tensors/quantization/lifecycle/forward.py +139 -61
- compressed_tensors/quantization/lifecycle/helpers.py +80 -0
- compressed_tensors/quantization/lifecycle/initialize.py +77 -13
- compressed_tensors/quantization/observers/__init__.py +1 -0
- compressed_tensors/quantization/observers/base.py +93 -14
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/observers/min_max.py +8 -0
- compressed_tensors/quantization/observers/mse.py +162 -0
- compressed_tensors/quantization/quant_args.py +139 -23
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +112 -13
- compressed_tensors/quantization/utils/helpers.py +68 -2
- compressed_tensors/utils/__init__.py +5 -0
- compressed_tensors/utils/helpers.py +44 -2
- compressed_tensors/utils/offload.py +116 -0
- compressed_tensors/utils/permute.py +70 -0
- compressed_tensors/utils/safetensors_load.py +2 -0
- compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
- compressed_tensors-0.6.0.dist-info/RECORD +52 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/int_quantized.py +0 -126
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -14,10 +14,14 @@
|
|
14
14
|
|
15
15
|
import logging
|
16
16
|
import re
|
17
|
-
from collections import OrderedDict
|
18
|
-
from
|
17
|
+
from collections import OrderedDict, defaultdict
|
18
|
+
from copy import deepcopy
|
19
|
+
from typing import Dict, Iterable, List, Optional
|
20
|
+
from typing import OrderedDict as OrderedDictType
|
21
|
+
from typing import Union
|
19
22
|
|
20
23
|
import torch
|
24
|
+
from compressed_tensors.config import CompressionFormat
|
21
25
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
22
26
|
set_module_for_calibration,
|
23
27
|
)
|
@@ -28,15 +32,20 @@ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quant
|
|
28
32
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
29
33
|
initialize_module_for_quantization,
|
30
34
|
)
|
35
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
31
36
|
from compressed_tensors.quantization.quant_config import (
|
32
37
|
QuantizationConfig,
|
33
38
|
QuantizationStatus,
|
34
39
|
)
|
40
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
35
41
|
from compressed_tensors.quantization.utils import (
|
42
|
+
KV_CACHE_TARGETS,
|
36
43
|
infer_quantization_status,
|
44
|
+
is_kv_cache_quant_scheme,
|
37
45
|
iter_named_leaf_modules,
|
38
46
|
)
|
39
|
-
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
47
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
48
|
+
from compressed_tensors.utils.offload import update_parameter_data
|
40
49
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
41
50
|
from torch.nn import Module
|
42
51
|
|
@@ -45,7 +54,7 @@ __all__ = [
|
|
45
54
|
"load_pretrained_quantization",
|
46
55
|
"apply_quantization_config",
|
47
56
|
"apply_quantization_status",
|
48
|
-
"
|
57
|
+
"find_name_or_class_matches",
|
49
58
|
]
|
50
59
|
|
51
60
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -96,33 +105,64 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
96
105
|
)
|
97
106
|
|
98
107
|
|
99
|
-
def apply_quantization_config(
|
108
|
+
def apply_quantization_config(
|
109
|
+
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
+
) -> Dict:
|
100
111
|
"""
|
101
112
|
Initializes the model for quantization in-place based on the given config
|
102
113
|
|
103
114
|
:param model: model to apply quantization config to
|
104
115
|
:param config: quantization config
|
116
|
+
:param run_compressed: Whether the model will be run in compressed mode or
|
117
|
+
decompressed fully on load
|
105
118
|
"""
|
119
|
+
# remove reference to the original `config`
|
120
|
+
# argument. This function can mutate it, and we'd
|
121
|
+
# like to keep the original `config` as it is.
|
122
|
+
config = deepcopy(config)
|
106
123
|
# build mapping of targets to schemes for easier matching
|
107
124
|
# use ordered dict to preserve target ordering in config
|
108
125
|
target_to_scheme = OrderedDict()
|
126
|
+
config = process_quantization_config(config)
|
127
|
+
names_to_scheme = OrderedDict()
|
109
128
|
for scheme in config.config_groups.values():
|
110
129
|
for target in scheme.targets:
|
111
130
|
target_to_scheme[target] = scheme
|
112
131
|
|
132
|
+
if run_compressed:
|
133
|
+
from compressed_tensors.linear.compressed_linear import CompressedLinear
|
134
|
+
|
113
135
|
# list of submodules to ignore
|
114
|
-
ignored_submodules =
|
136
|
+
ignored_submodules = defaultdict(list)
|
115
137
|
# mark appropriate layers for quantization by setting their quantization schemes
|
116
138
|
for name, submodule in iter_named_leaf_modules(model):
|
117
139
|
# potentially fix module name to remove FSDP wrapper prefix
|
118
140
|
name = fix_fsdp_module_name(name)
|
119
|
-
if
|
120
|
-
|
141
|
+
if matches := find_name_or_class_matches(name, submodule, config.ignore):
|
142
|
+
for match in matches:
|
143
|
+
ignored_submodules[match].append(name)
|
121
144
|
continue # layer matches ignore list, continue
|
122
|
-
|
123
|
-
if
|
145
|
+
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
146
|
+
if targets:
|
147
|
+
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
|
+
if run_compressed:
|
149
|
+
format = config.format
|
150
|
+
if format != CompressionFormat.dense.value:
|
151
|
+
if isinstance(submodule, torch.nn.Linear):
|
152
|
+
# TODO: expand to more module types
|
153
|
+
compressed_linear = CompressedLinear.from_linear(
|
154
|
+
submodule,
|
155
|
+
quantization_scheme=scheme,
|
156
|
+
quantization_format=format,
|
157
|
+
)
|
158
|
+
replace_module(model, name, compressed_linear)
|
159
|
+
|
124
160
|
# target matched - add layer and scheme to target list
|
125
|
-
submodule.quantization_scheme =
|
161
|
+
submodule.quantization_scheme = _scheme_from_targets(
|
162
|
+
target_to_scheme, targets, name
|
163
|
+
)
|
164
|
+
|
165
|
+
names_to_scheme[name] = submodule.quantization_scheme.weights
|
126
166
|
|
127
167
|
if config.ignore is not None and ignored_submodules is not None:
|
128
168
|
if set(config.ignore) - set(ignored_submodules):
|
@@ -131,8 +171,43 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
131
171
|
"not found in the model: "
|
132
172
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
133
173
|
)
|
174
|
+
|
134
175
|
# apply current quantization status across all targeted layers
|
135
176
|
apply_quantization_status(model, config.quantization_status)
|
177
|
+
return names_to_scheme
|
178
|
+
|
179
|
+
|
180
|
+
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
|
181
|
+
"""
|
182
|
+
Preprocess the raw QuantizationConfig
|
183
|
+
|
184
|
+
:param config: the raw QuantizationConfig
|
185
|
+
:return: the processed QuantizationConfig
|
186
|
+
"""
|
187
|
+
if config.kv_cache_scheme is not None:
|
188
|
+
config = process_kv_cache_config(config)
|
189
|
+
|
190
|
+
return config
|
191
|
+
|
192
|
+
|
193
|
+
def process_kv_cache_config(
|
194
|
+
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
|
195
|
+
) -> QuantizationConfig:
|
196
|
+
"""
|
197
|
+
Reformulate the `config.kv_cache` as a `config_group`
|
198
|
+
and add it to the set of existing `config.groups`
|
199
|
+
|
200
|
+
:param config: the QuantizationConfig
|
201
|
+
:return: the QuantizationConfig with additional "kv_cache" group
|
202
|
+
"""
|
203
|
+
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
204
|
+
kv_cache_scheme = QuantizationScheme(
|
205
|
+
output_activations=QuantizationArgs(**kv_cache_dict),
|
206
|
+
targets=targets,
|
207
|
+
)
|
208
|
+
kv_cache_group = dict(kv_cache=kv_cache_scheme)
|
209
|
+
config.config_groups.update(kv_cache_group)
|
210
|
+
return config
|
136
211
|
|
137
212
|
|
138
213
|
def apply_quantization_status(model: Module, status: QuantizationStatus):
|
@@ -145,10 +220,22 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
145
220
|
current_status = infer_quantization_status(model)
|
146
221
|
|
147
222
|
if status >= QuantizationStatus.INITIALIZED > current_status:
|
148
|
-
|
223
|
+
force_zero_point_init = status != QuantizationStatus.COMPRESSED
|
224
|
+
model.apply(
|
225
|
+
lambda module: initialize_module_for_quantization(
|
226
|
+
module, force_zero_point=force_zero_point_init
|
227
|
+
)
|
228
|
+
)
|
149
229
|
|
150
230
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
151
|
-
|
231
|
+
# only quantize weights up front when our end goal state is calibration,
|
232
|
+
# weight quantization parameters are already loaded for frozen/compressed
|
233
|
+
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
234
|
+
model.apply(
|
235
|
+
lambda module: set_module_for_calibration(
|
236
|
+
module, quantize_weights_upfront=quantize_weights_upfront
|
237
|
+
)
|
238
|
+
)
|
152
239
|
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
153
240
|
model.apply(freeze_module_quantization)
|
154
241
|
|
@@ -156,36 +243,45 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
156
243
|
model.apply(compress_quantized_weights)
|
157
244
|
|
158
245
|
|
159
|
-
def
|
246
|
+
def find_name_or_class_matches(
|
160
247
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
161
|
-
) ->
|
162
|
-
|
163
|
-
|
164
|
-
|
248
|
+
) -> List[str]:
|
249
|
+
"""
|
250
|
+
Returns all targets that match the given name or the class name.
|
251
|
+
Returns empty list otherwise.
|
252
|
+
The order of the output `matches` list matters.
|
253
|
+
The entries are sorted in the following order:
|
254
|
+
1. matches on exact strings
|
255
|
+
2. matches on regex patterns
|
256
|
+
3. matches on module names
|
257
|
+
"""
|
258
|
+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
165
259
|
if isinstance(targets, Iterable):
|
166
|
-
|
260
|
+
matches = _find_matches(name, targets) + _find_matches(
|
167
261
|
module.__class__.__name__, targets, check_contains
|
168
262
|
)
|
263
|
+
matches = [match for match in matches if match is not None]
|
264
|
+
return matches
|
169
265
|
|
170
266
|
|
171
|
-
def
|
267
|
+
def _find_matches(
|
172
268
|
value: str, targets: Iterable[str], check_contains: bool = False
|
173
|
-
) ->
|
174
|
-
# returns
|
269
|
+
) -> List[str]:
|
270
|
+
# returns all the targets that match value either
|
175
271
|
# exactly or as a regex after 're:'. if check_contains is set to True,
|
176
272
|
# additionally checks if the target string is contained with value.
|
177
|
-
|
273
|
+
matches = []
|
178
274
|
for target in targets:
|
179
275
|
if target.startswith("re:"):
|
180
276
|
pattern = target[3:]
|
181
277
|
if re.match(pattern, value):
|
182
|
-
|
278
|
+
matches.append(target)
|
183
279
|
elif check_contains:
|
184
280
|
if target.lower() in value.lower():
|
185
|
-
|
281
|
+
matches.append(target)
|
186
282
|
elif target == value:
|
187
|
-
|
188
|
-
return
|
283
|
+
matches.append(target)
|
284
|
+
return matches
|
189
285
|
|
190
286
|
|
191
287
|
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
@@ -210,20 +306,84 @@ def _load_quant_args_from_state_dict(
|
|
210
306
|
"""
|
211
307
|
scale_name = f"{base_name}_scale"
|
212
308
|
zp_name = f"{base_name}_zero_point"
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
309
|
+
g_idx_name = f"{base_name}_g_idx"
|
310
|
+
|
311
|
+
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
312
|
+
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
313
|
+
state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
|
314
|
+
|
315
|
+
if state_dict_scale is not None:
|
316
|
+
# module is quantized
|
317
|
+
update_parameter_data(module, state_dict_scale, scale_name)
|
318
|
+
if state_dict_zp is None:
|
319
|
+
# fill in zero point for symmetric quantization
|
320
|
+
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
321
|
+
update_parameter_data(module, state_dict_zp, zp_name)
|
322
|
+
|
323
|
+
if state_dict_g_idx is not None:
|
324
|
+
update_parameter_data(module, state_dict_g_idx, g_idx_name)
|
325
|
+
|
326
|
+
|
327
|
+
def _scheme_from_targets(
|
328
|
+
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
329
|
+
targets: List[str],
|
330
|
+
name: str,
|
331
|
+
) -> QuantizationScheme:
|
332
|
+
if len(targets) == 1:
|
333
|
+
# if `targets` iterable contains a single element
|
334
|
+
# use it as the key
|
335
|
+
return target_to_scheme[targets[0]]
|
336
|
+
|
337
|
+
# otherwise, we need to merge QuantizationSchemes corresponding
|
338
|
+
# to multiple targets. This is most likely because `name` module
|
339
|
+
# is being target both as an ordinary quantization target, as well
|
340
|
+
# as kv cache quantization target
|
341
|
+
schemes_to_merge = [target_to_scheme[target] for target in targets]
|
342
|
+
return _merge_schemes(schemes_to_merge, name)
|
343
|
+
|
344
|
+
|
345
|
+
def _merge_schemes(
|
346
|
+
schemes_to_merge: List[QuantizationScheme], name: str
|
347
|
+
) -> QuantizationScheme:
|
348
|
+
|
349
|
+
kv_cache_quantization_scheme = [
|
350
|
+
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
351
|
+
]
|
352
|
+
if not kv_cache_quantization_scheme:
|
353
|
+
# if the schemes_to_merge do not contain any
|
354
|
+
# kv cache QuantizationScheme
|
355
|
+
# return the first scheme (the prioritized one,
|
356
|
+
# since the order of schemes_to_merge matters)
|
357
|
+
return schemes_to_merge[0]
|
358
|
+
else:
|
359
|
+
# fetch the kv cache QuantizationScheme and the highest
|
360
|
+
# priority non-kv cache QuantizationScheme and merge them
|
361
|
+
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
|
362
|
+
quantization_scheme = [
|
363
|
+
scheme
|
364
|
+
for scheme in schemes_to_merge
|
365
|
+
if not is_kv_cache_quant_scheme(scheme)
|
366
|
+
][0]
|
367
|
+
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
|
368
|
+
merged_scheme = {}
|
369
|
+
for scheme in schemes_to_merge:
|
370
|
+
scheme_dict = {
|
371
|
+
k: v for k, v in scheme.model_dump().items() if v is not None
|
372
|
+
}
|
373
|
+
# when merging multiple schemes, the final target will be
|
374
|
+
# the `name` argument - hence erase the original targets
|
375
|
+
del scheme_dict["targets"]
|
376
|
+
# make sure that schemes do not "clash" with each other
|
377
|
+
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
|
378
|
+
if overlapping_keys:
|
379
|
+
raise ValueError(
|
380
|
+
f"The module: {name} is being modified by two clashing "
|
381
|
+
f"quantization schemes, that jointly try to override "
|
382
|
+
f"properties: {overlapping_keys}. Fix the quantization config "
|
383
|
+
"so that it is not ambiguous."
|
384
|
+
)
|
385
|
+
merged_scheme.update(scheme_dict)
|
386
|
+
|
387
|
+
merged_scheme.update(targets=[name])
|
388
|
+
|
389
|
+
return QuantizationScheme(**merged_scheme)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
+
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
|
19
20
|
from torch.nn import Module
|
20
21
|
|
21
22
|
|
@@ -27,7 +28,7 @@ __all__ = [
|
|
27
28
|
_LOGGER = logging.getLogger(__name__)
|
28
29
|
|
29
30
|
|
30
|
-
def set_module_for_calibration(module: Module):
|
31
|
+
def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
|
31
32
|
"""
|
32
33
|
marks a layer as ready for calibration which activates observers
|
33
34
|
to update scales and zero points on each forward pass
|
@@ -35,17 +36,36 @@ def set_module_for_calibration(module: Module):
|
|
35
36
|
apply to full model with `model.apply(set_module_for_calibration)`
|
36
37
|
|
37
38
|
:param module: module to set for calibration
|
39
|
+
:param quantize_weights_upfront: whether to automatically
|
40
|
+
run weight quantization at the start of calibration
|
38
41
|
"""
|
39
42
|
if not getattr(module, "quantization_scheme", None):
|
40
43
|
# no quantization scheme nothing to do
|
41
44
|
return
|
42
45
|
status = getattr(module, "quantization_status", None)
|
43
46
|
if not status or status != QuantizationStatus.INITIALIZED:
|
44
|
-
|
47
|
+
_LOGGER.warning(
|
45
48
|
f"Attempting set module with status {status} to calibration mode. "
|
46
49
|
f"but status is not {QuantizationStatus.INITIALIZED} - you may "
|
47
50
|
"be calibrating an uninitialized module which may fail or attempting "
|
48
51
|
"to re-calibrate a frozen module"
|
49
52
|
)
|
50
53
|
|
54
|
+
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
|
+
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
|
+
observer = module.weight_observer
|
57
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
58
|
+
|
59
|
+
offloaded = False
|
60
|
+
if is_module_offloaded(module):
|
61
|
+
module._hf_hook.pre_forward(module)
|
62
|
+
offloaded = True
|
63
|
+
|
64
|
+
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
65
|
+
update_parameter_data(module, scale, "weight_scale")
|
66
|
+
update_parameter_data(module, zero_point, "weight_zero_point")
|
67
|
+
|
68
|
+
if offloaded:
|
69
|
+
module._hf_hook.post_forward(module, None)
|
70
|
+
|
51
71
|
module.quantization_status = QuantizationStatus.CALIBRATION
|
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
|
|
49
49
|
weight = getattr(module, "weight", None)
|
50
50
|
scale = getattr(module, "weight_scale", None)
|
51
51
|
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
g_idx = getattr(module, "weight_g_idx", None)
|
52
53
|
|
53
|
-
if weight is None or scale is None
|
54
|
+
if weight is None or scale is None:
|
54
55
|
# no weight, scale, or ZP, nothing to do
|
55
56
|
|
56
57
|
# mark as compressed here to maintain consistent status throughout the model
|
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
|
|
62
63
|
x=weight,
|
63
64
|
scale=scale,
|
64
65
|
zero_point=zero_point,
|
66
|
+
g_idx=g_idx,
|
65
67
|
args=scheme.weights,
|
66
68
|
dtype=torch.int8,
|
67
69
|
)
|