compressed-tensors-nightly 0.8.1.20250106__py3-none-any.whl → 0.8.1.20250108__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -17,8 +17,9 @@ import logging
17
17
  import operator
18
18
  import os
19
19
  import re
20
+ from contextlib import contextmanager
20
21
  from copy import deepcopy
21
- from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
22
23
 
23
24
  import compressed_tensors
24
25
  import torch
@@ -38,6 +39,7 @@ from compressed_tensors.quantization import (
38
39
  apply_quantization_config,
39
40
  load_pretrained_quantization,
40
41
  )
42
+ from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
41
43
  from compressed_tensors.quantization.quant_args import QuantizationArgs
42
44
  from compressed_tensors.quantization.utils import (
43
45
  is_module_quantized,
@@ -104,7 +106,6 @@ class ModelCompressor:
104
106
  """
105
107
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106
108
  compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107
-
108
109
  return cls.from_compression_config(compression_config)
109
110
 
110
111
  @classmethod
@@ -282,8 +283,14 @@ class ModelCompressor:
282
283
  )
283
284
 
284
285
  if self.sparsity_compressor is not None:
286
+ sparse_compression_targets: Set[str] = expand_sparse_target_names(
287
+ model=model,
288
+ targets=self.sparsity_config.targets,
289
+ ignore=self.sparsity_config.ignore,
290
+ )
285
291
  compressed_state_dict = self.sparsity_compressor.compress(
286
- compressed_state_dict
292
+ compressed_state_dict,
293
+ compression_targets=sparse_compression_targets,
287
294
  )
288
295
 
289
296
  # HACK: Override the dtype_byte_size function in transformers to
@@ -301,23 +308,41 @@ class ModelCompressor:
301
308
  :param model: pytorch model to load decompressed weights into
302
309
  """
303
310
  model_path = get_safetensors_folder(model_path)
311
+ sparse_decompressed = False
312
+
304
313
  if self.sparsity_compressor is not None:
314
+ # Sparse decompression is applied on the model_path
305
315
  dense_gen = self.sparsity_compressor.decompress(model_path)
306
316
  self._replace_weights(dense_gen, model)
307
317
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
318
+ sparse_decompressed = True
308
319
 
309
320
  if self.quantization_compressor is not None:
310
- names_to_scheme = apply_quantization_config(model, self.quantization_config)
311
- load_pretrained_quantization(model, model_path)
321
+ # Temporarily set quantization status to FROZEN to prevent
322
+ # quantization during apply_quantization_config. This ensures
323
+ # that the dtypes of the weights are not unintentionally updated.
324
+ # The status is restored after quantization params are loaded.
325
+ with override_quantization_status(
326
+ self.quantization_config, QuantizationStatus.FROZEN
327
+ ):
328
+ names_to_scheme = apply_quantization_config(
329
+ model, self.quantization_config
330
+ )
331
+ load_pretrained_quantization(model, model_path)
332
+
333
+ model_path_or_state_dict = (
334
+ model.state_dict() if sparse_decompressed else model_path
335
+ )
336
+
312
337
  dense_gen = self.quantization_compressor.decompress(
313
- model_path, names_to_scheme=names_to_scheme
338
+ model_path_or_state_dict, names_to_scheme=names_to_scheme
314
339
  )
315
340
  self._replace_weights(dense_gen, model)
316
341
 
317
- def update_status(module):
342
+ def freeze_quantization_status(module):
318
343
  module.quantization_status = QuantizationStatus.FROZEN
319
344
 
320
- model.apply(update_status)
345
+ model.apply(freeze_quantization_status)
321
346
  setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
322
347
 
323
348
  def update_config(self, save_directory: str):
@@ -367,12 +392,26 @@ class ModelCompressor:
367
392
  with open(config_file_path, "w") as config_file:
368
393
  json.dump(config_data, config_file, indent=2, sort_keys=True)
369
394
 
370
- def _replace_weights(self, dense_weight_generator, model):
395
+ def _replace_weights(self, dense_weight_generator, model: Module):
396
+ """
397
+ Replace the weights of the model with the
398
+ provided dense weights.
399
+
400
+ This method iterates over the dense_weight_generator and
401
+ updates the corresponding weights in the model. If a parameter
402
+ name does not exist in the model, it will be skipped.
403
+
404
+ :param dense_weight_generator (generator): A generator that yields
405
+ tuples of (name, data), where 'name' is the parameter name and
406
+ 'data' is the updated param data
407
+ :param model: The model whose weights are to be updated.
408
+ """
371
409
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
372
410
  split_name = name.split(".")
373
411
  prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
374
412
  module = operator.attrgetter(prefix)(model)
375
- update_parameter_data(module, data, param_name)
413
+ if hasattr(module, param_name):
414
+ update_parameter_data(module, data, param_name)
376
415
 
377
416
 
378
417
  def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
@@ -402,3 +441,23 @@ def new_dtype_byte_size(dtype):
402
441
  raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
403
442
  bit_size = int(bit_search.groups()[0])
404
443
  return bit_size // 8
444
+
445
+
446
+ @contextmanager
447
+ def override_quantization_status(
448
+ config: QuantizationConfig, status: QuantizationStatus
449
+ ):
450
+ """
451
+ Within this context, the quantization status will be set to the
452
+ supplied status. After the context exits, the original status
453
+ will be restored.
454
+
455
+ :param config: the quantization config to override
456
+ :param status: the status to temporarily set
457
+ """
458
+ original_status = config.quantization_status
459
+ config.quantization_status = status
460
+ try:
461
+ yield
462
+ finally:
463
+ config.quantization_status = original_status
@@ -13,12 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Generator, Tuple, Union
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.compressors.base import BaseCompressor
20
21
  from compressed_tensors.quantization import QuantizationArgs
21
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
22
+ from compressed_tensors.utils import (
23
+ get_nested_mappings_from_state_dict,
24
+ get_nested_weight_mappings,
25
+ merge_names,
26
+ )
22
27
  from safetensors import safe_open
23
28
  from torch import Tensor
24
29
  from tqdm import tqdm
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
113
118
 
114
119
  def decompress(
115
120
  self,
116
- path_to_model_or_tensors: str,
121
+ path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
117
122
  names_to_scheme: Dict[str, QuantizationArgs],
118
123
  device: str = "cpu",
119
124
  ) -> Generator[Tuple[str, Tensor], None, None]:
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
121
126
  Reads a compressed state dict located at path_to_model_or_tensors
122
127
  and returns a generator for sequentially decompressing back to a
123
128
  dense state dict
124
-
125
129
  :param path_to_model_or_tensors: path to compressed safetensors model (directory
126
130
  with one or more safetensors files) or compressed tensors file
127
131
  :param names_to_scheme: quantization args for each quantized weight
128
132
  :param device: optional device to load intermediate weights into
129
133
  :return: compressed state dict
130
134
  """
135
+ if isinstance(path_to_model_or_tensors, (str, Path)):
136
+ yield from self._decompress_from_path(
137
+ path_to_model_or_tensors, names_to_scheme, device
138
+ )
139
+
140
+ else:
141
+ yield from self._decompress_from_state_dict(
142
+ path_to_model_or_tensors, names_to_scheme
143
+ )
144
+
145
+ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
131
146
  weight_mappings = get_nested_weight_mappings(
132
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
147
+ path_to_model, self.COMPRESSION_PARAM_NAMES
133
148
  )
134
149
  for weight_name in weight_mappings.keys():
135
150
  weight_data = {}
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
137
152
  full_name = merge_names(weight_name, param_name)
138
153
  with safe_open(safe_path, framework="pt", device=device) as f:
139
154
  weight_data[param_name] = f.get_tensor(full_name)
155
+ if "weight_scale" in weight_data:
156
+ quant_args = names_to_scheme[weight_name]
157
+ decompressed = self.decompress_weight(
158
+ compressed_data=weight_data, quantization_args=quant_args
159
+ )
160
+ yield merge_names(weight_name, "weight"), decompressed
161
+
162
+ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
163
+ weight_mappings = get_nested_mappings_from_state_dict(
164
+ state_dict, self.COMPRESSION_PARAM_NAMES
165
+ )
166
+ for weight_name in weight_mappings.keys():
167
+ weight_data = {}
168
+ for param_name, param_value in weight_mappings[weight_name].items():
169
+ weight_data[param_name] = param_value
140
170
 
141
171
  if "weight_scale" in weight_data:
142
172
  quant_args = names_to_scheme[weight_name]
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from typing import Dict, Generator, Optional, Set, Tuple
17
17
 
18
18
  from compressed_tensors.compressors.base import BaseCompressor
19
19
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
30
30
  class BaseSparseCompressor(BaseCompressor):
31
31
  """
32
32
  Base class representing a sparse compression algorithm. Each child class should
33
- implement compression_param_info, compress_weight and decompress_weight.
33
+ implement compression_param_info, compress_weight and decompress_weight; child
34
+ classes should also define COMPRESSION_PARAM_NAMES.
34
35
 
35
36
  Compressors support compressing/decompressing a full module state dict or a single
36
37
  quantized PyTorch leaf module.
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
59
60
  :param config: config specifying compression parameters
60
61
  """
61
62
 
62
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63
+ def compress(
64
+ self,
65
+ model_state: Dict[str, Tensor],
66
+ compression_targets: Optional[Set[str]] = None,
67
+ ) -> Dict[str, Tensor]:
63
68
  """
64
69
  Compresses a dense state dict using bitmask compression
65
70
 
66
71
  :param model_state: state dict of uncompressed model
72
+ :param compression_targets: optional set of layer prefixes to compress,
73
+ otherwise compress all layers (for backwards compatibility)
67
74
  :return: compressed state dict
68
75
  """
69
76
  compressed_dict = {}
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
71
78
  f"Compressing model with {len(model_state)} parameterized layers..."
72
79
  )
73
80
  for name, value in tqdm(model_state.items(), desc="Compressing model"):
74
- compression_data = self.compress_weight(name, value)
81
+ if not self.should_compress(name, compression_targets):
82
+ compressed_dict[name] = value
83
+ continue
84
+ prefix = name
85
+ if prefix.endswith(".weight"):
86
+ prefix = prefix[: -(len(".weight"))]
87
+
88
+ compression_data = self.compress_weight(prefix, value)
75
89
  for key in compression_data.keys():
76
90
  if key in compressed_dict:
77
91
  _LOGGER.warn(
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
97
111
  :param device: device to load decompressed weights onto
98
112
  :return: iterator for generating decompressed weights
99
113
  """
100
- weight_mappings = get_nested_weight_mappings(
101
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
114
+ weight_mappings, ignored_params = get_nested_weight_mappings(
115
+ path_to_model_or_tensors,
116
+ self.COMPRESSION_PARAM_NAMES,
117
+ return_unmatched_params=True,
102
118
  )
103
119
  for weight_name in weight_mappings.keys():
104
120
  weight_data = {}
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
107
123
  with safe_open(safe_path, framework="pt", device=device) as f:
108
124
  weight_data[param_name] = f.get_tensor(full_name)
109
125
  decompressed = self.decompress_weight(weight_data)
110
- yield weight_name, decompressed
126
+ yield merge_names(weight_name, "weight"), decompressed
127
+
128
+ for ignored_param_name, safe_path in ignored_params.items():
129
+ with safe_open(safe_path, framework="pt", device=device) as f:
130
+ value = f.get_tensor(ignored_param_name)
131
+ yield ignored_param_name, value
132
+
133
+ @staticmethod
134
+ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
135
+ """
136
+ Check if a parameter should be compressed.
137
+ Currently, this only returns True for weight parameters.
138
+
139
+ :param name: name of the parameter
140
+ :param expanded_targets: set of layer prefixes to compress
141
+ :return: whether or not the parameter should be compressed
142
+ """
143
+ if expanded_targets is None:
144
+ return name.endswith(".weight")
145
+
146
+ return (
147
+ name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
148
+ )
@@ -19,6 +19,7 @@ import torch
19
19
  from compressed_tensors.compressors.base import BaseCompressor
20
20
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
21
  from compressed_tensors.config import CompressionFormat
22
+ from compressed_tensors.quantization import FP8_DTYPE
22
23
  from compressed_tensors.utils import merge_names
23
24
  from torch import Tensor
24
25
 
@@ -134,9 +135,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
134
135
  bytemasks = tensor != 0
135
136
  row_counts = bytemasks.sum(dim=-1)
136
137
  row_offsets = torch.cumsum(row_counts, 0) - row_counts
137
- values = tensor[bytemasks]
138
+ if tensor.dtype == FP8_DTYPE:
139
+ # acces raw bytes of the tensor
140
+ tensor_view = tensor.view(torch.int8)
141
+ values = tensor_view[bytemasks]
142
+ values = values.view(FP8_DTYPE)
143
+ else:
144
+ values = tensor[bytemasks]
138
145
  bitmasks_packed = pack_bitmasks(bytemasks)
139
-
140
146
  return values, bitmasks_packed, row_offsets
141
147
 
142
148
 
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
18
18
  from copy import deepcopy
19
19
  from typing import Dict, Iterable, List, Optional
20
20
  from typing import OrderedDict as OrderedDictType
21
- from typing import Union
21
+ from typing import Set, Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
@@ -52,6 +52,8 @@ __all__ = [
52
52
  "apply_quantization_config",
53
53
  "apply_quantization_status",
54
54
  "find_name_or_class_matches",
55
+ "expand_sparse_target_names",
56
+ "is_sparse_target",
55
57
  ]
56
58
 
57
59
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
245
247
  model.apply(compress_quantized_weights)
246
248
 
247
249
 
250
+ def expand_sparse_target_names(
251
+ model: Module, targets: Iterable[str], ignore: Iterable[str]
252
+ ) -> Set[str]:
253
+ """
254
+ Finds all unique module names in the model that match the given
255
+ targets and ignore lists.
256
+
257
+ Note: Targets must be regexes, layer types, or full layer names.
258
+
259
+ :param model: model to search for targets in
260
+ :param targets: list of targets to search for
261
+ :param ignore: list of targets to ignore
262
+ :return: set of all targets that match the given targets and should
263
+ not be ignored
264
+ """
265
+ return {
266
+ name
267
+ for name, module in iter_named_leaf_modules(model)
268
+ if is_sparse_target(name, module, targets, ignore)
269
+ }
270
+
271
+
272
+ def is_sparse_target(
273
+ name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
274
+ ) -> bool:
275
+ """
276
+ Determines if a module should be included in the targets based on the
277
+ targets and ignore lists.
278
+
279
+ Note: Targets must be regexes, layer types, or full layer names.
280
+
281
+ :param name: name of the module
282
+ :param module: the module itself
283
+ :param targets: list of targets to search for
284
+ :param ignore: list of targets to ignore
285
+ :return: True if the module is a target and not ignored, False otherwise
286
+ """
287
+ return bool(
288
+ find_name_or_class_matches(name, module, targets)
289
+ and not find_name_or_class_matches(name, module, ignore or [])
290
+ )
291
+
292
+
248
293
  def find_name_or_class_matches(
249
294
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
250
295
  ) -> List[str]:
@@ -82,8 +82,8 @@ def quantize(
82
82
  def dequantize(
83
83
  x_q: torch.Tensor,
84
84
  scale: torch.Tensor,
85
- zero_point: torch.Tensor = None,
86
- args: QuantizationArgs = None,
85
+ zero_point: Optional[torch.Tensor] = None,
86
+ args: Optional[QuantizationArgs] = None,
87
87
  dtype: Optional[torch.dtype] = None,
88
88
  g_idx: Optional[torch.Tensor] = None,
89
89
  ) -> torch.Tensor:
@@ -16,7 +16,7 @@ import json
16
16
  import os
17
17
  import re
18
18
  import struct
19
- from typing import Dict, List, Optional
19
+ from typing import Dict, List, Optional, Tuple, Union
20
20
 
21
21
  from safetensors import safe_open
22
22
  from torch import Tensor
@@ -30,10 +30,14 @@ __all__ = [
30
30
  "merge_names",
31
31
  "get_weight_mappings",
32
32
  "get_nested_weight_mappings",
33
+ "get_nested_mappings_from_state_dict",
33
34
  "get_quantization_state_dict",
34
35
  "is_quantization_param",
35
36
  ]
36
37
 
38
+ WeightMappingType = Dict[str, str]
39
+ NestedWeightMappingType = Dict[str, WeightMappingType]
40
+
37
41
 
38
42
  def get_safetensors_folder(
39
43
  pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
@@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
92
96
  return header
93
97
 
94
98
 
95
- def match_param_name(full_name: str, param_name: str) -> str:
99
+ def match_param_name(full_name: str, param_name: str) -> Optional[str]:
96
100
  """
97
101
  Helper function extracting the uncompressed parameterized layer name from a
98
102
  compressed name. Assumes the compressed name was merged using merge_names.
@@ -176,38 +180,100 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
176
180
 
177
181
 
178
182
  def get_nested_weight_mappings(
179
- model_path: str, params_to_nest: List[str]
180
- ) -> Dict[str, Dict[str, str]]:
183
+ model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
184
+ ) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
181
185
  """
182
186
  Takes a path to a state dict saved in safetensors format and returns a nested
183
- mapping from uncompressed parameterized layer names to the file locations of each
184
- of the layers compression parameters.
187
+ mapping from uncompressed parameterized layer names to the file locations of
188
+ each layer's compression parameters.
185
189
 
186
- layer.weight: {
190
+ Example of the nested mapping:
191
+ layer: {
187
192
  bitmask: file_location,
188
193
  row_offsets: file_location,
189
194
  shape: file_location,
190
195
  compressed: file_location
191
196
  }
192
197
 
193
- This generalizes to cases where the model is split into multiple safetensors files
198
+ If other parameters are found that do not match the nested parameters, they will
199
+ be returned in a separate dictionary only if return_unmatched_params is True.
200
+ This dictionary may be needed for cases where compressors are stacked (e.g.,
201
+ quantization compression followed by sparse compression).
202
+
203
+ Example of the unmatched params mapping:
204
+ {
205
+ layer.weight_scale: file_location,
206
+ layer.input_scale: file_location
207
+ }
194
208
 
195
- :param model_path: path to safetensors state dict, must contain either a single
196
- safetensors file or multiple files with an index
197
- :return: nested mapping of parameterized layer name to file location
209
+ This generalizes to cases where the model is split into multiple safetensors
210
+ files.
211
+
212
+ :param model_path: Path to the safetensors state dict, must contain either a
213
+ single safetensors file or multiple files with an index.
214
+ :param params_to_nest: List of parameter names to nest.
215
+ :param return_unmatched_params: If True, return a second dictionary containing
216
+ the remaining parameters that were not matched to the params_to_nest.
217
+ :return:
218
+ - If return_unmatched_params is False:
219
+ NestedWeightMappingType: A nested mapping of parameterized layer names to
220
+ file locations of each layer's compression parameters.
221
+ - If return_unmatched_params is True:
222
+ Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
223
+ - NestedWeightMappingType: A nested mapping of parameterized layer
224
+ names to file locations of each layer's compression parameters.
225
+ - WeightMappingType: A mapping of the remaining parameter names to
226
+ their file locations that were not matched to the params_to_nest.
198
227
  """
199
228
  weight_mappings = get_weight_mappings(model_path)
200
-
201
229
  nested_weight_mappings = {}
202
- for key in weight_mappings.keys():
230
+ unmatched_params = {}
231
+
232
+ for key, file_location in weight_mappings.items():
233
+ matched = False
203
234
  for param_name in params_to_nest:
204
- maybe_match = match_param_name(key, param_name)
205
- if maybe_match is not None:
206
- dense_param = maybe_match
235
+ dense_param = match_param_name(key, param_name)
236
+ if dense_param:
207
237
  if dense_param not in nested_weight_mappings:
208
238
  nested_weight_mappings[dense_param] = {}
209
- nested_weight_mappings[dense_param][param_name] = weight_mappings[key]
239
+ nested_weight_mappings[dense_param][param_name] = file_location
240
+ matched = True
241
+ if return_unmatched_params and not matched:
242
+ unmatched_params[key] = file_location
243
+
244
+ if return_unmatched_params:
245
+ return nested_weight_mappings, unmatched_params
246
+ return nested_weight_mappings
210
247
 
248
+
249
+ def get_nested_mappings_from_state_dict(
250
+ state_dict, params_to_nest
251
+ ) -> NestedWeightMappingType:
252
+ """
253
+ Takes a state dict and returns a nested mapping from uncompressed
254
+ parameterized layer names to the value of
255
+ each layer's compression parameters.
256
+
257
+ Example of the nested mapping:
258
+ layer: {
259
+ weight_scale: ...,
260
+ weight: ...,
261
+ zero_point: ...,
262
+ }
263
+
264
+ :param state_dict: state dict of the model
265
+ :param params_to_nest: List of parameter names to nest.
266
+ :return: Nested mapping of parameterized layer names to the value of
267
+ each layer's compression parameters.
268
+ """
269
+ nested_weight_mappings = {}
270
+ for key in state_dict.keys():
271
+ for param_name in params_to_nest:
272
+ dense_param = match_param_name(key, param_name)
273
+ if dense_param:
274
+ if dense_param not in nested_weight_mappings:
275
+ nested_weight_mappings[dense_param] = {}
276
+ nested_weight_mappings[dense_param][param_name] = state_dict[key]
211
277
  return nested_weight_mappings
212
278
 
213
279
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.8.1.20250106
3
+ Version: 0.8.1.20250108
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -5,15 +5,15 @@ compressed_tensors/compressors/__init__.py,sha256=smSygTSfcfuujRrAXDc6uZm4L_ccV1
5
5
  compressed_tensors/compressors/base.py,sha256=D9TNwQcjanDiAHODPbg8JUqc66e3j50rctY7A708NEs,6743
6
6
  compressed_tensors/compressors/helpers.py,sha256=OK6qxX9j3bHwF9JfIYSGMgBJe2PWjlTA3byXKCJaTIQ,5431
7
7
  compressed_tensors/compressors/model_compressors/__init__.py,sha256=5RGGPFu4YqEt_aOdFSQYFYFDjcZFJN0CsMqRtDZz3Js,666
8
- compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=WmjhfGma7gswMHXLaRriyDNrefO5lCmi6rW35dCcLJM,15903
8
+ compressed_tensors/compressors/model_compressors/model_compressor.py,sha256=nsMKqjdzEttvkabpp_7Qt4mhWcmjwRYnwjQzeN2a2E4,18295
9
9
  compressed_tensors/compressors/quantized_compressors/__init__.py,sha256=09UJq68Pht6Bf-4iP9xYl3tetKsncNPHD8IAGbePsr4,714
10
- compressed_tensors/compressors/quantized_compressors/base.py,sha256=K1KOnS6Y8nUA1-HN7VhyfsDc01nilW0WfXMUhuD-l8w,5954
11
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=MMUya3Iwarm0BkeYXqKTUnEDPiBw98GKF09QiNST45k,4960
12
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=1CLwvBlu4AtGkuo3IisD1-rQzwLiA6hE1bCc-pF_XGo,7758
10
+ compressed_tensors/compressors/quantized_compressors/base.py,sha256=LVqSSqSjGi8LB-X13zC_0AFHc8BobGQVC0zjInDhOWE,7217
11
+ compressed_tensors/compressors/quantized_compressors/naive_quantized.py,sha256=fahmPJFz49rVS7q705uQwZ0kUtdP46GuXR7nPr6uIqI,4943
12
+ compressed_tensors/compressors/quantized_compressors/pack_quantized.py,sha256=OO5dceCfNVuY8A23kBg6z2wk-zGUVqR_MyLvObvT7pk,7741
13
13
  compressed_tensors/compressors/sparse_compressors/__init__.py,sha256=i2TESH27l7KXeOhJ6hShIoI904XX96l-cRQiMR6MAaU,704
14
- compressed_tensors/compressors/sparse_compressors/base.py,sha256=Ua4rUSGyucEs-YJI5z3oIUF-zqQLrFsQ9f-qKasEdUM,4410
14
+ compressed_tensors/compressors/sparse_compressors/base.py,sha256=9e841MQWr0j8m33ejDw_jP5_BIpQ5099x9_pvuZ-Nr0,5944
15
15
  compressed_tensors/compressors/sparse_compressors/dense.py,sha256=lSKNWRx6H7aUqaJj1j4qbXk8Gkm1UohbnvW1Rvq6Ra4,1284
16
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=4fKwCG7ZM8mUtSnjPvubzEHl-mTnxMzwjmcs7L43WLY,6622
16
+ compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py,sha256=Z9qMJ2JyUaBNQe-CXBJLuWacnHdFArrJYZEZDmW5x8o,6889
17
17
  compressed_tensors/compressors/sparse_quantized_compressors/__init__.py,sha256=4f_cwcKXB1nVVMoiKgTFAc8jAPjPLElo-Df_EDm1_xw,675
18
18
  compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py,sha256=BMIQWTLlnUvxy14iEJegtiP75WHJeOVojey9mKOK1hE,9427
19
19
  compressed_tensors/config/__init__.py,sha256=ZBqWn3r6ku1qfmlHHYp0mQueY0i7Pwhr9rbQk9dDlMc,704
@@ -27,9 +27,9 @@ compressed_tensors/quantization/quant_args.py,sha256=jwC__lSmuiJ2qSJYYZGgWgQNbZu
27
27
  compressed_tensors/quantization/quant_config.py,sha256=vx06wBo91p4LCb3Vzd-2eCTUeIf_Sz2ZXRP263eQyjQ,10385
28
28
  compressed_tensors/quantization/quant_scheme.py,sha256=eQ0JrRZ80GX69fpwW87VzPzzhajhk4mUaJScjk82OY4,6010
29
29
  compressed_tensors/quantization/lifecycle/__init__.py,sha256=_uItzFWusyV74Zco_pHLOTdE9a83cL-R-ZdyQrBkIyw,772
30
- compressed_tensors/quantization/lifecycle/apply.py,sha256=jCUSgeOBtagE5IhgIbyYMZ4kv8Rm20VGJ4IxXZ5HAnw,15066
30
+ compressed_tensors/quantization/lifecycle/apply.py,sha256=XS4M6N1opKBybhkuQsS338QVb_CKMhUM5TUKrqoNQ0k,16517
31
31
  compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
32
- compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
32
+ compressed_tensors/quantization/lifecycle/forward.py,sha256=DOWouUqfaLA4Qhg-ojVVBdhhSAlgZqFC26vZARxE0ko,12961
33
33
  compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
34
34
  compressed_tensors/quantization/lifecycle/initialize.py,sha256=hymYtayTSumm8KCYAYPY267aWmlsJpt8oQFiRblk8qE,7452
35
35
  compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
@@ -41,10 +41,10 @@ compressed_tensors/utils/helpers.py,sha256=XF36-SLkXnAHh0VzbvUlAdh6a88aCQvS_WeYs
41
41
  compressed_tensors/utils/offload.py,sha256=cMmzd9IdlNbs29CReHj1PPSLUM6OWaT5YumlLT5eP3w,13845
42
42
  compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
43
43
  compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
44
- compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
44
+ compressed_tensors/utils/safetensors_load.py,sha256=fBuoHVPoBt1mkvqFJ60zQIASX_4nhl0-6QfFS27NY8I,11430
45
45
  compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
46
- compressed_tensors_nightly-0.8.1.20250106.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
- compressed_tensors_nightly-0.8.1.20250106.dist-info/METADATA,sha256=6P9p-XL3qm0C8UbupyfkuD6H0uzZ52V0EZyCxNhlSiM,6799
48
- compressed_tensors_nightly-0.8.1.20250106.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
49
- compressed_tensors_nightly-0.8.1.20250106.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
50
- compressed_tensors_nightly-0.8.1.20250106.dist-info/RECORD,,
46
+ compressed_tensors_nightly-0.8.1.20250108.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
47
+ compressed_tensors_nightly-0.8.1.20250108.dist-info/METADATA,sha256=BOjjhhge9r8x_Xos4fkWnGccdA9uQOeRjQ296Kcw2ZU,6799
48
+ compressed_tensors_nightly-0.8.1.20250108.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
49
+ compressed_tensors_nightly-0.8.1.20250108.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
50
+ compressed_tensors_nightly-0.8.1.20250108.dist-info/RECORD,,