compressed-tensors 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +1 -1
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +33 -12
  7. compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +33 -15
  8. compressed_tensors/compressors/pack_quantized.py +58 -51
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +2 -0
  11. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  12. compressed_tensors/quantization/lifecycle/apply.py +161 -39
  13. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  14. compressed_tensors/quantization/lifecycle/forward.py +70 -25
  15. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  16. compressed_tensors/quantization/lifecycle/initialize.py +30 -1
  17. compressed_tensors/quantization/observers/base.py +39 -0
  18. compressed_tensors/quantization/observers/helpers.py +64 -11
  19. compressed_tensors/quantization/quant_args.py +45 -1
  20. compressed_tensors/quantization/quant_config.py +35 -2
  21. compressed_tensors/quantization/quant_scheme.py +105 -4
  22. compressed_tensors/quantization/utils/helpers.py +67 -1
  23. compressed_tensors/utils/__init__.py +4 -0
  24. compressed_tensors/utils/helpers.py +31 -2
  25. compressed_tensors/utils/offload.py +104 -0
  26. compressed_tensors/version.py +1 -1
  27. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +2 -1
  28. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  29. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  30. compressed_tensors/compressors/utils/__init__.py +0 -19
  31. compressed_tensors/compressors/utils/helpers.py +0 -43
  32. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  33. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  34. /compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +0 -0
  35. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import logging
15
16
  from typing import Any, Iterable, Optional, Tuple, Union
16
17
 
17
18
  import torch
@@ -24,6 +25,9 @@ from torch import FloatTensor, IntTensor, Tensor
24
25
  from torch.nn import Module
25
26
 
26
27
 
28
+ _LOGGER = logging.getLogger(__name__)
29
+
30
+
27
31
  __all__ = ["Observer"]
28
32
 
29
33
 
@@ -39,6 +43,7 @@ class Observer(Module, RegistryMixin):
39
43
  super().__init__()
40
44
  self._scale = None
41
45
  self._zero_point = None
46
+ self._num_observed_tokens = None
42
47
 
43
48
  @torch.no_grad()
44
49
  def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
@@ -48,6 +53,7 @@ class Observer(Module, RegistryMixin):
48
53
  from
49
54
  :return: tuple of scale and zero point based on last observed value
50
55
  """
56
+ self.record_observed_tokens(observed)
51
57
  return self.get_qparams(observed=observed)
52
58
 
53
59
  def calculate_qparams(
@@ -132,3 +138,36 @@ class Observer(Module, RegistryMixin):
132
138
  return self.calculate_qparams(
133
139
  observed, reduce_dims=reduce_dims, tensor_id=tensor_id
134
140
  )
141
+
142
+ def record_observed_tokens(self, batch_tensor: Tensor):
143
+ """
144
+ Counts the number of tokens observed during the
145
+ forward passes. The count is aggregated in the
146
+ _num_observed_tokens attribute of the class.
147
+
148
+ Note: The batch_tensor is expected to have two dimensions
149
+ (batch_size * sequence_length, num_features). This is the
150
+ general shape expected by the forward pass of the expert
151
+ layers in a MOE model. If the input tensor does not have
152
+ two dimensions, the _num_observed_tokens attribute will be set
153
+ to None.
154
+ """
155
+ if not isinstance(batch_tensor, Tensor):
156
+ raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
157
+
158
+ if batch_tensor.ndim != 2:
159
+ _LOGGER.debug(
160
+ "The input tensor is expected to have two dimensions "
161
+ "(batch_size * sequence_length, num_features). "
162
+ f"The input tensor has {batch_tensor.ndim} dimensions."
163
+ )
164
+ return
165
+
166
+ if self._num_observed_tokens is None:
167
+ # initialize the count
168
+ self._num_observed_tokens = 0
169
+
170
+ # batch_tensor (batch_size * sequence_length, num_features)
171
+ # observed_tokens (batch_size * sequence_length)
172
+ observed_tokens, _ = batch_tensor.shape
173
+ self._num_observed_tokens += observed_tokens
@@ -12,23 +12,45 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from collections import Counter
15
16
  from typing import Tuple
16
17
 
17
18
  import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import (
20
+ FP8_DTYPE,
21
+ QuantizationArgs,
22
+ QuantizationType,
23
+ )
19
24
  from torch import FloatTensor, IntTensor, Tensor
20
25
 
21
26
 
22
- __all__ = ["calculate_qparams"]
27
+ __all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28
+
29
+
30
+ def get_observer_token_count(module: torch.nn.Module) -> Counter:
31
+ """
32
+ Parse the module and return the number of tokens observed by
33
+ each module's observer.
34
+
35
+ :param module: module to parse
36
+ :return: counter with the number of tokens observed by each observer
37
+ """
38
+ token_counts = Counter()
39
+ for name, module in module.named_modules():
40
+ if name.endswith(".input_observer"):
41
+ token_counts[
42
+ name.replace(".input_observer", "")
43
+ ] = module._num_observed_tokens
44
+ return token_counts
23
45
 
24
46
 
25
47
  def calculate_qparams(
26
48
  min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
27
49
  ) -> Tuple[FloatTensor, IntTensor]:
28
50
  """
29
- :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
51
+ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
30
52
  from
31
- :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
53
+ :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
32
54
  from
33
55
  :param quantization_args: settings to quantization
34
56
  :return: tuple of the calculated scale(s) and zero point(s)
@@ -37,22 +59,53 @@ def calculate_qparams(
37
59
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
60
  device = min_vals.device
39
61
 
40
- bit_range = 2**quantization_args.num_bits - 1
41
- bit_min = -(bit_range + 1) / 2
42
- bit_max = bit_min + bit_range
62
+ bit_min, bit_max = calculate_range(quantization_args, device)
63
+ bit_range = bit_max - bit_min
64
+ zp_dtype = quantization_args.pytorch_dtype()
65
+
43
66
  if quantization_args.symmetric:
44
- max_val_pos = torch.max(-min_vals, max_vals)
67
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
45
68
  scales = max_val_pos / (float(bit_range) / 2)
46
69
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
- zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
70
+ zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
48
71
  else:
49
72
  scales = (max_vals - min_vals) / float(bit_range)
50
73
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
51
- zero_points = bit_min - torch.round(min_vals / scales)
52
- zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
74
+ zero_points = bit_min - (min_vals / scales)
75
+ zero_points = torch.clamp(zero_points, bit_min, bit_max)
76
+
77
+ # match zero-points to quantized type
78
+ zero_points = zero_points.to(zp_dtype)
53
79
 
54
80
  if scales.ndim == 0:
55
81
  scales = scales.reshape(1)
56
82
  zero_points = zero_points.reshape(1)
57
83
 
58
84
  return scales, zero_points
85
+
86
+
87
+ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
88
+ """
89
+ Calculated the effective quantization range for the given Quantization Args
90
+
91
+ :param quantization_args: quantization args to get range of
92
+ :param device: device to store the range to
93
+ :return: tuple endpoints for the given quantization range
94
+ """
95
+ if quantization_args.type == QuantizationType.INT:
96
+ bit_range = 2**quantization_args.num_bits
97
+ q_max = torch.tensor(bit_range / 2 - 1, device=device)
98
+ q_min = torch.tensor(-bit_range / 2, device=device)
99
+ elif quantization_args.type == QuantizationType.FLOAT:
100
+ if quantization_args.num_bits != 8:
101
+ raise ValueError(
102
+ "Floating point quantization is only supported for 8 bits,"
103
+ f"got {quantization_args.num_bits}"
104
+ )
105
+ fp_range_info = torch.finfo(FP8_DTYPE)
106
+ q_max = torch.tensor(fp_range_info.max, device=device)
107
+ q_min = torch.tensor(fp_range_info.min, device=device)
108
+ else:
109
+ raise ValueError(f"Invalid quantization type {quantization_args.type}")
110
+
111
+ return q_min, q_max
@@ -15,10 +15,19 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, Optional
17
17
 
18
+ import torch
18
19
  from pydantic import BaseModel, Field, validator
19
20
 
20
21
 
21
- __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
22
+ __all__ = [
23
+ "FP8_DTYPE",
24
+ "QuantizationType",
25
+ "QuantizationStrategy",
26
+ "QuantizationArgs",
27
+ "round_to_quantized_type",
28
+ ]
29
+
30
+ FP8_DTYPE = torch.float8_e4m3fn
22
31
 
23
32
 
24
33
  class QuantizationType(str, Enum):
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
123
132
  return QuantizationStrategy.TENSOR
124
133
 
125
134
  return value
135
+
136
+ def pytorch_dtype(self) -> torch.dtype:
137
+ if self.type == QuantizationType.FLOAT:
138
+ return FP8_DTYPE
139
+ elif self.type == QuantizationType.INT:
140
+ if self.num_bits <= 8:
141
+ return torch.int8
142
+ elif self.num_bits <= 16:
143
+ return torch.int16
144
+ else:
145
+ return torch.int32
146
+ else:
147
+ raise ValueError(f"Invalid quantization type {self.type}")
148
+
149
+
150
+ def round_to_quantized_type(
151
+ tensor: torch.Tensor, args: QuantizationArgs
152
+ ) -> torch.Tensor:
153
+ """
154
+ Rounds each element of the input tensor to the nearest quantized representation,
155
+ keeping to original dtype
156
+
157
+ :param tensor: tensor to round
158
+ :param args: QuantizationArgs to pull appropriate dtype from
159
+ :return: rounded tensor
160
+ """
161
+ original_dtype = tensor.dtype
162
+ if args.type == QuantizationType.FLOAT:
163
+ rounded = tensor.to(FP8_DTYPE)
164
+ elif args.type == QuantizationType.INT:
165
+ rounded = torch.round(tensor)
166
+ else:
167
+ raise ValueError(f"Invalid quantization type {args.type}")
168
+
169
+ return rounded.to(original_dtype)
@@ -16,6 +16,7 @@ from enum import Enum
16
16
  from typing import Dict, List, Optional, Union
17
17
 
18
18
  from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
19
20
  from compressed_tensors.quantization.quant_scheme import (
20
21
  QuantizationScheme,
21
22
  preset_name_to_scheme,
@@ -25,6 +26,7 @@ from compressed_tensors.quantization.utils import (
25
26
  is_module_quantized,
26
27
  iter_named_leaf_modules,
27
28
  module_type,
29
+ parse_out_kv_cache_args,
28
30
  )
29
31
  from pydantic import BaseModel, Field
30
32
  from torch.nn import Module
@@ -117,7 +119,18 @@ class QuantizationConfig(BaseModel):
117
119
  other quantization configs
118
120
  :param format: specifies how the quantized model is stored on disk
119
121
  :quantization_status: specifies the current status of all quantized layers. It is
120
- assumed all layers are in the same state.
122
+ assumed all layers are in the same state.
123
+ :param kv_cache_scheme: optional QuantizationArgs, that specify the
124
+ quantization of the kv cache. If None, kv cache is not quantized.
125
+ When applying kv cache quantization to transformer AutoModelForCausalLM,
126
+ the kv_cache_scheme gets converted into a QuantizationScheme that:
127
+ - targets the `q_proj` and `k_proj` modules of the model. The outputs
128
+ of those modules are the keys and values that might be cached
129
+ - quantizes the outputs of the aformentioned layers, so that
130
+ keys and values are compressed before storing them in the cache
131
+ There is an explicit assumption that the model contains modules with
132
+ `k_proj` and `v_proj` in their names. If this is not the case
133
+ and kv_cache_scheme != None, the quantization of kv cache will fail
121
134
  :global_compression_ratio: optional informational config to report the model
122
135
  compression ratio acheived by the quantization config
123
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
@@ -126,6 +139,7 @@ class QuantizationConfig(BaseModel):
126
139
 
127
140
  config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
128
141
  quant_method: str = DEFAULT_QUANTIZATION_METHOD
142
+ kv_cache_scheme: Optional[QuantizationArgs] = None
129
143
  format: str = DEFAULT_QUANTIZATION_FORMAT
130
144
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
131
145
  global_compression_ratio: Optional[float] = None
@@ -154,7 +168,7 @@ class QuantizationConfig(BaseModel):
154
168
  ) -> Optional["QuantizationConfig"]:
155
169
  """
156
170
  Converts a model into its associated QuantizationConfig based on the
157
- QuantizationScheme attached to each quanitzed module
171
+ QuantizationScheme attached to each quantized module
158
172
 
159
173
  :param model: model to calculate quantization scheme of
160
174
  :return: filled out QuantizationScheme for the input model
@@ -195,6 +209,13 @@ class QuantizationConfig(BaseModel):
195
209
  # else we leave it off the ignore list, doesn't fall under any of the
196
210
  # existing quantization schemes so it won't be quantized
197
211
 
212
+ kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
213
+ quant_scheme_to_layers
214
+ )
215
+ kv_cache_scheme = (
216
+ kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
217
+ )
218
+
198
219
  config_groups = {}
199
220
  for idx, scheme in enumerate(quant_scheme_to_layers):
200
221
  group_name = "group_" + str(idx)
@@ -213,7 +234,19 @@ class QuantizationConfig(BaseModel):
213
234
  return QuantizationConfig(
214
235
  config_groups=config_groups,
215
236
  quantization_status=quantization_status,
237
+ kv_cache_scheme=kv_cache_scheme,
216
238
  global_compression_ratio=compression_ratio,
217
239
  format=format,
218
240
  ignore=consolidated_ignore,
219
241
  )
242
+
243
+ def requires_calibration_data(self):
244
+ for _, scheme in self.config_groups.items():
245
+ if scheme.input_activations is not None:
246
+ if not scheme.input_activations.dynamic:
247
+ return True
248
+ if scheme.output_activations is not None:
249
+ if not scheme.output_activations.dynamic:
250
+ return True
251
+
252
+ return False
@@ -15,7 +15,11 @@
15
15
  from copy import deepcopy
16
16
  from typing import List, Optional
17
17
 
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ QuantizationType,
22
+ )
19
23
  from pydantic import BaseModel
20
24
 
21
25
 
@@ -107,13 +111,110 @@ def is_preset_scheme(name: str) -> bool:
107
111
  return name.upper() in PRESET_SCHEMES
108
112
 
109
113
 
114
+ # 8 bit integer weights and 8 bit activations quantization
110
115
  W8A8 = dict(
111
- weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True)
116
+ weights=QuantizationArgs(
117
+ num_bits=8,
118
+ type=QuantizationType.INT,
119
+ strategy=QuantizationStrategy.CHANNEL,
120
+ symmetric=True,
121
+ dynamic=False,
122
+ ),
123
+ input_activations=QuantizationArgs(
124
+ num_bits=8,
125
+ type=QuantizationType.INT,
126
+ strategy=QuantizationStrategy.TOKEN,
127
+ symmetric=True,
128
+ dynamic=True,
129
+ ),
130
+ )
131
+
132
+ # 8 bit integer weights only quantization
133
+ W8A16 = dict(
134
+ weights=QuantizationArgs(
135
+ num_bits=8,
136
+ type=QuantizationType.INT,
137
+ strategy=QuantizationStrategy.CHANNEL,
138
+ symmetric=True,
139
+ dynamic=False,
140
+ ),
141
+ )
142
+
143
+ # 4 bit integer weights only quantization
144
+ W4A16 = dict(
145
+ weights=QuantizationArgs(
146
+ num_bits=4,
147
+ type=QuantizationType.INT,
148
+ strategy=QuantizationStrategy.GROUP,
149
+ group_size=128,
150
+ symmetric=True,
151
+ dynamic=False,
152
+ ),
153
+ )
154
+
155
+ # 4 bit integer weights and 8 bit activations quantization
156
+ W4A8 = dict(
157
+ weights=QuantizationArgs(
158
+ num_bits=4,
159
+ type=QuantizationType.INT,
160
+ group_size=128,
161
+ strategy=QuantizationStrategy.GROUP,
162
+ symmetric=True,
163
+ dynamic=False,
164
+ ),
165
+ input_activations=QuantizationArgs(
166
+ num_bits=8,
167
+ type=QuantizationType.INT,
168
+ strategy=QuantizationStrategy.TOKEN,
169
+ symmetric=True,
170
+ dynamic=True,
171
+ ),
172
+ )
173
+
174
+ # FP8 weights and FP8 activations quantization
175
+ FP8 = dict(
176
+ weights=QuantizationArgs(
177
+ num_bits=8,
178
+ type=QuantizationType.FLOAT,
179
+ strategy=QuantizationStrategy.TENSOR,
180
+ symmetric=True,
181
+ dynamic=False,
182
+ ),
183
+ input_activations=QuantizationArgs(
184
+ num_bits=8,
185
+ type=QuantizationType.FLOAT,
186
+ strategy=QuantizationStrategy.TENSOR,
187
+ symmetric=True,
188
+ dynamic=False,
189
+ ),
112
190
  )
113
191
 
114
- W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
192
+ # FP8 weights and FP8 dynamic activations quantization
193
+ FP8_DYNAMIC = dict(
194
+ weights=QuantizationArgs(
195
+ num_bits=8,
196
+ type=QuantizationType.FLOAT,
197
+ strategy=QuantizationStrategy.CHANNEL,
198
+ symmetric=True,
199
+ dynamic=False,
200
+ ),
201
+ input_activations=QuantizationArgs(
202
+ num_bits=8,
203
+ type=QuantizationType.FLOAT,
204
+ strategy=QuantizationStrategy.TOKEN,
205
+ symmetric=True,
206
+ dynamic=True,
207
+ ),
208
+ )
115
209
 
116
210
  PRESET_SCHEMES = {
117
- "W8A8": W8A8,
211
+ # Integer weight only schemes
212
+ "W8A16": W8A16,
118
213
  "W4A16": W4A16,
214
+ # Integer weight and activation schemes
215
+ "W8A8": W8A8,
216
+ "W4A8": W4A8,
217
+ # Float weight and activation schemes
218
+ "FP8": FP8,
219
+ "FP8_DYNAMIC": FP8_DYNAMIC,
119
220
  }
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Optional, Tuple
16
+ import re
17
+ from typing import List, Optional, Tuple
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.quantization.observers.base import Observer
21
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
22
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
20
23
  from torch.nn import Module
21
24
  from tqdm import tqdm
22
25
 
@@ -30,8 +33,12 @@ __all__ = [
30
33
  "calculate_compression_ratio",
31
34
  "get_torch_bit_depth",
32
35
  "can_quantize",
36
+ "parse_out_kv_cache_args",
37
+ "KV_CACHE_TARGETS",
38
+ "is_kv_cache_quant_scheme",
33
39
  ]
34
40
 
41
+ KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
35
42
  _LOGGER: logging.Logger = logging.getLogger(__name__)
36
43
 
37
44
 
@@ -182,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
182
189
  total_uncompressed += uncompressed_bits * num_weights
183
190
 
184
191
  return total_uncompressed / total_compressed
192
+
193
+
194
+ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
195
+ """
196
+ Check whether the QuantizationScheme targets the kv cache.
197
+ It does if all the following criteria are met:
198
+ - the scheme targets either exactly match the KV_CACHE_TARGETS
199
+ or the match KV_CACHE_TARGETS regex pattern
200
+ - the scheme quantizes output_activations (we want to quantize the
201
+ outputs from the KV_CACHE_TARGETS, as their correspond to the
202
+ keys and values that are to be saved in the cache)
203
+
204
+ :param scheme: The QuantizationScheme to investigate
205
+ :return: boolean flag
206
+ """
207
+ if len(scheme.targets) == 1:
208
+ # match on the KV_CACHE_TARGETS regex pattern
209
+ # if there is only one target
210
+ is_match_targets = any(
211
+ [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
+ )
213
+ else:
214
+ # match on the exact KV_CACHE_TARGETS
215
+ # if there are multiple targets
216
+ is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
217
+
218
+ is_match_output_activations = scheme.output_activations is not None
219
+ return is_match_targets and is_match_output_activations
220
+
221
+
222
+ def parse_out_kv_cache_args(
223
+ quant_scheme_to_layers: List[QuantizationScheme],
224
+ ) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
225
+ """
226
+ If possible, parse out the kv cache specific QuantizationArgs
227
+ from the list of the QuantizationSchemes. If no kv cache
228
+ specific QuantizationArgs available, this function acts
229
+ as an identity function
230
+
231
+ :param quant_scheme_to_layers: list of QuantizationSchemes
232
+ :return: kv_cache_args (optional) and the (remaining or original)
233
+ list of the QuantizationSchemes
234
+ """
235
+ kv_cache_quant_scheme_to_layers = [
236
+ scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
237
+ ]
238
+ quant_scheme_to_layers = [
239
+ scheme
240
+ for scheme in quant_scheme_to_layers
241
+ if not is_kv_cache_quant_scheme(scheme)
242
+ ]
243
+
244
+ if kv_cache_quant_scheme_to_layers:
245
+ kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
246
+ kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
247
+ else:
248
+ kv_cache_args = None
249
+
250
+ return kv_cache_args, quant_scheme_to_layers
@@ -13,4 +13,8 @@
13
13
  # limitations under the License.
14
14
  # flake8: noqa
15
15
 
16
+ from .helpers import *
17
+ from .offload import *
18
+ from .permutations_24 import *
16
19
  from .safetensors_load import *
20
+ from .semi_structured_conversions import *
@@ -12,13 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
-
16
15
  from typing import Optional
17
16
 
17
+ import torch
18
18
  from transformers import AutoConfig
19
19
 
20
20
 
21
- __all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]
21
+ __all__ = [
22
+ "infer_compressor_from_model_config",
23
+ "fix_fsdp_module_name",
24
+ "tensor_follows_mask_structure",
25
+ ]
22
26
 
23
27
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
24
28
 
@@ -61,3 +65,28 @@ def fix_fsdp_module_name(name: str) -> str:
61
65
  return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
62
66
  "." + FSDP_WRAPPER_NAME, ""
63
67
  )
68
+
69
+
70
+ def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
71
+ """
72
+ :param tensor: tensor to check
73
+ :param mask: mask structure to check for, in the format "n:m"
74
+ :return: True if the tensor follows the mask structure, False otherwise.
75
+ Note, some weights can incidentally be zero, so we check for
76
+ atleast n zeros in each chunk of size m
77
+ """
78
+
79
+ n, m = tuple(map(int, mask.split(":")))
80
+ # Reshape the tensor into chunks of size m
81
+ tensor = tensor.view(-1, m)
82
+
83
+ # Count the number of zeros in each chunk
84
+ zero_counts = (tensor == 0).sum(dim=1)
85
+
86
+ # Check if the number of zeros in each chunk atleast n
87
+ # Greater than sign is needed as some weights can incidentally
88
+ # be zero
89
+ if not torch.all(zero_counts >= n).item():
90
+ raise ValueError()
91
+
92
+ return True