compressed-tensors-nightly 0.8.1.20241220__py3-none-any.whl → 0.8.1.20241225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/quantization/lifecycle/initialize.py +17 -44
- compressed_tensors/utils/helpers.py +64 -1
- compressed_tensors/utils/offload.py +332 -44
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/RECORD +8 -8
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
30
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
31
|
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
32
|
-
from compressed_tensors.utils import
|
32
|
+
from compressed_tensors.utils import (
|
33
|
+
disable_hf_hook,
|
34
|
+
has_offloaded_params,
|
35
|
+
register_offload_parameter,
|
36
|
+
)
|
33
37
|
from torch.nn import Module, Parameter
|
34
38
|
|
35
39
|
|
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
|
|
112
116
|
module.quantization_scheme = scheme
|
113
117
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
114
118
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
120
|
-
from accelerate.utils import PrefixedDataset
|
121
|
-
except ModuleNotFoundError:
|
122
|
-
raise ModuleNotFoundError(
|
123
|
-
"Offloaded model detected. To use CPU offloading with "
|
124
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
125
|
-
"run `pip install compressed-tensors[accelerate]`"
|
126
|
-
)
|
127
|
-
|
128
|
-
offloaded = True
|
129
|
-
hook = module._hf_hook
|
130
|
-
prefix_dict = module._hf_hook.weights_map
|
131
|
-
new_prefix = {}
|
132
|
-
|
133
|
-
# recreate the prefix dict (since it is immutable)
|
134
|
-
# and add quantization parameters
|
135
|
-
for key, data in module.named_parameters():
|
136
|
-
if key not in prefix_dict:
|
137
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
138
|
-
else:
|
139
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
140
|
-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
141
|
-
remove_hook_from_module(module)
|
142
|
-
|
143
|
-
# wrap forward call of module to perform
|
144
|
-
# quantized actions based on calltime status
|
145
|
-
wrap_module_forward_quantized(module, scheme)
|
146
|
-
|
147
|
-
if offloaded:
|
148
|
-
# we need to re-add the hook for offloading now that we've wrapped forward
|
149
|
-
add_hook_to_module(module, hook)
|
150
|
-
if prefix_dict is not None:
|
151
|
-
module._hf_hook.weights_map = new_prefix_dict
|
119
|
+
with disable_hf_hook(module):
|
120
|
+
# wrap forward call of module to perform
|
121
|
+
# quantized actions based on calltime status
|
122
|
+
wrap_module_forward_quantized(module, scheme)
|
152
123
|
|
153
124
|
|
154
125
|
def is_attention_module(module: Module):
|
@@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
|
|
169
140
|
if quantization_args.dynamic:
|
170
141
|
return
|
171
142
|
|
172
|
-
device
|
173
|
-
|
174
|
-
|
143
|
+
# begin on the same device as other parameters or cpu if offloaded.
|
144
|
+
# in the offloaded case, there's no point moving tensors to the execution device
|
145
|
+
# if they're going to be immediately offloaded by `register_offload_parameter`
|
146
|
+
params_device = next(module.parameters()).device
|
147
|
+
device = "cpu" if has_offloaded_params(module) else params_device
|
175
148
|
|
176
149
|
# infer expected scale/zero point shape
|
177
150
|
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
@@ -196,7 +169,7 @@ def _initialize_scale_zero_point(
|
|
196
169
|
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
197
170
|
requires_grad=False,
|
198
171
|
)
|
199
|
-
module
|
172
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
200
173
|
|
201
174
|
if force_zero_point or not quantization_args.symmetric:
|
202
175
|
zp_dtype = quantization_args.pytorch_dtype()
|
@@ -204,7 +177,7 @@ def _initialize_scale_zero_point(
|
|
204
177
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
205
178
|
requires_grad=False,
|
206
179
|
)
|
207
|
-
module
|
180
|
+
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
|
208
181
|
|
209
182
|
# only grouped activation ordering has g_idx
|
210
183
|
if quantization_args.actorder == ActivationOrdering.GROUP:
|
@@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
|
|
214
187
|
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
215
188
|
requires_grad=False,
|
216
189
|
)
|
217
|
-
module
|
190
|
+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
|
218
191
|
|
219
192
|
|
220
193
|
def _initialize_attn_scales(module: Module) -> None:
|
@@ -12,7 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import warnings
|
16
|
+
from functools import wraps
|
17
|
+
from typing import Any, Callable, Dict, Optional
|
16
18
|
|
17
19
|
import torch
|
18
20
|
from transformers import AutoConfig
|
@@ -24,6 +26,8 @@ __all__ = [
|
|
24
26
|
"tensor_follows_mask_structure",
|
25
27
|
"replace_module",
|
26
28
|
"is_compressed_tensors_config",
|
29
|
+
"getattr_chain",
|
30
|
+
"deprecated",
|
27
31
|
"Aliasable",
|
28
32
|
]
|
29
33
|
|
@@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
|
|
122
126
|
return False
|
123
127
|
|
124
128
|
|
129
|
+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
|
130
|
+
"""
|
131
|
+
Chain multiple getattr calls, separated by `.`
|
132
|
+
|
133
|
+
:param obj: base object whose attributes are being retrieved
|
134
|
+
:param chain_str: attribute names separated by `.`
|
135
|
+
:param default: default value, throw error otherwise
|
136
|
+
"""
|
137
|
+
if len(args) >= 1:
|
138
|
+
has_default = True
|
139
|
+
default = args[0]
|
140
|
+
elif "default" in kwargs:
|
141
|
+
has_default = True
|
142
|
+
default = kwargs["default"]
|
143
|
+
else:
|
144
|
+
has_default = False
|
145
|
+
|
146
|
+
attr_names = chain_str.split(".")
|
147
|
+
|
148
|
+
res = obj
|
149
|
+
for attr_name in attr_names:
|
150
|
+
if not hasattr(res, attr_name):
|
151
|
+
if has_default:
|
152
|
+
return default
|
153
|
+
else:
|
154
|
+
raise AttributeError(f"{res} object has no attribute {attr_name}")
|
155
|
+
res = getattr(res, attr_name)
|
156
|
+
|
157
|
+
return res
|
158
|
+
|
159
|
+
|
160
|
+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
|
161
|
+
"""
|
162
|
+
Decorator to mark functions as deprecated
|
163
|
+
|
164
|
+
:param new_function: Function called in place of depreciated function
|
165
|
+
:param message: Depreciation message, replaces default depreciation message
|
166
|
+
"""
|
167
|
+
|
168
|
+
def decorator(func: Callable[[Any], Any]):
|
169
|
+
nonlocal message
|
170
|
+
|
171
|
+
if message is None:
|
172
|
+
message = (
|
173
|
+
f"{func.__name__} is deprecated and will be removed in a future release"
|
174
|
+
)
|
175
|
+
if future_name is not None:
|
176
|
+
message += f". Please use {future_name} instead."
|
177
|
+
|
178
|
+
@wraps(func)
|
179
|
+
def wrapped(*args, **kwargs):
|
180
|
+
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
181
|
+
return func(*args, **kwargs)
|
182
|
+
|
183
|
+
return wrapped
|
184
|
+
|
185
|
+
return decorator
|
186
|
+
|
187
|
+
|
125
188
|
class Aliasable:
|
126
189
|
"""
|
127
190
|
A mixin for enums to allow aliasing of enum members
|
@@ -11,9 +11,48 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
"""
|
15
|
+
Utilities associated with offloading functionality provided by `accelerate`.
|
16
|
+
|
17
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
18
|
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
19
|
+
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
20
|
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
21
|
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
22
|
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
23
|
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
24
|
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
25
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
26
|
+
"""
|
27
|
+
|
28
|
+
import contextlib
|
29
|
+
from functools import wraps
|
30
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
14
31
|
|
15
32
|
import torch
|
16
|
-
|
33
|
+
|
34
|
+
|
35
|
+
try:
|
36
|
+
from accelerate.hooks import (
|
37
|
+
AlignDevicesHook,
|
38
|
+
add_hook_to_module,
|
39
|
+
remove_hook_from_module,
|
40
|
+
)
|
41
|
+
from accelerate.utils import (
|
42
|
+
OffloadedWeightsLoader,
|
43
|
+
PrefixedDataset,
|
44
|
+
set_module_tensor_to_device,
|
45
|
+
)
|
46
|
+
|
47
|
+
_has_accelerate = True
|
48
|
+
except ImportError:
|
49
|
+
_has_accelerate = False
|
50
|
+
AlignDevicesHook = None
|
51
|
+
add_hook_to_module = None
|
52
|
+
remove_hook_from_module = None
|
53
|
+
OffloadedWeightsLoader = None
|
54
|
+
PrefixedDataset = None
|
55
|
+
set_module_tensor_to_device = None
|
17
56
|
|
18
57
|
|
19
58
|
__all__ = [
|
@@ -22,23 +61,44 @@ __all__ = [
|
|
22
61
|
"get_offloaded_device",
|
23
62
|
"update_prefix_dict",
|
24
63
|
"update_parameter_data",
|
64
|
+
"register_offload_parameter",
|
65
|
+
"update_offload_parameter",
|
66
|
+
"delete_offload_parameter",
|
67
|
+
"has_offloaded_params",
|
68
|
+
"disable_hf_hook",
|
69
|
+
"align_module_device",
|
25
70
|
]
|
26
71
|
|
27
72
|
|
28
|
-
def
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
73
|
+
def check_accelerate(fallback: Any):
|
74
|
+
def decorator(func: Callable[[Any], Any]):
|
75
|
+
if not _has_accelerate:
|
76
|
+
|
77
|
+
@wraps(func)
|
78
|
+
def fallback_fn(*args, **kwargs):
|
79
|
+
return fallback
|
80
|
+
|
81
|
+
return fallback_fn
|
82
|
+
|
83
|
+
return func
|
34
84
|
|
85
|
+
return decorator
|
35
86
|
|
36
|
-
|
87
|
+
|
88
|
+
""" Candidates for Depreciation """
|
89
|
+
|
90
|
+
|
91
|
+
@check_accelerate(fallback=False)
|
92
|
+
def is_module_offloaded(module: torch.nn.Module) -> bool:
|
93
|
+
return has_offloaded_params(module)
|
94
|
+
|
95
|
+
|
96
|
+
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
37
97
|
"""
|
38
|
-
:param module:
|
39
|
-
:return: device
|
98
|
+
:param module: module to check
|
99
|
+
:return: device module is loaded onto during forward pass
|
40
100
|
"""
|
41
|
-
if
|
101
|
+
if has_offloaded_params(module):
|
42
102
|
return module._hf_hook.execution_device
|
43
103
|
device = next(module.parameters()).device
|
44
104
|
|
@@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device:
|
|
49
109
|
return device
|
50
110
|
|
51
111
|
|
52
|
-
def get_offloaded_device(module: Module) -> torch.device:
|
112
|
+
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
|
53
113
|
"""
|
54
|
-
:param module:
|
55
|
-
:return: device
|
114
|
+
:param module: module to check
|
115
|
+
:return: device module is offloaded to onto after forward pass
|
56
116
|
"""
|
57
|
-
if
|
117
|
+
if has_offloaded_params(module):
|
58
118
|
first_key = list(module._hf_hook.weights_map.keys())[0]
|
59
119
|
prefix_dataset = module._hf_hook.weights_map.dataset
|
60
120
|
return prefix_dataset[first_key].device
|
61
121
|
return next(module.parameters()).device
|
62
122
|
|
63
123
|
|
64
|
-
|
124
|
+
@check_accelerate(fallback=None)
|
125
|
+
def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
|
65
126
|
"""
|
66
127
|
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
67
128
|
by data. This is neccesary because parameter updates for offloaded modules do not
|
68
129
|
persist automatically between loads. This function only affects the offloaded
|
69
130
|
state dict and not the current state of the loaded module.
|
70
131
|
|
71
|
-
:param module:
|
132
|
+
:param module: module containing the parameter to update
|
72
133
|
:param key: name of parameter to update
|
73
134
|
:param data: tensor to update parameter with in the offloaded state dict
|
74
135
|
"""
|
75
|
-
if not
|
136
|
+
if not has_offloaded_params(module):
|
76
137
|
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
77
|
-
|
78
|
-
|
138
|
+
|
139
|
+
weights_map = module._hf_hook.weights_map
|
140
|
+
offload_to_weights_map(weights_map, key, data)
|
79
141
|
|
80
142
|
|
81
143
|
def update_parameter_data(
|
82
|
-
module: Module, new_param_data: torch.Tensor, param_name: str
|
144
|
+
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
|
83
145
|
):
|
84
146
|
"""
|
85
|
-
|
86
|
-
|
87
|
-
the module is offloaded. This is neccesary because parameter updates for offloaded
|
88
|
-
modules do not persist automatically between loads.
|
147
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
148
|
+
parameters of offloaded modules and non-offloaded modules
|
89
149
|
|
90
|
-
:param module:
|
150
|
+
:param module: module containing the parameter to update
|
91
151
|
:param new_param_data: tensor to update parameter with
|
92
|
-
:param param_name: name of
|
152
|
+
:param param_name: name of module parameter to update
|
93
153
|
"""
|
94
|
-
|
95
|
-
|
154
|
+
update_offload_parameter(module, param_name, new_param_data)
|
155
|
+
|
156
|
+
|
157
|
+
""" Candidates for Upstreaming """
|
158
|
+
|
159
|
+
|
160
|
+
def register_offload_parameter(
|
161
|
+
module: torch.nn.Module,
|
162
|
+
name: str,
|
163
|
+
parameter: torch.nn.Parameter,
|
164
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
165
|
+
):
|
166
|
+
"""
|
167
|
+
Register a parameter to the given module which may be offloaded
|
168
|
+
|
169
|
+
:param module: maybe offloaded module
|
170
|
+
:param name: name of newly registered parameter
|
171
|
+
:param parameter: parameter being registered
|
172
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
173
|
+
provided, then infer device from parameters on module
|
174
|
+
"""
|
175
|
+
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
|
176
|
+
module.register_parameter(name, parameter)
|
177
|
+
|
178
|
+
if has_offloaded_params(module):
|
179
|
+
weights_map = module._hf_hook.weights_map
|
180
|
+
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
|
181
|
+
if not has_onload:
|
182
|
+
set_module_tensor_to_device(module, name, "meta")
|
183
|
+
|
184
|
+
|
185
|
+
def update_offload_parameter(
|
186
|
+
module: torch.nn.Module,
|
187
|
+
name: str,
|
188
|
+
data: Optional[torch.Tensor],
|
189
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
190
|
+
):
|
191
|
+
"""
|
192
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
193
|
+
parameters of offloaded modules and non-offloaded modules
|
194
|
+
|
195
|
+
:param module: module containing the parameter to update
|
196
|
+
:param name: name of module parameter to update
|
197
|
+
:param data: tensor to update parameter with
|
198
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
199
|
+
provided, then infer device from parameters on module
|
200
|
+
"""
|
201
|
+
param = getattr(module, name)
|
202
|
+
data = data.to(param.dtype)
|
203
|
+
|
204
|
+
# copy data into onloaded parameter if applicable
|
205
|
+
if param.device != "meta":
|
206
|
+
param.data.copy_(data)
|
207
|
+
|
208
|
+
# update offload dict
|
209
|
+
if has_offloaded_params(module):
|
210
|
+
weights_map = module._hf_hook.weights_map
|
211
|
+
offload_to_weights_map(weights_map, name, data, offload_device)
|
212
|
+
|
213
|
+
|
214
|
+
def delete_offload_parameter(module: torch.nn.Module, name: str):
|
215
|
+
"""
|
216
|
+
Delete a parameter from a module which may be offloaded
|
217
|
+
|
218
|
+
:param module: maybe offloaded module
|
219
|
+
:param name: name of parameter being deleted
|
220
|
+
"""
|
221
|
+
delattr(module, name)
|
222
|
+
|
223
|
+
if has_offloaded_params(module):
|
224
|
+
weights_map = module._hf_hook.weights_map
|
225
|
+
delete_from_weights_map(weights_map, name)
|
96
226
|
|
97
|
-
device = next(module.parameters()).device
|
98
227
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
228
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
229
|
+
@contextlib.contextmanager
|
230
|
+
def disable_hf_hook(module: torch.nn.Module):
|
231
|
+
hooks = {}
|
103
232
|
|
104
|
-
|
105
|
-
|
106
|
-
|
233
|
+
def collect_hooks(module):
|
234
|
+
nonlocal hooks
|
235
|
+
if hasattr(module, "_hf_hook"):
|
236
|
+
hooks[module] = module._hf_hook
|
237
|
+
remove_hook_from_module(module)
|
107
238
|
|
108
|
-
|
109
|
-
parameter.data = new_param_data.to(device).to(dtype)
|
239
|
+
module.apply(collect_hooks)
|
110
240
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
241
|
+
yield
|
242
|
+
|
243
|
+
for submodule, hook in hooks.items():
|
244
|
+
add_hook_to_module(submodule, hook)
|
245
|
+
|
246
|
+
|
247
|
+
@check_accelerate(fallback=None)
|
248
|
+
def offload_to_weights_map(
|
249
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
250
|
+
key: str,
|
251
|
+
value: torch.Tensor,
|
252
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
253
|
+
):
|
254
|
+
"""
|
255
|
+
Helper function which implements offloaded item assignment for PrefixedDataset,
|
256
|
+
OffloadedWeightsLoader, and Dict types.
|
257
|
+
|
258
|
+
:param weights_map: weight map to be updated with offload information
|
259
|
+
:param key: key used to identify weight location
|
260
|
+
:param value: weight being offloaded
|
261
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
262
|
+
provided, then infer device from parameters in weights_map
|
263
|
+
"""
|
264
|
+
if isinstance(weights_map, PrefixedDataset):
|
265
|
+
if offload_device == "disk":
|
266
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
267
|
+
|
268
|
+
dataset = weights_map.dataset
|
269
|
+
key = f"{weights_map.prefix}{key}"
|
270
|
+
offload_to_weights_map(dataset, key, value, offload_device)
|
271
|
+
|
272
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
273
|
+
if key not in weights_map.all_keys:
|
274
|
+
weights_map.all_keys.append(key)
|
275
|
+
|
276
|
+
if len(weights_map.index) <= 0 and offload_device != "disk":
|
277
|
+
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
|
278
|
+
|
279
|
+
else:
|
280
|
+
raise NotImplementedError(
|
281
|
+
"Updating weights_map with disk offloading is not implemented yet"
|
282
|
+
)
|
283
|
+
|
284
|
+
elif isinstance(weights_map, dict):
|
285
|
+
if offload_device == "disk":
|
286
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
287
|
+
|
288
|
+
# infer offload device
|
289
|
+
if offload_device is None:
|
290
|
+
if key in weights_map:
|
291
|
+
offload_device = weights_map[key].device
|
292
|
+
else:
|
293
|
+
tens = next(iter(weights_map.values()), None)
|
294
|
+
if tens is None:
|
295
|
+
raise ValueError(
|
296
|
+
"Cannot infer offload device from empty weights_map"
|
297
|
+
)
|
298
|
+
offload_device = tens.device
|
299
|
+
|
300
|
+
weights_map[key] = value.to(device=offload_device)
|
301
|
+
|
302
|
+
else:
|
303
|
+
raise NotImplementedError(
|
304
|
+
"Updating offload data not implemented for weights_map of type "
|
305
|
+
f"{type(weights_map)}"
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
@check_accelerate(fallback=None)
|
310
|
+
def delete_from_weights_map(
|
311
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
312
|
+
key: str,
|
313
|
+
):
|
314
|
+
if isinstance(weights_map, PrefixedDataset):
|
315
|
+
dataset = weights_map.dataset
|
316
|
+
key = f"{weights_map.prefix}{key}"
|
317
|
+
delete_from_weights_map(dataset, key)
|
318
|
+
|
319
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
320
|
+
if len(weights_map.index) <= 0:
|
321
|
+
delete_from_weights_map(weights_map.state_dict, key)
|
322
|
+
|
323
|
+
else:
|
324
|
+
raise NotImplementedError(
|
325
|
+
"Delete from weights_map with disk offloading is not implemented yet"
|
326
|
+
)
|
327
|
+
|
328
|
+
elif isinstance(weights_map, dict):
|
329
|
+
del weights_map[key]
|
330
|
+
|
331
|
+
else:
|
332
|
+
raise NotImplementedError(
|
333
|
+
"Updating offload data not implemented for weights_map of type "
|
334
|
+
f"{type(weights_map)}"
|
116
335
|
)
|
336
|
+
|
337
|
+
|
338
|
+
""" Upstreamed Functions """
|
339
|
+
|
340
|
+
|
341
|
+
# introduced in accelerate v1.1.0
|
342
|
+
@check_accelerate(fallback=False)
|
343
|
+
def has_offloaded_params(module: torch.nn.Module) -> bool:
|
344
|
+
"""
|
345
|
+
Checks if a module has offloaded parameters by checking if the given module has a
|
346
|
+
AlignDevicesHook attached with offloading enabled
|
347
|
+
|
348
|
+
Args:
|
349
|
+
module (`torch.nn.Module`): The module to check for an offload hook.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
bool: `True` if the module has an offload hook and offloading is enabled,
|
353
|
+
`False` otherwise.
|
354
|
+
"""
|
355
|
+
return (
|
356
|
+
hasattr(module, "_hf_hook")
|
357
|
+
and isinstance(module._hf_hook, AlignDevicesHook)
|
358
|
+
and module._hf_hook.offload
|
359
|
+
)
|
360
|
+
|
361
|
+
|
362
|
+
# introduced in accelerate v1.1.0
|
363
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
364
|
+
@contextlib.contextmanager
|
365
|
+
def align_module_device(
|
366
|
+
module: torch.nn.Module, execution_device: Optional[torch.device] = None
|
367
|
+
):
|
368
|
+
"""
|
369
|
+
Context manager that moves a module's parameters to the specified execution device.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
module (`torch.nn.Module`):
|
373
|
+
Module with parameters to align.
|
374
|
+
execution_device (`torch.device`, *optional*):
|
375
|
+
If provided, overrides the module's execution device within the context.
|
376
|
+
Otherwise, use hook execution device or pass
|
377
|
+
"""
|
378
|
+
if has_offloaded_params(module):
|
379
|
+
if execution_device is not None:
|
380
|
+
original_device = module._hf_hook.execution_device
|
381
|
+
module._hf_hook.execution_device = execution_device
|
382
|
+
|
383
|
+
try:
|
384
|
+
module._hf_hook.pre_forward(module)
|
385
|
+
yield
|
386
|
+
finally:
|
387
|
+
module._hf_hook.post_forward(module, None)
|
388
|
+
if execution_device is not None:
|
389
|
+
module._hf_hook.execution_device = original_device
|
390
|
+
|
391
|
+
elif execution_device is not None:
|
392
|
+
devices = {
|
393
|
+
name: param.device for name, param in module.named_parameters(recurse=False)
|
394
|
+
}
|
395
|
+
try:
|
396
|
+
for name in devices:
|
397
|
+
set_module_tensor_to_device(module, name, execution_device)
|
398
|
+
yield
|
399
|
+
finally:
|
400
|
+
for name, device in devices.items():
|
401
|
+
set_module_tensor_to_device(module, name, device)
|
402
|
+
|
403
|
+
else:
|
404
|
+
yield
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.8.1.
|
3
|
+
Version: 0.8.1.20241225
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -31,20 +31,20 @@ compressed_tensors/quantization/lifecycle/apply.py,sha256=jCUSgeOBtagE5IhgIbyYMZ
|
|
31
31
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
32
32
|
compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
|
33
33
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
34
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
34
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=hymYtayTSumm8KCYAYPY267aWmlsJpt8oQFiRblk8qE,7452
|
35
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
36
36
|
compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
|
37
37
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
38
38
|
compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
|
39
39
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
40
|
-
compressed_tensors/utils/helpers.py,sha256=
|
41
|
-
compressed_tensors/utils/offload.py,sha256=
|
40
|
+
compressed_tensors/utils/helpers.py,sha256=XF36-SLkXnAHh0VzbvUlAdh6a88aCQvS_WeYs9Lfio8,6827
|
41
|
+
compressed_tensors/utils/offload.py,sha256=cMmzd9IdlNbs29CReHj1PPSLUM6OWaT5YumlLT5eP3w,13845
|
42
42
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
43
43
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
44
44
|
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
45
45
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
46
|
-
compressed_tensors_nightly-0.8.1.
|
47
|
-
compressed_tensors_nightly-0.8.1.
|
48
|
-
compressed_tensors_nightly-0.8.1.
|
49
|
-
compressed_tensors_nightly-0.8.1.
|
50
|
-
compressed_tensors_nightly-0.8.1.
|
46
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/METADATA,sha256=_Dh-8bZ7fT6iBH4JWMABCkMExAOjA4p6N0OJ_vyDwps,6799
|
48
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
49
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/RECORD,,
|
File without changes
|
File without changes
|