compressed-tensors 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +76 -14
  2. compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  3. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
  5. compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
  6. compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  7. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  8. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  9. compressed_tensors/config/__init__.py +1 -0
  10. compressed_tensors/config/base.py +1 -0
  11. compressed_tensors/config/sparse_24_bitmask.py +40 -0
  12. compressed_tensors/quantization/lifecycle/apply.py +46 -1
  13. compressed_tensors/quantization/lifecycle/forward.py +2 -2
  14. compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  15. compressed_tensors/quantization/quant_config.py +1 -1
  16. compressed_tensors/utils/helpers.py +174 -1
  17. compressed_tensors/utils/offload.py +332 -44
  18. compressed_tensors/utils/safetensors_load.py +83 -17
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
  21. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +24 -22
  22. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
  23. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +0 -0
  24. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
14
14
 
15
15
  from typing import Dict, List, Tuple, Union
16
16
 
17
- import numpy
18
17
  import torch
19
18
  from compressed_tensors.compressors.base import BaseCompressor
20
19
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
20
  from compressed_tensors.config import CompressionFormat
22
- from compressed_tensors.utils import merge_names
21
+ from compressed_tensors.quantization import FP8_DTYPE
22
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
23
23
  from torch import Tensor
24
24
 
25
25
 
@@ -28,8 +28,6 @@ __all__ = [
28
28
  "BitmaskTensor",
29
29
  "bitmask_compress",
30
30
  "bitmask_decompress",
31
- "pack_bitmasks",
32
- "unpack_bitmasks",
33
31
  ]
34
32
 
35
33
 
@@ -134,9 +132,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
134
132
  bytemasks = tensor != 0
135
133
  row_counts = bytemasks.sum(dim=-1)
136
134
  row_offsets = torch.cumsum(row_counts, 0) - row_counts
137
- values = tensor[bytemasks]
135
+ if tensor.dtype == FP8_DTYPE:
136
+ # acces raw bytes of the tensor
137
+ tensor_view = tensor.view(torch.int8)
138
+ values = tensor_view[bytemasks]
139
+ values = values.view(FP8_DTYPE)
140
+ else:
141
+ values = tensor[bytemasks]
138
142
  bitmasks_packed = pack_bitmasks(bytemasks)
139
-
140
143
  return values, bitmasks_packed, row_offsets
141
144
 
142
145
 
@@ -158,37 +161,3 @@ def bitmask_decompress(
158
161
  decompressed_tensor[bytemasks_unpacked] = values
159
162
 
160
163
  return decompressed_tensor
161
-
162
-
163
- def pack_bitmasks(bytemasks: Tensor) -> Tensor:
164
- """
165
- Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
166
- compressed to R x ceil(C/8)
167
- :param bytemasks: mask tensor where each byte corresponds to a weight
168
- :return: mask tensor where each bit corresounds to a weight
169
- """
170
- packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
171
- packed_bits_torch = torch.from_numpy(packed_bits_numpy)
172
-
173
- return packed_bits_torch
174
-
175
-
176
- def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
177
- """
178
- Converts a bitmask tensor back to a bytemask tensor for use during decompression
179
-
180
- :param packed_bitmasks: mask tensor where each bit corresponds to a weight
181
- :param original_shape: dense shape to decompress to
182
- :return: boolean mask of weights in the original dense shape
183
- """
184
- # Unpack the bits
185
- unpacked_bits = numpy.unpackbits(
186
- packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
187
- )
188
-
189
- # Reshape to match the original shape
190
- unpacked_bitmasks_torch = torch.from_numpy(
191
- unpacked_bits.reshape(original_shape).astype(bool)
192
- )
193
-
194
- return unpacked_bitmasks_torch
@@ -15,4 +15,5 @@
15
15
  # flake8: noqa
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -26,6 +26,7 @@ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"
26
26
  class CompressionFormat(Enum):
27
27
  dense = "dense"
28
28
  sparse_bitmask = "sparse-bitmask"
29
+ sparse_24_bitmask = "sparse-24-bitmask"
29
30
  int_quantized = "int-quantized"
30
31
  float_quantized = "float-quantized"
31
32
  naive_quantized = "naive-quantized"
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import (
18
+ CompressionFormat,
19
+ SparsityCompressionConfig,
20
+ SparsityStructure,
21
+ )
22
+
23
+
24
+ __all__ = ["Sparse24BitMaskConfig"]
25
+
26
+
27
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
28
+ class Sparse24BitMaskConfig(SparsityCompressionConfig):
29
+ """
30
+ Configuration for storing a 24 sparse model using
31
+ bytemask compression
32
+
33
+ :param global_sparsity: average sparsity of the entire model
34
+ :param sparsity_structure: structure of the sparsity, should always be
35
+ "2:4" for this compression format
36
+ """
37
+
38
+ format: str = CompressionFormat.sparse_24_bitmask.value
39
+ global_sparsity: Optional[float] = 0.0
40
+ sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
18
18
  from copy import deepcopy
19
19
  from typing import Dict, Iterable, List, Optional
20
20
  from typing import OrderedDict as OrderedDictType
21
- from typing import Union
21
+ from typing import Set, Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
@@ -52,6 +52,8 @@ __all__ = [
52
52
  "apply_quantization_config",
53
53
  "apply_quantization_status",
54
54
  "find_name_or_class_matches",
55
+ "expand_sparse_target_names",
56
+ "is_sparse_target",
55
57
  ]
56
58
 
57
59
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
245
247
  model.apply(compress_quantized_weights)
246
248
 
247
249
 
250
+ def expand_sparse_target_names(
251
+ model: Module, targets: Iterable[str], ignore: Iterable[str]
252
+ ) -> Set[str]:
253
+ """
254
+ Finds all unique module names in the model that match the given
255
+ targets and ignore lists.
256
+
257
+ Note: Targets must be regexes, layer types, or full layer names.
258
+
259
+ :param model: model to search for targets in
260
+ :param targets: list of targets to search for
261
+ :param ignore: list of targets to ignore
262
+ :return: set of all targets that match the given targets and should
263
+ not be ignored
264
+ """
265
+ return {
266
+ name
267
+ for name, module in iter_named_leaf_modules(model)
268
+ if is_sparse_target(name, module, targets, ignore)
269
+ }
270
+
271
+
272
+ def is_sparse_target(
273
+ name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
274
+ ) -> bool:
275
+ """
276
+ Determines if a module should be included in the targets based on the
277
+ targets and ignore lists.
278
+
279
+ Note: Targets must be regexes, layer types, or full layer names.
280
+
281
+ :param name: name of the module
282
+ :param module: the module itself
283
+ :param targets: list of targets to search for
284
+ :param ignore: list of targets to ignore
285
+ :return: True if the module is a target and not ignored, False otherwise
286
+ """
287
+ return bool(
288
+ find_name_or_class_matches(name, module, targets)
289
+ and not find_name_or_class_matches(name, module, ignore or [])
290
+ )
291
+
292
+
248
293
  def find_name_or_class_matches(
249
294
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
250
295
  ) -> List[str]:
@@ -82,8 +82,8 @@ def quantize(
82
82
  def dequantize(
83
83
  x_q: torch.Tensor,
84
84
  scale: torch.Tensor,
85
- zero_point: torch.Tensor = None,
86
- args: QuantizationArgs = None,
85
+ zero_point: Optional[torch.Tensor] = None,
86
+ args: Optional[QuantizationArgs] = None,
87
87
  dtype: Optional[torch.dtype] = None,
88
88
  g_idx: Optional[torch.Tensor] = None,
89
89
  ) -> torch.Tensor:
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
29
29
  from compressed_tensors.quantization.quant_config import QuantizationStatus
30
30
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
31
  from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32
- from compressed_tensors.utils import get_execution_device, is_module_offloaded
32
+ from compressed_tensors.utils import (
33
+ disable_hf_hook,
34
+ has_offloaded_params,
35
+ register_offload_parameter,
36
+ )
33
37
  from torch.nn import Module, Parameter
34
38
 
35
39
 
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
112
116
  module.quantization_scheme = scheme
113
117
  module.quantization_status = QuantizationStatus.INITIALIZED
114
118
 
115
- offloaded = False
116
- # What is this doing/why isn't this in the attn case?
117
- if is_module_offloaded(module):
118
- try:
119
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
120
- from accelerate.utils import PrefixedDataset
121
- except ModuleNotFoundError:
122
- raise ModuleNotFoundError(
123
- "Offloaded model detected. To use CPU offloading with "
124
- "compressed-tensors the `accelerate` package must be installed, "
125
- "run `pip install compressed-tensors[accelerate]`"
126
- )
127
-
128
- offloaded = True
129
- hook = module._hf_hook
130
- prefix_dict = module._hf_hook.weights_map
131
- new_prefix = {}
132
-
133
- # recreate the prefix dict (since it is immutable)
134
- # and add quantization parameters
135
- for key, data in module.named_parameters():
136
- if key not in prefix_dict:
137
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
138
- else:
139
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
140
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
141
- remove_hook_from_module(module)
142
-
143
- # wrap forward call of module to perform
144
- # quantized actions based on calltime status
145
- wrap_module_forward_quantized(module, scheme)
146
-
147
- if offloaded:
148
- # we need to re-add the hook for offloading now that we've wrapped forward
149
- add_hook_to_module(module, hook)
150
- if prefix_dict is not None:
151
- module._hf_hook.weights_map = new_prefix_dict
119
+ with disable_hf_hook(module):
120
+ # wrap forward call of module to perform
121
+ # quantized actions based on calltime status
122
+ wrap_module_forward_quantized(module, scheme)
152
123
 
153
124
 
154
125
  def is_attention_module(module: Module):
@@ -169,12 +140,17 @@ def _initialize_scale_zero_point(
169
140
  if quantization_args.dynamic:
170
141
  return
171
142
 
172
- device = next(module.parameters()).device
173
- if is_module_offloaded(module):
174
- device = get_execution_device(module)
143
+ # begin on the same device as other parameters or cpu if offloaded.
144
+ # in the offloaded case, there's no point moving tensors to the execution device
145
+ # if they're going to be immediately offloaded by `register_offload_parameter`
146
+ params_device = next(module.parameters()).device
147
+ device = "cpu" if has_offloaded_params(module) else params_device
175
148
 
176
149
  # infer expected scale/zero point shape
177
- expected_shape = 1 # per tensor
150
+ if quantization_args.strategy == QuantizationStrategy.TOKEN:
151
+ expected_shape = (1, 1)
152
+ else:
153
+ expected_shape = 1
178
154
 
179
155
  if base_name == "weight" and weight_shape is not None:
180
156
  if quantization_args.strategy == QuantizationStrategy.CHANNEL:
@@ -193,7 +169,7 @@ def _initialize_scale_zero_point(
193
169
  torch.empty(expected_shape, dtype=scale_dtype, device=device),
194
170
  requires_grad=False,
195
171
  )
196
- module.register_parameter(f"{base_name}_scale", init_scale)
172
+ register_offload_parameter(module, f"{base_name}_scale", init_scale)
197
173
 
198
174
  if force_zero_point or not quantization_args.symmetric:
199
175
  zp_dtype = quantization_args.pytorch_dtype()
@@ -201,7 +177,7 @@ def _initialize_scale_zero_point(
201
177
  torch.zeros(expected_shape, device=device, dtype=zp_dtype),
202
178
  requires_grad=False,
203
179
  )
204
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
180
+ register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
205
181
 
206
182
  # only grouped activation ordering has g_idx
207
183
  if quantization_args.actorder == ActivationOrdering.GROUP:
@@ -211,7 +187,7 @@ def _initialize_scale_zero_point(
211
187
  torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
212
188
  requires_grad=False,
213
189
  )
214
- module.register_parameter(f"{base_name}_g_idx", init_g_idx)
190
+ register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
215
191
 
216
192
 
217
193
  def _initialize_attn_scales(module: Module) -> None:
@@ -160,7 +160,7 @@ class QuantizationConfig(BaseModel):
160
160
 
161
161
  def to_dict(self):
162
162
  # for compatibility with HFQuantizer
163
- return self.dict()
163
+ return self.model_dump()
164
164
 
165
165
  @staticmethod
166
166
  def from_pretrained(
@@ -12,8 +12,11 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict, Optional
15
+ import warnings
16
+ from functools import wraps
17
+ from typing import Any, Callable, Dict, List, Optional
16
18
 
19
+ import numpy
17
20
  import torch
18
21
  from transformers import AutoConfig
19
22
 
@@ -24,7 +27,13 @@ __all__ = [
24
27
  "tensor_follows_mask_structure",
25
28
  "replace_module",
26
29
  "is_compressed_tensors_config",
30
+ "getattr_chain",
31
+ "deprecated",
27
32
  "Aliasable",
33
+ "combine_shards",
34
+ "shard_tensor",
35
+ "pack_bitmasks",
36
+ "unpack_bitmasks",
28
37
  ]
29
38
 
30
39
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -122,6 +131,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
122
131
  return False
123
132
 
124
133
 
134
+ def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
135
+ """
136
+ Chain multiple getattr calls, separated by `.`
137
+
138
+ :param obj: base object whose attributes are being retrieved
139
+ :param chain_str: attribute names separated by `.`
140
+ :param default: default value, throw error otherwise
141
+ """
142
+ if len(args) >= 1:
143
+ has_default = True
144
+ default = args[0]
145
+ elif "default" in kwargs:
146
+ has_default = True
147
+ default = kwargs["default"]
148
+ else:
149
+ has_default = False
150
+
151
+ attr_names = chain_str.split(".")
152
+
153
+ res = obj
154
+ for attr_name in attr_names:
155
+ if not hasattr(res, attr_name):
156
+ if has_default:
157
+ return default
158
+ else:
159
+ raise AttributeError(f"{res} object has no attribute {attr_name}")
160
+ res = getattr(res, attr_name)
161
+
162
+ return res
163
+
164
+
165
+ def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
166
+ """
167
+ Decorator to mark functions as deprecated
168
+
169
+ :param new_function: Function called in place of depreciated function
170
+ :param message: Depreciation message, replaces default depreciation message
171
+ """
172
+
173
+ def decorator(func: Callable[[Any], Any]):
174
+ nonlocal message
175
+
176
+ if message is None:
177
+ message = (
178
+ f"{func.__name__} is deprecated and will be removed in a future release"
179
+ )
180
+ if future_name is not None:
181
+ message += f". Please use {future_name} instead."
182
+
183
+ @wraps(func)
184
+ def wrapped(*args, **kwargs):
185
+ warnings.warn(message, DeprecationWarning, stacklevel=2)
186
+ return func(*args, **kwargs)
187
+
188
+ return wrapped
189
+
190
+ return decorator
191
+
192
+
125
193
  class Aliasable:
126
194
  """
127
195
  A mixin for enums to allow aliasing of enum members
@@ -151,3 +219,108 @@ class Aliasable:
151
219
  def __hash__(self):
152
220
  canonical_value = self.aliases.get(self.value, self.value)
153
221
  return hash(canonical_value)
222
+
223
+
224
+ def shard_tensor(
225
+ tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
226
+ ) -> List[torch.Tensor]:
227
+ """
228
+ Shards a tensor into a list of tensors along a given dimension.
229
+
230
+ raises: ValueError: If the sum of shard_sizes does not match the
231
+ size of the tensor along the given dimension.
232
+
233
+ :param tensor: The input tensor to shard.
234
+ :param shard_sizes : List of sizes for each shard along the specified dimension.
235
+ :param dim : The dimension along which to shard the tensor.
236
+ :returns: A list of tensors sharded along the specified dimension.
237
+ """
238
+ if sum(shard_sizes) != tensor.size(dim):
239
+ raise ValueError(
240
+ "Sum of shard_sizes must equal the size of the tensor "
241
+ "along the specified dimension."
242
+ )
243
+
244
+ shards = []
245
+ start_idx = 0
246
+
247
+ for size in shard_sizes:
248
+ end_idx = start_idx + size
249
+ shard = tensor.narrow(dim, start_idx, size)
250
+ shards.append(shard)
251
+ start_idx = end_idx
252
+
253
+ return shards
254
+
255
+
256
+ def combine_shards(shards, dim=0):
257
+ """
258
+ Combine decompressed shards along a given dimension using `narrow`.
259
+
260
+ :param shards: List of decompressed shard tensors.
261
+ :param dim: Dimension to combine along (default: 0).
262
+ :return: Combined decompressed tensor.
263
+ """
264
+ if not shards:
265
+ raise ValueError("The list of shards is empty.")
266
+
267
+ # Assert that all shards have the same dtype
268
+ shard_dtypes = {shard.dtype for shard in shards}
269
+ if len(shard_dtypes) > 1:
270
+ raise ValueError("All shards must have the same dtype.")
271
+
272
+ # Determine the total shape of the combined tensor
273
+ total_shape = list(shards[0].shape)
274
+ total_shape[dim] = sum(shard.shape[dim] for shard in shards)
275
+
276
+ # Create the combined tensor
277
+ combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
278
+
279
+ # Fill the combined tensor using narrow
280
+ shard_offset = 0
281
+ for shard in shards:
282
+ shard_size = shard.shape[dim]
283
+ combined.narrow(dim, shard_offset, shard_size).copy_(shard)
284
+ shard_offset += shard_size
285
+
286
+ return combined
287
+
288
+
289
+ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
290
+ """
291
+ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
292
+ compressed to R x ceil(C/8)
293
+
294
+ :param bytemasks: mask tensor where each byte corresponds to a weight
295
+ :return: mask tensor where each bit corresounds to a weight
296
+ """
297
+ packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
298
+ packed_bits_torch = torch.from_numpy(packed_bits_numpy)
299
+
300
+ return packed_bits_torch
301
+
302
+
303
+ def unpack_bitmasks(
304
+ packed_bitmasks: torch.Tensor, original_shape: torch.Size
305
+ ) -> torch.Tensor:
306
+ """
307
+ Converts a bitmask tensor back to a bytemask tensor for use during decompression
308
+
309
+ :param packed_bitmasks: mask tensor where each bit corresponds to a weight
310
+ :param original_shape: dense shape to decompress to
311
+ :return: boolean mask of weights in the original dense shape
312
+ """
313
+ # Unpack the bits
314
+ unpacked_bits = numpy.unpackbits(
315
+ packed_bitmasks.cpu().numpy(),
316
+ axis=-1,
317
+ count=original_shape[-1],
318
+ bitorder="little",
319
+ )
320
+
321
+ # Reshape to match the original shape
322
+ unpacked_bitmasks_torch = torch.from_numpy(
323
+ unpacked_bits.reshape(original_shape).astype(bool)
324
+ )
325
+
326
+ return unpacked_bitmasks_torch