compressed-tensors 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +1 -1
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +33 -12
- compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +33 -15
- compressed_tensors/compressors/pack_quantized.py +58 -51
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +2 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +161 -39
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/forward.py +70 -25
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +30 -1
- compressed_tensors/quantization/observers/base.py +39 -0
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/quant_args.py +45 -1
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +105 -4
- compressed_tensors/quantization/utils/helpers.py +67 -1
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +31 -2
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +2 -1
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/utils/__init__.py +0 -19
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- /compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,9 @@
|
|
15
15
|
import logging
|
16
16
|
import re
|
17
17
|
from collections import OrderedDict
|
18
|
-
from typing import Dict, Iterable, Optional
|
18
|
+
from typing import Dict, Iterable, List, Optional
|
19
|
+
from typing import OrderedDict as OrderedDictType
|
20
|
+
from typing import Union
|
19
21
|
|
20
22
|
import torch
|
21
23
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
@@ -28,15 +30,20 @@ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quant
|
|
28
30
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
29
31
|
initialize_module_for_quantization,
|
30
32
|
)
|
33
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
31
34
|
from compressed_tensors.quantization.quant_config import (
|
32
35
|
QuantizationConfig,
|
33
36
|
QuantizationStatus,
|
34
37
|
)
|
38
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
35
39
|
from compressed_tensors.quantization.utils import (
|
40
|
+
KV_CACHE_TARGETS,
|
36
41
|
infer_quantization_status,
|
42
|
+
is_kv_cache_quant_scheme,
|
37
43
|
iter_named_leaf_modules,
|
38
44
|
)
|
39
45
|
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
46
|
+
from compressed_tensors.utils.offload import update_parameter_data
|
40
47
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
41
48
|
from torch.nn import Module
|
42
49
|
|
@@ -45,7 +52,7 @@ __all__ = [
|
|
45
52
|
"load_pretrained_quantization",
|
46
53
|
"apply_quantization_config",
|
47
54
|
"apply_quantization_status",
|
48
|
-
"
|
55
|
+
"find_name_or_class_matches",
|
49
56
|
]
|
50
57
|
|
51
58
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -96,7 +103,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
96
103
|
)
|
97
104
|
|
98
105
|
|
99
|
-
def apply_quantization_config(model: Module, config: QuantizationConfig):
|
106
|
+
def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
|
100
107
|
"""
|
101
108
|
Initializes the model for quantization in-place based on the given config
|
102
109
|
|
@@ -106,6 +113,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
106
113
|
# build mapping of targets to schemes for easier matching
|
107
114
|
# use ordered dict to preserve target ordering in config
|
108
115
|
target_to_scheme = OrderedDict()
|
116
|
+
config = process_quantization_config(config)
|
117
|
+
names_to_scheme = OrderedDict()
|
109
118
|
for scheme in config.config_groups.values():
|
110
119
|
for target in scheme.targets:
|
111
120
|
target_to_scheme[target] = scheme
|
@@ -116,13 +125,16 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
116
125
|
for name, submodule in iter_named_leaf_modules(model):
|
117
126
|
# potentially fix module name to remove FSDP wrapper prefix
|
118
127
|
name = fix_fsdp_module_name(name)
|
119
|
-
if
|
128
|
+
if find_name_or_class_matches(name, submodule, config.ignore):
|
120
129
|
ignored_submodules.append(name)
|
121
130
|
continue # layer matches ignore list, continue
|
122
|
-
|
123
|
-
if
|
131
|
+
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
132
|
+
if targets:
|
124
133
|
# target matched - add layer and scheme to target list
|
125
|
-
submodule.quantization_scheme =
|
134
|
+
submodule.quantization_scheme = _scheme_from_targets(
|
135
|
+
target_to_scheme, targets, name
|
136
|
+
)
|
137
|
+
names_to_scheme[name] = submodule.quantization_scheme.weights
|
126
138
|
|
127
139
|
if config.ignore is not None and ignored_submodules is not None:
|
128
140
|
if set(config.ignore) - set(ignored_submodules):
|
@@ -132,7 +144,42 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
132
144
|
f"{set(config.ignore) - set(ignored_submodules)}"
|
133
145
|
)
|
134
146
|
# apply current quantization status across all targeted layers
|
147
|
+
|
135
148
|
apply_quantization_status(model, config.quantization_status)
|
149
|
+
return names_to_scheme
|
150
|
+
|
151
|
+
|
152
|
+
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
|
153
|
+
"""
|
154
|
+
Preprocess the raw QuantizationConfig
|
155
|
+
|
156
|
+
:param config: the raw QuantizationConfig
|
157
|
+
:return: the processed QuantizationConfig
|
158
|
+
"""
|
159
|
+
if config.kv_cache_scheme is not None:
|
160
|
+
config = process_kv_cache_config(config)
|
161
|
+
|
162
|
+
return config
|
163
|
+
|
164
|
+
|
165
|
+
def process_kv_cache_config(
|
166
|
+
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
|
167
|
+
) -> QuantizationConfig:
|
168
|
+
"""
|
169
|
+
Reformulate the `config.kv_cache` as a `config_group`
|
170
|
+
and add it to the set of existing `config.groups`
|
171
|
+
|
172
|
+
:param config: the QuantizationConfig
|
173
|
+
:return: the QuantizationConfig with additional "kv_cache" group
|
174
|
+
"""
|
175
|
+
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
176
|
+
kv_cache_scheme = QuantizationScheme(
|
177
|
+
output_activations=QuantizationArgs(**kv_cache_dict),
|
178
|
+
targets=targets,
|
179
|
+
)
|
180
|
+
kv_cache_group = dict(kv_cache=kv_cache_scheme)
|
181
|
+
config.config_groups.update(kv_cache_group)
|
182
|
+
return config
|
136
183
|
|
137
184
|
|
138
185
|
def apply_quantization_status(model: Module, status: QuantizationStatus):
|
@@ -148,7 +195,14 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
148
195
|
model.apply(initialize_module_for_quantization)
|
149
196
|
|
150
197
|
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
151
|
-
|
198
|
+
# only quantize weights up front when our end goal state is calibration,
|
199
|
+
# weight quantization parameters are already loaded for frozen/compressed
|
200
|
+
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
201
|
+
model.apply(
|
202
|
+
lambda module: set_module_for_calibration(
|
203
|
+
module, quantize_weights_upfront=quantize_weights_upfront
|
204
|
+
)
|
205
|
+
)
|
152
206
|
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
153
207
|
model.apply(freeze_module_quantization)
|
154
208
|
|
@@ -156,36 +210,45 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
156
210
|
model.apply(compress_quantized_weights)
|
157
211
|
|
158
212
|
|
159
|
-
def
|
213
|
+
def find_name_or_class_matches(
|
160
214
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
161
|
-
) ->
|
162
|
-
|
163
|
-
|
164
|
-
|
215
|
+
) -> List[str]:
|
216
|
+
"""
|
217
|
+
Returns all targets that match the given name or the class name.
|
218
|
+
Returns empty list otherwise.
|
219
|
+
The order of the output `matches` list matters.
|
220
|
+
The entries are sorted in the following order:
|
221
|
+
1. matches on exact strings
|
222
|
+
2. matches on regex patterns
|
223
|
+
3. matches on module names
|
224
|
+
"""
|
225
|
+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
165
226
|
if isinstance(targets, Iterable):
|
166
|
-
|
227
|
+
matches = _find_matches(name, targets) + _find_matches(
|
167
228
|
module.__class__.__name__, targets, check_contains
|
168
229
|
)
|
230
|
+
matches = [match for match in matches if match is not None]
|
231
|
+
return matches
|
169
232
|
|
170
233
|
|
171
|
-
def
|
234
|
+
def _find_matches(
|
172
235
|
value: str, targets: Iterable[str], check_contains: bool = False
|
173
|
-
) ->
|
174
|
-
# returns
|
236
|
+
) -> List[str]:
|
237
|
+
# returns all the targets that match value either
|
175
238
|
# exactly or as a regex after 're:'. if check_contains is set to True,
|
176
239
|
# additionally checks if the target string is contained with value.
|
177
|
-
|
240
|
+
matches = []
|
178
241
|
for target in targets:
|
179
242
|
if target.startswith("re:"):
|
180
243
|
pattern = target[3:]
|
181
244
|
if re.match(pattern, value):
|
182
|
-
|
245
|
+
matches.append(target)
|
183
246
|
elif check_contains:
|
184
247
|
if target.lower() in value.lower():
|
185
|
-
|
248
|
+
matches.append(target)
|
186
249
|
elif target == value:
|
187
|
-
|
188
|
-
return
|
250
|
+
matches.append(target)
|
251
|
+
return matches
|
189
252
|
|
190
253
|
|
191
254
|
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
@@ -210,20 +273,79 @@ def _load_quant_args_from_state_dict(
|
|
210
273
|
"""
|
211
274
|
scale_name = f"{base_name}_scale"
|
212
275
|
zp_name = f"{base_name}_zero_point"
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
if
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
276
|
+
|
277
|
+
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
278
|
+
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
279
|
+
|
280
|
+
if state_dict_scale is not None:
|
281
|
+
# module is quantized
|
282
|
+
update_parameter_data(module, state_dict_scale, scale_name)
|
283
|
+
if state_dict_zp is None:
|
284
|
+
# fill in zero point for symmetric quantization
|
285
|
+
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
286
|
+
update_parameter_data(module, state_dict_zp, zp_name)
|
287
|
+
|
288
|
+
|
289
|
+
def _scheme_from_targets(
|
290
|
+
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
291
|
+
targets: List[str],
|
292
|
+
name: str,
|
293
|
+
) -> QuantizationScheme:
|
294
|
+
if len(targets) == 1:
|
295
|
+
# if `targets` iterable contains a single element
|
296
|
+
# use it as the key
|
297
|
+
return target_to_scheme[targets[0]]
|
298
|
+
|
299
|
+
# otherwise, we need to merge QuantizationSchemes corresponding
|
300
|
+
# to multiple targets. This is most likely because `name` module
|
301
|
+
# is being target both as an ordinary quantization target, as well
|
302
|
+
# as kv cache quantization target
|
303
|
+
schemes_to_merge = [target_to_scheme[target] for target in targets]
|
304
|
+
return _merge_schemes(schemes_to_merge, name)
|
305
|
+
|
306
|
+
|
307
|
+
def _merge_schemes(
|
308
|
+
schemes_to_merge: List[QuantizationScheme], name: str
|
309
|
+
) -> QuantizationScheme:
|
310
|
+
|
311
|
+
kv_cache_quantization_scheme = [
|
312
|
+
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
313
|
+
]
|
314
|
+
if not kv_cache_quantization_scheme:
|
315
|
+
# if the schemes_to_merge do not contain any
|
316
|
+
# kv cache QuantizationScheme
|
317
|
+
# return the first scheme (the prioritized one,
|
318
|
+
# since the order of schemes_to_merge matters)
|
319
|
+
return schemes_to_merge[0]
|
320
|
+
else:
|
321
|
+
# fetch the kv cache QuantizationScheme and the highest
|
322
|
+
# priority non-kv cache QuantizationScheme and merge them
|
323
|
+
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
|
324
|
+
quantization_scheme = [
|
325
|
+
scheme
|
326
|
+
for scheme in schemes_to_merge
|
327
|
+
if not is_kv_cache_quant_scheme(scheme)
|
328
|
+
][0]
|
329
|
+
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
|
330
|
+
merged_scheme = {}
|
331
|
+
for scheme in schemes_to_merge:
|
332
|
+
scheme_dict = {
|
333
|
+
k: v for k, v in scheme.model_dump().items() if v is not None
|
334
|
+
}
|
335
|
+
# when merging multiple schemes, the final target will be
|
336
|
+
# the `name` argument - hence erase the original targets
|
337
|
+
del scheme_dict["targets"]
|
338
|
+
# make sure that schemes do not "clash" with each other
|
339
|
+
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
|
340
|
+
if overlapping_keys:
|
341
|
+
raise ValueError(
|
342
|
+
f"The module: {name} is being modified by two clashing "
|
343
|
+
f"quantization schemes, that jointly try to override "
|
344
|
+
f"properties: {overlapping_keys}. Fix the quantization config "
|
345
|
+
"so that it is not ambiguous."
|
346
|
+
)
|
347
|
+
merged_scheme.update(scheme_dict)
|
348
|
+
|
349
|
+
merged_scheme.update(targets=[name])
|
350
|
+
|
351
|
+
return QuantizationScheme(**merged_scheme)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
+
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
|
19
20
|
from torch.nn import Module
|
20
21
|
|
21
22
|
|
@@ -27,7 +28,7 @@ __all__ = [
|
|
27
28
|
_LOGGER = logging.getLogger(__name__)
|
28
29
|
|
29
30
|
|
30
|
-
def set_module_for_calibration(module: Module):
|
31
|
+
def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
|
31
32
|
"""
|
32
33
|
marks a layer as ready for calibration which activates observers
|
33
34
|
to update scales and zero points on each forward pass
|
@@ -35,6 +36,8 @@ def set_module_for_calibration(module: Module):
|
|
35
36
|
apply to full model with `model.apply(set_module_for_calibration)`
|
36
37
|
|
37
38
|
:param module: module to set for calibration
|
39
|
+
:param quantize_weights_upfront: whether to automatically run weight quantization at the
|
40
|
+
start of calibration
|
38
41
|
"""
|
39
42
|
if not getattr(module, "quantization_scheme", None):
|
40
43
|
# no quantization scheme nothing to do
|
@@ -48,4 +51,20 @@ def set_module_for_calibration(module: Module):
|
|
48
51
|
"to re-calibrate a frozen module"
|
49
52
|
)
|
50
53
|
|
54
|
+
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
|
+
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
|
+
observer = module.weight_observer
|
57
|
+
|
58
|
+
offloaded = False
|
59
|
+
if is_module_offloaded(module):
|
60
|
+
module._hf_hook.pre_forward(module)
|
61
|
+
offloaded = True
|
62
|
+
|
63
|
+
scale, zero_point = observer(module.weight)
|
64
|
+
update_parameter_data(module, scale, "weight_scale")
|
65
|
+
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
|
+
|
67
|
+
if offloaded:
|
68
|
+
module._hf_hook.post_forward(module, None)
|
69
|
+
|
51
70
|
module.quantization_status = QuantizationStatus.CALIBRATION
|
@@ -17,12 +17,15 @@ from math import ceil
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.helpers import calculate_range
|
20
21
|
from compressed_tensors.quantization.quant_args import (
|
21
22
|
QuantizationArgs,
|
22
23
|
QuantizationStrategy,
|
24
|
+
round_to_quantized_type,
|
23
25
|
)
|
24
26
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
25
27
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
28
|
+
from compressed_tensors.utils import update_parameter_data
|
26
29
|
from torch.nn import Module
|
27
30
|
|
28
31
|
|
@@ -80,8 +83,9 @@ def quantize(
|
|
80
83
|
def dequantize(
|
81
84
|
x_q: torch.Tensor,
|
82
85
|
scale: torch.Tensor,
|
83
|
-
zero_point: torch.Tensor,
|
86
|
+
zero_point: torch.Tensor = None,
|
84
87
|
args: QuantizationArgs = None,
|
88
|
+
dtype: Optional[torch.dtype] = None,
|
85
89
|
) -> torch.Tensor:
|
86
90
|
"""
|
87
91
|
Dequantize a quantized input tensor x_q based on the strategy specified in args. If
|
@@ -91,6 +95,7 @@ def dequantize(
|
|
91
95
|
:param scale: scale tensor
|
92
96
|
:param zero_point: zero point tensor
|
93
97
|
:param args: quantization args used to quantize x_q
|
98
|
+
:param dtype: optional dtype to cast the dequantized output to
|
94
99
|
:return: dequantized float tensor
|
95
100
|
"""
|
96
101
|
if args is None:
|
@@ -107,8 +112,12 @@ def dequantize(
|
|
107
112
|
else:
|
108
113
|
raise ValueError(
|
109
114
|
f"Could not infer a quantization strategy from scale with {scale.ndim} "
|
110
|
-
"dimmensions. Expected 0
|
115
|
+
"dimmensions. Expected 0 or 2 dimmensions."
|
111
116
|
)
|
117
|
+
|
118
|
+
if dtype is None:
|
119
|
+
dtype = scale.dtype
|
120
|
+
|
112
121
|
return _process_quantization(
|
113
122
|
x=x_q,
|
114
123
|
scale=scale,
|
@@ -116,6 +125,7 @@ def dequantize(
|
|
116
125
|
args=args,
|
117
126
|
do_quantize=False,
|
118
127
|
do_dequantize=True,
|
128
|
+
dtype=dtype,
|
119
129
|
)
|
120
130
|
|
121
131
|
|
@@ -159,19 +169,13 @@ def _process_quantization(
|
|
159
169
|
do_quantize: bool = True,
|
160
170
|
do_dequantize: bool = True,
|
161
171
|
) -> torch.Tensor:
|
162
|
-
|
163
|
-
q_max =
|
164
|
-
q_min = torch.tensor(-bit_range / 2, device=x.device)
|
172
|
+
|
173
|
+
q_min, q_max = calculate_range(args, x.device)
|
165
174
|
group_size = args.group_size
|
166
175
|
|
167
176
|
if args.strategy == QuantizationStrategy.GROUP:
|
168
|
-
|
169
|
-
|
170
|
-
# if dequantizing a quantized type infer the output type from the scale
|
171
|
-
output = torch.zeros_like(x, dtype=scale.dtype)
|
172
|
-
else:
|
173
|
-
output_dtype = dtype if dtype is not None else x.dtype
|
174
|
-
output = torch.zeros_like(x, dtype=output_dtype)
|
177
|
+
output_dtype = dtype if dtype is not None else x.dtype
|
178
|
+
output = torch.zeros_like(x).to(output_dtype)
|
175
179
|
|
176
180
|
# TODO: vectorize the for loop
|
177
181
|
# TODO: fix genetric assumption about the tensor size for computing group
|
@@ -181,7 +185,7 @@ def _process_quantization(
|
|
181
185
|
while scale.ndim < 2:
|
182
186
|
# pad scale and zero point dims for slicing
|
183
187
|
scale = scale.unsqueeze(1)
|
184
|
-
zero_point = zero_point.unsqueeze(1)
|
188
|
+
zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
|
185
189
|
|
186
190
|
columns = x.shape[1]
|
187
191
|
if columns >= group_size:
|
@@ -194,12 +198,18 @@ def _process_quantization(
|
|
194
198
|
# scale.shape should be [nchan, ndim]
|
195
199
|
# sc.shape should be [nchan, 1] after unsqueeze
|
196
200
|
sc = scale[:, i].view(-1, 1)
|
197
|
-
zp = zero_point[:, i].view(-1, 1)
|
201
|
+
zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
|
198
202
|
|
199
203
|
idx = i * group_size
|
200
204
|
if do_quantize:
|
201
205
|
output[:, idx : (idx + group_size)] = _quantize(
|
202
|
-
x[:, idx : (idx + group_size)],
|
206
|
+
x[:, idx : (idx + group_size)],
|
207
|
+
sc,
|
208
|
+
zp,
|
209
|
+
q_min,
|
210
|
+
q_max,
|
211
|
+
args,
|
212
|
+
dtype=dtype,
|
203
213
|
)
|
204
214
|
if do_dequantize:
|
205
215
|
input = (
|
@@ -211,7 +221,15 @@ def _process_quantization(
|
|
211
221
|
|
212
222
|
else: # covers channel, token and tensor strategies
|
213
223
|
if do_quantize:
|
214
|
-
output = _quantize(
|
224
|
+
output = _quantize(
|
225
|
+
x,
|
226
|
+
scale,
|
227
|
+
zero_point,
|
228
|
+
q_min,
|
229
|
+
q_max,
|
230
|
+
args,
|
231
|
+
dtype=dtype,
|
232
|
+
)
|
215
233
|
if do_dequantize:
|
216
234
|
output = _dequantize(output if do_quantize else x, scale, zero_point)
|
217
235
|
|
@@ -228,6 +246,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
228
246
|
|
229
247
|
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
230
248
|
def wrapped_forward(self, *args, **kwargs):
|
249
|
+
if not getattr(module, "quantization_enabled", True):
|
250
|
+
# quantization is disabled on forward passes, return baseline
|
251
|
+
# forward call
|
252
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
253
|
+
|
231
254
|
input_ = args[0]
|
232
255
|
|
233
256
|
if scheme.input_activations is not None:
|
@@ -276,6 +299,11 @@ def maybe_calibrate_or_quantize(
|
|
276
299
|
}:
|
277
300
|
return value
|
278
301
|
|
302
|
+
if value.numel() == 0:
|
303
|
+
# if the tensor is empty,
|
304
|
+
# skip quantization
|
305
|
+
return value
|
306
|
+
|
279
307
|
if args.dynamic:
|
280
308
|
# dynamic quantization - get scale and zero point directly from observer
|
281
309
|
observer = getattr(module, f"{base_name}_observer")
|
@@ -285,16 +313,19 @@ def maybe_calibrate_or_quantize(
|
|
285
313
|
scale = getattr(module, f"{base_name}_scale")
|
286
314
|
zero_point = getattr(module, f"{base_name}_zero_point")
|
287
315
|
|
288
|
-
if
|
316
|
+
if (
|
317
|
+
module.quantization_status == QuantizationStatus.CALIBRATION
|
318
|
+
and base_name != "weight"
|
319
|
+
):
|
289
320
|
# calibration mode - get new quant params from observer
|
290
321
|
observer = getattr(module, f"{base_name}_observer")
|
291
322
|
|
292
323
|
updated_scale, updated_zero_point = observer(value)
|
293
324
|
|
294
325
|
# update scale and zero point
|
295
|
-
|
296
|
-
|
297
|
-
|
326
|
+
update_parameter_data(module, updated_scale, f"{base_name}_scale")
|
327
|
+
update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
|
328
|
+
|
298
329
|
return fake_quantize(value, scale, zero_point, args)
|
299
330
|
|
300
331
|
|
@@ -305,14 +336,18 @@ def _quantize(
|
|
305
336
|
zero_point: torch.Tensor,
|
306
337
|
q_min: torch.Tensor,
|
307
338
|
q_max: torch.Tensor,
|
339
|
+
args: QuantizationArgs,
|
308
340
|
dtype: Optional[torch.dtype] = None,
|
309
341
|
) -> torch.Tensor:
|
310
|
-
|
311
|
-
|
342
|
+
|
343
|
+
scaled = x / scale + zero_point.to(x.dtype)
|
344
|
+
# clamp first because cast isn't guaranteed to be saturated (ie for fp8)
|
345
|
+
clamped_value = torch.clamp(
|
346
|
+
scaled,
|
312
347
|
q_min,
|
313
348
|
q_max,
|
314
349
|
)
|
315
|
-
|
350
|
+
quantized_value = round_to_quantized_type(clamped_value, args)
|
316
351
|
if dtype is not None:
|
317
352
|
quantized_value = quantized_value.to(dtype)
|
318
353
|
|
@@ -323,6 +358,16 @@ def _quantize(
|
|
323
358
|
def _dequantize(
|
324
359
|
x_q: torch.Tensor,
|
325
360
|
scale: torch.Tensor,
|
326
|
-
zero_point: torch.Tensor,
|
361
|
+
zero_point: torch.Tensor = None,
|
362
|
+
dtype: Optional[torch.dtype] = None,
|
327
363
|
) -> torch.Tensor:
|
328
|
-
|
364
|
+
|
365
|
+
dequant_value = x_q
|
366
|
+
if zero_point is not None:
|
367
|
+
dequant_value = dequant_value - zero_point.to(scale.dtype)
|
368
|
+
dequant_value = dequant_value.to(scale.dtype) * scale
|
369
|
+
|
370
|
+
if dtype is not None:
|
371
|
+
dequant_value = dequant_value.to(dtype)
|
372
|
+
|
373
|
+
return dequant_value
|
@@ -0,0 +1,53 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
"""
|
16
|
+
Miscelaneous helpers for the quantization lifecycle
|
17
|
+
"""
|
18
|
+
|
19
|
+
|
20
|
+
from torch.nn import Module
|
21
|
+
|
22
|
+
|
23
|
+
__all__ = [
|
24
|
+
"update_layer_weight_quant_params",
|
25
|
+
"enable_quantization",
|
26
|
+
"disable_quantization",
|
27
|
+
]
|
28
|
+
|
29
|
+
|
30
|
+
def update_layer_weight_quant_params(layer: Module):
|
31
|
+
weight = getattr(layer, "weight", None)
|
32
|
+
scale = getattr(layer, "weight_scale", None)
|
33
|
+
zero_point = getattr(layer, "weight_zero_point", None)
|
34
|
+
observer = getattr(layer, "weight_observer", None)
|
35
|
+
|
36
|
+
if weight is None or observer is None or scale is None or zero_point is None:
|
37
|
+
# scale, zp, or observer not calibratable or weight not available
|
38
|
+
return
|
39
|
+
|
40
|
+
updated_scale, updated_zero_point = observer(weight)
|
41
|
+
|
42
|
+
# update scale and zero point
|
43
|
+
device = next(layer.parameters()).device
|
44
|
+
scale.data = updated_scale.to(device)
|
45
|
+
zero_point.data = updated_zero_point.to(device)
|
46
|
+
|
47
|
+
|
48
|
+
def enable_quantization(module: Module):
|
49
|
+
module.quantization_enabled = True
|
50
|
+
|
51
|
+
|
52
|
+
def disable_quantization(module: Module):
|
53
|
+
module.quantization_enabled = False
|
@@ -17,6 +17,8 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
21
|
+
from accelerate.utils import PrefixedDataset
|
20
22
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
23
|
wrap_module_forward_quantized,
|
22
24
|
)
|
@@ -26,6 +28,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
26
28
|
)
|
27
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
28
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
|
+
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
29
32
|
from torch.nn import Module, Parameter
|
30
33
|
|
31
34
|
|
@@ -81,9 +84,32 @@ def initialize_module_for_quantization(
|
|
81
84
|
module.quantization_scheme = scheme
|
82
85
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
83
86
|
|
87
|
+
offloaded = False
|
88
|
+
if is_module_offloaded(module):
|
89
|
+
offloaded = True
|
90
|
+
hook = module._hf_hook
|
91
|
+
prefix_dict = module._hf_hook.weights_map
|
92
|
+
new_prefix = {}
|
93
|
+
|
94
|
+
# recreate the prefix dict (since it is immutable)
|
95
|
+
# and add quantization parameters
|
96
|
+
for key, data in module.named_parameters():
|
97
|
+
if key not in prefix_dict:
|
98
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
99
|
+
else:
|
100
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
101
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
102
|
+
remove_hook_from_module(module)
|
103
|
+
|
84
104
|
# wrap forward call of module to perform quantized actions based on calltime status
|
85
105
|
wrap_module_forward_quantized(module, scheme)
|
86
106
|
|
107
|
+
if offloaded:
|
108
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
109
|
+
add_hook_to_module(module, hook)
|
110
|
+
if prefix_dict is not None:
|
111
|
+
module._hf_hook.weights_map = new_prefix_dict
|
112
|
+
|
87
113
|
|
88
114
|
def _initialize_scale_zero_point_observer(
|
89
115
|
module: Module,
|
@@ -99,6 +125,8 @@ def _initialize_scale_zero_point_observer(
|
|
99
125
|
return # no need to register a scale and zero point for a dynamic observer
|
100
126
|
|
101
127
|
device = next(module.parameters()).device
|
128
|
+
if is_module_offloaded(module):
|
129
|
+
device = get_execution_device(module)
|
102
130
|
|
103
131
|
# infer expected scale/zero point shape
|
104
132
|
expected_shape = 1 # per tensor
|
@@ -120,8 +148,9 @@ def _initialize_scale_zero_point_observer(
|
|
120
148
|
)
|
121
149
|
module.register_parameter(f"{base_name}_scale", init_scale)
|
122
150
|
|
151
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
123
152
|
init_zero_point = Parameter(
|
124
|
-
torch.empty(expected_shape, device=device, dtype=
|
153
|
+
torch.empty(expected_shape, device=device, dtype=zp_dtype),
|
125
154
|
requires_grad=False,
|
126
155
|
)
|
127
156
|
module.register_parameter(f"{base_name}_zero_point", init_zero_point)
|