compressed-tensors 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/__init__.py +1 -0
- compressed_tensors/base.py +2 -0
- compressed_tensors/compressors/__init__.py +6 -12
- compressed_tensors/compressors/base.py +38 -102
- compressed_tensors/compressors/helpers.py +6 -6
- compressed_tensors/compressors/model_compressors/__init__.py +17 -0
- compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +95 -106
- compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
- compressed_tensors/compressors/quantized_compressors/base.py +146 -0
- compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
- compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
- compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
- compressed_tensors/compressors/sparse_compressors/base.py +110 -0
- compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
- compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
- compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
- compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
- compressed_tensors/linear/compressed_linear.py +2 -2
- compressed_tensors/quantization/__init__.py +1 -0
- compressed_tensors/quantization/cache.py +201 -0
- compressed_tensors/quantization/lifecycle/apply.py +19 -3
- compressed_tensors/quantization/lifecycle/calibration.py +2 -3
- compressed_tensors/quantization/lifecycle/forward.py +58 -7
- compressed_tensors/quantization/lifecycle/frozen.py +6 -1
- compressed_tensors/quantization/lifecycle/helpers.py +0 -47
- compressed_tensors/quantization/lifecycle/initialize.py +116 -67
- compressed_tensors/quantization/observers/__init__.py +0 -1
- compressed_tensors/quantization/observers/helpers.py +40 -2
- compressed_tensors/quantization/quant_args.py +34 -4
- compressed_tensors/quantization/quant_config.py +14 -2
- compressed_tensors/quantization/quant_scheme.py +8 -4
- compressed_tensors/quantization/utils/helpers.py +43 -18
- compressed_tensors/utils/helpers.py +17 -1
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/METADATA +1 -1
- compressed_tensors-0.7.1.dist-info/RECORD +58 -0
- compressed_tensors/quantization/observers/memoryless.py +0 -56
- compressed_tensors-0.6.0.dist-info/RECORD +0 -52
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.6.0.dist-info → compressed_tensors-0.7.1.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,10 @@ import logging
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.cache import KVCacheScaleType
|
20
21
|
from compressed_tensors.quantization.lifecycle.forward import (
|
21
22
|
wrap_module_forward_quantized,
|
23
|
+
wrap_module_forward_quantized_attn,
|
22
24
|
)
|
23
25
|
from compressed_tensors.quantization.quant_args import (
|
24
26
|
ActivationOrdering,
|
@@ -27,6 +29,7 @@ from compressed_tensors.quantization.quant_args import (
|
|
27
29
|
)
|
28
30
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
29
31
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
32
|
+
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
30
33
|
from compressed_tensors.utils import get_execution_device, is_module_offloaded
|
31
34
|
from torch.nn import Module, Parameter
|
32
35
|
|
@@ -62,72 +65,85 @@ def initialize_module_for_quantization(
|
|
62
65
|
# no scheme passed and layer not targeted for quantization - skip
|
63
66
|
return
|
64
67
|
|
65
|
-
if
|
66
|
-
|
67
|
-
|
68
|
-
)
|
69
|
-
|
70
|
-
|
71
|
-
|
68
|
+
if is_attention_module(module):
|
69
|
+
# wrap forward call of module to perform
|
70
|
+
# quantized actions based on calltime status
|
71
|
+
wrap_module_forward_quantized_attn(module, scheme)
|
72
|
+
_initialize_attn_scales(module)
|
73
|
+
|
74
|
+
else:
|
75
|
+
|
76
|
+
if scheme.input_activations is not None:
|
72
77
|
_initialize_scale_zero_point_observer(
|
73
78
|
module,
|
74
|
-
"
|
75
|
-
scheme.
|
76
|
-
weight_shape=weight_shape,
|
79
|
+
"input",
|
80
|
+
scheme.input_activations,
|
77
81
|
force_zero_point=force_zero_point,
|
78
82
|
)
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
)
|
92
|
-
|
93
|
-
module.quantization_scheme = scheme
|
94
|
-
module.quantization_status = QuantizationStatus.INITIALIZED
|
95
|
-
|
96
|
-
offloaded = False
|
97
|
-
if is_module_offloaded(module):
|
98
|
-
try:
|
99
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
100
|
-
from accelerate.utils import PrefixedDataset
|
101
|
-
except ModuleNotFoundError:
|
102
|
-
raise ModuleNotFoundError(
|
103
|
-
"Offloaded model detected. To use CPU offloading with "
|
104
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
105
|
-
"run `pip install compressed-tensors[accelerate]`"
|
106
|
-
)
|
107
|
-
|
108
|
-
offloaded = True
|
109
|
-
hook = module._hf_hook
|
110
|
-
prefix_dict = module._hf_hook.weights_map
|
111
|
-
new_prefix = {}
|
112
|
-
|
113
|
-
# recreate the prefix dict (since it is immutable)
|
114
|
-
# and add quantization parameters
|
115
|
-
for key, data in module.named_parameters():
|
116
|
-
if key not in prefix_dict:
|
117
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
83
|
+
if scheme.weights is not None:
|
84
|
+
if hasattr(module, "weight"):
|
85
|
+
weight_shape = None
|
86
|
+
if isinstance(module, torch.nn.Linear):
|
87
|
+
weight_shape = module.weight.shape
|
88
|
+
_initialize_scale_zero_point_observer(
|
89
|
+
module,
|
90
|
+
"weight",
|
91
|
+
scheme.weights,
|
92
|
+
weight_shape=weight_shape,
|
93
|
+
force_zero_point=force_zero_point,
|
94
|
+
)
|
118
95
|
else:
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
96
|
+
_LOGGER.warning(
|
97
|
+
f"module type {type(module)} targeted for weight quantization but "
|
98
|
+
"has no attribute weight, skipping weight quantization "
|
99
|
+
f"for {type(module)}"
|
100
|
+
)
|
101
|
+
|
102
|
+
if scheme.output_activations is not None:
|
103
|
+
if not is_kv_cache_quant_scheme(scheme):
|
104
|
+
_initialize_scale_zero_point_observer(
|
105
|
+
module, "output", scheme.output_activations
|
106
|
+
)
|
107
|
+
|
108
|
+
module.quantization_scheme = scheme
|
109
|
+
module.quantization_status = QuantizationStatus.INITIALIZED
|
110
|
+
|
111
|
+
offloaded = False
|
112
|
+
if is_module_offloaded(module):
|
113
|
+
try:
|
114
|
+
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
115
|
+
from accelerate.utils import PrefixedDataset
|
116
|
+
except ModuleNotFoundError:
|
117
|
+
raise ModuleNotFoundError(
|
118
|
+
"Offloaded model detected. To use CPU offloading with "
|
119
|
+
"compressed-tensors the `accelerate` package must be installed, "
|
120
|
+
"run `pip install compressed-tensors[accelerate]`"
|
121
|
+
)
|
122
|
+
|
123
|
+
offloaded = True
|
124
|
+
hook = module._hf_hook
|
125
|
+
prefix_dict = module._hf_hook.weights_map
|
126
|
+
new_prefix = {}
|
127
|
+
|
128
|
+
# recreate the prefix dict (since it is immutable)
|
129
|
+
# and add quantization parameters
|
130
|
+
for key, data in module.named_parameters():
|
131
|
+
if key not in prefix_dict:
|
132
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
133
|
+
else:
|
134
|
+
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
135
|
+
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
136
|
+
remove_hook_from_module(module)
|
137
|
+
|
138
|
+
# wrap forward call of module to perform
|
139
|
+
# quantized actions based on calltime status
|
140
|
+
wrap_module_forward_quantized(module, scheme)
|
141
|
+
|
142
|
+
if offloaded:
|
143
|
+
# we need to re-add the hook for offloading now that we've wrapped forward
|
144
|
+
add_hook_to_module(module, hook)
|
145
|
+
if prefix_dict is not None:
|
146
|
+
module._hf_hook.weights_map = new_prefix_dict
|
131
147
|
|
132
148
|
|
133
149
|
def _initialize_scale_zero_point_observer(
|
@@ -137,12 +153,16 @@ def _initialize_scale_zero_point_observer(
|
|
137
153
|
weight_shape: Optional[torch.Size] = None,
|
138
154
|
force_zero_point: bool = True,
|
139
155
|
):
|
156
|
+
|
140
157
|
# initialize observer module and attach as submodule
|
141
158
|
observer = quantization_args.get_observer()
|
142
|
-
|
159
|
+
# no need to register an observer for dynamic quantization
|
160
|
+
if observer:
|
161
|
+
module.register_module(f"{base_name}_observer", observer)
|
143
162
|
|
163
|
+
# no need to register a scale and zero point for a dynamic quantization
|
144
164
|
if quantization_args.dynamic:
|
145
|
-
return
|
165
|
+
return
|
146
166
|
|
147
167
|
device = next(module.parameters()).device
|
148
168
|
if is_module_offloaded(module):
|
@@ -156,10 +176,8 @@ def _initialize_scale_zero_point_observer(
|
|
156
176
|
# (output_channels, 1)
|
157
177
|
expected_shape = (weight_shape[0], 1)
|
158
178
|
elif quantization_args.strategy == QuantizationStrategy.GROUP:
|
159
|
-
|
160
|
-
|
161
|
-
weight_shape[1] // quantization_args.group_size,
|
162
|
-
)
|
179
|
+
num_groups = weight_shape[1] // quantization_args.group_size
|
180
|
+
expected_shape = (weight_shape[0], max(num_groups, 1))
|
163
181
|
|
164
182
|
scale_dtype = module.weight.dtype
|
165
183
|
if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
|
@@ -189,3 +207,34 @@ def _initialize_scale_zero_point_observer(
|
|
189
207
|
requires_grad=False,
|
190
208
|
)
|
191
209
|
module.register_parameter(f"{base_name}_g_idx", init_g_idx)
|
210
|
+
|
211
|
+
|
212
|
+
def is_attention_module(module: Module):
|
213
|
+
return "attention" in module.__class__.__name__.lower() and (
|
214
|
+
hasattr(module, "k_proj")
|
215
|
+
or hasattr(module, "v_proj")
|
216
|
+
or hasattr(module, "qkv_proj")
|
217
|
+
)
|
218
|
+
|
219
|
+
|
220
|
+
def _initialize_attn_scales(module: Module) -> None:
|
221
|
+
"""Initlaize k_scale, v_scale for self_attn"""
|
222
|
+
|
223
|
+
expected_shape = 1 # per tensor
|
224
|
+
|
225
|
+
param = next(module.parameters())
|
226
|
+
scale_dtype = param.dtype
|
227
|
+
device = param.device
|
228
|
+
|
229
|
+
init_scale = Parameter(
|
230
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
231
|
+
requires_grad=False,
|
232
|
+
)
|
233
|
+
|
234
|
+
module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
|
235
|
+
|
236
|
+
init_scale = Parameter(
|
237
|
+
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
238
|
+
requires_grad=False,
|
239
|
+
)
|
240
|
+
module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
|
@@ -13,18 +13,56 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from collections import Counter
|
16
|
-
from typing import Tuple
|
16
|
+
from typing import Optional, Tuple
|
17
17
|
|
18
18
|
import torch
|
19
19
|
from compressed_tensors.quantization.quant_args import (
|
20
20
|
FP8_DTYPE,
|
21
21
|
QuantizationArgs,
|
22
|
+
QuantizationStrategy,
|
22
23
|
QuantizationType,
|
23
24
|
)
|
24
25
|
from torch import FloatTensor, IntTensor, Tensor
|
25
26
|
|
26
27
|
|
27
|
-
__all__ = [
|
28
|
+
__all__ = [
|
29
|
+
"calculate_qparams",
|
30
|
+
"get_observer_token_count",
|
31
|
+
"calculate_range",
|
32
|
+
"compute_dynamic_scales_and_zp",
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
def compute_dynamic_scales_and_zp(value: Tensor, args: QuantizationArgs):
|
37
|
+
"""
|
38
|
+
Returns the computed scales and zero points for dynamic activation
|
39
|
+
qunatization.
|
40
|
+
|
41
|
+
:param value: tensor to calculate quantization parameters for
|
42
|
+
:param args: quantization args
|
43
|
+
:param reduce_dims: optional tuple of dimensions to reduce along,
|
44
|
+
returned scale and zero point will be shaped (1,) along the
|
45
|
+
reduced dimensions
|
46
|
+
:return: tuple of scale and zero point derived from the observed tensor
|
47
|
+
"""
|
48
|
+
if args.strategy == QuantizationStrategy.TOKEN:
|
49
|
+
dim = {1, 2}
|
50
|
+
reduce_dims = tuple(idx for idx in range(value.ndim) if idx not in dim)
|
51
|
+
elif args.strategy == QuantizationStrategy.TENSOR:
|
52
|
+
reduce_dims = None
|
53
|
+
else:
|
54
|
+
raise ValueError(
|
55
|
+
f"One of {QuantizationStrategy.TOKEN} or {QuantizationStrategy.TENSOR} ",
|
56
|
+
"must be used for dynamic quantization",
|
57
|
+
)
|
58
|
+
|
59
|
+
if not reduce_dims:
|
60
|
+
min_val, max_val = torch.aminmax(value)
|
61
|
+
else:
|
62
|
+
min_val = torch.amin(value, dim=reduce_dims, keepdims=True)
|
63
|
+
max_val = torch.amax(value, dim=reduce_dims, keepdims=True)
|
64
|
+
|
65
|
+
return calculate_qparams(min_val, max_val, args)
|
28
66
|
|
29
67
|
|
30
68
|
def get_observer_token_count(module: torch.nn.Module) -> Counter:
|
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import warnings
|
15
16
|
from enum import Enum
|
16
17
|
from typing import Any, Dict, Optional, Union
|
17
18
|
|
@@ -94,7 +95,7 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
94
95
|
block_structure: Optional[str] = None
|
95
96
|
dynamic: bool = False
|
96
97
|
actorder: Union[ActivationOrdering, bool, None] = None
|
97
|
-
observer: str = Field(
|
98
|
+
observer: Optional[str] = Field(
|
98
99
|
default="minmax",
|
99
100
|
description=(
|
100
101
|
"The class to use to compute the quantization param - "
|
@@ -115,13 +116,19 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
115
116
|
"""
|
116
117
|
from compressed_tensors.quantization.observers.base import Observer
|
117
118
|
|
119
|
+
# No observer required for the dynamic case
|
118
120
|
if self.dynamic:
|
119
|
-
|
120
|
-
|
121
|
-
self.observer = "memoryless"
|
121
|
+
self.observer = None
|
122
|
+
return self.observer
|
122
123
|
|
123
124
|
return Observer.load_from_registry(self.observer, quantization_args=self)
|
124
125
|
|
126
|
+
def get_kv_cache(self):
|
127
|
+
"""Get the singleton KV Cache"""
|
128
|
+
from compressed_tensors.quantization.cache import QuantizedKVParameterCache
|
129
|
+
|
130
|
+
return QuantizedKVParameterCache(self)
|
131
|
+
|
125
132
|
@field_validator("type", mode="before")
|
126
133
|
def validate_type(cls, value) -> QuantizationType:
|
127
134
|
if isinstance(value, str):
|
@@ -165,6 +172,8 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
165
172
|
strategy = model.strategy
|
166
173
|
group_size = model.group_size
|
167
174
|
actorder = model.actorder
|
175
|
+
dynamic = model.dynamic
|
176
|
+
observer = model.observer
|
168
177
|
|
169
178
|
# infer strategy
|
170
179
|
if strategy is None:
|
@@ -201,6 +210,27 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
201
210
|
"activation ordering"
|
202
211
|
)
|
203
212
|
|
213
|
+
if dynamic:
|
214
|
+
if strategy not in (
|
215
|
+
QuantizationStrategy.TOKEN,
|
216
|
+
QuantizationStrategy.TENSOR,
|
217
|
+
):
|
218
|
+
raise ValueError(
|
219
|
+
f"One of {QuantizationStrategy.TOKEN} or "
|
220
|
+
f"{QuantizationStrategy.TENSOR} must be used for dynamic ",
|
221
|
+
"quantization",
|
222
|
+
)
|
223
|
+
if observer is not None:
|
224
|
+
warnings.warn(
|
225
|
+
"No observer is used for dynamic quantization, setting to None"
|
226
|
+
)
|
227
|
+
model.observer = None
|
228
|
+
|
229
|
+
# if we have not set an observer and we
|
230
|
+
# are running static quantization, use minmax
|
231
|
+
if not observer and not dynamic:
|
232
|
+
model.observer = "minmax"
|
233
|
+
|
204
234
|
# write back modified values
|
205
235
|
model.strategy = strategy
|
206
236
|
return model
|
@@ -24,7 +24,7 @@ from compressed_tensors.quantization.quant_scheme import (
|
|
24
24
|
from compressed_tensors.quantization.utils import (
|
25
25
|
calculate_compression_ratio,
|
26
26
|
is_module_quantized,
|
27
|
-
|
27
|
+
iter_named_quantizable_modules,
|
28
28
|
module_type,
|
29
29
|
parse_out_kv_cache_args,
|
30
30
|
)
|
@@ -177,7 +177,9 @@ class QuantizationConfig(BaseModel):
|
|
177
177
|
quantization_status = None
|
178
178
|
ignore = {}
|
179
179
|
quantization_type_names = set()
|
180
|
-
for name, submodule in
|
180
|
+
for name, submodule in iter_named_quantizable_modules(
|
181
|
+
model, include_children=True, include_attn=True
|
182
|
+
):
|
181
183
|
layer_type = module_type(submodule)
|
182
184
|
if not is_module_quantized(submodule):
|
183
185
|
if layer_type not in ignore:
|
@@ -199,6 +201,13 @@ class QuantizationConfig(BaseModel):
|
|
199
201
|
if len(quant_scheme_to_layers) == 0: # No quantized layers
|
200
202
|
return None
|
201
203
|
|
204
|
+
# kv-cache only, no weight/activation quantization
|
205
|
+
if (
|
206
|
+
len(quantization_type_names) == 1
|
207
|
+
and "attention" in list(quantization_type_names)[0].lower()
|
208
|
+
):
|
209
|
+
quantization_type_names.add("Linear")
|
210
|
+
|
202
211
|
# clean up ignore list, we can leave out layers types if none of the
|
203
212
|
# instances are quantized
|
204
213
|
consolidated_ignore = []
|
@@ -241,6 +250,9 @@ class QuantizationConfig(BaseModel):
|
|
241
250
|
)
|
242
251
|
|
243
252
|
def requires_calibration_data(self):
|
253
|
+
if self.kv_cache_scheme is not None:
|
254
|
+
return True
|
255
|
+
|
244
256
|
for _, scheme in self.config_groups.items():
|
245
257
|
if scheme.input_activations is not None:
|
246
258
|
if not scheme.input_activations.dynamic:
|
@@ -108,7 +108,7 @@ def is_preset_scheme(name: str) -> bool:
|
|
108
108
|
UNQUANTIZED = dict()
|
109
109
|
|
110
110
|
# 8 bit integer weights and 8 bit activations quantization
|
111
|
-
|
111
|
+
INT8_W8A8 = dict(
|
112
112
|
weights=QuantizationArgs(
|
113
113
|
num_bits=8,
|
114
114
|
type=QuantizationType.INT,
|
@@ -122,6 +122,7 @@ W8A8 = dict(
|
|
122
122
|
strategy=QuantizationStrategy.TOKEN,
|
123
123
|
symmetric=True,
|
124
124
|
dynamic=True,
|
125
|
+
observer=None,
|
125
126
|
),
|
126
127
|
)
|
127
128
|
|
@@ -149,7 +150,7 @@ W4A16 = dict(
|
|
149
150
|
)
|
150
151
|
|
151
152
|
# 4 bit integer weights and 8 bit activations quantization
|
152
|
-
|
153
|
+
INT8_W4A8 = dict(
|
153
154
|
weights=QuantizationArgs(
|
154
155
|
num_bits=4,
|
155
156
|
type=QuantizationType.INT,
|
@@ -164,6 +165,7 @@ W4A8 = dict(
|
|
164
165
|
strategy=QuantizationStrategy.TOKEN,
|
165
166
|
symmetric=True,
|
166
167
|
dynamic=True,
|
168
|
+
observer=None,
|
167
169
|
),
|
168
170
|
)
|
169
171
|
|
@@ -200,6 +202,7 @@ FP8_DYNAMIC = dict(
|
|
200
202
|
strategy=QuantizationStrategy.TOKEN,
|
201
203
|
symmetric=True,
|
202
204
|
dynamic=True,
|
205
|
+
observer=None,
|
203
206
|
),
|
204
207
|
)
|
205
208
|
|
@@ -210,8 +213,9 @@ PRESET_SCHEMES = {
|
|
210
213
|
"W8A16": W8A16,
|
211
214
|
"W4A16": W4A16,
|
212
215
|
# Integer weight and activation schemes
|
213
|
-
"W8A8":
|
214
|
-
"
|
216
|
+
"W8A8": INT8_W8A8,
|
217
|
+
"INT8": INT8_W8A8, # alias for W8A8
|
218
|
+
"W4A8": INT8_W4A8,
|
215
219
|
# Float weight and activation schemes
|
216
220
|
"FP8": FP8,
|
217
221
|
"FP8_DYNAMIC": FP8_DYNAMIC,
|
@@ -13,8 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
import
|
17
|
-
from typing import List, Optional, Tuple
|
16
|
+
from typing import Generator, List, Optional, Tuple
|
18
17
|
|
19
18
|
import torch
|
20
19
|
from compressed_tensors.quantization.observers.base import Observer
|
@@ -28,7 +27,6 @@ __all__ = [
|
|
28
27
|
"infer_quantization_status",
|
29
28
|
"is_module_quantized",
|
30
29
|
"is_model_quantized",
|
31
|
-
"iter_named_leaf_modules",
|
32
30
|
"module_type",
|
33
31
|
"calculate_compression_ratio",
|
34
32
|
"get_torch_bit_depth",
|
@@ -36,9 +34,14 @@ __all__ = [
|
|
36
34
|
"parse_out_kv_cache_args",
|
37
35
|
"KV_CACHE_TARGETS",
|
38
36
|
"is_kv_cache_quant_scheme",
|
37
|
+
"iter_named_leaf_modules",
|
38
|
+
"iter_named_quantizable_modules",
|
39
39
|
]
|
40
40
|
|
41
|
-
|
41
|
+
# target the self_attn layer
|
42
|
+
# QuantizedKVParameterCache is responsible for obtaining the k_scale and v_scale
|
43
|
+
KV_CACHE_TARGETS = ["re:.*self_attn$"]
|
44
|
+
|
42
45
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
43
46
|
|
44
47
|
|
@@ -106,11 +109,10 @@ def module_type(module: Module) -> str:
|
|
106
109
|
return type(module).__name__
|
107
110
|
|
108
111
|
|
109
|
-
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
112
|
+
def iter_named_leaf_modules(model: Module) -> Generator[Tuple[str, Module], None, None]:
|
110
113
|
"""
|
111
114
|
Yields modules that do not have any submodules except observers. The observers
|
112
115
|
themselves are not yielded
|
113
|
-
|
114
116
|
:param model: model to get leaf modules of
|
115
117
|
:returns: generator tuple of (name, leaf_submodule)
|
116
118
|
"""
|
@@ -128,6 +130,37 @@ def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
|
128
130
|
yield name, submodule
|
129
131
|
|
130
132
|
|
133
|
+
def iter_named_quantizable_modules(
|
134
|
+
model: Module, include_children: bool = True, include_attn: bool = False
|
135
|
+
) -> Generator[Tuple[str, Module], None, None]:
|
136
|
+
"""
|
137
|
+
Yield name and submodule of
|
138
|
+
- leaf modules, set by include_children
|
139
|
+
- attention modyles, set by include_attn
|
140
|
+
|
141
|
+
:param model: model to get leaf modules of
|
142
|
+
:param include_children: flag to get the leaf modules
|
143
|
+
:param inlcude_attn: flag to get the attention modules
|
144
|
+
:returns: generator tuple of (name, submodule)
|
145
|
+
"""
|
146
|
+
for name, submodule in model.named_modules():
|
147
|
+
if include_children:
|
148
|
+
children = list(submodule.children())
|
149
|
+
if len(children) == 0 and not isinstance(submodule, Observer):
|
150
|
+
yield name, submodule
|
151
|
+
else:
|
152
|
+
has_non_observer_children = False
|
153
|
+
for child in children:
|
154
|
+
if not isinstance(child, Observer):
|
155
|
+
has_non_observer_children = True
|
156
|
+
|
157
|
+
if not has_non_observer_children:
|
158
|
+
yield name, submodule
|
159
|
+
if include_attn:
|
160
|
+
if name.endswith("self_attn"):
|
161
|
+
yield name, submodule
|
162
|
+
|
163
|
+
|
131
164
|
def get_torch_bit_depth(value: torch.Tensor) -> int:
|
132
165
|
"""
|
133
166
|
Determine the number of bits used to represent the dtype of a tensor
|
@@ -204,19 +237,11 @@ def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
|
|
204
237
|
:param scheme: The QuantizationScheme to investigate
|
205
238
|
:return: boolean flag
|
206
239
|
"""
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
is_match_targets = any(
|
211
|
-
[re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
|
212
|
-
)
|
213
|
-
else:
|
214
|
-
# match on the exact KV_CACHE_TARGETS
|
215
|
-
# if there are multiple targets
|
216
|
-
is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
|
240
|
+
for target in scheme.targets:
|
241
|
+
if target in KV_CACHE_TARGETS:
|
242
|
+
return True
|
217
243
|
|
218
|
-
|
219
|
-
return is_match_targets and is_match_output_activations
|
244
|
+
return False
|
220
245
|
|
221
246
|
|
222
247
|
def parse_out_kv_cache_args(
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional
|
15
|
+
from typing import Any, Optional
|
16
16
|
|
17
17
|
import torch
|
18
18
|
from transformers import AutoConfig
|
@@ -23,6 +23,7 @@ __all__ = [
|
|
23
23
|
"fix_fsdp_module_name",
|
24
24
|
"tensor_follows_mask_structure",
|
25
25
|
"replace_module",
|
26
|
+
"is_compressed_tensors_config",
|
26
27
|
]
|
27
28
|
|
28
29
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -103,3 +104,18 @@ def replace_module(model: torch.nn.Module, name: str, new_module: torch.nn.Modul
|
|
103
104
|
parent = model
|
104
105
|
child_name = name
|
105
106
|
setattr(parent, child_name, new_module)
|
107
|
+
|
108
|
+
|
109
|
+
def is_compressed_tensors_config(compression_config: Any) -> bool:
|
110
|
+
"""
|
111
|
+
Returns True if CompressedTensorsConfig is available from transformers and
|
112
|
+
compression_config is an instance of CompressedTensorsConfig
|
113
|
+
|
114
|
+
See: https://github.com/huggingface/transformers/pull/31704
|
115
|
+
"""
|
116
|
+
try:
|
117
|
+
from transformers.utils.quantization_config import CompressedTensorsConfig
|
118
|
+
|
119
|
+
return isinstance(compression_config, CompressedTensorsConfig)
|
120
|
+
except ImportError:
|
121
|
+
return False
|
compressed_tensors/version.py
CHANGED