compressed-tensors 0.8.0__tar.gz → 0.9.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/PKG-INFO +1 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/setup.py +32 -6
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
- {compressed-tensors-0.8.0/src/compressed_tensors/config → compressed-tensors-0.9.0/src/compressed_tensors/compressors/sparse_compressors}/__init__.py +2 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed-tensors-0.9.0/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- {compressed-tensors-0.8.0/src/compressed_tensors/compressors/sparse_compressors → compressed-tensors-0.9.0/src/compressed_tensors/config}/__init__.py +2 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/base.py +1 -0
- compressed-tensors-0.9.0/src/compressed_tensors/config/sparse_24_bitmask.py +40 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/linear/compressed_linear.py +3 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/apply.py +48 -2
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/forward.py +2 -2
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_args.py +16 -3
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_config.py +3 -3
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed-tensors-0.9.0/src/compressed_tensors/utils/helpers.py +326 -0
- compressed-tensors-0.9.0/src/compressed_tensors/utils/offload.py +404 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/safetensors_load.py +83 -17
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/version.py +1 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/SOURCES.txt +2 -0
- compressed-tensors-0.8.0/src/compressed_tensors/utils/helpers.py +0 -121
- compressed-tensors-0.8.0/src/compressed_tensors/utils/offload.py +0 -116
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/LICENSE +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/README.md +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/pyproject.toml +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/setup.cfg +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/base.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/base.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/helpers.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/dense.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/linear/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/registry/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/registry/registry.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/__init__.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/permutations_24.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/permute.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/requires.txt +0 -0
- {compressed-tensors-0.8.0 → compressed-tensors-0.9.0}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
-
#
|
2
|
+
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
5
5
|
# You may obtain a copy of the License at
|
6
|
-
#
|
6
|
+
#
|
7
7
|
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
8
|
+
#
|
9
9
|
# Unless required by applicable law or agreed to in writing,
|
10
10
|
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
@@ -15,7 +15,33 @@
|
|
15
15
|
import os
|
16
16
|
from setuptools import setup, find_packages
|
17
17
|
from typing import List, Dict, Tuple
|
18
|
-
|
18
|
+
|
19
|
+
|
20
|
+
def get_release_and_version(package_path: str) -> Tuple[bool, bool, str, str, str, str]:
|
21
|
+
"""
|
22
|
+
Load version and release info from compressed-tensors package
|
23
|
+
"""
|
24
|
+
# compressed-tensors/src/compressed_tensors/version.py always exists, default source of truth
|
25
|
+
version_path = os.path.join(package_path, "version.py")
|
26
|
+
|
27
|
+
# exec() cannot set local variables so need to manually
|
28
|
+
locals_dict = {}
|
29
|
+
exec(open(version_path).read(), globals(), locals_dict)
|
30
|
+
is_release = locals_dict.get("is_release", False)
|
31
|
+
version = locals_dict.get("version", "unknown")
|
32
|
+
version_major = locals_dict.get("version_major", "unknown")
|
33
|
+
version_minor = locals_dict.get("version_minor", "unknown")
|
34
|
+
version_bug = locals_dict.get("version_bug", "unknown")
|
35
|
+
|
36
|
+
print(f"Loaded version {version} from {version_path}")
|
37
|
+
|
38
|
+
return (
|
39
|
+
is_release,
|
40
|
+
version,
|
41
|
+
version_major,
|
42
|
+
version_minor,
|
43
|
+
version_bug,
|
44
|
+
)
|
19
45
|
|
20
46
|
|
21
47
|
package_path = os.path.join(
|
@@ -35,7 +61,7 @@ if is_release:
|
|
35
61
|
_PACKAGE_NAME = "compressed-tensors"
|
36
62
|
else:
|
37
63
|
_PACKAGE_NAME = "compressed-tensors-nightly"
|
38
|
-
|
64
|
+
|
39
65
|
|
40
66
|
def _setup_long_description() -> Tuple[str, str]:
|
41
67
|
return open("README.md", "r", encoding="utf-8").read(), "text/markdown"
|
@@ -44,7 +70,7 @@ def _setup_packages() -> List:
|
|
44
70
|
return find_packages(
|
45
71
|
"src", include=["compressed_tensors", "compressed_tensors.*"], exclude=["*.__pycache__.*"]
|
46
72
|
)
|
47
|
-
|
73
|
+
|
48
74
|
def _setup_install_requires() -> List:
|
49
75
|
return ["torch>=1.7.0", "transformers", "pydantic>=2.0"]
|
50
76
|
|
@@ -17,14 +17,14 @@ import logging
|
|
17
17
|
import operator
|
18
18
|
import os
|
19
19
|
import re
|
20
|
+
from contextlib import contextmanager
|
20
21
|
from copy import deepcopy
|
21
|
-
from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
|
22
|
+
from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
|
22
23
|
|
23
24
|
import compressed_tensors
|
24
25
|
import torch
|
25
26
|
import transformers
|
26
27
|
from compressed_tensors.base import (
|
27
|
-
COMPRESSION_CONFIG_NAME,
|
28
28
|
COMPRESSION_VERSION_NAME,
|
29
29
|
QUANTIZATION_CONFIG_NAME,
|
30
30
|
QUANTIZATION_METHOD_NAME,
|
@@ -39,6 +39,8 @@ from compressed_tensors.quantization import (
|
|
39
39
|
apply_quantization_config,
|
40
40
|
load_pretrained_quantization,
|
41
41
|
)
|
42
|
+
from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
|
43
|
+
from compressed_tensors.quantization.quant_args import QuantizationArgs
|
42
44
|
from compressed_tensors.quantization.utils import (
|
43
45
|
is_module_quantized,
|
44
46
|
iter_named_leaf_modules,
|
@@ -103,12 +105,13 @@ class ModelCompressor:
|
|
103
105
|
:return: compressor for the configs, or None if model is not compressed
|
104
106
|
"""
|
105
107
|
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
106
|
-
compression_config = getattr(config,
|
108
|
+
compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
|
107
109
|
return cls.from_compression_config(compression_config)
|
108
110
|
|
109
111
|
@classmethod
|
110
112
|
def from_compression_config(
|
111
|
-
cls,
|
113
|
+
cls,
|
114
|
+
compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"],
|
112
115
|
):
|
113
116
|
"""
|
114
117
|
:param compression_config:
|
@@ -135,7 +138,7 @@ class ModelCompressor:
|
|
135
138
|
format, **sparsity_config
|
136
139
|
)
|
137
140
|
if quantization_config is not None:
|
138
|
-
quantization_config = QuantizationConfig.
|
141
|
+
quantization_config = QuantizationConfig.model_validate(quantization_config)
|
139
142
|
|
140
143
|
return cls(
|
141
144
|
sparsity_config=sparsity_config, quantization_config=quantization_config
|
@@ -191,7 +194,7 @@ class ModelCompressor:
|
|
191
194
|
|
192
195
|
if is_compressed_tensors_config(compression_config):
|
193
196
|
s_config = compression_config.sparsity_config
|
194
|
-
return s_config.
|
197
|
+
return s_config.model_dump() if s_config is not None else None
|
195
198
|
|
196
199
|
return compression_config.get(SPARSITY_CONFIG_NAME, None)
|
197
200
|
|
@@ -212,7 +215,7 @@ class ModelCompressor:
|
|
212
215
|
|
213
216
|
if is_compressed_tensors_config(compression_config):
|
214
217
|
q_config = compression_config.quantization_config
|
215
|
-
return q_config.
|
218
|
+
return q_config.model_dump() if q_config is not None else None
|
216
219
|
|
217
220
|
quantization_config = deepcopy(compression_config)
|
218
221
|
quantization_config.pop(SPARSITY_CONFIG_NAME, None)
|
@@ -265,7 +268,11 @@ class ModelCompressor:
|
|
265
268
|
state_dict = model.state_dict()
|
266
269
|
|
267
270
|
compressed_state_dict = state_dict
|
268
|
-
|
271
|
+
|
272
|
+
quantized_modules_to_args: Dict[
|
273
|
+
str, QuantizationArgs
|
274
|
+
] = map_modules_to_quant_args(model)
|
275
|
+
|
269
276
|
if self.quantization_compressor is not None:
|
270
277
|
compressed_state_dict = self.quantization_compressor.compress(
|
271
278
|
state_dict, names_to_scheme=quantized_modules_to_args
|
@@ -276,8 +283,14 @@ class ModelCompressor:
|
|
276
283
|
)
|
277
284
|
|
278
285
|
if self.sparsity_compressor is not None:
|
286
|
+
sparse_compression_targets: Set[str] = expand_sparse_target_names(
|
287
|
+
model=model,
|
288
|
+
targets=self.sparsity_config.targets,
|
289
|
+
ignore=self.sparsity_config.ignore,
|
290
|
+
)
|
279
291
|
compressed_state_dict = self.sparsity_compressor.compress(
|
280
|
-
compressed_state_dict
|
292
|
+
compressed_state_dict,
|
293
|
+
compression_targets=sparse_compression_targets,
|
281
294
|
)
|
282
295
|
|
283
296
|
# HACK: Override the dtype_byte_size function in transformers to
|
@@ -295,23 +308,44 @@ class ModelCompressor:
|
|
295
308
|
:param model: pytorch model to load decompressed weights into
|
296
309
|
"""
|
297
310
|
model_path = get_safetensors_folder(model_path)
|
298
|
-
|
311
|
+
sparse_decompressed = False
|
312
|
+
|
313
|
+
if (
|
314
|
+
self.sparsity_compressor is not None
|
315
|
+
and self.sparsity_config.format != CompressionFormat.dense.value
|
316
|
+
):
|
317
|
+
# Sparse decompression is applied on the model_path
|
299
318
|
dense_gen = self.sparsity_compressor.decompress(model_path)
|
300
319
|
self._replace_weights(dense_gen, model)
|
301
320
|
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
|
321
|
+
sparse_decompressed = True
|
302
322
|
|
303
323
|
if self.quantization_compressor is not None:
|
304
|
-
|
305
|
-
|
324
|
+
# Temporarily set quantization status to FROZEN to prevent
|
325
|
+
# quantization during apply_quantization_config. This ensures
|
326
|
+
# that the dtypes of the weights are not unintentionally updated.
|
327
|
+
# The status is restored after quantization params are loaded.
|
328
|
+
with override_quantization_status(
|
329
|
+
self.quantization_config, QuantizationStatus.FROZEN
|
330
|
+
):
|
331
|
+
names_to_scheme = apply_quantization_config(
|
332
|
+
model, self.quantization_config
|
333
|
+
)
|
334
|
+
load_pretrained_quantization(model, model_path)
|
335
|
+
|
336
|
+
model_path_or_state_dict = (
|
337
|
+
model.state_dict() if sparse_decompressed else model_path
|
338
|
+
)
|
339
|
+
|
306
340
|
dense_gen = self.quantization_compressor.decompress(
|
307
|
-
|
341
|
+
model_path_or_state_dict, names_to_scheme=names_to_scheme
|
308
342
|
)
|
309
343
|
self._replace_weights(dense_gen, model)
|
310
344
|
|
311
|
-
def
|
345
|
+
def freeze_quantization_status(module):
|
312
346
|
module.quantization_status = QuantizationStatus.FROZEN
|
313
347
|
|
314
|
-
model.apply(
|
348
|
+
model.apply(freeze_quantization_status)
|
315
349
|
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
|
316
350
|
|
317
351
|
def update_config(self, save_directory: str):
|
@@ -361,15 +395,35 @@ class ModelCompressor:
|
|
361
395
|
with open(config_file_path, "w") as config_file:
|
362
396
|
json.dump(config_data, config_file, indent=2, sort_keys=True)
|
363
397
|
|
364
|
-
def _replace_weights(self, dense_weight_generator, model):
|
398
|
+
def _replace_weights(self, dense_weight_generator, model: Module):
|
399
|
+
"""
|
400
|
+
Replace the weights of the model with the
|
401
|
+
provided dense weights.
|
402
|
+
|
403
|
+
This method iterates over the dense_weight_generator and
|
404
|
+
updates the corresponding weights in the model. If a parameter
|
405
|
+
name does not exist in the model, it will be skipped.
|
406
|
+
|
407
|
+
:param dense_weight_generator (generator): A generator that yields
|
408
|
+
tuples of (name, data), where 'name' is the parameter name and
|
409
|
+
'data' is the updated param data
|
410
|
+
:param model: The model whose weights are to be updated.
|
411
|
+
"""
|
365
412
|
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
|
366
413
|
split_name = name.split(".")
|
367
414
|
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
|
368
415
|
module = operator.attrgetter(prefix)(model)
|
369
|
-
|
416
|
+
if hasattr(module, param_name):
|
417
|
+
update_parameter_data(module, data, param_name)
|
370
418
|
|
371
419
|
|
372
|
-
def map_modules_to_quant_args(model: Module) -> Dict:
|
420
|
+
def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
|
421
|
+
"""
|
422
|
+
Given a pytorch model, map out the submodule name (usually linear layers)
|
423
|
+
to the QuantizationArgs
|
424
|
+
|
425
|
+
:param model: pytorch model
|
426
|
+
"""
|
373
427
|
quantized_modules_to_args = {}
|
374
428
|
for name, submodule in iter_named_leaf_modules(model):
|
375
429
|
if is_module_quantized(submodule):
|
@@ -390,3 +444,23 @@ def new_dtype_byte_size(dtype):
|
|
390
444
|
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
|
391
445
|
bit_size = int(bit_search.groups()[0])
|
392
446
|
return bit_size // 8
|
447
|
+
|
448
|
+
|
449
|
+
@contextmanager
|
450
|
+
def override_quantization_status(
|
451
|
+
config: QuantizationConfig, status: QuantizationStatus
|
452
|
+
):
|
453
|
+
"""
|
454
|
+
Within this context, the quantization status will be set to the
|
455
|
+
supplied status. After the context exits, the original status
|
456
|
+
will be restored.
|
457
|
+
|
458
|
+
:param config: the quantization config to override
|
459
|
+
:param status: the status to temporarily set
|
460
|
+
"""
|
461
|
+
original_status = config.quantization_status
|
462
|
+
config.quantization_status = status
|
463
|
+
try:
|
464
|
+
yield
|
465
|
+
finally:
|
466
|
+
config.quantization_status = original_status
|
@@ -13,12 +13,17 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import Any, Dict, Generator, Tuple, Union
|
17
18
|
|
18
19
|
import torch
|
19
20
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
21
|
from compressed_tensors.quantization import QuantizationArgs
|
21
|
-
from compressed_tensors.utils import
|
22
|
+
from compressed_tensors.utils import (
|
23
|
+
get_nested_mappings_from_state_dict,
|
24
|
+
get_nested_weight_mappings,
|
25
|
+
merge_names,
|
26
|
+
)
|
22
27
|
from safetensors import safe_open
|
23
28
|
from torch import Tensor
|
24
29
|
from tqdm import tqdm
|
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
113
118
|
|
114
119
|
def decompress(
|
115
120
|
self,
|
116
|
-
path_to_model_or_tensors: str,
|
121
|
+
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
|
117
122
|
names_to_scheme: Dict[str, QuantizationArgs],
|
118
123
|
device: str = "cpu",
|
119
124
|
) -> Generator[Tuple[str, Tensor], None, None]:
|
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
121
126
|
Reads a compressed state dict located at path_to_model_or_tensors
|
122
127
|
and returns a generator for sequentially decompressing back to a
|
123
128
|
dense state dict
|
124
|
-
|
125
129
|
:param path_to_model_or_tensors: path to compressed safetensors model (directory
|
126
130
|
with one or more safetensors files) or compressed tensors file
|
127
131
|
:param names_to_scheme: quantization args for each quantized weight
|
128
132
|
:param device: optional device to load intermediate weights into
|
129
133
|
:return: compressed state dict
|
130
134
|
"""
|
135
|
+
if isinstance(path_to_model_or_tensors, (str, Path)):
|
136
|
+
yield from self._decompress_from_path(
|
137
|
+
path_to_model_or_tensors, names_to_scheme, device
|
138
|
+
)
|
139
|
+
|
140
|
+
else:
|
141
|
+
yield from self._decompress_from_state_dict(
|
142
|
+
path_to_model_or_tensors, names_to_scheme
|
143
|
+
)
|
144
|
+
|
145
|
+
def _decompress_from_path(self, path_to_model, names_to_scheme, device):
|
131
146
|
weight_mappings = get_nested_weight_mappings(
|
132
|
-
|
147
|
+
path_to_model, self.COMPRESSION_PARAM_NAMES
|
133
148
|
)
|
134
149
|
for weight_name in weight_mappings.keys():
|
135
150
|
weight_data = {}
|
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
|
|
137
152
|
full_name = merge_names(weight_name, param_name)
|
138
153
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
139
154
|
weight_data[param_name] = f.get_tensor(full_name)
|
155
|
+
if "weight_scale" in weight_data:
|
156
|
+
quant_args = names_to_scheme[weight_name]
|
157
|
+
decompressed = self.decompress_weight(
|
158
|
+
compressed_data=weight_data, quantization_args=quant_args
|
159
|
+
)
|
160
|
+
yield merge_names(weight_name, "weight"), decompressed
|
161
|
+
|
162
|
+
def _decompress_from_state_dict(self, state_dict, names_to_scheme):
|
163
|
+
weight_mappings = get_nested_mappings_from_state_dict(
|
164
|
+
state_dict, self.COMPRESSION_PARAM_NAMES
|
165
|
+
)
|
166
|
+
for weight_name in weight_mappings.keys():
|
167
|
+
weight_data = {}
|
168
|
+
for param_name, param_value in weight_mappings[weight_name].items():
|
169
|
+
weight_data[param_name] = param_value
|
140
170
|
|
141
171
|
if "weight_scale" in weight_data:
|
142
172
|
quant_args = names_to_scheme[weight_name]
|
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -93,9 +93,11 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
|
|
93
93
|
args=quantization_args,
|
94
94
|
dtype=quantization_args.pytorch_dtype(),
|
95
95
|
)
|
96
|
+
else:
|
97
|
+
quantized_weight = weight
|
96
98
|
|
97
|
-
|
98
|
-
|
99
|
+
if device is not None:
|
100
|
+
quantized_weight = quantized_weight.to(device)
|
99
101
|
|
100
102
|
return {"weight": quantized_weight}
|
101
103
|
|
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
68
68
|
self,
|
69
69
|
weight: Tensor,
|
70
70
|
scale: Tensor,
|
71
|
+
quantization_args: QuantizationArgs,
|
71
72
|
zero_point: Optional[Tensor] = None,
|
72
73
|
g_idx: Optional[torch.Tensor] = None,
|
73
|
-
quantization_args: Optional[QuantizationArgs] = None,
|
74
74
|
device: Optional[torch.device] = None,
|
75
75
|
) -> Dict[str, torch.Tensor]:
|
76
76
|
"""
|
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
78
78
|
|
79
79
|
:param weight: uncompressed weight tensor
|
80
80
|
:param scale: quantization scale for weight
|
81
|
+
:param quantization_args: quantization parameters for weight
|
81
82
|
:param zero_point: quantization zero point for weight
|
82
83
|
:param g_idx: optional mapping from column index to group index
|
83
|
-
:param quantization_args: quantization parameters for weight
|
84
84
|
:param device: optional device to move compressed output to
|
85
85
|
:return: dictionary of compressed weight data
|
86
86
|
"""
|
@@ -94,6 +94,8 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
|
|
94
94
|
args=quantization_args,
|
95
95
|
dtype=torch.int8,
|
96
96
|
)
|
97
|
+
else:
|
98
|
+
quantized_weight = weight
|
97
99
|
|
98
100
|
packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
|
99
101
|
weight_shape = torch.tensor(weight.shape)
|
@@ -11,8 +11,9 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
14
|
# flake8: noqa
|
15
|
+
|
16
16
|
from .base import *
|
17
17
|
from .dense import *
|
18
|
+
from .sparse_24_bitmask import *
|
18
19
|
from .sparse_bitmask import *
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import logging
|
16
|
-
from typing import Dict, Generator, Tuple
|
16
|
+
from typing import Dict, Generator, Optional, Set, Tuple
|
17
17
|
|
18
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
19
19
|
from compressed_tensors.utils import get_nested_weight_mappings, merge_names
|
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
|
|
30
30
|
class BaseSparseCompressor(BaseCompressor):
|
31
31
|
"""
|
32
32
|
Base class representing a sparse compression algorithm. Each child class should
|
33
|
-
implement compression_param_info, compress_weight and decompress_weight
|
33
|
+
implement compression_param_info, compress_weight and decompress_weight; child
|
34
|
+
classes should also define COMPRESSION_PARAM_NAMES.
|
34
35
|
|
35
36
|
Compressors support compressing/decompressing a full module state dict or a single
|
36
37
|
quantized PyTorch leaf module.
|
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
|
|
59
60
|
:param config: config specifying compression parameters
|
60
61
|
"""
|
61
62
|
|
62
|
-
def compress(
|
63
|
+
def compress(
|
64
|
+
self,
|
65
|
+
model_state: Dict[str, Tensor],
|
66
|
+
compression_targets: Optional[Set[str]] = None,
|
67
|
+
) -> Dict[str, Tensor]:
|
63
68
|
"""
|
64
69
|
Compresses a dense state dict using bitmask compression
|
65
70
|
|
66
71
|
:param model_state: state dict of uncompressed model
|
72
|
+
:param compression_targets: optional set of layer prefixes to compress,
|
73
|
+
otherwise compress all layers (for backwards compatibility)
|
67
74
|
:return: compressed state dict
|
68
75
|
"""
|
69
76
|
compressed_dict = {}
|
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
|
|
71
78
|
f"Compressing model with {len(model_state)} parameterized layers..."
|
72
79
|
)
|
73
80
|
for name, value in tqdm(model_state.items(), desc="Compressing model"):
|
74
|
-
|
81
|
+
if not self.should_compress(name, compression_targets):
|
82
|
+
compressed_dict[name] = value
|
83
|
+
continue
|
84
|
+
prefix = name
|
85
|
+
if prefix.endswith(".weight"):
|
86
|
+
prefix = prefix[: -(len(".weight"))]
|
87
|
+
|
88
|
+
compression_data = self.compress_weight(prefix, value)
|
75
89
|
for key in compression_data.keys():
|
76
90
|
if key in compressed_dict:
|
77
91
|
_LOGGER.warn(
|
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
|
|
97
111
|
:param device: device to load decompressed weights onto
|
98
112
|
:return: iterator for generating decompressed weights
|
99
113
|
"""
|
100
|
-
weight_mappings = get_nested_weight_mappings(
|
101
|
-
path_to_model_or_tensors,
|
114
|
+
weight_mappings, ignored_params = get_nested_weight_mappings(
|
115
|
+
path_to_model_or_tensors,
|
116
|
+
self.COMPRESSION_PARAM_NAMES,
|
117
|
+
return_unmatched_params=True,
|
102
118
|
)
|
103
119
|
for weight_name in weight_mappings.keys():
|
104
120
|
weight_data = {}
|
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
|
|
107
123
|
with safe_open(safe_path, framework="pt", device=device) as f:
|
108
124
|
weight_data[param_name] = f.get_tensor(full_name)
|
109
125
|
decompressed = self.decompress_weight(weight_data)
|
110
|
-
yield weight_name, decompressed
|
126
|
+
yield merge_names(weight_name, "weight"), decompressed
|
127
|
+
|
128
|
+
for ignored_param_name, safe_path in ignored_params.items():
|
129
|
+
with safe_open(safe_path, framework="pt", device=device) as f:
|
130
|
+
value = f.get_tensor(ignored_param_name)
|
131
|
+
yield ignored_param_name, value
|
132
|
+
|
133
|
+
@staticmethod
|
134
|
+
def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
|
135
|
+
"""
|
136
|
+
Check if a parameter should be compressed.
|
137
|
+
Currently, this only returns True for weight parameters.
|
138
|
+
|
139
|
+
:param name: name of the parameter
|
140
|
+
:param expanded_targets: set of layer prefixes to compress
|
141
|
+
:return: whether or not the parameter should be compressed
|
142
|
+
"""
|
143
|
+
if expanded_targets is None:
|
144
|
+
return name.endswith(".weight")
|
145
|
+
|
146
|
+
return (
|
147
|
+
name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
|
148
|
+
)
|