compressed-tensors 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +76 -14
  2. compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  3. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
  5. compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
  6. compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  7. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +240 -0
  8. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  9. compressed_tensors/config/__init__.py +1 -0
  10. compressed_tensors/config/base.py +1 -0
  11. compressed_tensors/config/sparse_24_bitmask.py +40 -0
  12. compressed_tensors/quantization/lifecycle/apply.py +46 -1
  13. compressed_tensors/quantization/lifecycle/forward.py +2 -2
  14. compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  15. compressed_tensors/quantization/quant_config.py +1 -1
  16. compressed_tensors/utils/helpers.py +174 -1
  17. compressed_tensors/utils/offload.py +332 -44
  18. compressed_tensors/utils/safetensors_load.py +83 -17
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/METADATA +1 -1
  21. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/RECORD +24 -22
  22. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/LICENSE +0 -0
  23. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/WHEEL +0 -0
  24. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.1.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,48 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+ """
15
+ Utilities associated with offloading functionality provided by `accelerate`.
16
+
17
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
18
+ | Operation | Without offloading support | With offloading support | # noqa: E501
19
+ | --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
+ | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
+ | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
+ | Onload | N/A | with align_module_device(module) | # noqa: E501
23
+ | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
+ | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
26
+ """
27
+
28
+ import contextlib
29
+ from functools import wraps
30
+ from typing import Any, Callable, Dict, Literal, Optional, Union
14
31
 
15
32
  import torch
16
- from torch.nn import Module
33
+
34
+
35
+ try:
36
+ from accelerate.hooks import (
37
+ AlignDevicesHook,
38
+ add_hook_to_module,
39
+ remove_hook_from_module,
40
+ )
41
+ from accelerate.utils import (
42
+ OffloadedWeightsLoader,
43
+ PrefixedDataset,
44
+ set_module_tensor_to_device,
45
+ )
46
+
47
+ _has_accelerate = True
48
+ except ImportError:
49
+ _has_accelerate = False
50
+ AlignDevicesHook = None
51
+ add_hook_to_module = None
52
+ remove_hook_from_module = None
53
+ OffloadedWeightsLoader = None
54
+ PrefixedDataset = None
55
+ set_module_tensor_to_device = None
17
56
 
18
57
 
19
58
  __all__ = [
@@ -22,23 +61,44 @@ __all__ = [
22
61
  "get_offloaded_device",
23
62
  "update_prefix_dict",
24
63
  "update_parameter_data",
64
+ "register_offload_parameter",
65
+ "update_offload_parameter",
66
+ "delete_offload_parameter",
67
+ "has_offloaded_params",
68
+ "disable_hf_hook",
69
+ "align_module_device",
25
70
  ]
26
71
 
27
72
 
28
- def is_module_offloaded(module: Module) -> bool:
29
- """
30
- :param module: layer to check
31
- :return: True if layer is offloaded from GPU, False otherwise
32
- """
33
- return hasattr(module, "_hf_hook") and module._hf_hook.offload
73
+ def check_accelerate(fallback: Any):
74
+ def decorator(func: Callable[[Any], Any]):
75
+ if not _has_accelerate:
76
+
77
+ @wraps(func)
78
+ def fallback_fn(*args, **kwargs):
79
+ return fallback
80
+
81
+ return fallback_fn
82
+
83
+ return func
34
84
 
85
+ return decorator
35
86
 
36
- def get_execution_device(module: Module) -> torch.device:
87
+
88
+ """ Candidates for Depreciation """
89
+
90
+
91
+ @check_accelerate(fallback=False)
92
+ def is_module_offloaded(module: torch.nn.Module) -> bool:
93
+ return has_offloaded_params(module)
94
+
95
+
96
+ def get_execution_device(module: torch.nn.Module) -> torch.device:
37
97
  """
38
- :param module: layer to check
39
- :return: device layer is loaded onto during forward pass
98
+ :param module: module to check
99
+ :return: device module is loaded onto during forward pass
40
100
  """
41
- if is_module_offloaded(module):
101
+ if has_offloaded_params(module):
42
102
  return module._hf_hook.execution_device
43
103
  device = next(module.parameters()).device
44
104
 
@@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device:
49
109
  return device
50
110
 
51
111
 
52
- def get_offloaded_device(module: Module) -> torch.device:
112
+ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
53
113
  """
54
- :param module: layer to check
55
- :return: device layer is offloaded to onto after forward pass
114
+ :param module: module to check
115
+ :return: device module is offloaded to onto after forward pass
56
116
  """
57
- if is_module_offloaded(module):
117
+ if has_offloaded_params(module):
58
118
  first_key = list(module._hf_hook.weights_map.keys())[0]
59
119
  prefix_dataset = module._hf_hook.weights_map.dataset
60
120
  return prefix_dataset[first_key].device
61
121
  return next(module.parameters()).device
62
122
 
63
123
 
64
- def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
124
+ @check_accelerate(fallback=None)
125
+ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
65
126
  """
66
127
  Updates the offloaded state dict for a given module. Parameter named key is replaced
67
128
  by data. This is neccesary because parameter updates for offloaded modules do not
68
129
  persist automatically between loads. This function only affects the offloaded
69
130
  state dict and not the current state of the loaded module.
70
131
 
71
- :param module: layer containing the parameter to update
132
+ :param module: module containing the parameter to update
72
133
  :param key: name of parameter to update
73
134
  :param data: tensor to update parameter with in the offloaded state dict
74
135
  """
75
- if not is_module_offloaded(module):
136
+ if not has_offloaded_params(module):
76
137
  raise ValueError("Prefix dict is only applicable to offloaded modules")
77
- prefix_dict = module._hf_hook.weights_map
78
- prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
138
+
139
+ weights_map = module._hf_hook.weights_map
140
+ offload_to_weights_map(weights_map, key, data)
79
141
 
80
142
 
81
143
  def update_parameter_data(
82
- module: Module, new_param_data: torch.Tensor, param_name: str
144
+ module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
83
145
  ):
84
146
  """
85
- Updates the paramter value named param_name for a given module. This function
86
- updates both the current loaded module state and the offloaded state dict if
87
- the module is offloaded. This is neccesary because parameter updates for offloaded
88
- modules do not persist automatically between loads.
147
+ Update the data of an existing parameter and its offload dict. Supports both
148
+ parameters of offloaded modules and non-offloaded modules
89
149
 
90
- :param module: layer containing the parameter to update
150
+ :param module: module containing the parameter to update
91
151
  :param new_param_data: tensor to update parameter with
92
- :param param_name: name of layer parameter to update
152
+ :param param_name: name of module parameter to update
93
153
  """
94
- if not hasattr(module, param_name):
95
- return
154
+ update_offload_parameter(module, param_name, new_param_data)
155
+
156
+
157
+ """ Candidates for Upstreaming """
158
+
159
+
160
+ def register_offload_parameter(
161
+ module: torch.nn.Module,
162
+ name: str,
163
+ parameter: torch.nn.Parameter,
164
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
165
+ ):
166
+ """
167
+ Register a parameter to the given module which may be offloaded
168
+
169
+ :param module: maybe offloaded module
170
+ :param name: name of newly registered parameter
171
+ :param parameter: parameter being registered
172
+ :param offload_device: device on which weight will be offloaded to. If None is
173
+ provided, then infer device from parameters on module
174
+ """
175
+ has_onload = any(p.device != torch.device("meta") for p in module.parameters())
176
+ module.register_parameter(name, parameter)
177
+
178
+ if has_offloaded_params(module):
179
+ weights_map = module._hf_hook.weights_map
180
+ offload_to_weights_map(weights_map, name, parameter.data, offload_device)
181
+ if not has_onload:
182
+ set_module_tensor_to_device(module, name, "meta")
183
+
184
+
185
+ def update_offload_parameter(
186
+ module: torch.nn.Module,
187
+ name: str,
188
+ data: Optional[torch.Tensor],
189
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
190
+ ):
191
+ """
192
+ Update the data of an existing parameter and its offload dict. Supports both
193
+ parameters of offloaded modules and non-offloaded modules
194
+
195
+ :param module: module containing the parameter to update
196
+ :param name: name of module parameter to update
197
+ :param data: tensor to update parameter with
198
+ :param offload_device: device on which weight will be offloaded to. If None is
199
+ provided, then infer device from parameters on module
200
+ """
201
+ param = getattr(module, name)
202
+ data = data.to(param.dtype)
203
+
204
+ # copy data into onloaded parameter if applicable
205
+ if param.device != "meta":
206
+ param.data.copy_(data)
207
+
208
+ # update offload dict
209
+ if has_offloaded_params(module):
210
+ weights_map = module._hf_hook.weights_map
211
+ offload_to_weights_map(weights_map, name, data, offload_device)
212
+
213
+
214
+ def delete_offload_parameter(module: torch.nn.Module, name: str):
215
+ """
216
+ Delete a parameter from a module which may be offloaded
217
+
218
+ :param module: maybe offloaded module
219
+ :param name: name of parameter being deleted
220
+ """
221
+ delattr(module, name)
222
+
223
+ if has_offloaded_params(module):
224
+ weights_map = module._hf_hook.weights_map
225
+ delete_from_weights_map(weights_map, name)
96
226
 
97
- device = next(module.parameters()).device
98
227
 
99
- offloaded = False
100
- if is_module_offloaded(module):
101
- offload_device = get_offloaded_device(module)
102
- offloaded = True
228
+ @check_accelerate(fallback=contextlib.nullcontext())
229
+ @contextlib.contextmanager
230
+ def disable_hf_hook(module: torch.nn.Module):
231
+ hooks = {}
103
232
 
104
- parameter = getattr(module, param_name, None)
105
- if parameter is None:
106
- raise ValueError("Attempted to update uninitialized parameter")
233
+ def collect_hooks(module):
234
+ nonlocal hooks
235
+ if hasattr(module, "_hf_hook"):
236
+ hooks[module] = module._hf_hook
237
+ remove_hook_from_module(module)
107
238
 
108
- dtype = parameter.dtype
109
- parameter.data = new_param_data.to(device).to(dtype)
239
+ module.apply(collect_hooks)
110
240
 
111
- if offloaded:
112
- prefix_dict = module._hf_hook.weights_map.dataset
113
- prefix = module._hf_hook.weights_map.prefix
114
- prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
115
- dtype
241
+ yield
242
+
243
+ for submodule, hook in hooks.items():
244
+ add_hook_to_module(submodule, hook)
245
+
246
+
247
+ @check_accelerate(fallback=None)
248
+ def offload_to_weights_map(
249
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
250
+ key: str,
251
+ value: torch.Tensor,
252
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
253
+ ):
254
+ """
255
+ Helper function which implements offloaded item assignment for PrefixedDataset,
256
+ OffloadedWeightsLoader, and Dict types.
257
+
258
+ :param weights_map: weight map to be updated with offload information
259
+ :param key: key used to identify weight location
260
+ :param value: weight being offloaded
261
+ :param offload_device: device on which weight will be offloaded to. If None is
262
+ provided, then infer device from parameters in weights_map
263
+ """
264
+ if isinstance(weights_map, PrefixedDataset):
265
+ if offload_device == "disk":
266
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
267
+
268
+ dataset = weights_map.dataset
269
+ key = f"{weights_map.prefix}{key}"
270
+ offload_to_weights_map(dataset, key, value, offload_device)
271
+
272
+ elif isinstance(weights_map, OffloadedWeightsLoader):
273
+ if key not in weights_map.all_keys:
274
+ weights_map.all_keys.append(key)
275
+
276
+ if len(weights_map.index) <= 0 and offload_device != "disk":
277
+ offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
278
+
279
+ else:
280
+ raise NotImplementedError(
281
+ "Updating weights_map with disk offloading is not implemented yet"
282
+ )
283
+
284
+ elif isinstance(weights_map, dict):
285
+ if offload_device == "disk":
286
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
287
+
288
+ # infer offload device
289
+ if offload_device is None:
290
+ if key in weights_map:
291
+ offload_device = weights_map[key].device
292
+ else:
293
+ tens = next(iter(weights_map.values()), None)
294
+ if tens is None:
295
+ raise ValueError(
296
+ "Cannot infer offload device from empty weights_map"
297
+ )
298
+ offload_device = tens.device
299
+
300
+ weights_map[key] = value.to(device=offload_device)
301
+
302
+ else:
303
+ raise NotImplementedError(
304
+ "Updating offload data not implemented for weights_map of type "
305
+ f"{type(weights_map)}"
306
+ )
307
+
308
+
309
+ @check_accelerate(fallback=None)
310
+ def delete_from_weights_map(
311
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
312
+ key: str,
313
+ ):
314
+ if isinstance(weights_map, PrefixedDataset):
315
+ dataset = weights_map.dataset
316
+ key = f"{weights_map.prefix}{key}"
317
+ delete_from_weights_map(dataset, key)
318
+
319
+ elif isinstance(weights_map, OffloadedWeightsLoader):
320
+ if len(weights_map.index) <= 0:
321
+ delete_from_weights_map(weights_map.state_dict, key)
322
+
323
+ else:
324
+ raise NotImplementedError(
325
+ "Delete from weights_map with disk offloading is not implemented yet"
326
+ )
327
+
328
+ elif isinstance(weights_map, dict):
329
+ del weights_map[key]
330
+
331
+ else:
332
+ raise NotImplementedError(
333
+ "Updating offload data not implemented for weights_map of type "
334
+ f"{type(weights_map)}"
116
335
  )
336
+
337
+
338
+ """ Upstreamed Functions """
339
+
340
+
341
+ # introduced in accelerate v1.1.0
342
+ @check_accelerate(fallback=False)
343
+ def has_offloaded_params(module: torch.nn.Module) -> bool:
344
+ """
345
+ Checks if a module has offloaded parameters by checking if the given module has a
346
+ AlignDevicesHook attached with offloading enabled
347
+
348
+ Args:
349
+ module (`torch.nn.Module`): The module to check for an offload hook.
350
+
351
+ Returns:
352
+ bool: `True` if the module has an offload hook and offloading is enabled,
353
+ `False` otherwise.
354
+ """
355
+ return (
356
+ hasattr(module, "_hf_hook")
357
+ and isinstance(module._hf_hook, AlignDevicesHook)
358
+ and module._hf_hook.offload
359
+ )
360
+
361
+
362
+ # introduced in accelerate v1.1.0
363
+ @check_accelerate(fallback=contextlib.nullcontext())
364
+ @contextlib.contextmanager
365
+ def align_module_device(
366
+ module: torch.nn.Module, execution_device: Optional[torch.device] = None
367
+ ):
368
+ """
369
+ Context manager that moves a module's parameters to the specified execution device.
370
+
371
+ Args:
372
+ module (`torch.nn.Module`):
373
+ Module with parameters to align.
374
+ execution_device (`torch.device`, *optional*):
375
+ If provided, overrides the module's execution device within the context.
376
+ Otherwise, use hook execution device or pass
377
+ """
378
+ if has_offloaded_params(module):
379
+ if execution_device is not None:
380
+ original_device = module._hf_hook.execution_device
381
+ module._hf_hook.execution_device = execution_device
382
+
383
+ try:
384
+ module._hf_hook.pre_forward(module)
385
+ yield
386
+ finally:
387
+ module._hf_hook.post_forward(module, None)
388
+ if execution_device is not None:
389
+ module._hf_hook.execution_device = original_device
390
+
391
+ elif execution_device is not None:
392
+ devices = {
393
+ name: param.device for name, param in module.named_parameters(recurse=False)
394
+ }
395
+ try:
396
+ for name in devices:
397
+ set_module_tensor_to_device(module, name, execution_device)
398
+ yield
399
+ finally:
400
+ for name, device in devices.items():
401
+ set_module_tensor_to_device(module, name, device)
402
+
403
+ else:
404
+ yield
@@ -16,7 +16,7 @@ import json
16
16
  import os
17
17
  import re
18
18
  import struct
19
- from typing import Dict, List, Optional
19
+ from typing import Dict, List, Optional, Tuple, Union
20
20
 
21
21
  from safetensors import safe_open
22
22
  from torch import Tensor
@@ -30,10 +30,14 @@ __all__ = [
30
30
  "merge_names",
31
31
  "get_weight_mappings",
32
32
  "get_nested_weight_mappings",
33
+ "get_nested_mappings_from_state_dict",
33
34
  "get_quantization_state_dict",
34
35
  "is_quantization_param",
35
36
  ]
36
37
 
38
+ WeightMappingType = Dict[str, str]
39
+ NestedWeightMappingType = Dict[str, WeightMappingType]
40
+
37
41
 
38
42
  def get_safetensors_folder(
39
43
  pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
@@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
92
96
  return header
93
97
 
94
98
 
95
- def match_param_name(full_name: str, param_name: str) -> str:
99
+ def match_param_name(full_name: str, param_name: str) -> Optional[str]:
96
100
  """
97
101
  Helper function extracting the uncompressed parameterized layer name from a
98
102
  compressed name. Assumes the compressed name was merged using merge_names.
@@ -176,38 +180,100 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
176
180
 
177
181
 
178
182
  def get_nested_weight_mappings(
179
- model_path: str, params_to_nest: List[str]
180
- ) -> Dict[str, Dict[str, str]]:
183
+ model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
184
+ ) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
181
185
  """
182
186
  Takes a path to a state dict saved in safetensors format and returns a nested
183
- mapping from uncompressed parameterized layer names to the file locations of each
184
- of the layers compression parameters.
187
+ mapping from uncompressed parameterized layer names to the file locations of
188
+ each layer's compression parameters.
185
189
 
186
- layer.weight: {
190
+ Example of the nested mapping:
191
+ layer: {
187
192
  bitmask: file_location,
188
193
  row_offsets: file_location,
189
194
  shape: file_location,
190
195
  compressed: file_location
191
196
  }
192
197
 
193
- This generalizes to cases where the model is split into multiple safetensors files
198
+ If other parameters are found that do not match the nested parameters, they will
199
+ be returned in a separate dictionary only if return_unmatched_params is True.
200
+ This dictionary may be needed for cases where compressors are stacked (e.g.,
201
+ quantization compression followed by sparse compression).
202
+
203
+ Example of the unmatched params mapping:
204
+ {
205
+ layer.weight_scale: file_location,
206
+ layer.input_scale: file_location
207
+ }
194
208
 
195
- :param model_path: path to safetensors state dict, must contain either a single
196
- safetensors file or multiple files with an index
197
- :return: nested mapping of parameterized layer name to file location
209
+ This generalizes to cases where the model is split into multiple safetensors
210
+ files.
211
+
212
+ :param model_path: Path to the safetensors state dict, must contain either a
213
+ single safetensors file or multiple files with an index.
214
+ :param params_to_nest: List of parameter names to nest.
215
+ :param return_unmatched_params: If True, return a second dictionary containing
216
+ the remaining parameters that were not matched to the params_to_nest.
217
+ :return:
218
+ - If return_unmatched_params is False:
219
+ NestedWeightMappingType: A nested mapping of parameterized layer names to
220
+ file locations of each layer's compression parameters.
221
+ - If return_unmatched_params is True:
222
+ Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
223
+ - NestedWeightMappingType: A nested mapping of parameterized layer
224
+ names to file locations of each layer's compression parameters.
225
+ - WeightMappingType: A mapping of the remaining parameter names to
226
+ their file locations that were not matched to the params_to_nest.
198
227
  """
199
228
  weight_mappings = get_weight_mappings(model_path)
200
-
201
229
  nested_weight_mappings = {}
202
- for key in weight_mappings.keys():
230
+ unmatched_params = {}
231
+
232
+ for key, file_location in weight_mappings.items():
233
+ matched = False
203
234
  for param_name in params_to_nest:
204
- maybe_match = match_param_name(key, param_name)
205
- if maybe_match is not None:
206
- dense_param = maybe_match
235
+ dense_param = match_param_name(key, param_name)
236
+ if dense_param:
207
237
  if dense_param not in nested_weight_mappings:
208
238
  nested_weight_mappings[dense_param] = {}
209
- nested_weight_mappings[dense_param][param_name] = weight_mappings[key]
239
+ nested_weight_mappings[dense_param][param_name] = file_location
240
+ matched = True
241
+ if return_unmatched_params and not matched:
242
+ unmatched_params[key] = file_location
243
+
244
+ if return_unmatched_params:
245
+ return nested_weight_mappings, unmatched_params
246
+ return nested_weight_mappings
210
247
 
248
+
249
+ def get_nested_mappings_from_state_dict(
250
+ state_dict, params_to_nest
251
+ ) -> NestedWeightMappingType:
252
+ """
253
+ Takes a state dict and returns a nested mapping from uncompressed
254
+ parameterized layer names to the value of
255
+ each layer's compression parameters.
256
+
257
+ Example of the nested mapping:
258
+ layer: {
259
+ weight_scale: ...,
260
+ weight: ...,
261
+ zero_point: ...,
262
+ }
263
+
264
+ :param state_dict: state dict of the model
265
+ :param params_to_nest: List of parameter names to nest.
266
+ :return: Nested mapping of parameterized layer names to the value of
267
+ each layer's compression parameters.
268
+ """
269
+ nested_weight_mappings = {}
270
+ for key in state_dict.keys():
271
+ for param_name in params_to_nest:
272
+ dense_param = match_param_name(key, param_name)
273
+ if dense_param:
274
+ if dense_param not in nested_weight_mappings:
275
+ nested_weight_mappings[dense_param] = {}
276
+ nested_weight_mappings[dense_param][param_name] = state_dict[key]
211
277
  return nested_weight_mappings
212
278
 
213
279
 
@@ -17,7 +17,7 @@ Functionality for storing and setting the version info for SparseML
17
17
  """
18
18
 
19
19
 
20
- version_base = "0.8.1"
20
+ version_base = "0.9.1"
21
21
  is_release = True # change to True to set the generated version as a release version
22
22
 
23
23
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors
3
- Version: 0.8.1
3
+ Version: 0.9.1
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.