compressed-tensors 0.12.3a20251013__py3-none-any.whl → 0.12.3a20251028__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/quantized_compressors/__init__.py +1 -1
- compressed_tensors/compressors/quantized_compressors/{nvfp4_quantized.py → fp4_quantized.py} +9 -0
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +4 -4
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +3 -3
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/modeling/__init__.py +18 -0
- compressed_tensors/modeling/attention.py +147 -0
- compressed_tensors/modeling/kvcache.py +183 -0
- compressed_tensors/quantization/lifecycle/apply.py +48 -103
- compressed_tensors/quantization/lifecycle/initialize.py +83 -28
- compressed_tensors/quantization/quant_args.py +8 -7
- compressed_tensors/quantization/quant_config.py +59 -45
- compressed_tensors/quantization/quant_scheme.py +2 -0
- compressed_tensors/quantization/utils/__init__.py +1 -0
- compressed_tensors/quantization/utils/helpers.py +2 -33
- compressed_tensors/quantization/utils/mxfp4_utils.py +97 -0
- compressed_tensors/utils/helpers.py +63 -1
- compressed_tensors/utils/match.py +29 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/METADATA +5 -5
- {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/RECORD +24 -20
- {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/licenses/LICENSE +0 -0
- {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/top_level.txt +0 -0
compressed_tensors/compressors/quantized_compressors/{nvfp4_quantized.py → fp4_quantized.py}
RENAMED
|
@@ -123,6 +123,15 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
|
|
|
123
123
|
return decompressed_weight
|
|
124
124
|
|
|
125
125
|
|
|
126
|
+
@BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value)
|
|
127
|
+
class MXFP4PackedCompressor(NVFP4PackedCompressor):
|
|
128
|
+
"""
|
|
129
|
+
Alias for mxfp4 quantized models
|
|
130
|
+
"""
|
|
131
|
+
|
|
132
|
+
pass
|
|
133
|
+
|
|
134
|
+
|
|
126
135
|
@torch.compile(fullgraph=True, dynamic=True)
|
|
127
136
|
def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
|
|
128
137
|
"""
|
|
@@ -19,7 +19,7 @@ import torch
|
|
|
19
19
|
from compressed_tensors.compressors.base import BaseCompressor
|
|
20
20
|
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
|
21
21
|
from compressed_tensors.config import CompressionFormat, SparsityStructure
|
|
22
|
-
from compressed_tensors.quantization import
|
|
22
|
+
from compressed_tensors.quantization import FP8_E4M3_DATA
|
|
23
23
|
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
|
24
24
|
from torch import Tensor
|
|
25
25
|
|
|
@@ -189,11 +189,11 @@ def sparse24_bitmask_compress(
|
|
|
189
189
|
|
|
190
190
|
bytemasks = get_24_bytemasks(tensor=tensor)
|
|
191
191
|
|
|
192
|
-
if tensor.dtype ==
|
|
192
|
+
if tensor.dtype == FP8_E4M3_DATA.dtype:
|
|
193
193
|
# acces raw bytes of the tensor
|
|
194
194
|
tensor_view = tensor.view(torch.int8)
|
|
195
195
|
values = tensor_view[bytemasks]
|
|
196
|
-
values = values.view(
|
|
196
|
+
values = values.view(FP8_E4M3_DATA.dtype)
|
|
197
197
|
else:
|
|
198
198
|
values = tensor[bytemasks]
|
|
199
199
|
|
|
@@ -241,7 +241,7 @@ def get_24_bytemasks(tensor):
|
|
|
241
241
|
multiple of 4.
|
|
242
242
|
"""
|
|
243
243
|
original_dtype = tensor.dtype
|
|
244
|
-
if tensor.dtype ==
|
|
244
|
+
if tensor.dtype == FP8_E4M3_DATA.dtype:
|
|
245
245
|
tensor = tensor.view(torch.int8)
|
|
246
246
|
original_shape = tensor.shape
|
|
247
247
|
num_elements = tensor.numel()
|
|
@@ -18,7 +18,7 @@ import torch
|
|
|
18
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
|
19
19
|
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
|
20
20
|
from compressed_tensors.config import CompressionFormat
|
|
21
|
-
from compressed_tensors.quantization import
|
|
21
|
+
from compressed_tensors.quantization import FP8_E4M3_DATA
|
|
22
22
|
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
|
23
23
|
from torch import Tensor
|
|
24
24
|
|
|
@@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
|
138
138
|
bytemasks = tensor != 0
|
|
139
139
|
row_counts = bytemasks.sum(dim=-1)
|
|
140
140
|
row_offsets = torch.cumsum(row_counts, 0) - row_counts
|
|
141
|
-
if tensor.dtype ==
|
|
141
|
+
if tensor.dtype == FP8_E4M3_DATA.dtype:
|
|
142
142
|
# acces raw bytes of the tensor
|
|
143
143
|
tensor_view = tensor.view(torch.int8)
|
|
144
144
|
values = tensor_view[bytemasks]
|
|
145
|
-
values = values.view(
|
|
145
|
+
values = values.view(FP8_E4M3_DATA.dtype)
|
|
146
146
|
else:
|
|
147
147
|
values = tensor[bytemasks]
|
|
148
148
|
bitmasks_packed = pack_bitmasks(bytemasks)
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# flake8: noqa
|
|
16
|
+
# isort: off
|
|
17
|
+
from .kvcache import *
|
|
18
|
+
from .attention import *
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import Callable, Optional
|
|
17
|
+
|
|
18
|
+
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
|
|
19
|
+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
|
|
20
|
+
from compressed_tensors.utils import getattr_chain
|
|
21
|
+
from compressed_tensors.utils.internal import InternalModule
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
from torch.nn import Module
|
|
24
|
+
from torch.utils.hooks import RemovableHandle
|
|
25
|
+
from transformers import PretrainedConfig, PreTrainedModel
|
|
26
|
+
from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
|
|
27
|
+
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
"QuantizedAttentionImpl",
|
|
32
|
+
"initialize_hooked_attention",
|
|
33
|
+
"register_query_hook",
|
|
34
|
+
"IMPL_ATTR",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
IMPL_ATTR = "impl"
|
|
39
|
+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class QuantizedAttentionImpl(InternalModule):
|
|
43
|
+
"""
|
|
44
|
+
QuantizedAttentionImpl module which wraps the functionality of the original
|
|
45
|
+
attention implementation. Unlike the original attention function, this
|
|
46
|
+
implementation is a `torch.nn.Module` which can be hooked to trigger
|
|
47
|
+
transforms and calibration hooks.
|
|
48
|
+
|
|
49
|
+
This module works by being registered as a submodule to attention modules via
|
|
50
|
+
`initialize_hooked_attention`, registering a new attention implementation function
|
|
51
|
+
which calls this module, then setting the model attention implementation to the new
|
|
52
|
+
function. After triggering hooks and quantization, this module calls the original
|
|
53
|
+
attention implementation function.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
_original_impl = "eager"
|
|
57
|
+
|
|
58
|
+
def __init__(self, config: PretrainedConfig):
|
|
59
|
+
super().__init__()
|
|
60
|
+
self.config = config
|
|
61
|
+
|
|
62
|
+
def forward(
|
|
63
|
+
self,
|
|
64
|
+
module: Module,
|
|
65
|
+
query: Tensor,
|
|
66
|
+
key: Tensor,
|
|
67
|
+
value: Tensor,
|
|
68
|
+
*args,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
# quantization
|
|
72
|
+
quant_args_attr = "quantization_scheme.input_activations"
|
|
73
|
+
quant_args = getattr_chain(module, quant_args_attr, None)
|
|
74
|
+
quant_enabled = getattr(module, "quantization_enabled", True)
|
|
75
|
+
if quant_args is not None and quant_enabled:
|
|
76
|
+
query = forward_quantize(module, query, "q", quant_args)
|
|
77
|
+
|
|
78
|
+
# original attention
|
|
79
|
+
return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl](
|
|
80
|
+
module,
|
|
81
|
+
query,
|
|
82
|
+
key,
|
|
83
|
+
value,
|
|
84
|
+
*args,
|
|
85
|
+
**kwargs,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# ----- initialize ----- #
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _hooked_attention(module: Module, *args, **kwargs):
|
|
93
|
+
assert hasattr(module, IMPL_ATTR), (
|
|
94
|
+
f"Using {HOOKED_ATTENTION_NAME} attention implementation, "
|
|
95
|
+
f"but attention module does not have {IMPL_ATTR} submodule."
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return getattr(module, IMPL_ATTR)(module, *args, **kwargs)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def initialize_hooked_attention(model: PreTrainedModel, module: Module):
|
|
102
|
+
"""
|
|
103
|
+
Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
|
|
104
|
+
attached to attention. Assumes that only one model is hooked at a time.
|
|
105
|
+
|
|
106
|
+
:param model: parent model of attention module
|
|
107
|
+
:param module: attention module to initialize with
|
|
108
|
+
"""
|
|
109
|
+
if not hasattr(module, IMPL_ATTR):
|
|
110
|
+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config))
|
|
111
|
+
|
|
112
|
+
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
|
|
113
|
+
QuantizedAttentionImpl._original_impl = model.config._attn_implementation
|
|
114
|
+
original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation]
|
|
115
|
+
|
|
116
|
+
ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention)
|
|
117
|
+
ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask)
|
|
118
|
+
model.set_attn_implementation(HOOKED_ATTENTION_NAME)
|
|
119
|
+
assert model.config._attn_implementation == HOOKED_ATTENTION_NAME
|
|
120
|
+
|
|
121
|
+
initialize_hooked_kv_cache(model, module)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# ----- hooks ----- #
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def register_query_hook(
|
|
128
|
+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
|
|
129
|
+
) -> RemovableHandle:
|
|
130
|
+
"""
|
|
131
|
+
Register a hook which takes post-rope query states as an argument and
|
|
132
|
+
returns the modified query states or `None`
|
|
133
|
+
|
|
134
|
+
:param module: attention module to add hook to
|
|
135
|
+
:param hook: query hook function
|
|
136
|
+
"""
|
|
137
|
+
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
|
|
138
|
+
|
|
139
|
+
def _hook(impl: QuantizedAttentionImpl, args, kwargs):
|
|
140
|
+
bound = inspect.signature(impl.forward).bind(*args, **kwargs)
|
|
141
|
+
value = hook(module, bound.arguments["query"])
|
|
142
|
+
if value is not None:
|
|
143
|
+
bound.arguments["query"] = value
|
|
144
|
+
|
|
145
|
+
return bound.args, bound.kwargs
|
|
146
|
+
|
|
147
|
+
return impl.register_forward_pre_hook(_hook, with_kwargs=True)
|
|
@@ -0,0 +1,183 @@
|
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
import inspect
|
|
16
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
17
|
+
from weakref import ReferenceType, ref
|
|
18
|
+
|
|
19
|
+
from compressed_tensors.quantization.lifecycle.forward import forward_quantize
|
|
20
|
+
from compressed_tensors.utils import getattr_chain
|
|
21
|
+
from compressed_tensors.utils.internal import InternalModule
|
|
22
|
+
from torch import Tensor
|
|
23
|
+
from torch.nn import Module
|
|
24
|
+
from torch.utils.hooks import RemovableHandle
|
|
25
|
+
from transformers import Cache, PretrainedConfig, PreTrainedModel
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
__all__ = [
|
|
29
|
+
"QuantizedKVCache",
|
|
30
|
+
"initialize_hooked_kv_cache",
|
|
31
|
+
"register_key_hook",
|
|
32
|
+
"register_value_hook",
|
|
33
|
+
"KV_CACHE_ATTR",
|
|
34
|
+
]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
KV_CACHE_ATTR = "kv_cache"
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class QuantizedKVCache(InternalModule):
|
|
41
|
+
"""
|
|
42
|
+
QuantizedKVCache module which wraps the functionality of any existing kvcache args.
|
|
43
|
+
Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
|
|
44
|
+
hooked to trigger transforms and calibration hooks.
|
|
45
|
+
|
|
46
|
+
This module works by being registered as a submodule to attention modules via
|
|
47
|
+
`initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
|
|
48
|
+
kwargs with this module. This module adopts the functionality of the replaced cache,
|
|
49
|
+
preserving caching functionality such as sliding window attention, ect.
|
|
50
|
+
|
|
51
|
+
:param attn_module: parent attention module
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self, config: PretrainedConfig, attn_module: Module):
|
|
55
|
+
super().__init__()
|
|
56
|
+
self.config = config
|
|
57
|
+
self.attn_module = ref(attn_module) # avoid circular reference
|
|
58
|
+
self.past_key_values: Optional[ReferenceType[Cache]] = None
|
|
59
|
+
|
|
60
|
+
def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
|
|
61
|
+
return self(*args, **kwargs)
|
|
62
|
+
|
|
63
|
+
def forward(
|
|
64
|
+
self,
|
|
65
|
+
key_states: Tensor,
|
|
66
|
+
value_states: Tensor,
|
|
67
|
+
*args,
|
|
68
|
+
**kwargs,
|
|
69
|
+
) -> Tuple[Tensor, Tensor]:
|
|
70
|
+
# quantization
|
|
71
|
+
module = self.attn_module()
|
|
72
|
+
quant_args_attr = "quantization_scheme.input_activations"
|
|
73
|
+
quant_args = getattr_chain(module, quant_args_attr, None)
|
|
74
|
+
quant_enabled = getattr(module, "quantization_enabled", True)
|
|
75
|
+
if quant_args is not None and quant_enabled:
|
|
76
|
+
key_states = forward_quantize(module, key_states, "k", quant_args)
|
|
77
|
+
value_states = forward_quantize(module, value_states, "v", quant_args)
|
|
78
|
+
|
|
79
|
+
# original cache
|
|
80
|
+
if self.past_key_values is not None:
|
|
81
|
+
ret = self.past_key_values().update(
|
|
82
|
+
key_states, value_states, *args, **kwargs
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
ret = (key_states, value_states)
|
|
86
|
+
self.past_key_values = None
|
|
87
|
+
|
|
88
|
+
return ret
|
|
89
|
+
|
|
90
|
+
def add_past_key_values(self, past_key_values: Optional[Cache]):
|
|
91
|
+
if past_key_values is not None:
|
|
92
|
+
self.past_key_values = ref(past_key_values)
|
|
93
|
+
else:
|
|
94
|
+
self.past_key_values = None
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
# ----- initialize ----- #
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def _kv_cache_attention_hook(
|
|
101
|
+
module: Module, args: List[Any], kwargs: Dict[str, Any]
|
|
102
|
+
) -> Tuple[List[Any], Dict[str, Any]]:
|
|
103
|
+
"""
|
|
104
|
+
Hook which should be called before each quantized attention forward pass.
|
|
105
|
+
This hook dynamically replaces the `past_key_values` kwarg to the attention
|
|
106
|
+
forward function.
|
|
107
|
+
|
|
108
|
+
The original kvcache object is assigned to QuantizedKVCache().past_key_values
|
|
109
|
+
as a weakref to maintain original cache functionality and compute savings
|
|
110
|
+
"""
|
|
111
|
+
_past_kv_name = (
|
|
112
|
+
"past_key_values" # transformers#39956
|
|
113
|
+
if "past_key_values" in inspect.signature(module.forward).parameters
|
|
114
|
+
else "past_key_value"
|
|
115
|
+
)
|
|
116
|
+
past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
|
|
117
|
+
|
|
118
|
+
cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
|
|
119
|
+
cache.add_past_key_values(past_key_values)
|
|
120
|
+
kwargs[_past_kv_name] = cache
|
|
121
|
+
|
|
122
|
+
return args, kwargs
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
|
|
126
|
+
"""
|
|
127
|
+
Initialize a `QuantizedKVCache` instance attached to attention
|
|
128
|
+
|
|
129
|
+
:param model: parent model of attention module
|
|
130
|
+
:param module: attention module to initialize with
|
|
131
|
+
"""
|
|
132
|
+
if not hasattr(module, KV_CACHE_ATTR):
|
|
133
|
+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
|
|
134
|
+
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
# ----- hooks ----- #
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def register_key_hook(
|
|
141
|
+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
|
|
142
|
+
) -> RemovableHandle:
|
|
143
|
+
"""
|
|
144
|
+
Register a hook which takes post-rope key states as an argument and
|
|
145
|
+
returns the modified key states or `None`
|
|
146
|
+
|
|
147
|
+
:param module: attention module to add hook to
|
|
148
|
+
:param hook: key hook function
|
|
149
|
+
"""
|
|
150
|
+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
|
|
151
|
+
|
|
152
|
+
def _hook(cache: QuantizedKVCache, args, kwargs):
|
|
153
|
+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
|
|
154
|
+
value = hook(module, bound.arguments["key_states"])
|
|
155
|
+
if value is not None:
|
|
156
|
+
bound.arguments["key_states"] = value
|
|
157
|
+
|
|
158
|
+
return bound.args, bound.kwargs
|
|
159
|
+
|
|
160
|
+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def register_value_hook(
|
|
164
|
+
module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
|
|
165
|
+
) -> RemovableHandle:
|
|
166
|
+
"""
|
|
167
|
+
Register a hook which takes value states as an argument and
|
|
168
|
+
returns the modified value states or `None`
|
|
169
|
+
|
|
170
|
+
:param module: attention module to add hook to
|
|
171
|
+
:param hook: value hook function
|
|
172
|
+
"""
|
|
173
|
+
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
|
|
174
|
+
|
|
175
|
+
def _hook(cache: QuantizedKVCache, args, kwargs):
|
|
176
|
+
bound = inspect.signature(cache.forward).bind(*args, **kwargs)
|
|
177
|
+
value = hook(module, bound.arguments["value_states"])
|
|
178
|
+
if value is not None:
|
|
179
|
+
bound.arguments["value_states"] = value
|
|
180
|
+
|
|
181
|
+
return bound.args, bound.kwargs
|
|
182
|
+
|
|
183
|
+
return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
|
|
@@ -12,7 +12,6 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
|
-
import logging
|
|
16
15
|
from collections import OrderedDict
|
|
17
16
|
from copy import deepcopy
|
|
18
17
|
from typing import Dict, List, Optional
|
|
@@ -21,8 +20,13 @@ from typing import Union
|
|
|
21
20
|
|
|
22
21
|
import torch
|
|
23
22
|
from compressed_tensors.config import CompressionFormat
|
|
23
|
+
from compressed_tensors.modeling import (
|
|
24
|
+
initialize_hooked_attention,
|
|
25
|
+
initialize_hooked_kv_cache,
|
|
26
|
+
)
|
|
24
27
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
|
25
28
|
initialize_module_for_quantization,
|
|
29
|
+
is_attention_module,
|
|
26
30
|
)
|
|
27
31
|
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
|
28
32
|
from compressed_tensors.quantization.quant_config import (
|
|
@@ -30,14 +34,15 @@ from compressed_tensors.quantization.quant_config import (
|
|
|
30
34
|
QuantizationStatus,
|
|
31
35
|
)
|
|
32
36
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
|
33
|
-
from compressed_tensors.quantization.utils import (
|
|
34
|
-
KV_CACHE_TARGETS,
|
|
35
|
-
is_kv_cache_quant_scheme,
|
|
36
|
-
)
|
|
37
37
|
from compressed_tensors.utils.helpers import replace_module
|
|
38
|
-
from compressed_tensors.utils.match import
|
|
38
|
+
from compressed_tensors.utils.match import (
|
|
39
|
+
is_narrow_match,
|
|
40
|
+
match_named_modules,
|
|
41
|
+
match_targets,
|
|
42
|
+
)
|
|
39
43
|
from compressed_tensors.utils.offload import update_parameter_data
|
|
40
44
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
|
45
|
+
from loguru import logger
|
|
41
46
|
from safetensors import safe_open
|
|
42
47
|
from torch.nn import Module
|
|
43
48
|
|
|
@@ -53,9 +58,6 @@ from compressed_tensors.utils.safetensors_load import (
|
|
|
53
58
|
)
|
|
54
59
|
|
|
55
60
|
|
|
56
|
-
_LOGGER = logging.getLogger(__name__)
|
|
57
|
-
|
|
58
|
-
|
|
59
61
|
def load_pretrained_quantization_parameters(
|
|
60
62
|
model: Module,
|
|
61
63
|
model_name_or_path: Optional[str] = None,
|
|
@@ -125,8 +127,14 @@ def apply_quantization_config(
|
|
|
125
127
|
if config is None: # see PR #180
|
|
126
128
|
return dict()
|
|
127
129
|
|
|
128
|
-
#
|
|
129
|
-
|
|
130
|
+
# force zero points during initialization
|
|
131
|
+
force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED
|
|
132
|
+
|
|
133
|
+
# apply and initialize kv cache quantization
|
|
134
|
+
if config.kv_cache_scheme is not None:
|
|
135
|
+
_apply_kv_cache_scheme(
|
|
136
|
+
model, config.kv_cache_scheme, config.quantization_status
|
|
137
|
+
)
|
|
130
138
|
|
|
131
139
|
# build mapping of targets to schemes for easier matching
|
|
132
140
|
# use ordered dict to preserve target ordering in config
|
|
@@ -162,49 +170,40 @@ def apply_quantization_config(
|
|
|
162
170
|
replace_module(model, name, compressed_linear)
|
|
163
171
|
|
|
164
172
|
else:
|
|
173
|
+
if is_attention_module(submodule) and is_narrow_match(
|
|
174
|
+
model, scheme.targets, name
|
|
175
|
+
):
|
|
176
|
+
initialize_hooked_attention(model, submodule)
|
|
177
|
+
|
|
165
178
|
initialize_module_for_quantization(
|
|
166
179
|
submodule,
|
|
167
|
-
force_zero_point=
|
|
168
|
-
!= QuantizationStatus.COMPRESSED,
|
|
180
|
+
force_zero_point=force_zero_point,
|
|
169
181
|
)
|
|
170
182
|
|
|
171
183
|
submodule.quantization_status = config.quantization_status
|
|
172
184
|
|
|
173
185
|
|
|
174
|
-
def
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
|
|
189
|
-
) -> QuantizationConfig:
|
|
190
|
-
"""
|
|
191
|
-
Reformulate the `config.kv_cache` as a `config_group`
|
|
192
|
-
and add it to the set of existing `config.groups`
|
|
193
|
-
|
|
194
|
-
:param config: the QuantizationConfig
|
|
195
|
-
:return: the QuantizationConfig with additional "kv_cache" group
|
|
196
|
-
"""
|
|
197
|
-
if targets == KV_CACHE_TARGETS:
|
|
198
|
-
_LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
|
|
199
|
-
|
|
200
|
-
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
|
201
|
-
kv_cache_scheme = QuantizationScheme(
|
|
202
|
-
output_activations=QuantizationArgs(**kv_cache_dict),
|
|
203
|
-
targets=targets,
|
|
186
|
+
def _apply_kv_cache_scheme(
|
|
187
|
+
model: torch.nn.Module,
|
|
188
|
+
kv_cache_scheme: QuantizationArgs,
|
|
189
|
+
status: QuantizationStatus,
|
|
190
|
+
):
|
|
191
|
+
if not kv_cache_scheme.symmetric:
|
|
192
|
+
raise logger.warning("vLLM does not support asymmetric kv cache quantization")
|
|
193
|
+
|
|
194
|
+
# applies and initializes kv cache quantization
|
|
195
|
+
# this step cannot come after attention apply/initialize
|
|
196
|
+
# otherwise it will override the attention qparams
|
|
197
|
+
scheme = QuantizationScheme(
|
|
198
|
+
targets=[".*self_attn$"], # is never read in practice
|
|
199
|
+
input_activations=kv_cache_scheme,
|
|
204
200
|
)
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
201
|
+
for submodule in model.modules():
|
|
202
|
+
if is_attention_module(submodule):
|
|
203
|
+
submodule.quantization_scheme = scheme
|
|
204
|
+
initialize_hooked_kv_cache(model, submodule)
|
|
205
|
+
initialize_module_for_quantization(submodule, force_zero_point=False)
|
|
206
|
+
submodule.quantization_status = status
|
|
208
207
|
|
|
209
208
|
|
|
210
209
|
def _load_quant_args_from_mapping(
|
|
@@ -256,60 +255,6 @@ def _scheme_from_targets(
|
|
|
256
255
|
targets: List[str],
|
|
257
256
|
name: str,
|
|
258
257
|
) -> QuantizationScheme:
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
return target_to_scheme[targets[0]]
|
|
263
|
-
|
|
264
|
-
# otherwise, we need to merge QuantizationSchemes corresponding
|
|
265
|
-
# to multiple targets. This is most likely because `name` module
|
|
266
|
-
# is being target both as an ordinary quantization target, as well
|
|
267
|
-
# as kv cache quantization target
|
|
268
|
-
schemes_to_merge = [target_to_scheme[target] for target in targets]
|
|
269
|
-
return _merge_schemes(schemes_to_merge, name)
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
def _merge_schemes(
|
|
273
|
-
schemes_to_merge: List[QuantizationScheme], name: str
|
|
274
|
-
) -> QuantizationScheme:
|
|
275
|
-
kv_cache_quantization_scheme = [
|
|
276
|
-
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
|
277
|
-
]
|
|
278
|
-
if not kv_cache_quantization_scheme:
|
|
279
|
-
# if the schemes_to_merge do not contain any
|
|
280
|
-
# kv cache QuantizationScheme
|
|
281
|
-
# return the first scheme (the prioritized one,
|
|
282
|
-
# since the order of schemes_to_merge matters)
|
|
283
|
-
return schemes_to_merge[0]
|
|
284
|
-
else:
|
|
285
|
-
# fetch the kv cache QuantizationScheme and the highest
|
|
286
|
-
# priority non-kv cache QuantizationScheme and merge them
|
|
287
|
-
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
|
|
288
|
-
quantization_scheme = [
|
|
289
|
-
scheme
|
|
290
|
-
for scheme in schemes_to_merge
|
|
291
|
-
if not is_kv_cache_quant_scheme(scheme)
|
|
292
|
-
][0]
|
|
293
|
-
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
|
|
294
|
-
merged_scheme = {}
|
|
295
|
-
for scheme in schemes_to_merge:
|
|
296
|
-
scheme_dict = {
|
|
297
|
-
k: v for k, v in scheme.model_dump().items() if v is not None
|
|
298
|
-
}
|
|
299
|
-
# when merging multiple schemes, the final target will be
|
|
300
|
-
# the `name` argument - hence erase the original targets
|
|
301
|
-
del scheme_dict["targets"]
|
|
302
|
-
# make sure that schemes do not "clash" with each other
|
|
303
|
-
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
|
|
304
|
-
if overlapping_keys:
|
|
305
|
-
raise ValueError(
|
|
306
|
-
f"The module: {name} is being modified by two clashing "
|
|
307
|
-
f"quantization schemes, that jointly try to override "
|
|
308
|
-
f"properties: {overlapping_keys}. Fix the quantization config "
|
|
309
|
-
"so that it is not ambiguous."
|
|
310
|
-
)
|
|
311
|
-
merged_scheme.update(scheme_dict)
|
|
312
|
-
|
|
313
|
-
merged_scheme.update(targets=[name])
|
|
314
|
-
|
|
315
|
-
return QuantizationScheme(**merged_scheme)
|
|
258
|
+
# return the first scheme (the prioritized one,
|
|
259
|
+
# since the order of target_to_scheme matters)
|
|
260
|
+
return target_to_scheme[targets[0]]
|