compressed-tensors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. compressed_tensors/base.py +2 -1
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +11 -54
  4. compressed_tensors/compressors/dense.py +4 -4
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/int_quantized.py +126 -0
  7. compressed_tensors/compressors/marlin_24.py +250 -0
  8. compressed_tensors/compressors/model_compressor.py +315 -0
  9. compressed_tensors/compressors/pack_quantized.py +212 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/compressors/utils/__init__.py +19 -0
  12. compressed_tensors/compressors/utils/helpers.py +43 -0
  13. compressed_tensors/compressors/utils/permutations_24.py +65 -0
  14. compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
  15. compressed_tensors/config/base.py +7 -4
  16. compressed_tensors/config/dense.py +4 -4
  17. compressed_tensors/config/sparse_bitmask.py +3 -3
  18. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  19. compressed_tensors/quantization/lifecycle/apply.py +75 -19
  20. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors/quantization/lifecycle/forward.py +208 -22
  22. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors/quantization/lifecycle/initialize.py +33 -5
  24. compressed_tensors/quantization/observers/base.py +70 -5
  25. compressed_tensors/quantization/observers/helpers.py +6 -1
  26. compressed_tensors/quantization/observers/memoryless.py +17 -9
  27. compressed_tensors/quantization/observers/min_max.py +44 -13
  28. compressed_tensors/quantization/quant_args.py +33 -4
  29. compressed_tensors/quantization/quant_config.py +69 -21
  30. compressed_tensors/quantization/quant_scheme.py +81 -1
  31. compressed_tensors/quantization/utils/helpers.py +77 -8
  32. compressed_tensors/utils/helpers.py +26 -122
  33. compressed_tensors/utils/safetensors_load.py +3 -2
  34. compressed_tensors/version.py +53 -0
  35. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -9
  36. compressed_tensors-0.4.0.dist-info/RECORD +48 -0
  37. compressed_tensors-0.3.2.dist-info/RECORD +0 -38
  38. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
  39. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
  40. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@
15
15
  from enum import Enum
16
16
  from typing import Any, Dict, Optional
17
17
 
18
- from pydantic import BaseModel, Field
18
+ from pydantic import BaseModel, Field, validator
19
19
 
20
20
 
21
21
  __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
@@ -39,9 +39,10 @@ class QuantizationStrategy(str, Enum):
39
39
  CHANNEL = "channel"
40
40
  GROUP = "group"
41
41
  BLOCK = "block"
42
+ TOKEN = "token"
42
43
 
43
44
 
44
- class QuantizationArgs(BaseModel):
45
+ class QuantizationArgs(BaseModel, use_enum_values=True):
45
46
  """
46
47
  User facing arguments used to define a quantization config for weights or
47
48
  activations
@@ -61,10 +62,10 @@ class QuantizationArgs(BaseModel):
61
62
  """
62
63
 
63
64
  num_bits: int = 8
64
- type: QuantizationType = QuantizationType.INT
65
+ type: QuantizationType = QuantizationType.INT.value
65
66
  symmetric: bool = True
66
- strategy: QuantizationStrategy = QuantizationStrategy.TENSOR
67
67
  group_size: Optional[int] = None
68
+ strategy: Optional[QuantizationStrategy] = None
68
69
  block_structure: Optional[str] = None
69
70
  dynamic: bool = False
70
71
  observer: str = Field(
@@ -94,3 +95,31 @@ class QuantizationArgs(BaseModel):
94
95
  self.observer = "memoryless"
95
96
 
96
97
  return Observer.load_from_registry(self.observer, quantization_args=self)
98
+
99
+ @validator("strategy", pre=True, always=True)
100
+ def validate_strategy(cls, value, values):
101
+ group_size = values.get("group_size")
102
+
103
+ # use group_size to determinine strategy if not given explicity
104
+ if group_size is not None and value is None:
105
+ if group_size > 0:
106
+ return QuantizationStrategy.GROUP
107
+
108
+ elif group_size == -1:
109
+ return QuantizationStrategy.CHANNEL
110
+
111
+ else:
112
+ raise ValueError(
113
+ f"group_size={group_size} with strategy {value} is invald. "
114
+ "group_size > 0 for strategy='group' and "
115
+ "group_size = -1 for 'channel'"
116
+ )
117
+
118
+ if value == QuantizationStrategy.GROUP:
119
+ if group_size is None:
120
+ raise ValueError(f"strategy {value} requires group_size to be set.")
121
+
122
+ if value is None:
123
+ return QuantizationStrategy.TENSOR
124
+
125
+ return value
@@ -13,10 +13,13 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Dict, List, Optional
16
+ from typing import Dict, List, Optional, Union
17
17
 
18
- from compressed_tensors.base import QUANTIZATION_CONFIG_NAME
19
- from compressed_tensors.quantization.quant_scheme import QuantizationScheme
18
+ from compressed_tensors.config import CompressionFormat
19
+ from compressed_tensors.quantization.quant_scheme import (
20
+ QuantizationScheme,
21
+ preset_name_to_scheme,
22
+ )
20
23
  from compressed_tensors.quantization.utils import (
21
24
  calculate_compression_ratio,
22
25
  is_module_quantized,
@@ -25,13 +28,14 @@ from compressed_tensors.quantization.utils import (
25
28
  )
26
29
  from pydantic import BaseModel, Field
27
30
  from torch.nn import Module
28
- from transformers import AutoConfig
29
31
 
30
32
 
31
33
  __all__ = [
32
34
  "QuantizationStatus",
33
35
  "QuantizationConfig",
34
36
  "LIFECYCLE_ORDER",
37
+ "DEFAULT_QUANTIZATION_METHOD",
38
+ "DEFAULT_QUANTIZATION_FORMAT",
35
39
  ]
36
40
 
37
41
 
@@ -62,10 +66,33 @@ class QuantizationStatus(str, Enum):
62
66
  return
63
67
 
64
68
  def __ge__(self, other):
69
+ if other is None:
70
+ return True
65
71
  if not isinstance(other, self.__class__):
66
72
  raise NotImplementedError
67
73
  return LIFECYCLE_ORDER.index(self) >= LIFECYCLE_ORDER.index(other)
68
74
 
75
+ def __gt__(self, other):
76
+ if other is None:
77
+ return True
78
+ if not isinstance(other, self.__class__):
79
+ raise NotImplementedError
80
+ return LIFECYCLE_ORDER.index(self) > LIFECYCLE_ORDER.index(other)
81
+
82
+ def __lt__(self, other):
83
+ if other is None:
84
+ return False
85
+ if not isinstance(other, self.__class__):
86
+ raise NotImplementedError
87
+ return LIFECYCLE_ORDER.index(self) < LIFECYCLE_ORDER.index(other)
88
+
89
+ def __le__(self, other):
90
+ if other is None:
91
+ return False
92
+ if not isinstance(other, self.__class__):
93
+ raise NotImplementedError
94
+ return LIFECYCLE_ORDER.index(self) <= LIFECYCLE_ORDER.index(other)
95
+
69
96
 
70
97
  LIFECYCLE_ORDER = [
71
98
  QuantizationStatus.INITIALIZED,
@@ -74,6 +101,9 @@ LIFECYCLE_ORDER = [
74
101
  QuantizationStatus.COMPRESSED,
75
102
  ]
76
103
 
104
+ DEFAULT_QUANTIZATION_METHOD = "compressed-tensors"
105
+ DEFAULT_QUANTIZATION_FORMAT = "fakequant"
106
+
77
107
 
78
108
  class QuantizationConfig(BaseModel):
79
109
  """
@@ -81,7 +111,8 @@ class QuantizationConfig(BaseModel):
81
111
  mapped to a QuantizationScheme in config_groups.
82
112
 
83
113
  :param config_groups: dict of QuantizationSchemes specifying the quantization
84
- settings for each quantized layer
114
+ settings for each quantized layer. A group could also be a reference to
115
+ a predefined scheme name, mapped to a list of its target layers/classes
85
116
  :param quant_method: a constant used to differentiate sparseML quantization from
86
117
  other quantization configs
87
118
  :param format: specifies how the quantized model is stored on disk
@@ -93,30 +124,34 @@ class QuantizationConfig(BaseModel):
93
124
  are not quantized even if they match up with a target in config_groups
94
125
  """
95
126
 
96
- config_groups: Dict[str, QuantizationScheme]
97
- quant_method: str = "sparseml"
98
- format: str = "fakequant"
127
+ config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
128
+ quant_method: str = DEFAULT_QUANTIZATION_METHOD
129
+ format: str = DEFAULT_QUANTIZATION_FORMAT
99
130
  quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED
100
131
  global_compression_ratio: Optional[float] = None
101
132
  ignore: Optional[List[str]] = Field(default_factory=list)
102
133
 
103
- @staticmethod
104
- def from_model_config(model_name_or_path) -> "QuantizationConfig":
134
+ def model_post_init(self, __context):
105
135
  """
106
- Given a path to a model config, extract a quantization config if it exists
107
-
108
- :param pretrained_model_name_or_path: path to model config on disk or HF hub
109
- :return: instantiated QuantizationConfig if config contains a quant config
136
+ updates any quantization schemes defined as presets to be fully loaded
137
+ schemes
110
138
  """
111
- config = AutoConfig.from_pretrained(model_name_or_path)
112
- quantization_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
113
- if quantization_config is None:
114
- return None
115
-
116
- return QuantizationConfig.parse_obj(quantization_config)
139
+ for group_name, targets_or_scheme in self.config_groups.items():
140
+ if isinstance(targets_or_scheme, QuantizationScheme):
141
+ continue # scheme already defined
142
+ self.config_groups[group_name] = preset_name_to_scheme(
143
+ name=group_name,
144
+ targets=targets_or_scheme,
145
+ )
146
+
147
+ def to_dict(self):
148
+ # for compatibility with HFQuantizer
149
+ return self.dict()
117
150
 
118
151
  @staticmethod
119
- def from_pretrained(model: Module) -> "QuantizationConfig":
152
+ def from_pretrained(
153
+ model: Module, format: Optional[str] = None
154
+ ) -> Optional["QuantizationConfig"]:
120
155
  """
121
156
  Converts a model into its associated QuantizationConfig based on the
122
157
  QuantizationScheme attached to each quanitzed module
@@ -147,6 +182,9 @@ class QuantizationConfig(BaseModel):
147
182
  if not match_found:
148
183
  quant_scheme_to_layers.append(scheme)
149
184
 
185
+ if len(quant_scheme_to_layers) == 0: # No quantized layers
186
+ return None
187
+
150
188
  # clean up ignore list, we can leave out layers types if none of the
151
189
  # instances are quantized
152
190
  consolidated_ignore = []
@@ -162,10 +200,20 @@ class QuantizationConfig(BaseModel):
162
200
  group_name = "group_" + str(idx)
163
201
  config_groups[group_name] = scheme
164
202
 
203
+ # TODO: this is incorrect in compressed mode, since we are overwriting the
204
+ # original weight we lose the uncompressed bit_depth indo
165
205
  compression_ratio = calculate_compression_ratio(model)
206
+
207
+ if format is None:
208
+ if quantization_status == QuantizationStatus.COMPRESSED:
209
+ format = CompressionFormat.int_quantized.value
210
+ else:
211
+ format = CompressionFormat.dense.value
212
+
166
213
  return QuantizationConfig(
167
214
  config_groups=config_groups,
168
215
  quantization_status=quantization_status,
169
216
  global_compression_ratio=compression_ratio,
217
+ format=format,
170
218
  ignore=consolidated_ignore,
171
219
  )
@@ -12,13 +12,18 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from copy import deepcopy
15
16
  from typing import List, Optional
16
17
 
17
18
  from compressed_tensors.quantization.quant_args import QuantizationArgs
18
19
  from pydantic import BaseModel
19
20
 
20
21
 
21
- __all__ = ["QuantizationScheme"]
22
+ __all__ = [
23
+ "QuantizationScheme",
24
+ "preset_name_to_scheme",
25
+ "is_preset_scheme",
26
+ ]
22
27
 
23
28
 
24
29
  class QuantizationScheme(BaseModel):
@@ -37,3 +42,78 @@ class QuantizationScheme(BaseModel):
37
42
  weights: Optional[QuantizationArgs] = None
38
43
  input_activations: Optional[QuantizationArgs] = None
39
44
  output_activations: Optional[QuantizationArgs] = None
45
+
46
+ @classmethod
47
+ def default_scheme(
48
+ cls,
49
+ targets: Optional[List[str]] = None,
50
+ ):
51
+
52
+ if targets is None:
53
+ # default to quantizing all Linear layers
54
+ targets = ["Linear"]
55
+
56
+ # default to 8 bit integer symmetric quantization
57
+ # for weights
58
+ weights = QuantizationArgs(num_bits=8, symmetric=True)
59
+
60
+ # default to 8 bit integer asymmetric quantization
61
+ input_activations = QuantizationArgs(num_bits=8, symmetric=True)
62
+
63
+ # Do not quantize the output activations
64
+ # by default
65
+ output_activations = None
66
+
67
+ return cls(
68
+ targets=targets,
69
+ weights=weights,
70
+ input_activations=input_activations,
71
+ output_activations=output_activations,
72
+ )
73
+
74
+
75
+ """
76
+ Pre-Set Quantization Scheme Args
77
+ """
78
+
79
+
80
+ def preset_name_to_scheme(name: str, targets: List[str]) -> QuantizationScheme:
81
+ """
82
+ :param name: preset quantization settings name. must exist in upper case in
83
+ PRESET_SCHEMES
84
+ :param targets: list of quantization targets to be passed to the Scheme
85
+ :return: new QuantizationScheme for a given name with the given targets
86
+ """
87
+ name = name.upper()
88
+
89
+ if name not in PRESET_SCHEMES:
90
+ raise KeyError(
91
+ f"Unknown preset scheme name {name}, "
92
+ f"available names: {list(PRESET_SCHEMES.keys())}"
93
+ )
94
+
95
+ scheme_args = deepcopy(PRESET_SCHEMES[name]) # deepcopy to avoid args references
96
+ return QuantizationScheme(
97
+ targets=targets,
98
+ **scheme_args,
99
+ )
100
+
101
+
102
+ def is_preset_scheme(name: str) -> bool:
103
+ """
104
+ :param name: preset quantization settings name
105
+ :return: True if the name is a preset scheme name
106
+ """
107
+ return name.upper() in PRESET_SCHEMES
108
+
109
+
110
+ W8A8 = dict(
111
+ weights=QuantizationArgs(), input_activations=QuantizationArgs(symmetric=True)
112
+ )
113
+
114
+ W4A16 = dict(weights=QuantizationArgs(num_bits=4, group_size=128))
115
+
116
+ PRESET_SCHEMES = {
117
+ "W8A8": W8A8,
118
+ "W4A16": W4A16,
119
+ }
@@ -12,21 +12,43 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ import logging
16
+ from typing import Optional, Tuple
16
17
 
17
18
  import torch
19
+ from compressed_tensors.quantization.observers.base import Observer
18
20
  from torch.nn import Module
19
21
  from tqdm import tqdm
20
22
 
21
23
 
22
24
  __all__ = [
25
+ "infer_quantization_status",
23
26
  "is_module_quantized",
24
27
  "is_model_quantized",
25
28
  "iter_named_leaf_modules",
26
29
  "module_type",
27
30
  "calculate_compression_ratio",
31
+ "get_torch_bit_depth",
32
+ "can_quantize",
28
33
  ]
29
34
 
35
+ _LOGGER: logging.Logger = logging.getLogger(__name__)
36
+
37
+
38
+ def infer_quantization_status(model: Module) -> Optional["QuantizationStatus"]: # noqa
39
+ """
40
+ Checks the quantization status of a model. Assumes all modules in the model have
41
+ the same status, so only the first quantized model is checked.
42
+
43
+ :param model: model to check quantization status for
44
+ :return: quantization status if the model is quantized, otherwise None
45
+ """
46
+ for module in model.modules():
47
+ status = getattr(module, "quantization_status", None)
48
+ if status is not None:
49
+ return status
50
+ return None
51
+
30
52
 
31
53
  def is_module_quantized(module: Module) -> bool:
32
54
  """
@@ -78,11 +100,60 @@ def module_type(module: Module) -> str:
78
100
 
79
101
 
80
102
  def iter_named_leaf_modules(model: Module) -> Tuple[str, Module]:
81
- # yields modules that do not have any submodules
82
- # TODO: potentially expand to add list of allowed submodules such as observers
103
+ """
104
+ Yields modules that do not have any submodules except observers. The observers
105
+ themselves are not yielded
106
+
107
+ :param model: model to get leaf modules of
108
+ :returns: generator tuple of (name, leaf_submodule)
109
+ """
83
110
  for name, submodule in model.named_modules():
84
- if len(list(submodule.children())) == 0:
111
+ children = list(submodule.children())
112
+ if len(children) == 0 and not isinstance(submodule, Observer):
85
113
  yield name, submodule
114
+ else:
115
+ has_non_observer_children = False
116
+ for child in children:
117
+ if not isinstance(child, Observer):
118
+ has_non_observer_children = True
119
+
120
+ if not has_non_observer_children:
121
+ yield name, submodule
122
+
123
+
124
+ def get_torch_bit_depth(value: torch.Tensor) -> int:
125
+ """
126
+ Determine the number of bits used to represent the dtype of a tensor
127
+
128
+ :param value: tensor to check bit depth of
129
+ :return: bit depth of each element in the value tensor
130
+ """
131
+ try:
132
+ bit_depth = torch.finfo(value.dtype).bits
133
+ except TypeError:
134
+ bit_depth = torch.iinfo(value.dtype).bits
135
+
136
+ return bit_depth
137
+
138
+
139
+ def can_quantize(value: torch.Tensor, quant_args: "QuantizationArgs") -> bool: # noqa
140
+ """
141
+ Checks if value can be quantized by quant_args.
142
+
143
+ :param value: tensor to check for quantization
144
+ :param quant_args: QuantizationArgs to use for quantization
145
+ :return: False if value is already quantized to quant_args or value is incompatible
146
+ with quant_args, True if value can be quantized with quant_args
147
+ """
148
+ bit_depth = get_torch_bit_depth(value)
149
+ requested_depth = quant_args.num_bits
150
+ if bit_depth < quant_args.num_bits:
151
+ _LOGGER.warn(
152
+ f"Can't quantize tensor with bit depth {bit_depth} to {requested_depth}."
153
+ "The QuantizationArgs provided are not compatible with the input tensor."
154
+ )
155
+
156
+ return bit_depth > quant_args.num_bits
86
157
 
87
158
 
88
159
  def calculate_compression_ratio(model: Module) -> float:
@@ -101,13 +172,11 @@ def calculate_compression_ratio(model: Module) -> float:
101
172
  desc="Calculating quantization compression ratio",
102
173
  ):
103
174
  for parameter in model.parameters():
104
- try:
105
- uncompressed_bits = torch.finfo(parameter.dtype).bits
106
- except TypeError:
107
- uncompressed_bits = torch.iinfo(parameter.dtype).bits
175
+ uncompressed_bits = get_torch_bit_depth(parameter)
108
176
  compressed_bits = uncompressed_bits
109
177
  if is_module_quantized(submodule):
110
178
  compressed_bits = submodule.quantization_scheme.weights.num_bits
179
+
111
180
  num_weights = parameter.numel()
112
181
  total_compressed += compressed_bits * num_weights
113
182
  total_uncompressed += uncompressed_bits * num_weights
@@ -12,47 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from pathlib import Path
16
- from typing import Dict, Optional, Union
17
15
 
18
- import torch
19
- from compressed_tensors.base import SPARSITY_CONFIG_NAME
20
- from compressed_tensors.compressors import ModelCompressor
21
- from compressed_tensors.config import (
22
- CompressionConfig,
23
- CompressionFormat,
24
- DenseSparsityConfig,
25
- )
26
- from safetensors.torch import save_file
27
- from torch import Tensor
16
+ from typing import Optional
17
+
28
18
  from transformers import AutoConfig
29
19
 
30
20
 
31
- __all__ = [
32
- "infer_compressor_from_model_config",
33
- "infer_compression_config_from_model_config",
34
- "load_compressed",
35
- "save_compressed",
36
- "save_compressed_model",
37
- ]
21
+ __all__ = ["infer_compressor_from_model_config", "fix_fsdp_module_name"]
38
22
 
39
- def infer_compressor_from_model_config(
40
- pretrained_model_name_or_path: str,
41
- ) -> Optional[CompressionConfig]:
42
- """
43
- Given a path to a model config, extract a sparsity config if it exists and return
44
- the associated CompressionConfig
45
-
46
- :param pretrained_model_name_or_path: path to model config on disk or HF hub
47
- :return: matching compression config if config contains a sparsity config
48
- """
49
- config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
50
- return getattr(config, SPARSITY_CONFIG_NAME, None)
23
+ FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
51
24
 
52
25
 
53
26
  def infer_compressor_from_model_config(
54
27
  pretrained_model_name_or_path: str,
55
- ) -> Optional[ModelCompressor]:
28
+ ) -> Optional["ModelCompressor"]: # noqa: F821
56
29
  """
57
30
  Given a path to a model config, extract a sparsity config if it exists and return
58
31
  the associated ModelCompressor
@@ -60,100 +33,31 @@ def infer_compressor_from_model_config(
60
33
  :param pretrained_model_name_or_path: path to model config on disk or HF hub
61
34
  :return: matching compressor if config contains a sparsity config
62
35
  """
63
- sparsity_config = infer_compressor_from_model_config(pretrained_model_name_or_path)
64
- compressor = ModelCompressor.load_from_registry(sparsity_config.format, config=sparsity_config)
65
- return compressor
66
-
67
-
68
- def save_compressed(
69
- tensors: Dict[str, Tensor],
70
- save_path: Union[str, Path],
71
- compression_format: Optional[CompressionFormat] = None,
72
- ):
73
- """
74
- Save compressed tensors to disk. If tensors are not compressed,
75
- save them as is.
76
-
77
- :param tensors: dictionary of tensors to compress
78
- :param save_path: path to save compressed tensors
79
- :param compression_format: compression format used for the tensors
80
- :return: compression config, if tensors were compressed - None otherwise
81
- """
82
- if tensors is None or len(tensors) == 0:
83
- raise ValueError("No tensors or empty tensors provided to compress")
36
+ from compressed_tensors.compressors import ModelCompressor
37
+ from compressed_tensors.config import CompressionConfig
84
38
 
85
- # if no compression_format specified, default to `dense_sparsity`
86
- compression_format = compression_format or CompressionFormat.dense_sparsity.value
87
-
88
- if not (
89
- compression_format in ModelCompressor.registered_names()
90
- or compression_format in ModelCompressor.registered_aliases()
91
- ):
92
- raise ValueError(
93
- f"Unknown compression format: {compression_format}. "
94
- f"Must be one of {set(ModelCompressor.registered_names() + ModelCompressor.registered_aliases())}" # noqa E501
95
- )
39
+ config = AutoConfig.from_pretrained(pretrained_model_name_or_path)
40
+ sparsity_config = ModelCompressor.parse_sparsity_config(config)
41
+ if sparsity_config is None:
42
+ return None
96
43
 
97
- # compress
98
- compressor = ModelCompressor.load_from_registry(compression_format)
99
- # save compressed tensors
100
- compressed_tensors = compressor.compress(tensors)
101
- save_file(compressed_tensors, save_path)
44
+ format = sparsity_config.get("format")
45
+ sparsity_config = CompressionConfig.load_from_registry(format, **sparsity_config)
46
+ compressor = ModelCompressor.load_from_registry(format, config=sparsity_config)
47
+ return compressor
102
48
 
103
49
 
104
- def load_compressed(
105
- compressed_tensors: Union[str, Path],
106
- compression_config: CompressionConfig = None,
107
- device: Optional[str] = "cpu",
108
- ) -> Dict[str, Tensor]:
50
+ # TODO: There is already the same function in
51
+ # SparseML, should be moved to a shared location
52
+ # in the future
53
+ def fix_fsdp_module_name(name: str) -> str:
109
54
  """
110
- Load compressed tensors from disk. If tensors are not compressed,
111
- load them as is.
112
-
113
- :param compressed_tensors: path to compressed tensors
114
- :param compression_config: compression config to use for decompressing tensors.
115
- :param device: device to move tensors to. If None, tensors are loaded on CPU.
116
- :return decompressed tensors
55
+ Remove FSDP wrapper prefixes from a module name
56
+ Accounts for scenario where FSDP_WRAPPER_NAME is
57
+ at the end of the name, as well as in the middle.
58
+ :param name: name to strip
59
+ :return: stripped name
117
60
  """
118
-
119
- if compressed_tensors is None or not Path(compressed_tensors).exists():
120
- raise ValueError("No compressed tensors provided to load")
121
-
122
- # if no compression_config specified, default to `dense_sparsity`
123
- compression_config = compression_config or DenseSparsityConfig()
124
-
125
- # decompress
126
- compression_format = compression_config.format
127
- compressor = ModelCompressor.load_from_registry(
128
- compression_format, config=compression_config
61
+ return name.replace(FSDP_WRAPPER_NAME + ".", "").replace(
62
+ "." + FSDP_WRAPPER_NAME, ""
129
63
  )
130
- return dict(compressor.decompress(compressed_tensors, device=device))
131
-
132
-
133
- def save_compressed_model(
134
- model: torch.nn.Module,
135
- filename: str,
136
- compression_format: Optional[CompressionFormat] = None,
137
- force_contiguous: bool = True,
138
- ):
139
- """
140
- Wrapper around safetensors `save_model` helper function, which allows for
141
- saving compressed model to disk.
142
-
143
- Note: The model is assumed to have a
144
- state_dict with unique entries
145
-
146
- :param model: model to save on disk
147
- :param filename: filename location to save the file
148
- :param compression_format: compression format used for the model
149
- :param force_contiguous: forcing the state_dict to be saved as contiguous tensors
150
- """
151
- state_dict = model.state_dict()
152
- if force_contiguous:
153
- state_dict = {k: v.contiguous() for k, v in state_dict.items()}
154
- try:
155
- save_compressed(state_dict, filename, compression_format=compression_format)
156
- except ValueError as e:
157
- msg = str(e)
158
- msg += " Or use save_compressed_model(..., force_contiguous=True), read the docs for potential caveats." # noqa E501
159
- raise ValueError(msg)
@@ -31,6 +31,7 @@ __all__ = [
31
31
  "get_weight_mappings",
32
32
  "get_nested_weight_mappings",
33
33
  "get_quantization_state_dict",
34
+ "is_quantization_param",
34
35
  ]
35
36
 
36
37
 
@@ -214,7 +215,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
214
215
  weight_mappings = get_weight_mappings(model_path)
215
216
  state_dict = {}
216
217
  for weight_name, safe_path in weight_mappings.items():
217
- if not _is_quantization_weight(weight_name):
218
+ if not is_quantization_param(weight_name):
218
219
  continue
219
220
  with safe_open(safe_path, framework="pt", device="cpu") as f:
220
221
  state_dict[weight_name] = f.get_tensor(weight_name)
@@ -222,7 +223,7 @@ def get_quantization_state_dict(model_path: str) -> Dict[str, Tensor]:
222
223
  return state_dict
223
224
 
224
225
 
225
- def _is_quantization_weight(name: str) -> bool:
226
+ def is_quantization_param(name: str) -> bool:
226
227
  """
227
228
  Checks is a parameter name is associated with a quantization parameter
228
229