compressed-tensors 0.6.0__tar.gz → 0.7.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/PKG-INFO +1 -1
  2. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/__init__.py +1 -0
  3. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/base.py +2 -0
  4. compressed-tensors-0.7.1/src/compressed_tensors/compressors/__init__.py +22 -0
  5. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/compressors/base.py +38 -102
  6. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/compressors/helpers.py +6 -6
  7. compressed-tensors-0.7.1/src/compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  8. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/model_compressors}/model_compressor.py +95 -106
  9. compressed-tensors-0.7.1/src/compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  10. compressed-tensors-0.7.1/src/compressed_tensors/compressors/quantized_compressors/base.py +146 -0
  11. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/quantized_compressors}/naive_quantized.py +11 -11
  12. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/quantized_compressors}/pack_quantized.py +6 -3
  13. compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  14. compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  15. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_compressors}/dense.py +3 -3
  16. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_compressors}/sparse_bitmask.py +14 -59
  17. compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  18. {compressed-tensors-0.6.0/src/compressed_tensors/compressors → compressed-tensors-0.7.1/src/compressed_tensors/compressors/sparse_quantized_compressors}/marlin_24.py +3 -3
  19. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/linear/compressed_linear.py +2 -2
  20. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/__init__.py +1 -0
  21. compressed-tensors-0.7.1/src/compressed_tensors/quantization/cache.py +201 -0
  22. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/apply.py +19 -3
  23. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/calibration.py +2 -3
  24. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/forward.py +58 -7
  25. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  26. compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/helpers.py +33 -0
  27. compressed-tensors-0.7.1/src/compressed_tensors/quantization/lifecycle/initialize.py +240 -0
  28. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/observers/__init__.py +0 -1
  29. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/observers/helpers.py +40 -2
  30. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/quant_args.py +34 -4
  31. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/quant_config.py +14 -2
  32. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/quant_scheme.py +8 -4
  33. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/utils/helpers.py +43 -18
  34. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/helpers.py +17 -1
  35. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/version.py +1 -1
  36. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors.egg-info/PKG-INFO +1 -1
  37. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors.egg-info/SOURCES.txt +13 -7
  38. compressed-tensors-0.6.0/src/compressed_tensors/compressors/__init__.py +0 -28
  39. compressed-tensors-0.6.0/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -80
  40. compressed-tensors-0.6.0/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -191
  41. compressed-tensors-0.6.0/src/compressed_tensors/quantization/observers/memoryless.py +0 -56
  42. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/LICENSE +0 -0
  43. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/README.md +0 -0
  44. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/pyproject.toml +0 -0
  45. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/setup.cfg +0 -0
  46. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/setup.py +0 -0
  47. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/config/__init__.py +0 -0
  48. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/config/base.py +0 -0
  49. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/config/dense.py +0 -0
  50. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  51. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/linear/__init__.py +0 -0
  52. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  53. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  54. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/observers/base.py +0 -0
  55. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/observers/min_max.py +0 -0
  56. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/observers/mse.py +0 -0
  57. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  58. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/registry/__init__.py +0 -0
  59. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/registry/registry.py +0 -0
  60. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/__init__.py +0 -0
  61. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/offload.py +0 -0
  62. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/permutations_24.py +0 -0
  63. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/permute.py +0 -0
  64. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  65. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  66. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors.egg-info/dependency_links.txt +0 -0
  67. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors.egg-info/requires.txt +0 -0
  68. {compressed-tensors-0.6.0 → compressed-tensors-0.7.1}/src/compressed_tensors.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.6.0
3
+ Version: 0.7.1
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -19,3 +19,4 @@ from .compressors import *
19
19
  from .config import *
20
20
  from .quantization import QuantizationConfig, QuantizationStatus
21
21
  from .utils import *
22
+ from .version import *
@@ -16,3 +16,5 @@ SPARSITY_CONFIG_NAME = "sparsity_config"
16
16
  QUANTIZATION_CONFIG_NAME = "quantization_config"
17
17
  COMPRESSION_CONFIG_NAME = "compression_config"
18
18
  KV_CACHE_SCHEME_NAME = "kv_cache_scheme"
19
+ COMPRESSION_VERSION_NAME = "version"
20
+ QUANTIZATION_METHOD_NAME = "quant_method"
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # flake8: noqa
16
+
17
+ from .base import *
18
+ from .helpers import *
19
+ from .model_compressors import *
20
+ from .quantized_compressors import *
21
+ from .sparse_compressors import *
22
+ from .sparse_quantized_compressors import *
@@ -12,26 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import logging
15
+ from abc import ABC, abstractmethod
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
19
  from compressed_tensors.config import SparsityCompressionConfig
20
20
  from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
21
21
  from compressed_tensors.registry import RegistryMixin
22
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
23
- from safetensors import safe_open
24
22
  from torch import Tensor
25
- from torch.nn.modules import Module
26
- from tqdm import tqdm
23
+ from torch.nn import Module
27
24
 
28
25
 
29
- _LOGGER: logging.Logger = logging.getLogger(__name__)
26
+ __all__ = ["BaseCompressor"]
30
27
 
31
- __all__ = ["Compressor"]
32
28
 
33
-
34
- class Compressor(RegistryMixin):
29
+ class BaseCompressor(RegistryMixin, ABC):
35
30
  """
36
31
  Base class representing a model compression algorithm. Each child class should
37
32
  implement compression_param_info, compress_weight and decompress_weight.
@@ -42,19 +37,18 @@ class Compressor(RegistryMixin):
42
37
  Model Load Lifecycle (run_compressed=False):
43
38
  - ModelCompressor.decompress()
44
39
  - apply_quantization_config()
45
- - Compressor.decompress()
46
- - Compressor.decompress_weight()
40
+ - BaseCompressor.decompress()
47
41
 
48
42
  Model Save Lifecycle:
49
43
  - ModelCompressor.compress()
50
- - Compressor.compress()
51
- - Compressor.compress_weight()
44
+ - BaseCompressor.compress()
45
+
52
46
 
53
47
  Module Lifecycle (run_compressed=True):
54
48
  - apply_quantization_config()
55
49
  - compressed_module = CompressedLinear(module)
56
50
  - initialize_module_for_quantization()
57
- - Compressor.compression_param_info()
51
+ - BaseCompressor.compression_param_info()
58
52
  - register_parameters()
59
53
  - compressed_module.forward()
60
54
  -compressed_module.decompress()
@@ -83,61 +77,27 @@ class Compressor(RegistryMixin):
83
77
  """
84
78
  raise NotImplementedError()
85
79
 
80
+ @abstractmethod
86
81
  def compress(
87
82
  self,
88
83
  model_state: Dict[str, Tensor],
89
- names_to_scheme: Dict[str, QuantizationArgs],
90
84
  **kwargs,
91
85
  ) -> Dict[str, Tensor]:
92
86
  """
93
87
  Compresses a dense state dict
94
88
 
95
89
  :param model_state: state dict of uncompressed model
96
- :param names_to_scheme: quantization args for each quantized weight, needed for
97
- quantize function to calculate bit depth
90
+ :param kwargs: additional arguments for compression
98
91
  :return: compressed state dict
99
92
  """
100
- compressed_dict = {}
101
- weight_suffix = ".weight"
102
- _LOGGER.debug(
103
- f"Compressing model with {len(model_state)} parameterized layers..."
104
- )
105
-
106
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
107
- if name.endswith(weight_suffix):
108
- prefix = name[: -(len(weight_suffix))]
109
- scale = model_state.get(merge_names(prefix, "weight_scale"), None)
110
- zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
111
- g_idx = model_state.get(merge_names(prefix, "weight_g_idx"), None)
112
- if scale is not None:
113
- # weight is quantized, compress it
114
- quant_args = names_to_scheme[prefix]
115
- compressed_data = self.compress_weight(
116
- weight=value,
117
- scale=scale,
118
- zero_point=zp,
119
- g_idx=g_idx,
120
- quantization_args=quant_args,
121
- device="cpu",
122
- )
123
- for key, value in compressed_data.items():
124
- compressed_dict[merge_names(prefix, key)] = value
125
- else:
126
- compressed_dict[name] = value.to("cpu")
127
- elif name.endswith("zero_point") and torch.all(value == 0):
128
- continue
129
- elif name.endswith("g_idx") and torch.any(value <= -1):
130
- continue
131
- else:
132
- compressed_dict[name] = value.to("cpu")
133
-
134
- return compressed_dict
93
+ raise NotImplementedError()
135
94
 
95
+ @abstractmethod
136
96
  def decompress(
137
97
  self,
138
98
  path_to_model_or_tensors: str,
139
- names_to_scheme: Dict[str, QuantizationArgs],
140
99
  device: str = "cpu",
100
+ **kwargs,
141
101
  ) -> Generator[Tuple[str, Tensor], None, None]:
142
102
  """
143
103
  Reads a compressed state dict located at path_to_model_or_tensors
@@ -150,55 +110,6 @@ class Compressor(RegistryMixin):
150
110
  :param device: optional device to load intermediate weights into
151
111
  :return: compressed state dict
152
112
  """
153
- weight_mappings = get_nested_weight_mappings(
154
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
155
- )
156
- for weight_name in weight_mappings.keys():
157
- weight_data = {}
158
- for param_name, safe_path in weight_mappings[weight_name].items():
159
- full_name = merge_names(weight_name, param_name)
160
- with safe_open(safe_path, framework="pt", device=device) as f:
161
- weight_data[param_name] = f.get_tensor(full_name)
162
-
163
- if "weight_scale" in weight_data:
164
- quant_args = names_to_scheme[weight_name]
165
- decompressed = self.decompress_weight(
166
- compressed_data=weight_data, quantization_args=quant_args
167
- )
168
- yield merge_names(weight_name, "weight"), decompressed
169
-
170
- def compress_weight(
171
- self,
172
- weight: Tensor,
173
- scale: Tensor,
174
- zero_point: Optional[Tensor] = None,
175
- g_idx: Optional[torch.Tensor] = None,
176
- quantization_args: Optional[QuantizationArgs] = None,
177
- ) -> Dict[str, torch.Tensor]:
178
- """
179
- Compresses a single uncompressed weight
180
-
181
- :param weight: uncompressed weight tensor
182
- :param scale: quantization scale for weight
183
- :param zero_point: quantization zero point for weight
184
- :param g_idx: optional mapping from column index to group index
185
- :param quantization_args: quantization parameters for weight
186
- :return: dictionary of compressed weight data
187
- """
188
- raise NotImplementedError()
189
-
190
- def decompress_weight(
191
- self,
192
- compressed_data: Dict[str, Tensor],
193
- quantization_args: Optional[QuantizationArgs] = None,
194
- ) -> torch.Tensor:
195
- """
196
- Decompresses a single compressed weight
197
-
198
- :param compressed_data: dictionary of data needed for decompression
199
- :param quantization_args: quantization parameters for the weight
200
- :return: tensor of the decompressed weight
201
- """
202
113
  raise NotImplementedError()
203
114
 
204
115
  def compress_module(self, module: Module) -> Optional[Dict[str, torch.Tensor]]:
@@ -228,6 +139,19 @@ class Compressor(RegistryMixin):
228
139
  quantization_args=quantization_args,
229
140
  )
230
141
 
142
+ def compress_weight(
143
+ self,
144
+ weight: Tensor,
145
+ **kwargs,
146
+ ) -> Dict[str, torch.Tensor]:
147
+ """
148
+ Compresses a single uncompressed weight
149
+
150
+ :param weight: uncompressed weight tensor
151
+ :param kwargs: additional arguments for compression
152
+ """
153
+ raise NotImplementedError()
154
+
231
155
  def decompress_module(self, module: Module):
232
156
  """
233
157
  Decompresses a single compressed leaf PyTorch module. If the module is not
@@ -250,3 +174,15 @@ class Compressor(RegistryMixin):
250
174
  return self.decompress_weight(
251
175
  compressed_data=compressed_data, quantization_args=quantization_args
252
176
  )
177
+
178
+ def decompress_weight(
179
+ self, compressed_data: Dict[str, Tensor], **kwargs
180
+ ) -> torch.Tensor:
181
+ """
182
+ Decompresses a single compressed weight
183
+
184
+ :param compressed_data: dictionary of data needed for decompression
185
+ :param kwargs: additional arguments for decompression
186
+ :return: tensor of the decompressed weight
187
+ """
188
+ raise NotImplementedError()
@@ -16,7 +16,7 @@ from pathlib import Path
16
16
  from typing import Dict, Generator, Optional, Tuple, Union
17
17
 
18
18
  import torch
19
- from compressed_tensors.compressors import Compressor
19
+ from compressed_tensors.compressors import BaseCompressor
20
20
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
21
21
  from compressed_tensors.utils.safetensors_load import get_weight_mappings
22
22
  from safetensors import safe_open
@@ -52,16 +52,16 @@ def save_compressed(
52
52
  compression_format = compression_format or CompressionFormat.dense.value
53
53
 
54
54
  if not (
55
- compression_format in Compressor.registered_names()
56
- or compression_format in Compressor.registered_aliases()
55
+ compression_format in BaseCompressor.registered_names()
56
+ or compression_format in BaseCompressor.registered_aliases()
57
57
  ):
58
58
  raise ValueError(
59
59
  f"Unknown compression format: {compression_format}. "
60
- f"Must be one of {set(Compressor.registered_names() + Compressor.registered_aliases())}" # noqa E501
60
+ f"Must be one of {set(BaseCompressor.registered_names() + BaseCompressor.registered_aliases())}" # noqa E501
61
61
  )
62
62
 
63
63
  # compress
64
- compressor = Compressor.load_from_registry(compression_format)
64
+ compressor = BaseCompressor.load_from_registry(compression_format)
65
65
  # save compressed tensors
66
66
  compressed_tensors = compressor.compress(tensors)
67
67
  save_file(compressed_tensors, save_path)
@@ -102,7 +102,7 @@ def load_compressed(
102
102
  else:
103
103
  # decompress tensors
104
104
  compression_format = compression_config.format
105
- compressor = Compressor.load_from_registry(
105
+ compressor = BaseCompressor.load_from_registry(
106
106
  compression_format, config=compression_config
107
107
  )
108
108
  yield from compressor.decompress(compressed_tensors, device=device)
@@ -0,0 +1,17 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+
17
+ from .model_compressor import *
@@ -18,18 +18,22 @@ import operator
18
18
  import os
19
19
  import re
20
20
  from copy import deepcopy
21
- from typing import Any, Dict, Optional, Union
21
+ from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
22
 
23
+ import compressed_tensors
23
24
  import torch
24
25
  import transformers
25
26
  from compressed_tensors.base import (
26
27
  COMPRESSION_CONFIG_NAME,
28
+ COMPRESSION_VERSION_NAME,
27
29
  QUANTIZATION_CONFIG_NAME,
30
+ QUANTIZATION_METHOD_NAME,
28
31
  SPARSITY_CONFIG_NAME,
29
32
  )
30
- from compressed_tensors.compressors import Compressor
33
+ from compressed_tensors.compressors.base import BaseCompressor
31
34
  from compressed_tensors.config import CompressionFormat, SparsityCompressionConfig
32
35
  from compressed_tensors.quantization import (
36
+ DEFAULT_QUANTIZATION_METHOD,
33
37
  QuantizationConfig,
34
38
  QuantizationStatus,
35
39
  apply_quantization_config,
@@ -40,7 +44,10 @@ from compressed_tensors.quantization.utils import (
40
44
  iter_named_leaf_modules,
41
45
  )
42
46
  from compressed_tensors.utils import get_safetensors_folder, update_parameter_data
43
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import (
48
+ fix_fsdp_module_name,
49
+ is_compressed_tensors_config,
50
+ )
44
51
  from torch import Tensor
45
52
  from torch.nn import Module
46
53
  from tqdm import tqdm
@@ -53,6 +60,11 @@ __all__ = ["ModelCompressor", "map_modules_to_quant_args"]
53
60
  _LOGGER: logging.Logger = logging.getLogger(__name__)
54
61
 
55
62
 
63
+ if TYPE_CHECKING:
64
+ # dummy type if not available from transformers
65
+ CompressedTensorsConfig = TypeVar("CompressedTensorsConfig")
66
+
67
+
56
68
  class ModelCompressor:
57
69
  """
58
70
  Handles compression and decompression of a model with a sparsity config and/or
@@ -88,45 +100,41 @@ class ModelCompressor:
88
100
  configs and load a ModelCompressor
89
101
 
90
102
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
91
- :return: compressor for the extracted configs
103
+ :return: compressor for the configs, or None if model is not compressed
92
104
  """
93
105
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
94
106
  compression_config = getattr(config, COMPRESSION_CONFIG_NAME, None)
95
107
  return cls.from_compression_config(compression_config)
96
108
 
97
109
  @classmethod
98
- def from_compression_config(cls, compression_config: Dict[str, Any]):
110
+ def from_compression_config(
111
+ cls, compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
112
+ ):
99
113
  """
100
- :param compression_config: compression/quantization config dictionary
101
- found under key "quantization_config" in HF model config
102
- :return: compressor for the extracted configs
114
+ :param compression_config:
115
+ A compression or quantization config
116
+
117
+ The type is one of the following:
118
+ 1. A Dict found under either "quantization_config" or "compression_config"
119
+ keys in the config.json
120
+ 2. A CompressedTensorsConfig found under key "quantization_config" in HF
121
+ model config
122
+ :return: compressor for the configs, or None if model is not compressed
103
123
  """
104
124
  if compression_config is None:
105
125
  return None
106
126
 
107
- try:
108
- from transformers.utils.quantization_config import CompressedTensorsConfig
109
-
110
- if isinstance(compression_config, CompressedTensorsConfig):
111
- compression_config = compression_config.to_dict()
112
- except ImportError:
113
- pass
114
-
115
127
  sparsity_config = cls.parse_sparsity_config(compression_config)
116
128
  quantization_config = cls.parse_quantization_config(compression_config)
117
129
  if sparsity_config is None and quantization_config is None:
118
130
  return None
119
131
 
120
- if sparsity_config is not None and not isinstance(
121
- sparsity_config, SparsityCompressionConfig
122
- ):
132
+ if sparsity_config is not None:
123
133
  format = sparsity_config.get("format")
124
134
  sparsity_config = SparsityCompressionConfig.load_from_registry(
125
135
  format, **sparsity_config
126
136
  )
127
- if quantization_config is not None and not isinstance(
128
- quantization_config, QuantizationConfig
129
- ):
137
+ if quantization_config is not None:
130
138
  quantization_config = QuantizationConfig.parse_obj(quantization_config)
131
139
 
132
140
  return cls(
@@ -149,7 +157,7 @@ class ModelCompressor:
149
157
  to a sparsity compression algorithm
150
158
  :param quantization_format: string corresponding to a quantization compression
151
159
  algorithm
152
- :return: compressor for the extracted configs
160
+ :return: compressor for the configs, or None if model is not compressed
153
161
  """
154
162
  quantization_config = QuantizationConfig.from_pretrained(
155
163
  model, format=quantization_format
@@ -168,39 +176,60 @@ class ModelCompressor:
168
176
  )
169
177
 
170
178
  @staticmethod
171
- def parse_sparsity_config(compression_config: Dict) -> Union[Dict, None]:
179
+ def parse_sparsity_config(
180
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
181
+ ) -> Union[Dict[str, Any], None]:
182
+ """
183
+ Parse sparsity config from quantization/compression config. Sparsity
184
+ config is nested inside q/c config
185
+
186
+ :param compression_config: quantization/compression config
187
+ :return: sparsity config
188
+ """
172
189
  if compression_config is None:
173
190
  return None
174
- if SPARSITY_CONFIG_NAME not in compression_config:
175
- return None
176
- if hasattr(compression_config, SPARSITY_CONFIG_NAME):
177
- # for loaded HFQuantizer config
178
- return getattr(compression_config, SPARSITY_CONFIG_NAME)
179
- if SPARSITY_CONFIG_NAME in compression_config:
180
- # for loaded HFQuantizer config from dict
181
- return compression_config[SPARSITY_CONFIG_NAME]
182
-
183
- # SparseAutoModel format
191
+
192
+ if is_compressed_tensors_config(compression_config):
193
+ s_config = compression_config.sparsity_config
194
+ return s_config.dict() if s_config is not None else None
195
+
184
196
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
185
197
 
186
198
  @staticmethod
187
- def parse_quantization_config(compression_config: Dict) -> Union[Dict, None]:
199
+ def parse_quantization_config(
200
+ compression_config: Union[Dict[str, Any], "CompressedTensorsConfig"]
201
+ ) -> Union[Dict[str, Any], None]:
202
+ """
203
+ Parse quantization config from quantization/compression config. The
204
+ quantization are all the fields that are not the sparsity config or
205
+ metadata fields
206
+
207
+ :param compression_config: quantization/compression config
208
+ :return: quantization config without sparsity config or metadata fields
209
+ """
188
210
  if compression_config is None:
189
211
  return None
190
212
 
191
- if hasattr(compression_config, QUANTIZATION_CONFIG_NAME):
192
- # for loaded HFQuantizer config
193
- return getattr(compression_config, QUANTIZATION_CONFIG_NAME)
194
-
195
- if QUANTIZATION_CONFIG_NAME in compression_config:
196
- # for loaded HFQuantizer config from dict
197
- return compression_config[QUANTIZATION_CONFIG_NAME]
213
+ if is_compressed_tensors_config(compression_config):
214
+ q_config = compression_config.quantization_config
215
+ return q_config.dict() if q_config is not None else None
198
216
 
199
- # SparseAutoModel format
200
217
  quantization_config = deepcopy(compression_config)
201
218
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
219
+
220
+ # some fields are required, even if a qconfig is not present
221
+ # pop them off and if nothing remains, then there is no qconfig
222
+ quant_method = quantization_config.pop(QUANTIZATION_METHOD_NAME, None)
223
+ _ = quantization_config.pop(COMPRESSION_VERSION_NAME, None)
224
+
202
225
  if len(quantization_config) == 0:
203
- quantization_config = None
226
+ return None
227
+
228
+ # replace popped off values
229
+ # note that version is discarded for now
230
+ if quant_method is not None:
231
+ quantization_config[QUANTIZATION_METHOD_NAME] = quant_method
232
+
204
233
  return quantization_config
205
234
 
206
235
  def __init__(
@@ -214,11 +243,11 @@ class ModelCompressor:
214
243
  self.quantization_compressor = None
215
244
 
216
245
  if sparsity_config is not None:
217
- self.sparsity_compressor = Compressor.load_from_registry(
246
+ self.sparsity_compressor = BaseCompressor.load_from_registry(
218
247
  sparsity_config.format, config=sparsity_config
219
248
  )
220
249
  if quantization_config is not None:
221
- self.quantization_compressor = Compressor.load_from_registry(
250
+ self.quantization_compressor = BaseCompressor.load_from_registry(
222
251
  quantization_config.format, config=quantization_config
223
252
  )
224
253
 
@@ -229,7 +258,7 @@ class ModelCompressor:
229
258
  Compresses a dense state dict or model with sparsity and/or quantization
230
259
 
231
260
  :param model: uncompressed model to compress
232
- :param model_state: optional uncompressed state_dict to insert into model
261
+ :param state_dict: optional uncompressed state_dict to insert into model
233
262
  :return: compressed state dict
234
263
  """
235
264
  if state_dict is None:
@@ -251,62 +280,6 @@ class ModelCompressor:
251
280
  compressed_state_dict
252
281
  )
253
282
 
254
- # HACK (mgoin): Post-process step for kv cache scales to take the
255
- # k/v_proj module `output_scale` parameters, and store them in the
256
- # parent attention module as `k_scale` and `v_scale`
257
- #
258
- # Example:
259
- # Replace `model.layers.0.self_attn.k_proj.output_scale`
260
- # with `model.layers.0.self_attn.k_scale`
261
- if (
262
- self.quantization_config is not None
263
- and self.quantization_config.kv_cache_scheme is not None
264
- ):
265
- # HACK (mgoin): We assume the quantized modules in question
266
- # will be k_proj and v_proj since those are the default targets.
267
- # We check that both of these modules have output activation
268
- # quantization, and additionally check that q_proj doesn't.
269
- q_proj_has_no_quant_output = 0
270
- k_proj_has_quant_output = 0
271
- v_proj_has_quant_output = 0
272
- for name, module in model.named_modules():
273
- if not hasattr(module, "quantization_scheme"):
274
- # We still want to count non-quantized q_proj
275
- if name.endswith(".q_proj"):
276
- q_proj_has_no_quant_output += 1
277
- continue
278
- out_act = module.quantization_scheme.output_activations
279
- if name.endswith(".q_proj") and out_act is None:
280
- q_proj_has_no_quant_output += 1
281
- elif name.endswith(".k_proj") and out_act is not None:
282
- k_proj_has_quant_output += 1
283
- elif name.endswith(".v_proj") and out_act is not None:
284
- v_proj_has_quant_output += 1
285
-
286
- assert (
287
- q_proj_has_no_quant_output > 0
288
- and k_proj_has_quant_output > 0
289
- and v_proj_has_quant_output > 0
290
- )
291
- assert (
292
- q_proj_has_no_quant_output
293
- == k_proj_has_quant_output
294
- == v_proj_has_quant_output
295
- )
296
-
297
- # Move all .k/v_proj.output_scale parameters to .k/v_scale
298
- working_state_dict = {}
299
- for key in compressed_state_dict.keys():
300
- if key.endswith(".k_proj.output_scale"):
301
- new_key = key.replace(".k_proj.output_scale", ".k_scale")
302
- working_state_dict[new_key] = compressed_state_dict[key]
303
- elif key.endswith(".v_proj.output_scale"):
304
- new_key = key.replace(".v_proj.output_scale", ".v_scale")
305
- working_state_dict[new_key] = compressed_state_dict[key]
306
- else:
307
- working_state_dict[key] = compressed_state_dict[key]
308
- compressed_state_dict = working_state_dict
309
-
310
283
  # HACK: Override the dtype_byte_size function in transformers to
311
284
  # support float8 types. Fix is posted upstream
312
285
  # https://github.com/huggingface/transformers/pull/30488
@@ -348,6 +321,9 @@ class ModelCompressor:
348
321
 
349
322
  :param save_directory: path to a folder containing a HF model config
350
323
  """
324
+ if self.quantization_config is None and self.sparsity_config is None:
325
+ return
326
+
351
327
  config_file_path = os.path.join(save_directory, CONFIG_NAME)
352
328
  if not os.path.exists(config_file_path):
353
329
  _LOGGER.warning(
@@ -359,13 +335,26 @@ class ModelCompressor:
359
335
  with open(config_file_path, "r") as config_file:
360
336
  config_data = json.load(config_file)
361
337
 
362
- config_data[COMPRESSION_CONFIG_NAME] = {}
338
+ # required metadata whenever a quantization or sparsity config is present
339
+ # overwrite previous config and version if already existing
340
+ config_data[QUANTIZATION_CONFIG_NAME] = {}
341
+ config_data[QUANTIZATION_CONFIG_NAME][
342
+ COMPRESSION_VERSION_NAME
343
+ ] = compressed_tensors.__version__
344
+ if self.quantization_config is not None:
345
+ self.quantization_config.quant_method = DEFAULT_QUANTIZATION_METHOD
346
+ else:
347
+ config_data[QUANTIZATION_CONFIG_NAME][
348
+ QUANTIZATION_METHOD_NAME
349
+ ] = DEFAULT_QUANTIZATION_METHOD
350
+
351
+ # quantization and sparsity configs
363
352
  if self.quantization_config is not None:
364
353
  quant_config_data = self.quantization_config.model_dump()
365
- config_data[COMPRESSION_CONFIG_NAME] = quant_config_data
354
+ config_data[QUANTIZATION_CONFIG_NAME] = quant_config_data
366
355
  if self.sparsity_config is not None:
367
356
  sparsity_config_data = self.sparsity_config.model_dump()
368
- config_data[COMPRESSION_CONFIG_NAME][
357
+ config_data[QUANTIZATION_CONFIG_NAME][
369
358
  SPARSITY_CONFIG_NAME
370
359
  ] = sparsity_config_data
371
360
 
@@ -0,0 +1,18 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # flake8: noqa
15
+
16
+ from .base import *
17
+ from .naive_quantized import *
18
+ from .pack_quantized import *