compressed-tensors-nightly 0.8.1.20241220__tar.gz → 0.8.1.20241223__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (57) hide show
  1. {compressed-tensors-nightly-0.8.1.20241220/src/compressed_tensors_nightly.egg-info → compressed-tensors-nightly-0.8.1.20241223}/PKG-INFO +1 -1
  2. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/initialize.py +17 -44
  3. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/helpers.py +64 -1
  4. compressed-tensors-nightly-0.8.1.20241223/src/compressed_tensors/utils/offload.py +404 -0
  5. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223/src/compressed_tensors_nightly.egg-info}/PKG-INFO +1 -1
  6. compressed-tensors-nightly-0.8.1.20241220/src/compressed_tensors/utils/offload.py +0 -116
  7. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/LICENSE +0 -0
  8. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/README.md +0 -0
  9. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/pyproject.toml +0 -0
  10. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/setup.cfg +0 -0
  11. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/setup.py +0 -0
  12. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/__init__.py +0 -0
  13. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/base.py +0 -0
  14. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/__init__.py +0 -0
  15. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/base.py +0 -0
  16. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/helpers.py +0 -0
  17. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  18. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +0 -0
  19. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  20. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  21. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  22. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  23. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_compressors/__init__.py +0 -0
  24. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  25. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  26. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +0 -0
  27. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  28. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  29. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/config/__init__.py +0 -0
  30. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/config/base.py +0 -0
  31. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/config/dense.py +0 -0
  32. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  33. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/linear/__init__.py +0 -0
  34. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  35. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/__init__.py +0 -0
  36. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  37. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  38. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  39. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  40. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  41. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/quant_args.py +0 -0
  42. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/quant_config.py +0 -0
  43. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  44. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  45. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  46. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/registry/__init__.py +0 -0
  47. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/registry/registry.py +0 -0
  48. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/__init__.py +0 -0
  49. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/permutations_24.py +0 -0
  50. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/permute.py +0 -0
  51. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  52. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  53. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors/version.py +0 -0
  54. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors_nightly.egg-info/SOURCES.txt +0 -0
  55. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors_nightly.egg-info/dependency_links.txt +0 -0
  56. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors_nightly.egg-info/requires.txt +0 -0
  57. {compressed-tensors-nightly-0.8.1.20241220 → compressed-tensors-nightly-0.8.1.20241223}/src/compressed_tensors_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.8.1.20241220
3
+ Version: 0.8.1.20241223
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
29
29
  from compressed_tensors.quantization.quant_config import QuantizationStatus
30
30
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
31
  from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32
- from compressed_tensors.utils import get_execution_device, is_module_offloaded
32
+ from compressed_tensors.utils import (
33
+ disable_hf_hook,
34
+ has_offloaded_params,
35
+ register_offload_parameter,
36
+ )
33
37
  from torch.nn import Module, Parameter
34
38
 
35
39
 
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
112
116
  module.quantization_scheme = scheme
113
117
  module.quantization_status = QuantizationStatus.INITIALIZED
114
118
 
115
- offloaded = False
116
- # What is this doing/why isn't this in the attn case?
117
- if is_module_offloaded(module):
118
- try:
119
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
120
- from accelerate.utils import PrefixedDataset
121
- except ModuleNotFoundError:
122
- raise ModuleNotFoundError(
123
- "Offloaded model detected. To use CPU offloading with "
124
- "compressed-tensors the `accelerate` package must be installed, "
125
- "run `pip install compressed-tensors[accelerate]`"
126
- )
127
-
128
- offloaded = True
129
- hook = module._hf_hook
130
- prefix_dict = module._hf_hook.weights_map
131
- new_prefix = {}
132
-
133
- # recreate the prefix dict (since it is immutable)
134
- # and add quantization parameters
135
- for key, data in module.named_parameters():
136
- if key not in prefix_dict:
137
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
138
- else:
139
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
140
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
141
- remove_hook_from_module(module)
142
-
143
- # wrap forward call of module to perform
144
- # quantized actions based on calltime status
145
- wrap_module_forward_quantized(module, scheme)
146
-
147
- if offloaded:
148
- # we need to re-add the hook for offloading now that we've wrapped forward
149
- add_hook_to_module(module, hook)
150
- if prefix_dict is not None:
151
- module._hf_hook.weights_map = new_prefix_dict
119
+ with disable_hf_hook(module):
120
+ # wrap forward call of module to perform
121
+ # quantized actions based on calltime status
122
+ wrap_module_forward_quantized(module, scheme)
152
123
 
153
124
 
154
125
  def is_attention_module(module: Module):
@@ -169,9 +140,11 @@ def _initialize_scale_zero_point(
169
140
  if quantization_args.dynamic:
170
141
  return
171
142
 
172
- device = next(module.parameters()).device
173
- if is_module_offloaded(module):
174
- device = get_execution_device(module)
143
+ # begin on the same device as other parameters or cpu if offloaded.
144
+ # in the offloaded case, there's no point moving tensors to the execution device
145
+ # if they're going to be immediately offloaded by `register_offload_parameter`
146
+ params_device = next(module.parameters()).device
147
+ device = "cpu" if has_offloaded_params(module) else params_device
175
148
 
176
149
  # infer expected scale/zero point shape
177
150
  if quantization_args.strategy == QuantizationStrategy.TOKEN:
@@ -196,7 +169,7 @@ def _initialize_scale_zero_point(
196
169
  torch.empty(expected_shape, dtype=scale_dtype, device=device),
197
170
  requires_grad=False,
198
171
  )
199
- module.register_parameter(f"{base_name}_scale", init_scale)
172
+ register_offload_parameter(module, f"{base_name}_scale", init_scale)
200
173
 
201
174
  if force_zero_point or not quantization_args.symmetric:
202
175
  zp_dtype = quantization_args.pytorch_dtype()
@@ -204,7 +177,7 @@ def _initialize_scale_zero_point(
204
177
  torch.zeros(expected_shape, device=device, dtype=zp_dtype),
205
178
  requires_grad=False,
206
179
  )
207
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
180
+ register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
208
181
 
209
182
  # only grouped activation ordering has g_idx
210
183
  if quantization_args.actorder == ActivationOrdering.GROUP:
@@ -214,7 +187,7 @@ def _initialize_scale_zero_point(
214
187
  torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
215
188
  requires_grad=False,
216
189
  )
217
- module.register_parameter(f"{base_name}_g_idx", init_g_idx)
190
+ register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
218
191
 
219
192
 
220
193
  def _initialize_attn_scales(module: Module) -> None:
@@ -12,7 +12,9 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ import warnings
16
+ from functools import wraps
17
+ from typing import Any, Callable, Dict, Optional
16
18
 
17
19
  import torch
18
20
  from transformers import AutoConfig
@@ -24,6 +26,8 @@ __all__ = [
24
26
  "tensor_follows_mask_structure",
25
27
  "replace_module",
26
28
  "is_compressed_tensors_config",
29
+ "getattr_chain",
30
+ "deprecated",
27
31
  "Aliasable",
28
32
  ]
29
33
 
@@ -122,6 +126,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
122
126
  return False
123
127
 
124
128
 
129
+ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
130
+ """
131
+ Chain multiple getattr calls, separated by `.`
132
+
133
+ :param obj: base object whose attributes are being retrieved
134
+ :param chain_str: attribute names separated by `.`
135
+ :param default: default value, throw error otherwise
136
+ """
137
+ if len(args) >= 1:
138
+ has_default = True
139
+ default = args[0]
140
+ elif "default" in kwargs:
141
+ has_default = True
142
+ default = kwargs["default"]
143
+ else:
144
+ has_default = False
145
+
146
+ attr_names = chain_str.split(".")
147
+
148
+ res = obj
149
+ for attr_name in attr_names:
150
+ if not hasattr(res, attr_name):
151
+ if has_default:
152
+ return default
153
+ else:
154
+ raise AttributeError(f"{res} object has no attribute {attr_name}")
155
+ res = getattr(res, attr_name)
156
+
157
+ return res
158
+
159
+
160
+ def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
161
+ """
162
+ Decorator to mark functions as deprecated
163
+
164
+ :param new_function: Function called in place of depreciated function
165
+ :param message: Depreciation message, replaces default depreciation message
166
+ """
167
+
168
+ def decorator(func: Callable[[Any], Any]):
169
+ nonlocal message
170
+
171
+ if message is None:
172
+ message = (
173
+ f"{func.__name__} is deprecated and will be removed in a future release"
174
+ )
175
+ if future_name is not None:
176
+ message += f". Please use {future_name} instead."
177
+
178
+ @wraps(func)
179
+ def wrapped(*args, **kwargs):
180
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
181
+ return func(*args, **kwargs)
182
+
183
+ return wrapped
184
+
185
+ return decorator
186
+
187
+
125
188
  class Aliasable:
126
189
  """
127
190
  A mixin for enums to allow aliasing of enum members
@@ -0,0 +1,404 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Utilities associated with offloading functionality provided by `accelerate`.
16
+
17
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
18
+ | Operation | Without offloading support | With offloading support | # noqa: E501
19
+ | --------- | -------------------------------------- | ------------------------------------------------ | # noqa: E501
20
+ | Add | module.register_parameter(name, param) | register_offload_parameter(module, name, param) | # noqa: E501
21
+ | Check | N/A | has_offloaded_params(module) | # noqa: E501
22
+ | Onload | N/A | with align_module_device(module) | # noqa: E501
23
+ | Update | module.name.data.copy_(new_data) | update_offload_parameter(module, name, new_data) | # noqa: E501
24
+ | Delete | del module.name | delete_offload_parameter(module, name) | # noqa: E501
25
+ | ----------------------------------------------------------------------------------------------------- | # noqa: E501
26
+ """
27
+
28
+ import contextlib
29
+ from functools import wraps
30
+ from typing import Any, Callable, Dict, Literal, Optional, Union
31
+
32
+ import torch
33
+
34
+
35
+ try:
36
+ from accelerate.hooks import (
37
+ AlignDevicesHook,
38
+ add_hook_to_module,
39
+ remove_hook_from_module,
40
+ )
41
+ from accelerate.utils import (
42
+ OffloadedWeightsLoader,
43
+ PrefixedDataset,
44
+ set_module_tensor_to_device,
45
+ )
46
+
47
+ _has_accelerate = True
48
+ except ImportError:
49
+ _has_accelerate = False
50
+ AlignDevicesHook = None
51
+ add_hook_to_module = None
52
+ remove_hook_from_module = None
53
+ OffloadedWeightsLoader = None
54
+ PrefixedDataset = None
55
+ set_module_tensor_to_device = None
56
+
57
+
58
+ __all__ = [
59
+ "is_module_offloaded",
60
+ "get_execution_device",
61
+ "get_offloaded_device",
62
+ "update_prefix_dict",
63
+ "update_parameter_data",
64
+ "register_offload_parameter",
65
+ "update_offload_parameter",
66
+ "delete_offload_parameter",
67
+ "has_offloaded_params",
68
+ "disable_hf_hook",
69
+ "align_module_device",
70
+ ]
71
+
72
+
73
+ def check_accelerate(fallback: Any):
74
+ def decorator(func: Callable[[Any], Any]):
75
+ if not _has_accelerate:
76
+
77
+ @wraps(func)
78
+ def fallback_fn(*args, **kwargs):
79
+ return fallback
80
+
81
+ return fallback_fn
82
+
83
+ return func
84
+
85
+ return decorator
86
+
87
+
88
+ """ Candidates for Depreciation """
89
+
90
+
91
+ @check_accelerate(fallback=False)
92
+ def is_module_offloaded(module: torch.nn.Module) -> bool:
93
+ return has_offloaded_params(module)
94
+
95
+
96
+ def get_execution_device(module: torch.nn.Module) -> torch.device:
97
+ """
98
+ :param module: module to check
99
+ :return: device module is loaded onto during forward pass
100
+ """
101
+ if has_offloaded_params(module):
102
+ return module._hf_hook.execution_device
103
+ device = next(module.parameters()).device
104
+
105
+ # offload only gets set for leaf modules, fallback to checking for device type
106
+ if device.type == "meta":
107
+ return module._hf_hook.execution_device
108
+
109
+ return device
110
+
111
+
112
+ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
113
+ """
114
+ :param module: module to check
115
+ :return: device module is offloaded to onto after forward pass
116
+ """
117
+ if has_offloaded_params(module):
118
+ first_key = list(module._hf_hook.weights_map.keys())[0]
119
+ prefix_dataset = module._hf_hook.weights_map.dataset
120
+ return prefix_dataset[first_key].device
121
+ return next(module.parameters()).device
122
+
123
+
124
+ @check_accelerate(fallback=None)
125
+ def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
126
+ """
127
+ Updates the offloaded state dict for a given module. Parameter named key is replaced
128
+ by data. This is neccesary because parameter updates for offloaded modules do not
129
+ persist automatically between loads. This function only affects the offloaded
130
+ state dict and not the current state of the loaded module.
131
+
132
+ :param module: module containing the parameter to update
133
+ :param key: name of parameter to update
134
+ :param data: tensor to update parameter with in the offloaded state dict
135
+ """
136
+ if not has_offloaded_params(module):
137
+ raise ValueError("Prefix dict is only applicable to offloaded modules")
138
+
139
+ weights_map = module._hf_hook.weights_map
140
+ offload_to_weights_map(weights_map, key, data)
141
+
142
+
143
+ def update_parameter_data(
144
+ module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
145
+ ):
146
+ """
147
+ Update the data of an existing parameter and its offload dict. Supports both
148
+ parameters of offloaded modules and non-offloaded modules
149
+
150
+ :param module: module containing the parameter to update
151
+ :param new_param_data: tensor to update parameter with
152
+ :param param_name: name of module parameter to update
153
+ """
154
+ update_offload_parameter(module, param_name, new_param_data)
155
+
156
+
157
+ """ Candidates for Upstreaming """
158
+
159
+
160
+ def register_offload_parameter(
161
+ module: torch.nn.Module,
162
+ name: str,
163
+ parameter: torch.nn.Parameter,
164
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
165
+ ):
166
+ """
167
+ Register a parameter to the given module which may be offloaded
168
+
169
+ :param module: maybe offloaded module
170
+ :param name: name of newly registered parameter
171
+ :param parameter: parameter being registered
172
+ :param offload_device: device on which weight will be offloaded to. If None is
173
+ provided, then infer device from parameters on module
174
+ """
175
+ has_onload = any(p.device != torch.device("meta") for p in module.parameters())
176
+ module.register_parameter(name, parameter)
177
+
178
+ if has_offloaded_params(module):
179
+ weights_map = module._hf_hook.weights_map
180
+ offload_to_weights_map(weights_map, name, parameter.data, offload_device)
181
+ if not has_onload:
182
+ set_module_tensor_to_device(module, name, "meta")
183
+
184
+
185
+ def update_offload_parameter(
186
+ module: torch.nn.Module,
187
+ name: str,
188
+ data: Optional[torch.Tensor],
189
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
190
+ ):
191
+ """
192
+ Update the data of an existing parameter and its offload dict. Supports both
193
+ parameters of offloaded modules and non-offloaded modules
194
+
195
+ :param module: module containing the parameter to update
196
+ :param name: name of module parameter to update
197
+ :param data: tensor to update parameter with
198
+ :param offload_device: device on which weight will be offloaded to. If None is
199
+ provided, then infer device from parameters on module
200
+ """
201
+ param = getattr(module, name)
202
+ data = data.to(param.dtype)
203
+
204
+ # copy data into onloaded parameter if applicable
205
+ if param.device != "meta":
206
+ param.data.copy_(data)
207
+
208
+ # update offload dict
209
+ if has_offloaded_params(module):
210
+ weights_map = module._hf_hook.weights_map
211
+ offload_to_weights_map(weights_map, name, data, offload_device)
212
+
213
+
214
+ def delete_offload_parameter(module: torch.nn.Module, name: str):
215
+ """
216
+ Delete a parameter from a module which may be offloaded
217
+
218
+ :param module: maybe offloaded module
219
+ :param name: name of parameter being deleted
220
+ """
221
+ delattr(module, name)
222
+
223
+ if has_offloaded_params(module):
224
+ weights_map = module._hf_hook.weights_map
225
+ delete_from_weights_map(weights_map, name)
226
+
227
+
228
+ @check_accelerate(fallback=contextlib.nullcontext())
229
+ @contextlib.contextmanager
230
+ def disable_hf_hook(module: torch.nn.Module):
231
+ hooks = {}
232
+
233
+ def collect_hooks(module):
234
+ nonlocal hooks
235
+ if hasattr(module, "_hf_hook"):
236
+ hooks[module] = module._hf_hook
237
+ remove_hook_from_module(module)
238
+
239
+ module.apply(collect_hooks)
240
+
241
+ yield
242
+
243
+ for submodule, hook in hooks.items():
244
+ add_hook_to_module(submodule, hook)
245
+
246
+
247
+ @check_accelerate(fallback=None)
248
+ def offload_to_weights_map(
249
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
250
+ key: str,
251
+ value: torch.Tensor,
252
+ offload_device: Optional[Union[torch.device, Literal["disk"]]] = None,
253
+ ):
254
+ """
255
+ Helper function which implements offloaded item assignment for PrefixedDataset,
256
+ OffloadedWeightsLoader, and Dict types.
257
+
258
+ :param weights_map: weight map to be updated with offload information
259
+ :param key: key used to identify weight location
260
+ :param value: weight being offloaded
261
+ :param offload_device: device on which weight will be offloaded to. If None is
262
+ provided, then infer device from parameters in weights_map
263
+ """
264
+ if isinstance(weights_map, PrefixedDataset):
265
+ if offload_device == "disk":
266
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
267
+
268
+ dataset = weights_map.dataset
269
+ key = f"{weights_map.prefix}{key}"
270
+ offload_to_weights_map(dataset, key, value, offload_device)
271
+
272
+ elif isinstance(weights_map, OffloadedWeightsLoader):
273
+ if key not in weights_map.all_keys:
274
+ weights_map.all_keys.append(key)
275
+
276
+ if len(weights_map.index) <= 0 and offload_device != "disk":
277
+ offload_to_weights_map(weights_map.state_dict, key, value, offload_device)
278
+
279
+ else:
280
+ raise NotImplementedError(
281
+ "Updating weights_map with disk offloading is not implemented yet"
282
+ )
283
+
284
+ elif isinstance(weights_map, dict):
285
+ if offload_device == "disk":
286
+ raise ValueError(f"Cannot offload to disk with type {type(weights_map)}")
287
+
288
+ # infer offload device
289
+ if offload_device is None:
290
+ if key in weights_map:
291
+ offload_device = weights_map[key].device
292
+ else:
293
+ tens = next(iter(weights_map.values()), None)
294
+ if tens is None:
295
+ raise ValueError(
296
+ "Cannot infer offload device from empty weights_map"
297
+ )
298
+ offload_device = tens.device
299
+
300
+ weights_map[key] = value.to(device=offload_device)
301
+
302
+ else:
303
+ raise NotImplementedError(
304
+ "Updating offload data not implemented for weights_map of type "
305
+ f"{type(weights_map)}"
306
+ )
307
+
308
+
309
+ @check_accelerate(fallback=None)
310
+ def delete_from_weights_map(
311
+ weights_map: Union[PrefixedDataset, Dict, OffloadedWeightsLoader],
312
+ key: str,
313
+ ):
314
+ if isinstance(weights_map, PrefixedDataset):
315
+ dataset = weights_map.dataset
316
+ key = f"{weights_map.prefix}{key}"
317
+ delete_from_weights_map(dataset, key)
318
+
319
+ elif isinstance(weights_map, OffloadedWeightsLoader):
320
+ if len(weights_map.index) <= 0:
321
+ delete_from_weights_map(weights_map.state_dict, key)
322
+
323
+ else:
324
+ raise NotImplementedError(
325
+ "Delete from weights_map with disk offloading is not implemented yet"
326
+ )
327
+
328
+ elif isinstance(weights_map, dict):
329
+ del weights_map[key]
330
+
331
+ else:
332
+ raise NotImplementedError(
333
+ "Updating offload data not implemented for weights_map of type "
334
+ f"{type(weights_map)}"
335
+ )
336
+
337
+
338
+ """ Upstreamed Functions """
339
+
340
+
341
+ # introduced in accelerate v1.1.0
342
+ @check_accelerate(fallback=False)
343
+ def has_offloaded_params(module: torch.nn.Module) -> bool:
344
+ """
345
+ Checks if a module has offloaded parameters by checking if the given module has a
346
+ AlignDevicesHook attached with offloading enabled
347
+
348
+ Args:
349
+ module (`torch.nn.Module`): The module to check for an offload hook.
350
+
351
+ Returns:
352
+ bool: `True` if the module has an offload hook and offloading is enabled,
353
+ `False` otherwise.
354
+ """
355
+ return (
356
+ hasattr(module, "_hf_hook")
357
+ and isinstance(module._hf_hook, AlignDevicesHook)
358
+ and module._hf_hook.offload
359
+ )
360
+
361
+
362
+ # introduced in accelerate v1.1.0
363
+ @check_accelerate(fallback=contextlib.nullcontext())
364
+ @contextlib.contextmanager
365
+ def align_module_device(
366
+ module: torch.nn.Module, execution_device: Optional[torch.device] = None
367
+ ):
368
+ """
369
+ Context manager that moves a module's parameters to the specified execution device.
370
+
371
+ Args:
372
+ module (`torch.nn.Module`):
373
+ Module with parameters to align.
374
+ execution_device (`torch.device`, *optional*):
375
+ If provided, overrides the module's execution device within the context.
376
+ Otherwise, use hook execution device or pass
377
+ """
378
+ if has_offloaded_params(module):
379
+ if execution_device is not None:
380
+ original_device = module._hf_hook.execution_device
381
+ module._hf_hook.execution_device = execution_device
382
+
383
+ try:
384
+ module._hf_hook.pre_forward(module)
385
+ yield
386
+ finally:
387
+ module._hf_hook.post_forward(module, None)
388
+ if execution_device is not None:
389
+ module._hf_hook.execution_device = original_device
390
+
391
+ elif execution_device is not None:
392
+ devices = {
393
+ name: param.device for name, param in module.named_parameters(recurse=False)
394
+ }
395
+ try:
396
+ for name in devices:
397
+ set_module_tensor_to_device(module, name, execution_device)
398
+ yield
399
+ finally:
400
+ for name, device in devices.items():
401
+ set_module_tensor_to_device(module, name, device)
402
+
403
+ else:
404
+ yield
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.8.1.20241220
3
+ Version: 0.8.1.20241223
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -1,116 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import torch
16
- from torch.nn import Module
17
-
18
-
19
- __all__ = [
20
- "is_module_offloaded",
21
- "get_execution_device",
22
- "get_offloaded_device",
23
- "update_prefix_dict",
24
- "update_parameter_data",
25
- ]
26
-
27
-
28
- def is_module_offloaded(module: Module) -> bool:
29
- """
30
- :param module: layer to check
31
- :return: True if layer is offloaded from GPU, False otherwise
32
- """
33
- return hasattr(module, "_hf_hook") and module._hf_hook.offload
34
-
35
-
36
- def get_execution_device(module: Module) -> torch.device:
37
- """
38
- :param module: layer to check
39
- :return: device layer is loaded onto during forward pass
40
- """
41
- if is_module_offloaded(module):
42
- return module._hf_hook.execution_device
43
- device = next(module.parameters()).device
44
-
45
- # offload only gets set for leaf modules, fallback to checking for device type
46
- if device.type == "meta":
47
- return module._hf_hook.execution_device
48
-
49
- return device
50
-
51
-
52
- def get_offloaded_device(module: Module) -> torch.device:
53
- """
54
- :param module: layer to check
55
- :return: device layer is offloaded to onto after forward pass
56
- """
57
- if is_module_offloaded(module):
58
- first_key = list(module._hf_hook.weights_map.keys())[0]
59
- prefix_dataset = module._hf_hook.weights_map.dataset
60
- return prefix_dataset[first_key].device
61
- return next(module.parameters()).device
62
-
63
-
64
- def update_prefix_dict(module: Module, key: str, data: torch.Tensor):
65
- """
66
- Updates the offloaded state dict for a given module. Parameter named key is replaced
67
- by data. This is neccesary because parameter updates for offloaded modules do not
68
- persist automatically between loads. This function only affects the offloaded
69
- state dict and not the current state of the loaded module.
70
-
71
- :param module: layer containing the parameter to update
72
- :param key: name of parameter to update
73
- :param data: tensor to update parameter with in the offloaded state dict
74
- """
75
- if not is_module_offloaded(module):
76
- raise ValueError("Prefix dict is only applicable to offloaded modules")
77
- prefix_dict = module._hf_hook.weights_map
78
- prefix_dict.dataset[f"{prefix_dict.prefix}{key}"] = data
79
-
80
-
81
- def update_parameter_data(
82
- module: Module, new_param_data: torch.Tensor, param_name: str
83
- ):
84
- """
85
- Updates the paramter value named param_name for a given module. This function
86
- updates both the current loaded module state and the offloaded state dict if
87
- the module is offloaded. This is neccesary because parameter updates for offloaded
88
- modules do not persist automatically between loads.
89
-
90
- :param module: layer containing the parameter to update
91
- :param new_param_data: tensor to update parameter with
92
- :param param_name: name of layer parameter to update
93
- """
94
- if not hasattr(module, param_name):
95
- return
96
-
97
- device = next(module.parameters()).device
98
-
99
- offloaded = False
100
- if is_module_offloaded(module):
101
- offload_device = get_offloaded_device(module)
102
- offloaded = True
103
-
104
- parameter = getattr(module, param_name, None)
105
- if parameter is None:
106
- raise ValueError("Attempted to update uninitialized parameter")
107
-
108
- dtype = parameter.dtype
109
- parameter.data = new_param_data.to(device).to(dtype)
110
-
111
- if offloaded:
112
- prefix_dict = module._hf_hook.weights_map.dataset
113
- prefix = module._hf_hook.weights_map.prefix
114
- prefix_dict[f"{prefix}{param_name}"] = new_param_data.to(offload_device).to(
115
- dtype
116
- )