compressed-tensors 0.3.3__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/base.py +3 -1
- compressed_tensors/compressors/__init__.py +9 -1
- compressed_tensors/compressors/base.py +12 -55
- compressed_tensors/compressors/dense.py +5 -5
- compressed_tensors/compressors/helpers.py +12 -12
- compressed_tensors/compressors/marlin_24.py +251 -0
- compressed_tensors/compressors/model_compressor.py +336 -0
- compressed_tensors/compressors/naive_quantized.py +144 -0
- compressed_tensors/compressors/pack_quantized.py +219 -0
- compressed_tensors/compressors/sparse_bitmask.py +4 -4
- compressed_tensors/config/base.py +9 -4
- compressed_tensors/config/dense.py +4 -4
- compressed_tensors/config/sparse_bitmask.py +3 -3
- compressed_tensors/quantization/lifecycle/__init__.py +2 -0
- compressed_tensors/quantization/lifecycle/apply.py +204 -31
- compressed_tensors/quantization/lifecycle/calibration.py +20 -1
- compressed_tensors/quantization/lifecycle/compressed.py +69 -0
- compressed_tensors/quantization/lifecycle/forward.py +214 -62
- compressed_tensors/quantization/lifecycle/frozen.py +4 -0
- compressed_tensors/quantization/lifecycle/helpers.py +53 -0
- compressed_tensors/quantization/lifecycle/initialize.py +62 -5
- compressed_tensors/quantization/observers/base.py +66 -23
- compressed_tensors/quantization/observers/helpers.py +69 -11
- compressed_tensors/quantization/observers/memoryless.py +17 -9
- compressed_tensors/quantization/observers/min_max.py +44 -13
- compressed_tensors/quantization/quant_args.py +47 -3
- compressed_tensors/quantization/quant_config.py +104 -23
- compressed_tensors/quantization/quant_scheme.py +183 -2
- compressed_tensors/quantization/utils/helpers.py +142 -8
- compressed_tensors/utils/__init__.py +4 -0
- compressed_tensors/utils/helpers.py +54 -7
- compressed_tensors/utils/offload.py +104 -0
- compressed_tensors/utils/permutations_24.py +65 -0
- compressed_tensors/utils/safetensors_load.py +3 -2
- compressed_tensors/utils/semi_structured_conversions.py +341 -0
- compressed_tensors/version.py +53 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +47 -8
- compressed_tensors-0.5.0.dist-info/RECORD +48 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
- compressed_tensors-0.3.3.dist-info/RECORD +0 -38
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ from typing import Dict, Generator, List, Tuple, Union
|
|
17
17
|
|
18
18
|
import numpy
|
19
19
|
import torch
|
20
|
-
from compressed_tensors.compressors import
|
20
|
+
from compressed_tensors.compressors import Compressor
|
21
21
|
from compressed_tensors.config import CompressionFormat
|
22
22
|
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
23
23
|
from safetensors import safe_open
|
@@ -37,8 +37,8 @@ __all__ = [
|
|
37
37
|
_LOGGER: logging.Logger = logging.getLogger(__name__)
|
38
38
|
|
39
39
|
|
40
|
-
@
|
41
|
-
class BitmaskCompressor(
|
40
|
+
@Compressor.register(name=CompressionFormat.sparse_bitmask.value)
|
41
|
+
class BitmaskCompressor(Compressor):
|
42
42
|
"""
|
43
43
|
Compression for sparse models using bitmasks. Non-zero weights are stored in a 1d
|
44
44
|
values tensor, with their locations stored in a 2d bitmask
|
@@ -72,7 +72,7 @@ class BitmaskCompressor(ModelCompressor):
|
|
72
72
|
return compressed_dict
|
73
73
|
|
74
74
|
def decompress(
|
75
|
-
self, path_to_model_or_tensors: str, device: str = "cpu"
|
75
|
+
self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
|
76
76
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
77
77
|
"""
|
78
78
|
Reads a bitmask compressed state dict located
|
@@ -19,17 +19,22 @@ from compressed_tensors.registry import RegistryMixin
|
|
19
19
|
from pydantic import BaseModel
|
20
20
|
|
21
21
|
|
22
|
-
__all__ = ["
|
22
|
+
__all__ = ["SparsityCompressionConfig", "CompressionFormat"]
|
23
23
|
|
24
24
|
|
25
25
|
class CompressionFormat(Enum):
|
26
|
-
|
26
|
+
dense = "dense"
|
27
27
|
sparse_bitmask = "sparse-bitmask"
|
28
|
+
int_quantized = "int-quantized"
|
29
|
+
float_quantized = "float-quantized"
|
30
|
+
naive_quantized = "naive-quantized"
|
31
|
+
pack_quantized = "pack-quantized"
|
32
|
+
marlin_24 = "marlin-24"
|
28
33
|
|
29
34
|
|
30
|
-
class
|
35
|
+
class SparsityCompressionConfig(RegistryMixin, BaseModel):
|
31
36
|
"""
|
32
|
-
Base data class for storing compression parameters
|
37
|
+
Base data class for storing sparsity compression parameters
|
33
38
|
|
34
39
|
:param format: name of compression format
|
35
40
|
:param global_sparsity: average sparsity of the entire model
|
@@ -14,14 +14,14 @@
|
|
14
14
|
|
15
15
|
from typing import Optional
|
16
16
|
|
17
|
-
from compressed_tensors.config import
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = ["DenseSparsityConfig"]
|
21
21
|
|
22
22
|
|
23
|
-
@
|
24
|
-
class DenseSparsityConfig(
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.dense.value)
|
24
|
+
class DenseSparsityConfig(SparsityCompressionConfig):
|
25
25
|
"""
|
26
26
|
Identity configuration for storing a sparse model in
|
27
27
|
an uncompressed dense format
|
@@ -31,6 +31,6 @@ class DenseSparsityConfig(CompressionConfig):
|
|
31
31
|
"unstructured", "2:4", "8:16" etc
|
32
32
|
"""
|
33
33
|
|
34
|
-
format: str = CompressionFormat.
|
34
|
+
format: str = CompressionFormat.dense.value
|
35
35
|
global_sparsity: Optional[float] = 0.0
|
36
36
|
sparsity_structure: Optional[str] = "unstructured"
|
@@ -14,14 +14,14 @@
|
|
14
14
|
|
15
15
|
from typing import Optional
|
16
16
|
|
17
|
-
from compressed_tensors.config import
|
17
|
+
from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
|
18
18
|
|
19
19
|
|
20
20
|
__all__ = ["BitmaskConfig"]
|
21
21
|
|
22
22
|
|
23
|
-
@
|
24
|
-
class BitmaskConfig(
|
23
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_bitmask.value)
|
24
|
+
class BitmaskConfig(SparsityCompressionConfig):
|
25
25
|
"""
|
26
26
|
Configuration for storing a sparse model using
|
27
27
|
bitmask compression
|
@@ -12,22 +12,38 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import logging
|
15
16
|
import re
|
16
17
|
from collections import OrderedDict
|
17
|
-
from typing import Dict, Iterable, Optional
|
18
|
+
from typing import Dict, Iterable, List, Optional
|
19
|
+
from typing import OrderedDict as OrderedDictType
|
20
|
+
from typing import Union
|
18
21
|
|
22
|
+
import torch
|
19
23
|
from compressed_tensors.quantization.lifecycle.calibration import (
|
20
24
|
set_module_for_calibration,
|
21
25
|
)
|
26
|
+
from compressed_tensors.quantization.lifecycle.compressed import (
|
27
|
+
compress_quantized_weights,
|
28
|
+
)
|
22
29
|
from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quantization
|
23
30
|
from compressed_tensors.quantization.lifecycle.initialize import (
|
24
31
|
initialize_module_for_quantization,
|
25
32
|
)
|
33
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
26
34
|
from compressed_tensors.quantization.quant_config import (
|
27
35
|
QuantizationConfig,
|
28
36
|
QuantizationStatus,
|
29
37
|
)
|
30
|
-
from compressed_tensors.quantization.
|
38
|
+
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
39
|
+
from compressed_tensors.quantization.utils import (
|
40
|
+
KV_CACHE_TARGETS,
|
41
|
+
infer_quantization_status,
|
42
|
+
is_kv_cache_quant_scheme,
|
43
|
+
iter_named_leaf_modules,
|
44
|
+
)
|
45
|
+
from compressed_tensors.utils.helpers import fix_fsdp_module_name
|
46
|
+
from compressed_tensors.utils.offload import update_parameter_data
|
31
47
|
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
32
48
|
from torch.nn import Module
|
33
49
|
|
@@ -36,13 +52,16 @@ __all__ = [
|
|
36
52
|
"load_pretrained_quantization",
|
37
53
|
"apply_quantization_config",
|
38
54
|
"apply_quantization_status",
|
39
|
-
"
|
55
|
+
"find_name_or_class_matches",
|
40
56
|
]
|
41
57
|
|
42
58
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
43
59
|
from compressed_tensors.utils.safetensors_load import get_quantization_state_dict
|
44
60
|
|
45
61
|
|
62
|
+
_LOGGER = logging.getLogger(__name__)
|
63
|
+
|
64
|
+
|
46
65
|
def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
47
66
|
"""
|
48
67
|
Loads the quantization parameters (scale and zero point) from model_name_or_path to
|
@@ -84,7 +103,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
|
|
84
103
|
)
|
85
104
|
|
86
105
|
|
87
|
-
def apply_quantization_config(model: Module, config: QuantizationConfig):
|
106
|
+
def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
|
88
107
|
"""
|
89
108
|
Initializes the model for quantization in-place based on the given config
|
90
109
|
|
@@ -94,21 +113,73 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
|
|
94
113
|
# build mapping of targets to schemes for easier matching
|
95
114
|
# use ordered dict to preserve target ordering in config
|
96
115
|
target_to_scheme = OrderedDict()
|
116
|
+
config = process_quantization_config(config)
|
117
|
+
names_to_scheme = OrderedDict()
|
97
118
|
for scheme in config.config_groups.values():
|
98
119
|
for target in scheme.targets:
|
99
120
|
target_to_scheme[target] = scheme
|
100
121
|
|
122
|
+
# list of submodules to ignore
|
123
|
+
ignored_submodules = []
|
101
124
|
# mark appropriate layers for quantization by setting their quantization schemes
|
102
125
|
for name, submodule in iter_named_leaf_modules(model):
|
103
|
-
|
126
|
+
# potentially fix module name to remove FSDP wrapper prefix
|
127
|
+
name = fix_fsdp_module_name(name)
|
128
|
+
if find_name_or_class_matches(name, submodule, config.ignore):
|
129
|
+
ignored_submodules.append(name)
|
104
130
|
continue # layer matches ignore list, continue
|
105
|
-
|
106
|
-
if
|
131
|
+
targets = find_name_or_class_matches(name, submodule, target_to_scheme)
|
132
|
+
if targets:
|
107
133
|
# target matched - add layer and scheme to target list
|
108
|
-
submodule.quantization_scheme =
|
134
|
+
submodule.quantization_scheme = _scheme_from_targets(
|
135
|
+
target_to_scheme, targets, name
|
136
|
+
)
|
137
|
+
names_to_scheme[name] = submodule.quantization_scheme.weights
|
109
138
|
|
139
|
+
if config.ignore is not None and ignored_submodules is not None:
|
140
|
+
if set(config.ignore) - set(ignored_submodules):
|
141
|
+
_LOGGER.warning(
|
142
|
+
"Some layers that were to be ignored were "
|
143
|
+
"not found in the model: "
|
144
|
+
f"{set(config.ignore) - set(ignored_submodules)}"
|
145
|
+
)
|
110
146
|
# apply current quantization status across all targeted layers
|
147
|
+
|
111
148
|
apply_quantization_status(model, config.quantization_status)
|
149
|
+
return names_to_scheme
|
150
|
+
|
151
|
+
|
152
|
+
def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
|
153
|
+
"""
|
154
|
+
Preprocess the raw QuantizationConfig
|
155
|
+
|
156
|
+
:param config: the raw QuantizationConfig
|
157
|
+
:return: the processed QuantizationConfig
|
158
|
+
"""
|
159
|
+
if config.kv_cache_scheme is not None:
|
160
|
+
config = process_kv_cache_config(config)
|
161
|
+
|
162
|
+
return config
|
163
|
+
|
164
|
+
|
165
|
+
def process_kv_cache_config(
|
166
|
+
config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
|
167
|
+
) -> QuantizationConfig:
|
168
|
+
"""
|
169
|
+
Reformulate the `config.kv_cache` as a `config_group`
|
170
|
+
and add it to the set of existing `config.groups`
|
171
|
+
|
172
|
+
:param config: the QuantizationConfig
|
173
|
+
:return: the QuantizationConfig with additional "kv_cache" group
|
174
|
+
"""
|
175
|
+
kv_cache_dict = config.kv_cache_scheme.model_dump()
|
176
|
+
kv_cache_scheme = QuantizationScheme(
|
177
|
+
output_activations=QuantizationArgs(**kv_cache_dict),
|
178
|
+
targets=targets,
|
179
|
+
)
|
180
|
+
kv_cache_group = dict(kv_cache=kv_cache_scheme)
|
181
|
+
config.config_groups.update(kv_cache_group)
|
182
|
+
return config
|
112
183
|
|
113
184
|
|
114
185
|
def apply_quantization_status(model: Module, status: QuantizationStatus):
|
@@ -118,41 +189,73 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
118
189
|
:param model: model to apply quantization to
|
119
190
|
:param status: status to update the module to
|
120
191
|
"""
|
121
|
-
|
192
|
+
current_status = infer_quantization_status(model)
|
193
|
+
|
194
|
+
if status >= QuantizationStatus.INITIALIZED > current_status:
|
122
195
|
model.apply(initialize_module_for_quantization)
|
123
|
-
|
124
|
-
|
125
|
-
|
196
|
+
|
197
|
+
if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
|
198
|
+
# only quantize weights up front when our end goal state is calibration,
|
199
|
+
# weight quantization parameters are already loaded for frozen/compressed
|
200
|
+
quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
|
201
|
+
model.apply(
|
202
|
+
lambda module: set_module_for_calibration(
|
203
|
+
module, quantize_weights_upfront=quantize_weights_upfront
|
204
|
+
)
|
205
|
+
)
|
206
|
+
if current_status < status >= QuantizationStatus.FROZEN > current_status:
|
126
207
|
model.apply(freeze_module_quantization)
|
127
208
|
|
209
|
+
if current_status < status >= QuantizationStatus.COMPRESSED > current_status:
|
210
|
+
model.apply(compress_quantized_weights)
|
128
211
|
|
129
|
-
|
212
|
+
|
213
|
+
def find_name_or_class_matches(
|
130
214
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
131
|
-
) ->
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
215
|
+
) -> List[str]:
|
216
|
+
"""
|
217
|
+
Returns all targets that match the given name or the class name.
|
218
|
+
Returns empty list otherwise.
|
219
|
+
The order of the output `matches` list matters.
|
220
|
+
The entries are sorted in the following order:
|
221
|
+
1. matches on exact strings
|
222
|
+
2. matches on regex patterns
|
223
|
+
3. matches on module names
|
224
|
+
"""
|
225
|
+
targets = sorted(targets, key=lambda x: ("re:" in x, x))
|
226
|
+
if isinstance(targets, Iterable):
|
227
|
+
matches = _find_matches(name, targets) + _find_matches(
|
228
|
+
module.__class__.__name__, targets, check_contains
|
229
|
+
)
|
230
|
+
matches = [match for match in matches if match is not None]
|
231
|
+
return matches
|
138
232
|
|
139
233
|
|
140
|
-
def
|
234
|
+
def _find_matches(
|
141
235
|
value: str, targets: Iterable[str], check_contains: bool = False
|
142
|
-
) ->
|
143
|
-
# returns
|
236
|
+
) -> List[str]:
|
237
|
+
# returns all the targets that match value either
|
144
238
|
# exactly or as a regex after 're:'. if check_contains is set to True,
|
145
239
|
# additionally checks if the target string is contained with value.
|
240
|
+
matches = []
|
146
241
|
for target in targets:
|
147
242
|
if target.startswith("re:"):
|
148
243
|
pattern = target[3:]
|
149
244
|
if re.match(pattern, value):
|
150
|
-
|
245
|
+
matches.append(target)
|
151
246
|
elif check_contains:
|
152
247
|
if target.lower() in value.lower():
|
153
|
-
|
248
|
+
matches.append(target)
|
154
249
|
elif target == value:
|
155
|
-
|
250
|
+
matches.append(target)
|
251
|
+
return matches
|
252
|
+
|
253
|
+
|
254
|
+
def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
255
|
+
for module in model.modules():
|
256
|
+
status = getattr(module, "quantization_status", None)
|
257
|
+
if status is not None:
|
258
|
+
return status
|
156
259
|
return None
|
157
260
|
|
158
261
|
|
@@ -170,9 +273,79 @@ def _load_quant_args_from_state_dict(
|
|
170
273
|
"""
|
171
274
|
scale_name = f"{base_name}_scale"
|
172
275
|
zp_name = f"{base_name}_zero_point"
|
173
|
-
device = next(module.parameters()).device
|
174
276
|
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
277
|
+
state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
|
278
|
+
state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
|
279
|
+
|
280
|
+
if state_dict_scale is not None:
|
281
|
+
# module is quantized
|
282
|
+
update_parameter_data(module, state_dict_scale, scale_name)
|
283
|
+
if state_dict_zp is None:
|
284
|
+
# fill in zero point for symmetric quantization
|
285
|
+
state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
|
286
|
+
update_parameter_data(module, state_dict_zp, zp_name)
|
287
|
+
|
288
|
+
|
289
|
+
def _scheme_from_targets(
|
290
|
+
target_to_scheme: OrderedDictType[str, QuantizationScheme],
|
291
|
+
targets: List[str],
|
292
|
+
name: str,
|
293
|
+
) -> QuantizationScheme:
|
294
|
+
if len(targets) == 1:
|
295
|
+
# if `targets` iterable contains a single element
|
296
|
+
# use it as the key
|
297
|
+
return target_to_scheme[targets[0]]
|
298
|
+
|
299
|
+
# otherwise, we need to merge QuantizationSchemes corresponding
|
300
|
+
# to multiple targets. This is most likely because `name` module
|
301
|
+
# is being target both as an ordinary quantization target, as well
|
302
|
+
# as kv cache quantization target
|
303
|
+
schemes_to_merge = [target_to_scheme[target] for target in targets]
|
304
|
+
return _merge_schemes(schemes_to_merge, name)
|
305
|
+
|
306
|
+
|
307
|
+
def _merge_schemes(
|
308
|
+
schemes_to_merge: List[QuantizationScheme], name: str
|
309
|
+
) -> QuantizationScheme:
|
310
|
+
|
311
|
+
kv_cache_quantization_scheme = [
|
312
|
+
scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
313
|
+
]
|
314
|
+
if not kv_cache_quantization_scheme:
|
315
|
+
# if the schemes_to_merge do not contain any
|
316
|
+
# kv cache QuantizationScheme
|
317
|
+
# return the first scheme (the prioritized one,
|
318
|
+
# since the order of schemes_to_merge matters)
|
319
|
+
return schemes_to_merge[0]
|
320
|
+
else:
|
321
|
+
# fetch the kv cache QuantizationScheme and the highest
|
322
|
+
# priority non-kv cache QuantizationScheme and merge them
|
323
|
+
kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
|
324
|
+
quantization_scheme = [
|
325
|
+
scheme
|
326
|
+
for scheme in schemes_to_merge
|
327
|
+
if not is_kv_cache_quant_scheme(scheme)
|
328
|
+
][0]
|
329
|
+
schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
|
330
|
+
merged_scheme = {}
|
331
|
+
for scheme in schemes_to_merge:
|
332
|
+
scheme_dict = {
|
333
|
+
k: v for k, v in scheme.model_dump().items() if v is not None
|
334
|
+
}
|
335
|
+
# when merging multiple schemes, the final target will be
|
336
|
+
# the `name` argument - hence erase the original targets
|
337
|
+
del scheme_dict["targets"]
|
338
|
+
# make sure that schemes do not "clash" with each other
|
339
|
+
overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
|
340
|
+
if overlapping_keys:
|
341
|
+
raise ValueError(
|
342
|
+
f"The module: {name} is being modified by two clashing "
|
343
|
+
f"quantization schemes, that jointly try to override "
|
344
|
+
f"properties: {overlapping_keys}. Fix the quantization config "
|
345
|
+
"so that it is not ambiguous."
|
346
|
+
)
|
347
|
+
merged_scheme.update(scheme_dict)
|
348
|
+
|
349
|
+
merged_scheme.update(targets=[name])
|
350
|
+
|
351
|
+
return QuantizationScheme(**merged_scheme)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import logging
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
19
|
+
from compressed_tensors.utils import is_module_offloaded, update_parameter_data
|
19
20
|
from torch.nn import Module
|
20
21
|
|
21
22
|
|
@@ -27,7 +28,7 @@ __all__ = [
|
|
27
28
|
_LOGGER = logging.getLogger(__name__)
|
28
29
|
|
29
30
|
|
30
|
-
def set_module_for_calibration(module: Module):
|
31
|
+
def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
|
31
32
|
"""
|
32
33
|
marks a layer as ready for calibration which activates observers
|
33
34
|
to update scales and zero points on each forward pass
|
@@ -35,6 +36,8 @@ def set_module_for_calibration(module: Module):
|
|
35
36
|
apply to full model with `model.apply(set_module_for_calibration)`
|
36
37
|
|
37
38
|
:param module: module to set for calibration
|
39
|
+
:param quantize_weights_upfront: whether to automatically run weight quantization at the
|
40
|
+
start of calibration
|
38
41
|
"""
|
39
42
|
if not getattr(module, "quantization_scheme", None):
|
40
43
|
# no quantization scheme nothing to do
|
@@ -48,4 +51,20 @@ def set_module_for_calibration(module: Module):
|
|
48
51
|
"to re-calibrate a frozen module"
|
49
52
|
)
|
50
53
|
|
54
|
+
if quantize_weights_upfront and module.quantization_scheme.weights is not None:
|
55
|
+
# set weight scale and zero_point up front, calibration data doesn't affect it
|
56
|
+
observer = module.weight_observer
|
57
|
+
|
58
|
+
offloaded = False
|
59
|
+
if is_module_offloaded(module):
|
60
|
+
module._hf_hook.pre_forward(module)
|
61
|
+
offloaded = True
|
62
|
+
|
63
|
+
scale, zero_point = observer(module.weight)
|
64
|
+
update_parameter_data(module, scale, "weight_scale")
|
65
|
+
update_parameter_data(module, zero_point, "weight_zero_point")
|
66
|
+
|
67
|
+
if offloaded:
|
68
|
+
module._hf_hook.post_forward(module, None)
|
69
|
+
|
51
70
|
module.quantization_status = QuantizationStatus.CALIBRATION
|
@@ -0,0 +1,69 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
import logging
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.quantization.lifecycle.forward import quantize
|
20
|
+
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
21
|
+
from torch.nn import Module
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = [
|
25
|
+
"compress_quantized_weights",
|
26
|
+
]
|
27
|
+
|
28
|
+
|
29
|
+
_LOGGER = logging.getLogger(__name__)
|
30
|
+
|
31
|
+
|
32
|
+
def compress_quantized_weights(module: Module):
|
33
|
+
"""
|
34
|
+
Quantizes the module weight representation to use fewer bits in memory
|
35
|
+
|
36
|
+
apply to full model with `model.apply(compress_quantized_weights)`
|
37
|
+
|
38
|
+
:param module: module to compress to quantized representation
|
39
|
+
"""
|
40
|
+
scheme = getattr(module, "quantization_scheme", None)
|
41
|
+
if not scheme or not scheme.weights:
|
42
|
+
# no quantization scheme or weights not quantized, nothing to do
|
43
|
+
return
|
44
|
+
|
45
|
+
if scheme is QuantizationStatus.COMPRESSED:
|
46
|
+
# module is already compressed, nothing to do
|
47
|
+
return
|
48
|
+
|
49
|
+
weight = getattr(module, "weight", None)
|
50
|
+
scale = getattr(module, "weight_scale", None)
|
51
|
+
zero_point = getattr(module, "weight_zero_point", None)
|
52
|
+
|
53
|
+
if weight is None or scale is None or zero_point is None:
|
54
|
+
# no weight, scale, or ZP, nothing to do
|
55
|
+
|
56
|
+
# mark as compressed here to maintain consistent status throughout the model
|
57
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|
58
|
+
return
|
59
|
+
|
60
|
+
module.weight.requires_grad = False # cannot use auto grad after compression
|
61
|
+
module.weight.data = quantize(
|
62
|
+
x=weight,
|
63
|
+
scale=scale,
|
64
|
+
zero_point=zero_point,
|
65
|
+
args=scheme.weights,
|
66
|
+
dtype=torch.int8,
|
67
|
+
)
|
68
|
+
|
69
|
+
module.quantization_status = QuantizationStatus.COMPRESSED
|