compressed-tensors 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
  2. compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  3. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
  5. compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
  6. compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  7. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  8. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  9. compressed_tensors/config/__init__.py +1 -0
  10. compressed_tensors/config/base.py +1 -0
  11. compressed_tensors/config/sparse_24_bitmask.py +40 -0
  12. compressed_tensors/linear/compressed_linear.py +3 -1
  13. compressed_tensors/quantization/lifecycle/apply.py +48 -2
  14. compressed_tensors/quantization/lifecycle/forward.py +2 -2
  15. compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  16. compressed_tensors/quantization/quant_args.py +16 -3
  17. compressed_tensors/quantization/quant_config.py +3 -3
  18. compressed_tensors/quantization/quant_scheme.py +17 -24
  19. compressed_tensors/utils/helpers.py +206 -1
  20. compressed_tensors/utils/offload.py +332 -44
  21. compressed_tensors/utils/safetensors_load.py +83 -17
  22. compressed_tensors/version.py +1 -1
  23. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
  24. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +27 -25
  25. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +1 -1
  26. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
  27. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -13,14 +13,14 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from copy import deepcopy
16
- from typing import List, Optional
16
+ from typing import Any, Dict, List, Optional
17
17
 
18
18
  from compressed_tensors.quantization.quant_args import (
19
19
  QuantizationArgs,
20
20
  QuantizationStrategy,
21
21
  QuantizationType,
22
22
  )
23
- from pydantic import BaseModel
23
+ from pydantic import BaseModel, model_validator
24
24
 
25
25
 
26
26
  __all__ = [
@@ -36,7 +36,7 @@ class QuantizationScheme(BaseModel):
36
36
  of modules should be quantized
37
37
 
38
38
  :param targets: list of modules to apply the QuantizationArgs to, can be layer
39
- names, layer types or a regular expression
39
+ names, layer types or a regular expression, typically ["Linear"]
40
40
  :param weights: quantization config for layer weights
41
41
  :param input_activations: quantization config for layer inputs
42
42
  :param output_activations: quantization config for layer outputs
@@ -47,27 +47,20 @@ class QuantizationScheme(BaseModel):
47
47
  input_activations: Optional[QuantizationArgs] = None
48
48
  output_activations: Optional[QuantizationArgs] = None
49
49
 
50
- @classmethod
51
- def default_scheme(
52
- cls,
53
- targets: Optional[List[str]] = None,
54
- ):
55
-
56
- if targets is None:
57
- # default to quantizing all Linear layers
58
- targets = ["Linear"]
59
-
60
- # by default, activations and weights are left unquantized
61
- weights = None
62
- input_activations = None
63
- output_activations = None
64
-
65
- return cls(
66
- targets=targets,
67
- weights=weights,
68
- input_activations=input_activations,
69
- output_activations=output_activations,
70
- )
50
+ @model_validator(mode="after")
51
+ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
52
+ inputs = model.input_activations
53
+ outputs = model.output_activations
54
+
55
+ if inputs is not None:
56
+ if inputs.actorder is not None:
57
+ raise ValueError("Cannot apply actorder to input activations")
58
+
59
+ if outputs is not None:
60
+ if outputs.actorder is not None:
61
+ raise ValueError("Cannot apply actorder to output activations")
62
+
63
+ return model
71
64
 
72
65
 
73
66
  """
@@ -12,8 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Optional
15
+ import warnings
16
+ from functools import wraps
17
+ from typing import Any, Callable, Dict, List, Optional
16
18
 
19
+ import numpy
17
20
  import torch
18
21
  from transformers import AutoConfig
19
22
 
@@ -24,6 +27,13 @@ __all__ = [
24
27
  "tensor_follows_mask_structure",
25
28
  "replace_module",
26
29
  "is_compressed_tensors_config",
30
+ "getattr_chain",
31
+ "deprecated",
32
+ "Aliasable",
33
+ "combine_shards",
34
+ "shard_tensor",
35
+ "pack_bitmasks",
36
+ "unpack_bitmasks",
27
37
  ]
28
38
 
29
39
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -119,3 +129,198 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
119
129
  return isinstance(compression_config, CompressedTensorsConfig)
120
130
  except ImportError:
121
131
  return False
132
+
133
+
134
+ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
135
+ """
136
+ Chain multiple getattr calls, separated by `.`
137
+
138
+ :param obj: base object whose attributes are being retrieved
139
+ :param chain_str: attribute names separated by `.`
140
+ :param default: default value, throw error otherwise
141
+ """
142
+ if len(args) >= 1:
143
+ has_default = True
144
+ default = args[0]
145
+ elif "default" in kwargs:
146
+ has_default = True
147
+ default = kwargs["default"]
148
+ else:
149
+ has_default = False
150
+
151
+ attr_names = chain_str.split(".")
152
+
153
+ res = obj
154
+ for attr_name in attr_names:
155
+ if not hasattr(res, attr_name):
156
+ if has_default:
157
+ return default
158
+ else:
159
+ raise AttributeError(f"{res} object has no attribute {attr_name}")
160
+ res = getattr(res, attr_name)
161
+
162
+ return res
163
+
164
+
165
+ def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
166
+ """
167
+ Decorator to mark functions as deprecated
168
+
169
+ :param new_function: Function called in place of depreciated function
170
+ :param message: Depreciation message, replaces default depreciation message
171
+ """
172
+
173
+ def decorator(func: Callable[[Any], Any]):
174
+ nonlocal message
175
+
176
+ if message is None:
177
+ message = (
178
+ f"{func.__name__} is deprecated and will be removed in a future release"
179
+ )
180
+ if future_name is not None:
181
+ message += f". Please use {future_name} instead."
182
+
183
+ @wraps(func)
184
+ def wrapped(*args, **kwargs):
185
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
186
+ return func(*args, **kwargs)
187
+
188
+ return wrapped
189
+
190
+ return decorator
191
+
192
+
193
+ class Aliasable:
194
+ """
195
+ A mixin for enums to allow aliasing of enum members
196
+
197
+ Example:
198
+ >>> class MyClass(Aliasable, int, Enum):
199
+ >>> ...
200
+ """
201
+
202
+ @staticmethod
203
+ def get_aliases() -> Dict[str, str]:
204
+ raise NotImplementedError()
205
+
206
+ def __eq__(self, other):
207
+ if isinstance(other, self.__class__):
208
+ aliases = self.get_aliases()
209
+ return self.value == other.value or (
210
+ aliases.get(self.value, self.value)
211
+ == aliases.get(other.value, other.value)
212
+ )
213
+ else:
214
+ aliases = self.get_aliases()
215
+ self_value = aliases.get(self.value, self.value)
216
+ other_value = aliases.get(other, other)
217
+ return self_value == other_value
218
+
219
+ def __hash__(self):
220
+ canonical_value = self.aliases.get(self.value, self.value)
221
+ return hash(canonical_value)
222
+
223
+
224
+ def shard_tensor(
225
+ tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
226
+ ) -> List[torch.Tensor]:
227
+ """
228
+ Shards a tensor into a list of tensors along a given dimension.
229
+
230
+ raises: ValueError: If the sum of shard_sizes does not match the
231
+ size of the tensor along the given dimension.
232
+
233
+ :param tensor: The input tensor to shard.
234
+ :param shard_sizes : List of sizes for each shard along the specified dimension.
235
+ :param dim : The dimension along which to shard the tensor.
236
+ :returns: A list of tensors sharded along the specified dimension.
237
+ """
238
+ if sum(shard_sizes) != tensor.size(dim):
239
+ raise ValueError(
240
+ "Sum of shard_sizes must equal the size of the tensor "
241
+ "along the specified dimension."
242
+ )
243
+
244
+ shards = []
245
+ start_idx = 0
246
+
247
+ for size in shard_sizes:
248
+ end_idx = start_idx + size
249
+ shard = tensor.narrow(dim, start_idx, size)
250
+ shards.append(shard)
251
+ start_idx = end_idx
252
+
253
+ return shards
254
+
255
+
256
+ def combine_shards(shards, dim=0):
257
+ """
258
+ Combine decompressed shards along a given dimension using `narrow`.
259
+
260
+ :param shards: List of decompressed shard tensors.
261
+ :param dim: Dimension to combine along (default: 0).
262
+ :return: Combined decompressed tensor.
263
+ """
264
+ if not shards:
265
+ raise ValueError("The list of shards is empty.")
266
+
267
+ # Assert that all shards have the same dtype
268
+ shard_dtypes = {shard.dtype for shard in shards}
269
+ if len(shard_dtypes) > 1:
270
+ raise ValueError("All shards must have the same dtype.")
271
+
272
+ # Determine the total shape of the combined tensor
273
+ total_shape = list(shards[0].shape)
274
+ total_shape[dim] = sum(shard.shape[dim] for shard in shards)
275
+
276
+ # Create the combined tensor
277
+ combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
278
+
279
+ # Fill the combined tensor using narrow
280
+ shard_offset = 0
281
+ for shard in shards:
282
+ shard_size = shard.shape[dim]
283
+ combined.narrow(dim, shard_offset, shard_size).copy_(shard)
284
+ shard_offset += shard_size
285
+
286
+ return combined
287
+
288
+
289
+ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
290
+ """
291
+ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
292
+ compressed to R x ceil(C/8)
293
+
294
+ :param bytemasks: mask tensor where each byte corresponds to a weight
295
+ :return: mask tensor where each bit corresounds to a weight
296
+ """
297
+ packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
298
+ packed_bits_torch = torch.from_numpy(packed_bits_numpy)
299
+
300
+ return packed_bits_torch
301
+
302
+
303
+ def unpack_bitmasks(
304
+ packed_bitmasks: torch.Tensor, original_shape: torch.Size
305
+ ) -> torch.Tensor:
306
+ """
307
+ Converts a bitmask tensor back to a bytemask tensor for use during decompression
308
+
309
+ :param packed_bitmasks: mask tensor where each bit corresponds to a weight
310
+ :param original_shape: dense shape to decompress to
311
+ :return: boolean mask of weights in the original dense shape
312
+ """
313
+ # Unpack the bits
314
+ unpacked_bits = numpy.unpackbits(
315
+ packed_bitmasks.cpu().numpy(),
316
+ axis=-1,
317
+ count=original_shape[-1],
318
+ bitorder="little",
319
+ )
320
+
321
+ # Reshape to match the original shape
322
+ unpacked_bitmasks_torch = torch.from_numpy(
323
+ unpacked_bits.reshape(original_shape).astype(bool)
324
+ )
325
+
326
+ return unpacked_bitmasks_torch