compressed-tensors 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +76 -14
  2. compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  3. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
  5. compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
  6. compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  7. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  8. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  9. compressed_tensors/config/__init__.py +1 -0
  10. compressed_tensors/config/base.py +1 -0
  11. compressed_tensors/config/sparse_24_bitmask.py +40 -0
  12. compressed_tensors/quantization/lifecycle/apply.py +46 -1
  13. compressed_tensors/quantization/lifecycle/forward.py +2 -2
  14. compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  15. compressed_tensors/quantization/quant_config.py +1 -1
  16. compressed_tensors/utils/helpers.py +174 -1
  17. compressed_tensors/utils/offload.py +332 -44
  18. compressed_tensors/utils/safetensors_load.py +83 -17
  19. compressed_tensors/version.py +1 -1
  20. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
  21. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +24 -22
  22. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
  23. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +0 -0
  24. {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -17,8 +17,9 @@ import logging
17
17
  import operator
18
18
  import os
19
19
  import re
20
+ from contextlib import contextmanager
20
21
  from copy import deepcopy
21
- from typing import TYPE_CHECKING, Any, Dict, Optional, TypeVar, Union
22
+ from typing import TYPE_CHECKING, Any, Dict, Optional, Set, TypeVar, Union
22
23
 
23
24
  import compressed_tensors
24
25
  import torch
@@ -38,6 +39,7 @@ from compressed_tensors.quantization import (
38
39
  apply_quantization_config,
39
40
  load_pretrained_quantization,
40
41
  )
42
+ from compressed_tensors.quantization.lifecycle import expand_sparse_target_names
41
43
  from compressed_tensors.quantization.quant_args import QuantizationArgs
42
44
  from compressed_tensors.quantization.utils import (
43
45
  is_module_quantized,
@@ -104,7 +106,6 @@ class ModelCompressor:
104
106
  """
105
107
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
106
108
  compression_config = getattr(config, QUANTIZATION_CONFIG_NAME, None)
107
-
108
109
  return cls.from_compression_config(compression_config)
109
110
 
110
111
  @classmethod
@@ -137,7 +138,7 @@ class ModelCompressor:
137
138
  format, **sparsity_config
138
139
  )
139
140
  if quantization_config is not None:
140
- quantization_config = QuantizationConfig.parse_obj(quantization_config)
141
+ quantization_config = QuantizationConfig.model_validate(quantization_config)
141
142
 
142
143
  return cls(
143
144
  sparsity_config=sparsity_config, quantization_config=quantization_config
@@ -193,7 +194,7 @@ class ModelCompressor:
193
194
 
194
195
  if is_compressed_tensors_config(compression_config):
195
196
  s_config = compression_config.sparsity_config
196
- return s_config.dict() if s_config is not None else None
197
+ return s_config.model_dump() if s_config is not None else None
197
198
 
198
199
  return compression_config.get(SPARSITY_CONFIG_NAME, None)
199
200
 
@@ -214,7 +215,7 @@ class ModelCompressor:
214
215
 
215
216
  if is_compressed_tensors_config(compression_config):
216
217
  q_config = compression_config.quantization_config
217
- return q_config.dict() if q_config is not None else None
218
+ return q_config.model_dump() if q_config is not None else None
218
219
 
219
220
  quantization_config = deepcopy(compression_config)
220
221
  quantization_config.pop(SPARSITY_CONFIG_NAME, None)
@@ -282,8 +283,14 @@ class ModelCompressor:
282
283
  )
283
284
 
284
285
  if self.sparsity_compressor is not None:
286
+ sparse_compression_targets: Set[str] = expand_sparse_target_names(
287
+ model=model,
288
+ targets=self.sparsity_config.targets,
289
+ ignore=self.sparsity_config.ignore,
290
+ )
285
291
  compressed_state_dict = self.sparsity_compressor.compress(
286
- compressed_state_dict
292
+ compressed_state_dict,
293
+ compression_targets=sparse_compression_targets,
287
294
  )
288
295
 
289
296
  # HACK: Override the dtype_byte_size function in transformers to
@@ -301,23 +308,44 @@ class ModelCompressor:
301
308
  :param model: pytorch model to load decompressed weights into
302
309
  """
303
310
  model_path = get_safetensors_folder(model_path)
304
- if self.sparsity_compressor is not None:
311
+ sparse_decompressed = False
312
+
313
+ if (
314
+ self.sparsity_compressor is not None
315
+ and self.sparsity_config.format != CompressionFormat.dense.value
316
+ ):
317
+ # Sparse decompression is applied on the model_path
305
318
  dense_gen = self.sparsity_compressor.decompress(model_path)
306
319
  self._replace_weights(dense_gen, model)
307
320
  setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
321
+ sparse_decompressed = True
308
322
 
309
323
  if self.quantization_compressor is not None:
310
- names_to_scheme = apply_quantization_config(model, self.quantization_config)
311
- load_pretrained_quantization(model, model_path)
324
+ # Temporarily set quantization status to FROZEN to prevent
325
+ # quantization during apply_quantization_config. This ensures
326
+ # that the dtypes of the weights are not unintentionally updated.
327
+ # The status is restored after quantization params are loaded.
328
+ with override_quantization_status(
329
+ self.quantization_config, QuantizationStatus.FROZEN
330
+ ):
331
+ names_to_scheme = apply_quantization_config(
332
+ model, self.quantization_config
333
+ )
334
+ load_pretrained_quantization(model, model_path)
335
+
336
+ model_path_or_state_dict = (
337
+ model.state_dict() if sparse_decompressed else model_path
338
+ )
339
+
312
340
  dense_gen = self.quantization_compressor.decompress(
313
- model_path, names_to_scheme=names_to_scheme
341
+ model_path_or_state_dict, names_to_scheme=names_to_scheme
314
342
  )
315
343
  self._replace_weights(dense_gen, model)
316
344
 
317
- def update_status(module):
345
+ def freeze_quantization_status(module):
318
346
  module.quantization_status = QuantizationStatus.FROZEN
319
347
 
320
- model.apply(update_status)
348
+ model.apply(freeze_quantization_status)
321
349
  setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
322
350
 
323
351
  def update_config(self, save_directory: str):
@@ -367,12 +395,26 @@ class ModelCompressor:
367
395
  with open(config_file_path, "w") as config_file:
368
396
  json.dump(config_data, config_file, indent=2, sort_keys=True)
369
397
 
370
- def _replace_weights(self, dense_weight_generator, model):
398
+ def _replace_weights(self, dense_weight_generator, model: Module):
399
+ """
400
+ Replace the weights of the model with the
401
+ provided dense weights.
402
+
403
+ This method iterates over the dense_weight_generator and
404
+ updates the corresponding weights in the model. If a parameter
405
+ name does not exist in the model, it will be skipped.
406
+
407
+ :param dense_weight_generator (generator): A generator that yields
408
+ tuples of (name, data), where 'name' is the parameter name and
409
+ 'data' is the updated param data
410
+ :param model: The model whose weights are to be updated.
411
+ """
371
412
  for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
372
413
  split_name = name.split(".")
373
414
  prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
374
415
  module = operator.attrgetter(prefix)(model)
375
- update_parameter_data(module, data, param_name)
416
+ if hasattr(module, param_name):
417
+ update_parameter_data(module, data, param_name)
376
418
 
377
419
 
378
420
  def map_modules_to_quant_args(model: Module) -> Dict[str, QuantizationArgs]:
@@ -402,3 +444,23 @@ def new_dtype_byte_size(dtype):
402
444
  raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
403
445
  bit_size = int(bit_search.groups()[0])
404
446
  return bit_size // 8
447
+
448
+
449
+ @contextmanager
450
+ def override_quantization_status(
451
+ config: QuantizationConfig, status: QuantizationStatus
452
+ ):
453
+ """
454
+ Within this context, the quantization status will be set to the
455
+ supplied status. After the context exits, the original status
456
+ will be restored.
457
+
458
+ :param config: the quantization config to override
459
+ :param status: the status to temporarily set
460
+ """
461
+ original_status = config.quantization_status
462
+ config.quantization_status = status
463
+ try:
464
+ yield
465
+ finally:
466
+ config.quantization_status = original_status
@@ -13,12 +13,17 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from pathlib import Path
17
+ from typing import Any, Dict, Generator, Tuple, Union
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.compressors.base import BaseCompressor
20
21
  from compressed_tensors.quantization import QuantizationArgs
21
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
22
+ from compressed_tensors.utils import (
23
+ get_nested_mappings_from_state_dict,
24
+ get_nested_weight_mappings,
25
+ merge_names,
26
+ )
22
27
  from safetensors import safe_open
23
28
  from torch import Tensor
24
29
  from tqdm import tqdm
@@ -113,7 +118,7 @@ class BaseQuantizationCompressor(BaseCompressor):
113
118
 
114
119
  def decompress(
115
120
  self,
116
- path_to_model_or_tensors: str,
121
+ path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
117
122
  names_to_scheme: Dict[str, QuantizationArgs],
118
123
  device: str = "cpu",
119
124
  ) -> Generator[Tuple[str, Tensor], None, None]:
@@ -121,15 +126,25 @@ class BaseQuantizationCompressor(BaseCompressor):
121
126
  Reads a compressed state dict located at path_to_model_or_tensors
122
127
  and returns a generator for sequentially decompressing back to a
123
128
  dense state dict
124
-
125
129
  :param path_to_model_or_tensors: path to compressed safetensors model (directory
126
130
  with one or more safetensors files) or compressed tensors file
127
131
  :param names_to_scheme: quantization args for each quantized weight
128
132
  :param device: optional device to load intermediate weights into
129
133
  :return: compressed state dict
130
134
  """
135
+ if isinstance(path_to_model_or_tensors, (str, Path)):
136
+ yield from self._decompress_from_path(
137
+ path_to_model_or_tensors, names_to_scheme, device
138
+ )
139
+
140
+ else:
141
+ yield from self._decompress_from_state_dict(
142
+ path_to_model_or_tensors, names_to_scheme
143
+ )
144
+
145
+ def _decompress_from_path(self, path_to_model, names_to_scheme, device):
131
146
  weight_mappings = get_nested_weight_mappings(
132
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
147
+ path_to_model, self.COMPRESSION_PARAM_NAMES
133
148
  )
134
149
  for weight_name in weight_mappings.keys():
135
150
  weight_data = {}
@@ -137,6 +152,21 @@ class BaseQuantizationCompressor(BaseCompressor):
137
152
  full_name = merge_names(weight_name, param_name)
138
153
  with safe_open(safe_path, framework="pt", device=device) as f:
139
154
  weight_data[param_name] = f.get_tensor(full_name)
155
+ if "weight_scale" in weight_data:
156
+ quant_args = names_to_scheme[weight_name]
157
+ decompressed = self.decompress_weight(
158
+ compressed_data=weight_data, quantization_args=quant_args
159
+ )
160
+ yield merge_names(weight_name, "weight"), decompressed
161
+
162
+ def _decompress_from_state_dict(self, state_dict, names_to_scheme):
163
+ weight_mappings = get_nested_mappings_from_state_dict(
164
+ state_dict, self.COMPRESSION_PARAM_NAMES
165
+ )
166
+ for weight_name in weight_mappings.keys():
167
+ weight_data = {}
168
+ for param_name, param_value in weight_mappings[weight_name].items():
169
+ weight_data[param_name] = param_value
140
170
 
141
171
  if "weight_scale" in weight_data:
142
172
  quant_args = names_to_scheme[weight_name]
@@ -68,9 +68,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class NaiveQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -68,9 +68,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
68
68
  self,
69
69
  weight: Tensor,
70
70
  scale: Tensor,
71
+ quantization_args: QuantizationArgs,
71
72
  zero_point: Optional[Tensor] = None,
72
73
  g_idx: Optional[torch.Tensor] = None,
73
- quantization_args: Optional[QuantizationArgs] = None,
74
74
  device: Optional[torch.device] = None,
75
75
  ) -> Dict[str, torch.Tensor]:
76
76
  """
@@ -78,9 +78,9 @@ class PackedQuantizationCompressor(BaseQuantizationCompressor):
78
78
 
79
79
  :param weight: uncompressed weight tensor
80
80
  :param scale: quantization scale for weight
81
+ :param quantization_args: quantization parameters for weight
81
82
  :param zero_point: quantization zero point for weight
82
83
  :param g_idx: optional mapping from column index to group index
83
- :param quantization_args: quantization parameters for weight
84
84
  :param device: optional device to move compressed output to
85
85
  :return: dictionary of compressed weight data
86
86
  """
@@ -15,4 +15,5 @@
15
15
 
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
- from typing import Dict, Generator, Tuple
16
+ from typing import Dict, Generator, Optional, Set, Tuple
17
17
 
18
18
  from compressed_tensors.compressors.base import BaseCompressor
19
19
  from compressed_tensors.utils import get_nested_weight_mappings, merge_names
@@ -30,7 +30,8 @@ _LOGGER: logging.Logger = logging.getLogger(__name__)
30
30
  class BaseSparseCompressor(BaseCompressor):
31
31
  """
32
32
  Base class representing a sparse compression algorithm. Each child class should
33
- implement compression_param_info, compress_weight and decompress_weight.
33
+ implement compression_param_info, compress_weight and decompress_weight; child
34
+ classes should also define COMPRESSION_PARAM_NAMES.
34
35
 
35
36
  Compressors support compressing/decompressing a full module state dict or a single
36
37
  quantized PyTorch leaf module.
@@ -59,11 +60,17 @@ class BaseSparseCompressor(BaseCompressor):
59
60
  :param config: config specifying compression parameters
60
61
  """
61
62
 
62
- def compress(self, model_state: Dict[str, Tensor]) -> Dict[str, Tensor]:
63
+ def compress(
64
+ self,
65
+ model_state: Dict[str, Tensor],
66
+ compression_targets: Optional[Set[str]] = None,
67
+ ) -> Dict[str, Tensor]:
63
68
  """
64
69
  Compresses a dense state dict using bitmask compression
65
70
 
66
71
  :param model_state: state dict of uncompressed model
72
+ :param compression_targets: optional set of layer prefixes to compress,
73
+ otherwise compress all layers (for backwards compatibility)
67
74
  :return: compressed state dict
68
75
  """
69
76
  compressed_dict = {}
@@ -71,7 +78,14 @@ class BaseSparseCompressor(BaseCompressor):
71
78
  f"Compressing model with {len(model_state)} parameterized layers..."
72
79
  )
73
80
  for name, value in tqdm(model_state.items(), desc="Compressing model"):
74
- compression_data = self.compress_weight(name, value)
81
+ if not self.should_compress(name, compression_targets):
82
+ compressed_dict[name] = value
83
+ continue
84
+ prefix = name
85
+ if prefix.endswith(".weight"):
86
+ prefix = prefix[: -(len(".weight"))]
87
+
88
+ compression_data = self.compress_weight(prefix, value)
75
89
  for key in compression_data.keys():
76
90
  if key in compressed_dict:
77
91
  _LOGGER.warn(
@@ -97,8 +111,10 @@ class BaseSparseCompressor(BaseCompressor):
97
111
  :param device: device to load decompressed weights onto
98
112
  :return: iterator for generating decompressed weights
99
113
  """
100
- weight_mappings = get_nested_weight_mappings(
101
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
114
+ weight_mappings, ignored_params = get_nested_weight_mappings(
115
+ path_to_model_or_tensors,
116
+ self.COMPRESSION_PARAM_NAMES,
117
+ return_unmatched_params=True,
102
118
  )
103
119
  for weight_name in weight_mappings.keys():
104
120
  weight_data = {}
@@ -107,4 +123,26 @@ class BaseSparseCompressor(BaseCompressor):
107
123
  with safe_open(safe_path, framework="pt", device=device) as f:
108
124
  weight_data[param_name] = f.get_tensor(full_name)
109
125
  decompressed = self.decompress_weight(weight_data)
110
- yield weight_name, decompressed
126
+ yield merge_names(weight_name, "weight"), decompressed
127
+
128
+ for ignored_param_name, safe_path in ignored_params.items():
129
+ with safe_open(safe_path, framework="pt", device=device) as f:
130
+ value = f.get_tensor(ignored_param_name)
131
+ yield ignored_param_name, value
132
+
133
+ @staticmethod
134
+ def should_compress(name: str, expanded_targets: Optional[Set[str]] = None) -> bool:
135
+ """
136
+ Check if a parameter should be compressed.
137
+ Currently, this only returns True for weight parameters.
138
+
139
+ :param name: name of the parameter
140
+ :param expanded_targets: set of layer prefixes to compress
141
+ :return: whether or not the parameter should be compressed
142
+ """
143
+ if expanded_targets is None:
144
+ return name.endswith(".weight")
145
+
146
+ return (
147
+ name.endswith(".weight") and name[: -(len(".weight"))] in expanded_targets
148
+ )
@@ -0,0 +1,238 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
+ from compressed_tensors.config import CompressionFormat, SparsityStructure
22
+ from compressed_tensors.quantization import FP8_DTYPE
23
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
+ from torch import Tensor
25
+
26
+
27
+ __all__ = [
28
+ "Sparse24BitMaskCompressor",
29
+ "Sparse24BitMaskTensor",
30
+ "sparse24_bitmask_compress",
31
+ "sparse24_bitmask_decompress",
32
+ "get_24_bytemasks",
33
+ ]
34
+
35
+
36
+ @BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
37
+ class Sparse24BitMaskCompressor(BaseSparseCompressor):
38
+ """
39
+ Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
40
+ values tensor, with their locations stored in a 2d bitmask
41
+ """
42
+
43
+ COMPRESSION_PARAM_NAMES = [
44
+ "shape",
45
+ "compressed",
46
+ "bitmask",
47
+ ]
48
+
49
+ def compress_weight(self, name, value):
50
+ bitmask_tensor = Sparse24BitMaskTensor.from_dense(
51
+ value, self.config.sparsity_structure
52
+ )
53
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
54
+ return bitmask_dict
55
+
56
+ def decompress_weight(self, weight_data):
57
+ data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
58
+ decompressed = data.decompress()
59
+ return decompressed
60
+
61
+
62
+ @dataclass
63
+ class Sparse24BitMaskTensor:
64
+ """
65
+ Owns compressions and decompression for a single 2:4 sparse
66
+ bitmask compressed tensor.
67
+
68
+ :param shape: shape of dense tensor
69
+ :param compressed: 2d tensor of non-zero values
70
+ :param bitmask: 2d bitmask of non-zero values
71
+ """
72
+
73
+ shape: List[int]
74
+ compressed: Tensor
75
+ bitmask: Tensor
76
+
77
+ @staticmethod
78
+ def from_dense(
79
+ tensor: Tensor,
80
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
81
+ ) -> "Sparse24BitMaskTensor":
82
+ """
83
+ :param tensor: dense tensor to compress
84
+ :return: instantiated compressed tensor
85
+ """
86
+ shape = list(tensor.shape)
87
+ compressed, bitmask = sparse24_bitmask_compress(
88
+ tensor.cpu(), sparsity_structure=sparsity_structure
89
+ )
90
+ return Sparse24BitMaskTensor(
91
+ shape=shape,
92
+ compressed=compressed,
93
+ bitmask=bitmask,
94
+ )
95
+
96
+ @staticmethod
97
+ def from_compressed_data(
98
+ shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
99
+ ) -> "Sparse24BitMaskTensor":
100
+ """
101
+ :param shape: shape of the dense tensor (can be a list or a tensor)
102
+ :param compressed: 2d tensor of non-zero values
103
+ :param bitmask: 2d bitmask of non-zero values
104
+ :return: instantiated Sparse24BitMaskTensor
105
+ """
106
+ if isinstance(shape, Tensor):
107
+ shape = shape.tolist()
108
+ return Sparse24BitMaskTensor(
109
+ shape=shape, compressed=compressed, bitmask=bitmask
110
+ )
111
+
112
+ def decompress(self) -> Tensor:
113
+ """
114
+ :return: reconstructed dense tensor
115
+ """
116
+ return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
117
+
118
+ def curr_memory_size_bytes(self) -> int:
119
+ """
120
+ :return: size in bytes required to store compressed tensor on disk
121
+ """
122
+
123
+ def sizeof_tensor(a: Tensor) -> int:
124
+ return a.element_size() * a.nelement()
125
+
126
+ return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
127
+
128
+ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
129
+ """
130
+ :param name_prefix: name of original tensor to store compressed weight as
131
+ :return: dict of compressed data for the stored weight
132
+ """
133
+ if name_prefix.endswith(".weight"):
134
+ name_prefix = name_prefix[: -len(".weight")]
135
+ return {
136
+ merge_names(name_prefix, "shape"): torch.tensor(
137
+ self.shape, device=device
138
+ ).reshape(-1, 1),
139
+ merge_names(name_prefix, "compressed"): self.compressed.to(device),
140
+ merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
141
+ }
142
+
143
+ def __repr__(self) -> str:
144
+ return f"BitMaskTensor(shape={self.shape}, compressed=True)"
145
+
146
+
147
+ def sparse24_bitmask_compress(
148
+ tensor: Tensor,
149
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
150
+ ) -> Tuple[Tensor, Tensor, Tensor]:
151
+ """
152
+ Compresses a dense tensor using bitmask compression
153
+
154
+ :param tensor: dense 2D tensor to compress
155
+ :param sparsity_structure: structure of sparsity in the tensor, defaults
156
+ to unstructured, can also be set to `2:4`
157
+ :return: tuple of compressed data representing tensor
158
+ """
159
+ assert len(tensor.shape) == 2, "Only 2D tensors are supported"
160
+ assert (
161
+ SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
162
+ ), "Only 2:4 sparsity is supported"
163
+
164
+ bytemasks = get_24_bytemasks(tensor=tensor)
165
+
166
+ if tensor.dtype == FP8_DTYPE:
167
+ # acces raw bytes of the tensor
168
+ tensor_view = tensor.view(torch.int8)
169
+ values = tensor_view[bytemasks]
170
+ values = values.view(FP8_DTYPE)
171
+ else:
172
+ values = tensor[bytemasks]
173
+
174
+ num_rows, num_cols = tensor.shape
175
+ compressed_values = values.reshape(num_rows, num_cols // 2)
176
+ bitmasks_packed = pack_bitmasks(bytemasks)
177
+ return compressed_values, bitmasks_packed
178
+
179
+
180
+ def sparse24_bitmask_decompress(
181
+ values: Tensor, bitmasks: Tensor, original_shape: torch.Size
182
+ ) -> Tensor:
183
+ """
184
+ Reconstructs a dense tensor from a compressed one
185
+
186
+ :param values: 1d tensor of non-zero values
187
+ :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
188
+ tensors original shape
189
+ :param original_shape: shape of the dense tensor
190
+ :return: decompressed dense tensor
191
+ """
192
+ bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
193
+
194
+ decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
195
+ decompressed_tensor = decompressed_tensor.to(values.device)
196
+ values = values.flatten()
197
+ if decompressed_tensor.dtype == FP8_DTYPE:
198
+ decompressed_tensor[bytemasks_unpacked] = values
199
+ decompressed_tensor = decompressed_tensor.cuda()
200
+ else:
201
+ decompressed_tensor[bytemasks_unpacked] = values
202
+ return decompressed_tensor
203
+
204
+
205
+ def get_24_bytemasks(tensor):
206
+ """
207
+ Generate a 2:4 sparsity mask for the given tensor.
208
+
209
+ This function creates a mask where exactly 2 out of every 4 elements are
210
+ preserved based on their magnitudes. The preserved elements are the ones
211
+ with the highest absolute values in each group of 4 elements.
212
+
213
+ :param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
214
+ The tensor can be of any shape but its total number of elements
215
+ must be a multiple of 4.
216
+ :return: A boolean tensor of the same shape as the input tensor, where `True`
217
+ indicates the preserved elements and `False` indicates the pruned elements.
218
+ :raises ValueError: If the total number of elements in the tensor is not a
219
+ multiple of 4.
220
+ """
221
+ original_dtype = tensor.dtype
222
+ if tensor.dtype == FP8_DTYPE:
223
+ tensor = tensor.view(torch.int8)
224
+ original_shape = tensor.shape
225
+ num_elements = tensor.numel()
226
+
227
+ if num_elements % 4 != 0:
228
+ raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
229
+
230
+ reshaped_tensor = tensor.view(-1, 4)
231
+ abs_tensor = reshaped_tensor.abs()
232
+ topk_indices = abs_tensor.topk(2, dim=1).indices
233
+ mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
234
+ mask.scatter_(1, topk_indices, True)
235
+ mask = mask.view(original_shape)
236
+ tensor = tensor.view(original_dtype)
237
+
238
+ return mask