compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/base.py +3 -1
  2. compressed_tensors/compressors/__init__.py +9 -1
  3. compressed_tensors/compressors/base.py +12 -55
  4. compressed_tensors/compressors/dense.py +5 -5
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/marlin_24.py +251 -0
  7. compressed_tensors/compressors/model_compressor.py +336 -0
  8. compressed_tensors/compressors/naive_quantized.py +144 -0
  9. compressed_tensors/compressors/pack_quantized.py +219 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/config/base.py +9 -4
  12. compressed_tensors/config/dense.py +4 -4
  13. compressed_tensors/config/sparse_bitmask.py +3 -3
  14. compressed_tensors/quantization/lifecycle/__init__.py +2 -0
  15. compressed_tensors/quantization/lifecycle/apply.py +204 -31
  16. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  17. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  18. compressed_tensors/quantization/lifecycle/forward.py +214 -62
  19. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  20. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +62 -5
  22. compressed_tensors/quantization/observers/base.py +66 -23
  23. compressed_tensors/quantization/observers/helpers.py +69 -11
  24. compressed_tensors/quantization/observers/memoryless.py +17 -9
  25. compressed_tensors/quantization/observers/min_max.py +44 -13
  26. compressed_tensors/quantization/quant_args.py +47 -3
  27. compressed_tensors/quantization/quant_config.py +104 -23
  28. compressed_tensors/quantization/quant_scheme.py +183 -2
  29. compressed_tensors/quantization/utils/helpers.py +142 -8
  30. compressed_tensors/utils/__init__.py +4 -0
  31. compressed_tensors/utils/helpers.py +54 -7
  32. compressed_tensors/utils/offload.py +104 -0
  33. compressed_tensors/utils/permutations_24.py +65 -0
  34. compressed_tensors/utils/safetensors_load.py +3 -2
  35. compressed_tensors/utils/semi_structured_conversions.py +341 -0
  36. compressed_tensors/version.py +53 -0
  37. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
  38. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  40. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  41. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  42. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -13,25 +13,31 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Dict, List, Optional
16
+ from typing import Dict, List, Optional, Union
17
17
 
18
- from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
19
- from compressed_tensors.quantization.quant_scheme import QuantizationScheme
18
+ from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
20
+ from compressed_tensors.quantization.quant_scheme import (
21
+ QuantizationScheme,
22
+ preset_name_to_scheme,
23
+ )
20
24
  from compressed_tensors.quantization.utils import (
21
25
  calculate_compression_ratio,
22
26
  is_module_quantized,
23
27
  iter_named_leaf_modules,
24
28
  module_type,
29
+ parse_out_kv_cache_args,
25
30
  )
26
31
  from pydantic import BaseModel, Field
27
32
  from torch.nn import Module
28
- from transformers import AutoConfig
29
33
 
30
34
 
31
35
  __all__ = [
32
36
  "QuantizationStatus",
33
37
  "QuantizationConfig",
34
38
  "LIFECYCLE_ORDER",
39
+ "DEFAULT_QUANTIZATION_METHOD",
40
+ "DEFAULT_QUANTIZATION_FORMAT",
35
41
  ]
36
42
 
37
43
 
@@ -62,10 +68,33 @@ class QuantizationStatus(str, Enum):
62
68
  return
63
69
 
64
70
  def __ge__(self, other):
71
+ if other is None:
72
+ return True
65
73
  if not isinstance(other, self.__class__):
66
74
  raise NotImplementedError
67
75
  return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
68
76
 
77
+ def __gt__(self, other):
78
+ if other is None:
79
+ return True
80
+ if not isinstance(other, self.__class__):
81
+ raise NotImplementedError
82
+ return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
83
+
84
+ def __lt__(self, other):
85
+ if other is None:
86
+ return False
87
+ if not isinstance(other, self.__class__):
88
+ raise NotImplementedError
89
+ return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
90
+
91
+ def __le__(self, other):
92
+ if other is None:
93
+ return False
94
+ if not isinstance(other, self.__class__):
95
+ raise NotImplementedError
96
+ return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
97
+
69
98
 
70
99
  LIFECYCLE_ORDER = [
71
100
  QuantizationStatus.INITIALIZED,
@@ -74,6 +103,9 @@ LIFECYCLE_ORDER = [
74
103
  QuantizationStatus.COMPRESSED,
75
104
  ]
76
105
 
106
+ DEFAULT_QUANTIZATION_METHOD = "compressed-tensors"
107
+ DEFAULT_QUANTIZATION_FORMAT = "fakequant"
108
+
77
109
 
78
110
  class QuantizationConfig(BaseModel):
79
111
  """
@@ -81,45 +113,62 @@ class QuantizationConfig(BaseModel):
81
113
  mapped to a QuantizationScheme in config_groups.
82
114
 
83
115
  :param config_groups: dict of QuantizationSchemes specifying the quantization
84
- settings for each quantized layer
116
+ settings for each quantized layer. A group could also be a reference to
117
+ a predefined scheme name, mapped to a list of its target layers/classes
85
118
  :param quant_method: a constant used to differentiate sparseML quantization from
86
119
  other quantization configs
87
120
  :param format: specifies how the quantized model is stored on disk
88
121
  :quantization_status: specifies the current status of all quantized layers. It is
89
- assumed all layers are in the same state.
122
+ assumed all layers are in the same state.
123
+ :param kv_cache_scheme: optional QuantizationArgs, that specify the
124
+ quantization of the kv cache. If None, kv cache is not quantized.
125
+ When applying kv cache quantization to transformer AutoModelForCausalLM,
126
+ the kv_cache_scheme gets converted into a QuantizationScheme that:
127
+ - targets the `q_proj` and `k_proj` modules of the model. The outputs
128
+ of those modules are the keys and values that might be cached
129
+ - quantizes the outputs of the aformentioned layers, so that
130
+ keys and values are compressed before storing them in the cache
131
+ There is an explicit assumption that the model contains modules with
132
+ `k_proj` and `v_proj` in their names. If this is not the case
133
+ and kv_cache_scheme != None, the quantization of kv cache will fail
90
134
  :global_compression_ratio: optional informational config to report the model
91
135
  compression ratio acheived by the quantization config
92
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
93
137
  are not quantized even if they match up with a target in config_groups
94
138
  """
95
139
 
96
- config_groups: Dict[str, QuantizationScheme]
97
- quant_method: str = "sparseml"
98
- format: str = "fakequant"
140
+ config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
141
+ quant_method: str = DEFAULT_QUANTIZATION_METHOD
142
+ kv_cache_scheme: Optional[QuantizationArgs] = None
143
+ format: str = DEFAULT_QUANTIZATION_FORMAT
99
144
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
100
145
  global_compression_ratio: Optional[float] = None
101
146
  ignore: Optional[List[str]] = Field(default_factory=list)
102
147
 
103
- @staticmethod
104
- def from_model_config(model_name_or_path) -> "QuantizationConfig":
148
+ def model_post_init(self, __context):
105
149
  """
106
- Given a path to a model config, extract a quantization config if it exists
107
-
108
- :param pretrained_model_name_or_path: path to model config on disk or HF hub
109
- :return: instantiated QuantizationConfig if config contains a quant config
150
+ updates any quantization schemes defined as presets to be fully loaded
151
+ schemes
110
152
  """
111
- config = AutoConfig.from_pretrained(model_name_or_path)
112
- quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
113
- if quantization_config is None:
114
- return None
115
-
116
- return QuantizationConfig.parse_obj(quantization_config)
153
+ for group_name, targets_or_scheme in self.config_groups.items():
154
+ if isinstance(targets_or_scheme, QuantizationScheme):
155
+ continue # scheme already defined
156
+ self.config_groups[group_name] = preset_name_to_scheme(
157
+ name=group_name,
158
+ targets=targets_or_scheme,
159
+ )
160
+
161
+ def to_dict(self):
162
+ # for compatibility with HFQuantizer
163
+ return self.dict()
117
164
 
118
165
  @staticmethod
119
- def from_pretrained(model: Module) -> "QuantizationConfig":
166
+ def from_pretrained(
167
+ model: Module, format: Optional[str] = None
168
+ ) -> Optional["QuantizationConfig"]:
120
169
  """
121
170
  Converts a model into its associated QuantizationConfig based on the
122
- QuantizationScheme attached to each quanitzed module
171
+ QuantizationScheme attached to each quantized module
123
172
 
124
173
  :param model: model to calculate quantization scheme of
125
174
  :return: filled out QuantizationScheme for the input model
@@ -147,6 +196,9 @@ class QuantizationConfig(BaseModel):
147
196
  if not match_found:
148
197
  quant_scheme_to_layers.append(scheme)
149
198
 
199
+ if len(quant_scheme_to_layers) == 0: # No quantized layers
200
+ return None
201
+
150
202
  # clean up ignore list, we can leave out layers types if none of the
151
203
  # instances are quantized
152
204
  consolidated_ignore = []
@@ -157,15 +209,44 @@ class QuantizationConfig(BaseModel):
157
209
  # else we leave it off the ignore list, doesn't fall under any of the
158
210
  # existing quantization schemes so it won't be quantized
159
211
 
212
+ kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
213
+ quant_scheme_to_layers
214
+ )
215
+ kv_cache_scheme = (
216
+ kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
217
+ )
218
+
160
219
  config_groups = {}
161
220
  for idx, scheme in enumerate(quant_scheme_to_layers):
162
221
  group_name = "group_" + str(idx)
163
222
  config_groups[group_name] = scheme
164
223
 
224
+ # TODO: this is incorrect in compressed mode, since we are overwriting the
225
+ # original weight we lose the uncompressed bit_depth indo
165
226
  compression_ratio = calculate_compression_ratio(model)
227
+
228
+ if format is None:
229
+ if quantization_status == QuantizationStatus.COMPRESSED:
230
+ format = CompressionFormat.int_quantized.value
231
+ else:
232
+ format = CompressionFormat.dense.value
233
+
166
234
  return QuantizationConfig(
167
235
  config_groups=config_groups,
168
236
  quantization_status=quantization_status,
237
+ kv_cache_scheme=kv_cache_scheme,
169
238
  global_compression_ratio=compression_ratio,
239
+ format=format,
170
240
  ignore=consolidated_ignore,
171
241
  )
242
+
243
+ def requires_calibration_data(self):
244
+ for _, scheme in self.config_groups.items():
245
+ if scheme.input_activations is not None:
246
+ if not scheme.input_activations.dynamic:
247
+ return True
248
+ if scheme.output_activations is not None:
249
+ if not scheme.output_activations.dynamic:
250
+ return True
251
+
252
+ return False
@@ -12,13 +12,22 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from copy import deepcopy
15
16
  from typing import List, Optional
16
17
 
17
- from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ QuantizationType,
22
+ )
18
23
  from pydantic import BaseModel
19
24
 
20
25
 
21
- __all__ = ["QuantizationScheme"]
26
+ __all__ = [
27
+ "QuantizationScheme",
28
+ "preset_name_to_scheme",
29
+ "is_preset_scheme",
30
+ ]
22
31
 
23
32
 
24
33
  class QuantizationScheme(BaseModel):
@@ -37,3 +46,175 @@ class QuantizationScheme(BaseModel):
37
46
  weights: Optional[QuantizationArgs] = None
38
47
  input_activations: Optional[QuantizationArgs] = None
39
48
  output_activations: Optional[QuantizationArgs] = None
49
+
50
+ @classmethod
51
+ def default_scheme(
52
+ cls,
53
+ targets: Optional[List[str]] = None,
54
+ ):
55
+
56
+ if targets is None:
57
+ # default to quantizing all Linear layers
58
+ targets = ["Linear"]
59
+
60
+ # default to 8 bit integer symmetric quantization
61
+ # for weights
62
+ weights = QuantizationArgs(num_bits=8, symmetric=True)
63
+
64
+ # default to 8 bit integer asymmetric quantization
65
+ input_activations = QuantizationArgs(num_bits=8, symmetric=True)
66
+
67
+ # Do not quantize the output activations
68
+ # by default
69
+ output_activations = None
70
+
71
+ return cls(
72
+ targets=targets,
73
+ weights=weights,
74
+ input_activations=input_activations,
75
+ output_activations=output_activations,
76
+ )
77
+
78
+
79
+ """
80
+ Pre-Set Quantization Scheme Args
81
+ """
82
+
83
+
84
+ def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
85
+ """
86
+ :param name: preset quantization settings name. must exist in upper case in
87
+ PRESET_SCHEMES
88
+ :param targets: list of quantization targets to be passed to the Scheme
89
+ :return: new QuantizationScheme for a given name with the given targets
90
+ """
91
+ name = name.upper()
92
+
93
+ if name not in PRESET_SCHEMES:
94
+ raise KeyError(
95
+ f"Unknown preset scheme name {name}, "
96
+ f"available names: {list(PRESET_SCHEMES.keys())}"
97
+ )
98
+
99
+ scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
100
+ return QuantizationScheme(
101
+ targets=targets,
102
+ **scheme_args,
103
+ )
104
+
105
+
106
+ def is_preset_scheme(name: str) -> bool:
107
+ """
108
+ :param name: preset quantization settings name
109
+ :return: True if the name is a preset scheme name
110
+ """
111
+ return name.upper() in PRESET_SCHEMES
112
+
113
+
114
+ # 8 bit integer weights and 8 bit activations quantization
115
+ W8A8 = dict(
116
+ weights=QuantizationArgs(
117
+ num_bits=8,
118
+ type=QuantizationType.INT,
119
+ strategy=QuantizationStrategy.CHANNEL,
120
+ symmetric=True,
121
+ dynamic=False,
122
+ ),
123
+ input_activations=QuantizationArgs(
124
+ num_bits=8,
125
+ type=QuantizationType.INT,
126
+ strategy=QuantizationStrategy.TOKEN,
127
+ symmetric=True,
128
+ dynamic=True,
129
+ ),
130
+ )
131
+
132
+ # 8 bit integer weights only quantization
133
+ W8A16 = dict(
134
+ weights=QuantizationArgs(
135
+ num_bits=8,
136
+ type=QuantizationType.INT,
137
+ strategy=QuantizationStrategy.CHANNEL,
138
+ symmetric=True,
139
+ dynamic=False,
140
+ ),
141
+ )
142
+
143
+ # 4 bit integer weights only quantization
144
+ W4A16 = dict(
145
+ weights=QuantizationArgs(
146
+ num_bits=4,
147
+ type=QuantizationType.INT,
148
+ strategy=QuantizationStrategy.GROUP,
149
+ group_size=128,
150
+ symmetric=True,
151
+ dynamic=False,
152
+ ),
153
+ )
154
+
155
+ # 4 bit integer weights and 8 bit activations quantization
156
+ W4A8 = dict(
157
+ weights=QuantizationArgs(
158
+ num_bits=4,
159
+ type=QuantizationType.INT,
160
+ group_size=128,
161
+ strategy=QuantizationStrategy.GROUP,
162
+ symmetric=True,
163
+ dynamic=False,
164
+ ),
165
+ input_activations=QuantizationArgs(
166
+ num_bits=8,
167
+ type=QuantizationType.INT,
168
+ strategy=QuantizationStrategy.TOKEN,
169
+ symmetric=True,
170
+ dynamic=True,
171
+ ),
172
+ )
173
+
174
+ # FP8 weights and FP8 activations quantization
175
+ FP8 = dict(
176
+ weights=QuantizationArgs(
177
+ num_bits=8,
178
+ type=QuantizationType.FLOAT,
179
+ strategy=QuantizationStrategy.TENSOR,
180
+ symmetric=True,
181
+ dynamic=False,
182
+ ),
183
+ input_activations=QuantizationArgs(
184
+ num_bits=8,
185
+ type=QuantizationType.FLOAT,
186
+ strategy=QuantizationStrategy.TENSOR,
187
+ symmetric=True,
188
+ dynamic=False,
189
+ ),
190
+ )
191
+
192
+ # FP8 weights and FP8 dynamic activations quantization
193
+ FP8_DYNAMIC = dict(
194
+ weights=QuantizationArgs(
195
+ num_bits=8,
196
+ type=QuantizationType.FLOAT,
197
+ strategy=QuantizationStrategy.CHANNEL,
198
+ symmetric=True,
199
+ dynamic=False,
200
+ ),
201
+ input_activations=QuantizationArgs(
202
+ num_bits=8,
203
+ type=QuantizationType.FLOAT,
204
+ strategy=QuantizationStrategy.TOKEN,
205
+ symmetric=True,
206
+ dynamic=True,
207
+ ),
208
+ )
209
+
210
+ PRESET_SCHEMES = {
211
+ # Integer weight only schemes
212
+ "W8A16": W8A16,
213
+ "W4A16": W4A16,
214
+ # Integer weight and activation schemes
215
+ "W8A8": W8A8,
216
+ "W4A8": W4A8,
217
+ # Float weight and activation schemes
218
+ "FP8": FP8,
219
+ "FP8_DYNAMIC": FP8_DYNAMIC,
220
+ }
@@ -12,21 +12,50 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ import logging
16
+ import re
17
+ from typing import List, Optional, Tuple
16
18
 
17
19
  import torch
20
+ from compressed_tensors.quantization.observers.base import Observer
21
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
22
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
18
23
  from torch.nn import Module
19
24
  from tqdm import tqdm
20
25
 
21
26
 
22
27
  __all__ = [
28
+ "infer_quantization_status",
23
29
  "is_module_quantized",
24
30
  "is_model_quantized",
25
31
  "iter_named_leaf_modules",
26
32
  "module_type",
27
33
  "calculate_compression_ratio",
34
+ "get_torch_bit_depth",
35
+ "can_quantize",
36
+ "parse_out_kv_cache_args",
37
+ "KV_CACHE_TARGETS",
38
+ "is_kv_cache_quant_scheme",
28
39
  ]
29
40
 
41
+ KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
42
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
43
+
44
+
45
+ def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
46
+ """
47
+ Checks the quantization status of a model. Assumes all modules in the model have
48
+ the same status, so only the first quantized model is checked.
49
+
50
+ :param model: model to check quantization status for
51
+ :return: quantization status if the model is quantized, otherwise None
52
+ """
53
+ for module in model.modules():
54
+ status = getattr(module, "quantization_status", None)
55
+ if status is not None:
56
+ return status
57
+ return None
58
+
30
59
 
31
60
  def is_module_quantized(module: Module) -> bool:
32
61
  """
@@ -78,11 +107,60 @@ def module_type(module: Module) -> str:
78
107
 
79
108
 
80
109
  def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
81
- # yields modules that do not have any submodules
82
- # TODO: potentially expand to add list of allowed submodules such as observers
110
+ """
111
+ Yields modules that do not have any submodules except observers. The observers
112
+ themselves are not yielded
113
+
114
+ :param model: model to get leaf modules of
115
+ :returns: generator tuple of (name, leaf_submodule)
116
+ """
83
117
  for name, submodule in model.named_modules():
84
- if len(list(submodule.children())) == 0:
118
+ children = list(submodule.children())
119
+ if len(children) == 0 and not isinstance(submodule, Observer):
85
120
  yield name, submodule
121
+ else:
122
+ has_non_observer_children = False
123
+ for child in children:
124
+ if not isinstance(child, Observer):
125
+ has_non_observer_children = True
126
+
127
+ if not has_non_observer_children:
128
+ yield name, submodule
129
+
130
+
131
+ def get_torch_bit_depth(value: torch.Tensor) -> int:
132
+ """
133
+ Determine the number of bits used to represent the dtype of a tensor
134
+
135
+ :param value: tensor to check bit depth of
136
+ :return: bit depth of each element in the value tensor
137
+ """
138
+ try:
139
+ bit_depth = torch.finfo(value.dtype).bits
140
+ except TypeError:
141
+ bit_depth = torch.iinfo(value.dtype).bits
142
+
143
+ return bit_depth
144
+
145
+
146
+ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa
147
+ """
148
+ Checks if value can be quantized by quant_args.
149
+
150
+ :param value: tensor to check for quantization
151
+ :param quant_args: QuantizationArgs to use for quantization
152
+ :return: False if value is already quantized to quant_args or value is incompatible
153
+ with quant_args, True if value can be quantized with quant_args
154
+ """
155
+ bit_depth = get_torch_bit_depth(value)
156
+ requested_depth = quant_args.num_bits
157
+ if bit_depth < quant_args.num_bits:
158
+ _LOGGER.warn(
159
+ f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
160
+ "The QuantizationArgs provided are not compatible with the input tensor."
161
+ )
162
+
163
+ return bit_depth > quant_args.num_bits
86
164
 
87
165
 
88
166
  def calculate_compression_ratio(model: Module) -> float:
@@ -101,10 +179,7 @@ def calculate_compression_ratio(model: Module) -> float:
101
179
  desc="Calculating quantization compression ratio",
102
180
  ):
103
181
  for parameter in model.parameters():
104
- try:
105
- uncompressed_bits = torch.finfo(parameter.dtype).bits
106
- except TypeError:
107
- uncompressed_bits = torch.iinfo(parameter.dtype).bits
182
+ uncompressed_bits = get_torch_bit_depth(parameter)
108
183
  compressed_bits = uncompressed_bits
109
184
  if is_module_quantized(submodule):
110
185
  compressed_bits = submodule.quantization_scheme.weights.num_bits
@@ -114,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
114
189
  total_uncompressed += uncompressed_bits * num_weights
115
190
 
116
191
  return total_uncompressed / total_compressed
192
+
193
+
194
+ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
195
+ """
196
+ Check whether the QuantizationScheme targets the kv cache.
197
+ It does if all the following criteria are met:
198
+ - the scheme targets either exactly match the KV_CACHE_TARGETS
199
+ or the match KV_CACHE_TARGETS regex pattern
200
+ - the scheme quantizes output_activations (we want to quantize the
201
+ outputs from the KV_CACHE_TARGETS, as their correspond to the
202
+ keys and values that are to be saved in the cache)
203
+
204
+ :param scheme: The QuantizationScheme to investigate
205
+ :return: boolean flag
206
+ """
207
+ if len(scheme.targets) == 1:
208
+ # match on the KV_CACHE_TARGETS regex pattern
209
+ # if there is only one target
210
+ is_match_targets = any(
211
+ [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
+ )
213
+ else:
214
+ # match on the exact KV_CACHE_TARGETS
215
+ # if there are multiple targets
216
+ is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
217
+
218
+ is_match_output_activations = scheme.output_activations is not None
219
+ return is_match_targets and is_match_output_activations
220
+
221
+
222
+ def parse_out_kv_cache_args(
223
+ quant_scheme_to_layers: List[QuantizationScheme],
224
+ ) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
225
+ """
226
+ If possible, parse out the kv cache specific QuantizationArgs
227
+ from the list of the QuantizationSchemes. If no kv cache
228
+ specific QuantizationArgs available, this function acts
229
+ as an identity function
230
+
231
+ :param quant_scheme_to_layers: list of QuantizationSchemes
232
+ :return: kv_cache_args (optional) and the (remaining or original)
233
+ list of the QuantizationSchemes
234
+ """
235
+ kv_cache_quant_scheme_to_layers = [
236
+ scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
237
+ ]
238
+ quant_scheme_to_layers = [
239
+ scheme
240
+ for scheme in quant_scheme_to_layers
241
+ if not is_kv_cache_quant_scheme(scheme)
242
+ ]
243
+
244
+ if kv_cache_quant_scheme_to_layers:
245
+ kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
246
+ kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
247
+ else:
248
+ kv_cache_args = None
249
+
250
+ return kv_cache_args, quant_scheme_to_layers
@@ -13,4 +13,8 @@
13
13
  # limitations under the License.
14
14
  # flake8: noqa
15
15
 
16
+ from .helpers import *
17
+ from .offload import *
18
+ from .permutations_24 import *
16
19
  from .safetensors_load import *
20
+ from .semi_structured_conversions import *