compressed-tensors 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +76 -14
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
- compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +240 -0
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- compressed_tensors/config/__init__.py +1 -0
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/config/sparse_24_bitmask.py +40 -0
- compressed_tensors/quantization/lifecycle/apply.py +46 -1
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- compressed_tensors/quantization/quant_config.py +1 -1
- compressed_tensors/utils/helpers.py +174 -1
- compressed_tensors/utils/offload.py +332 -44
- compressed_tensors/utils/safetensors_load.py +83 -17
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/METADATA +1 -1
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/RECORD +24 -22
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,9 @@ import logging
|
|
17
17
|
import operator
|
18
18
|
import os
|
19
19
|
import re
|
20
|
+
from contextlib import contextmanager
|
20
21
|
from copy import deepcopy
|
21
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
|
22
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
|
22
23
|
|
23
24
|
import compressed_tensors
|
24
25
|
import torch
|
@@ -38,6 +39,7 @@ from compressed_tensors.quantization import (
|
|
38
39
|
apply_quantization_config,
|
39
40
|
load_pretrained_quantization,
|
40
41
|
)
|
42
|
+
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
|
41
43
|
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
42
44
|
from compressed_tensors.quantization.utils import (
|
43
45
|
is_module_quantized,
|
@@ -104,7 +106,6 @@ class ModelCompressor:
|
|
104
106
|
"""
|
105
107
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
106
108
|
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
107
|
-
|
108
109
|
return cls.from_compression_config(compression_config)
|
109
110
|
|
110
111
|
@classmethod
|
@@ -137,7 +138,7 @@ class ModelCompressor:
|
|
137
138
|
format, **sparsity_config
|
138
139
|
)
|
139
140
|
if quantization_config is not None:
|
140
|
-
quantization_config = QuantizationConfig.
|
141
|
+
quantization_config = QuantizationConfig.model_validate(quantization_config)
|
141
142
|
|
142
143
|
return cls(
|
143
144
|
sparsity_config=sparsity_config, quantization_config=quantization_config
|
@@ -193,7 +194,7 @@ class ModelCompressor:
|
|
193
194
|
|
194
195
|
if is_compressed_tensors_config(compression_config):
|
195
196
|
s_config = compression_config.sparsity_config
|
196
|
-
return s_config.
|
197
|
+
return s_config.model_dump() if s_config is not None else None
|
197
198
|
|
198
199
|
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
199
200
|
|
@@ -214,7 +215,7 @@ class ModelCompressor:
|
|
214
215
|
|
215
216
|
if is_compressed_tensors_config(compression_config):
|
216
217
|
q_config = compression_config.quantization_config
|
217
|
-
return q_config.
|
218
|
+
return q_config.model_dump() if q_config is not None else None
|
218
219
|
|
219
220
|
quantization_config = deepcopy(compression_config)
|
220
221
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
@@ -282,8 +283,14 @@ class ModelCompressor:
|
|
282
283
|
)
|
283
284
|
|
284
285
|
if self.sparsity_compressor is not None:
|
286
|
+
sparse_compression_targets: Set[str] = expand_sparse_target_names(
|
287
|
+
model=model,
|
288
|
+
targets=self.sparsity_config.targets,
|
289
|
+
ignore=self.sparsity_config.ignore,
|
290
|
+
)
|
285
291
|
compressed_state_dict = self.sparsity_compressor.compress(
|
286
|
-
compressed_state_dict
|
292
|
+
compressed_state_dict,
|
293
|
+
compression_targets=sparse_compression_targets,
|
287
294
|
)
|
288
295
|
|
289
296
|
# HACK: Override the dtype_byte_size function in transformers to
|
@@ -301,23 +308,44 @@ class ModelCompressor:
|
|
301
308
|
:param model: pytorch model to load decompressed weights into
|
302
309
|
"""
|
303
310
|
model_path = get_safetensors_folder(model_path)
|
304
|
-
|
311
|
+
sparse_decompressed = False
|
312
|
+
|
313
|
+
if (
|
314
|
+
self.sparsity_compressor is not None
|
315
|
+
and self.sparsity_config.format != CompressionFormat.dense.value
|
316
|
+
):
|
317
|
+
# Sparse decompression is applied on the model_path
|
305
318
|
dense_gen = self.sparsity_compressor.decompress(model_path)
|
306
319
|
self._replace_weights(dense_gen, model)
|
307
320
|
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
321
|
+
sparse_decompressed = True
|
308
322
|
|
309
323
|
if self.quantization_compressor is not None:
|
310
|
-
|
311
|
-
|
324
|
+
# Temporarily set quantization status to FROZEN to prevent
|
325
|
+
# quantization during apply_quantization_config. This ensures
|
326
|
+
# that the dtypes of the weights are not unintentionally updated.
|
327
|
+
# The status is restored after quantization params are loaded.
|
328
|
+
with override_quantization_status(
|
329
|
+
self.quantization_config, QuantizationStatus.FROZEN
|
330
|
+
):
|
331
|
+
names_to_scheme = apply_quantization_config(
|
332
|
+
model, self.quantization_config
|
333
|
+
)
|
334
|
+
load_pretrained_quantization(model, model_path)
|
335
|
+
|
336
|
+
model_path_or_state_dict = (
|
337
|
+
model.state_dict() if sparse_decompressed else model_path
|
338
|
+
)
|
339
|
+
|
312
340
|
dense_gen = self.quantization_compressor.decompress(
|
313
|
-
|
341
|
+
model_path_or_state_dict, names_to_scheme=names_to_scheme
|
314
342
|
)
|
315
343
|
self._replace_weights(dense_gen, model)
|
316
344
|
|
317
|
-
def
|
345
|
+
def freeze_quantization_status(module):
|
318
346
|
module.quantization_status = QuantizationStatus.FROZEN
|
319
347
|
|
320
|
-
model.apply(
|
348
|
+
model.apply(freeze_quantization_status)
|
321
349
|
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
|
322
350
|
|
323
351
|
def update_config(self, save_directory: str):
|
@@ -367,12 +395,26 @@ class ModelCompressor:
|
|
367
395
|
with open(config_file_path, "w") as config_file:
|
368
396
|
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
369
397
|
|
370
|
-
def _replace_weights(self, dense_weight_generator, model):
|
398
|
+
def _replace_weights(self, dense_weight_generator, model: Module):
|
399
|
+
"""
|
400
|
+
Replace the weights of the model with the
|
401
|
+
provided dense weights.
|
402
|
+
|
403
|
+
This method iterates over the dense_weight_generator and
|
404
|
+
updates the corresponding weights in the model. If a parameter
|
405
|
+
name does not exist in the model, it will be skipped.
|
406
|
+
|
407
|
+
:param dense_weight_generator (generator): A generator that yields
|
408
|
+
tuples of (name, data), where 'name' is the parameter name and
|
409
|
+
'data' is the updated param data
|
410
|
+
:param model: The model whose weights are to be updated.
|
411
|
+
"""
|
371
412
|
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
372
413
|
split_name = name.split(".")
|
373
414
|
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
374
415
|
module = operator.attrgetter(prefix)(model)
|
375
|
-
|
416
|
+
if hasattr(module, param_name):
|
417
|
+
update_parameter_data(module, data, param_name)
|
376
418
|
|
377
419
|
|
378
420
|
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
|
@@ -402,3 +444,23 @@ def new_dtype_byte_size(dtype):
|
|
402
444
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
403
445
|
bit_size = int(bit_search.groups()[0])
|
404
446
|
return bit_size // 8
|
447
|
+
|
448
|
+
|
449
|
+
@contextmanager
|
450
|
+
def override_quantization_status(
|
451
|
+
config: QuantizationConfig, status: QuantizationStatus
|
452
|
+
):
|
453
|
+
"""
|
454
|
+
Within this context, the quantization status will be set to the
|
455
|
+
supplied status. After the context exits, the original status
|
456
|
+
will be restored.
|
457
|
+
|
458
|
+
:param config: the quantization config to override
|
459
|
+
:param status: the status to temporarily set
|
460
|
+
"""
|
461
|
+
original_status = config.quantization_status
|
462
|
+
config.quantization_status = status
|
463
|
+
try:
|
464
|
+
yield
|
465
|
+
finally:
|
466
|
+
config.quantization_status = original_status
|
@@ -13,12 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Any, Dict, Generator, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
21
|
from compressed_tensors.quantization import QuantizationArgs
|
21
|
-
from compressed_tensors.utils import
|
22
|
+
from compressed_tensors.utils import (
|
23
|
+
get_nested_mappings_from_state_dict,
|
24
|
+
get_nested_weight_mappings,
|
25
|
+
merge_names,
|
26
|
+
)
|
22
27
|
from safetensors import safe_open
|
23
28
|
from torch import Tensor
|
24
29
|
from tqdm import tqdm
|
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
113
118
|
|
114
119
|
def decompress(
|
115
120
|
self,
|
116
|
-
path_to_model_or_tensors: str,
|
121
|
+
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
|
117
122
|
names_to_scheme: Dict[str, QuantizationArgs],
|
118
123
|
device: str = "cpu",
|
119
124
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
121
126
|
Reads a compressed state dict located at path_to_model_or_tensors
|
122
127
|
and returns a generator for sequentially decompressing back to a
|
123
128
|
dense state dict
|
124
|
-
|
125
129
|
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
126
130
|
with one or more safetensors files) or compressed tensors file
|
127
131
|
:param names_to_scheme: quantization args for each quantized weight
|
128
132
|
:param device: optional device to load intermediate weights into
|
129
133
|
:return: compressed state dict
|
130
134
|
"""
|
135
|
+
if isinstance(path_to_model_or_tensors, (str, Path)):
|
136
|
+
yield from self._decompress_from_path(
|
137
|
+
path_to_model_or_tensors, names_to_scheme, device
|
138
|
+
)
|
139
|
+
|
140
|
+
else:
|
141
|
+
yield from self._decompress_from_state_dict(
|
142
|
+
path_to_model_or_tensors, names_to_scheme
|
143
|
+
)
|
144
|
+
|
145
|
+
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
|
131
146
|
weight_mappings = get_nested_weight_mappings(
|
132
|
-
|
147
|
+
path_to_model, self.COMPRESSION_PARAM_NAMES
|
133
148
|
)
|
134
149
|
for weight_name in weight_mappings.keys():
|
135
150
|
weight_data = {}
|
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
137
152
|
full_name = merge_names(weight_name, param_name)
|
138
153
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
139
154
|
weight_data[param_name] = f.get_tensor(full_name)
|
155
|
+
if "weight_scale" in weight_data:
|
156
|
+
quant_args = names_to_scheme[weight_name]
|
157
|
+
decompressed = self.decompress_weight(
|
158
|
+
compressed_data=weight_data, quantization_args=quant_args
|
159
|
+
)
|
160
|
+
yield merge_names(weight_name, "weight"), decompressed
|
161
|
+
|
162
|
+
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
|
163
|
+
weight_mappings = get_nested_mappings_from_state_dict(
|
164
|
+
state_dict, self.COMPRESSION_PARAM_NAMES
|
165
|
+
)
|
166
|
+
for weight_name in weight_mappings.keys():
|
167
|
+
weight_data = {}
|
168
|
+
for param_name, param_value in weight_mappings[weight_name].items():
|
169
|
+
weight_data[param_name] = param_value
|
140
170
|
|
141
171
|
if "weight_scale" in weight_data:
|
142
172
|
quant_args = names_to_scheme[weight_name]
|
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from typing import Dict, Generator, Tuple
|
16
|
+
from typing import Dict, Generator, Optional, Set, Tuple
|
17
17
|
|
18
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
19
19
|
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
30
30
|
class BaseSparseCompressor(BaseCompressor):
|
31
31
|
"""
|
32
32
|
Base class representing a sparse compression algorithm. Each child class should
|
33
|
-
implement compression_param_info, compress_weight and decompress_weight
|
33
|
+
implement compression_param_info, compress_weight and decompress_weight; child
|
34
|
+
classes should also define COMPRESSION_PARAM_NAMES.
|
34
35
|
|
35
36
|
Compressors support compressing/decompressing a full module state dict or a single
|
36
37
|
quantized PyTorch leaf module.
|
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
|
|
59
60
|
:param config: config specifying compression parameters
|
60
61
|
"""
|
61
62
|
|
62
|
-
def compress(
|
63
|
+
def compress(
|
64
|
+
self,
|
65
|
+
model_state: Dict[str, Tensor],
|
66
|
+
compression_targets: Optional[Set[str]] = None,
|
67
|
+
) -> Dict[str, Tensor]:
|
63
68
|
"""
|
64
69
|
Compresses a dense state dict using bitmask compression
|
65
70
|
|
66
71
|
:param model_state: state dict of uncompressed model
|
72
|
+
:param compression_targets: optional set of layer prefixes to compress,
|
73
|
+
otherwise compress all layers (for backwards compatibility)
|
67
74
|
:return: compressed state dict
|
68
75
|
"""
|
69
76
|
compressed_dict = {}
|
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
|
|
71
78
|
f"Compressing model with {len(model_state)} parameterized layers..."
|
72
79
|
)
|
73
80
|
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
74
|
-
|
81
|
+
if not self.should_compress(name, compression_targets):
|
82
|
+
compressed_dict[name] = value
|
83
|
+
continue
|
84
|
+
prefix = name
|
85
|
+
if prefix.endswith(".weight"):
|
86
|
+
prefix = prefix[: -(len(".weight"))]
|
87
|
+
|
88
|
+
compression_data = self.compress_weight(prefix, value)
|
75
89
|
for key in compression_data.keys():
|
76
90
|
if key in compressed_dict:
|
77
91
|
_LOGGER.warn(
|
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
|
|
97
111
|
:param device: device to load decompressed weights onto
|
98
112
|
:return: iterator for generating decompressed weights
|
99
113
|
"""
|
100
|
-
weight_mappings = get_nested_weight_mappings(
|
101
|
-
path_to_model_or_tensors,
|
114
|
+
weight_mappings, ignored_params = get_nested_weight_mappings(
|
115
|
+
path_to_model_or_tensors,
|
116
|
+
self.COMPRESSION_PARAM_NAMES,
|
117
|
+
return_unmatched_params=True,
|
102
118
|
)
|
103
119
|
for weight_name in weight_mappings.keys():
|
104
120
|
weight_data = {}
|
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
|
|
107
123
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
108
124
|
weight_data[param_name] = f.get_tensor(full_name)
|
109
125
|
decompressed = self.decompress_weight(weight_data)
|
110
|
-
yield weight_name, decompressed
|
126
|
+
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
for ignored_param_name, safe_path in ignored_params.items():
|
129
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
130
|
+
value = f.get_tensor(ignored_param_name)
|
131
|
+
yield ignored_param_name, value
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
|
135
|
+
"""
|
136
|
+
Check if a parameter should be compressed.
|
137
|
+
Currently, this only returns True for weight parameters.
|
138
|
+
|
139
|
+
:param name: name of the parameter
|
140
|
+
:param expanded_targets: set of layer prefixes to compress
|
141
|
+
:return: whether or not the parameter should be compressed
|
142
|
+
"""
|
143
|
+
if expanded_targets is None:
|
144
|
+
return name.endswith(".weight")
|
145
|
+
|
146
|
+
return (
|
147
|
+
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
|
148
|
+
)
|
@@ -0,0 +1,240 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Dict, List, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
20
|
+
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
|
+
from compressed_tensors.config import CompressionFormat, SparsityStructure
|
22
|
+
from compressed_tensors.quantization import FP8_DTYPE
|
23
|
+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
24
|
+
from torch import Tensor
|
25
|
+
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
"Sparse24BitMaskCompressor",
|
29
|
+
"Sparse24BitMaskTensor",
|
30
|
+
"sparse24_bitmask_compress",
|
31
|
+
"sparse24_bitmask_decompress",
|
32
|
+
"get_24_bytemasks",
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
|
37
|
+
class Sparse24BitMaskCompressor(BaseSparseCompressor):
|
38
|
+
"""
|
39
|
+
Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
|
40
|
+
values tensor, with their locations stored in a 2d bitmask
|
41
|
+
"""
|
42
|
+
|
43
|
+
COMPRESSION_PARAM_NAMES = [
|
44
|
+
"shape",
|
45
|
+
"compressed",
|
46
|
+
"bitmask",
|
47
|
+
]
|
48
|
+
|
49
|
+
def compress_weight(self, name, value):
|
50
|
+
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
|
51
|
+
value, self.config.sparsity_structure
|
52
|
+
)
|
53
|
+
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
|
54
|
+
return bitmask_dict
|
55
|
+
|
56
|
+
def decompress_weight(self, weight_data):
|
57
|
+
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
|
58
|
+
decompressed = data.decompress()
|
59
|
+
return decompressed
|
60
|
+
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class Sparse24BitMaskTensor:
|
64
|
+
"""
|
65
|
+
Owns compressions and decompression for a single 2:4 sparse
|
66
|
+
bitmask compressed tensor.
|
67
|
+
|
68
|
+
:param shape: shape of dense tensor
|
69
|
+
:param compressed: 2d tensor of non-zero values
|
70
|
+
:param bitmask: 2d bitmask of non-zero values
|
71
|
+
"""
|
72
|
+
|
73
|
+
shape: List[int]
|
74
|
+
compressed: Tensor
|
75
|
+
bitmask: Tensor
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def from_dense(
|
79
|
+
tensor: Tensor,
|
80
|
+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
|
81
|
+
) -> "Sparse24BitMaskTensor":
|
82
|
+
"""
|
83
|
+
:param tensor: dense tensor to compress
|
84
|
+
:return: instantiated compressed tensor
|
85
|
+
"""
|
86
|
+
shape = list(tensor.shape)
|
87
|
+
compressed, bitmask = sparse24_bitmask_compress(
|
88
|
+
tensor.cpu(), sparsity_structure=sparsity_structure
|
89
|
+
)
|
90
|
+
return Sparse24BitMaskTensor(
|
91
|
+
shape=shape,
|
92
|
+
compressed=compressed,
|
93
|
+
bitmask=bitmask,
|
94
|
+
)
|
95
|
+
|
96
|
+
@staticmethod
|
97
|
+
def from_compressed_data(
|
98
|
+
shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
|
99
|
+
) -> "Sparse24BitMaskTensor":
|
100
|
+
"""
|
101
|
+
:param shape: shape of the dense tensor (can be a list or a tensor)
|
102
|
+
:param compressed: 2d tensor of non-zero values
|
103
|
+
:param bitmask: 2d bitmask of non-zero values
|
104
|
+
:return: instantiated Sparse24BitMaskTensor
|
105
|
+
"""
|
106
|
+
if isinstance(shape, list):
|
107
|
+
shape = torch.tensor(shape)
|
108
|
+
if isinstance(shape, torch.Tensor):
|
109
|
+
shape = shape.flatten().tolist()
|
110
|
+
return Sparse24BitMaskTensor(
|
111
|
+
shape=shape, compressed=compressed, bitmask=bitmask
|
112
|
+
)
|
113
|
+
|
114
|
+
def decompress(self) -> Tensor:
|
115
|
+
"""
|
116
|
+
:return: reconstructed dense tensor
|
117
|
+
"""
|
118
|
+
return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
|
119
|
+
|
120
|
+
def curr_memory_size_bytes(self) -> int:
|
121
|
+
"""
|
122
|
+
:return: size in bytes required to store compressed tensor on disk
|
123
|
+
"""
|
124
|
+
|
125
|
+
def sizeof_tensor(a: Tensor) -> int:
|
126
|
+
return a.element_size() * a.nelement()
|
127
|
+
|
128
|
+
return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
|
129
|
+
|
130
|
+
def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
|
131
|
+
"""
|
132
|
+
:param name_prefix: name of original tensor to store compressed weight as
|
133
|
+
:return: dict of compressed data for the stored weight
|
134
|
+
"""
|
135
|
+
if name_prefix.endswith(".weight"):
|
136
|
+
name_prefix = name_prefix[: -len(".weight")]
|
137
|
+
return {
|
138
|
+
merge_names(name_prefix, "shape"): torch.tensor(
|
139
|
+
self.shape, device=device
|
140
|
+
).reshape(-1, 1),
|
141
|
+
merge_names(name_prefix, "compressed"): self.compressed.to(device),
|
142
|
+
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
|
143
|
+
}
|
144
|
+
|
145
|
+
def __repr__(self) -> str:
|
146
|
+
return f"BitMaskTensor(shape={self.shape}, compressed=True)"
|
147
|
+
|
148
|
+
|
149
|
+
def sparse24_bitmask_compress(
|
150
|
+
tensor: Tensor,
|
151
|
+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
|
152
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
153
|
+
"""
|
154
|
+
Compresses a dense tensor using bitmask compression
|
155
|
+
|
156
|
+
:param tensor: dense 2D tensor to compress
|
157
|
+
:param sparsity_structure: structure of sparsity in the tensor, defaults
|
158
|
+
to unstructured, can also be set to `2:4`
|
159
|
+
:return: tuple of compressed data representing tensor
|
160
|
+
"""
|
161
|
+
assert len(tensor.shape) == 2, "Only 2D tensors are supported"
|
162
|
+
assert (
|
163
|
+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
|
164
|
+
), "Only 2:4 sparsity is supported"
|
165
|
+
|
166
|
+
bytemasks = get_24_bytemasks(tensor=tensor)
|
167
|
+
|
168
|
+
if tensor.dtype == FP8_DTYPE:
|
169
|
+
# acces raw bytes of the tensor
|
170
|
+
tensor_view = tensor.view(torch.int8)
|
171
|
+
values = tensor_view[bytemasks]
|
172
|
+
values = values.view(FP8_DTYPE)
|
173
|
+
else:
|
174
|
+
values = tensor[bytemasks]
|
175
|
+
|
176
|
+
num_rows, num_cols = tensor.shape
|
177
|
+
compressed_values = values.reshape(num_rows, num_cols // 2)
|
178
|
+
bitmasks_packed = pack_bitmasks(bytemasks)
|
179
|
+
return compressed_values, bitmasks_packed
|
180
|
+
|
181
|
+
|
182
|
+
def sparse24_bitmask_decompress(
|
183
|
+
values: Tensor, bitmasks: Tensor, original_shape: torch.Size
|
184
|
+
) -> Tensor:
|
185
|
+
"""
|
186
|
+
Reconstructs a dense tensor from a compressed one
|
187
|
+
|
188
|
+
:param values: 1d tensor of non-zero values
|
189
|
+
:param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
|
190
|
+
tensors original shape
|
191
|
+
:param original_shape: shape of the dense tensor
|
192
|
+
:return: decompressed dense tensor
|
193
|
+
"""
|
194
|
+
bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
|
195
|
+
|
196
|
+
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
|
197
|
+
decompressed_tensor = decompressed_tensor.to(values.device)
|
198
|
+
values = values.flatten()
|
199
|
+
if decompressed_tensor.dtype == FP8_DTYPE:
|
200
|
+
decompressed_tensor[bytemasks_unpacked] = values
|
201
|
+
decompressed_tensor = decompressed_tensor.cuda()
|
202
|
+
else:
|
203
|
+
decompressed_tensor[bytemasks_unpacked] = values
|
204
|
+
return decompressed_tensor
|
205
|
+
|
206
|
+
|
207
|
+
def get_24_bytemasks(tensor):
|
208
|
+
"""
|
209
|
+
Generate a 2:4 sparsity mask for the given tensor.
|
210
|
+
|
211
|
+
This function creates a mask where exactly 2 out of every 4 elements are
|
212
|
+
preserved based on their magnitudes. The preserved elements are the ones
|
213
|
+
with the highest absolute values in each group of 4 elements.
|
214
|
+
|
215
|
+
:param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
|
216
|
+
The tensor can be of any shape but its total number of elements
|
217
|
+
must be a multiple of 4.
|
218
|
+
:return: A boolean tensor of the same shape as the input tensor, where `True`
|
219
|
+
indicates the preserved elements and `False` indicates the pruned elements.
|
220
|
+
:raises ValueError: If the total number of elements in the tensor is not a
|
221
|
+
multiple of 4.
|
222
|
+
"""
|
223
|
+
original_dtype = tensor.dtype
|
224
|
+
if tensor.dtype == FP8_DTYPE:
|
225
|
+
tensor = tensor.view(torch.int8)
|
226
|
+
original_shape = tensor.shape
|
227
|
+
num_elements = tensor.numel()
|
228
|
+
|
229
|
+
if num_elements % 4 != 0:
|
230
|
+
raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
|
231
|
+
|
232
|
+
reshaped_tensor = tensor.view(-1, 4)
|
233
|
+
abs_tensor = reshaped_tensor.abs()
|
234
|
+
topk_indices = abs_tensor.topk(2, dim=1).indices
|
235
|
+
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
|
236
|
+
mask.scatter_(1, topk_indices, True)
|
237
|
+
mask = mask.view(original_shape)
|
238
|
+
tensor = tensor.view(original_dtype)
|
239
|
+
|
240
|
+
return mask
|