compressed-tensors-nightly 0.8.1.20250105__py3-none-any.whl → 0.8.1.20250107__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- compressed_tensors/compressors/model_compressors/model_compressor.py +69 -10
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +8 -2
- compressed_tensors/quantization/lifecycle/apply.py +46 -1
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/utils/safetensors_load.py +83 -17
- {compressed_tensors_nightly-0.8.1.20250105.dist-info → compressed_tensors_nightly-0.8.1.20250107.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.8.1.20250105.dist-info → compressed_tensors_nightly-0.8.1.20250107.dist-info}/RECORD +14 -14
- {compressed_tensors_nightly-0.8.1.20250105.dist-info → compressed_tensors_nightly-0.8.1.20250107.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.8.1.20250105.dist-info → compressed_tensors_nightly-0.8.1.20250107.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.8.1.20250105.dist-info → compressed_tensors_nightly-0.8.1.20250107.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,9 @@ import logging
|
|
17
17
|
import operator
|
18
18
|
import os
|
19
19
|
import re
|
20
|
+
from contextlib import contextmanager
|
20
21
|
from copy import deepcopy
|
21
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
|
22
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
|
22
23
|
|
23
24
|
import compressed_tensors
|
24
25
|
import torch
|
@@ -38,6 +39,7 @@ from compressed_tensors.quantization import (
|
|
38
39
|
apply_quantization_config,
|
39
40
|
load_pretrained_quantization,
|
40
41
|
)
|
42
|
+
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
|
41
43
|
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
42
44
|
from compressed_tensors.quantization.utils import (
|
43
45
|
is_module_quantized,
|
@@ -104,7 +106,6 @@ class ModelCompressor:
|
|
104
106
|
"""
|
105
107
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
106
108
|
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
107
|
-
|
108
109
|
return cls.from_compression_config(compression_config)
|
109
110
|
|
110
111
|
@classmethod
|
@@ -282,8 +283,14 @@ class ModelCompressor:
|
|
282
283
|
)
|
283
284
|
|
284
285
|
if self.sparsity_compressor is not None:
|
286
|
+
sparse_compression_targets: Set[str] = expand_sparse_target_names(
|
287
|
+
model=model,
|
288
|
+
targets=self.sparsity_config.targets,
|
289
|
+
ignore=self.sparsity_config.ignore,
|
290
|
+
)
|
285
291
|
compressed_state_dict = self.sparsity_compressor.compress(
|
286
|
-
compressed_state_dict
|
292
|
+
compressed_state_dict,
|
293
|
+
compression_targets=sparse_compression_targets,
|
287
294
|
)
|
288
295
|
|
289
296
|
# HACK: Override the dtype_byte_size function in transformers to
|
@@ -301,23 +308,41 @@ class ModelCompressor:
|
|
301
308
|
:param model: pytorch model to load decompressed weights into
|
302
309
|
"""
|
303
310
|
model_path = get_safetensors_folder(model_path)
|
311
|
+
sparse_decompressed = False
|
312
|
+
|
304
313
|
if self.sparsity_compressor is not None:
|
314
|
+
# Sparse decompression is applied on the model_path
|
305
315
|
dense_gen = self.sparsity_compressor.decompress(model_path)
|
306
316
|
self._replace_weights(dense_gen, model)
|
307
317
|
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
318
|
+
sparse_decompressed = True
|
308
319
|
|
309
320
|
if self.quantization_compressor is not None:
|
310
|
-
|
311
|
-
|
321
|
+
# Temporarily set quantization status to FROZEN to prevent
|
322
|
+
# quantization during apply_quantization_config. This ensures
|
323
|
+
# that the dtypes of the weights are not unintentionally updated.
|
324
|
+
# The status is restored after quantization params are loaded.
|
325
|
+
with override_quantization_status(
|
326
|
+
self.quantization_config, QuantizationStatus.FROZEN
|
327
|
+
):
|
328
|
+
names_to_scheme = apply_quantization_config(
|
329
|
+
model, self.quantization_config
|
330
|
+
)
|
331
|
+
load_pretrained_quantization(model, model_path)
|
332
|
+
|
333
|
+
model_path_or_state_dict = (
|
334
|
+
model.state_dict() if sparse_decompressed else model_path
|
335
|
+
)
|
336
|
+
|
312
337
|
dense_gen = self.quantization_compressor.decompress(
|
313
|
-
|
338
|
+
model_path_or_state_dict, names_to_scheme=names_to_scheme
|
314
339
|
)
|
315
340
|
self._replace_weights(dense_gen, model)
|
316
341
|
|
317
|
-
def
|
342
|
+
def freeze_quantization_status(module):
|
318
343
|
module.quantization_status = QuantizationStatus.FROZEN
|
319
344
|
|
320
|
-
model.apply(
|
345
|
+
model.apply(freeze_quantization_status)
|
321
346
|
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
|
322
347
|
|
323
348
|
def update_config(self, save_directory: str):
|
@@ -367,12 +392,26 @@ class ModelCompressor:
|
|
367
392
|
with open(config_file_path, "w") as config_file:
|
368
393
|
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
369
394
|
|
370
|
-
def _replace_weights(self, dense_weight_generator, model):
|
395
|
+
def _replace_weights(self, dense_weight_generator, model: Module):
|
396
|
+
"""
|
397
|
+
Replace the weights of the model with the
|
398
|
+
provided dense weights.
|
399
|
+
|
400
|
+
This method iterates over the dense_weight_generator and
|
401
|
+
updates the corresponding weights in the model. If a parameter
|
402
|
+
name does not exist in the model, it will be skipped.
|
403
|
+
|
404
|
+
:param dense_weight_generator (generator): A generator that yields
|
405
|
+
tuples of (name, data), where 'name' is the parameter name and
|
406
|
+
'data' is the updated param data
|
407
|
+
:param model: The model whose weights are to be updated.
|
408
|
+
"""
|
371
409
|
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
372
410
|
split_name = name.split(".")
|
373
411
|
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
374
412
|
module = operator.attrgetter(prefix)(model)
|
375
|
-
|
413
|
+
if hasattr(module, param_name):
|
414
|
+
update_parameter_data(module, data, param_name)
|
376
415
|
|
377
416
|
|
378
417
|
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
|
@@ -402,3 +441,23 @@ def new_dtype_byte_size(dtype):
|
|
402
441
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
403
442
|
bit_size = int(bit_search.groups()[0])
|
404
443
|
return bit_size // 8
|
444
|
+
|
445
|
+
|
446
|
+
@contextmanager
|
447
|
+
def override_quantization_status(
|
448
|
+
config: QuantizationConfig, status: QuantizationStatus
|
449
|
+
):
|
450
|
+
"""
|
451
|
+
Within this context, the quantization status will be set to the
|
452
|
+
supplied status. After the context exits, the original status
|
453
|
+
will be restored.
|
454
|
+
|
455
|
+
:param config: the quantization config to override
|
456
|
+
:param status: the status to temporarily set
|
457
|
+
"""
|
458
|
+
original_status = config.quantization_status
|
459
|
+
config.quantization_status = status
|
460
|
+
try:
|
461
|
+
yield
|
462
|
+
finally:
|
463
|
+
config.quantization_status = original_status
|
@@ -13,12 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Any, Dict, Generator, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
21
|
from compressed_tensors.quantization import QuantizationArgs
|
21
|
-
from compressed_tensors.utils import
|
22
|
+
from compressed_tensors.utils import (
|
23
|
+
get_nested_mappings_from_state_dict,
|
24
|
+
get_nested_weight_mappings,
|
25
|
+
merge_names,
|
26
|
+
)
|
22
27
|
from safetensors import safe_open
|
23
28
|
from torch import Tensor
|
24
29
|
from tqdm import tqdm
|
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
113
118
|
|
114
119
|
def decompress(
|
115
120
|
self,
|
116
|
-
path_to_model_or_tensors: str,
|
121
|
+
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
|
117
122
|
names_to_scheme: Dict[str, QuantizationArgs],
|
118
123
|
device: str = "cpu",
|
119
124
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
121
126
|
Reads a compressed state dict located at path_to_model_or_tensors
|
122
127
|
and returns a generator for sequentially decompressing back to a
|
123
128
|
dense state dict
|
124
|
-
|
125
129
|
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
126
130
|
with one or more safetensors files) or compressed tensors file
|
127
131
|
:param names_to_scheme: quantization args for each quantized weight
|
128
132
|
:param device: optional device to load intermediate weights into
|
129
133
|
:return: compressed state dict
|
130
134
|
"""
|
135
|
+
if isinstance(path_to_model_or_tensors, (str, Path)):
|
136
|
+
yield from self._decompress_from_path(
|
137
|
+
path_to_model_or_tensors, names_to_scheme, device
|
138
|
+
)
|
139
|
+
|
140
|
+
else:
|
141
|
+
yield from self._decompress_from_state_dict(
|
142
|
+
path_to_model_or_tensors, names_to_scheme
|
143
|
+
)
|
144
|
+
|
145
|
+
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
|
131
146
|
weight_mappings = get_nested_weight_mappings(
|
132
|
-
|
147
|
+
path_to_model, self.COMPRESSION_PARAM_NAMES
|
133
148
|
)
|
134
149
|
for weight_name in weight_mappings.keys():
|
135
150
|
weight_data = {}
|
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
137
152
|
full_name = merge_names(weight_name, param_name)
|
138
153
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
139
154
|
weight_data[param_name] = f.get_tensor(full_name)
|
155
|
+
if "weight_scale" in weight_data:
|
156
|
+
quant_args = names_to_scheme[weight_name]
|
157
|
+
decompressed = self.decompress_weight(
|
158
|
+
compressed_data=weight_data, quantization_args=quant_args
|
159
|
+
)
|
160
|
+
yield merge_names(weight_name, "weight"), decompressed
|
161
|
+
|
162
|
+
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
|
163
|
+
weight_mappings = get_nested_mappings_from_state_dict(
|
164
|
+
state_dict, self.COMPRESSION_PARAM_NAMES
|
165
|
+
)
|
166
|
+
for weight_name in weight_mappings.keys():
|
167
|
+
weight_data = {}
|
168
|
+
for param_name, param_value in weight_mappings[weight_name].items():
|
169
|
+
weight_data[param_name] = param_value
|
140
170
|
|
141
171
|
if "weight_scale" in weight_data:
|
142
172
|
quant_args = names_to_scheme[weight_name]
|
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from typing import Dict, Generator, Tuple
|
16
|
+
from typing import Dict, Generator, Optional, Set, Tuple
|
17
17
|
|
18
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
19
19
|
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
30
30
|
class BaseSparseCompressor(BaseCompressor):
|
31
31
|
"""
|
32
32
|
Base class representing a sparse compression algorithm. Each child class should
|
33
|
-
implement compression_param_info, compress_weight and decompress_weight
|
33
|
+
implement compression_param_info, compress_weight and decompress_weight; child
|
34
|
+
classes should also define COMPRESSION_PARAM_NAMES.
|
34
35
|
|
35
36
|
Compressors support compressing/decompressing a full module state dict or a single
|
36
37
|
quantized PyTorch leaf module.
|
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
|
|
59
60
|
:param config: config specifying compression parameters
|
60
61
|
"""
|
61
62
|
|
62
|
-
def compress(
|
63
|
+
def compress(
|
64
|
+
self,
|
65
|
+
model_state: Dict[str, Tensor],
|
66
|
+
compression_targets: Optional[Set[str]] = None,
|
67
|
+
) -> Dict[str, Tensor]:
|
63
68
|
"""
|
64
69
|
Compresses a dense state dict using bitmask compression
|
65
70
|
|
66
71
|
:param model_state: state dict of uncompressed model
|
72
|
+
:param compression_targets: optional set of layer prefixes to compress,
|
73
|
+
otherwise compress all layers (for backwards compatibility)
|
67
74
|
:return: compressed state dict
|
68
75
|
"""
|
69
76
|
compressed_dict = {}
|
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
|
|
71
78
|
f"Compressing model with {len(model_state)} parameterized layers..."
|
72
79
|
)
|
73
80
|
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
74
|
-
|
81
|
+
if not self.should_compress(name, compression_targets):
|
82
|
+
compressed_dict[name] = value
|
83
|
+
continue
|
84
|
+
prefix = name
|
85
|
+
if prefix.endswith(".weight"):
|
86
|
+
prefix = prefix[: -(len(".weight"))]
|
87
|
+
|
88
|
+
compression_data = self.compress_weight(prefix, value)
|
75
89
|
for key in compression_data.keys():
|
76
90
|
if key in compressed_dict:
|
77
91
|
_LOGGER.warn(
|
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
|
|
97
111
|
:param device: device to load decompressed weights onto
|
98
112
|
:return: iterator for generating decompressed weights
|
99
113
|
"""
|
100
|
-
weight_mappings = get_nested_weight_mappings(
|
101
|
-
path_to_model_or_tensors,
|
114
|
+
weight_mappings, ignored_params = get_nested_weight_mappings(
|
115
|
+
path_to_model_or_tensors,
|
116
|
+
self.COMPRESSION_PARAM_NAMES,
|
117
|
+
return_unmatched_params=True,
|
102
118
|
)
|
103
119
|
for weight_name in weight_mappings.keys():
|
104
120
|
weight_data = {}
|
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
|
|
107
123
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
108
124
|
weight_data[param_name] = f.get_tensor(full_name)
|
109
125
|
decompressed = self.decompress_weight(weight_data)
|
110
|
-
yield weight_name, decompressed
|
126
|
+
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
for ignored_param_name, safe_path in ignored_params.items():
|
129
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
130
|
+
value = f.get_tensor(ignored_param_name)
|
131
|
+
yield ignored_param_name, value
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
|
135
|
+
"""
|
136
|
+
Check if a parameter should be compressed.
|
137
|
+
Currently, this only returns True for weight parameters.
|
138
|
+
|
139
|
+
:param name: name of the parameter
|
140
|
+
:param expanded_targets: set of layer prefixes to compress
|
141
|
+
:return: whether or not the parameter should be compressed
|
142
|
+
"""
|
143
|
+
if expanded_targets is None:
|
144
|
+
return name.endswith(".weight")
|
145
|
+
|
146
|
+
return (
|
147
|
+
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
|
148
|
+
)
|
@@ -19,6 +19,7 @@ import torch
|
|
19
19
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
20
|
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
21
|
from compressed_tensors.config import CompressionFormat
|
22
|
+
from compressed_tensors.quantization import FP8_DTYPE
|
22
23
|
from compressed_tensors.utils import merge_names
|
23
24
|
from torch import Tensor
|
24
25
|
|
@@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
134
135
|
bytemasks = tensor != 0
|
135
136
|
row_counts = bytemasks.sum(dim=-1)
|
136
137
|
row_offsets = torch.cumsum(row_counts, 0) - row_counts
|
137
|
-
|
138
|
+
if tensor.dtype == FP8_DTYPE:
|
139
|
+
# acces raw bytes of the tensor
|
140
|
+
tensor_view = tensor.view(torch.int8)
|
141
|
+
values = tensor_view[bytemasks]
|
142
|
+
values = values.view(FP8_DTYPE)
|
143
|
+
else:
|
144
|
+
values = tensor[bytemasks]
|
138
145
|
bitmasks_packed = pack_bitmasks(bytemasks)
|
139
|
-
|
140
146
|
return values, bitmasks_packed, row_offsets
|
141
147
|
|
142
148
|
|
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
|
|
18
18
|
from copy import deepcopy
|
19
19
|
from typing import Dict, Iterable, List, Optional
|
20
20
|
from typing import OrderedDict as OrderedDictType
|
21
|
-
from typing import Union
|
21
|
+
from typing import Set, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
@@ -52,6 +52,8 @@ __all__ = [
|
|
52
52
|
"apply_quantization_config",
|
53
53
|
"apply_quantization_status",
|
54
54
|
"find_name_or_class_matches",
|
55
|
+
"expand_sparse_target_names",
|
56
|
+
"is_sparse_target",
|
55
57
|
]
|
56
58
|
|
57
59
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
245
247
|
model.apply(compress_quantized_weights)
|
246
248
|
|
247
249
|
|
250
|
+
def expand_sparse_target_names(
|
251
|
+
model: Module, targets: Iterable[str], ignore: Iterable[str]
|
252
|
+
) -> Set[str]:
|
253
|
+
"""
|
254
|
+
Finds all unique module names in the model that match the given
|
255
|
+
targets and ignore lists.
|
256
|
+
|
257
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
258
|
+
|
259
|
+
:param model: model to search for targets in
|
260
|
+
:param targets: list of targets to search for
|
261
|
+
:param ignore: list of targets to ignore
|
262
|
+
:return: set of all targets that match the given targets and should
|
263
|
+
not be ignored
|
264
|
+
"""
|
265
|
+
return {
|
266
|
+
name
|
267
|
+
for name, module in iter_named_leaf_modules(model)
|
268
|
+
if is_sparse_target(name, module, targets, ignore)
|
269
|
+
}
|
270
|
+
|
271
|
+
|
272
|
+
def is_sparse_target(
|
273
|
+
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
|
274
|
+
) -> bool:
|
275
|
+
"""
|
276
|
+
Determines if a module should be included in the targets based on the
|
277
|
+
targets and ignore lists.
|
278
|
+
|
279
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
280
|
+
|
281
|
+
:param name: name of the module
|
282
|
+
:param module: the module itself
|
283
|
+
:param targets: list of targets to search for
|
284
|
+
:param ignore: list of targets to ignore
|
285
|
+
:return: True if the module is a target and not ignored, False otherwise
|
286
|
+
"""
|
287
|
+
return bool(
|
288
|
+
find_name_or_class_matches(name, module, targets)
|
289
|
+
and not find_name_or_class_matches(name, module, ignore or [])
|
290
|
+
)
|
291
|
+
|
292
|
+
|
248
293
|
def find_name_or_class_matches(
|
249
294
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
250
295
|
) -> List[str]:
|
@@ -82,8 +82,8 @@ def quantize(
|
|
82
82
|
def dequantize(
|
83
83
|
x_q: torch.Tensor,
|
84
84
|
scale: torch.Tensor,
|
85
|
-
zero_point: torch.Tensor = None,
|
86
|
-
args: QuantizationArgs = None,
|
85
|
+
zero_point: Optional[torch.Tensor] = None,
|
86
|
+
args: Optional[QuantizationArgs] = None,
|
87
87
|
dtype: Optional[torch.dtype] = None,
|
88
88
|
g_idx: Optional[torch.Tensor] = None,
|
89
89
|
) -> torch.Tensor:
|
@@ -16,7 +16,7 @@ import json
|
|
16
16
|
import os
|
17
17
|
import re
|
18
18
|
import struct
|
19
|
-
from typing import Dict, List, Optional
|
19
|
+
from typing import Dict, List, Optional, Tuple, Union
|
20
20
|
|
21
21
|
from safetensors import safe_open
|
22
22
|
from torch import Tensor
|
@@ -30,10 +30,14 @@ __all__ = [
|
|
30
30
|
"merge_names",
|
31
31
|
"get_weight_mappings",
|
32
32
|
"get_nested_weight_mappings",
|
33
|
+
"get_nested_mappings_from_state_dict",
|
33
34
|
"get_quantization_state_dict",
|
34
35
|
"is_quantization_param",
|
35
36
|
]
|
36
37
|
|
38
|
+
WeightMappingType = Dict[str, str]
|
39
|
+
NestedWeightMappingType = Dict[str, WeightMappingType]
|
40
|
+
|
37
41
|
|
38
42
|
def get_safetensors_folder(
|
39
43
|
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
|
@@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
|
|
92
96
|
return header
|
93
97
|
|
94
98
|
|
95
|
-
def match_param_name(full_name: str, param_name: str) -> str:
|
99
|
+
def match_param_name(full_name: str, param_name: str) -> Optional[str]:
|
96
100
|
"""
|
97
101
|
Helper function extracting the uncompressed parameterized layer name from a
|
98
102
|
compressed name. Assumes the compressed name was merged using merge_names.
|
@@ -176,38 +180,100 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
|
|
176
180
|
|
177
181
|
|
178
182
|
def get_nested_weight_mappings(
|
179
|
-
model_path: str, params_to_nest: List[str]
|
180
|
-
) ->
|
183
|
+
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
|
184
|
+
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
|
181
185
|
"""
|
182
186
|
Takes a path to a state dict saved in safetensors format and returns a nested
|
183
|
-
mapping from uncompressed parameterized layer names to the file locations of
|
184
|
-
|
187
|
+
mapping from uncompressed parameterized layer names to the file locations of
|
188
|
+
each layer's compression parameters.
|
185
189
|
|
186
|
-
|
190
|
+
Example of the nested mapping:
|
191
|
+
layer: {
|
187
192
|
bitmask: file_location,
|
188
193
|
row_offsets: file_location,
|
189
194
|
shape: file_location,
|
190
195
|
compressed: file_location
|
191
196
|
}
|
192
197
|
|
193
|
-
|
198
|
+
If other parameters are found that do not match the nested parameters, they will
|
199
|
+
be returned in a separate dictionary only if return_unmatched_params is True.
|
200
|
+
This dictionary may be needed for cases where compressors are stacked (e.g.,
|
201
|
+
quantization compression followed by sparse compression).
|
202
|
+
|
203
|
+
Example of the unmatched params mapping:
|
204
|
+
{
|
205
|
+
layer.weight_scale: file_location,
|
206
|
+
layer.input_scale: file_location
|
207
|
+
}
|
194
208
|
|
195
|
-
|
196
|
-
|
197
|
-
|
209
|
+
This generalizes to cases where the model is split into multiple safetensors
|
210
|
+
files.
|
211
|
+
|
212
|
+
:param model_path: Path to the safetensors state dict, must contain either a
|
213
|
+
single safetensors file or multiple files with an index.
|
214
|
+
:param params_to_nest: List of parameter names to nest.
|
215
|
+
:param return_unmatched_params: If True, return a second dictionary containing
|
216
|
+
the remaining parameters that were not matched to the params_to_nest.
|
217
|
+
:return:
|
218
|
+
- If return_unmatched_params is False:
|
219
|
+
NestedWeightMappingType: A nested mapping of parameterized layer names to
|
220
|
+
file locations of each layer's compression parameters.
|
221
|
+
- If return_unmatched_params is True:
|
222
|
+
Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
|
223
|
+
- NestedWeightMappingType: A nested mapping of parameterized layer
|
224
|
+
names to file locations of each layer's compression parameters.
|
225
|
+
- WeightMappingType: A mapping of the remaining parameter names to
|
226
|
+
their file locations that were not matched to the params_to_nest.
|
198
227
|
"""
|
199
228
|
weight_mappings = get_weight_mappings(model_path)
|
200
|
-
|
201
229
|
nested_weight_mappings = {}
|
202
|
-
|
230
|
+
unmatched_params = {}
|
231
|
+
|
232
|
+
for key, file_location in weight_mappings.items():
|
233
|
+
matched = False
|
203
234
|
for param_name in params_to_nest:
|
204
|
-
|
205
|
-
if
|
206
|
-
dense_param = maybe_match
|
235
|
+
dense_param = match_param_name(key, param_name)
|
236
|
+
if dense_param:
|
207
237
|
if dense_param not in nested_weight_mappings:
|
208
238
|
nested_weight_mappings[dense_param] = {}
|
209
|
-
nested_weight_mappings[dense_param][param_name] =
|
239
|
+
nested_weight_mappings[dense_param][param_name] = file_location
|
240
|
+
matched = True
|
241
|
+
if return_unmatched_params and not matched:
|
242
|
+
unmatched_params[key] = file_location
|
243
|
+
|
244
|
+
if return_unmatched_params:
|
245
|
+
return nested_weight_mappings, unmatched_params
|
246
|
+
return nested_weight_mappings
|
210
247
|
|
248
|
+
|
249
|
+
def get_nested_mappings_from_state_dict(
|
250
|
+
state_dict, params_to_nest
|
251
|
+
) -> NestedWeightMappingType:
|
252
|
+
"""
|
253
|
+
Takes a state dict and returns a nested mapping from uncompressed
|
254
|
+
parameterized layer names to the value of
|
255
|
+
each layer's compression parameters.
|
256
|
+
|
257
|
+
Example of the nested mapping:
|
258
|
+
layer: {
|
259
|
+
weight_scale: ...,
|
260
|
+
weight: ...,
|
261
|
+
zero_point: ...,
|
262
|
+
}
|
263
|
+
|
264
|
+
:param state_dict: state dict of the model
|
265
|
+
:param params_to_nest: List of parameter names to nest.
|
266
|
+
:return: Nested mapping of parameterized layer names to the value of
|
267
|
+
each layer's compression parameters.
|
268
|
+
"""
|
269
|
+
nested_weight_mappings = {}
|
270
|
+
for key in state_dict.keys():
|
271
|
+
for param_name in params_to_nest:
|
272
|
+
dense_param = match_param_name(key, param_name)
|
273
|
+
if dense_param:
|
274
|
+
if dense_param not in nested_weight_mappings:
|
275
|
+
nested_weight_mappings[dense_param] = {}
|
276
|
+
nested_weight_mappings[dense_param][param_name] = state_dict[key]
|
211
277
|
return nested_weight_mappings
|
212
278
|
|
213
279
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.8.1.
|
3
|
+
Version: 0.8.1.20250107
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -5,15 +5,15 @@ compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1
|
|
5
5
|
compressed_tensors/compressors/base.py,sha256=D9TNwQcjanDiAHODPbg8JUqc66e3j50rctY7A708NEs,6743
|
6
6
|
compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
|
7
7
|
compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
|
8
|
-
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=
|
8
|
+
compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=nsMKqjdzEttvkabpp_7Qt4mhWcmjwRYnwjQzeN2a2E4,18295
|
9
9
|
compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
|
10
|
-
compressed_tensors/compressors/quantized_compressors/base.py,sha256=
|
11
|
-
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=
|
12
|
-
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=
|
10
|
+
compressed_tensors/compressors/quantized_compressors/base.py,sha256=LVqSSqSjGi8LB-X13zC_0AFHc8BobGQVC0zjInDhOWE,7217
|
11
|
+
compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fahmPJFz49rVS7q705uQwZ0kUtdP46GuXR7nPr6uIqI,4943
|
12
|
+
compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=OO5dceCfNVuY8A23kBg6z2wk-zGUVqR_MyLvObvT7pk,7741
|
13
13
|
compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=i2TESH27l7KXeOhJ6hShIoI904XX96l-cRQiMR6MAaU,704
|
14
|
-
compressed_tensors/compressors/sparse_compressors/base.py,sha256=
|
14
|
+
compressed_tensors/compressors/sparse_compressors/base.py,sha256=9e841MQWr0j8m33ejDw_jP5_BIpQ5099x9_pvuZ-Nr0,5944
|
15
15
|
compressed_tensors/compressors/sparse_compressors/dense.py,sha256=lSKNWRx6H7aUqaJj1j4qbXk8Gkm1UohbnvW1Rvq6Ra4,1284
|
16
|
-
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=
|
16
|
+
compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=Z9qMJ2JyUaBNQe-CXBJLuWacnHdFArrJYZEZDmW5x8o,6889
|
17
17
|
compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
|
18
18
|
compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=BMIQWTLlnUvxy14iEJegtiP75WHJeOVojey9mKOK1hE,9427
|
19
19
|
compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
|
@@ -27,9 +27,9 @@ compressed_tensors/quantization/quant_args.py,sha256=jwC__lSmuiJ2qSJYYZGgWgQNbZu
|
|
27
27
|
compressed_tensors/quantization/quant_config.py,sha256=vx06wBo91p4LCb3Vzd-2eCTUeIf_Sz2ZXRP263eQyjQ,10385
|
28
28
|
compressed_tensors/quantization/quant_scheme.py,sha256=eQ0JrRZ80GX69fpwW87VzPzzhajhk4mUaJScjk82OY4,6010
|
29
29
|
compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
|
30
|
-
compressed_tensors/quantization/lifecycle/apply.py,sha256=
|
30
|
+
compressed_tensors/quantization/lifecycle/apply.py,sha256=XS4M6N1opKBybhkuQsS338QVb_CKMhUM5TUKrqoNQ0k,16517
|
31
31
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
32
|
-
compressed_tensors/quantization/lifecycle/forward.py,sha256=
|
32
|
+
compressed_tensors/quantization/lifecycle/forward.py,sha256=DOWouUqfaLA4Qhg-ojVVBdhhSAlgZqFC26vZARxE0ko,12961
|
33
33
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
34
34
|
compressed_tensors/quantization/lifecycle/initialize.py,sha256=hymYtayTSumm8KCYAYPY267aWmlsJpt8oQFiRblk8qE,7452
|
35
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
@@ -41,10 +41,10 @@ compressed_tensors/utils/helpers.py,sha256=XF36-SLkXnAHh0VzbvUlAdh6a88aCQvS_WeYs
|
|
41
41
|
compressed_tensors/utils/offload.py,sha256=cMmzd9IdlNbs29CReHj1PPSLUM6OWaT5YumlLT5eP3w,13845
|
42
42
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
43
43
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
44
|
-
compressed_tensors/utils/safetensors_load.py,sha256=
|
44
|
+
compressed_tensors/utils/safetensors_load.py,sha256=fBuoHVPoBt1mkvqFJ60zQIASX_4nhl0-6QfFS27NY8I,11430
|
45
45
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
46
|
-
compressed_tensors_nightly-0.8.1.
|
47
|
-
compressed_tensors_nightly-0.8.1.
|
48
|
-
compressed_tensors_nightly-0.8.1.
|
49
|
-
compressed_tensors_nightly-0.8.1.
|
50
|
-
compressed_tensors_nightly-0.8.1.
|
46
|
+
compressed_tensors_nightly-0.8.1.20250107.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors_nightly-0.8.1.20250107.dist-info/METADATA,sha256=X847Gn6LonNqO0XBA0k2lHz_BGmYt7oeHJ8-k90LVsA,6799
|
48
|
+
compressed_tensors_nightly-0.8.1.20250107.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
49
|
+
compressed_tensors_nightly-0.8.1.20250107.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors_nightly-0.8.1.20250107.dist-info/RECORD,,
|
File without changes
|
File without changes
|