compressed-tensors-nightly 0.8.1.20250110__tar.gz → 0.8.1.20250112__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
Files changed (58) hide show
  1. {compressed-tensors-nightly-0.8.1.20250110/src/compressed_tensors_nightly.egg-info → compressed-tensors-nightly-0.8.1.20250112}/PKG-INFO +1 -1
  2. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/model_compressors/model_compressor.py +4 -1
  3. {compressed-tensors-nightly-0.8.1.20250110/src/compressed_tensors/config → compressed-tensors-nightly-0.8.1.20250112/src/compressed_tensors/compressors/sparse_compressors}/__init__.py +2 -1
  4. compressed-tensors-nightly-0.8.1.20250112/src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  5. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +1 -38
  6. {compressed-tensors-nightly-0.8.1.20250110/src/compressed_tensors/compressors/sparse_compressors → compressed-tensors-nightly-0.8.1.20250112/src/compressed_tensors/config}/__init__.py +2 -1
  7. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/config/base.py +1 -0
  8. compressed-tensors-nightly-0.8.1.20250112/src/compressed_tensors/config/sparse_24_bitmask.py +40 -0
  9. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/helpers.py +111 -1
  10. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112/src/compressed_tensors_nightly.egg-info}/PKG-INFO +1 -1
  11. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors_nightly.egg-info/SOURCES.txt +2 -0
  12. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/LICENSE +0 -0
  13. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/README.md +0 -0
  14. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/pyproject.toml +0 -0
  15. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/setup.cfg +0 -0
  16. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/setup.py +0 -0
  17. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/__init__.py +0 -0
  18. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/base.py +0 -0
  19. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/__init__.py +0 -0
  20. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/base.py +0 -0
  21. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/helpers.py +0 -0
  22. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/model_compressors/__init__.py +0 -0
  23. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/quantized_compressors/__init__.py +0 -0
  24. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/quantized_compressors/base.py +0 -0
  25. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/quantized_compressors/naive_quantized.py +0 -0
  26. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +0 -0
  27. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/sparse_compressors/base.py +0 -0
  28. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/sparse_compressors/dense.py +0 -0
  29. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +0 -0
  30. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +0 -0
  31. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/config/dense.py +0 -0
  32. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/config/sparse_bitmask.py +0 -0
  33. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/linear/__init__.py +0 -0
  34. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/linear/compressed_linear.py +0 -0
  35. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/__init__.py +0 -0
  36. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/__init__.py +0 -0
  37. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/apply.py +0 -0
  38. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/compressed.py +0 -0
  39. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/forward.py +0 -0
  40. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/helpers.py +0 -0
  41. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/lifecycle/initialize.py +0 -0
  42. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/quant_args.py +0 -0
  43. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/quant_config.py +0 -0
  44. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/quant_scheme.py +0 -0
  45. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/utils/__init__.py +0 -0
  46. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/quantization/utils/helpers.py +0 -0
  47. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/registry/__init__.py +0 -0
  48. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/registry/registry.py +0 -0
  49. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/__init__.py +0 -0
  50. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/offload.py +0 -0
  51. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/permutations_24.py +0 -0
  52. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/permute.py +0 -0
  53. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/safetensors_load.py +0 -0
  54. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/utils/semi_structured_conversions.py +0 -0
  55. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors/version.py +0 -0
  56. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors_nightly.egg-info/dependency_links.txt +0 -0
  57. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors_nightly.egg-info/requires.txt +0 -0
  58. {compressed-tensors-nightly-0.8.1.20250110 → compressed-tensors-nightly-0.8.1.20250112}/src/compressed_tensors_nightly.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.8.1.20250110
3
+ Version: 0.8.1.20250112
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -310,7 +310,10 @@ class ModelCompressor:
310
310
  model_path = get_safetensors_folder(model_path)
311
311
  sparse_decompressed = False
312
312
 
313
- if self.sparsity_compressor is not None:
313
+ if (
314
+ self.sparsity_compressor is not None
315
+ and self.sparsity_config.format != CompressionFormat.dense.value
316
+ ):
314
317
  # Sparse decompression is applied on the model_path
315
318
  dense_gen = self.sparsity_compressor.decompress(model_path)
316
319
  self._replace_weights(dense_gen, model)
@@ -11,8 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
14
  # flake8: noqa
15
+
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -0,0 +1,238 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
+ from compressed_tensors.config import CompressionFormat, SparsityStructure
22
+ from compressed_tensors.quantization import FP8_DTYPE
23
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
+ from torch import Tensor
25
+
26
+
27
+ __all__ = [
28
+ "Sparse24BitMaskCompressor",
29
+ "Sparse24BitMaskTensor",
30
+ "sparse24_bitmask_compress",
31
+ "sparse24_bitmask_decompress",
32
+ "get_24_bytemasks",
33
+ ]
34
+
35
+
36
+ @BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
37
+ class Sparse24BitMaskCompressor(BaseSparseCompressor):
38
+ """
39
+ Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
40
+ values tensor, with their locations stored in a 2d bitmask
41
+ """
42
+
43
+ COMPRESSION_PARAM_NAMES = [
44
+ "shape",
45
+ "compressed",
46
+ "bitmask",
47
+ ]
48
+
49
+ def compress_weight(self, name, value):
50
+ bitmask_tensor = Sparse24BitMaskTensor.from_dense(
51
+ value, self.config.sparsity_structure
52
+ )
53
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
54
+ return bitmask_dict
55
+
56
+ def decompress_weight(self, weight_data):
57
+ data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
58
+ decompressed = data.decompress()
59
+ return decompressed
60
+
61
+
62
+ @dataclass
63
+ class Sparse24BitMaskTensor:
64
+ """
65
+ Owns compressions and decompression for a single 2:4 sparse
66
+ bitmask compressed tensor.
67
+
68
+ :param shape: shape of dense tensor
69
+ :param compressed: 2d tensor of non-zero values
70
+ :param bitmask: 2d bitmask of non-zero values
71
+ """
72
+
73
+ shape: List[int]
74
+ compressed: Tensor
75
+ bitmask: Tensor
76
+
77
+ @staticmethod
78
+ def from_dense(
79
+ tensor: Tensor,
80
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
81
+ ) -> "Sparse24BitMaskTensor":
82
+ """
83
+ :param tensor: dense tensor to compress
84
+ :return: instantiated compressed tensor
85
+ """
86
+ shape = list(tensor.shape)
87
+ compressed, bitmask = sparse24_bitmask_compress(
88
+ tensor.cpu(), sparsity_structure=sparsity_structure
89
+ )
90
+ return Sparse24BitMaskTensor(
91
+ shape=shape,
92
+ compressed=compressed,
93
+ bitmask=bitmask,
94
+ )
95
+
96
+ @staticmethod
97
+ def from_compressed_data(
98
+ shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
99
+ ) -> "Sparse24BitMaskTensor":
100
+ """
101
+ :param shape: shape of the dense tensor (can be a list or a tensor)
102
+ :param compressed: 2d tensor of non-zero values
103
+ :param bitmask: 2d bitmask of non-zero values
104
+ :return: instantiated Sparse24BitMaskTensor
105
+ """
106
+ if isinstance(shape, Tensor):
107
+ shape = shape.tolist()
108
+ return Sparse24BitMaskTensor(
109
+ shape=shape, compressed=compressed, bitmask=bitmask
110
+ )
111
+
112
+ def decompress(self) -> Tensor:
113
+ """
114
+ :return: reconstructed dense tensor
115
+ """
116
+ return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
117
+
118
+ def curr_memory_size_bytes(self) -> int:
119
+ """
120
+ :return: size in bytes required to store compressed tensor on disk
121
+ """
122
+
123
+ def sizeof_tensor(a: Tensor) -> int:
124
+ return a.element_size() * a.nelement()
125
+
126
+ return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
127
+
128
+ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
129
+ """
130
+ :param name_prefix: name of original tensor to store compressed weight as
131
+ :return: dict of compressed data for the stored weight
132
+ """
133
+ if name_prefix.endswith(".weight"):
134
+ name_prefix = name_prefix[: -len(".weight")]
135
+ return {
136
+ merge_names(name_prefix, "shape"): torch.tensor(
137
+ self.shape, device=device
138
+ ).reshape(-1, 1),
139
+ merge_names(name_prefix, "compressed"): self.compressed.to(device),
140
+ merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
141
+ }
142
+
143
+ def __repr__(self) -> str:
144
+ return f"BitMaskTensor(shape={self.shape}, compressed=True)"
145
+
146
+
147
+ def sparse24_bitmask_compress(
148
+ tensor: Tensor,
149
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
150
+ ) -> Tuple[Tensor, Tensor, Tensor]:
151
+ """
152
+ Compresses a dense tensor using bitmask compression
153
+
154
+ :param tensor: dense 2D tensor to compress
155
+ :param sparsity_structure: structure of sparsity in the tensor, defaults
156
+ to unstructured, can also be set to `2:4`
157
+ :return: tuple of compressed data representing tensor
158
+ """
159
+ assert len(tensor.shape) == 2, "Only 2D tensors are supported"
160
+ assert (
161
+ SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
162
+ ), "Only 2:4 sparsity is supported"
163
+
164
+ bytemasks = get_24_bytemasks(tensor=tensor)
165
+
166
+ if tensor.dtype == FP8_DTYPE:
167
+ # acces raw bytes of the tensor
168
+ tensor_view = tensor.view(torch.int8)
169
+ values = tensor_view[bytemasks]
170
+ values = values.view(FP8_DTYPE)
171
+ else:
172
+ values = tensor[bytemasks]
173
+
174
+ num_rows, num_cols = tensor.shape
175
+ compressed_values = values.reshape(num_rows, num_cols // 2)
176
+ bitmasks_packed = pack_bitmasks(bytemasks)
177
+ return compressed_values, bitmasks_packed
178
+
179
+
180
+ def sparse24_bitmask_decompress(
181
+ values: Tensor, bitmasks: Tensor, original_shape: torch.Size
182
+ ) -> Tensor:
183
+ """
184
+ Reconstructs a dense tensor from a compressed one
185
+
186
+ :param values: 1d tensor of non-zero values
187
+ :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
188
+ tensors original shape
189
+ :param original_shape: shape of the dense tensor
190
+ :return: decompressed dense tensor
191
+ """
192
+ bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
193
+
194
+ decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
195
+ decompressed_tensor = decompressed_tensor.to(values.device)
196
+ values = values.flatten()
197
+ if decompressed_tensor.dtype == FP8_DTYPE:
198
+ decompressed_tensor[bytemasks_unpacked] = values
199
+ decompressed_tensor = decompressed_tensor.cuda()
200
+ else:
201
+ decompressed_tensor[bytemasks_unpacked] = values
202
+ return decompressed_tensor
203
+
204
+
205
+ def get_24_bytemasks(tensor):
206
+ """
207
+ Generate a 2:4 sparsity mask for the given tensor.
208
+
209
+ This function creates a mask where exactly 2 out of every 4 elements are
210
+ preserved based on their magnitudes. The preserved elements are the ones
211
+ with the highest absolute values in each group of 4 elements.
212
+
213
+ :param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
214
+ The tensor can be of any shape but its total number of elements
215
+ must be a multiple of 4.
216
+ :return: A boolean tensor of the same shape as the input tensor, where `True`
217
+ indicates the preserved elements and `False` indicates the pruned elements.
218
+ :raises ValueError: If the total number of elements in the tensor is not a
219
+ multiple of 4.
220
+ """
221
+ original_dtype = tensor.dtype
222
+ if tensor.dtype == FP8_DTYPE:
223
+ tensor = tensor.view(torch.int8)
224
+ original_shape = tensor.shape
225
+ num_elements = tensor.numel()
226
+
227
+ if num_elements % 4 != 0:
228
+ raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
229
+
230
+ reshaped_tensor = tensor.view(-1, 4)
231
+ abs_tensor = reshaped_tensor.abs()
232
+ topk_indices = abs_tensor.topk(2, dim=1).indices
233
+ mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
234
+ mask.scatter_(1, topk_indices, True)
235
+ mask = mask.view(original_shape)
236
+ tensor = tensor.view(original_dtype)
237
+
238
+ return mask
@@ -14,13 +14,12 @@
14
14
 
15
15
  from typing import Dict, List, Tuple, Union
16
16
 
17
- import numpy
18
17
  import torch
19
18
  from compressed_tensors.compressors.base import BaseCompressor
20
19
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
20
  from compressed_tensors.config import CompressionFormat
22
21
  from compressed_tensors.quantization import FP8_DTYPE
23
- from compressed_tensors.utils import merge_names
22
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
23
  from torch import Tensor
25
24
 
26
25
 
@@ -29,8 +28,6 @@ __all__ = [
29
28
  "BitmaskTensor",
30
29
  "bitmask_compress",
31
30
  "bitmask_decompress",
32
- "pack_bitmasks",
33
- "unpack_bitmasks",
34
31
  ]
35
32
 
36
33
 
@@ -164,37 +161,3 @@ def bitmask_decompress(
164
161
  decompressed_tensor[bytemasks_unpacked] = values
165
162
 
166
163
  return decompressed_tensor
167
-
168
-
169
- def pack_bitmasks(bytemasks: Tensor) -> Tensor:
170
- """
171
- Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
172
- compressed to R x ceil(C/8)
173
- :param bytemasks: mask tensor where each byte corresponds to a weight
174
- :return: mask tensor where each bit corresounds to a weight
175
- """
176
- packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
177
- packed_bits_torch = torch.from_numpy(packed_bits_numpy)
178
-
179
- return packed_bits_torch
180
-
181
-
182
- def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
183
- """
184
- Converts a bitmask tensor back to a bytemask tensor for use during decompression
185
-
186
- :param packed_bitmasks: mask tensor where each bit corresponds to a weight
187
- :param original_shape: dense shape to decompress to
188
- :return: boolean mask of weights in the original dense shape
189
- """
190
- # Unpack the bits
191
- unpacked_bits = numpy.unpackbits(
192
- packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
193
- )
194
-
195
- # Reshape to match the original shape
196
- unpacked_bitmasks_torch = torch.from_numpy(
197
- unpacked_bits.reshape(original_shape).astype(bool)
198
- )
199
-
200
- return unpacked_bitmasks_torch
@@ -11,8 +11,9 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- # flake8: noqa
15
14
 
15
+ # flake8: noqa
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -26,6 +26,7 @@ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"
26
26
  class CompressionFormat(Enum):
27
27
  dense = "dense"
28
28
  sparse_bitmask = "sparse-bitmask"
29
+ sparse_24_bitmask = "sparse-24-bitmask"
29
30
  int_quantized = "int-quantized"
30
31
  float_quantized = "float-quantized"
31
32
  naive_quantized = "naive-quantized"
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import (
18
+ CompressionFormat,
19
+ SparsityCompressionConfig,
20
+ SparsityStructure,
21
+ )
22
+
23
+
24
+ __all__ = ["Sparse24BitMaskConfig"]
25
+
26
+
27
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
28
+ class Sparse24BitMaskConfig(SparsityCompressionConfig):
29
+ """
30
+ Configuration for storing a 24 sparse model using
31
+ bytemask compression
32
+
33
+ :param global_sparsity: average sparsity of the entire model
34
+ :param sparsity_structure: structure of the sparsity, should always be
35
+ "2:4" for this compression format
36
+ """
37
+
38
+ format: str = CompressionFormat.sparse_24_bitmask.value
39
+ global_sparsity: Optional[float] = 0.0
40
+ sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
@@ -14,8 +14,9 @@
14
14
 
15
15
  import warnings
16
16
  from functools import wraps
17
- from typing import Any, Callable, Dict, Optional
17
+ from typing import Any, Callable, Dict, List, Optional
18
18
 
19
+ import numpy
19
20
  import torch
20
21
  from transformers import AutoConfig
21
22
 
@@ -29,6 +30,10 @@ __all__ = [
29
30
  "getattr_chain",
30
31
  "deprecated",
31
32
  "Aliasable",
33
+ "combine_shards",
34
+ "shard_tensor",
35
+ "pack_bitmasks",
36
+ "unpack_bitmasks",
32
37
  ]
33
38
 
34
39
  FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
@@ -214,3 +219,108 @@ class Aliasable:
214
219
  def __hash__(self):
215
220
  canonical_value = self.aliases.get(self.value, self.value)
216
221
  return hash(canonical_value)
222
+
223
+
224
+ def shard_tensor(
225
+ tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
226
+ ) -> List[torch.Tensor]:
227
+ """
228
+ Shards a tensor into a list of tensors along a given dimension.
229
+
230
+ raises: ValueError: If the sum of shard_sizes does not match the
231
+ size of the tensor along the given dimension.
232
+
233
+ :param tensor: The input tensor to shard.
234
+ :param shard_sizes : List of sizes for each shard along the specified dimension.
235
+ :param dim : The dimension along which to shard the tensor.
236
+ :returns: A list of tensors sharded along the specified dimension.
237
+ """
238
+ if sum(shard_sizes) != tensor.size(dim):
239
+ raise ValueError(
240
+ "Sum of shard_sizes must equal the size of the tensor "
241
+ "along the specified dimension."
242
+ )
243
+
244
+ shards = []
245
+ start_idx = 0
246
+
247
+ for size in shard_sizes:
248
+ end_idx = start_idx + size
249
+ shard = tensor.narrow(dim, start_idx, size)
250
+ shards.append(shard)
251
+ start_idx = end_idx
252
+
253
+ return shards
254
+
255
+
256
+ def combine_shards(shards, dim=0):
257
+ """
258
+ Combine decompressed shards along a given dimension using `narrow`.
259
+
260
+ :param shards: List of decompressed shard tensors.
261
+ :param dim: Dimension to combine along (default: 0).
262
+ :return: Combined decompressed tensor.
263
+ """
264
+ if not shards:
265
+ raise ValueError("The list of shards is empty.")
266
+
267
+ # Assert that all shards have the same dtype
268
+ shard_dtypes = {shard.dtype for shard in shards}
269
+ if len(shard_dtypes) > 1:
270
+ raise ValueError("All shards must have the same dtype.")
271
+
272
+ # Determine the total shape of the combined tensor
273
+ total_shape = list(shards[0].shape)
274
+ total_shape[dim] = sum(shard.shape[dim] for shard in shards)
275
+
276
+ # Create the combined tensor
277
+ combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
278
+
279
+ # Fill the combined tensor using narrow
280
+ shard_offset = 0
281
+ for shard in shards:
282
+ shard_size = shard.shape[dim]
283
+ combined.narrow(dim, shard_offset, shard_size).copy_(shard)
284
+ shard_offset += shard_size
285
+
286
+ return combined
287
+
288
+
289
+ def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
290
+ """
291
+ Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
292
+ compressed to R x ceil(C/8)
293
+
294
+ :param bytemasks: mask tensor where each byte corresponds to a weight
295
+ :return: mask tensor where each bit corresounds to a weight
296
+ """
297
+ packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
298
+ packed_bits_torch = torch.from_numpy(packed_bits_numpy)
299
+
300
+ return packed_bits_torch
301
+
302
+
303
+ def unpack_bitmasks(
304
+ packed_bitmasks: torch.Tensor, original_shape: torch.Size
305
+ ) -> torch.Tensor:
306
+ """
307
+ Converts a bitmask tensor back to a bytemask tensor for use during decompression
308
+
309
+ :param packed_bitmasks: mask tensor where each bit corresponds to a weight
310
+ :param original_shape: dense shape to decompress to
311
+ :return: boolean mask of weights in the original dense shape
312
+ """
313
+ # Unpack the bits
314
+ unpacked_bits = numpy.unpackbits(
315
+ packed_bitmasks.cpu().numpy(),
316
+ axis=-1,
317
+ count=original_shape[-1],
318
+ bitorder="little",
319
+ )
320
+
321
+ # Reshape to match the original shape
322
+ unpacked_bitmasks_torch = torch.from_numpy(
323
+ unpacked_bits.reshape(original_shape).astype(bool)
324
+ )
325
+
326
+ return unpacked_bitmasks_torch
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: compressed-tensors-nightly
3
- Version: 0.8.1.20250110
3
+ Version: 0.8.1.20250112
4
4
  Summary: Library for utilization of compressed safetensors of neural network models
5
5
  Home-page: https://github.com/neuralmagic/compressed-tensors
6
6
  Author: Neuralmagic, Inc.
@@ -18,12 +18,14 @@ src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py
18
18
  src/compressed_tensors/compressors/sparse_compressors/__init__.py
19
19
  src/compressed_tensors/compressors/sparse_compressors/base.py
20
20
  src/compressed_tensors/compressors/sparse_compressors/dense.py
21
+ src/compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py
21
22
  src/compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py
22
23
  src/compressed_tensors/compressors/sparse_quantized_compressors/__init__.py
23
24
  src/compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py
24
25
  src/compressed_tensors/config/__init__.py
25
26
  src/compressed_tensors/config/base.py
26
27
  src/compressed_tensors/config/dense.py
28
+ src/compressed_tensors/config/sparse_24_bitmask.py
27
29
  src/compressed_tensors/config/sparse_bitmask.py
28
30
  src/compressed_tensors/linear/__init__.py
29
31
  src/compressed_tensors/linear/compressed_linear.py