compressed-tensors-nightly 0.3.3.20240514__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. compressed_tensors/__init__.py +21 -0
  2. compressed_tensors/base.py +17 -0
  3. compressed_tensors/compressors/__init__.py +22 -0
  4. compressed_tensors/compressors/base.py +59 -0
  5. compressed_tensors/compressors/dense.py +34 -0
  6. compressed_tensors/compressors/helpers.py +137 -0
  7. compressed_tensors/compressors/int_quantized.py +95 -0
  8. compressed_tensors/compressors/model_compressor.py +264 -0
  9. compressed_tensors/compressors/sparse_bitmask.py +239 -0
  10. compressed_tensors/config/__init__.py +18 -0
  11. compressed_tensors/config/base.py +43 -0
  12. compressed_tensors/config/dense.py +36 -0
  13. compressed_tensors/config/sparse_bitmask.py +36 -0
  14. compressed_tensors/quantization/__init__.py +21 -0
  15. compressed_tensors/quantization/lifecycle/__init__.py +23 -0
  16. compressed_tensors/quantization/lifecycle/apply.py +196 -0
  17. compressed_tensors/quantization/lifecycle/calibration.py +51 -0
  18. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  19. compressed_tensors/quantization/lifecycle/forward.py +333 -0
  20. compressed_tensors/quantization/lifecycle/frozen.py +50 -0
  21. compressed_tensors/quantization/lifecycle/initialize.py +99 -0
  22. compressed_tensors/quantization/observers/__init__.py +21 -0
  23. compressed_tensors/quantization/observers/base.py +130 -0
  24. compressed_tensors/quantization/observers/helpers.py +54 -0
  25. compressed_tensors/quantization/observers/memoryless.py +48 -0
  26. compressed_tensors/quantization/observers/min_max.py +80 -0
  27. compressed_tensors/quantization/quant_args.py +125 -0
  28. compressed_tensors/quantization/quant_config.py +210 -0
  29. compressed_tensors/quantization/quant_scheme.py +39 -0
  30. compressed_tensors/quantization/utils/__init__.py +16 -0
  31. compressed_tensors/quantization/utils/helpers.py +131 -0
  32. compressed_tensors/registry/__init__.py +17 -0
  33. compressed_tensors/registry/registry.py +360 -0
  34. compressed_tensors/utils/__init__.py +16 -0
  35. compressed_tensors/utils/helpers.py +45 -0
  36. compressed_tensors/utils/safetensors_load.py +237 -0
  37. compressed_tensors/version.py +50 -0
  38. compressed_tensors_nightly-0.3.3.20240514.dist-info/LICENSE +201 -0
  39. compressed_tensors_nightly-0.3.3.20240514.dist-info/METADATA +105 -0
  40. compressed_tensors_nightly-0.3.3.20240514.dist-info/RECORD +42 -0
  41. compressed_tensors_nightly-0.3.3.20240514.dist-info/WHEEL +5 -0
  42. compressed_tensors_nightly-0.3.3.20240514.dist-info/top_level.txt +1 -0
@@ -0,0 +1,21 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .base import *
16
+
17
+ # flake8: noqa
18
+ from .compressors import *
19
+ from .config import *
20
+ from .quantization import QuantizationConfig, QuantizationStatus
21
+ from .utils import *
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ SPARSITY_CONFIG_NAME = "sparsity_config"
16
+ QUANTIZATION_CONFIG_NAME = "quantization_config"
17
+ COMPRESSION_CONFIG_NAME = "compression_config"
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+
17
+ from .base import Compressor
18
+ from .dense import DenseCompressor
19
+ from .helpers import load_compressed, save_compressed, save_compressed_model
20
+ from .int_quantized import IntQuantizationCompressor
21
+ from .model_compressor import ModelCompressor
22
+ from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
@@ -0,0 +1,59 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Generator, Tuple, Union
16
+
17
+ from compressed_tensors.config import SparsityCompressionConfig
18
+ from compressed_tensors.quantization import QuantizationConfig
19
+ from compressed_tensors.registry import RegistryMixin
20
+ from torch import Tensor
21
+
22
+
23
+ __all__ = ["Compressor"]
24
+
25
+
26
+ class Compressor(RegistryMixin):
27
+ """
28
+ Base class representing a model compression algorithm
29
+
30
+ :param config: config specifying compression parameters
31
+ """
32
+
33
+ def __init__(
34
+ self, config: Union[SparsityCompressionConfig, QuantizationConfig, None] = None
35
+ ):
36
+ self.config = config
37
+
38
+ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
39
+ """
40
+ Compresses a dense state dict
41
+
42
+ :param model_state: state dict of uncompressed model
43
+ :return: compressed state dict
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def decompress(
48
+ self, path_to_model_or_tensors: str, device: str = "cpu"
49
+ ) -> Generator[Tuple[str, Tensor], None, None]:
50
+ """
51
+ Reads a compressed state dict located at path_to_model_or_tensors
52
+ and returns a generator for sequentially decompressing back to a
53
+ dense state dict
54
+
55
+ :param model_path: path to compressed safetensors model (directory with
56
+ one or more safetensors files) or compressed tensors file
57
+ :return: compressed state dict
58
+ """
59
+ raise NotImplementedError()
@@ -0,0 +1,34 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict, Generator, Tuple
16
+
17
+ from compressed_tensors.compressors import Compressor
18
+ from compressed_tensors.config import CompressionFormat
19
+ from torch import Tensor
20
+
21
+
22
+ @Compressor.register(name=CompressionFormat.dense.value)
23
+ class DenseCompressor(Compressor):
24
+ """
25
+ Identity compressor for dense models, returns the original state_dict
26
+ """
27
+
28
+ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
29
+ return model_state
30
+
31
+ def decompress(
32
+ self, path_to_model_or_tensors: str, device: str = "cpu"
33
+ ) -> Generator[Tuple[str, Tensor], None, None]:
34
+ return iter([])
@@ -0,0 +1,137 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from pathlib import Path
16
+ from typing import Dict, Generator, Optional, Tuple, Union
17
+
18
+ import torch
19
+ from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
21
+ from compressed_tensors.utils.safetensors_load import get_weight_mappings
22
+ from safetensors import safe_open
23
+ from safetensors.torch import save_file
24
+ from torch import Tensor
25
+
26
+
27
+ __all__ = [
28
+ "load_compressed",
29
+ "save_compressed",
30
+ "save_compressed_model",
31
+ ]
32
+
33
+
34
+ def save_compressed(
35
+ tensors: Dict[str, Tensor],
36
+ save_path: Union[str, Path],
37
+ compression_format: Optional[CompressionFormat] = None,
38
+ ):
39
+ """
40
+ Save compressed tensors to disk. If tensors are not compressed,
41
+ save them as is.
42
+
43
+ :param tensors: dictionary of tensors to compress
44
+ :param save_path: path to save compressed tensors
45
+ :param compression_format: compression format used for the tensors
46
+ :return: compression config, if tensors were compressed - None otherwise
47
+ """
48
+ if tensors is None or len(tensors) == 0:
49
+ raise ValueError("No tensors or empty tensors provided to compress")
50
+
51
+ # if no compression_format specified, default to `dense`
52
+ compression_format = compression_format or CompressionFormat.dense.value
53
+
54
+ if not (
55
+ compression_format in Compressor.registered_names()
56
+ or compression_format in Compressor.registered_aliases()
57
+ ):
58
+ raise ValueError(
59
+ f"Unknown compression format: {compression_format}. "
60
+ f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
61
+ )
62
+
63
+ # compress
64
+ compressor = Compressor.load_from_registry(compression_format)
65
+ # save compressed tensors
66
+ compressed_tensors = compressor.compress(tensors)
67
+ save_file(compressed_tensors, save_path)
68
+
69
+
70
+ def load_compressed(
71
+ compressed_tensors: Union[str, Path],
72
+ compression_config: SparsityCompressionConfig = None,
73
+ device: Optional[str] = "cpu",
74
+ ) -> Generator[Tuple[str, Tensor], None, None]:
75
+ """
76
+ Load compressed tensors from disk.
77
+ If tensors are not compressed, load them as is.
78
+
79
+ :param compressed_tensors: path to compressed tensors.
80
+ This can be a path to a file or a directory containing
81
+ one or multiple safetensor files (if multiple - in the format
82
+ assumed by huggingface)
83
+ :param compression_config: compression config to use for decompressing tensors.
84
+ :param device: device to move tensors to. If None, tensors are loaded on CPU.
85
+ :param return_dict: if True, return a dictionary of decompressed tensors
86
+ :return a generator that yields the name and tensor of the decompressed tensor
87
+ """
88
+ if compressed_tensors is None or not Path(compressed_tensors).exists():
89
+ raise ValueError("No compressed tensors provided to load")
90
+
91
+ if (
92
+ compression_config is None
93
+ or compression_config.format == CompressionFormat.dense.value
94
+ ):
95
+ # if no compression_config specified, or `dense` format specified,
96
+ # assume tensors are not compressed on disk
97
+ weight_mappings = get_weight_mappings(compressed_tensors)
98
+ for weight_name, file_with_weight_name in weight_mappings.items():
99
+ with safe_open(file_with_weight_name, framework="pt", device=device) as f:
100
+ weight = f.get_tensor(weight_name)
101
+ yield weight_name, weight
102
+ else:
103
+ # decompress tensors
104
+ compression_format = compression_config.format
105
+ compressor = Compressor.load_from_registry(
106
+ compression_format, config=compression_config
107
+ )
108
+ yield from compressor.decompress(compressed_tensors, device=device)
109
+
110
+
111
+ def save_compressed_model(
112
+ model: torch.nn.Module,
113
+ filename: str,
114
+ compression_format: Optional[CompressionFormat] = None,
115
+ force_contiguous: bool = True,
116
+ ):
117
+ """
118
+ Wrapper around safetensors `save_model` helper function, which allows for
119
+ saving compressed model to disk.
120
+
121
+ Note: The model is assumed to have a
122
+ state_dict with unique entries
123
+
124
+ :param model: model to save on disk
125
+ :param filename: filename location to save the file
126
+ :param compression_format: compression format used for the model
127
+ :param force_contiguous: forcing the state_dict to be saved as contiguous tensors
128
+ """
129
+ state_dict = model.state_dict()
130
+ if force_contiguous:
131
+ state_dict = {k: v.contiguous() for k, v in state_dict.items()}
132
+ try:
133
+ save_compressed(state_dict, filename, compression_format=compression_format)
134
+ except ValueError as e:
135
+ msg = str(e)
136
+ msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
137
+ raise ValueError(msg)
@@ -0,0 +1,95 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from typing import Dict, Generator, Tuple
17
+
18
+ import torch
19
+ from compressed_tensors.compressors import Compressor
20
+ from compressed_tensors.config import CompressionFormat
21
+ from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
22
+ from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
+ from safetensors import safe_open
24
+ from torch import Tensor
25
+ from tqdm import tqdm
26
+
27
+
28
+ __all__ = ["IntQuantizationCompressor"]
29
+
30
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
31
+
32
+
33
+ @Compressor.register(name=CompressionFormat.int_quantized.value)
34
+ class IntQuantizationCompressor(Compressor):
35
+ """
36
+ Integer compression for quantized models. Weight of each quantized layer is
37
+ converted from its original float type to the format specified by the layer's
38
+ quantization scheme.
39
+ """
40
+
41
+ COMPRESSION_PARAM_NAMES = ["weight", "weight_scale", "weight_zero_point"]
42
+
43
+ def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
44
+ model_quant_args = kwargs["model_quant_args"]
45
+ compressed_dict = {}
46
+ _LOGGER.debug(
47
+ f"Compressing model with {len(model_state)} parameterized layers..."
48
+ )
49
+
50
+ for name, value in tqdm(model_state.items(), desc="Compressing model"):
51
+ if name.endswith(".weight"):
52
+ prefix = name.removesuffix(".weight")
53
+ scale = model_state.get(merge_names(prefix, "weight_scale"), None)
54
+ zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
55
+ if scale is not None and zp is not None:
56
+ # weight is quantized, compress it
57
+ quant_args = model_quant_args[prefix]
58
+ try:
59
+ bit_depth = torch.finfo(value.dtype).bits
60
+ except TypeError:
61
+ bit_depth = torch.iinfo(value.dtype).bits
62
+ if bit_depth > quant_args.num_bits:
63
+ # only quantize if not already quantized
64
+ value = quantize(
65
+ x=value,
66
+ scale=scale,
67
+ zero_point=zp,
68
+ args=quant_args,
69
+ dtype=torch.int8,
70
+ )
71
+
72
+ compressed_dict[name] = value.to("cpu")
73
+
74
+ return compressed_dict
75
+
76
+ def decompress(
77
+ self, path_to_model_or_tensors: str, device: str = "cpu"
78
+ ) -> Generator[Tuple[str, Tensor], None, None]:
79
+ weight_mappings = get_nested_weight_mappings(
80
+ path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
81
+ )
82
+ for weight_name in weight_mappings.keys():
83
+ weight_data = {}
84
+ for param_name, safe_path in weight_mappings[weight_name].items():
85
+ full_name = merge_names(weight_name, param_name)
86
+ with safe_open(safe_path, framework="pt", device=device) as f:
87
+ weight_data[param_name] = f.get_tensor(full_name)
88
+
89
+ if len(weight_data) == len(self.COMPRESSION_PARAM_NAMES):
90
+ decompressed = dequantize(
91
+ x_q=weight_data["weight"],
92
+ scale=weight_data["weight_scale"],
93
+ zero_point=weight_data["weight_zero_point"],
94
+ )
95
+ yield merge_names(weight_name, "weight"), decompressed
@@ -0,0 +1,264 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import logging
17
+ import operator
18
+ import os
19
+ from typing import Dict, Optional, Union
20
+
21
+ from compressed_tensors.base import (
22
+ COMPRESSION_CONFIG_NAME,
23
+ QUANTIZATION_CONFIG_NAME,
24
+ SPARSITY_CONFIG_NAME,
25
+ )
26
+ from compressed_tensors.compressors import Compressor
27
+ from compressed_tensors.config import SparsityCompressionConfig
28
+ from compressed_tensors.quantization import (
29
+ QuantizationConfig,
30
+ QuantizationStatus,
31
+ apply_quantization_config,
32
+ load_pretrained_quantization,
33
+ )
34
+ from compressed_tensors.quantization.utils import (
35
+ is_module_quantized,
36
+ iter_named_leaf_modules,
37
+ )
38
+ from compressed_tensors.utils import get_safetensors_folder
39
+ from torch import Tensor
40
+ from torch.nn import Module, Parameter
41
+ from tqdm import tqdm
42
+ from transformers import AutoConfig
43
+ from transformers.file_utils import CONFIG_NAME
44
+
45
+
46
+ __all__ = ["ModelCompressor"]
47
+
48
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
49
+
50
+
51
+ class ModelCompressor:
52
+ """
53
+ Handles compression and decompression of a model with a sparsity config and/or
54
+ quantization config.
55
+
56
+ Compression LifeCycle
57
+ - compressor = ModelCompressor.from_pretrained_model(model)
58
+ - compressed_state_dict = compressor.compress(model, state_dict)
59
+ - compressor.quantization_compressor.compress(model, state_dict)
60
+ - compressor.sparsity_compressor.compress(model, state_dict)
61
+ - model.save_pretrained(output_dir, state_dict=compressed_state_dict)
62
+ - compressor.update_config(output_dir)
63
+
64
+ Decompression LifeCycle
65
+ - compressor = ModelCompressor.from_pretrained(comp_model_path)
66
+ - model = AutoModel.from_pretrained(comp_model_path)
67
+ - compressor.decompress(comp_model_path, model)
68
+ - compressor.sparsity_compressor.decompress(comp_model_path, model)
69
+ - compressor.quantization_compressor.decompress(comp_model_path, model)
70
+
71
+ :param sparsity_config: config specifying sparsity compression parameters
72
+ :param quantization_config: config specifying quantization compression parameters
73
+ """
74
+
75
+ @classmethod
76
+ def from_pretrained(
77
+ cls,
78
+ pretrained_model_name_or_path: str,
79
+ ) -> Optional["ModelCompressor"]:
80
+ """
81
+ Given a path to a model config, extract the sparsity and/or quantization
82
+ configs and load a ModelCompressor
83
+
84
+ :param pretrained_model_name_or_path: path to model config on disk or HF hub
85
+ :return: compressor for the extracted configs
86
+ """
87
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
88
+ compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
89
+ if compression_config is None:
90
+ return None
91
+
92
+ sparsity_config = compression_config.get(SPARSITY_CONFIG_NAME, None)
93
+ quantization_config = compression_config.get(QUANTIZATION_CONFIG_NAME, None)
94
+
95
+ if sparsity_config is None and quantization_config is None:
96
+ return None
97
+
98
+ if sparsity_config is not None:
99
+ format = sparsity_config.get("format")
100
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
101
+ format, **sparsity_config
102
+ )
103
+ if quantization_config is not None:
104
+ quantization_config = QuantizationConfig.parse_obj(quantization_config)
105
+
106
+ return cls(
107
+ sparsity_config=sparsity_config, quantization_config=quantization_config
108
+ )
109
+
110
+ @classmethod
111
+ def from_pretrained_model(
112
+ cls,
113
+ model: Module,
114
+ sparsity_config: Union[SparsityCompressionConfig, str, None] = None,
115
+ quantization_format: Optional[str] = None,
116
+ ) -> Optional["ModelCompressor"]:
117
+ """
118
+ Given a pytorch model and optional sparsity and/or quantization configs,
119
+ load the appropriate compressors
120
+
121
+ :param model: pytorch model to target for compression
122
+ :param sparsity_config: a filled in sparsity config or string corresponding
123
+ to a sparsity compression algorithm
124
+ :param quantization_format: string corresponding to a quantization compression
125
+ algorithm
126
+ :return: compressor for the extracted configs
127
+ """
128
+ quantization_config = QuantizationConfig.from_pretrained(
129
+ model, format=quantization_format
130
+ )
131
+
132
+ if isinstance(sparsity_config, str): # we passed in a sparsity format
133
+ sparsity_config = SparsityCompressionConfig.load_from_registry(
134
+ sparsity_config
135
+ )
136
+
137
+ if sparsity_config is None and quantization_config is None:
138
+ return None
139
+
140
+ return cls(
141
+ sparsity_config=sparsity_config, quantization_config=quantization_config
142
+ )
143
+
144
+ def __init__(
145
+ self,
146
+ sparsity_config: Optional[SparsityCompressionConfig] = None,
147
+ quantization_config: Optional[QuantizationConfig] = None,
148
+ ):
149
+ self.sparsity_config = sparsity_config
150
+ self.quantization_config = quantization_config
151
+ self.sparsity_compressor = None
152
+ self.quantization_compressor = None
153
+
154
+ if sparsity_config is not None:
155
+ self.sparsity_compressor = Compressor.load_from_registry(
156
+ sparsity_config.format, config=sparsity_config
157
+ )
158
+ if quantization_config is not None:
159
+ self.quantization_compressor = Compressor.load_from_registry(
160
+ quantization_config.format, config=quantization_config
161
+ )
162
+
163
+ def compress(
164
+ self, model: Module, state_dict: Optional[Dict[str, Tensor]] = None
165
+ ) -> Dict[str, Tensor]:
166
+ """
167
+ Compresses a dense state dict or model with sparsity and/or quantization
168
+
169
+ :param model: uncompressed model to compress
170
+ :param model_state: optional uncompressed state_dict to insert into model
171
+ :return: compressed state dict
172
+ """
173
+ if state_dict is None:
174
+ state_dict = model.state_dict()
175
+
176
+ compressed_state_dict = state_dict
177
+ quantized_modules_to_args = _get_weight_arg_mappings(model)
178
+ if self.quantization_compressor is not None:
179
+ compressed_state_dict = self.quantization_compressor.compress(
180
+ state_dict, model_quant_args=quantized_modules_to_args
181
+ )
182
+
183
+ if self.sparsity_compressor is not None:
184
+ compressed_state_dict = self.sparsity_compressor.compress(
185
+ compressed_state_dict
186
+ )
187
+
188
+ return compressed_state_dict
189
+
190
+ def decompress(self, model_path: str, model: Module):
191
+ """
192
+ Overwrites the weights in model with weights decompressed from model_path
193
+
194
+ :param model_path: path to compressed weights
195
+ :param model: pytorch model to load decompressed weights into
196
+ """
197
+ model_path = get_safetensors_folder(model_path)
198
+ if self.sparsity_compressor is not None:
199
+ dense_gen = self.sparsity_compressor.decompress(model_path)
200
+ self._replace_weights(dense_gen, model)
201
+ setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
202
+
203
+ if self.quantization_compressor is not None:
204
+ apply_quantization_config(model, self.quantization_config)
205
+ load_pretrained_quantization(model, model_path)
206
+ dense_gen = self.quantization_compressor.decompress(model_path)
207
+ self._replace_weights(dense_gen, model)
208
+
209
+ def update_status(module):
210
+ module.quantization_status = QuantizationStatus.FROZEN
211
+
212
+ model.apply(update_status)
213
+ setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
214
+
215
+ def update_config(self, save_directory: str):
216
+ """
217
+ Update the model config located at save_directory with compression configs
218
+ for sparsity and/or quantization
219
+
220
+ :param save_directory: path to a folder containing a HF model config
221
+ """
222
+ config_file_path = os.path.join(save_directory, CONFIG_NAME)
223
+ if not os.path.exists(config_file_path):
224
+ _LOGGER.warning(
225
+ f"Could not find a valid model config file in "
226
+ f"{save_directory}. Compression config will not be saved."
227
+ )
228
+ return
229
+
230
+ with open(config_file_path, "r") as config_file:
231
+ config_data = json.load(config_file)
232
+
233
+ config_data[COMPRESSION_CONFIG_NAME] = {}
234
+ if self.quantization_config is not None:
235
+ quant_config_data = self.quantization_config.model_dump()
236
+ config_data[COMPRESSION_CONFIG_NAME][
237
+ QUANTIZATION_CONFIG_NAME
238
+ ] = quant_config_data
239
+ if self.sparsity_config is not None:
240
+ sparsity_config_data = self.sparsity_config.model_dump()
241
+ config_data[COMPRESSION_CONFIG_NAME][
242
+ SPARSITY_CONFIG_NAME
243
+ ] = sparsity_config_data
244
+
245
+ with open(config_file_path, "w") as config_file:
246
+ json.dump(config_data, config_file, indent=2, sort_keys=True)
247
+
248
+ def _replace_weights(self, dense_weight_generator, model):
249
+ for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
250
+ # loading the decompressed weights into the model
251
+ model_device = operator.attrgetter(name)(model).device
252
+ data_new = Parameter(data.to(model_device))
253
+ data_old = operator.attrgetter(name)(model)
254
+ data_old.data = data_new.data
255
+
256
+
257
+ def _get_weight_arg_mappings(model: Module) -> Dict:
258
+ quantized_modules_to_args = {}
259
+ for name, submodule in iter_named_leaf_modules(model):
260
+ if is_module_quantized(submodule):
261
+ if submodule.quantization_scheme.weights is not None:
262
+ quantized_modules_to_args[name] = submodule.quantization_scheme.weights
263
+
264
+ return quantized_modules_to_args