compressed-tensors-nightly 0.6.0.20240929__py3-none-any.whl → 0.6.0.20241004__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +6 -12
  3. compressed_tensors/compressors/base.py +38 -102
  4. compressed_tensors/compressors/helpers.py +6 -6
  5. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  6. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +91 -53
  7. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  8. compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/naive_quantized.py} +11 -11
  10. compressed_tensors/compressors/{pack_quantized.py → quantized_compressors/pack_quantized.py} +6 -3
  11. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  12. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  13. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  14. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  15. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  16. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  17. compressed_tensors/linear/compressed_linear.py +2 -2
  18. compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  19. compressed_tensors/quantization/lifecycle/initialize.py +2 -1
  20. compressed_tensors/quantization/quant_config.py +7 -0
  21. compressed_tensors/quantization/quant_scheme.py +1 -1
  22. compressed_tensors/utils/helpers.py +17 -1
  23. {compressed_tensors_nightly-0.6.0.20240929.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/METADATA +1 -1
  24. {compressed_tensors_nightly-0.6.0.20240929.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/RECORD +27 -21
  25. {compressed_tensors_nightly-0.6.0.20240929.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/LICENSE +0 -0
  26. {compressed_tensors_nightly-0.6.0.20240929.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/WHEEL +0 -0
  27. {compressed_tensors_nightly-0.6.0.20240929.dist-info → compressed_tensors_nightly-0.6.0.20241004.dist-info}/top_level.txt +0 -0
@@ -17,3 +17,4 @@ QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
18
  KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
19
19
  COMPRESSION_VERSION_NAME = "version"
20
+ QUANTIZATION_METHOD_NAME = "quant_method"
@@ -14,15 +14,9 @@
14
14
 
15
15
  # flake8: noqa
16
16
 
17
- from .base import Compressor
18
- from .dense import DenseCompressor
19
- from .helpers import load_compressed, save_compressed, save_compressed_model
20
- from .marlin_24 import Marlin24Compressor
21
- from .model_compressor import ModelCompressor, map_modules_to_quant_args
22
- from .naive_quantized import (
23
- FloatQuantizationCompressor,
24
- IntQuantizationCompressor,
25
- QuantizationCompressor,
26
- )
27
- from .pack_quantized import PackedQuantizationCompressor
28
- from .sparse_bitmask import BitmaskCompressor, BitmaskTensor
17
+ from .base import *
18
+ from .helpers import *
19
+ from .model_compressors import *
20
+ from .quantized_compressors import *
21
+ from .sparse_compressors import *
22
+ from .sparse_quantized_compressors import *
@@ -12,26 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
15
+ from abc import ABC, abstractmethod
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from compressed_tensors.config import SparsityCompressionConfig
20
20
  from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
21
21
  from compressed_tensors.registry import RegistryMixin
22
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
- from safetensors import safe_open
24
22
  from torch import Tensor
25
- from torch.nn.modules import Module
26
- from tqdm import tqdm
23
+ from torch.nn import Module
27
24
 
28
25
 
29
- _LOGGER: logging.Logger = logging.getLogger(__name__)
26
+ __all__ = ["BaseCompressor"]
30
27
 
31
- __all__ = ["Compressor"]
32
28
 
33
-
34
- class Compressor(RegistryMixin):
29
+ class BaseCompressor(RegistryMixin, ABC):
35
30
  """
36
31
  Base class representing a model compression algorithm. Each child class should
37
32
  implement compression_param_info, compress_weight and decompress_weight.
@@ -42,19 +37,18 @@ class Compressor(RegistryMixin):
42
37
  Model Load Lifecycle (run_compressed=False):
43
38
  - ModelCompressor.decompress()
44
39
  - apply_quantization_config()
45
- - Compressor.decompress()
46
- - Compressor.decompress_weight()
40
+ - BaseCompressor.decompress()
47
41
 
48
42
  Model Save Lifecycle:
49
43
  - ModelCompressor.compress()
50
- - Compressor.compress()
51
- - Compressor.compress_weight()
44
+ - BaseCompressor.compress()
45
+
52
46
 
53
47
  Module Lifecycle (run_compressed=True):
54
48
  - apply_quantization_config()
55
49
  - compressed_module = CompressedLinear(module)
56
50
  - initialize_module_for_quantization()
57
- - Compressor.compression_param_info()
51
+ - BaseCompressor.compression_param_info()
58
52
  - register_parameters()
59
53
  - compressed_module.forward()
60
54
  -compressed_module.decompress()
@@ -83,61 +77,27 @@ class Compressor(RegistryMixin):
83
77
  """
84
78
  raise NotImplementedError()
85
79
 
80
+ @abstractmethod
86
81
  def compress(
87
82
  self,
88
83
  model_state: Dict[str, Tensor],
89
- names_to_scheme: Dict[str, QuantizationArgs],
90
84
  **kwargs,
91
85
  ) -> Dict[str, Tensor]:
92
86
  """
93
87
  Compresses a dense state dict
94
88
 
95
89
  :param model_state: state dict of uncompressed model
96
- :param names_to_scheme: quantization args for each quantized weight, needed for
97
- quantize function to calculate bit depth
90
+ :param kwargs: additional arguments for compression
98
91
  :return: compressed state dict
99
92
  """
100
- compressed_dict = {}
101
- weight_suffix = ".weight"
102
- _LOGGER.debug(
103
- f"Compressing model with {len(model_state)} parameterized layers..."
104
- )
105
-
106
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
107
- if name.endswith(weight_suffix):
108
- prefix = name[: -(len(weight_suffix))]
109
- scale = model_state.get(merge_names(prefix, "weight_scale"), None)
110
- zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
111
- g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
112
- if scale is not None:
113
- # weight is quantized, compress it
114
- quant_args = names_to_scheme[prefix]
115
- compressed_data = self.compress_weight(
116
- weight=value,
117
- scale=scale,
118
- zero_point=zp,
119
- g_idx=g_idx,
120
- quantization_args=quant_args,
121
- device="cpu",
122
- )
123
- for key, value in compressed_data.items():
124
- compressed_dict[merge_names(prefix, key)] = value
125
- else:
126
- compressed_dict[name] = value.to("cpu")
127
- elif name.endswith("zero_point") and torch.all(value == 0):
128
- continue
129
- elif name.endswith("g_idx") and torch.any(value <= -1):
130
- continue
131
- else:
132
- compressed_dict[name] = value.to("cpu")
133
-
134
- return compressed_dict
93
+ raise NotImplementedError()
135
94
 
95
+ @abstractmethod
136
96
  def decompress(
137
97
  self,
138
98
  path_to_model_or_tensors: str,
139
- names_to_scheme: Dict[str, QuantizationArgs],
140
99
  device: str = "cpu",
100
+ **kwargs,
141
101
  ) -> Generator[Tuple[str, Tensor], None, None]:
142
102
  """
143
103
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -150,55 +110,6 @@ class Compressor(RegistryMixin):
150
110
  :param device: optional device to load intermediate weights into
151
111
  :return: compressed state dict
152
112
  """
153
- weight_mappings = get_nested_weight_mappings(
154
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
155
- )
156
- for weight_name in weight_mappings.keys():
157
- weight_data = {}
158
- for param_name, safe_path in weight_mappings[weight_name].items():
159
- full_name = merge_names(weight_name, param_name)
160
- with safe_open(safe_path, framework="pt", device=device) as f:
161
- weight_data[param_name] = f.get_tensor(full_name)
162
-
163
- if "weight_scale" in weight_data:
164
- quant_args = names_to_scheme[weight_name]
165
- decompressed = self.decompress_weight(
166
- compressed_data=weight_data, quantization_args=quant_args
167
- )
168
- yield merge_names(weight_name, "weight"), decompressed
169
-
170
- def compress_weight(
171
- self,
172
- weight: Tensor,
173
- scale: Tensor,
174
- zero_point: Optional[Tensor] = None,
175
- g_idx: Optional[torch.Tensor] = None,
176
- quantization_args: Optional[QuantizationArgs] = None,
177
- ) -> Dict[str, torch.Tensor]:
178
- """
179
- Compresses a single uncompressed weight
180
-
181
- :param weight: uncompressed weight tensor
182
- :param scale: quantization scale for weight
183
- :param zero_point: quantization zero point for weight
184
- :param g_idx: optional mapping from column index to group index
185
- :param quantization_args: quantization parameters for weight
186
- :return: dictionary of compressed weight data
187
- """
188
- raise NotImplementedError()
189
-
190
- def decompress_weight(
191
- self,
192
- compressed_data: Dict[str, Tensor],
193
- quantization_args: Optional[QuantizationArgs] = None,
194
- ) -> torch.Tensor:
195
- """
196
- Decompresses a single compressed weight
197
-
198
- :param compressed_data: dictionary of data needed for decompression
199
- :param quantization_args: quantization parameters for the weight
200
- :return: tensor of the decompressed weight
201
- """
202
113
  raise NotImplementedError()
203
114
 
204
115
  def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
@@ -228,6 +139,19 @@ class Compressor(RegistryMixin):
228
139
  quantization_args=quantization_args,
229
140
  )
230
141
 
142
+ def compress_weight(
143
+ self,
144
+ weight: Tensor,
145
+ **kwargs,
146
+ ) -> Dict[str, torch.Tensor]:
147
+ """
148
+ Compresses a single uncompressed weight
149
+
150
+ :param weight: uncompressed weight tensor
151
+ :param kwargs: additional arguments for compression
152
+ """
153
+ raise NotImplementedError()
154
+
231
155
  def decompress_module(self, module: Module):
232
156
  """
233
157
  Decompresses a single compressed leaf PyTorch module. If the module is not
@@ -250,3 +174,15 @@ class Compressor(RegistryMixin):
250
174
  return self.decompress_weight(
251
175
  compressed_data=compressed_data, quantization_args=quantization_args
252
176
  )
177
+
178
+ def decompress_weight(
179
+ self, compressed_data: Dict[str, Tensor], **kwargs
180
+ ) -> torch.Tensor:
181
+ """
182
+ Decompresses a single compressed weight
183
+
184
+ :param compressed_data: dictionary of data needed for decompression
185
+ :param kwargs: additional arguments for decompression
186
+ :return: tensor of the decompressed weight
187
+ """
188
+ raise NotImplementedError()
@@ -16,7 +16,7 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors import BaseCompressor
20
20
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
21
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
22
22
  from safetensors import safe_open
@@ -52,16 +52,16 @@ def save_compressed(
52
52
  compression_format = compression_format or CompressionFormat.dense.value
53
53
 
54
54
  if not (
55
- compression_format in Compressor.registered_names()
56
- or compression_format in Compressor.registered_aliases()
55
+ compression_format in BaseCompressor.registered_names()
56
+ or compression_format in BaseCompressor.registered_aliases()
57
57
  ):
58
58
  raise ValueError(
59
59
  f"Unknown compression format: {compression_format}. "
60
- f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
60
+ f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501
61
61
  )
62
62
 
63
63
  # compress
64
- compressor = Compressor.load_from_registry(compression_format)
64
+ compressor = BaseCompressor.load_from_registry(compression_format)
65
65
  # save compressed tensors
66
66
  compressed_tensors = compressor.compress(tensors)
67
67
  save_file(compressed_tensors, save_path)
@@ -102,7 +102,7 @@ def load_compressed(
102
102
  else:
103
103
  # decompress tensors
104
104
  compression_format = compression_config.format
105
- compressor = Compressor.load_from_registry(
105
+ compressor = BaseCompressor.load_from_registry(
106
106
  compression_format, config=compression_config
107
107
  )
108
108
  yield from compressor.decompress(compressed_tensors, device=device)
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+
17
+ from .model_compressor import *
@@ -18,20 +18,22 @@ import operator
18
18
  import os
19
19
  import re
20
20
  from copy import deepcopy
21
- from typing import Any, Dict, Optional, Union
21
+ from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
22
 
23
+ import compressed_tensors
23
24
  import torch
24
25
  import transformers
25
- import compressed_tensors
26
26
  from compressed_tensors.base import (
27
27
  COMPRESSION_CONFIG_NAME,
28
28
  COMPRESSION_VERSION_NAME,
29
29
  QUANTIZATION_CONFIG_NAME,
30
+ QUANTIZATION_METHOD_NAME,
30
31
  SPARSITY_CONFIG_NAME,
31
32
  )
32
- from compressed_tensors.compressors import Compressor
33
+ from compressed_tensors.compressors.base import BaseCompressor
33
34
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
34
35
  from compressed_tensors.quantization import (
36
+ DEFAULT_QUANTIZATION_METHOD,
35
37
  QuantizationConfig,
36
38
  QuantizationStatus,
37
39
  apply_quantization_config,
@@ -42,7 +44,10 @@ from compressed_tensors.quantization.utils import (
42
44
  iter_named_leaf_modules,
43
45
  )
44
46
  from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
45
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import (
48
+ fix_fsdp_module_name,
49
+ is_compressed_tensors_config,
50
+ )
46
51
  from torch import Tensor
47
52
  from torch.nn import Module
48
53
  from tqdm import tqdm
@@ -55,6 +60,11 @@ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
55
60
  _LOGGER: logging.Logger = logging.getLogger(__name__)
56
61
 
57
62
 
63
+ if TYPE_CHECKING:
64
+ # dummy type if not available from transformers
65
+ CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
66
+
67
+
58
68
  class ModelCompressor:
59
69
  """
60
70
  Handles compression and decompression of a model with a sparsity config and/or
@@ -90,45 +100,41 @@ class ModelCompressor:
90
100
  configs and load a ModelCompressor
91
101
 
92
102
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
93
- :return: compressor for the extracted configs
103
+ :return: compressor for the configs, or None if model is not compressed
94
104
  """
95
105
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
96
106
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
97
107
  return cls.from_compression_config(compression_config)
98
108
 
99
109
  @classmethod
100
- def from_compression_config(cls, compression_config: Dict[str, Any]):
110
+ def from_compression_config(
111
+ cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
112
+ ):
101
113
  """
102
- :param compression_config: compression/quantization config dictionary
103
- found under key "quantization_config" in HF model config
104
- :return: compressor for the extracted configs
114
+ :param compression_config:
115
+ A compression or quantization config
116
+
117
+ The type is one of the following:
118
+ 1. A Dict found under either "quantization_config" or "compression_config"
119
+ keys in the config.json
120
+ 2. A CompressedTensorsConfig found under key "quantization_config" in HF
121
+ model config
122
+ :return: compressor for the configs, or None if model is not compressed
105
123
  """
106
124
  if compression_config is None:
107
125
  return None
108
126
 
109
- try:
110
- from transformers.utils.quantization_config import CompressedTensorsConfig
111
-
112
- if isinstance(compression_config, CompressedTensorsConfig):
113
- compression_config = compression_config.to_dict()
114
- except ImportError:
115
- pass
116
-
117
127
  sparsity_config = cls.parse_sparsity_config(compression_config)
118
128
  quantization_config = cls.parse_quantization_config(compression_config)
119
129
  if sparsity_config is None and quantization_config is None:
120
130
  return None
121
131
 
122
- if sparsity_config is not None and not isinstance(
123
- sparsity_config, SparsityCompressionConfig
124
- ):
132
+ if sparsity_config is not None:
125
133
  format = sparsity_config.get("format")
126
134
  sparsity_config = SparsityCompressionConfig.load_from_registry(
127
135
  format, **sparsity_config
128
136
  )
129
- if quantization_config is not None and not isinstance(
130
- quantization_config, QuantizationConfig
131
- ):
137
+ if quantization_config is not None:
132
138
  quantization_config = QuantizationConfig.parse_obj(quantization_config)
133
139
 
134
140
  return cls(
@@ -151,7 +157,7 @@ class ModelCompressor:
151
157
  to a sparsity compression algorithm
152
158
  :param quantization_format: string corresponding to a quantization compression
153
159
  algorithm
154
- :return: compressor for the extracted configs
160
+ :return: compressor for the configs, or None if model is not compressed
155
161
  """
156
162
  quantization_config = QuantizationConfig.from_pretrained(
157
163
  model, format=quantization_format
@@ -170,40 +176,60 @@ class ModelCompressor:
170
176
  )
171
177
 
172
178
  @staticmethod
173
- def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
179
+ def parse_sparsity_config(
180
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
181
+ ) -> Union[Dict[str, Any], None]:
182
+ """
183
+ Parse sparsity config from quantization/compression config. Sparsity
184
+ config is nested inside q/c config
185
+
186
+ :param compression_config: quantization/compression config
187
+ :return: sparsity config
188
+ """
174
189
  if compression_config is None:
175
190
  return None
176
- if SPARSITY_CONFIG_NAME not in compression_config:
177
- return None
178
- if hasattr(compression_config, SPARSITY_CONFIG_NAME):
179
- # for loaded HFQuantizer config
180
- return getattr(compression_config, SPARSITY_CONFIG_NAME)
181
- if SPARSITY_CONFIG_NAME in compression_config:
182
- # for loaded HFQuantizer config from dict
183
- return compression_config[SPARSITY_CONFIG_NAME]
184
-
185
- # SparseAutoModel format
191
+
192
+ if is_compressed_tensors_config(compression_config):
193
+ s_config = compression_config.sparsity_config
194
+ return s_config.dict() if s_config is not None else None
195
+
186
196
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
187
197
 
188
198
  @staticmethod
189
- def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
199
+ def parse_quantization_config(
200
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
201
+ ) -> Union[Dict[str, Any], None]:
202
+ """
203
+ Parse quantization config from quantization/compression config. The
204
+ quantization are all the fields that are not the sparsity config or
205
+ metadata fields
206
+
207
+ :param compression_config: quantization/compression config
208
+ :return: quantization config without sparsity config or metadata fields
209
+ """
190
210
  if compression_config is None:
191
211
  return None
192
212
 
193
- if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
194
- # for loaded HFQuantizer config
195
- return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
213
+ if is_compressed_tensors_config(compression_config):
214
+ q_config = compression_config.quantization_config
215
+ return q_config.dict() if q_config is not None else None
196
216
 
197
- if QUANTIZATION_CONFIG_NAME in compression_config:
198
- # for loaded HFQuantizer config from dict
199
- return compression_config[QUANTIZATION_CONFIG_NAME]
200
-
201
- # SparseAutoModel format
202
217
  quantization_config = deepcopy(compression_config)
203
218
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
204
- quantization_config.pop(COMPRESSION_VERSION_NAME, None)
219
+
220
+ # some fields are required, even if a qconfig is not present
221
+ # pop them off and if nothing remains, then there is no qconfig
222
+ quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
223
+ _ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
224
+
205
225
  if len(quantization_config) == 0:
206
- quantization_config = None
226
+ return None
227
+
228
+ # replace popped off values
229
+ # note that version is discarded for now
230
+ if quant_method is not None:
231
+ quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
232
+
207
233
  return quantization_config
208
234
 
209
235
  def __init__(
@@ -216,17 +242,16 @@ class ModelCompressor:
216
242
  self.sparsity_compressor = None
217
243
  self.quantization_compressor = None
218
244
 
219
-
220
245
  if sparsity_config and sparsity_config.format == CompressionFormat.dense.value:
221
246
  # ignore dense sparsity config
222
247
  self.sparsity_config = None
223
248
 
224
249
  if sparsity_config is not None:
225
- self.sparsity_compressor = Compressor.load_from_registry(
250
+ self.sparsity_compressor = BaseCompressor.load_from_registry(
226
251
  sparsity_config.format, config=sparsity_config
227
252
  )
228
253
  if quantization_config is not None:
229
- self.quantization_compressor = Compressor.load_from_registry(
254
+ self.quantization_compressor = BaseCompressor.load_from_registry(
230
255
  quantization_config.format, config=quantization_config
231
256
  )
232
257
 
@@ -237,7 +262,7 @@ class ModelCompressor:
237
262
  Compresses a dense state dict or model with sparsity and/or quantization
238
263
 
239
264
  :param model: uncompressed model to compress
240
- :param model_state: optional uncompressed state_dict to insert into model
265
+ :param state_dict: optional uncompressed state_dict to insert into model
241
266
  :return: compressed state dict
242
267
  """
243
268
  if state_dict is None:
@@ -300,6 +325,9 @@ class ModelCompressor:
300
325
 
301
326
  :param save_directory: path to a folder containing a HF model config
302
327
  """
328
+ if self.quantization_config is None and self.sparsity_config is None:
329
+ return
330
+
303
331
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
304
332
  if not os.path.exists(config_file_path):
305
333
  _LOGGER.warning(
@@ -311,7 +339,20 @@ class ModelCompressor:
311
339
  with open(config_file_path, "r") as config_file:
312
340
  config_data = json.load(config_file)
313
341
 
342
+ # required metadata whenever a quantization or sparsity config is present
343
+ # overwrite previous config and version if already existing
314
344
  config_data[QUANTIZATION_CONFIG_NAME] = {}
345
+ config_data[QUANTIZATION_CONFIG_NAME][
346
+ COMPRESSION_VERSION_NAME
347
+ ] = compressed_tensors.__version__
348
+ if self.quantization_config is not None:
349
+ self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
350
+ else:
351
+ config_data[QUANTIZATION_CONFIG_NAME][
352
+ QUANTIZATION_METHOD_NAME
353
+ ] = DEFAULT_QUANTIZATION_METHOD
354
+
355
+ # quantization and sparsity configs
315
356
  if self.quantization_config is not None:
316
357
  quant_config_data = self.quantization_config.model_dump()
317
358
  config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
@@ -320,9 +361,6 @@ class ModelCompressor:
320
361
  config_data[QUANTIZATION_CONFIG_NAME][
321
362
  SPARSITY_CONFIG_NAME
322
363
  ] = sparsity_config_data
323
- config_data[QUANTIZATION_CONFIG_NAME][
324
- COMPRESSION_VERSION_NAME
325
- ] = compressed_tensors.__version__
326
364
 
327
365
  with open(config_file_path, "w") as config_file:
328
366
  json.dump(config_data, config_file, indent=2, sort_keys=True)
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import *
17
+ from .naive_quantized import *
18
+ from .pack_quantized import *