compressed-tensors 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +1 -0
- compressed_tensors/compressors/__init__.py +5 -1
- compressed_tensors/compressors/base.py +1 -1
- compressed_tensors/compressors/dense.py +1 -1
- compressed_tensors/compressors/marlin_24.py +11 -10
- compressed_tensors/compressors/model_compressor.py +33 -12
- compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +33 -15
- compressed_tensors/compressors/pack_quantized.py +58 -51
- compressed_tensors/compressors/sparse_bitmask.py +1 -1
- compressed_tensors/config/base.py +2 -0
- compressed_tensors/quantization/lifecycle/__init__.py +1 -0
- compressed_tensors/quantization/lifecycle/apply.py +161 -39
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/forward.py +70 -25
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +30 -1
- compressed_tensors/quantization/observers/base.py +39 -0
- compressed_tensors/quantization/observers/helpers.py +64 -11
- compressed_tensors/quantization/quant_args.py +45 -1
- compressed_tensors/quantization/quant_config.py +35 -2
- compressed_tensors/quantization/quant_scheme.py +105 -4
- compressed_tensors/quantization/utils/helpers.py +67 -1
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +31 -2
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +2 -1
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors/compressors/utils/__init__.py +0 -19
- compressed_tensors/compressors/utils/helpers.py +0 -43
- compressed_tensors-0.4.0.dist-info/RECORD +0 -48
- /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
- /compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import logging
|
15
16
|
from typing import Any, Iterable, Optional, Tuple, Union
|
16
17
|
|
17
18
|
import torch
|
@@ -24,6 +25,9 @@ from torch import FloatTensor, IntTensor, Tensor
|
|
24
25
|
from torch.nn import Module
|
25
26
|
|
26
27
|
|
28
|
+
_LOGGER = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
|
27
31
|
__all__ = ["Observer"]
|
28
32
|
|
29
33
|
|
@@ -39,6 +43,7 @@ class Observer(Module, RegistryMixin):
|
|
39
43
|
super().__init__()
|
40
44
|
self._scale = None
|
41
45
|
self._zero_point = None
|
46
|
+
self._num_observed_tokens = None
|
42
47
|
|
43
48
|
@torch.no_grad()
|
44
49
|
def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
|
@@ -48,6 +53,7 @@ class Observer(Module, RegistryMixin):
|
|
48
53
|
from
|
49
54
|
:return: tuple of scale and zero point based on last observed value
|
50
55
|
"""
|
56
|
+
self.record_observed_tokens(observed)
|
51
57
|
return self.get_qparams(observed=observed)
|
52
58
|
|
53
59
|
def calculate_qparams(
|
@@ -132,3 +138,36 @@ class Observer(Module, RegistryMixin):
|
|
132
138
|
return self.calculate_qparams(
|
133
139
|
observed, reduce_dims=reduce_dims, tensor_id=tensor_id
|
134
140
|
)
|
141
|
+
|
142
|
+
def record_observed_tokens(self, batch_tensor: Tensor):
|
143
|
+
"""
|
144
|
+
Counts the number of tokens observed during the
|
145
|
+
forward passes. The count is aggregated in the
|
146
|
+
_num_observed_tokens attribute of the class.
|
147
|
+
|
148
|
+
Note: The batch_tensor is expected to have two dimensions
|
149
|
+
(batch_size * sequence_length, num_features). This is the
|
150
|
+
general shape expected by the forward pass of the expert
|
151
|
+
layers in a MOE model. If the input tensor does not have
|
152
|
+
two dimensions, the _num_observed_tokens attribute will be set
|
153
|
+
to None.
|
154
|
+
"""
|
155
|
+
if not isinstance(batch_tensor, Tensor):
|
156
|
+
raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
|
157
|
+
|
158
|
+
if batch_tensor.ndim != 2:
|
159
|
+
_LOGGER.debug(
|
160
|
+
"The input tensor is expected to have two dimensions "
|
161
|
+
"(batch_size * sequence_length, num_features). "
|
162
|
+
f"The input tensor has {batch_tensor.ndim} dimensions."
|
163
|
+
)
|
164
|
+
return
|
165
|
+
|
166
|
+
if self._num_observed_tokens is None:
|
167
|
+
# initialize the count
|
168
|
+
self._num_observed_tokens = 0
|
169
|
+
|
170
|
+
# batch_tensor (batch_size * sequence_length, num_features)
|
171
|
+
# observed_tokens (batch_size * sequence_length)
|
172
|
+
observed_tokens, _ = batch_tensor.shape
|
173
|
+
self._num_observed_tokens += observed_tokens
|
@@ -12,23 +12,45 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from collections import Counter
|
15
16
|
from typing import Tuple
|
16
17
|
|
17
18
|
import torch
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
19
|
+
from compressed_tensors.quantization.quant_args import (
|
20
|
+
FP8_DTYPE,
|
21
|
+
QuantizationArgs,
|
22
|
+
QuantizationType,
|
23
|
+
)
|
19
24
|
from torch import FloatTensor, IntTensor, Tensor
|
20
25
|
|
21
26
|
|
22
|
-
__all__ = ["calculate_qparams"]
|
27
|
+
__all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
|
28
|
+
|
29
|
+
|
30
|
+
def get_observer_token_count(module: torch.nn.Module) -> Counter:
|
31
|
+
"""
|
32
|
+
Parse the module and return the number of tokens observed by
|
33
|
+
each module's observer.
|
34
|
+
|
35
|
+
:param module: module to parse
|
36
|
+
:return: counter with the number of tokens observed by each observer
|
37
|
+
"""
|
38
|
+
token_counts = Counter()
|
39
|
+
for name, module in module.named_modules():
|
40
|
+
if name.endswith(".input_observer"):
|
41
|
+
token_counts[
|
42
|
+
name.replace(".input_observer", "")
|
43
|
+
] = module._num_observed_tokens
|
44
|
+
return token_counts
|
23
45
|
|
24
46
|
|
25
47
|
def calculate_qparams(
|
26
48
|
min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
|
27
49
|
) -> Tuple[FloatTensor, IntTensor]:
|
28
50
|
"""
|
29
|
-
:param min_vals: tensor of min value(s) to
|
51
|
+
:param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
|
30
52
|
from
|
31
|
-
:param max_vals: tensor of max value(s) to
|
53
|
+
:param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
|
32
54
|
from
|
33
55
|
:param quantization_args: settings to quantization
|
34
56
|
:return: tuple of the calculated scale(s) and zero point(s)
|
@@ -37,22 +59,53 @@ def calculate_qparams(
|
|
37
59
|
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
|
38
60
|
device = min_vals.device
|
39
61
|
|
40
|
-
|
41
|
-
|
42
|
-
|
62
|
+
bit_min, bit_max = calculate_range(quantization_args, device)
|
63
|
+
bit_range = bit_max - bit_min
|
64
|
+
zp_dtype = quantization_args.pytorch_dtype()
|
65
|
+
|
43
66
|
if quantization_args.symmetric:
|
44
|
-
max_val_pos = torch.max(
|
67
|
+
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
|
45
68
|
scales = max_val_pos / (float(bit_range) / 2)
|
46
69
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
47
|
-
zero_points = torch.zeros(scales.shape, device=device, dtype=
|
70
|
+
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
|
48
71
|
else:
|
49
72
|
scales = (max_vals - min_vals) / float(bit_range)
|
50
73
|
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
|
51
|
-
zero_points = bit_min -
|
52
|
-
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
74
|
+
zero_points = bit_min - (min_vals / scales)
|
75
|
+
zero_points = torch.clamp(zero_points, bit_min, bit_max)
|
76
|
+
|
77
|
+
# match zero-points to quantized type
|
78
|
+
zero_points = zero_points.to(zp_dtype)
|
53
79
|
|
54
80
|
if scales.ndim == 0:
|
55
81
|
scales = scales.reshape(1)
|
56
82
|
zero_points = zero_points.reshape(1)
|
57
83
|
|
58
84
|
return scales, zero_points
|
85
|
+
|
86
|
+
|
87
|
+
def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
|
88
|
+
"""
|
89
|
+
Calculated the effective quantization range for the given Quantization Args
|
90
|
+
|
91
|
+
:param quantization_args: quantization args to get range of
|
92
|
+
:param device: device to store the range to
|
93
|
+
:return: tuple endpoints for the given quantization range
|
94
|
+
"""
|
95
|
+
if quantization_args.type == QuantizationType.INT:
|
96
|
+
bit_range = 2**quantization_args.num_bits
|
97
|
+
q_max = torch.tensor(bit_range / 2 - 1, device=device)
|
98
|
+
q_min = torch.tensor(-bit_range / 2, device=device)
|
99
|
+
elif quantization_args.type == QuantizationType.FLOAT:
|
100
|
+
if quantization_args.num_bits != 8:
|
101
|
+
raise ValueError(
|
102
|
+
"Floating point quantization is only supported for 8 bits,"
|
103
|
+
f"got {quantization_args.num_bits}"
|
104
|
+
)
|
105
|
+
fp_range_info = torch.finfo(FP8_DTYPE)
|
106
|
+
q_max = torch.tensor(fp_range_info.max, device=device)
|
107
|
+
q_min = torch.tensor(fp_range_info.min, device=device)
|
108
|
+
else:
|
109
|
+
raise ValueError(f"Invalid quantization type {quantization_args.type}")
|
110
|
+
|
111
|
+
return q_min, q_max
|
@@ -15,10 +15,19 @@
|
|
15
15
|
from enum import Enum
|
16
16
|
from typing import Any, Dict, Optional
|
17
17
|
|
18
|
+
import torch
|
18
19
|
from pydantic import BaseModel, Field, validator
|
19
20
|
|
20
21
|
|
21
|
-
__all__ = [
|
22
|
+
__all__ = [
|
23
|
+
"FP8_DTYPE",
|
24
|
+
"QuantizationType",
|
25
|
+
"QuantizationStrategy",
|
26
|
+
"QuantizationArgs",
|
27
|
+
"round_to_quantized_type",
|
28
|
+
]
|
29
|
+
|
30
|
+
FP8_DTYPE = torch.float8_e4m3fn
|
22
31
|
|
23
32
|
|
24
33
|
class QuantizationType(str, Enum):
|
@@ -123,3 +132,38 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
|
|
123
132
|
return QuantizationStrategy.TENSOR
|
124
133
|
|
125
134
|
return value
|
135
|
+
|
136
|
+
def pytorch_dtype(self) -> torch.dtype:
|
137
|
+
if self.type == QuantizationType.FLOAT:
|
138
|
+
return FP8_DTYPE
|
139
|
+
elif self.type == QuantizationType.INT:
|
140
|
+
if self.num_bits <= 8:
|
141
|
+
return torch.int8
|
142
|
+
elif self.num_bits <= 16:
|
143
|
+
return torch.int16
|
144
|
+
else:
|
145
|
+
return torch.int32
|
146
|
+
else:
|
147
|
+
raise ValueError(f"Invalid quantization type {self.type}")
|
148
|
+
|
149
|
+
|
150
|
+
def round_to_quantized_type(
|
151
|
+
tensor: torch.Tensor, args: QuantizationArgs
|
152
|
+
) -> torch.Tensor:
|
153
|
+
"""
|
154
|
+
Rounds each element of the input tensor to the nearest quantized representation,
|
155
|
+
keeping to original dtype
|
156
|
+
|
157
|
+
:param tensor: tensor to round
|
158
|
+
:param args: QuantizationArgs to pull appropriate dtype from
|
159
|
+
:return: rounded tensor
|
160
|
+
"""
|
161
|
+
original_dtype = tensor.dtype
|
162
|
+
if args.type == QuantizationType.FLOAT:
|
163
|
+
rounded = tensor.to(FP8_DTYPE)
|
164
|
+
elif args.type == QuantizationType.INT:
|
165
|
+
rounded = torch.round(tensor)
|
166
|
+
else:
|
167
|
+
raise ValueError(f"Invalid quantization type {args.type}")
|
168
|
+
|
169
|
+
return rounded.to(original_dtype)
|
@@ -16,6 +16,7 @@ from enum import Enum
|
|
16
16
|
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
18
|
from compressed_tensors.config import CompressionFormat
|
19
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
19
20
|
from compressed_tensors.quantization.quant_scheme import (
|
20
21
|
QuantizationScheme,
|
21
22
|
preset_name_to_scheme,
|
@@ -25,6 +26,7 @@ from compressed_tensors.quantization.utils import (
|
|
25
26
|
is_module_quantized,
|
26
27
|
iter_named_leaf_modules,
|
27
28
|
module_type,
|
29
|
+
parse_out_kv_cache_args,
|
28
30
|
)
|
29
31
|
from pydantic import BaseModel, Field
|
30
32
|
from torch.nn import Module
|
@@ -117,7 +119,18 @@ class QuantizationConfig(BaseModel):
|
|
117
119
|
other quantization configs
|
118
120
|
:param format: specifies how the quantized model is stored on disk
|
119
121
|
:quantization_status: specifies the current status of all quantized layers. It is
|
120
|
-
|
122
|
+
assumed all layers are in the same state.
|
123
|
+
:param kv_cache_scheme: optional QuantizationArgs, that specify the
|
124
|
+
quantization of the kv cache. If None, kv cache is not quantized.
|
125
|
+
When applying kv cache quantization to transformer AutoModelForCausalLM,
|
126
|
+
the kv_cache_scheme gets converted into a QuantizationScheme that:
|
127
|
+
- targets the `q_proj` and `k_proj` modules of the model. The outputs
|
128
|
+
of those modules are the keys and values that might be cached
|
129
|
+
- quantizes the outputs of the aformentioned layers, so that
|
130
|
+
keys and values are compressed before storing them in the cache
|
131
|
+
There is an explicit assumption that the model contains modules with
|
132
|
+
`k_proj` and `v_proj` in their names. If this is not the case
|
133
|
+
and kv_cache_scheme != None, the quantization of kv cache will fail
|
121
134
|
:global_compression_ratio: optional informational config to report the model
|
122
135
|
compression ratio acheived by the quantization config
|
123
136
|
:ignore: optional list of layers to ignore from config_groups. Layers in this list
|
@@ -126,6 +139,7 @@ class QuantizationConfig(BaseModel):
|
|
126
139
|
|
127
140
|
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
|
128
141
|
quant_method: str = DEFAULT_QUANTIZATION_METHOD
|
142
|
+
kv_cache_scheme: Optional[QuantizationArgs] = None
|
129
143
|
format: str = DEFAULT_QUANTIZATION_FORMAT
|
130
144
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
131
145
|
global_compression_ratio: Optional[float] = None
|
@@ -154,7 +168,7 @@ class QuantizationConfig(BaseModel):
|
|
154
168
|
) -> Optional["QuantizationConfig"]:
|
155
169
|
"""
|
156
170
|
Converts a model into its associated QuantizationConfig based on the
|
157
|
-
QuantizationScheme attached to each
|
171
|
+
QuantizationScheme attached to each quantized module
|
158
172
|
|
159
173
|
:param model: model to calculate quantization scheme of
|
160
174
|
:return: filled out QuantizationScheme for the input model
|
@@ -195,6 +209,13 @@ class QuantizationConfig(BaseModel):
|
|
195
209
|
# else we leave it off the ignore list, doesn't fall under any of the
|
196
210
|
# existing quantization schemes so it won't be quantized
|
197
211
|
|
212
|
+
kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
|
213
|
+
quant_scheme_to_layers
|
214
|
+
)
|
215
|
+
kv_cache_scheme = (
|
216
|
+
kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
|
217
|
+
)
|
218
|
+
|
198
219
|
config_groups = {}
|
199
220
|
for idx, scheme in enumerate(quant_scheme_to_layers):
|
200
221
|
group_name = "group_" + str(idx)
|
@@ -213,7 +234,19 @@ class QuantizationConfig(BaseModel):
|
|
213
234
|
return QuantizationConfig(
|
214
235
|
config_groups=config_groups,
|
215
236
|
quantization_status=quantization_status,
|
237
|
+
kv_cache_scheme=kv_cache_scheme,
|
216
238
|
global_compression_ratio=compression_ratio,
|
217
239
|
format=format,
|
218
240
|
ignore=consolidated_ignore,
|
219
241
|
)
|
242
|
+
|
243
|
+
def requires_calibration_data(self):
|
244
|
+
for _, scheme in self.config_groups.items():
|
245
|
+
if scheme.input_activations is not None:
|
246
|
+
if not scheme.input_activations.dynamic:
|
247
|
+
return True
|
248
|
+
if scheme.output_activations is not None:
|
249
|
+
if not scheme.output_activations.dynamic:
|
250
|
+
return True
|
251
|
+
|
252
|
+
return False
|
@@ -15,7 +15,11 @@
|
|
15
15
|
from copy import deepcopy
|
16
16
|
from typing import List, Optional
|
17
17
|
|
18
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
19
23
|
from pydantic import BaseModel
|
20
24
|
|
21
25
|
|
@@ -107,13 +111,110 @@ def is_preset_scheme(name: str) -> bool:
|
|
107
111
|
return name.upper() in PRESET_SCHEMES
|
108
112
|
|
109
113
|
|
114
|
+
# 8 bit integer weights and 8 bit activations quantization
|
110
115
|
W8A8 = dict(
|
111
|
-
weights=QuantizationArgs(
|
116
|
+
weights=QuantizationArgs(
|
117
|
+
num_bits=8,
|
118
|
+
type=QuantizationType.INT,
|
119
|
+
strategy=QuantizationStrategy.CHANNEL,
|
120
|
+
symmetric=True,
|
121
|
+
dynamic=False,
|
122
|
+
),
|
123
|
+
input_activations=QuantizationArgs(
|
124
|
+
num_bits=8,
|
125
|
+
type=QuantizationType.INT,
|
126
|
+
strategy=QuantizationStrategy.TOKEN,
|
127
|
+
symmetric=True,
|
128
|
+
dynamic=True,
|
129
|
+
),
|
130
|
+
)
|
131
|
+
|
132
|
+
# 8 bit integer weights only quantization
|
133
|
+
W8A16 = dict(
|
134
|
+
weights=QuantizationArgs(
|
135
|
+
num_bits=8,
|
136
|
+
type=QuantizationType.INT,
|
137
|
+
strategy=QuantizationStrategy.CHANNEL,
|
138
|
+
symmetric=True,
|
139
|
+
dynamic=False,
|
140
|
+
),
|
141
|
+
)
|
142
|
+
|
143
|
+
# 4 bit integer weights only quantization
|
144
|
+
W4A16 = dict(
|
145
|
+
weights=QuantizationArgs(
|
146
|
+
num_bits=4,
|
147
|
+
type=QuantizationType.INT,
|
148
|
+
strategy=QuantizationStrategy.GROUP,
|
149
|
+
group_size=128,
|
150
|
+
symmetric=True,
|
151
|
+
dynamic=False,
|
152
|
+
),
|
153
|
+
)
|
154
|
+
|
155
|
+
# 4 bit integer weights and 8 bit activations quantization
|
156
|
+
W4A8 = dict(
|
157
|
+
weights=QuantizationArgs(
|
158
|
+
num_bits=4,
|
159
|
+
type=QuantizationType.INT,
|
160
|
+
group_size=128,
|
161
|
+
strategy=QuantizationStrategy.GROUP,
|
162
|
+
symmetric=True,
|
163
|
+
dynamic=False,
|
164
|
+
),
|
165
|
+
input_activations=QuantizationArgs(
|
166
|
+
num_bits=8,
|
167
|
+
type=QuantizationType.INT,
|
168
|
+
strategy=QuantizationStrategy.TOKEN,
|
169
|
+
symmetric=True,
|
170
|
+
dynamic=True,
|
171
|
+
),
|
172
|
+
)
|
173
|
+
|
174
|
+
# FP8 weights and FP8 activations quantization
|
175
|
+
FP8 = dict(
|
176
|
+
weights=QuantizationArgs(
|
177
|
+
num_bits=8,
|
178
|
+
type=QuantizationType.FLOAT,
|
179
|
+
strategy=QuantizationStrategy.TENSOR,
|
180
|
+
symmetric=True,
|
181
|
+
dynamic=False,
|
182
|
+
),
|
183
|
+
input_activations=QuantizationArgs(
|
184
|
+
num_bits=8,
|
185
|
+
type=QuantizationType.FLOAT,
|
186
|
+
strategy=QuantizationStrategy.TENSOR,
|
187
|
+
symmetric=True,
|
188
|
+
dynamic=False,
|
189
|
+
),
|
112
190
|
)
|
113
191
|
|
114
|
-
|
192
|
+
# FP8 weights and FP8 dynamic activations quantization
|
193
|
+
FP8_DYNAMIC = dict(
|
194
|
+
weights=QuantizationArgs(
|
195
|
+
num_bits=8,
|
196
|
+
type=QuantizationType.FLOAT,
|
197
|
+
strategy=QuantizationStrategy.CHANNEL,
|
198
|
+
symmetric=True,
|
199
|
+
dynamic=False,
|
200
|
+
),
|
201
|
+
input_activations=QuantizationArgs(
|
202
|
+
num_bits=8,
|
203
|
+
type=QuantizationType.FLOAT,
|
204
|
+
strategy=QuantizationStrategy.TOKEN,
|
205
|
+
symmetric=True,
|
206
|
+
dynamic=True,
|
207
|
+
),
|
208
|
+
)
|
115
209
|
|
116
210
|
PRESET_SCHEMES = {
|
117
|
-
|
211
|
+
# Integer weight only schemes
|
212
|
+
"W8A16": W8A16,
|
118
213
|
"W4A16": W4A16,
|
214
|
+
# Integer weight and activation schemes
|
215
|
+
"W8A8": W8A8,
|
216
|
+
"W4A8": W4A8,
|
217
|
+
# Float weight and activation schemes
|
218
|
+
"FP8": FP8,
|
219
|
+
"FP8_DYNAMIC": FP8_DYNAMIC,
|
119
220
|
}
|
@@ -13,10 +13,13 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
|
16
|
+
import re
|
17
|
+
from typing import List, Optional, Tuple
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.quantization.observers.base import Observer
|
21
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
22
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
20
23
|
from torch.nn import Module
|
21
24
|
from tqdm import tqdm
|
22
25
|
|
@@ -30,8 +33,12 @@ __all__ = [
|
|
30
33
|
"calculate_compression_ratio",
|
31
34
|
"get_torch_bit_depth",
|
32
35
|
"can_quantize",
|
36
|
+
"parse_out_kv_cache_args",
|
37
|
+
"KV_CACHE_TARGETS",
|
38
|
+
"is_kv_cache_quant_scheme",
|
33
39
|
]
|
34
40
|
|
41
|
+
KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
|
35
42
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
36
43
|
|
37
44
|
|
@@ -182,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
182
189
|
total_uncompressed += uncompressed_bits * num_weights
|
183
190
|
|
184
191
|
return total_uncompressed / total_compressed
|
192
|
+
|
193
|
+
|
194
|
+
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
|
195
|
+
"""
|
196
|
+
Check whether the QuantizationScheme targets the kv cache.
|
197
|
+
It does if all the following criteria are met:
|
198
|
+
- the scheme targets either exactly match the KV_CACHE_TARGETS
|
199
|
+
or the match KV_CACHE_TARGETS regex pattern
|
200
|
+
- the scheme quantizes output_activations (we want to quantize the
|
201
|
+
outputs from the KV_CACHE_TARGETS, as their correspond to the
|
202
|
+
keys and values that are to be saved in the cache)
|
203
|
+
|
204
|
+
:param scheme: The QuantizationScheme to investigate
|
205
|
+
:return: boolean flag
|
206
|
+
"""
|
207
|
+
if len(scheme.targets) == 1:
|
208
|
+
# match on the KV_CACHE_TARGETS regex pattern
|
209
|
+
# if there is only one target
|
210
|
+
is_match_targets = any(
|
211
|
+
[re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
# match on the exact KV_CACHE_TARGETS
|
215
|
+
# if there are multiple targets
|
216
|
+
is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
|
217
|
+
|
218
|
+
is_match_output_activations = scheme.output_activations is not None
|
219
|
+
return is_match_targets and is_match_output_activations
|
220
|
+
|
221
|
+
|
222
|
+
def parse_out_kv_cache_args(
|
223
|
+
quant_scheme_to_layers: List[QuantizationScheme],
|
224
|
+
) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
|
225
|
+
"""
|
226
|
+
If possible, parse out the kv cache specific QuantizationArgs
|
227
|
+
from the list of the QuantizationSchemes. If no kv cache
|
228
|
+
specific QuantizationArgs available, this function acts
|
229
|
+
as an identity function
|
230
|
+
|
231
|
+
:param quant_scheme_to_layers: list of QuantizationSchemes
|
232
|
+
:return: kv_cache_args (optional) and the (remaining or original)
|
233
|
+
list of the QuantizationSchemes
|
234
|
+
"""
|
235
|
+
kv_cache_quant_scheme_to_layers = [
|
236
|
+
scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
|
237
|
+
]
|
238
|
+
quant_scheme_to_layers = [
|
239
|
+
scheme
|
240
|
+
for scheme in quant_scheme_to_layers
|
241
|
+
if not is_kv_cache_quant_scheme(scheme)
|
242
|
+
]
|
243
|
+
|
244
|
+
if kv_cache_quant_scheme_to_layers:
|
245
|
+
kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
|
246
|
+
kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
|
247
|
+
else:
|
248
|
+
kv_cache_args = None
|
249
|
+
|
250
|
+
return kv_cache_args, quant_scheme_to_layers
|
@@ -12,13 +12,17 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
15
|
from typing import Optional
|
17
16
|
|
17
|
+
import torch
|
18
18
|
from transformers import AutoConfig
|
19
19
|
|
20
20
|
|
21
|
-
__all__ = [
|
21
|
+
__all__ = [
|
22
|
+
"infer_compressor_from_model_config",
|
23
|
+
"fix_fsdp_module_name",
|
24
|
+
"tensor_follows_mask_structure",
|
25
|
+
]
|
22
26
|
|
23
27
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
24
28
|
|
@@ -61,3 +65,28 @@ def fix_fsdp_module_name(name: str) -> str:
|
|
61
65
|
return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
|
62
66
|
"." + FSDP_WRAPPER_NAME, ""
|
63
67
|
)
|
68
|
+
|
69
|
+
|
70
|
+
def tensor_follows_mask_structure(tensor, mask: str = "2:4") -> bool:
|
71
|
+
"""
|
72
|
+
:param tensor: tensor to check
|
73
|
+
:param mask: mask structure to check for, in the format "n:m"
|
74
|
+
:return: True if the tensor follows the mask structure, False otherwise.
|
75
|
+
Note, some weights can incidentally be zero, so we check for
|
76
|
+
atleast n zeros in each chunk of size m
|
77
|
+
"""
|
78
|
+
|
79
|
+
n, m = tuple(map(int, mask.split(":")))
|
80
|
+
# Reshape the tensor into chunks of size m
|
81
|
+
tensor = tensor.view(-1, m)
|
82
|
+
|
83
|
+
# Count the number of zeros in each chunk
|
84
|
+
zero_counts = (tensor == 0).sum(dim=1)
|
85
|
+
|
86
|
+
# Check if the number of zeros in each chunk atleast n
|
87
|
+
# Greater than sign is needed as some weights can incidentally
|
88
|
+
# be zero
|
89
|
+
if not torch.all(zero_counts >= n).item():
|
90
|
+
raise ValueError()
|
91
|
+
|
92
|
+
return True
|