compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -19,3 +19,4 @@ from .compressors import *
19
19
  from .config import *
20
20
  from .quantization import QuantizationConfig, QuantizationStatus
21
21
  from .utils import *
22
+ from .version import *
@@ -16,3 +16,5 @@ SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
18
  KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
19
+ COMPRESSION_VERSION_NAME = "version"
20
+ QUANTIZATION_METHOD_NAME = "quant_method"
@@ -14,15 +14,9 @@
14
14
 
15
15
  # flake8: noqa
16
16
 
17
- from .base import Compressor
18
- from .dense import DenseCompressor
19
- from .helpers import load_compressed, save_compressed, save_compressed_model
20
- from .marlin_24 import Marlin24Compressor
21
- from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
- from .naive_quantized import (
23
- FloatQuantizationCompressor,
24
- IntQuantizationCompressor,
25
- QuantizationCompressor,
26
- )
27
- from .pack_quantized import PackedQuantizationCompressor
28
- from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
17
+ from .base import *
18
+ from .helpers import *
19
+ from .model_compressors import *
20
+ from .quantized_compressors import *
21
+ from .sparse_compressors import *
22
+ from .sparse_quantized_compressors import *
@@ -12,20 +12,47 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Dict, Generator, Tuple, Union
15
+ from abc import ABC, abstractmethod
16
+ from typing import Dict, Generator, Optional, Tuple, Union
16
17
 
18
+ import torch
17
19
  from compressed_tensors.config import SparsityCompressionConfig
18
- from compressed_tensors.quantization import QuantizationConfig
20
+ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
19
21
  from compressed_tensors.registry import RegistryMixin
20
22
  from torch import Tensor
23
+ from torch.nn import Module
21
24
 
22
25
 
23
- __all__ = ["Compressor"]
26
+ __all__ = ["BaseCompressor"]
24
27
 
25
28
 
26
- class Compressor(RegistryMixin):
29
+ class BaseCompressor(RegistryMixin, ABC):
27
30
  """
28
- Base class representing a model compression algorithm
31
+ Base class representing a model compression algorithm. Each child class should
32
+ implement compression_param_info, compress_weight and decompress_weight.
33
+
34
+ Compressors support compressing/decompressing a full module state dict or a single
35
+ quantized PyTorch leaf module.
36
+
37
+ Model Load Lifecycle (run_compressed=False):
38
+ - ModelCompressor.decompress()
39
+ - apply_quantization_config()
40
+ - BaseCompressor.decompress()
41
+
42
+ Model Save Lifecycle:
43
+ - ModelCompressor.compress()
44
+ - BaseCompressor.compress()
45
+
46
+
47
+ Module Lifecycle (run_compressed=True):
48
+ - apply_quantization_config()
49
+ - compressed_module = CompressedLinear(module)
50
+ - initialize_module_for_quantization()
51
+ - BaseCompressor.compression_param_info()
52
+ - register_parameters()
53
+ - compressed_module.forward()
54
+ -compressed_module.decompress()
55
+
29
56
 
30
57
  :param config: config specifying compression parameters
31
58
  """
@@ -35,26 +62,127 @@ class Compressor(RegistryMixin):
35
62
  ):
36
63
  self.config = config
37
64
 
38
- def compress(self, model_state: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
65
+ def compression_param_info(
66
+ self,
67
+ weight_shape: torch.Size,
68
+ quantization_args: Optional[QuantizationArgs] = None,
69
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
70
+ """
71
+ Creates a dictionary of expected shapes and dtypes for each compression
72
+ parameter used by the compressor
73
+
74
+ :param weight_shape: uncompressed weight shape
75
+ :param quantization_args: quantization parameters for the weight
76
+ :return: dictionary mapping compressed parameter names to shape and dtype
77
+ """
78
+ raise NotImplementedError()
79
+
80
+ @abstractmethod
81
+ def compress(
82
+ self,
83
+ model_state: Dict[str, Tensor],
84
+ **kwargs,
85
+ ) -> Dict[str, Tensor]:
39
86
  """
40
87
  Compresses a dense state dict
41
88
 
42
89
  :param model_state: state dict of uncompressed model
90
+ :param kwargs: additional arguments for compression
43
91
  :return: compressed state dict
44
92
  """
45
93
  raise NotImplementedError()
46
94
 
95
+ @abstractmethod
47
96
  def decompress(
48
- self, path_to_model_or_tensors: str, device: str = "cpu", **kwargs
97
+ self,
98
+ path_to_model_or_tensors: str,
99
+ device: str = "cpu",
100
+ **kwargs,
49
101
  ) -> Generator[Tuple[str, Tensor], None, None]:
50
102
  """
51
103
  Reads a compressed state dict located at path_to_model_or_tensors
52
104
  and returns a generator for sequentially decompressing back to a
53
105
  dense state dict
54
106
 
55
- :param model_path: path to compressed safetensors model (directory with
56
- one or more safetensors files) or compressed tensors file
107
+ :param path_to_model_or_tensors: path to compressed safetensors model (directory
108
+ with one or more safetensors files) or compressed tensors file
109
+ :param names_to_scheme: quantization args for each quantized weight
57
110
  :param device: optional device to load intermediate weights into
58
111
  :return: compressed state dict
59
112
  """
60
113
  raise NotImplementedError()
114
+
115
+ def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
116
+ """
117
+ Compresses a single quantized leaf PyTorch module. If the module is not
118
+ quantized, this function has no effect.
119
+
120
+ :param module: PyTorch module to compress
121
+ :return: dictionary of compressed weight data, or None if module is not
122
+ quantized
123
+ """
124
+ if not hasattr(module, "quantization_scheme"):
125
+ return None # module is not quantized
126
+ quantization_scheme = module.quantization_scheme
127
+ if not hasattr(quantization_scheme, "weights"):
128
+ return None # weights are not quantized
129
+
130
+ quantization_args = quantization_scheme.weights
131
+ weight = getattr(module, "weight", None)
132
+ weight_scale = getattr(module, "weight_scale", None)
133
+ weight_zero_point = getattr(module, "weight_zero_point", None)
134
+
135
+ return self.compress_weight(
136
+ weight=weight,
137
+ scale=weight_scale,
138
+ zero_point=weight_zero_point,
139
+ quantization_args=quantization_args,
140
+ )
141
+
142
+ def compress_weight(
143
+ self,
144
+ weight: Tensor,
145
+ **kwargs,
146
+ ) -> Dict[str, torch.Tensor]:
147
+ """
148
+ Compresses a single uncompressed weight
149
+
150
+ :param weight: uncompressed weight tensor
151
+ :param kwargs: additional arguments for compression
152
+ """
153
+ raise NotImplementedError()
154
+
155
+ def decompress_module(self, module: Module):
156
+ """
157
+ Decompresses a single compressed leaf PyTorch module. If the module is not
158
+ quantized, this function has no effect.
159
+
160
+ :param module: PyTorch module to decompress
161
+ :return: tensor of the decompressed weight, or None if module is not quantized
162
+ """
163
+ if not hasattr(module, "quantization_scheme"):
164
+ return None # module is not quantized
165
+ quantization_scheme = module.quantization_scheme
166
+ if not hasattr(quantization_scheme, "weights"):
167
+ return None # weights are not quantized
168
+
169
+ quantization_args = quantization_scheme.weights
170
+ compressed_data = {}
171
+ for name, parameter in module.named_parameters():
172
+ compressed_data[name] = parameter
173
+
174
+ return self.decompress_weight(
175
+ compressed_data=compressed_data, quantization_args=quantization_args
176
+ )
177
+
178
+ def decompress_weight(
179
+ self, compressed_data: Dict[str, Tensor], **kwargs
180
+ ) -> torch.Tensor:
181
+ """
182
+ Decompresses a single compressed weight
183
+
184
+ :param compressed_data: dictionary of data needed for decompression
185
+ :param kwargs: additional arguments for decompression
186
+ :return: tensor of the decompressed weight
187
+ """
188
+ raise NotImplementedError()
@@ -16,7 +16,7 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors import BaseCompressor
20
20
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
21
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
22
22
  from safetensors import safe_open
@@ -52,16 +52,16 @@ def save_compressed(
52
52
  compression_format = compression_format or CompressionFormat.dense.value
53
53
 
54
54
  if not (
55
- compression_format in Compressor.registered_names()
56
- or compression_format in Compressor.registered_aliases()
55
+ compression_format in BaseCompressor.registered_names()
56
+ or compression_format in BaseCompressor.registered_aliases()
57
57
  ):
58
58
  raise ValueError(
59
59
  f"Unknown compression format: {compression_format}. "
60
- f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
60
+ f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501
61
61
  )
62
62
 
63
63
  # compress
64
- compressor = Compressor.load_from_registry(compression_format)
64
+ compressor = BaseCompressor.load_from_registry(compression_format)
65
65
  # save compressed tensors
66
66
  compressed_tensors = compressor.compress(tensors)
67
67
  save_file(compressed_tensors, save_path)
@@ -102,7 +102,7 @@ def load_compressed(
102
102
  else:
103
103
  # decompress tensors
104
104
  compression_format = compression_config.format
105
- compressor = Compressor.load_from_registry(
105
+ compressor = BaseCompressor.load_from_registry(
106
106
  compression_format, config=compression_config
107
107
  )
108
108
  yield from compressor.decompress(compressed_tensors, device=device)
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+
17
+ from .model_compressor import *
@@ -18,18 +18,22 @@ import operator
18
18
  import os
19
19
  import re
20
20
  from copy import deepcopy
21
- from typing import Any, Dict, Optional, Union
21
+ from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
22
 
23
+ import compressed_tensors
23
24
  import torch
24
25
  import transformers
25
26
  from compressed_tensors.base import (
26
27
  COMPRESSION_CONFIG_NAME,
28
+ COMPRESSION_VERSION_NAME,
27
29
  QUANTIZATION_CONFIG_NAME,
30
+ QUANTIZATION_METHOD_NAME,
28
31
  SPARSITY_CONFIG_NAME,
29
32
  )
30
- from compressed_tensors.compressors import Compressor
31
- from compressed_tensors.config import SparsityCompressionConfig
33
+ from compressed_tensors.compressors.base import BaseCompressor
34
+ from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
32
35
  from compressed_tensors.quantization import (
36
+ DEFAULT_QUANTIZATION_METHOD,
33
37
  QuantizationConfig,
34
38
  QuantizationStatus,
35
39
  apply_quantization_config,
@@ -40,7 +44,10 @@ from compressed_tensors.quantization.utils import (
40
44
  iter_named_leaf_modules,
41
45
  )
42
46
  from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
43
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import (
48
+ fix_fsdp_module_name,
49
+ is_compressed_tensors_config,
50
+ )
44
51
  from torch import Tensor
45
52
  from torch.nn import Module
46
53
  from tqdm import tqdm
@@ -53,6 +60,11 @@ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
53
60
  _LOGGER: logging.Logger = logging.getLogger(__name__)
54
61
 
55
62
 
63
+ if TYPE_CHECKING:
64
+ # dummy type if not available from transformers
65
+ CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
66
+
67
+
56
68
  class ModelCompressor:
57
69
  """
58
70
  Handles compression and decompression of a model with a sparsity config and/or
@@ -88,45 +100,41 @@ class ModelCompressor:
88
100
  configs and load a ModelCompressor
89
101
 
90
102
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
91
- :return: compressor for the extracted configs
103
+ :return: compressor for the configs, or None if model is not compressed
92
104
  """
93
105
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
94
106
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
95
107
  return cls.from_compression_config(compression_config)
96
108
 
97
109
  @classmethod
98
- def from_compression_config(cls, compression_config: Dict[str, Any]):
110
+ def from_compression_config(
111
+ cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
112
+ ):
99
113
  """
100
- :param compression_config: compression/quantization config dictionary
101
- found under key "quantization_config" in HF model config
102
- :return: compressor for the extracted configs
114
+ :param compression_config:
115
+ A compression or quantization config
116
+
117
+ The type is one of the following:
118
+ 1. A Dict found under either "quantization_config" or "compression_config"
119
+ keys in the config.json
120
+ 2. A CompressedTensorsConfig found under key "quantization_config" in HF
121
+ model config
122
+ :return: compressor for the configs, or None if model is not compressed
103
123
  """
104
124
  if compression_config is None:
105
125
  return None
106
126
 
107
- try:
108
- from transformers.utils.quantization_config import CompressedTensorsConfig
109
-
110
- if isinstance(compression_config, CompressedTensorsConfig):
111
- compression_config = compression_config.to_dict()
112
- except ImportError:
113
- pass
114
-
115
127
  sparsity_config = cls.parse_sparsity_config(compression_config)
116
128
  quantization_config = cls.parse_quantization_config(compression_config)
117
129
  if sparsity_config is None and quantization_config is None:
118
130
  return None
119
131
 
120
- if sparsity_config is not None and not isinstance(
121
- sparsity_config, SparsityCompressionConfig
122
- ):
132
+ if sparsity_config is not None:
123
133
  format = sparsity_config.get("format")
124
134
  sparsity_config = SparsityCompressionConfig.load_from_registry(
125
135
  format, **sparsity_config
126
136
  )
127
- if quantization_config is not None and not isinstance(
128
- quantization_config, QuantizationConfig
129
- ):
137
+ if quantization_config is not None:
130
138
  quantization_config = QuantizationConfig.parse_obj(quantization_config)
131
139
 
132
140
  return cls(
@@ -149,7 +157,7 @@ class ModelCompressor:
149
157
  to a sparsity compression algorithm
150
158
  :param quantization_format: string corresponding to a quantization compression
151
159
  algorithm
152
- :return: compressor for the extracted configs
160
+ :return: compressor for the configs, or None if model is not compressed
153
161
  """
154
162
  quantization_config = QuantizationConfig.from_pretrained(
155
163
  model, format=quantization_format
@@ -168,32 +176,60 @@ class ModelCompressor:
168
176
  )
169
177
 
170
178
  @staticmethod
171
- def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
179
+ def parse_sparsity_config(
180
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
181
+ ) -> Union[Dict[str, Any], None]:
182
+ """
183
+ Parse sparsity config from quantization/compression config. Sparsity
184
+ config is nested inside q/c config
185
+
186
+ :param compression_config: quantization/compression config
187
+ :return: sparsity config
188
+ """
172
189
  if compression_config is None:
173
190
  return None
174
- if SPARSITY_CONFIG_NAME not in compression_config:
175
- return None
176
- if hasattr(compression_config, SPARSITY_CONFIG_NAME):
177
- # for loaded HFQuantizer config
178
- return getattr(compression_config, SPARSITY_CONFIG_NAME)
179
191
 
180
- # SparseAutoModel format
192
+ if is_compressed_tensors_config(compression_config):
193
+ s_config = compression_config.sparsity_config
194
+ return s_config.dict() if s_config is not None else None
195
+
181
196
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
182
197
 
183
198
  @staticmethod
184
- def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
199
+ def parse_quantization_config(
200
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
201
+ ) -> Union[Dict[str, Any], None]:
202
+ """
203
+ Parse quantization config from quantization/compression config. The
204
+ quantization are all the fields that are not the sparsity config or
205
+ metadata fields
206
+
207
+ :param compression_config: quantization/compression config
208
+ :return: quantization config without sparsity config or metadata fields
209
+ """
185
210
  if compression_config is None:
186
211
  return None
187
212
 
188
- if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
189
- # for loaded HFQuantizer config
190
- return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
213
+ if is_compressed_tensors_config(compression_config):
214
+ q_config = compression_config.quantization_config
215
+ return q_config.dict() if q_config is not None else None
191
216
 
192
- # SparseAutoModel format
193
217
  quantization_config = deepcopy(compression_config)
194
218
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
219
+
220
+ # some fields are required, even if a qconfig is not present
221
+ # pop them off and if nothing remains, then there is no qconfig
222
+ quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
223
+ _ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
224
+
195
225
  if len(quantization_config) == 0:
196
- quantization_config = None
226
+ return None
227
+
228
+ # replace popped off values
229
+ # note that version is discarded for now
230
+ if quant_method is not None:
231
+ quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
232
+
197
233
  return quantization_config
198
234
 
199
235
  def __init__(
@@ -207,11 +243,11 @@ class ModelCompressor:
207
243
  self.quantization_compressor = None
208
244
 
209
245
  if sparsity_config is not None:
210
- self.sparsity_compressor = Compressor.load_from_registry(
246
+ self.sparsity_compressor = BaseCompressor.load_from_registry(
211
247
  sparsity_config.format, config=sparsity_config
212
248
  )
213
249
  if quantization_config is not None:
214
- self.quantization_compressor = Compressor.load_from_registry(
250
+ self.quantization_compressor = BaseCompressor.load_from_registry(
215
251
  quantization_config.format, config=quantization_config
216
252
  )
217
253
 
@@ -222,7 +258,7 @@ class ModelCompressor:
222
258
  Compresses a dense state dict or model with sparsity and/or quantization
223
259
 
224
260
  :param model: uncompressed model to compress
225
- :param model_state: optional uncompressed state_dict to insert into model
261
+ :param state_dict: optional uncompressed state_dict to insert into model
226
262
  :return: compressed state dict
227
263
  """
228
264
  if state_dict is None:
@@ -234,6 +270,10 @@ class ModelCompressor:
234
270
  compressed_state_dict = self.quantization_compressor.compress(
235
271
  state_dict, names_to_scheme=quantized_modules_to_args
236
272
  )
273
+ if self.quantization_config.format != CompressionFormat.dense.value:
274
+ self.quantization_config.quantization_status = (
275
+ QuantizationStatus.COMPRESSED
276
+ )
237
277
 
238
278
  if self.sparsity_compressor is not None:
239
279
  compressed_state_dict = self.sparsity_compressor.compress(
@@ -281,6 +321,9 @@ class ModelCompressor:
281
321
 
282
322
  :param save_directory: path to a folder containing a HF model config
283
323
  """
324
+ if self.quantization_config is None and self.sparsity_config is None:
325
+ return
326
+
284
327
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
285
328
  if not os.path.exists(config_file_path):
286
329
  _LOGGER.warning(
@@ -292,13 +335,26 @@ class ModelCompressor:
292
335
  with open(config_file_path, "r") as config_file:
293
336
  config_data = json.load(config_file)
294
337
 
295
- config_data[COMPRESSION_CONFIG_NAME] = {}
338
+ # required metadata whenever a quantization or sparsity config is present
339
+ # overwrite previous config and version if already existing
340
+ config_data[QUANTIZATION_CONFIG_NAME] = {}
341
+ config_data[QUANTIZATION_CONFIG_NAME][
342
+ COMPRESSION_VERSION_NAME
343
+ ] = compressed_tensors.__version__
344
+ if self.quantization_config is not None:
345
+ self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
346
+ else:
347
+ config_data[QUANTIZATION_CONFIG_NAME][
348
+ QUANTIZATION_METHOD_NAME
349
+ ] = DEFAULT_QUANTIZATION_METHOD
350
+
351
+ # quantization and sparsity configs
296
352
  if self.quantization_config is not None:
297
353
  quant_config_data = self.quantization_config.model_dump()
298
- config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
354
+ config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
299
355
  if self.sparsity_config is not None:
300
356
  sparsity_config_data = self.sparsity_config.model_dump()
301
- config_data[COMPRESSION_CONFIG_NAME][
357
+ config_data[QUANTIZATION_CONFIG_NAME][
302
358
  SPARSITY_CONFIG_NAME
303
359
  ] = sparsity_config_data
304
360
 
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import *
17
+ from .naive_quantized import *
18
+ from .pack_quantized import *