compressed-tensors 0.6.0__py3-none-any.whl → 0.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +38 -102
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/quantized_compressors/base.py +146 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
- compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/linear/compressed_linear.py +2 -2
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +19 -3
- compressed_tensors/quantization/lifecycle/calibration.py +2 -3
- compressed_tensors/quantization/lifecycle/forward.py +52 -3
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -47
- compressed_tensors/quantization/lifecycle/initialize.py +110 -62
- compressed_tensors/quantization/quant_args.py +6 -0
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +5 -4
- compressed_tensors/quantization/utils/helpers.py +43 -18
- compressed_tensors/utils/helpers.py +17 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +1 -1
- compressed_tensors-0.7.0.dist-info/RECORD +59 -0
- compressed_tensors-0.6.0.dist-info/RECORD +0 -52
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,201 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple
|
18
|
+
|
19
|
+
from compressed_tensors.quantization.observers import Observer
|
20
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
21
|
+
from torch import Tensor
|
22
|
+
from transformers import DynamicCache as HFDyanmicCache
|
23
|
+
|
24
|
+
|
25
|
+
class KVCacheScaleType(Enum):
|
26
|
+
KEY = "k_scale"
|
27
|
+
VALUE = "v_scale"
|
28
|
+
|
29
|
+
|
30
|
+
class QuantizedKVParameterCache(HFDyanmicCache):
|
31
|
+
|
32
|
+
"""
|
33
|
+
Quantized KV cache used in the forward call based on HF's dynamic cache.
|
34
|
+
Quantization strategy (tensor, group, channel) set from Quantization arg's strategy
|
35
|
+
Singleton, so that the same cache gets reused in all forward call of self_attn.
|
36
|
+
Each time forward is called, .update() is called, and ._quantize(), ._dequantize()
|
37
|
+
gets called appropriately.
|
38
|
+
The size of tensor is
|
39
|
+
`[batch_size, num_heads, seq_len - residual_length, head_dim]`.
|
40
|
+
|
41
|
+
|
42
|
+
Triggered by adding kv_cache_scheme in the recipe.
|
43
|
+
|
44
|
+
Example:
|
45
|
+
|
46
|
+
```python3
|
47
|
+
recipe = '''
|
48
|
+
quant_stage:
|
49
|
+
quant_modifiers:
|
50
|
+
QuantizationModifier:
|
51
|
+
kv_cache_scheme:
|
52
|
+
num_bits: 8
|
53
|
+
type: float
|
54
|
+
strategy: tensor
|
55
|
+
dynamic: false
|
56
|
+
symmetric: true
|
57
|
+
'''
|
58
|
+
|
59
|
+
"""
|
60
|
+
|
61
|
+
_instance = None
|
62
|
+
_initialized = False
|
63
|
+
|
64
|
+
def __new__(cls, *args, **kwargs):
|
65
|
+
"""Singleton"""
|
66
|
+
if cls._instance is None:
|
67
|
+
cls._instance = super(QuantizedKVParameterCache, cls).__new__(cls)
|
68
|
+
return cls._instance
|
69
|
+
|
70
|
+
def __init__(self, quantization_args: QuantizationArgs):
|
71
|
+
if not self._initialized:
|
72
|
+
super().__init__()
|
73
|
+
|
74
|
+
self.quantization_args = quantization_args
|
75
|
+
|
76
|
+
self.k_observers: List[Observer] = []
|
77
|
+
self.v_observers: List[Observer] = []
|
78
|
+
|
79
|
+
# each index corresponds to layer_idx of the attention layer
|
80
|
+
self.k_scales: List[Tensor] = []
|
81
|
+
self.v_scales: List[Tensor] = []
|
82
|
+
|
83
|
+
self.k_zps: List[Tensor] = []
|
84
|
+
self.v_zps: List[Tensor] = []
|
85
|
+
|
86
|
+
self._initialized = True
|
87
|
+
|
88
|
+
def update(
|
89
|
+
self,
|
90
|
+
key_states: Tensor,
|
91
|
+
value_states: Tensor,
|
92
|
+
layer_idx: int,
|
93
|
+
cache_kwargs: Optional[Dict[str, Any]] = None,
|
94
|
+
) -> Tuple[Tensor, Tensor]:
|
95
|
+
"""
|
96
|
+
Get the k_scale and v_scale and output the
|
97
|
+
fakequant-ed key_states and value_states
|
98
|
+
"""
|
99
|
+
|
100
|
+
if len(self.k_observers) <= layer_idx:
|
101
|
+
k_observer = self.quantization_args.get_observer()
|
102
|
+
v_observer = self.quantization_args.get_observer()
|
103
|
+
|
104
|
+
self.k_observers.append(k_observer)
|
105
|
+
self.v_observers.append(v_observer)
|
106
|
+
|
107
|
+
q_key_states = self._quantize(
|
108
|
+
key_states.contiguous(), KVCacheScaleType.KEY, layer_idx
|
109
|
+
)
|
110
|
+
q_value_states = self._quantize(
|
111
|
+
value_states.contiguous(), KVCacheScaleType.VALUE, layer_idx
|
112
|
+
)
|
113
|
+
|
114
|
+
qdq_key_states = self._dequantize(q_key_states, KVCacheScaleType.KEY, layer_idx)
|
115
|
+
qdq_value_states = self._dequantize(
|
116
|
+
q_value_states, KVCacheScaleType.VALUE, layer_idx
|
117
|
+
)
|
118
|
+
|
119
|
+
keys_to_return, values_to_return = qdq_key_states, qdq_value_states
|
120
|
+
|
121
|
+
return keys_to_return, values_to_return
|
122
|
+
|
123
|
+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
124
|
+
"""
|
125
|
+
Returns the sequence length of the cached states.
|
126
|
+
A layer index can be optionally passed.
|
127
|
+
"""
|
128
|
+
if len(self.key_cache) <= layer_idx:
|
129
|
+
return 0
|
130
|
+
# since we cannot get the seq_length of each layer directly and
|
131
|
+
# rely on `_seen_tokens` which is updated every "layer_idx" == 0,
|
132
|
+
# this is a hack to get the actual seq_length for the given layer_idx
|
133
|
+
# this part of code otherwise fails when used to
|
134
|
+
# verify attn_weight shape in some models
|
135
|
+
return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
|
136
|
+
|
137
|
+
def reset_states(self):
|
138
|
+
"""reset the kv states (used in calibration)"""
|
139
|
+
self.key_cache: List[Tensor] = []
|
140
|
+
self.value_cache: List[Tensor] = []
|
141
|
+
# Used in `generate` to keep tally of how many tokens the cache has seen
|
142
|
+
self._seen_tokens = 0
|
143
|
+
self._quantized_key_cache: List[Tensor] = []
|
144
|
+
self._quantized_value_cache: List[Tensor] = []
|
145
|
+
|
146
|
+
def reset(self):
|
147
|
+
"""
|
148
|
+
Reset the instantiation, create new instance on init
|
149
|
+
"""
|
150
|
+
QuantizedKVParameterCache._instance = None
|
151
|
+
QuantizedKVParameterCache._initialized = False
|
152
|
+
|
153
|
+
def _quantize(self, tensor, kv_type, layer_idx):
|
154
|
+
"""Quantizes a key/value using a defined quantization method."""
|
155
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
156
|
+
|
157
|
+
if kv_type == KVCacheScaleType.KEY: # key type
|
158
|
+
observer = self.k_observers[layer_idx]
|
159
|
+
scales = self.k_scales
|
160
|
+
zps = self.k_zps
|
161
|
+
else:
|
162
|
+
assert kv_type == KVCacheScaleType.VALUE
|
163
|
+
observer = self.v_observers[layer_idx]
|
164
|
+
scales = self.v_scales
|
165
|
+
zps = self.v_zps
|
166
|
+
|
167
|
+
scale, zp = observer(tensor)
|
168
|
+
if len(scales) <= layer_idx:
|
169
|
+
scales.append(scale)
|
170
|
+
zps.append(zp)
|
171
|
+
else:
|
172
|
+
scales[layer_idx] = scale
|
173
|
+
zps[layer_idx] = scale
|
174
|
+
|
175
|
+
q_tensor = quantize(
|
176
|
+
x=tensor,
|
177
|
+
scale=scale,
|
178
|
+
zero_point=zp,
|
179
|
+
args=self.quantization_args,
|
180
|
+
)
|
181
|
+
return q_tensor
|
182
|
+
|
183
|
+
def _dequantize(self, qtensor, kv_type, layer_idx):
|
184
|
+
"""Dequantizes back the tensor that was quantized by `self._quantize()`"""
|
185
|
+
from compressed_tensors.quantization.lifecycle.forward import dequantize
|
186
|
+
|
187
|
+
if kv_type == KVCacheScaleType.KEY:
|
188
|
+
scale = self.k_scales[layer_idx]
|
189
|
+
zp = self.k_zps[layer_idx]
|
190
|
+
else:
|
191
|
+
assert kv_type == KVCacheScaleType.VALUE
|
192
|
+
scale = self.v_scales[layer_idx]
|
193
|
+
zp = self.v_zps[layer_idx]
|
194
|
+
|
195
|
+
qdq_tensor = dequantize(
|
196
|
+
x_q=qtensor,
|
197
|
+
scale=scale,
|
198
|
+
zero_point=zp,
|
199
|
+
args=self.quantization_args,
|
200
|
+
)
|
201
|
+
return qdq_tensor
|
@@ -43,6 +43,7 @@ from compressed_tensors.quantization.utils import (
|
|
43
43
|
infer_quantization_status,
|
44
44
|
is_kv_cache_quant_scheme,
|
45
45
|
iter_named_leaf_modules,
|
46
|
+
iter_named_quantizable_modules,
|
46
47
|
)
|
47
48
|
from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
48
49
|
from compressed_tensors.utils.offload import update_parameter_data
|
@@ -106,8 +107,8 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
106
107
|
|
107
108
|
|
108
109
|
def apply_quantization_config(
|
109
|
-
model: Module, config: QuantizationConfig, run_compressed: bool = False
|
110
|
-
) ->
|
110
|
+
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
111
|
+
) -> OrderedDict:
|
111
112
|
"""
|
112
113
|
Initializes the model for quantization in-place based on the given config
|
113
114
|
|
@@ -116,6 +117,10 @@ def apply_quantization_config(
|
|
116
117
|
:param run_compressed: Whether the model will be run in compressed mode or
|
117
118
|
decompressed fully on load
|
118
119
|
"""
|
120
|
+
# Workaround for when HF Quantizer passes None, see PR #180
|
121
|
+
if config is None:
|
122
|
+
return OrderedDict()
|
123
|
+
|
119
124
|
# remove reference to the original `config`
|
120
125
|
# argument. This function can mutate it, and we'd
|
121
126
|
# like to keep the original `config` as it is.
|
@@ -135,15 +140,23 @@ def apply_quantization_config(
|
|
135
140
|
# list of submodules to ignore
|
136
141
|
ignored_submodules = defaultdict(list)
|
137
142
|
# mark appropriate layers for quantization by setting their quantization schemes
|
138
|
-
for name, submodule in
|
143
|
+
for name, submodule in iter_named_quantizable_modules(
|
144
|
+
model,
|
145
|
+
include_children=True,
|
146
|
+
include_attn=True,
|
147
|
+
): # child modules and attention modules
|
139
148
|
# potentially fix module name to remove FSDP wrapper prefix
|
140
149
|
name = fix_fsdp_module_name(name)
|
141
150
|
if matches := find_name_or_class_matches(name, submodule, config.ignore):
|
142
151
|
for match in matches:
|
143
152
|
ignored_submodules[match].append(name)
|
144
153
|
continue # layer matches ignore list, continue
|
154
|
+
|
145
155
|
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
156
|
+
|
146
157
|
if targets:
|
158
|
+
# mark modules to be quantized by adding
|
159
|
+
# quant scheme to the matching layers
|
147
160
|
scheme = _scheme_from_targets(target_to_scheme, targets, name)
|
148
161
|
if run_compressed:
|
149
162
|
format = config.format
|
@@ -200,6 +213,9 @@ def process_kv_cache_config(
|
|
200
213
|
:param config: the QuantizationConfig
|
201
214
|
:return: the QuantizationConfig with additional "kv_cache" group
|
202
215
|
"""
|
216
|
+
if targets == KV_CACHE_TARGETS:
|
217
|
+
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
|
218
|
+
|
203
219
|
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
204
220
|
kv_cache_scheme = QuantizationScheme(
|
205
221
|
output_activations=QuantizationArgs(**kv_cache_dict),
|
@@ -56,10 +56,9 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
|
|
56
56
|
observer = module.weight_observer
|
57
57
|
g_idx = getattr(module, "weight_g_idx", None)
|
58
58
|
|
59
|
-
offloaded =
|
60
|
-
if
|
59
|
+
offloaded = is_module_offloaded(module)
|
60
|
+
if offloaded:
|
61
61
|
module._hf_hook.pre_forward(module)
|
62
|
-
offloaded = True
|
63
62
|
|
64
63
|
scale, zero_point = observer(module.weight, g_idx=g_idx)
|
65
64
|
update_parameter_data(module, scale, "weight_scale")
|
@@ -14,9 +14,10 @@
|
|
14
14
|
|
15
15
|
from functools import wraps
|
16
16
|
from math import ceil
|
17
|
-
from typing import Optional
|
17
|
+
from typing import Callable, Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
20
21
|
from compressed_tensors.quantization.observers.helpers import calculate_range
|
21
22
|
from compressed_tensors.quantization.quant_args import (
|
22
23
|
QuantizationArgs,
|
@@ -62,6 +63,7 @@ def quantize(
|
|
62
63
|
:param g_idx: optional mapping from column index to group index
|
63
64
|
:return: fake quantized tensor
|
64
65
|
"""
|
66
|
+
|
65
67
|
return _process_quantization(
|
66
68
|
x=x,
|
67
69
|
scale=scale,
|
@@ -165,8 +167,8 @@ def _process_quantization(
|
|
165
167
|
x: torch.Tensor,
|
166
168
|
scale: torch.Tensor,
|
167
169
|
zero_point: torch.Tensor,
|
168
|
-
g_idx: Optional[torch.Tensor],
|
169
170
|
args: QuantizationArgs,
|
171
|
+
g_idx: Optional[torch.Tensor] = None,
|
170
172
|
dtype: Optional[torch.dtype] = None,
|
171
173
|
do_quantize: bool = True,
|
172
174
|
do_dequantize: bool = True,
|
@@ -266,6 +268,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
266
268
|
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
267
269
|
|
268
270
|
input_ = args[0]
|
271
|
+
|
269
272
|
compressed = module.quantization_status == QuantizationStatus.COMPRESSED
|
270
273
|
|
271
274
|
if scheme.input_activations is not None:
|
@@ -285,9 +288,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
285
288
|
output = forward_func_orig.__get__(module, module.__class__)(
|
286
289
|
input_, *args[1:], **kwargs
|
287
290
|
)
|
288
|
-
|
289
291
|
if scheme.output_activations is not None:
|
292
|
+
|
290
293
|
# calibrate and (fake) quantize output activations when applicable
|
294
|
+
# kv_cache scales updated on model self_attn forward call in
|
295
|
+
# wrap_module_forward_quantized_attn
|
291
296
|
output = maybe_calibrate_or_quantize(
|
292
297
|
module, output, "output", scheme.output_activations
|
293
298
|
)
|
@@ -304,6 +309,50 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
|
|
304
309
|
setattr(module, "forward", bound_wrapped_forward)
|
305
310
|
|
306
311
|
|
312
|
+
def wrap_module_forward_quantized_attn(module: Module, scheme: QuantizationScheme):
|
313
|
+
# expects a module already initialized and injected with the parameters in
|
314
|
+
# initialize_module_for_quantization
|
315
|
+
if hasattr(module.forward, "__func__"):
|
316
|
+
forward_func_orig = module.forward.__func__
|
317
|
+
else:
|
318
|
+
forward_func_orig = module.forward.func
|
319
|
+
|
320
|
+
@wraps(forward_func_orig) # ensures docstring, names, etc are propagated
|
321
|
+
def wrapped_forward(self, *args, **kwargs):
|
322
|
+
|
323
|
+
# kv cache stored under weights
|
324
|
+
if module.quantization_status == QuantizationStatus.CALIBRATION:
|
325
|
+
quantization_args: QuantizationArgs = scheme.output_activations
|
326
|
+
past_key_value: QuantizedKVParameterCache = quantization_args.get_kv_cache()
|
327
|
+
kwargs["past_key_value"] = past_key_value
|
328
|
+
|
329
|
+
# QuantizedKVParameterCache used for obtaining k_scale, v_scale only,
|
330
|
+
# does not store quantized_key_states and quantized_value_state
|
331
|
+
kwargs["use_cache"] = False
|
332
|
+
|
333
|
+
attn_forward: Callable = forward_func_orig.__get__(module, module.__class__)
|
334
|
+
|
335
|
+
past_key_value.reset_states()
|
336
|
+
|
337
|
+
rtn = attn_forward(*args, **kwargs)
|
338
|
+
|
339
|
+
update_parameter_data(
|
340
|
+
module, past_key_value.k_scales[module.layer_idx], "k_scale"
|
341
|
+
)
|
342
|
+
update_parameter_data(
|
343
|
+
module, past_key_value.v_scales[module.layer_idx], "v_scale"
|
344
|
+
)
|
345
|
+
|
346
|
+
return rtn
|
347
|
+
|
348
|
+
return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
|
349
|
+
|
350
|
+
# bind wrapped forward to module class so reference to `self` is correct
|
351
|
+
bound_wrapped_forward = wrapped_forward.__get__(module, module.__class__)
|
352
|
+
# set forward to wrapped forward
|
353
|
+
setattr(module, "forward", bound_wrapped_forward)
|
354
|
+
|
355
|
+
|
307
356
|
def maybe_calibrate_or_quantize(
|
308
357
|
module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
|
309
358
|
) -> torch.Tensor:
|
@@ -14,6 +14,7 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
17
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
17
18
|
from torch.nn import Module
|
18
19
|
|
19
20
|
|
@@ -44,7 +45,11 @@ def freeze_module_quantization(module: Module):
|
|
44
45
|
delattr(module, "input_observer")
|
45
46
|
if scheme.weights and not scheme.weights.dynamic:
|
46
47
|
delattr(module, "weight_observer")
|
47
|
-
if
|
48
|
+
if (
|
49
|
+
scheme.output_activations
|
50
|
+
and not is_kv_cache_quant_scheme(scheme)
|
51
|
+
and not scheme.output_activations.dynamic
|
52
|
+
):
|
48
53
|
delattr(module, "output_observer")
|
49
54
|
|
50
55
|
module.quantization_status = QuantizationStatus.FROZEN
|
@@ -16,62 +16,15 @@
|
|
16
16
|
Miscelaneous helpers for the quantization lifecycle
|
17
17
|
"""
|
18
18
|
|
19
|
-
from typing import Optional
|
20
|
-
|
21
|
-
import torch
|
22
19
|
from torch.nn import Module
|
23
20
|
|
24
21
|
|
25
22
|
__all__ = [
|
26
|
-
"update_layer_weight_quant_params",
|
27
23
|
"enable_quantization",
|
28
24
|
"disable_quantization",
|
29
25
|
]
|
30
26
|
|
31
27
|
|
32
|
-
def update_layer_weight_quant_params(
|
33
|
-
layer: Module,
|
34
|
-
weight: Optional[torch.Tensor] = None,
|
35
|
-
g_idx: Optional[torch.Tensor] = None,
|
36
|
-
reset_obs: bool = False,
|
37
|
-
):
|
38
|
-
"""
|
39
|
-
Update quantization parameters on layer
|
40
|
-
|
41
|
-
:param layer: input layer
|
42
|
-
:param weight: weight to update quant params with, defaults to layer weight
|
43
|
-
:param g_idx: optional mapping from column index to group index
|
44
|
-
:param reset_obs: reset the observer before calculating quant params,
|
45
|
-
defaults to False
|
46
|
-
"""
|
47
|
-
attached_weight = getattr(layer, "weight", None)
|
48
|
-
|
49
|
-
if weight is None:
|
50
|
-
weight = attached_weight
|
51
|
-
scale = getattr(layer, "weight_scale", None)
|
52
|
-
zero_point = getattr(layer, "weight_zero_point", None)
|
53
|
-
if g_idx is None:
|
54
|
-
g_idx = getattr(layer, "weight_g_idx", None)
|
55
|
-
observer = getattr(layer, "weight_observer", None)
|
56
|
-
|
57
|
-
if weight is None or observer is None or scale is None or zero_point is None:
|
58
|
-
# scale, zp, or observer not calibratable or weight not available
|
59
|
-
return
|
60
|
-
|
61
|
-
if reset_obs:
|
62
|
-
observer.reset()
|
63
|
-
|
64
|
-
if attached_weight is not None:
|
65
|
-
weight = weight.to(attached_weight.dtype)
|
66
|
-
|
67
|
-
updated_scale, updated_zero_point = observer(weight)
|
68
|
-
|
69
|
-
# update scale and zero point
|
70
|
-
device = next(layer.parameters()).device
|
71
|
-
scale.data = updated_scale.to(device)
|
72
|
-
zero_point.data = updated_zero_point.to(device)
|
73
|
-
|
74
|
-
|
75
28
|
def enable_quantization(module: Module):
|
76
29
|
module.quantization_enabled = True
|
77
30
|
|
@@ -17,8 +17,10 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.cache import KVCacheScaleType
|
20
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
22
|
wrap_module_forward_quantized,
|
23
|
+
wrap_module_forward_quantized_attn,
|
22
24
|
)
|
23
25
|
from compressed_tensors.quantization.quant_args import (
|
24
26
|
ActivationOrdering,
|
@@ -27,6 +29,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
27
29
|
)
|
28
30
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
29
31
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
30
33
|
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
31
34
|
from torch.nn import Module, Parameter
|
32
35
|
|
@@ -62,72 +65,85 @@ def initialize_module_for_quantization(
|
|
62
65
|
# no scheme passed and layer not targeted for quantization - skip
|
63
66
|
return
|
64
67
|
|
65
|
-
if
|
66
|
-
|
67
|
-
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
|
68
|
+
if is_attention_module(module):
|
69
|
+
# wrap forward call of module to perform
|
70
|
+
# quantized actions based on calltime status
|
71
|
+
wrap_module_forward_quantized_attn(module, scheme)
|
72
|
+
_initialize_attn_scales(module)
|
73
|
+
|
74
|
+
else:
|
75
|
+
|
76
|
+
if scheme.input_activations is not None:
|
72
77
|
_initialize_scale_zero_point_observer(
|
73
78
|
module,
|
74
|
-
"
|
75
|
-
scheme.
|
76
|
-
weight_shape=weight_shape,
|
79
|
+
"input",
|
80
|
+
scheme.input_activations,
|
77
81
|
force_zero_point=force_zero_point,
|
78
82
|
)
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
)
|
92
|
-
|
93
|
-
module.quantization_scheme = scheme
|
94
|
-
module.quantization_status = QuantizationStatus.INITIALIZED
|
95
|
-
|
96
|
-
offloaded = False
|
97
|
-
if is_module_offloaded(module):
|
98
|
-
try:
|
99
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
100
|
-
from accelerate.utils import PrefixedDataset
|
101
|
-
except ModuleNotFoundError:
|
102
|
-
raise ModuleNotFoundError(
|
103
|
-
"Offloaded model detected. To use CPU offloading with "
|
104
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
105
|
-
"run `pip install compressed-tensors[accelerate]`"
|
106
|
-
)
|
107
|
-
|
108
|
-
offloaded = True
|
109
|
-
hook = module._hf_hook
|
110
|
-
prefix_dict = module._hf_hook.weights_map
|
111
|
-
new_prefix = {}
|
112
|
-
|
113
|
-
# recreate the prefix dict (since it is immutable)
|
114
|
-
# and add quantization parameters
|
115
|
-
for key, data in module.named_parameters():
|
116
|
-
if key not in prefix_dict:
|
117
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
83
|
+
if scheme.weights is not None:
|
84
|
+
if hasattr(module, "weight"):
|
85
|
+
weight_shape = None
|
86
|
+
if isinstance(module, torch.nn.Linear):
|
87
|
+
weight_shape = module.weight.shape
|
88
|
+
_initialize_scale_zero_point_observer(
|
89
|
+
module,
|
90
|
+
"weight",
|
91
|
+
scheme.weights,
|
92
|
+
weight_shape=weight_shape,
|
93
|
+
force_zero_point=force_zero_point,
|
94
|
+
)
|
118
95
|
else:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
96
|
+
_LOGGER.warning(
|
97
|
+
f"module type {type(module)} targeted for weight quantization but "
|
98
|
+
"has no attribute weight, skipping weight quantization "
|
99
|
+
f"for {type(module)}"
|
100
|
+
)
|
101
|
+
|
102
|
+
if scheme.output_activations is not None:
|
103
|
+
if not is_kv_cache_quant_scheme(scheme):
|
104
|
+
_initialize_scale_zero_point_observer(
|
105
|
+
module, "output", scheme.output_activations
|
106
|
+
)
|
107
|
+
|
108
|
+
module.quantization_scheme = scheme
|
109
|
+
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
|
+
|
111
|
+
offloaded = False
|
112
|
+
if is_module_offloaded(module):
|
113
|
+
try:
|
114
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
115
|
+
from accelerate.utils import PrefixedDataset
|
116
|
+
except ModuleNotFoundError:
|
117
|
+
raise ModuleNotFoundError(
|
118
|
+
"Offloaded model detected. To use CPU offloading with "
|
119
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
120
|
+
"run `pip install compressed-tensors[accelerate]`"
|
121
|
+
)
|
122
|
+
|
123
|
+
offloaded = True
|
124
|
+
hook = module._hf_hook
|
125
|
+
prefix_dict = module._hf_hook.weights_map
|
126
|
+
new_prefix = {}
|
127
|
+
|
128
|
+
# recreate the prefix dict (since it is immutable)
|
129
|
+
# and add quantization parameters
|
130
|
+
for key, data in module.named_parameters():
|
131
|
+
if key not in prefix_dict:
|
132
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
133
|
+
else:
|
134
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
135
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
136
|
+
remove_hook_from_module(module)
|
137
|
+
|
138
|
+
# wrap forward call of module to perform
|
139
|
+
# quantized actions based on calltime status
|
140
|
+
wrap_module_forward_quantized(module, scheme)
|
141
|
+
|
142
|
+
if offloaded:
|
143
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
144
|
+
add_hook_to_module(module, hook)
|
145
|
+
if prefix_dict is not None:
|
146
|
+
module._hf_hook.weights_map = new_prefix_dict
|
131
147
|
|
132
148
|
|
133
149
|
def _initialize_scale_zero_point_observer(
|
@@ -156,9 +172,10 @@ def _initialize_scale_zero_point_observer(
|
|
156
172
|
# (output_channels, 1)
|
157
173
|
expected_shape = (weight_shape[0], 1)
|
158
174
|
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
175
|
+
num_groups = weight_shape[1] // quantization_args.group_size
|
159
176
|
expected_shape = (
|
160
177
|
weight_shape[0],
|
161
|
-
|
178
|
+
max(num_groups, 1)
|
162
179
|
)
|
163
180
|
|
164
181
|
scale_dtype = module.weight.dtype
|
@@ -189,3 +206,34 @@ def _initialize_scale_zero_point_observer(
|
|
189
206
|
requires_grad=False,
|
190
207
|
)
|
191
208
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
209
|
+
|
210
|
+
|
211
|
+
def is_attention_module(module: Module):
|
212
|
+
return "attention" in module.__class__.__name__.lower() and (
|
213
|
+
hasattr(module, "k_proj")
|
214
|
+
or hasattr(module, "v_proj")
|
215
|
+
or hasattr(module, "qkv_proj")
|
216
|
+
)
|
217
|
+
|
218
|
+
|
219
|
+
def _initialize_attn_scales(module: Module) -> None:
|
220
|
+
"""Initlaize k_scale, v_scale for self_attn"""
|
221
|
+
|
222
|
+
expected_shape = 1 # per tensor
|
223
|
+
|
224
|
+
param = next(module.parameters())
|
225
|
+
scale_dtype = param.dtype
|
226
|
+
device = param.device
|
227
|
+
|
228
|
+
init_scale = Parameter(
|
229
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
230
|
+
requires_grad=False,
|
231
|
+
)
|
232
|
+
|
233
|
+
module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
|
234
|
+
|
235
|
+
init_scale = Parameter(
|
236
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
237
|
+
requires_grad=False,
|
238
|
+
)
|
239
|
+
module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
|
@@ -122,6 +122,12 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
122
122
|
|
123
123
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
124
124
|
|
125
|
+
def get_kv_cache(self):
|
126
|
+
"""Get the singleton KV Cache"""
|
127
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
128
|
+
|
129
|
+
return QuantizedKVParameterCache(self)
|
130
|
+
|
125
131
|
@field_validator("type", mode="before")
|
126
132
|
def validate_type(cls, value) -> QuantizationType:
|
127
133
|
if isinstance(value, str):
|