compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -13,25 +13,31 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from enum import Enum
|
16
|
-
from typing import Dict, List, Optional
|
16
|
+
from typing import Dict, List, Optional, Union
|
17
17
|
|
18
|
-
from compressed_tensors.
|
19
|
-
from compressed_tensors.quantization.
|
18
|
+
from compressed_tensors.config import CompressionFormat
|
19
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
20
|
+
from compressed_tensors.quantization.quant_scheme import (
|
21
|
+
QuantizationScheme,
|
22
|
+
preset_name_to_scheme,
|
23
|
+
)
|
20
24
|
from compressed_tensors.quantization.utils import (
|
21
25
|
calculate_compression_ratio,
|
22
26
|
is_module_quantized,
|
23
27
|
iter_named_leaf_modules,
|
24
28
|
module_type,
|
29
|
+
parse_out_kv_cache_args,
|
25
30
|
)
|
26
31
|
from pydantic import BaseModel, Field
|
27
32
|
from torch.nn import Module
|
28
|
-
from transformers import AutoConfig
|
29
33
|
|
30
34
|
|
31
35
|
__all__ = [
|
32
36
|
"QuantizationStatus",
|
33
37
|
"QuantizationConfig",
|
34
38
|
"LIFECYCLE_ORDER",
|
39
|
+
"DEFAULT_QUANTIZATION_METHOD",
|
40
|
+
"DEFAULT_QUANTIZATION_FORMAT",
|
35
41
|
]
|
36
42
|
|
37
43
|
|
@@ -62,10 +68,33 @@ class QuantizationStatus(str, Enum):
|
|
62
68
|
return
|
63
69
|
|
64
70
|
def __ge__(self, other):
|
71
|
+
if other is None:
|
72
|
+
return True
|
65
73
|
if not isinstance(other, self.__class__):
|
66
74
|
raise NotImplementedError
|
67
75
|
return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
|
68
76
|
|
77
|
+
def __gt__(self, other):
|
78
|
+
if other is None:
|
79
|
+
return True
|
80
|
+
if not isinstance(other, self.__class__):
|
81
|
+
raise NotImplementedError
|
82
|
+
return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
|
83
|
+
|
84
|
+
def __lt__(self, other):
|
85
|
+
if other is None:
|
86
|
+
return False
|
87
|
+
if not isinstance(other, self.__class__):
|
88
|
+
raise NotImplementedError
|
89
|
+
return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
|
90
|
+
|
91
|
+
def __le__(self, other):
|
92
|
+
if other is None:
|
93
|
+
return False
|
94
|
+
if not isinstance(other, self.__class__):
|
95
|
+
raise NotImplementedError
|
96
|
+
return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
|
97
|
+
|
69
98
|
|
70
99
|
LIFECYCLE_ORDER = [
|
71
100
|
QuantizationStatus.INITIALIZED,
|
@@ -74,6 +103,9 @@ LIFECYCLE_ORDER = [
|
|
74
103
|
QuantizationStatus.COMPRESSED,
|
75
104
|
]
|
76
105
|
|
106
|
+
DEFAULT_QUANTIZATION_METHOD = "compressed-tensors"
|
107
|
+
DEFAULT_QUANTIZATION_FORMAT = "fakequant"
|
108
|
+
|
77
109
|
|
78
110
|
class QuantizationConfig(BaseModel):
|
79
111
|
"""
|
@@ -81,45 +113,62 @@ class QuantizationConfig(BaseModel):
|
|
81
113
|
mapped to a QuantizationScheme in config_groups.
|
82
114
|
|
83
115
|
:param config_groups: dict of QuantizationSchemes specifying the quantization
|
84
|
-
settings for each quantized layer
|
116
|
+
settings for each quantized layer. A group could also be a reference to
|
117
|
+
a predefined scheme name, mapped to a list of its target layers/classes
|
85
118
|
:param quant_method: a constant used to differentiate sparseML quantization from
|
86
119
|
other quantization configs
|
87
120
|
:param format: specifies how the quantized model is stored on disk
|
88
121
|
:quantization_status: specifies the current status of all quantized layers. It is
|
89
|
-
|
122
|
+
assumed all layers are in the same state.
|
123
|
+
:param kv_cache_scheme: optional QuantizationArgs, that specify the
|
124
|
+
quantization of the kv cache. If None, kv cache is not quantized.
|
125
|
+
When applying kv cache quantization to transformer AutoModelForCausalLM,
|
126
|
+
the kv_cache_scheme gets converted into a QuantizationScheme that:
|
127
|
+
- targets the `q_proj` and `k_proj` modules of the model. The outputs
|
128
|
+
of those modules are the keys and values that might be cached
|
129
|
+
- quantizes the outputs of the aformentioned layers, so that
|
130
|
+
keys and values are compressed before storing them in the cache
|
131
|
+
There is an explicit assumption that the model contains modules with
|
132
|
+
`k_proj` and `v_proj` in their names. If this is not the case
|
133
|
+
and kv_cache_scheme != None, the quantization of kv cache will fail
|
90
134
|
:global_compression_ratio: optional informational config to report the model
|
91
135
|
compression ratio acheived by the quantization config
|
92
136
|
:ignore: optional list of layers to ignore from config_groups. Layers in this list
|
93
137
|
are not quantized even if they match up with a target in config_groups
|
94
138
|
"""
|
95
139
|
|
96
|
-
config_groups: Dict[str, QuantizationScheme]
|
97
|
-
quant_method: str =
|
98
|
-
|
140
|
+
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
|
141
|
+
quant_method: str = DEFAULT_QUANTIZATION_METHOD
|
142
|
+
kv_cache_scheme: Optional[QuantizationArgs] = None
|
143
|
+
format: str = DEFAULT_QUANTIZATION_FORMAT
|
99
144
|
quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
|
100
145
|
global_compression_ratio: Optional[float] = None
|
101
146
|
ignore: Optional[List[str]] = Field(default_factory=list)
|
102
147
|
|
103
|
-
|
104
|
-
def from_model_config(model_name_or_path) -> "QuantizationConfig":
|
148
|
+
def model_post_init(self, __context):
|
105
149
|
"""
|
106
|
-
|
107
|
-
|
108
|
-
:param pretrained_model_name_or_path: path to model config on disk or HF hub
|
109
|
-
:return: instantiated QuantizationConfig if config contains a quant config
|
150
|
+
updates any quantization schemes defined as presets to be fully loaded
|
151
|
+
schemes
|
110
152
|
"""
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
153
|
+
for group_name, targets_or_scheme in self.config_groups.items():
|
154
|
+
if isinstance(targets_or_scheme, QuantizationScheme):
|
155
|
+
continue # scheme already defined
|
156
|
+
self.config_groups[group_name] = preset_name_to_scheme(
|
157
|
+
name=group_name,
|
158
|
+
targets=targets_or_scheme,
|
159
|
+
)
|
160
|
+
|
161
|
+
def to_dict(self):
|
162
|
+
# for compatibility with HFQuantizer
|
163
|
+
return self.dict()
|
117
164
|
|
118
165
|
@staticmethod
|
119
|
-
def from_pretrained(
|
166
|
+
def from_pretrained(
|
167
|
+
model: Module, format: Optional[str] = None
|
168
|
+
) -> Optional["QuantizationConfig"]:
|
120
169
|
"""
|
121
170
|
Converts a model into its associated QuantizationConfig based on the
|
122
|
-
QuantizationScheme attached to each
|
171
|
+
QuantizationScheme attached to each quantized module
|
123
172
|
|
124
173
|
:param model: model to calculate quantization scheme of
|
125
174
|
:return: filled out QuantizationScheme for the input model
|
@@ -147,6 +196,9 @@ class QuantizationConfig(BaseModel):
|
|
147
196
|
if not match_found:
|
148
197
|
quant_scheme_to_layers.append(scheme)
|
149
198
|
|
199
|
+
if len(quant_scheme_to_layers) == 0: # No quantized layers
|
200
|
+
return None
|
201
|
+
|
150
202
|
# clean up ignore list, we can leave out layers types if none of the
|
151
203
|
# instances are quantized
|
152
204
|
consolidated_ignore = []
|
@@ -157,15 +209,44 @@ class QuantizationConfig(BaseModel):
|
|
157
209
|
# else we leave it off the ignore list, doesn't fall under any of the
|
158
210
|
# existing quantization schemes so it won't be quantized
|
159
211
|
|
212
|
+
kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args(
|
213
|
+
quant_scheme_to_layers
|
214
|
+
)
|
215
|
+
kv_cache_scheme = (
|
216
|
+
kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args
|
217
|
+
)
|
218
|
+
|
160
219
|
config_groups = {}
|
161
220
|
for idx, scheme in enumerate(quant_scheme_to_layers):
|
162
221
|
group_name = "group_" + str(idx)
|
163
222
|
config_groups[group_name] = scheme
|
164
223
|
|
224
|
+
# TODO: this is incorrect in compressed mode, since we are overwriting the
|
225
|
+
# original weight we lose the uncompressed bit_depth indo
|
165
226
|
compression_ratio = calculate_compression_ratio(model)
|
227
|
+
|
228
|
+
if format is None:
|
229
|
+
if quantization_status == QuantizationStatus.COMPRESSED:
|
230
|
+
format = CompressionFormat.int_quantized.value
|
231
|
+
else:
|
232
|
+
format = CompressionFormat.dense.value
|
233
|
+
|
166
234
|
return QuantizationConfig(
|
167
235
|
config_groups=config_groups,
|
168
236
|
quantization_status=quantization_status,
|
237
|
+
kv_cache_scheme=kv_cache_scheme,
|
169
238
|
global_compression_ratio=compression_ratio,
|
239
|
+
format=format,
|
170
240
|
ignore=consolidated_ignore,
|
171
241
|
)
|
242
|
+
|
243
|
+
def requires_calibration_data(self):
|
244
|
+
for _, scheme in self.config_groups.items():
|
245
|
+
if scheme.input_activations is not None:
|
246
|
+
if not scheme.input_activations.dynamic:
|
247
|
+
return True
|
248
|
+
if scheme.output_activations is not None:
|
249
|
+
if not scheme.output_activations.dynamic:
|
250
|
+
return True
|
251
|
+
|
252
|
+
return False
|
@@ -12,13 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from copy import deepcopy
|
15
16
|
from typing import List, Optional
|
16
17
|
|
17
|
-
from compressed_tensors.quantization.quant_args import
|
18
|
+
from compressed_tensors.quantization.quant_args import (
|
19
|
+
QuantizationArgs,
|
20
|
+
QuantizationStrategy,
|
21
|
+
QuantizationType,
|
22
|
+
)
|
18
23
|
from pydantic import BaseModel
|
19
24
|
|
20
25
|
|
21
|
-
__all__ = [
|
26
|
+
__all__ = [
|
27
|
+
"QuantizationScheme",
|
28
|
+
"preset_name_to_scheme",
|
29
|
+
"is_preset_scheme",
|
30
|
+
]
|
22
31
|
|
23
32
|
|
24
33
|
class QuantizationScheme(BaseModel):
|
@@ -37,3 +46,175 @@ class QuantizationScheme(BaseModel):
|
|
37
46
|
weights: Optional[QuantizationArgs] = None
|
38
47
|
input_activations: Optional[QuantizationArgs] = None
|
39
48
|
output_activations: Optional[QuantizationArgs] = None
|
49
|
+
|
50
|
+
@classmethod
|
51
|
+
def default_scheme(
|
52
|
+
cls,
|
53
|
+
targets: Optional[List[str]] = None,
|
54
|
+
):
|
55
|
+
|
56
|
+
if targets is None:
|
57
|
+
# default to quantizing all Linear layers
|
58
|
+
targets = ["Linear"]
|
59
|
+
|
60
|
+
# default to 8 bit integer symmetric quantization
|
61
|
+
# for weights
|
62
|
+
weights = QuantizationArgs(num_bits=8, symmetric=True)
|
63
|
+
|
64
|
+
# default to 8 bit integer asymmetric quantization
|
65
|
+
input_activations = QuantizationArgs(num_bits=8, symmetric=True)
|
66
|
+
|
67
|
+
# Do not quantize the output activations
|
68
|
+
# by default
|
69
|
+
output_activations = None
|
70
|
+
|
71
|
+
return cls(
|
72
|
+
targets=targets,
|
73
|
+
weights=weights,
|
74
|
+
input_activations=input_activations,
|
75
|
+
output_activations=output_activations,
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
"""
|
80
|
+
Pre-Set Quantization Scheme Args
|
81
|
+
"""
|
82
|
+
|
83
|
+
|
84
|
+
def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
|
85
|
+
"""
|
86
|
+
:param name: preset quantization settings name. must exist in upper case in
|
87
|
+
PRESET_SCHEMES
|
88
|
+
:param targets: list of quantization targets to be passed to the Scheme
|
89
|
+
:return: new QuantizationScheme for a given name with the given targets
|
90
|
+
"""
|
91
|
+
name = name.upper()
|
92
|
+
|
93
|
+
if name not in PRESET_SCHEMES:
|
94
|
+
raise KeyError(
|
95
|
+
f"Unknown preset scheme name {name}, "
|
96
|
+
f"available names: {list(PRESET_SCHEMES.keys())}"
|
97
|
+
)
|
98
|
+
|
99
|
+
scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
|
100
|
+
return QuantizationScheme(
|
101
|
+
targets=targets,
|
102
|
+
**scheme_args,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
def is_preset_scheme(name: str) -> bool:
|
107
|
+
"""
|
108
|
+
:param name: preset quantization settings name
|
109
|
+
:return: True if the name is a preset scheme name
|
110
|
+
"""
|
111
|
+
return name.upper() in PRESET_SCHEMES
|
112
|
+
|
113
|
+
|
114
|
+
# 8 bit integer weights and 8 bit activations quantization
|
115
|
+
W8A8 = dict(
|
116
|
+
weights=QuantizationArgs(
|
117
|
+
num_bits=8,
|
118
|
+
type=QuantizationType.INT,
|
119
|
+
strategy=QuantizationStrategy.CHANNEL,
|
120
|
+
symmetric=True,
|
121
|
+
dynamic=False,
|
122
|
+
),
|
123
|
+
input_activations=QuantizationArgs(
|
124
|
+
num_bits=8,
|
125
|
+
type=QuantizationType.INT,
|
126
|
+
strategy=QuantizationStrategy.TOKEN,
|
127
|
+
symmetric=True,
|
128
|
+
dynamic=True,
|
129
|
+
),
|
130
|
+
)
|
131
|
+
|
132
|
+
# 8 bit integer weights only quantization
|
133
|
+
W8A16 = dict(
|
134
|
+
weights=QuantizationArgs(
|
135
|
+
num_bits=8,
|
136
|
+
type=QuantizationType.INT,
|
137
|
+
strategy=QuantizationStrategy.CHANNEL,
|
138
|
+
symmetric=True,
|
139
|
+
dynamic=False,
|
140
|
+
),
|
141
|
+
)
|
142
|
+
|
143
|
+
# 4 bit integer weights only quantization
|
144
|
+
W4A16 = dict(
|
145
|
+
weights=QuantizationArgs(
|
146
|
+
num_bits=4,
|
147
|
+
type=QuantizationType.INT,
|
148
|
+
strategy=QuantizationStrategy.GROUP,
|
149
|
+
group_size=128,
|
150
|
+
symmetric=True,
|
151
|
+
dynamic=False,
|
152
|
+
),
|
153
|
+
)
|
154
|
+
|
155
|
+
# 4 bit integer weights and 8 bit activations quantization
|
156
|
+
W4A8 = dict(
|
157
|
+
weights=QuantizationArgs(
|
158
|
+
num_bits=4,
|
159
|
+
type=QuantizationType.INT,
|
160
|
+
group_size=128,
|
161
|
+
strategy=QuantizationStrategy.GROUP,
|
162
|
+
symmetric=True,
|
163
|
+
dynamic=False,
|
164
|
+
),
|
165
|
+
input_activations=QuantizationArgs(
|
166
|
+
num_bits=8,
|
167
|
+
type=QuantizationType.INT,
|
168
|
+
strategy=QuantizationStrategy.TOKEN,
|
169
|
+
symmetric=True,
|
170
|
+
dynamic=True,
|
171
|
+
),
|
172
|
+
)
|
173
|
+
|
174
|
+
# FP8 weights and FP8 activations quantization
|
175
|
+
FP8 = dict(
|
176
|
+
weights=QuantizationArgs(
|
177
|
+
num_bits=8,
|
178
|
+
type=QuantizationType.FLOAT,
|
179
|
+
strategy=QuantizationStrategy.TENSOR,
|
180
|
+
symmetric=True,
|
181
|
+
dynamic=False,
|
182
|
+
),
|
183
|
+
input_activations=QuantizationArgs(
|
184
|
+
num_bits=8,
|
185
|
+
type=QuantizationType.FLOAT,
|
186
|
+
strategy=QuantizationStrategy.TENSOR,
|
187
|
+
symmetric=True,
|
188
|
+
dynamic=False,
|
189
|
+
),
|
190
|
+
)
|
191
|
+
|
192
|
+
# FP8 weights and FP8 dynamic activations quantization
|
193
|
+
FP8_DYNAMIC = dict(
|
194
|
+
weights=QuantizationArgs(
|
195
|
+
num_bits=8,
|
196
|
+
type=QuantizationType.FLOAT,
|
197
|
+
strategy=QuantizationStrategy.CHANNEL,
|
198
|
+
symmetric=True,
|
199
|
+
dynamic=False,
|
200
|
+
),
|
201
|
+
input_activations=QuantizationArgs(
|
202
|
+
num_bits=8,
|
203
|
+
type=QuantizationType.FLOAT,
|
204
|
+
strategy=QuantizationStrategy.TOKEN,
|
205
|
+
symmetric=True,
|
206
|
+
dynamic=True,
|
207
|
+
),
|
208
|
+
)
|
209
|
+
|
210
|
+
PRESET_SCHEMES = {
|
211
|
+
# Integer weight only schemes
|
212
|
+
"W8A16": W8A16,
|
213
|
+
"W4A16": W4A16,
|
214
|
+
# Integer weight and activation schemes
|
215
|
+
"W8A8": W8A8,
|
216
|
+
"W4A8": W4A8,
|
217
|
+
# Float weight and activation schemes
|
218
|
+
"FP8": FP8,
|
219
|
+
"FP8_DYNAMIC": FP8_DYNAMIC,
|
220
|
+
}
|
@@ -12,21 +12,50 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import logging
|
16
|
+
import re
|
17
|
+
from typing import List, Optional, Tuple
|
16
18
|
|
17
19
|
import torch
|
20
|
+
from compressed_tensors.quantization.observers.base import Observer
|
21
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
22
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
18
23
|
from torch.nn import Module
|
19
24
|
from tqdm import tqdm
|
20
25
|
|
21
26
|
|
22
27
|
__all__ = [
|
28
|
+
"infer_quantization_status",
|
23
29
|
"is_module_quantized",
|
24
30
|
"is_model_quantized",
|
25
31
|
"iter_named_leaf_modules",
|
26
32
|
"module_type",
|
27
33
|
"calculate_compression_ratio",
|
34
|
+
"get_torch_bit_depth",
|
35
|
+
"can_quantize",
|
36
|
+
"parse_out_kv_cache_args",
|
37
|
+
"KV_CACHE_TARGETS",
|
38
|
+
"is_kv_cache_quant_scheme",
|
28
39
|
]
|
29
40
|
|
41
|
+
KV_CACHE_TARGETS = ["re:.*k_proj", "re:.*v_proj"]
|
42
|
+
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
43
|
+
|
44
|
+
|
45
|
+
def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
|
46
|
+
"""
|
47
|
+
Checks the quantization status of a model. Assumes all modules in the model have
|
48
|
+
the same status, so only the first quantized model is checked.
|
49
|
+
|
50
|
+
:param model: model to check quantization status for
|
51
|
+
:return: quantization status if the model is quantized, otherwise None
|
52
|
+
"""
|
53
|
+
for module in model.modules():
|
54
|
+
status = getattr(module, "quantization_status", None)
|
55
|
+
if status is not None:
|
56
|
+
return status
|
57
|
+
return None
|
58
|
+
|
30
59
|
|
31
60
|
def is_module_quantized(module: Module) -> bool:
|
32
61
|
"""
|
@@ -78,11 +107,60 @@ def module_type(module: Module) -> str:
|
|
78
107
|
|
79
108
|
|
80
109
|
def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
|
81
|
-
|
82
|
-
|
110
|
+
"""
|
111
|
+
Yields modules that do not have any submodules except observers. The observers
|
112
|
+
themselves are not yielded
|
113
|
+
|
114
|
+
:param model: model to get leaf modules of
|
115
|
+
:returns: generator tuple of (name, leaf_submodule)
|
116
|
+
"""
|
83
117
|
for name, submodule in model.named_modules():
|
84
|
-
|
118
|
+
children = list(submodule.children())
|
119
|
+
if len(children) == 0 and not isinstance(submodule, Observer):
|
85
120
|
yield name, submodule
|
121
|
+
else:
|
122
|
+
has_non_observer_children = False
|
123
|
+
for child in children:
|
124
|
+
if not isinstance(child, Observer):
|
125
|
+
has_non_observer_children = True
|
126
|
+
|
127
|
+
if not has_non_observer_children:
|
128
|
+
yield name, submodule
|
129
|
+
|
130
|
+
|
131
|
+
def get_torch_bit_depth(value: torch.Tensor) -> int:
|
132
|
+
"""
|
133
|
+
Determine the number of bits used to represent the dtype of a tensor
|
134
|
+
|
135
|
+
:param value: tensor to check bit depth of
|
136
|
+
:return: bit depth of each element in the value tensor
|
137
|
+
"""
|
138
|
+
try:
|
139
|
+
bit_depth = torch.finfo(value.dtype).bits
|
140
|
+
except TypeError:
|
141
|
+
bit_depth = torch.iinfo(value.dtype).bits
|
142
|
+
|
143
|
+
return bit_depth
|
144
|
+
|
145
|
+
|
146
|
+
def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa
|
147
|
+
"""
|
148
|
+
Checks if value can be quantized by quant_args.
|
149
|
+
|
150
|
+
:param value: tensor to check for quantization
|
151
|
+
:param quant_args: QuantizationArgs to use for quantization
|
152
|
+
:return: False if value is already quantized to quant_args or value is incompatible
|
153
|
+
with quant_args, True if value can be quantized with quant_args
|
154
|
+
"""
|
155
|
+
bit_depth = get_torch_bit_depth(value)
|
156
|
+
requested_depth = quant_args.num_bits
|
157
|
+
if bit_depth < quant_args.num_bits:
|
158
|
+
_LOGGER.warn(
|
159
|
+
f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
|
160
|
+
"The QuantizationArgs provided are not compatible with the input tensor."
|
161
|
+
)
|
162
|
+
|
163
|
+
return bit_depth > quant_args.num_bits
|
86
164
|
|
87
165
|
|
88
166
|
def calculate_compression_ratio(model: Module) -> float:
|
@@ -101,10 +179,7 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
101
179
|
desc="Calculating quantization compression ratio",
|
102
180
|
):
|
103
181
|
for parameter in model.parameters():
|
104
|
-
|
105
|
-
uncompressed_bits = torch.finfo(parameter.dtype).bits
|
106
|
-
except TypeError:
|
107
|
-
uncompressed_bits = torch.iinfo(parameter.dtype).bits
|
182
|
+
uncompressed_bits = get_torch_bit_depth(parameter)
|
108
183
|
compressed_bits = uncompressed_bits
|
109
184
|
if is_module_quantized(submodule):
|
110
185
|
compressed_bits = submodule.quantization_scheme.weights.num_bits
|
@@ -114,3 +189,62 @@ def calculate_compression_ratio(model: Module) -> float:
|
|
114
189
|
total_uncompressed += uncompressed_bits * num_weights
|
115
190
|
|
116
191
|
return total_uncompressed / total_compressed
|
192
|
+
|
193
|
+
|
194
|
+
def is_kv_cache_quant_scheme(scheme: QuantizationScheme) -> bool:
|
195
|
+
"""
|
196
|
+
Check whether the QuantizationScheme targets the kv cache.
|
197
|
+
It does if all the following criteria are met:
|
198
|
+
- the scheme targets either exactly match the KV_CACHE_TARGETS
|
199
|
+
or the match KV_CACHE_TARGETS regex pattern
|
200
|
+
- the scheme quantizes output_activations (we want to quantize the
|
201
|
+
outputs from the KV_CACHE_TARGETS, as their correspond to the
|
202
|
+
keys and values that are to be saved in the cache)
|
203
|
+
|
204
|
+
:param scheme: The QuantizationScheme to investigate
|
205
|
+
:return: boolean flag
|
206
|
+
"""
|
207
|
+
if len(scheme.targets) == 1:
|
208
|
+
# match on the KV_CACHE_TARGETS regex pattern
|
209
|
+
# if there is only one target
|
210
|
+
is_match_targets = any(
|
211
|
+
[re.match(pattern[3:], scheme.targets[0]) for pattern in KV_CACHE_TARGETS]
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
# match on the exact KV_CACHE_TARGETS
|
215
|
+
# if there are multiple targets
|
216
|
+
is_match_targets = set(KV_CACHE_TARGETS) == set(scheme.targets)
|
217
|
+
|
218
|
+
is_match_output_activations = scheme.output_activations is not None
|
219
|
+
return is_match_targets and is_match_output_activations
|
220
|
+
|
221
|
+
|
222
|
+
def parse_out_kv_cache_args(
|
223
|
+
quant_scheme_to_layers: List[QuantizationScheme],
|
224
|
+
) -> Tuple[Optional[QuantizationArgs], List[QuantizationScheme]]:
|
225
|
+
"""
|
226
|
+
If possible, parse out the kv cache specific QuantizationArgs
|
227
|
+
from the list of the QuantizationSchemes. If no kv cache
|
228
|
+
specific QuantizationArgs available, this function acts
|
229
|
+
as an identity function
|
230
|
+
|
231
|
+
:param quant_scheme_to_layers: list of QuantizationSchemes
|
232
|
+
:return: kv_cache_args (optional) and the (remaining or original)
|
233
|
+
list of the QuantizationSchemes
|
234
|
+
"""
|
235
|
+
kv_cache_quant_scheme_to_layers = [
|
236
|
+
scheme for scheme in quant_scheme_to_layers if is_kv_cache_quant_scheme(scheme)
|
237
|
+
]
|
238
|
+
quant_scheme_to_layers = [
|
239
|
+
scheme
|
240
|
+
for scheme in quant_scheme_to_layers
|
241
|
+
if not is_kv_cache_quant_scheme(scheme)
|
242
|
+
]
|
243
|
+
|
244
|
+
if kv_cache_quant_scheme_to_layers:
|
245
|
+
kv_cache_quant_scheme_to_layers = kv_cache_quant_scheme_to_layers[0]
|
246
|
+
kv_cache_args = kv_cache_quant_scheme_to_layers.output_activations
|
247
|
+
else:
|
248
|
+
kv_cache_args = None
|
249
|
+
|
250
|
+
return kv_cache_args, quant_scheme_to_layers
|