compressed-tensors 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +38 -102
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  10. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
  11. compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/linear/compressed_linear.py +2 -2
  19. compressed_tensors/quantization/__init__.py +1 -0
  20. compressed_tensors/quantization/cache.py +201 -0
  21. compressed_tensors/quantization/lifecycle/apply.py +19 -3
  22. compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  23. compressed_tensors/quantization/lifecycle/forward.py +58 -7
  24. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  25. compressed_tensors/quantization/lifecycle/helpers.py +0 -47
  26. compressed_tensors/quantization/lifecycle/initialize.py +116 -67
  27. compressed_tensors/quantization/observers/__init__.py +0 -1
  28. compressed_tensors/quantization/observers/helpers.py +40 -2
  29. compressed_tensors/quantization/quant_args.py +34 -4
  30. compressed_tensors/quantization/quant_config.py +14 -2
  31. compressed_tensors/quantization/quant_scheme.py +8 -4
  32. compressed_tensors/quantization/utils/helpers.py +43 -18
  33. compressed_tensors/utils/helpers.py +17 -1
  34. compressed_tensors/version.py +1 -1
  35. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/METADATA +1 -1
  36. compressed_tensors-0.7.1.dist-info/RECORD +58 -0
  37. compressed_tensors/quantization/observers/memoryless.py +0 -56
  38. compressed_tensors-0.6.0.dist-info/RECORD +0 -52
  39. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/LICENSE +0 -0
  40. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/WHEEL +0 -0
  41. {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,10 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.cache import KVCacheScaleType
20
21
  from compressed_tensors.quantization.lifecycle.forward import (
21
22
  wrap_module_forward_quantized,
23
+ wrap_module_forward_quantized_attn,
22
24
  )
23
25
  from compressed_tensors.quantization.quant_args import (
24
26
  ActivationOrdering,
@@ -27,6 +29,7 @@ from compressed_tensors.quantization.quant_args import (
27
29
  )
28
30
  from compressed_tensors.quantization.quant_config import QuantizationStatus
29
31
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
30
33
  from compressed_tensors.utils import get_execution_device, is_module_offloaded
31
34
  from torch.nn import Module, Parameter
32
35
 
@@ -62,72 +65,85 @@ def initialize_module_for_quantization(
62
65
  # no scheme passed and layer not targeted for quantization - skip
63
66
  return
64
67
 
65
- if scheme.input_activations is not None:
66
- _initialize_scale_zero_point_observer(
67
- module, "input", scheme.input_activations, force_zero_point=force_zero_point
68
- )
69
- if scheme.weights is not None:
70
- if hasattr(module, "weight"):
71
- weight_shape = module.weight.shape
68
+ if is_attention_module(module):
69
+ # wrap forward call of module to perform
70
+ # quantized actions based on calltime status
71
+ wrap_module_forward_quantized_attn(module, scheme)
72
+ _initialize_attn_scales(module)
73
+
74
+ else:
75
+
76
+ if scheme.input_activations is not None:
72
77
  _initialize_scale_zero_point_observer(
73
78
  module,
74
- "weight",
75
- scheme.weights,
76
- weight_shape=weight_shape,
79
+ "input",
80
+ scheme.input_activations,
77
81
  force_zero_point=force_zero_point,
78
82
  )
79
- else:
80
- _LOGGER.warning(
81
- f"module type {type(module)} targeted for weight quantization but "
82
- "has no attribute weight, skipping weight quantization "
83
- f"for {type(module)}"
84
- )
85
- if scheme.output_activations is not None:
86
- _initialize_scale_zero_point_observer(
87
- module,
88
- "output",
89
- scheme.output_activations,
90
- force_zero_point=force_zero_point,
91
- )
92
-
93
- module.quantization_scheme = scheme
94
- module.quantization_status = QuantizationStatus.INITIALIZED
95
-
96
- offloaded = False
97
- if is_module_offloaded(module):
98
- try:
99
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
100
- from accelerate.utils import PrefixedDataset
101
- except ModuleNotFoundError:
102
- raise ModuleNotFoundError(
103
- "Offloaded model detected. To use CPU offloading with "
104
- "compressed-tensors the `accelerate` package must be installed, "
105
- "run `pip install compressed-tensors[accelerate]`"
106
- )
107
-
108
- offloaded = True
109
- hook = module._hf_hook
110
- prefix_dict = module._hf_hook.weights_map
111
- new_prefix = {}
112
-
113
- # recreate the prefix dict (since it is immutable)
114
- # and add quantization parameters
115
- for key, data in module.named_parameters():
116
- if key not in prefix_dict:
117
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
83
+ if scheme.weights is not None:
84
+ if hasattr(module, "weight"):
85
+ weight_shape = None
86
+ if isinstance(module, torch.nn.Linear):
87
+ weight_shape = module.weight.shape
88
+ _initialize_scale_zero_point_observer(
89
+ module,
90
+ "weight",
91
+ scheme.weights,
92
+ weight_shape=weight_shape,
93
+ force_zero_point=force_zero_point,
94
+ )
118
95
  else:
119
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
120
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
121
- remove_hook_from_module(module)
122
-
123
- # wrap forward call of module to perform quantized actions based on calltime status
124
- wrap_module_forward_quantized(module, scheme)
125
-
126
- if offloaded:
127
- # we need to re-add the hook for offloading now that we've wrapped forward
128
- add_hook_to_module(module, hook)
129
- if prefix_dict is not None:
130
- module._hf_hook.weights_map = new_prefix_dict
96
+ _LOGGER.warning(
97
+ f"module type {type(module)} targeted for weight quantization but "
98
+ "has no attribute weight, skipping weight quantization "
99
+ f"for {type(module)}"
100
+ )
101
+
102
+ if scheme.output_activations is not None:
103
+ if not is_kv_cache_quant_scheme(scheme):
104
+ _initialize_scale_zero_point_observer(
105
+ module, "output", scheme.output_activations
106
+ )
107
+
108
+ module.quantization_scheme = scheme
109
+ module.quantization_status = QuantizationStatus.INITIALIZED
110
+
111
+ offloaded = False
112
+ if is_module_offloaded(module):
113
+ try:
114
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
115
+ from accelerate.utils import PrefixedDataset
116
+ except ModuleNotFoundError:
117
+ raise ModuleNotFoundError(
118
+ "Offloaded model detected. To use CPU offloading with "
119
+ "compressed-tensors the `accelerate` package must be installed, "
120
+ "run `pip install compressed-tensors[accelerate]`"
121
+ )
122
+
123
+ offloaded = True
124
+ hook = module._hf_hook
125
+ prefix_dict = module._hf_hook.weights_map
126
+ new_prefix = {}
127
+
128
+ # recreate the prefix dict (since it is immutable)
129
+ # and add quantization parameters
130
+ for key, data in module.named_parameters():
131
+ if key not in prefix_dict:
132
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
133
+ else:
134
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
135
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
136
+ remove_hook_from_module(module)
137
+
138
+ # wrap forward call of module to perform
139
+ # quantized actions based on calltime status
140
+ wrap_module_forward_quantized(module, scheme)
141
+
142
+ if offloaded:
143
+ # we need to re-add the hook for offloading now that we've wrapped forward
144
+ add_hook_to_module(module, hook)
145
+ if prefix_dict is not None:
146
+ module._hf_hook.weights_map = new_prefix_dict
131
147
 
132
148
 
133
149
  def _initialize_scale_zero_point_observer(
@@ -137,12 +153,16 @@ def _initialize_scale_zero_point_observer(
137
153
  weight_shape: Optional[torch.Size] = None,
138
154
  force_zero_point: bool = True,
139
155
  ):
156
+
140
157
  # initialize observer module and attach as submodule
141
158
  observer = quantization_args.get_observer()
142
- module.register_module(f"{base_name}_observer", observer)
159
+ # no need to register an observer for dynamic quantization
160
+ if observer:
161
+ module.register_module(f"{base_name}_observer", observer)
143
162
 
163
+ # no need to register a scale and zero point for a dynamic quantization
144
164
  if quantization_args.dynamic:
145
- return # no need to register a scale and zero point for a dynamic observer
165
+ return
146
166
 
147
167
  device = next(module.parameters()).device
148
168
  if is_module_offloaded(module):
@@ -156,10 +176,8 @@ def _initialize_scale_zero_point_observer(
156
176
  # (output_channels, 1)
157
177
  expected_shape = (weight_shape[0], 1)
158
178
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
159
- expected_shape = (
160
- weight_shape[0],
161
- weight_shape[1] // quantization_args.group_size,
162
- )
179
+ num_groups = weight_shape[1] // quantization_args.group_size
180
+ expected_shape = (weight_shape[0], max(num_groups, 1))
163
181
 
164
182
  scale_dtype = module.weight.dtype
165
183
  if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
@@ -189,3 +207,34 @@ def _initialize_scale_zero_point_observer(
189
207
  requires_grad=False,
190
208
  )
191
209
  module.register_parameter(f"{base_name}_g_idx", init_g_idx)
210
+
211
+
212
+ def is_attention_module(module: Module):
213
+ return "attention" in module.__class__.__name__.lower() and (
214
+ hasattr(module, "k_proj")
215
+ or hasattr(module, "v_proj")
216
+ or hasattr(module, "qkv_proj")
217
+ )
218
+
219
+
220
+ def _initialize_attn_scales(module: Module) -> None:
221
+ """Initlaize k_scale, v_scale for self_attn"""
222
+
223
+ expected_shape = 1 # per tensor
224
+
225
+ param = next(module.parameters())
226
+ scale_dtype = param.dtype
227
+ device = param.device
228
+
229
+ init_scale = Parameter(
230
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
231
+ requires_grad=False,
232
+ )
233
+
234
+ module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
235
+
236
+ init_scale = Parameter(
237
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
238
+ requires_grad=False,
239
+ )
240
+ module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
@@ -17,6 +17,5 @@
17
17
 
18
18
  from .helpers import *
19
19
  from .base import *
20
- from .memoryless import *
21
20
  from .min_max import *
22
21
  from .mse import *
@@ -13,18 +13,56 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from collections import Counter
16
- from typing import Tuple
16
+ from typing import Optional, Tuple
17
17
 
18
18
  import torch
19
19
  from compressed_tensors.quantization.quant_args import (
20
20
  FP8_DTYPE,
21
21
  QuantizationArgs,
22
+ QuantizationStrategy,
22
23
  QuantizationType,
23
24
  )
24
25
  from torch import FloatTensor, IntTensor, Tensor
25
26
 
26
27
 
27
- __all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28
+ __all__ = [
29
+ "calculate_qparams",
30
+ "get_observer_token_count",
31
+ "calculate_range",
32
+ "compute_dynamic_scales_and_zp",
33
+ ]
34
+
35
+
36
+ def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
37
+ """
38
+ Returns the computed scales and zero points for dynamic activation
39
+ qunatization.
40
+
41
+ :param value: tensor to calculate quantization parameters for
42
+ :param args: quantization args
43
+ :param reduce_dims: optional tuple of dimensions to reduce along,
44
+ returned scale and zero point will be shaped (1,) along the
45
+ reduced dimensions
46
+ :return: tuple of scale and zero point derived from the observed tensor
47
+ """
48
+ if args.strategy == QuantizationStrategy.TOKEN:
49
+ dim = {1, 2}
50
+ reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
51
+ elif args.strategy == QuantizationStrategy.TENSOR:
52
+ reduce_dims = None
53
+ else:
54
+ raise ValueError(
55
+ f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
56
+ "must be used for dynamic quantization",
57
+ )
58
+
59
+ if not reduce_dims:
60
+ min_val, max_val = torch.aminmax(value)
61
+ else:
62
+ min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
63
+ max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
64
+
65
+ return calculate_qparams(min_val, max_val, args)
28
66
 
29
67
 
30
68
  def get_observer_token_count(module: torch.nn.Module) -> Counter:
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import warnings
15
16
  from enum import Enum
16
17
  from typing import Any, Dict, Optional, Union
17
18
 
@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
94
95
  block_structure: Optional[str] = None
95
96
  dynamic: bool = False
96
97
  actorder: Union[ActivationOrdering, bool, None] = None
97
- observer: str = Field(
98
+ observer: Optional[str] = Field(
98
99
  default="minmax",
99
100
  description=(
100
101
  "The class to use to compute the quantization param - "
@@ -115,13 +116,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
115
116
  """
116
117
  from compressed_tensors.quantization.observers.base import Observer
117
118
 
119
+ # No observer required for the dynamic case
118
120
  if self.dynamic:
119
- # override defualt observer for dynamic, you never want minmax which
120
- # keeps state across samples for dynamic
121
- self.observer = "memoryless"
121
+ self.observer = None
122
+ return self.observer
122
123
 
123
124
  return Observer.load_from_registry(self.observer, quantization_args=self)
124
125
 
126
+ def get_kv_cache(self):
127
+ """Get the singleton KV Cache"""
128
+ from compressed_tensors.quantization.cache import QuantizedKVParameterCache
129
+
130
+ return QuantizedKVParameterCache(self)
131
+
125
132
  @field_validator("type", mode="before")
126
133
  def validate_type(cls, value) -> QuantizationType:
127
134
  if isinstance(value, str):
@@ -165,6 +172,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
165
172
  strategy = model.strategy
166
173
  group_size = model.group_size
167
174
  actorder = model.actorder
175
+ dynamic = model.dynamic
176
+ observer = model.observer
168
177
 
169
178
  # infer strategy
170
179
  if strategy is None:
@@ -201,6 +210,27 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
201
210
  "activation ordering"
202
211
  )
203
212
 
213
+ if dynamic:
214
+ if strategy not in (
215
+ QuantizationStrategy.TOKEN,
216
+ QuantizationStrategy.TENSOR,
217
+ ):
218
+ raise ValueError(
219
+ f"One of {QuantizationStrategy.TOKEN} or "
220
+ f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
221
+ "quantization",
222
+ )
223
+ if observer is not None:
224
+ warnings.warn(
225
+ "No observer is used for dynamic quantization, setting to None"
226
+ )
227
+ model.observer = None
228
+
229
+ # if we have not set an observer and we
230
+ # are running static quantization, use minmax
231
+ if not observer and not dynamic:
232
+ model.observer = "minmax"
233
+
204
234
  # write back modified values
205
235
  model.strategy = strategy
206
236
  return model
@@ -24,7 +24,7 @@ from compressed_tensors.quantization.quant_scheme import (
24
24
  from compressed_tensors.quantization.utils import (
25
25
  calculate_compression_ratio,
26
26
  is_module_quantized,
27
- iter_named_leaf_modules,
27
+ iter_named_quantizable_modules,
28
28
  module_type,
29
29
  parse_out_kv_cache_args,
30
30
  )
@@ -177,7 +177,9 @@ class QuantizationConfig(BaseModel):
177
177
  quantization_status = None
178
178
  ignore = {}
179
179
  quantization_type_names = set()
180
- for name, submodule in iter_named_leaf_modules(model):
180
+ for name, submodule in iter_named_quantizable_modules(
181
+ model, include_children=True, include_attn=True
182
+ ):
181
183
  layer_type = module_type(submodule)
182
184
  if not is_module_quantized(submodule):
183
185
  if layer_type not in ignore:
@@ -199,6 +201,13 @@ class QuantizationConfig(BaseModel):
199
201
  if len(quant_scheme_to_layers) == 0: # No quantized layers
200
202
  return None
201
203
 
204
+ # kv-cache only, no weight/activation quantization
205
+ if (
206
+ len(quantization_type_names) == 1
207
+ and "attention" in list(quantization_type_names)[0].lower()
208
+ ):
209
+ quantization_type_names.add("Linear")
210
+
202
211
  # clean up ignore list, we can leave out layers types if none of the
203
212
  # instances are quantized
204
213
  consolidated_ignore = []
@@ -241,6 +250,9 @@ class QuantizationConfig(BaseModel):
241
250
  )
242
251
 
243
252
  def requires_calibration_data(self):
253
+ if self.kv_cache_scheme is not None:
254
+ return True
255
+
244
256
  for _, scheme in self.config_groups.items():
245
257
  if scheme.input_activations is not None:
246
258
  if not scheme.input_activations.dynamic:
@@ -108,7 +108,7 @@ def is_preset_scheme(name: str) -> bool:
108
108
  UNQUANTIZED = dict()
109
109
 
110
110
  # 8 bit integer weights and 8 bit activations quantization
111
- W8A8 = dict(
111
+ INT8_W8A8 = dict(
112
112
  weights=QuantizationArgs(
113
113
  num_bits=8,
114
114
  type=QuantizationType.INT,
@@ -122,6 +122,7 @@ W8A8 = dict(
122
122
  strategy=QuantizationStrategy.TOKEN,
123
123
  symmetric=True,
124
124
  dynamic=True,
125
+ observer=None,
125
126
  ),
126
127
  )
127
128
 
@@ -149,7 +150,7 @@ W4A16 = dict(
149
150
  )
150
151
 
151
152
  # 4 bit integer weights and 8 bit activations quantization
152
- W4A8 = dict(
153
+ INT8_W4A8 = dict(
153
154
  weights=QuantizationArgs(
154
155
  num_bits=4,
155
156
  type=QuantizationType.INT,
@@ -164,6 +165,7 @@ W4A8 = dict(
164
165
  strategy=QuantizationStrategy.TOKEN,
165
166
  symmetric=True,
166
167
  dynamic=True,
168
+ observer=None,
167
169
  ),
168
170
  )
169
171
 
@@ -200,6 +202,7 @@ FP8_DYNAMIC = dict(
200
202
  strategy=QuantizationStrategy.TOKEN,
201
203
  symmetric=True,
202
204
  dynamic=True,
205
+ observer=None,
203
206
  ),
204
207
  )
205
208
 
@@ -210,8 +213,9 @@ PRESET_SCHEMES = {
210
213
  "W8A16": W8A16,
211
214
  "W4A16": W4A16,
212
215
  # Integer weight and activation schemes
213
- "W8A8": W8A8,
214
- "W4A8": W4A8,
216
+ "W8A8": INT8_W8A8,
217
+ "INT8": INT8_W8A8, # alias for W8A8
218
+ "W4A8": INT8_W4A8,
215
219
  # Float weight and activation schemes
216
220
  "FP8": FP8,
217
221
  "FP8_DYNAMIC": FP8_DYNAMIC,
@@ -13,8 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- import re
17
- from typing import List, Optional, Tuple
16
+ from typing import Generator, List, Optional, Tuple
18
17
 
19
18
  import torch
20
19
  from compressed_tensors.quantization.observers.base import Observer
@@ -28,7 +27,6 @@ __all__ = [
28
27
  "infer_quantization_status",
29
28
  "is_module_quantized",
30
29
  "is_model_quantized",
31
- "iter_named_leaf_modules",
32
30
  "module_type",
33
31
  "calculate_compression_ratio",
34
32
  "get_torch_bit_depth",
@@ -36,9 +34,14 @@ __all__ = [
36
34
  "parse_out_kv_cache_args",
37
35
  "KV_CACHE_TARGETS",
38
36
  "is_kv_cache_quant_scheme",
37
+ "iter_named_leaf_modules",
38
+ "iter_named_quantizable_modules",
39
39
  ]
40
40
 
41
- KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
41
+ # target the self_attn layer
42
+ # QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
43
+ KV_CACHE_TARGETS = ["re:.*self_attn$"]
44
+
42
45
  _LOGGER: logging.Logger = logging.getLogger(__name__)
43
46
 
44
47
 
@@ -106,11 +109,10 @@ def module_type(module: Module) -> str:
106
109
  return type(module).__name__
107
110
 
108
111
 
109
- def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
112
+ def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
110
113
  """
111
114
  Yields modules that do not have any submodules except observers. The observers
112
115
  themselves are not yielded
113
-
114
116
  :param model: model to get leaf modules of
115
117
  :returns: generator tuple of (name, leaf_submodule)
116
118
  """
@@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
128
130
  yield name, submodule
129
131
 
130
132
 
133
+ def iter_named_quantizable_modules(
134
+ model: Module, include_children: bool = True, include_attn: bool = False
135
+ ) -> Generator[Tuple[str, Module], None, None]:
136
+ """
137
+ Yield name and submodule of
138
+ - leaf modules, set by include_children
139
+ - attention modyles, set by include_attn
140
+
141
+ :param model: model to get leaf modules of
142
+ :param include_children: flag to get the leaf modules
143
+ :param inlcude_attn: flag to get the attention modules
144
+ :returns: generator tuple of (name, submodule)
145
+ """
146
+ for name, submodule in model.named_modules():
147
+ if include_children:
148
+ children = list(submodule.children())
149
+ if len(children) == 0 and not isinstance(submodule, Observer):
150
+ yield name, submodule
151
+ else:
152
+ has_non_observer_children = False
153
+ for child in children:
154
+ if not isinstance(child, Observer):
155
+ has_non_observer_children = True
156
+
157
+ if not has_non_observer_children:
158
+ yield name, submodule
159
+ if include_attn:
160
+ if name.endswith("self_attn"):
161
+ yield name, submodule
162
+
163
+
131
164
  def get_torch_bit_depth(value: torch.Tensor) -> int:
132
165
  """
133
166
  Determine the number of bits used to represent the dtype of a tensor
@@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
204
237
  :param scheme: The QuantizationScheme to investigate
205
238
  :return: boolean flag
206
239
  """
207
- if len(scheme.targets) == 1:
208
- # match on the KV_CACHE_TARGETS regex pattern
209
- # if there is only one target
210
- is_match_targets = any(
211
- [re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
212
- )
213
- else:
214
- # match on the exact KV_CACHE_TARGETS
215
- # if there are multiple targets
216
- is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
240
+ for target in scheme.targets:
241
+ if target in KV_CACHE_TARGETS:
242
+ return True
217
243
 
218
- is_match_output_activations = scheme.output_activations is not None
219
- return is_match_targets and is_match_output_activations
244
+ return False
220
245
 
221
246
 
222
247
  def parse_out_kv_cache_args(
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional
15
+ from typing import Any, Optional
16
16
 
17
17
  import torch
18
18
  from transformers import AutoConfig
@@ -23,6 +23,7 @@ __all__ = [
23
23
  "fix_fsdp_module_name",
24
24
  "tensor_follows_mask_structure",
25
25
  "replace_module",
26
+ "is_compressed_tensors_config",
26
27
  ]
27
28
 
28
29
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -103,3 +104,18 @@ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Modul
103
104
  parent = model
104
105
  child_name = name
105
106
  setattr(parent, child_name, new_module)
107
+
108
+
109
+ def is_compressed_tensors_config(compression_config: Any) -> bool:
110
+ """
111
+ Returns True if CompressedTensorsConfig is available from transformers and
112
+ compression_config is an instance of CompressedTensorsConfig
113
+
114
+ See: https://github.com/huggingface/transformers/pull/31704
115
+ """
116
+ try:
117
+ from transformers.utils.quantization_config import CompressedTensorsConfig
118
+
119
+ return isinstance(compression_config, CompressedTensorsConfig)
120
+ except ImportError:
121
+ return False
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.6.0"
20
+ version_base = "0.7.1"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.6.0
3
+ Version: 0.7.1
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.