compressed-tensors 0.12.3a20251013__py3-none-any.whl → 0.12.3a20251028__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/quantized_compressors/__init__.py +1 -1
  2. compressed_tensors/compressors/quantized_compressors/{nvfp4_quantized.py → fp4_quantized.py} +9 -0
  3. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +4 -4
  4. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +3 -3
  5. compressed_tensors/config/base.py +1 -0
  6. compressed_tensors/modeling/__init__.py +18 -0
  7. compressed_tensors/modeling/attention.py +147 -0
  8. compressed_tensors/modeling/kvcache.py +183 -0
  9. compressed_tensors/quantization/lifecycle/apply.py +48 -103
  10. compressed_tensors/quantization/lifecycle/initialize.py +83 -28
  11. compressed_tensors/quantization/quant_args.py +8 -7
  12. compressed_tensors/quantization/quant_config.py +59 -45
  13. compressed_tensors/quantization/quant_scheme.py +2 -0
  14. compressed_tensors/quantization/utils/__init__.py +1 -0
  15. compressed_tensors/quantization/utils/helpers.py +2 -33
  16. compressed_tensors/quantization/utils/mxfp4_utils.py +97 -0
  17. compressed_tensors/utils/helpers.py +63 -1
  18. compressed_tensors/utils/match.py +29 -0
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/METADATA +5 -5
  21. {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/RECORD +24 -20
  22. {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/WHEEL +0 -0
  23. {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/licenses/LICENSE +0 -0
  24. {compressed_tensors-0.12.3a20251013.dist-info → compressed_tensors-0.12.3a20251028.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,6 @@
14
14
  # flake8: noqa
15
15
 
16
16
  from .base import *
17
+ from .fp4_quantized import *
17
18
  from .naive_quantized import *
18
- from .nvfp4_quantized import *
19
19
  from .pack_quantized import *
@@ -123,6 +123,15 @@ class NVFP4PackedCompressor(BaseQuantizationCompressor):
123
123
  return decompressed_weight
124
124
 
125
125
 
126
+ @BaseCompressor.register(name=CompressionFormat.mxfp4_pack_quantized.value)
127
+ class MXFP4PackedCompressor(NVFP4PackedCompressor):
128
+ """
129
+ Alias for mxfp4 quantized models
130
+ """
131
+
132
+ pass
133
+
134
+
126
135
  @torch.compile(fullgraph=True, dynamic=True)
127
136
  def pack_fp4_to_uint8(x: torch.Tensor) -> torch.Tensor:
128
137
  """
@@ -19,7 +19,7 @@ import torch
19
19
  from compressed_tensors.compressors.base import BaseCompressor
20
20
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
21
  from compressed_tensors.config import CompressionFormat, SparsityStructure
22
- from compressed_tensors.quantization import FP8_DTYPE
22
+ from compressed_tensors.quantization import FP8_E4M3_DATA
23
23
  from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
24
  from torch import Tensor
25
25
 
@@ -189,11 +189,11 @@ def sparse24_bitmask_compress(
189
189
 
190
190
  bytemasks = get_24_bytemasks(tensor=tensor)
191
191
 
192
- if tensor.dtype == FP8_DTYPE:
192
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
193
193
  # acces raw bytes of the tensor
194
194
  tensor_view = tensor.view(torch.int8)
195
195
  values = tensor_view[bytemasks]
196
- values = values.view(FP8_DTYPE)
196
+ values = values.view(FP8_E4M3_DATA.dtype)
197
197
  else:
198
198
  values = tensor[bytemasks]
199
199
 
@@ -241,7 +241,7 @@ def get_24_bytemasks(tensor):
241
241
  multiple of 4.
242
242
  """
243
243
  original_dtype = tensor.dtype
244
- if tensor.dtype == FP8_DTYPE:
244
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
245
245
  tensor = tensor.view(torch.int8)
246
246
  original_shape = tensor.shape
247
247
  num_elements = tensor.numel()
@@ -18,7 +18,7 @@ import torch
18
18
  from compressed_tensors.compressors.base import BaseCompressor
19
19
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
20
20
  from compressed_tensors.config import CompressionFormat
21
- from compressed_tensors.quantization import FP8_DTYPE
21
+ from compressed_tensors.quantization import FP8_E4M3_DATA
22
22
  from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
23
23
  from torch import Tensor
24
24
 
@@ -138,11 +138,11 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
138
138
  bytemasks = tensor != 0
139
139
  row_counts = bytemasks.sum(dim=-1)
140
140
  row_offsets = torch.cumsum(row_counts, 0) - row_counts
141
- if tensor.dtype == FP8_DTYPE:
141
+ if tensor.dtype == FP8_E4M3_DATA.dtype:
142
142
  # acces raw bytes of the tensor
143
143
  tensor_view = tensor.view(torch.int8)
144
144
  values = tensor_view[bytemasks]
145
- values = values.view(FP8_DTYPE)
145
+ values = values.view(FP8_E4M3_DATA.dtype)
146
146
  else:
147
147
  values = tensor[bytemasks]
148
148
  bitmasks_packed = pack_bitmasks(bytemasks)
@@ -34,6 +34,7 @@ class CompressionFormat(Enum):
34
34
  marlin_24 = "marlin-24"
35
35
  mixed_precision = "mixed-precision"
36
36
  nvfp4_pack_quantized = "nvfp4-pack-quantized"
37
+ mxfp4_pack_quantized = "mxfp4-pack-quantized"
37
38
 
38
39
 
39
40
  @unique
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+ # isort: off
17
+ from .kvcache import *
18
+ from .attention import *
@@ -0,0 +1,147 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Callable, Optional
17
+
18
+ from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
19
+ from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20
+ from compressed_tensors.utils import getattr_chain
21
+ from compressed_tensors.utils.internal import InternalModule
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+ from torch.utils.hooks import RemovableHandle
25
+ from transformers import PretrainedConfig, PreTrainedModel
26
+ from transformers.masking_utils import ALL_MASK_ATTENTION_FUNCTIONS
27
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
28
+
29
+
30
+ __all__ = [
31
+ "QuantizedAttentionImpl",
32
+ "initialize_hooked_attention",
33
+ "register_query_hook",
34
+ "IMPL_ATTR",
35
+ ]
36
+
37
+
38
+ IMPL_ATTR = "impl"
39
+ HOOKED_ATTENTION_NAME = "ct_hooked_attention"
40
+
41
+
42
+ class QuantizedAttentionImpl(InternalModule):
43
+ """
44
+ QuantizedAttentionImpl module which wraps the functionality of the original
45
+ attention implementation. Unlike the original attention function, this
46
+ implementation is a `torch.nn.Module` which can be hooked to trigger
47
+ transforms and calibration hooks.
48
+
49
+ This module works by being registered as a submodule to attention modules via
50
+ `initialize_hooked_attention`, registering a new attention implementation function
51
+ which calls this module, then setting the model attention implementation to the new
52
+ function. After triggering hooks and quantization, this module calls the original
53
+ attention implementation function.
54
+ """
55
+
56
+ _original_impl = "eager"
57
+
58
+ def __init__(self, config: PretrainedConfig):
59
+ super().__init__()
60
+ self.config = config
61
+
62
+ def forward(
63
+ self,
64
+ module: Module,
65
+ query: Tensor,
66
+ key: Tensor,
67
+ value: Tensor,
68
+ *args,
69
+ **kwargs,
70
+ ):
71
+ # quantization
72
+ quant_args_attr = "quantization_scheme.input_activations"
73
+ quant_args = getattr_chain(module, quant_args_attr, None)
74
+ quant_enabled = getattr(module, "quantization_enabled", True)
75
+ if quant_args is not None and quant_enabled:
76
+ query = forward_quantize(module, query, "q", quant_args)
77
+
78
+ # original attention
79
+ return ALL_ATTENTION_FUNCTIONS[QuantizedAttentionImpl._original_impl](
80
+ module,
81
+ query,
82
+ key,
83
+ value,
84
+ *args,
85
+ **kwargs,
86
+ )
87
+
88
+
89
+ # ----- initialize ----- #
90
+
91
+
92
+ def _hooked_attention(module: Module, *args, **kwargs):
93
+ assert hasattr(module, IMPL_ATTR), (
94
+ f"Using {HOOKED_ATTENTION_NAME} attention implementation, "
95
+ f"but attention module does not have {IMPL_ATTR} submodule."
96
+ )
97
+
98
+ return getattr(module, IMPL_ATTR)(module, *args, **kwargs)
99
+
100
+
101
+ def initialize_hooked_attention(model: PreTrainedModel, module: Module):
102
+ """
103
+ Initialize `QuantizedAttentionImpl` and `QuantizedKVCache` instances
104
+ attached to attention. Assumes that only one model is hooked at a time.
105
+
106
+ :param model: parent model of attention module
107
+ :param module: attention module to initialize with
108
+ """
109
+ if not hasattr(module, IMPL_ATTR):
110
+ module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config))
111
+
112
+ if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
113
+ QuantizedAttentionImpl._original_impl = model.config._attn_implementation
114
+ original_mask = ALL_MASK_ATTENTION_FUNCTIONS[model.config._attn_implementation]
115
+
116
+ ALL_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, _hooked_attention)
117
+ ALL_MASK_ATTENTION_FUNCTIONS.register(HOOKED_ATTENTION_NAME, original_mask)
118
+ model.set_attn_implementation(HOOKED_ATTENTION_NAME)
119
+ assert model.config._attn_implementation == HOOKED_ATTENTION_NAME
120
+
121
+ initialize_hooked_kv_cache(model, module)
122
+
123
+
124
+ # ----- hooks ----- #
125
+
126
+
127
+ def register_query_hook(
128
+ module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
129
+ ) -> RemovableHandle:
130
+ """
131
+ Register a hook which takes post-rope query states as an argument and
132
+ returns the modified query states or `None`
133
+
134
+ :param module: attention module to add hook to
135
+ :param hook: query hook function
136
+ """
137
+ impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
138
+
139
+ def _hook(impl: QuantizedAttentionImpl, args, kwargs):
140
+ bound = inspect.signature(impl.forward).bind(*args, **kwargs)
141
+ value = hook(module, bound.arguments["query"])
142
+ if value is not None:
143
+ bound.arguments["query"] = value
144
+
145
+ return bound.args, bound.kwargs
146
+
147
+ return impl.register_forward_pre_hook(_hook, with_kwargs=True)
@@ -0,0 +1,183 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple
17
+ from weakref import ReferenceType, ref
18
+
19
+ from compressed_tensors.quantization.lifecycle.forward import forward_quantize
20
+ from compressed_tensors.utils import getattr_chain
21
+ from compressed_tensors.utils.internal import InternalModule
22
+ from torch import Tensor
23
+ from torch.nn import Module
24
+ from torch.utils.hooks import RemovableHandle
25
+ from transformers import Cache, PretrainedConfig, PreTrainedModel
26
+
27
+
28
+ __all__ = [
29
+ "QuantizedKVCache",
30
+ "initialize_hooked_kv_cache",
31
+ "register_key_hook",
32
+ "register_value_hook",
33
+ "KV_CACHE_ATTR",
34
+ ]
35
+
36
+
37
+ KV_CACHE_ATTR = "kv_cache"
38
+
39
+
40
+ class QuantizedKVCache(InternalModule):
41
+ """
42
+ QuantizedKVCache module which wraps the functionality of any existing kvcache args.
43
+ Unlike transform Cache instances, this cache is a `torch.nn.Module` which can be
44
+ hooked to trigger transforms and calibration hooks.
45
+
46
+ This module works by being registered as a submodule to attention modules via
47
+ `initialize_hooked_kv_cache`, then adding a hook which replaces `past_key_values`
48
+ kwargs with this module. This module adopts the functionality of the replaced cache,
49
+ preserving caching functionality such as sliding window attention, ect.
50
+
51
+ :param attn_module: parent attention module
52
+ """
53
+
54
+ def __init__(self, config: PretrainedConfig, attn_module: Module):
55
+ super().__init__()
56
+ self.config = config
57
+ self.attn_module = ref(attn_module) # avoid circular reference
58
+ self.past_key_values: Optional[ReferenceType[Cache]] = None
59
+
60
+ def update(self, *args, **kwargs) -> Tuple[Tensor, Tensor]:
61
+ return self(*args, **kwargs)
62
+
63
+ def forward(
64
+ self,
65
+ key_states: Tensor,
66
+ value_states: Tensor,
67
+ *args,
68
+ **kwargs,
69
+ ) -> Tuple[Tensor, Tensor]:
70
+ # quantization
71
+ module = self.attn_module()
72
+ quant_args_attr = "quantization_scheme.input_activations"
73
+ quant_args = getattr_chain(module, quant_args_attr, None)
74
+ quant_enabled = getattr(module, "quantization_enabled", True)
75
+ if quant_args is not None and quant_enabled:
76
+ key_states = forward_quantize(module, key_states, "k", quant_args)
77
+ value_states = forward_quantize(module, value_states, "v", quant_args)
78
+
79
+ # original cache
80
+ if self.past_key_values is not None:
81
+ ret = self.past_key_values().update(
82
+ key_states, value_states, *args, **kwargs
83
+ )
84
+ else:
85
+ ret = (key_states, value_states)
86
+ self.past_key_values = None
87
+
88
+ return ret
89
+
90
+ def add_past_key_values(self, past_key_values: Optional[Cache]):
91
+ if past_key_values is not None:
92
+ self.past_key_values = ref(past_key_values)
93
+ else:
94
+ self.past_key_values = None
95
+
96
+
97
+ # ----- initialize ----- #
98
+
99
+
100
+ def _kv_cache_attention_hook(
101
+ module: Module, args: List[Any], kwargs: Dict[str, Any]
102
+ ) -> Tuple[List[Any], Dict[str, Any]]:
103
+ """
104
+ Hook which should be called before each quantized attention forward pass.
105
+ This hook dynamically replaces the `past_key_values` kwarg to the attention
106
+ forward function.
107
+
108
+ The original kvcache object is assigned to QuantizedKVCache().past_key_values
109
+ as a weakref to maintain original cache functionality and compute savings
110
+ """
111
+ _past_kv_name = (
112
+ "past_key_values" # transformers#39956
113
+ if "past_key_values" in inspect.signature(module.forward).parameters
114
+ else "past_key_value"
115
+ )
116
+ past_key_values: Optional[Cache] = kwargs.get(_past_kv_name, None)
117
+
118
+ cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
119
+ cache.add_past_key_values(past_key_values)
120
+ kwargs[_past_kv_name] = cache
121
+
122
+ return args, kwargs
123
+
124
+
125
+ def initialize_hooked_kv_cache(model: PreTrainedModel, module: Module):
126
+ """
127
+ Initialize a `QuantizedKVCache` instance attached to attention
128
+
129
+ :param model: parent model of attention module
130
+ :param module: attention module to initialize with
131
+ """
132
+ if not hasattr(module, KV_CACHE_ATTR):
133
+ module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
134
+ module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
135
+
136
+
137
+ # ----- hooks ----- #
138
+
139
+
140
+ def register_key_hook(
141
+ module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
142
+ ) -> RemovableHandle:
143
+ """
144
+ Register a hook which takes post-rope key states as an argument and
145
+ returns the modified key states or `None`
146
+
147
+ :param module: attention module to add hook to
148
+ :param hook: key hook function
149
+ """
150
+ kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
151
+
152
+ def _hook(cache: QuantizedKVCache, args, kwargs):
153
+ bound = inspect.signature(cache.forward).bind(*args, **kwargs)
154
+ value = hook(module, bound.arguments["key_states"])
155
+ if value is not None:
156
+ bound.arguments["key_states"] = value
157
+
158
+ return bound.args, bound.kwargs
159
+
160
+ return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
161
+
162
+
163
+ def register_value_hook(
164
+ module: Module, hook: Callable[[Module, Tensor], Optional[Tensor]]
165
+ ) -> RemovableHandle:
166
+ """
167
+ Register a hook which takes value states as an argument and
168
+ returns the modified value states or `None`
169
+
170
+ :param module: attention module to add hook to
171
+ :param hook: value hook function
172
+ """
173
+ kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)
174
+
175
+ def _hook(cache: QuantizedKVCache, args, kwargs):
176
+ bound = inspect.signature(cache.forward).bind(*args, **kwargs)
177
+ value = hook(module, bound.arguments["value_states"])
178
+ if value is not None:
179
+ bound.arguments["value_states"] = value
180
+
181
+ return bound.args, bound.kwargs
182
+
183
+ return kv_cache.register_forward_pre_hook(_hook, with_kwargs=True)
@@ -12,7 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
16
15
  from collections import OrderedDict
17
16
  from copy import deepcopy
18
17
  from typing import Dict, List, Optional
@@ -21,8 +20,13 @@ from typing import Union
21
20
 
22
21
  import torch
23
22
  from compressed_tensors.config import CompressionFormat
23
+ from compressed_tensors.modeling import (
24
+ initialize_hooked_attention,
25
+ initialize_hooked_kv_cache,
26
+ )
24
27
  from compressed_tensors.quantization.lifecycle.initialize import (
25
28
  initialize_module_for_quantization,
29
+ is_attention_module,
26
30
  )
27
31
  from compressed_tensors.quantization.quant_args import QuantizationArgs
28
32
  from compressed_tensors.quantization.quant_config import (
@@ -30,14 +34,15 @@ from compressed_tensors.quantization.quant_config import (
30
34
  QuantizationStatus,
31
35
  )
32
36
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
33
- from compressed_tensors.quantization.utils import (
34
- KV_CACHE_TARGETS,
35
- is_kv_cache_quant_scheme,
36
- )
37
37
  from compressed_tensors.utils.helpers import replace_module
38
- from compressed_tensors.utils.match import match_named_modules, match_targets
38
+ from compressed_tensors.utils.match import (
39
+ is_narrow_match,
40
+ match_named_modules,
41
+ match_targets,
42
+ )
39
43
  from compressed_tensors.utils.offload import update_parameter_data
40
44
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
45
+ from loguru import logger
41
46
  from safetensors import safe_open
42
47
  from torch.nn import Module
43
48
 
@@ -53,9 +58,6 @@ from compressed_tensors.utils.safetensors_load import (
53
58
  )
54
59
 
55
60
 
56
- _LOGGER = logging.getLogger(__name__)
57
-
58
-
59
61
  def load_pretrained_quantization_parameters(
60
62
  model: Module,
61
63
  model_name_or_path: Optional[str] = None,
@@ -125,8 +127,14 @@ def apply_quantization_config(
125
127
  if config is None: # see PR #180
126
128
  return dict()
127
129
 
128
- # preprocess to support kv cache scheme
129
- config = process_quantization_config(config)
130
+ # force zero points during initialization
131
+ force_zero_point = config.quantization_status != QuantizationStatus.COMPRESSED
132
+
133
+ # apply and initialize kv cache quantization
134
+ if config.kv_cache_scheme is not None:
135
+ _apply_kv_cache_scheme(
136
+ model, config.kv_cache_scheme, config.quantization_status
137
+ )
130
138
 
131
139
  # build mapping of targets to schemes for easier matching
132
140
  # use ordered dict to preserve target ordering in config
@@ -162,49 +170,40 @@ def apply_quantization_config(
162
170
  replace_module(model, name, compressed_linear)
163
171
 
164
172
  else:
173
+ if is_attention_module(submodule) and is_narrow_match(
174
+ model, scheme.targets, name
175
+ ):
176
+ initialize_hooked_attention(model, submodule)
177
+
165
178
  initialize_module_for_quantization(
166
179
  submodule,
167
- force_zero_point=config.quantization_status
168
- != QuantizationStatus.COMPRESSED,
180
+ force_zero_point=force_zero_point,
169
181
  )
170
182
 
171
183
  submodule.quantization_status = config.quantization_status
172
184
 
173
185
 
174
- def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
175
- """
176
- Preprocess the raw QuantizationConfig
177
-
178
- :param config: the raw QuantizationConfig
179
- :return: the processed QuantizationConfig
180
- """
181
- if config.kv_cache_scheme is not None:
182
- config = process_kv_cache_config(config)
183
-
184
- return config
185
-
186
-
187
- def process_kv_cache_config(
188
- config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
189
- ) -> QuantizationConfig:
190
- """
191
- Reformulate the `config.kv_cache` as a `config_group`
192
- and add it to the set of existing `config.groups`
193
-
194
- :param config: the QuantizationConfig
195
- :return: the QuantizationConfig with additional "kv_cache" group
196
- """
197
- if targets == KV_CACHE_TARGETS:
198
- _LOGGER.info(f"KV cache targets set to default value of: {KV_CACHE_TARGETS}")
199
-
200
- kv_cache_dict = config.kv_cache_scheme.model_dump()
201
- kv_cache_scheme = QuantizationScheme(
202
- output_activations=QuantizationArgs(**kv_cache_dict),
203
- targets=targets,
186
+ def _apply_kv_cache_scheme(
187
+ model: torch.nn.Module,
188
+ kv_cache_scheme: QuantizationArgs,
189
+ status: QuantizationStatus,
190
+ ):
191
+ if not kv_cache_scheme.symmetric:
192
+ raise logger.warning("vLLM does not support asymmetric kv cache quantization")
193
+
194
+ # applies and initializes kv cache quantization
195
+ # this step cannot come after attention apply/initialize
196
+ # otherwise it will override the attention qparams
197
+ scheme = QuantizationScheme(
198
+ targets=[".*self_attn$"], # is never read in practice
199
+ input_activations=kv_cache_scheme,
204
200
  )
205
- kv_cache_group = dict(kv_cache=kv_cache_scheme)
206
- config.config_groups.update(kv_cache_group)
207
- return config
201
+ for submodule in model.modules():
202
+ if is_attention_module(submodule):
203
+ submodule.quantization_scheme = scheme
204
+ initialize_hooked_kv_cache(model, submodule)
205
+ initialize_module_for_quantization(submodule, force_zero_point=False)
206
+ submodule.quantization_status = status
208
207
 
209
208
 
210
209
  def _load_quant_args_from_mapping(
@@ -256,60 +255,6 @@ def _scheme_from_targets(
256
255
  targets: List[str],
257
256
  name: str,
258
257
  ) -> QuantizationScheme:
259
- if len(targets) == 1:
260
- # if `targets` iterable contains a single element
261
- # use it as the key
262
- return target_to_scheme[targets[0]]
263
-
264
- # otherwise, we need to merge QuantizationSchemes corresponding
265
- # to multiple targets. This is most likely because `name` module
266
- # is being target both as an ordinary quantization target, as well
267
- # as kv cache quantization target
268
- schemes_to_merge = [target_to_scheme[target] for target in targets]
269
- return _merge_schemes(schemes_to_merge, name)
270
-
271
-
272
- def _merge_schemes(
273
- schemes_to_merge: List[QuantizationScheme], name: str
274
- ) -> QuantizationScheme:
275
- kv_cache_quantization_scheme = [
276
- scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
277
- ]
278
- if not kv_cache_quantization_scheme:
279
- # if the schemes_to_merge do not contain any
280
- # kv cache QuantizationScheme
281
- # return the first scheme (the prioritized one,
282
- # since the order of schemes_to_merge matters)
283
- return schemes_to_merge[0]
284
- else:
285
- # fetch the kv cache QuantizationScheme and the highest
286
- # priority non-kv cache QuantizationScheme and merge them
287
- kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
288
- quantization_scheme = [
289
- scheme
290
- for scheme in schemes_to_merge
291
- if not is_kv_cache_quant_scheme(scheme)
292
- ][0]
293
- schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
294
- merged_scheme = {}
295
- for scheme in schemes_to_merge:
296
- scheme_dict = {
297
- k: v for k, v in scheme.model_dump().items() if v is not None
298
- }
299
- # when merging multiple schemes, the final target will be
300
- # the `name` argument - hence erase the original targets
301
- del scheme_dict["targets"]
302
- # make sure that schemes do not "clash" with each other
303
- overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
304
- if overlapping_keys:
305
- raise ValueError(
306
- f"The module: {name} is being modified by two clashing "
307
- f"quantization schemes, that jointly try to override "
308
- f"properties: {overlapping_keys}. Fix the quantization config "
309
- "so that it is not ambiguous."
310
- )
311
- merged_scheme.update(scheme_dict)
312
-
313
- merged_scheme.update(targets=[name])
314
-
315
- return QuantizationScheme(**merged_scheme)
258
+ # return the first scheme (the prioritized one,
259
+ # since the order of target_to_scheme matters)
260
+ return target_to_scheme[targets[0]]