compressed-tensors 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
- compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- compressed_tensors/config/__init__.py +1 -0
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/config/sparse_24_bitmask.py +40 -0
- compressed_tensors/linear/compressed_linear.py +3 -1
- compressed_tensors/quantization/lifecycle/apply.py +48 -2
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- compressed_tensors/quantization/quant_args.py +16 -3
- compressed_tensors/quantization/quant_config.py +3 -3
- compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed_tensors/utils/helpers.py +206 -1
- compressed_tensors/utils/offload.py +332 -44
- compressed_tensors/utils/safetensors_load.py +83 -17
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +27 -25
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -11,9 +11,48 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
"""
|
15
|
+
Utilities associated with offloading functionality provided by `accelerate`.
|
16
|
+
|
17
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
18
|
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
19
|
+
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
20
|
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
21
|
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
22
|
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
23
|
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
24
|
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
25
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
26
|
+
"""
|
27
|
+
|
28
|
+
import contextlib
|
29
|
+
from functools import wraps
|
30
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
14
31
|
|
15
32
|
import torch
|
16
|
-
|
33
|
+
|
34
|
+
|
35
|
+
try:
|
36
|
+
from accelerate.hooks import (
|
37
|
+
AlignDevicesHook,
|
38
|
+
add_hook_to_module,
|
39
|
+
remove_hook_from_module,
|
40
|
+
)
|
41
|
+
from accelerate.utils import (
|
42
|
+
OffloadedWeightsLoader,
|
43
|
+
PrefixedDataset,
|
44
|
+
set_module_tensor_to_device,
|
45
|
+
)
|
46
|
+
|
47
|
+
_has_accelerate = True
|
48
|
+
except ImportError:
|
49
|
+
_has_accelerate = False
|
50
|
+
AlignDevicesHook = None
|
51
|
+
add_hook_to_module = None
|
52
|
+
remove_hook_from_module = None
|
53
|
+
OffloadedWeightsLoader = None
|
54
|
+
PrefixedDataset = None
|
55
|
+
set_module_tensor_to_device = None
|
17
56
|
|
18
57
|
|
19
58
|
__all__ = [
|
@@ -22,23 +61,44 @@ __all__ = [
|
|
22
61
|
"get_offloaded_device",
|
23
62
|
"update_prefix_dict",
|
24
63
|
"update_parameter_data",
|
64
|
+
"register_offload_parameter",
|
65
|
+
"update_offload_parameter",
|
66
|
+
"delete_offload_parameter",
|
67
|
+
"has_offloaded_params",
|
68
|
+
"disable_hf_hook",
|
69
|
+
"align_module_device",
|
25
70
|
]
|
26
71
|
|
27
72
|
|
28
|
-
def
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
73
|
+
def check_accelerate(fallback: Any):
|
74
|
+
def decorator(func: Callable[[Any], Any]):
|
75
|
+
if not _has_accelerate:
|
76
|
+
|
77
|
+
@wraps(func)
|
78
|
+
def fallback_fn(*args, **kwargs):
|
79
|
+
return fallback
|
80
|
+
|
81
|
+
return fallback_fn
|
82
|
+
|
83
|
+
return func
|
34
84
|
|
85
|
+
return decorator
|
35
86
|
|
36
|
-
|
87
|
+
|
88
|
+
""" Candidates for Depreciation """
|
89
|
+
|
90
|
+
|
91
|
+
@check_accelerate(fallback=False)
|
92
|
+
def is_module_offloaded(module: torch.nn.Module) -> bool:
|
93
|
+
return has_offloaded_params(module)
|
94
|
+
|
95
|
+
|
96
|
+
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
37
97
|
"""
|
38
|
-
:param module:
|
39
|
-
:return: device
|
98
|
+
:param module: module to check
|
99
|
+
:return: device module is loaded onto during forward pass
|
40
100
|
"""
|
41
|
-
if
|
101
|
+
if has_offloaded_params(module):
|
42
102
|
return module._hf_hook.execution_device
|
43
103
|
device = next(module.parameters()).device
|
44
104
|
|
@@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device:
|
|
49
109
|
return device
|
50
110
|
|
51
111
|
|
52
|
-
def get_offloaded_device(module: Module) -> torch.device:
|
112
|
+
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
|
53
113
|
"""
|
54
|
-
:param module:
|
55
|
-
:return: device
|
114
|
+
:param module: module to check
|
115
|
+
:return: device module is offloaded to onto after forward pass
|
56
116
|
"""
|
57
|
-
if
|
117
|
+
if has_offloaded_params(module):
|
58
118
|
first_key = list(module._hf_hook.weights_map.keys())[0]
|
59
119
|
prefix_dataset = module._hf_hook.weights_map.dataset
|
60
120
|
return prefix_dataset[first_key].device
|
61
121
|
return next(module.parameters()).device
|
62
122
|
|
63
123
|
|
64
|
-
|
124
|
+
@check_accelerate(fallback=None)
|
125
|
+
def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
|
65
126
|
"""
|
66
127
|
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
67
128
|
by data. This is neccesary because parameter updates for offloaded modules do not
|
68
129
|
persist automatically between loads. This function only affects the offloaded
|
69
130
|
state dict and not the current state of the loaded module.
|
70
131
|
|
71
|
-
:param module:
|
132
|
+
:param module: module containing the parameter to update
|
72
133
|
:param key: name of parameter to update
|
73
134
|
:param data: tensor to update parameter with in the offloaded state dict
|
74
135
|
"""
|
75
|
-
if not
|
136
|
+
if not has_offloaded_params(module):
|
76
137
|
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
77
|
-
|
78
|
-
|
138
|
+
|
139
|
+
weights_map = module._hf_hook.weights_map
|
140
|
+
offload_to_weights_map(weights_map, key, data)
|
79
141
|
|
80
142
|
|
81
143
|
def update_parameter_data(
|
82
|
-
module: Module, new_param_data: torch.Tensor, param_name: str
|
144
|
+
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
|
83
145
|
):
|
84
146
|
"""
|
85
|
-
|
86
|
-
|
87
|
-
the module is offloaded. This is neccesary because parameter updates for offloaded
|
88
|
-
modules do not persist automatically between loads.
|
147
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
148
|
+
parameters of offloaded modules and non-offloaded modules
|
89
149
|
|
90
|
-
:param module:
|
150
|
+
:param module: module containing the parameter to update
|
91
151
|
:param new_param_data: tensor to update parameter with
|
92
|
-
:param param_name: name of
|
152
|
+
:param param_name: name of module parameter to update
|
93
153
|
"""
|
94
|
-
|
95
|
-
|
154
|
+
update_offload_parameter(module, param_name, new_param_data)
|
155
|
+
|
156
|
+
|
157
|
+
""" Candidates for Upstreaming """
|
158
|
+
|
159
|
+
|
160
|
+
def register_offload_parameter(
|
161
|
+
module: torch.nn.Module,
|
162
|
+
name: str,
|
163
|
+
parameter: torch.nn.Parameter,
|
164
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
165
|
+
):
|
166
|
+
"""
|
167
|
+
Register a parameter to the given module which may be offloaded
|
168
|
+
|
169
|
+
:param module: maybe offloaded module
|
170
|
+
:param name: name of newly registered parameter
|
171
|
+
:param parameter: parameter being registered
|
172
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
173
|
+
provided, then infer device from parameters on module
|
174
|
+
"""
|
175
|
+
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
|
176
|
+
module.register_parameter(name, parameter)
|
177
|
+
|
178
|
+
if has_offloaded_params(module):
|
179
|
+
weights_map = module._hf_hook.weights_map
|
180
|
+
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
|
181
|
+
if not has_onload:
|
182
|
+
set_module_tensor_to_device(module, name, "meta")
|
183
|
+
|
184
|
+
|
185
|
+
def update_offload_parameter(
|
186
|
+
module: torch.nn.Module,
|
187
|
+
name: str,
|
188
|
+
data: Optional[torch.Tensor],
|
189
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
190
|
+
):
|
191
|
+
"""
|
192
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
193
|
+
parameters of offloaded modules and non-offloaded modules
|
194
|
+
|
195
|
+
:param module: module containing the parameter to update
|
196
|
+
:param name: name of module parameter to update
|
197
|
+
:param data: tensor to update parameter with
|
198
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
199
|
+
provided, then infer device from parameters on module
|
200
|
+
"""
|
201
|
+
param = getattr(module, name)
|
202
|
+
data = data.to(param.dtype)
|
203
|
+
|
204
|
+
# copy data into onloaded parameter if applicable
|
205
|
+
if param.device != "meta":
|
206
|
+
param.data.copy_(data)
|
207
|
+
|
208
|
+
# update offload dict
|
209
|
+
if has_offloaded_params(module):
|
210
|
+
weights_map = module._hf_hook.weights_map
|
211
|
+
offload_to_weights_map(weights_map, name, data, offload_device)
|
212
|
+
|
213
|
+
|
214
|
+
def delete_offload_parameter(module: torch.nn.Module, name: str):
|
215
|
+
"""
|
216
|
+
Delete a parameter from a module which may be offloaded
|
217
|
+
|
218
|
+
:param module: maybe offloaded module
|
219
|
+
:param name: name of parameter being deleted
|
220
|
+
"""
|
221
|
+
delattr(module, name)
|
222
|
+
|
223
|
+
if has_offloaded_params(module):
|
224
|
+
weights_map = module._hf_hook.weights_map
|
225
|
+
delete_from_weights_map(weights_map, name)
|
96
226
|
|
97
|
-
device = next(module.parameters()).device
|
98
227
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
228
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
229
|
+
@contextlib.contextmanager
|
230
|
+
def disable_hf_hook(module: torch.nn.Module):
|
231
|
+
hooks = {}
|
103
232
|
|
104
|
-
|
105
|
-
|
106
|
-
|
233
|
+
def collect_hooks(module):
|
234
|
+
nonlocal hooks
|
235
|
+
if hasattr(module, "_hf_hook"):
|
236
|
+
hooks[module] = module._hf_hook
|
237
|
+
remove_hook_from_module(module)
|
107
238
|
|
108
|
-
|
109
|
-
parameter.data = new_param_data.to(device).to(dtype)
|
239
|
+
module.apply(collect_hooks)
|
110
240
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
241
|
+
yield
|
242
|
+
|
243
|
+
for submodule, hook in hooks.items():
|
244
|
+
add_hook_to_module(submodule, hook)
|
245
|
+
|
246
|
+
|
247
|
+
@check_accelerate(fallback=None)
|
248
|
+
def offload_to_weights_map(
|
249
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
250
|
+
key: str,
|
251
|
+
value: torch.Tensor,
|
252
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
253
|
+
):
|
254
|
+
"""
|
255
|
+
Helper function which implements offloaded item assignment for PrefixedDataset,
|
256
|
+
OffloadedWeightsLoader, and Dict types.
|
257
|
+
|
258
|
+
:param weights_map: weight map to be updated with offload information
|
259
|
+
:param key: key used to identify weight location
|
260
|
+
:param value: weight being offloaded
|
261
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
262
|
+
provided, then infer device from parameters in weights_map
|
263
|
+
"""
|
264
|
+
if isinstance(weights_map, PrefixedDataset):
|
265
|
+
if offload_device == "disk":
|
266
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
267
|
+
|
268
|
+
dataset = weights_map.dataset
|
269
|
+
key = f"{weights_map.prefix}{key}"
|
270
|
+
offload_to_weights_map(dataset, key, value, offload_device)
|
271
|
+
|
272
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
273
|
+
if key not in weights_map.all_keys:
|
274
|
+
weights_map.all_keys.append(key)
|
275
|
+
|
276
|
+
if len(weights_map.index) <= 0 and offload_device != "disk":
|
277
|
+
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
|
278
|
+
|
279
|
+
else:
|
280
|
+
raise NotImplementedError(
|
281
|
+
"Updating weights_map with disk offloading is not implemented yet"
|
282
|
+
)
|
283
|
+
|
284
|
+
elif isinstance(weights_map, dict):
|
285
|
+
if offload_device == "disk":
|
286
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
287
|
+
|
288
|
+
# infer offload device
|
289
|
+
if offload_device is None:
|
290
|
+
if key in weights_map:
|
291
|
+
offload_device = weights_map[key].device
|
292
|
+
else:
|
293
|
+
tens = next(iter(weights_map.values()), None)
|
294
|
+
if tens is None:
|
295
|
+
raise ValueError(
|
296
|
+
"Cannot infer offload device from empty weights_map"
|
297
|
+
)
|
298
|
+
offload_device = tens.device
|
299
|
+
|
300
|
+
weights_map[key] = value.to(device=offload_device)
|
301
|
+
|
302
|
+
else:
|
303
|
+
raise NotImplementedError(
|
304
|
+
"Updating offload data not implemented for weights_map of type "
|
305
|
+
f"{type(weights_map)}"
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
@check_accelerate(fallback=None)
|
310
|
+
def delete_from_weights_map(
|
311
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
312
|
+
key: str,
|
313
|
+
):
|
314
|
+
if isinstance(weights_map, PrefixedDataset):
|
315
|
+
dataset = weights_map.dataset
|
316
|
+
key = f"{weights_map.prefix}{key}"
|
317
|
+
delete_from_weights_map(dataset, key)
|
318
|
+
|
319
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
320
|
+
if len(weights_map.index) <= 0:
|
321
|
+
delete_from_weights_map(weights_map.state_dict, key)
|
322
|
+
|
323
|
+
else:
|
324
|
+
raise NotImplementedError(
|
325
|
+
"Delete from weights_map with disk offloading is not implemented yet"
|
326
|
+
)
|
327
|
+
|
328
|
+
elif isinstance(weights_map, dict):
|
329
|
+
del weights_map[key]
|
330
|
+
|
331
|
+
else:
|
332
|
+
raise NotImplementedError(
|
333
|
+
"Updating offload data not implemented for weights_map of type "
|
334
|
+
f"{type(weights_map)}"
|
116
335
|
)
|
336
|
+
|
337
|
+
|
338
|
+
""" Upstreamed Functions """
|
339
|
+
|
340
|
+
|
341
|
+
# introduced in accelerate v1.1.0
|
342
|
+
@check_accelerate(fallback=False)
|
343
|
+
def has_offloaded_params(module: torch.nn.Module) -> bool:
|
344
|
+
"""
|
345
|
+
Checks if a module has offloaded parameters by checking if the given module has a
|
346
|
+
AlignDevicesHook attached with offloading enabled
|
347
|
+
|
348
|
+
Args:
|
349
|
+
module (`torch.nn.Module`): The module to check for an offload hook.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
bool: `True` if the module has an offload hook and offloading is enabled,
|
353
|
+
`False` otherwise.
|
354
|
+
"""
|
355
|
+
return (
|
356
|
+
hasattr(module, "_hf_hook")
|
357
|
+
and isinstance(module._hf_hook, AlignDevicesHook)
|
358
|
+
and module._hf_hook.offload
|
359
|
+
)
|
360
|
+
|
361
|
+
|
362
|
+
# introduced in accelerate v1.1.0
|
363
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
364
|
+
@contextlib.contextmanager
|
365
|
+
def align_module_device(
|
366
|
+
module: torch.nn.Module, execution_device: Optional[torch.device] = None
|
367
|
+
):
|
368
|
+
"""
|
369
|
+
Context manager that moves a module's parameters to the specified execution device.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
module (`torch.nn.Module`):
|
373
|
+
Module with parameters to align.
|
374
|
+
execution_device (`torch.device`, *optional*):
|
375
|
+
If provided, overrides the module's execution device within the context.
|
376
|
+
Otherwise, use hook execution device or pass
|
377
|
+
"""
|
378
|
+
if has_offloaded_params(module):
|
379
|
+
if execution_device is not None:
|
380
|
+
original_device = module._hf_hook.execution_device
|
381
|
+
module._hf_hook.execution_device = execution_device
|
382
|
+
|
383
|
+
try:
|
384
|
+
module._hf_hook.pre_forward(module)
|
385
|
+
yield
|
386
|
+
finally:
|
387
|
+
module._hf_hook.post_forward(module, None)
|
388
|
+
if execution_device is not None:
|
389
|
+
module._hf_hook.execution_device = original_device
|
390
|
+
|
391
|
+
elif execution_device is not None:
|
392
|
+
devices = {
|
393
|
+
name: param.device for name, param in module.named_parameters(recurse=False)
|
394
|
+
}
|
395
|
+
try:
|
396
|
+
for name in devices:
|
397
|
+
set_module_tensor_to_device(module, name, execution_device)
|
398
|
+
yield
|
399
|
+
finally:
|
400
|
+
for name, device in devices.items():
|
401
|
+
set_module_tensor_to_device(module, name, device)
|
402
|
+
|
403
|
+
else:
|
404
|
+
yield
|
@@ -16,7 +16,7 @@ import json
|
|
16
16
|
import os
|
17
17
|
import re
|
18
18
|
import struct
|
19
|
-
from typing import Dict, List, Optional
|
19
|
+
from typing import Dict, List, Optional, Tuple, Union
|
20
20
|
|
21
21
|
from safetensors import safe_open
|
22
22
|
from torch import Tensor
|
@@ -30,10 +30,14 @@ __all__ = [
|
|
30
30
|
"merge_names",
|
31
31
|
"get_weight_mappings",
|
32
32
|
"get_nested_weight_mappings",
|
33
|
+
"get_nested_mappings_from_state_dict",
|
33
34
|
"get_quantization_state_dict",
|
34
35
|
"is_quantization_param",
|
35
36
|
]
|
36
37
|
|
38
|
+
WeightMappingType = Dict[str, str]
|
39
|
+
NestedWeightMappingType = Dict[str, WeightMappingType]
|
40
|
+
|
37
41
|
|
38
42
|
def get_safetensors_folder(
|
39
43
|
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
|
@@ -92,7 +96,7 @@ def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
|
|
92
96
|
return header
|
93
97
|
|
94
98
|
|
95
|
-
def match_param_name(full_name: str, param_name: str) -> str:
|
99
|
+
def match_param_name(full_name: str, param_name: str) -> Optional[str]:
|
96
100
|
"""
|
97
101
|
Helper function extracting the uncompressed parameterized layer name from a
|
98
102
|
compressed name. Assumes the compressed name was merged using merge_names.
|
@@ -176,38 +180,100 @@ def get_weight_mappings(path_to_model_or_tensors: str) -> Dict[str, str]:
|
|
176
180
|
|
177
181
|
|
178
182
|
def get_nested_weight_mappings(
|
179
|
-
model_path: str, params_to_nest: List[str]
|
180
|
-
) ->
|
183
|
+
model_path: str, params_to_nest: List[str], return_unmatched_params: bool = False
|
184
|
+
) -> Union[NestedWeightMappingType, Tuple[NestedWeightMappingType, WeightMappingType]]:
|
181
185
|
"""
|
182
186
|
Takes a path to a state dict saved in safetensors format and returns a nested
|
183
|
-
mapping from uncompressed parameterized layer names to the file locations of
|
184
|
-
|
187
|
+
mapping from uncompressed parameterized layer names to the file locations of
|
188
|
+
each layer's compression parameters.
|
185
189
|
|
186
|
-
|
190
|
+
Example of the nested mapping:
|
191
|
+
layer: {
|
187
192
|
bitmask: file_location,
|
188
193
|
row_offsets: file_location,
|
189
194
|
shape: file_location,
|
190
195
|
compressed: file_location
|
191
196
|
}
|
192
197
|
|
193
|
-
|
198
|
+
If other parameters are found that do not match the nested parameters, they will
|
199
|
+
be returned in a separate dictionary only if return_unmatched_params is True.
|
200
|
+
This dictionary may be needed for cases where compressors are stacked (e.g.,
|
201
|
+
quantization compression followed by sparse compression).
|
202
|
+
|
203
|
+
Example of the unmatched params mapping:
|
204
|
+
{
|
205
|
+
layer.weight_scale: file_location,
|
206
|
+
layer.input_scale: file_location
|
207
|
+
}
|
194
208
|
|
195
|
-
|
196
|
-
|
197
|
-
|
209
|
+
This generalizes to cases where the model is split into multiple safetensors
|
210
|
+
files.
|
211
|
+
|
212
|
+
:param model_path: Path to the safetensors state dict, must contain either a
|
213
|
+
single safetensors file or multiple files with an index.
|
214
|
+
:param params_to_nest: List of parameter names to nest.
|
215
|
+
:param return_unmatched_params: If True, return a second dictionary containing
|
216
|
+
the remaining parameters that were not matched to the params_to_nest.
|
217
|
+
:return:
|
218
|
+
- If return_unmatched_params is False:
|
219
|
+
NestedWeightMappingType: A nested mapping of parameterized layer names to
|
220
|
+
file locations of each layer's compression parameters.
|
221
|
+
- If return_unmatched_params is True:
|
222
|
+
Tuple[NestedWeightMappingType, WeightMappingType]: A tuple containing:
|
223
|
+
- NestedWeightMappingType: A nested mapping of parameterized layer
|
224
|
+
names to file locations of each layer's compression parameters.
|
225
|
+
- WeightMappingType: A mapping of the remaining parameter names to
|
226
|
+
their file locations that were not matched to the params_to_nest.
|
198
227
|
"""
|
199
228
|
weight_mappings = get_weight_mappings(model_path)
|
200
|
-
|
201
229
|
nested_weight_mappings = {}
|
202
|
-
|
230
|
+
unmatched_params = {}
|
231
|
+
|
232
|
+
for key, file_location in weight_mappings.items():
|
233
|
+
matched = False
|
203
234
|
for param_name in params_to_nest:
|
204
|
-
|
205
|
-
if
|
206
|
-
dense_param = maybe_match
|
235
|
+
dense_param = match_param_name(key, param_name)
|
236
|
+
if dense_param:
|
207
237
|
if dense_param not in nested_weight_mappings:
|
208
238
|
nested_weight_mappings[dense_param] = {}
|
209
|
-
nested_weight_mappings[dense_param][param_name] =
|
239
|
+
nested_weight_mappings[dense_param][param_name] = file_location
|
240
|
+
matched = True
|
241
|
+
if return_unmatched_params and not matched:
|
242
|
+
unmatched_params[key] = file_location
|
243
|
+
|
244
|
+
if return_unmatched_params:
|
245
|
+
return nested_weight_mappings, unmatched_params
|
246
|
+
return nested_weight_mappings
|
210
247
|
|
248
|
+
|
249
|
+
def get_nested_mappings_from_state_dict(
|
250
|
+
state_dict, params_to_nest
|
251
|
+
) -> NestedWeightMappingType:
|
252
|
+
"""
|
253
|
+
Takes a state dict and returns a nested mapping from uncompressed
|
254
|
+
parameterized layer names to the value of
|
255
|
+
each layer's compression parameters.
|
256
|
+
|
257
|
+
Example of the nested mapping:
|
258
|
+
layer: {
|
259
|
+
weight_scale: ...,
|
260
|
+
weight: ...,
|
261
|
+
zero_point: ...,
|
262
|
+
}
|
263
|
+
|
264
|
+
:param state_dict: state dict of the model
|
265
|
+
:param params_to_nest: List of parameter names to nest.
|
266
|
+
:return: Nested mapping of parameterized layer names to the value of
|
267
|
+
each layer's compression parameters.
|
268
|
+
"""
|
269
|
+
nested_weight_mappings = {}
|
270
|
+
for key in state_dict.keys():
|
271
|
+
for param_name in params_to_nest:
|
272
|
+
dense_param = match_param_name(key, param_name)
|
273
|
+
if dense_param:
|
274
|
+
if dense_param not in nested_weight_mappings:
|
275
|
+
nested_weight_mappings[dense_param] = {}
|
276
|
+
nested_weight_mappings[dense_param][param_name] = state_dict[key]
|
211
277
|
return nested_weight_mappings
|
212
278
|
|
213
279
|
|
compressed_tensors/version.py
CHANGED