compressed-tensors-nightly 0.8.1.20241220__py3-none-any.whl → 0.8.1.20241225__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- compressed_tensors/quantization/lifecycle/initialize.py +17 -44
- compressed_tensors/utils/helpers.py +64 -1
- compressed_tensors/utils/offload.py +332 -44
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/METADATA +1 -1
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/RECORD +8 -8
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/LICENSE +0 -0
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/WHEEL +0 -0
- {compressed_tensors_nightly-0.8.1.20241220.dist-info → compressed_tensors_nightly-0.8.1.20241225.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
30
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
31
|
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
32
|
-
from compressed_tensors.utils import
|
32
|
+
from compressed_tensors.utils import (
|
33
|
+
disable_hf_hook,
|
34
|
+
has_offloaded_params,
|
35
|
+
register_offload_parameter,
|
36
|
+
)
|
33
37
|
from torch.nn import Module, Parameter
|
34
38
|
|
35
39
|
|
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
|
|
112
116
|
module.quantization_scheme = scheme
|
113
117
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
114
118
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
120
|
-
from accelerate.utils import PrefixedDataset
|
121
|
-
except ModuleNotFoundError:
|
122
|
-
raise ModuleNotFoundError(
|
123
|
-
"Offloaded model detected. To use CPU offloading with "
|
124
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
125
|
-
"run `pip install compressed-tensors[accelerate]`"
|
126
|
-
)
|
127
|
-
|
128
|
-
offloaded = True
|
129
|
-
hook = module._hf_hook
|
130
|
-
prefix_dict = module._hf_hook.weights_map
|
131
|
-
new_prefix = {}
|
132
|
-
|
133
|
-
# recreate the prefix dict (since it is immutable)
|
134
|
-
# and add quantization parameters
|
135
|
-
for key, data in module.named_parameters():
|
136
|
-
if key not in prefix_dict:
|
137
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
138
|
-
else:
|
139
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
140
|
-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
141
|
-
remove_hook_from_module(module)
|
142
|
-
|
143
|
-
# wrap forward call of module to perform
|
144
|
-
# quantized actions based on calltime status
|
145
|
-
wrap_module_forward_quantized(module, scheme)
|
146
|
-
|
147
|
-
if offloaded:
|
148
|
-
# we need to re-add the hook for offloading now that we've wrapped forward
|
149
|
-
add_hook_to_module(module, hook)
|
150
|
-
if prefix_dict is not None:
|
151
|
-
module._hf_hook.weights_map = new_prefix_dict
|
119
|
+
with disable_hf_hook(module):
|
120
|
+
# wrap forward call of module to perform
|
121
|
+
# quantized actions based on calltime status
|
122
|
+
wrap_module_forward_quantized(module, scheme)
|
152
123
|
|
153
124
|
|
154
125
|
def is_attention_module(module: Module):
|
@@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
|
|
169
140
|
if quantization_args.dynamic:
|
170
141
|
return
|
171
142
|
|
172
|
-
device
|
173
|
-
|
174
|
-
|
143
|
+
# begin on the same device as other parameters or cpu if offloaded.
|
144
|
+
# in the offloaded case, there's no point moving tensors to the execution device
|
145
|
+
# if they're going to be immediately offloaded by `register_offload_parameter`
|
146
|
+
params_device = next(module.parameters()).device
|
147
|
+
device = "cpu" if has_offloaded_params(module) else params_device
|
175
148
|
|
176
149
|
# infer expected scale/zero point shape
|
177
150
|
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
@@ -196,7 +169,7 @@ def _initialize_scale_zero_point(
|
|
196
169
|
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
197
170
|
requires_grad=False,
|
198
171
|
)
|
199
|
-
module
|
172
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
200
173
|
|
201
174
|
if force_zero_point or not quantization_args.symmetric:
|
202
175
|
zp_dtype = quantization_args.pytorch_dtype()
|
@@ -204,7 +177,7 @@ def _initialize_scale_zero_point(
|
|
204
177
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
205
178
|
requires_grad=False,
|
206
179
|
)
|
207
|
-
module
|
180
|
+
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
|
208
181
|
|
209
182
|
# only grouped activation ordering has g_idx
|
210
183
|
if quantization_args.actorder == ActivationOrdering.GROUP:
|
@@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
|
|
214
187
|
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
215
188
|
requires_grad=False,
|
216
189
|
)
|
217
|
-
module
|
190
|
+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
|
218
191
|
|
219
192
|
|
220
193
|
def _initialize_attn_scales(module: Module) -> None:
|
@@ -12,7 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import warnings
|
16
|
+
from functools import wraps
|
17
|
+
from typing import Any, Callable, Dict, Optional
|
16
18
|
|
17
19
|
import torch
|
18
20
|
from transformers import AutoConfig
|
@@ -24,6 +26,8 @@ __all__ = [
|
|
24
26
|
"tensor_follows_mask_structure",
|
25
27
|
"replace_module",
|
26
28
|
"is_compressed_tensors_config",
|
29
|
+
"getattr_chain",
|
30
|
+
"deprecated",
|
27
31
|
"Aliasable",
|
28
32
|
]
|
29
33
|
|
@@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
|
|
122
126
|
return False
|
123
127
|
|
124
128
|
|
129
|
+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
|
130
|
+
"""
|
131
|
+
Chain multiple getattr calls, separated by `.`
|
132
|
+
|
133
|
+
:param obj: base object whose attributes are being retrieved
|
134
|
+
:param chain_str: attribute names separated by `.`
|
135
|
+
:param default: default value, throw error otherwise
|
136
|
+
"""
|
137
|
+
if len(args) >= 1:
|
138
|
+
has_default = True
|
139
|
+
default = args[0]
|
140
|
+
elif "default" in kwargs:
|
141
|
+
has_default = True
|
142
|
+
default = kwargs["default"]
|
143
|
+
else:
|
144
|
+
has_default = False
|
145
|
+
|
146
|
+
attr_names = chain_str.split(".")
|
147
|
+
|
148
|
+
res = obj
|
149
|
+
for attr_name in attr_names:
|
150
|
+
if not hasattr(res, attr_name):
|
151
|
+
if has_default:
|
152
|
+
return default
|
153
|
+
else:
|
154
|
+
raise AttributeError(f"{res} object has no attribute {attr_name}")
|
155
|
+
res = getattr(res, attr_name)
|
156
|
+
|
157
|
+
return res
|
158
|
+
|
159
|
+
|
160
|
+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
|
161
|
+
"""
|
162
|
+
Decorator to mark functions as deprecated
|
163
|
+
|
164
|
+
:param new_function: Function called in place of depreciated function
|
165
|
+
:param message: Depreciation message, replaces default depreciation message
|
166
|
+
"""
|
167
|
+
|
168
|
+
def decorator(func: Callable[[Any], Any]):
|
169
|
+
nonlocal message
|
170
|
+
|
171
|
+
if message is None:
|
172
|
+
message = (
|
173
|
+
f"{func.__name__} is deprecated and will be removed in a future release"
|
174
|
+
)
|
175
|
+
if future_name is not None:
|
176
|
+
message += f". Please use {future_name} instead."
|
177
|
+
|
178
|
+
@wraps(func)
|
179
|
+
def wrapped(*args, **kwargs):
|
180
|
+
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
181
|
+
return func(*args, **kwargs)
|
182
|
+
|
183
|
+
return wrapped
|
184
|
+
|
185
|
+
return decorator
|
186
|
+
|
187
|
+
|
125
188
|
class Aliasable:
|
126
189
|
"""
|
127
190
|
A mixin for enums to allow aliasing of enum members
|
@@ -11,9 +11,48 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
"""
|
15
|
+
Utilities associated with offloading functionality provided by `accelerate`.
|
16
|
+
|
17
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
18
|
+
| Operation | Without offloading support | With offloading support | # noqa: E501
|
19
|
+
| --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
|
20
|
+
| Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
|
21
|
+
| Check | N/A | has_offloaded_params(module) | # noqa: E501
|
22
|
+
| Onload | N/A | with align_module_device(module) | # noqa: E501
|
23
|
+
| Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
|
24
|
+
| Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
|
25
|
+
| ----------------------------------------------------------------------------------------------------- | # noqa: E501
|
26
|
+
"""
|
27
|
+
|
28
|
+
import contextlib
|
29
|
+
from functools import wraps
|
30
|
+
from typing import Any, Callable, Dict, Literal, Optional, Union
|
14
31
|
|
15
32
|
import torch
|
16
|
-
|
33
|
+
|
34
|
+
|
35
|
+
try:
|
36
|
+
from accelerate.hooks import (
|
37
|
+
AlignDevicesHook,
|
38
|
+
add_hook_to_module,
|
39
|
+
remove_hook_from_module,
|
40
|
+
)
|
41
|
+
from accelerate.utils import (
|
42
|
+
OffloadedWeightsLoader,
|
43
|
+
PrefixedDataset,
|
44
|
+
set_module_tensor_to_device,
|
45
|
+
)
|
46
|
+
|
47
|
+
_has_accelerate = True
|
48
|
+
except ImportError:
|
49
|
+
_has_accelerate = False
|
50
|
+
AlignDevicesHook = None
|
51
|
+
add_hook_to_module = None
|
52
|
+
remove_hook_from_module = None
|
53
|
+
OffloadedWeightsLoader = None
|
54
|
+
PrefixedDataset = None
|
55
|
+
set_module_tensor_to_device = None
|
17
56
|
|
18
57
|
|
19
58
|
__all__ = [
|
@@ -22,23 +61,44 @@ __all__ = [
|
|
22
61
|
"get_offloaded_device",
|
23
62
|
"update_prefix_dict",
|
24
63
|
"update_parameter_data",
|
64
|
+
"register_offload_parameter",
|
65
|
+
"update_offload_parameter",
|
66
|
+
"delete_offload_parameter",
|
67
|
+
"has_offloaded_params",
|
68
|
+
"disable_hf_hook",
|
69
|
+
"align_module_device",
|
25
70
|
]
|
26
71
|
|
27
72
|
|
28
|
-
def
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
73
|
+
def check_accelerate(fallback: Any):
|
74
|
+
def decorator(func: Callable[[Any], Any]):
|
75
|
+
if not _has_accelerate:
|
76
|
+
|
77
|
+
@wraps(func)
|
78
|
+
def fallback_fn(*args, **kwargs):
|
79
|
+
return fallback
|
80
|
+
|
81
|
+
return fallback_fn
|
82
|
+
|
83
|
+
return func
|
34
84
|
|
85
|
+
return decorator
|
35
86
|
|
36
|
-
|
87
|
+
|
88
|
+
""" Candidates for Depreciation """
|
89
|
+
|
90
|
+
|
91
|
+
@check_accelerate(fallback=False)
|
92
|
+
def is_module_offloaded(module: torch.nn.Module) -> bool:
|
93
|
+
return has_offloaded_params(module)
|
94
|
+
|
95
|
+
|
96
|
+
def get_execution_device(module: torch.nn.Module) -> torch.device:
|
37
97
|
"""
|
38
|
-
:param module:
|
39
|
-
:return: device
|
98
|
+
:param module: module to check
|
99
|
+
:return: device module is loaded onto during forward pass
|
40
100
|
"""
|
41
|
-
if
|
101
|
+
if has_offloaded_params(module):
|
42
102
|
return module._hf_hook.execution_device
|
43
103
|
device = next(module.parameters()).device
|
44
104
|
|
@@ -49,68 +109,296 @@ def get_execution_device(module: Module) -> torch.device:
|
|
49
109
|
return device
|
50
110
|
|
51
111
|
|
52
|
-
def get_offloaded_device(module: Module) -> torch.device:
|
112
|
+
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
|
53
113
|
"""
|
54
|
-
:param module:
|
55
|
-
:return: device
|
114
|
+
:param module: module to check
|
115
|
+
:return: device module is offloaded to onto after forward pass
|
56
116
|
"""
|
57
|
-
if
|
117
|
+
if has_offloaded_params(module):
|
58
118
|
first_key = list(module._hf_hook.weights_map.keys())[0]
|
59
119
|
prefix_dataset = module._hf_hook.weights_map.dataset
|
60
120
|
return prefix_dataset[first_key].device
|
61
121
|
return next(module.parameters()).device
|
62
122
|
|
63
123
|
|
64
|
-
|
124
|
+
@check_accelerate(fallback=None)
|
125
|
+
def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
|
65
126
|
"""
|
66
127
|
Updates the offloaded state dict for a given module. Parameter named key is replaced
|
67
128
|
by data. This is neccesary because parameter updates for offloaded modules do not
|
68
129
|
persist automatically between loads. This function only affects the offloaded
|
69
130
|
state dict and not the current state of the loaded module.
|
70
131
|
|
71
|
-
:param module:
|
132
|
+
:param module: module containing the parameter to update
|
72
133
|
:param key: name of parameter to update
|
73
134
|
:param data: tensor to update parameter with in the offloaded state dict
|
74
135
|
"""
|
75
|
-
if not
|
136
|
+
if not has_offloaded_params(module):
|
76
137
|
raise ValueError("Prefix dict is only applicable to offloaded modules")
|
77
|
-
|
78
|
-
|
138
|
+
|
139
|
+
weights_map = module._hf_hook.weights_map
|
140
|
+
offload_to_weights_map(weights_map, key, data)
|
79
141
|
|
80
142
|
|
81
143
|
def update_parameter_data(
|
82
|
-
module: Module, new_param_data: torch.Tensor, param_name: str
|
144
|
+
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
|
83
145
|
):
|
84
146
|
"""
|
85
|
-
|
86
|
-
|
87
|
-
the module is offloaded. This is neccesary because parameter updates for offloaded
|
88
|
-
modules do not persist automatically between loads.
|
147
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
148
|
+
parameters of offloaded modules and non-offloaded modules
|
89
149
|
|
90
|
-
:param module:
|
150
|
+
:param module: module containing the parameter to update
|
91
151
|
:param new_param_data: tensor to update parameter with
|
92
|
-
:param param_name: name of
|
152
|
+
:param param_name: name of module parameter to update
|
93
153
|
"""
|
94
|
-
|
95
|
-
|
154
|
+
update_offload_parameter(module, param_name, new_param_data)
|
155
|
+
|
156
|
+
|
157
|
+
""" Candidates for Upstreaming """
|
158
|
+
|
159
|
+
|
160
|
+
def register_offload_parameter(
|
161
|
+
module: torch.nn.Module,
|
162
|
+
name: str,
|
163
|
+
parameter: torch.nn.Parameter,
|
164
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
165
|
+
):
|
166
|
+
"""
|
167
|
+
Register a parameter to the given module which may be offloaded
|
168
|
+
|
169
|
+
:param module: maybe offloaded module
|
170
|
+
:param name: name of newly registered parameter
|
171
|
+
:param parameter: parameter being registered
|
172
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
173
|
+
provided, then infer device from parameters on module
|
174
|
+
"""
|
175
|
+
has_onload = any(p.device != torch.device("meta") for p in module.parameters())
|
176
|
+
module.register_parameter(name, parameter)
|
177
|
+
|
178
|
+
if has_offloaded_params(module):
|
179
|
+
weights_map = module._hf_hook.weights_map
|
180
|
+
offload_to_weights_map(weights_map, name, parameter.data, offload_device)
|
181
|
+
if not has_onload:
|
182
|
+
set_module_tensor_to_device(module, name, "meta")
|
183
|
+
|
184
|
+
|
185
|
+
def update_offload_parameter(
|
186
|
+
module: torch.nn.Module,
|
187
|
+
name: str,
|
188
|
+
data: Optional[torch.Tensor],
|
189
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
190
|
+
):
|
191
|
+
"""
|
192
|
+
Update the data of an existing parameter and its offload dict. Supports both
|
193
|
+
parameters of offloaded modules and non-offloaded modules
|
194
|
+
|
195
|
+
:param module: module containing the parameter to update
|
196
|
+
:param name: name of module parameter to update
|
197
|
+
:param data: tensor to update parameter with
|
198
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
199
|
+
provided, then infer device from parameters on module
|
200
|
+
"""
|
201
|
+
param = getattr(module, name)
|
202
|
+
data = data.to(param.dtype)
|
203
|
+
|
204
|
+
# copy data into onloaded parameter if applicable
|
205
|
+
if param.device != "meta":
|
206
|
+
param.data.copy_(data)
|
207
|
+
|
208
|
+
# update offload dict
|
209
|
+
if has_offloaded_params(module):
|
210
|
+
weights_map = module._hf_hook.weights_map
|
211
|
+
offload_to_weights_map(weights_map, name, data, offload_device)
|
212
|
+
|
213
|
+
|
214
|
+
def delete_offload_parameter(module: torch.nn.Module, name: str):
|
215
|
+
"""
|
216
|
+
Delete a parameter from a module which may be offloaded
|
217
|
+
|
218
|
+
:param module: maybe offloaded module
|
219
|
+
:param name: name of parameter being deleted
|
220
|
+
"""
|
221
|
+
delattr(module, name)
|
222
|
+
|
223
|
+
if has_offloaded_params(module):
|
224
|
+
weights_map = module._hf_hook.weights_map
|
225
|
+
delete_from_weights_map(weights_map, name)
|
96
226
|
|
97
|
-
device = next(module.parameters()).device
|
98
227
|
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
228
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
229
|
+
@contextlib.contextmanager
|
230
|
+
def disable_hf_hook(module: torch.nn.Module):
|
231
|
+
hooks = {}
|
103
232
|
|
104
|
-
|
105
|
-
|
106
|
-
|
233
|
+
def collect_hooks(module):
|
234
|
+
nonlocal hooks
|
235
|
+
if hasattr(module, "_hf_hook"):
|
236
|
+
hooks[module] = module._hf_hook
|
237
|
+
remove_hook_from_module(module)
|
107
238
|
|
108
|
-
|
109
|
-
parameter.data = new_param_data.to(device).to(dtype)
|
239
|
+
module.apply(collect_hooks)
|
110
240
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
241
|
+
yield
|
242
|
+
|
243
|
+
for submodule, hook in hooks.items():
|
244
|
+
add_hook_to_module(submodule, hook)
|
245
|
+
|
246
|
+
|
247
|
+
@check_accelerate(fallback=None)
|
248
|
+
def offload_to_weights_map(
|
249
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
250
|
+
key: str,
|
251
|
+
value: torch.Tensor,
|
252
|
+
offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
|
253
|
+
):
|
254
|
+
"""
|
255
|
+
Helper function which implements offloaded item assignment for PrefixedDataset,
|
256
|
+
OffloadedWeightsLoader, and Dict types.
|
257
|
+
|
258
|
+
:param weights_map: weight map to be updated with offload information
|
259
|
+
:param key: key used to identify weight location
|
260
|
+
:param value: weight being offloaded
|
261
|
+
:param offload_device: device on which weight will be offloaded to. If None is
|
262
|
+
provided, then infer device from parameters in weights_map
|
263
|
+
"""
|
264
|
+
if isinstance(weights_map, PrefixedDataset):
|
265
|
+
if offload_device == "disk":
|
266
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
267
|
+
|
268
|
+
dataset = weights_map.dataset
|
269
|
+
key = f"{weights_map.prefix}{key}"
|
270
|
+
offload_to_weights_map(dataset, key, value, offload_device)
|
271
|
+
|
272
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
273
|
+
if key not in weights_map.all_keys:
|
274
|
+
weights_map.all_keys.append(key)
|
275
|
+
|
276
|
+
if len(weights_map.index) <= 0 and offload_device != "disk":
|
277
|
+
offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
|
278
|
+
|
279
|
+
else:
|
280
|
+
raise NotImplementedError(
|
281
|
+
"Updating weights_map with disk offloading is not implemented yet"
|
282
|
+
)
|
283
|
+
|
284
|
+
elif isinstance(weights_map, dict):
|
285
|
+
if offload_device == "disk":
|
286
|
+
raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
|
287
|
+
|
288
|
+
# infer offload device
|
289
|
+
if offload_device is None:
|
290
|
+
if key in weights_map:
|
291
|
+
offload_device = weights_map[key].device
|
292
|
+
else:
|
293
|
+
tens = next(iter(weights_map.values()), None)
|
294
|
+
if tens is None:
|
295
|
+
raise ValueError(
|
296
|
+
"Cannot infer offload device from empty weights_map"
|
297
|
+
)
|
298
|
+
offload_device = tens.device
|
299
|
+
|
300
|
+
weights_map[key] = value.to(device=offload_device)
|
301
|
+
|
302
|
+
else:
|
303
|
+
raise NotImplementedError(
|
304
|
+
"Updating offload data not implemented for weights_map of type "
|
305
|
+
f"{type(weights_map)}"
|
306
|
+
)
|
307
|
+
|
308
|
+
|
309
|
+
@check_accelerate(fallback=None)
|
310
|
+
def delete_from_weights_map(
|
311
|
+
weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
|
312
|
+
key: str,
|
313
|
+
):
|
314
|
+
if isinstance(weights_map, PrefixedDataset):
|
315
|
+
dataset = weights_map.dataset
|
316
|
+
key = f"{weights_map.prefix}{key}"
|
317
|
+
delete_from_weights_map(dataset, key)
|
318
|
+
|
319
|
+
elif isinstance(weights_map, OffloadedWeightsLoader):
|
320
|
+
if len(weights_map.index) <= 0:
|
321
|
+
delete_from_weights_map(weights_map.state_dict, key)
|
322
|
+
|
323
|
+
else:
|
324
|
+
raise NotImplementedError(
|
325
|
+
"Delete from weights_map with disk offloading is not implemented yet"
|
326
|
+
)
|
327
|
+
|
328
|
+
elif isinstance(weights_map, dict):
|
329
|
+
del weights_map[key]
|
330
|
+
|
331
|
+
else:
|
332
|
+
raise NotImplementedError(
|
333
|
+
"Updating offload data not implemented for weights_map of type "
|
334
|
+
f"{type(weights_map)}"
|
116
335
|
)
|
336
|
+
|
337
|
+
|
338
|
+
""" Upstreamed Functions """
|
339
|
+
|
340
|
+
|
341
|
+
# introduced in accelerate v1.1.0
|
342
|
+
@check_accelerate(fallback=False)
|
343
|
+
def has_offloaded_params(module: torch.nn.Module) -> bool:
|
344
|
+
"""
|
345
|
+
Checks if a module has offloaded parameters by checking if the given module has a
|
346
|
+
AlignDevicesHook attached with offloading enabled
|
347
|
+
|
348
|
+
Args:
|
349
|
+
module (`torch.nn.Module`): The module to check for an offload hook.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
bool: `True` if the module has an offload hook and offloading is enabled,
|
353
|
+
`False` otherwise.
|
354
|
+
"""
|
355
|
+
return (
|
356
|
+
hasattr(module, "_hf_hook")
|
357
|
+
and isinstance(module._hf_hook, AlignDevicesHook)
|
358
|
+
and module._hf_hook.offload
|
359
|
+
)
|
360
|
+
|
361
|
+
|
362
|
+
# introduced in accelerate v1.1.0
|
363
|
+
@check_accelerate(fallback=contextlib.nullcontext())
|
364
|
+
@contextlib.contextmanager
|
365
|
+
def align_module_device(
|
366
|
+
module: torch.nn.Module, execution_device: Optional[torch.device] = None
|
367
|
+
):
|
368
|
+
"""
|
369
|
+
Context manager that moves a module's parameters to the specified execution device.
|
370
|
+
|
371
|
+
Args:
|
372
|
+
module (`torch.nn.Module`):
|
373
|
+
Module with parameters to align.
|
374
|
+
execution_device (`torch.device`, *optional*):
|
375
|
+
If provided, overrides the module's execution device within the context.
|
376
|
+
Otherwise, use hook execution device or pass
|
377
|
+
"""
|
378
|
+
if has_offloaded_params(module):
|
379
|
+
if execution_device is not None:
|
380
|
+
original_device = module._hf_hook.execution_device
|
381
|
+
module._hf_hook.execution_device = execution_device
|
382
|
+
|
383
|
+
try:
|
384
|
+
module._hf_hook.pre_forward(module)
|
385
|
+
yield
|
386
|
+
finally:
|
387
|
+
module._hf_hook.post_forward(module, None)
|
388
|
+
if execution_device is not None:
|
389
|
+
module._hf_hook.execution_device = original_device
|
390
|
+
|
391
|
+
elif execution_device is not None:
|
392
|
+
devices = {
|
393
|
+
name: param.device for name, param in module.named_parameters(recurse=False)
|
394
|
+
}
|
395
|
+
try:
|
396
|
+
for name in devices:
|
397
|
+
set_module_tensor_to_device(module, name, execution_device)
|
398
|
+
yield
|
399
|
+
finally:
|
400
|
+
for name, device in devices.items():
|
401
|
+
set_module_tensor_to_device(module, name, device)
|
402
|
+
|
403
|
+
else:
|
404
|
+
yield
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: compressed-tensors-nightly
|
3
|
-
Version: 0.8.1.
|
3
|
+
Version: 0.8.1.20241225
|
4
4
|
Summary: Library for utilization of compressed safetensors of neural network models
|
5
5
|
Home-page: https://github.com/neuralmagic/compressed-tensors
|
6
6
|
Author: Neuralmagic, Inc.
|
@@ -31,20 +31,20 @@ compressed_tensors/quantization/lifecycle/apply.py,sha256=jCUSgeOBtagE5IhgIbyYMZ
|
|
31
31
|
compressed_tensors/quantization/lifecycle/compressed.py,sha256=Fj9n66IN0EWsOAkBHg3O0GlOQpxstqjCcs0ttzMXrJ0,2296
|
32
32
|
compressed_tensors/quantization/lifecycle/forward.py,sha256=QPL6-vKOFuKdKIEsVqMhsw4x552Jpm2sqO0oeChbnrM,12941
|
33
33
|
compressed_tensors/quantization/lifecycle/helpers.py,sha256=C0mhy2vJ0fCjVeN4kFNhw8Eq1wkteBGHiZ36RVLThRY,944
|
34
|
-
compressed_tensors/quantization/lifecycle/initialize.py,sha256=
|
34
|
+
compressed_tensors/quantization/lifecycle/initialize.py,sha256=hymYtayTSumm8KCYAYPY267aWmlsJpt8oQFiRblk8qE,7452
|
35
35
|
compressed_tensors/quantization/utils/__init__.py,sha256=VdtEmP0bvuND_IGQnyqUPc5lnFp-1_yD7StKSX4x80w,656
|
36
36
|
compressed_tensors/quantization/utils/helpers.py,sha256=DBP-sGRpGAY01K0LFE7qqonNj4hkTYL_mXrMs2LtAD8,14100
|
37
37
|
compressed_tensors/registry/__init__.py,sha256=FwLSNYqfIrb5JD_6OK_MT4_svvKTN_nEhpgQlQvGbjI,658
|
38
38
|
compressed_tensors/registry/registry.py,sha256=vRcjVB1ITfSbfYUaGndBBmqhip_5vsS62weorVg0iXo,11896
|
39
39
|
compressed_tensors/utils/__init__.py,sha256=gS4gSU2pwcAbsKj-6YMaqhm25udFy6ISYaWBf-myRSM,808
|
40
|
-
compressed_tensors/utils/helpers.py,sha256=
|
41
|
-
compressed_tensors/utils/offload.py,sha256=
|
40
|
+
compressed_tensors/utils/helpers.py,sha256=XF36-SLkXnAHh0VzbvUlAdh6a88aCQvS_WeYs9Lfio8,6827
|
41
|
+
compressed_tensors/utils/offload.py,sha256=cMmzd9IdlNbs29CReHj1PPSLUM6OWaT5YumlLT5eP3w,13845
|
42
42
|
compressed_tensors/utils/permutations_24.py,sha256=kx6fsfDHebx94zsSzhXGyCyuC9sVyah6BUUir_StT28,2530
|
43
43
|
compressed_tensors/utils/permute.py,sha256=V6tJLKo3Syccj-viv4F7ZKZgJeCB-hl-dK8RKI_kBwI,2355
|
44
44
|
compressed_tensors/utils/safetensors_load.py,sha256=m08ANVuTBxQdoa6LufDgcNJ7wCLDJolyZljB8VEybAU,8578
|
45
45
|
compressed_tensors/utils/semi_structured_conversions.py,sha256=XKNffPum54kPASgqKzgKvyeqWPAkair2XEQXjkp7ho8,13489
|
46
|
-
compressed_tensors_nightly-0.8.1.
|
47
|
-
compressed_tensors_nightly-0.8.1.
|
48
|
-
compressed_tensors_nightly-0.8.1.
|
49
|
-
compressed_tensors_nightly-0.8.1.
|
50
|
-
compressed_tensors_nightly-0.8.1.
|
46
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
47
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/METADATA,sha256=_Dh-8bZ7fT6iBH4JWMABCkMExAOjA4p6N0OJ_vyDwps,6799
|
48
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
49
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/top_level.txt,sha256=w2i-GyPs2s1UwVxvutSvN_lM22SXC2hQFBmoMcPnV7Y,19
|
50
|
+
compressed_tensors_nightly-0.8.1.20241225.dist-info/RECORD,,
|
File without changes
|
File without changes
|