compressed-tensors 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
- compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- compressed_tensors/config/__init__.py +1 -0
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/config/sparse_24_bitmask.py +40 -0
- compressed_tensors/linear/compressed_linear.py +3 -1
- compressed_tensors/quantization/lifecycle/apply.py +48 -2
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- compressed_tensors/quantization/quant_args.py +16 -3
- compressed_tensors/quantization/quant_config.py +3 -3
- compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed_tensors/utils/helpers.py +206 -1
- compressed_tensors/utils/offload.py +332 -44
- compressed_tensors/utils/safetensors_load.py +83 -17
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +27 -25
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,14 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from copy import deepcopy
|
16
|
-
from typing import List, Optional
|
16
|
+
from typing import Any, Dict, List, Optional
|
17
17
|
|
18
18
|
from compressed_tensors.quantization.quant_args import (
|
19
19
|
QuantizationArgs,
|
20
20
|
QuantizationStrategy,
|
21
21
|
QuantizationType,
|
22
22
|
)
|
23
|
-
from pydantic import BaseModel
|
23
|
+
from pydantic import BaseModel, model_validator
|
24
24
|
|
25
25
|
|
26
26
|
__all__ = [
|
@@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel):
|
|
36
36
|
of modules should be quantized
|
37
37
|
|
38
38
|
:param targets: list of modules to apply the QuantizationArgs to, can be layer
|
39
|
-
names, layer types or a regular expression
|
39
|
+
names, layer types or a regular expression, typically ["Linear"]
|
40
40
|
:param weights: quantization config for layer weights
|
41
41
|
:param input_activations: quantization config for layer inputs
|
42
42
|
:param output_activations: quantization config for layer outputs
|
@@ -47,27 +47,20 @@ class QuantizationScheme(BaseModel):
|
|
47
47
|
input_activations: Optional[QuantizationArgs] = None
|
48
48
|
output_activations: Optional[QuantizationArgs] = None
|
49
49
|
|
50
|
-
@
|
51
|
-
def
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return cls(
|
66
|
-
targets=targets,
|
67
|
-
weights=weights,
|
68
|
-
input_activations=input_activations,
|
69
|
-
output_activations=output_activations,
|
70
|
-
)
|
50
|
+
@model_validator(mode="after")
|
51
|
+
def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
|
52
|
+
inputs = model.input_activations
|
53
|
+
outputs = model.output_activations
|
54
|
+
|
55
|
+
if inputs is not None:
|
56
|
+
if inputs.actorder is not None:
|
57
|
+
raise ValueError("Cannot apply actorder to input activations")
|
58
|
+
|
59
|
+
if outputs is not None:
|
60
|
+
if outputs.actorder is not None:
|
61
|
+
raise ValueError("Cannot apply actorder to output activations")
|
62
|
+
|
63
|
+
return model
|
71
64
|
|
72
65
|
|
73
66
|
"""
|
@@ -12,8 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import warnings
|
16
|
+
from functools import wraps
|
17
|
+
from typing import Any, Callable, Dict, List, Optional
|
16
18
|
|
19
|
+
import numpy
|
17
20
|
import torch
|
18
21
|
from transformers import AutoConfig
|
19
22
|
|
@@ -24,6 +27,13 @@ __all__ = [
|
|
24
27
|
"tensor_follows_mask_structure",
|
25
28
|
"replace_module",
|
26
29
|
"is_compressed_tensors_config",
|
30
|
+
"getattr_chain",
|
31
|
+
"deprecated",
|
32
|
+
"Aliasable",
|
33
|
+
"combine_shards",
|
34
|
+
"shard_tensor",
|
35
|
+
"pack_bitmasks",
|
36
|
+
"unpack_bitmasks",
|
27
37
|
]
|
28
38
|
|
29
39
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -119,3 +129,198 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
|
|
119
129
|
return isinstance(compression_config, CompressedTensorsConfig)
|
120
130
|
except ImportError:
|
121
131
|
return False
|
132
|
+
|
133
|
+
|
134
|
+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
|
135
|
+
"""
|
136
|
+
Chain multiple getattr calls, separated by `.`
|
137
|
+
|
138
|
+
:param obj: base object whose attributes are being retrieved
|
139
|
+
:param chain_str: attribute names separated by `.`
|
140
|
+
:param default: default value, throw error otherwise
|
141
|
+
"""
|
142
|
+
if len(args) >= 1:
|
143
|
+
has_default = True
|
144
|
+
default = args[0]
|
145
|
+
elif "default" in kwargs:
|
146
|
+
has_default = True
|
147
|
+
default = kwargs["default"]
|
148
|
+
else:
|
149
|
+
has_default = False
|
150
|
+
|
151
|
+
attr_names = chain_str.split(".")
|
152
|
+
|
153
|
+
res = obj
|
154
|
+
for attr_name in attr_names:
|
155
|
+
if not hasattr(res, attr_name):
|
156
|
+
if has_default:
|
157
|
+
return default
|
158
|
+
else:
|
159
|
+
raise AttributeError(f"{res} object has no attribute {attr_name}")
|
160
|
+
res = getattr(res, attr_name)
|
161
|
+
|
162
|
+
return res
|
163
|
+
|
164
|
+
|
165
|
+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
|
166
|
+
"""
|
167
|
+
Decorator to mark functions as deprecated
|
168
|
+
|
169
|
+
:param new_function: Function called in place of depreciated function
|
170
|
+
:param message: Depreciation message, replaces default depreciation message
|
171
|
+
"""
|
172
|
+
|
173
|
+
def decorator(func: Callable[[Any], Any]):
|
174
|
+
nonlocal message
|
175
|
+
|
176
|
+
if message is None:
|
177
|
+
message = (
|
178
|
+
f"{func.__name__} is deprecated and will be removed in a future release"
|
179
|
+
)
|
180
|
+
if future_name is not None:
|
181
|
+
message += f". Please use {future_name} instead."
|
182
|
+
|
183
|
+
@wraps(func)
|
184
|
+
def wrapped(*args, **kwargs):
|
185
|
+
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
186
|
+
return func(*args, **kwargs)
|
187
|
+
|
188
|
+
return wrapped
|
189
|
+
|
190
|
+
return decorator
|
191
|
+
|
192
|
+
|
193
|
+
class Aliasable:
|
194
|
+
"""
|
195
|
+
A mixin for enums to allow aliasing of enum members
|
196
|
+
|
197
|
+
Example:
|
198
|
+
>>> class MyClass(Aliasable, int, Enum):
|
199
|
+
>>> ...
|
200
|
+
"""
|
201
|
+
|
202
|
+
@staticmethod
|
203
|
+
def get_aliases() -> Dict[str, str]:
|
204
|
+
raise NotImplementedError()
|
205
|
+
|
206
|
+
def __eq__(self, other):
|
207
|
+
if isinstance(other, self.__class__):
|
208
|
+
aliases = self.get_aliases()
|
209
|
+
return self.value == other.value or (
|
210
|
+
aliases.get(self.value, self.value)
|
211
|
+
== aliases.get(other.value, other.value)
|
212
|
+
)
|
213
|
+
else:
|
214
|
+
aliases = self.get_aliases()
|
215
|
+
self_value = aliases.get(self.value, self.value)
|
216
|
+
other_value = aliases.get(other, other)
|
217
|
+
return self_value == other_value
|
218
|
+
|
219
|
+
def __hash__(self):
|
220
|
+
canonical_value = self.aliases.get(self.value, self.value)
|
221
|
+
return hash(canonical_value)
|
222
|
+
|
223
|
+
|
224
|
+
def shard_tensor(
|
225
|
+
tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
|
226
|
+
) -> List[torch.Tensor]:
|
227
|
+
"""
|
228
|
+
Shards a tensor into a list of tensors along a given dimension.
|
229
|
+
|
230
|
+
raises: ValueError: If the sum of shard_sizes does not match the
|
231
|
+
size of the tensor along the given dimension.
|
232
|
+
|
233
|
+
:param tensor: The input tensor to shard.
|
234
|
+
:param shard_sizes : List of sizes for each shard along the specified dimension.
|
235
|
+
:param dim : The dimension along which to shard the tensor.
|
236
|
+
:returns: A list of tensors sharded along the specified dimension.
|
237
|
+
"""
|
238
|
+
if sum(shard_sizes) != tensor.size(dim):
|
239
|
+
raise ValueError(
|
240
|
+
"Sum of shard_sizes must equal the size of the tensor "
|
241
|
+
"along the specified dimension."
|
242
|
+
)
|
243
|
+
|
244
|
+
shards = []
|
245
|
+
start_idx = 0
|
246
|
+
|
247
|
+
for size in shard_sizes:
|
248
|
+
end_idx = start_idx + size
|
249
|
+
shard = tensor.narrow(dim, start_idx, size)
|
250
|
+
shards.append(shard)
|
251
|
+
start_idx = end_idx
|
252
|
+
|
253
|
+
return shards
|
254
|
+
|
255
|
+
|
256
|
+
def combine_shards(shards, dim=0):
|
257
|
+
"""
|
258
|
+
Combine decompressed shards along a given dimension using `narrow`.
|
259
|
+
|
260
|
+
:param shards: List of decompressed shard tensors.
|
261
|
+
:param dim: Dimension to combine along (default: 0).
|
262
|
+
:return: Combined decompressed tensor.
|
263
|
+
"""
|
264
|
+
if not shards:
|
265
|
+
raise ValueError("The list of shards is empty.")
|
266
|
+
|
267
|
+
# Assert that all shards have the same dtype
|
268
|
+
shard_dtypes = {shard.dtype for shard in shards}
|
269
|
+
if len(shard_dtypes) > 1:
|
270
|
+
raise ValueError("All shards must have the same dtype.")
|
271
|
+
|
272
|
+
# Determine the total shape of the combined tensor
|
273
|
+
total_shape = list(shards[0].shape)
|
274
|
+
total_shape[dim] = sum(shard.shape[dim] for shard in shards)
|
275
|
+
|
276
|
+
# Create the combined tensor
|
277
|
+
combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
|
278
|
+
|
279
|
+
# Fill the combined tensor using narrow
|
280
|
+
shard_offset = 0
|
281
|
+
for shard in shards:
|
282
|
+
shard_size = shard.shape[dim]
|
283
|
+
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
|
284
|
+
shard_offset += shard_size
|
285
|
+
|
286
|
+
return combined
|
287
|
+
|
288
|
+
|
289
|
+
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
|
290
|
+
"""
|
291
|
+
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
292
|
+
compressed to R x ceil(C/8)
|
293
|
+
|
294
|
+
:param bytemasks: mask tensor where each byte corresponds to a weight
|
295
|
+
:return: mask tensor where each bit corresounds to a weight
|
296
|
+
"""
|
297
|
+
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
298
|
+
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
299
|
+
|
300
|
+
return packed_bits_torch
|
301
|
+
|
302
|
+
|
303
|
+
def unpack_bitmasks(
|
304
|
+
packed_bitmasks: torch.Tensor, original_shape: torch.Size
|
305
|
+
) -> torch.Tensor:
|
306
|
+
"""
|
307
|
+
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
308
|
+
|
309
|
+
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
310
|
+
:param original_shape: dense shape to decompress to
|
311
|
+
:return: boolean mask of weights in the original dense shape
|
312
|
+
"""
|
313
|
+
# Unpack the bits
|
314
|
+
unpacked_bits = numpy.unpackbits(
|
315
|
+
packed_bitmasks.cpu().numpy(),
|
316
|
+
axis=-1,
|
317
|
+
count=original_shape[-1],
|
318
|
+
bitorder="little",
|
319
|
+
)
|
320
|
+
|
321
|
+
# Reshape to match the original shape
|
322
|
+
unpacked_bitmasks_torch = torch.from_numpy(
|
323
|
+
unpacked_bits.reshape(original_shape).astype(bool)
|
324
|
+
)
|
325
|
+
|
326
|
+
return unpacked_bitmasks_torch
|