compressed-tensors 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +76 -14
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +2 -2
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +2 -2
- compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- compressed_tensors/config/__init__.py +1 -0
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/config/sparse_24_bitmask.py +40 -0
- compressed_tensors/quantization/lifecycle/apply.py +46 -1
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- compressed_tensors/quantization/quant_config.py +1 -1
- compressed_tensors/utils/helpers.py +174 -1
- compressed_tensors/utils/offload.py +332 -44
- compressed_tensors/utils/safetensors_load.py +83 -17
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +24 -22
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +0 -0
- {compressed_tensors-0.8.1.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -14,12 +14,12 @@
|
|
14
14
|
|
15
15
|
from typing import Dict, List, Tuple, Union
|
16
16
|
|
17
|
-
import numpy
|
18
17
|
import torch
|
19
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
19
|
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
20
|
from compressed_tensors.config import CompressionFormat
|
22
|
-
from compressed_tensors.
|
21
|
+
from compressed_tensors.quantization import FP8_DTYPE
|
22
|
+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
23
23
|
from torch import Tensor
|
24
24
|
|
25
25
|
|
@@ -28,8 +28,6 @@ __all__ = [
|
|
28
28
|
"BitmaskTensor",
|
29
29
|
"bitmask_compress",
|
30
30
|
"bitmask_decompress",
|
31
|
-
"pack_bitmasks",
|
32
|
-
"unpack_bitmasks",
|
33
31
|
]
|
34
32
|
|
35
33
|
|
@@ -134,9 +132,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
134
132
|
bytemasks = tensor != 0
|
135
133
|
row_counts = bytemasks.sum(dim=-1)
|
136
134
|
row_offsets = torch.cumsum(row_counts, 0) - row_counts
|
137
|
-
|
135
|
+
if tensor.dtype == FP8_DTYPE:
|
136
|
+
# acces raw bytes of the tensor
|
137
|
+
tensor_view = tensor.view(torch.int8)
|
138
|
+
values = tensor_view[bytemasks]
|
139
|
+
values = values.view(FP8_DTYPE)
|
140
|
+
else:
|
141
|
+
values = tensor[bytemasks]
|
138
142
|
bitmasks_packed = pack_bitmasks(bytemasks)
|
139
|
-
|
140
143
|
return values, bitmasks_packed, row_offsets
|
141
144
|
|
142
145
|
|
@@ -158,37 +161,3 @@ def bitmask_decompress(
|
|
158
161
|
decompressed_tensor[bytemasks_unpacked] = values
|
159
162
|
|
160
163
|
return decompressed_tensor
|
161
|
-
|
162
|
-
|
163
|
-
def pack_bitmasks(bytemasks: Tensor) -> Tensor:
|
164
|
-
"""
|
165
|
-
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
166
|
-
compressed to R x ceil(C/8)
|
167
|
-
:param bytemasks: mask tensor where each byte corresponds to a weight
|
168
|
-
:return: mask tensor where each bit corresounds to a weight
|
169
|
-
"""
|
170
|
-
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
171
|
-
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
172
|
-
|
173
|
-
return packed_bits_torch
|
174
|
-
|
175
|
-
|
176
|
-
def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
|
177
|
-
"""
|
178
|
-
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
179
|
-
|
180
|
-
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
181
|
-
:param original_shape: dense shape to decompress to
|
182
|
-
:return: boolean mask of weights in the original dense shape
|
183
|
-
"""
|
184
|
-
# Unpack the bits
|
185
|
-
unpacked_bits = numpy.unpackbits(
|
186
|
-
packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
|
187
|
-
)
|
188
|
-
|
189
|
-
# Reshape to match the original shape
|
190
|
-
unpacked_bitmasks_torch = torch.from_numpy(
|
191
|
-
unpacked_bits.reshape(original_shape).astype(bool)
|
192
|
-
)
|
193
|
-
|
194
|
-
return unpacked_bitmasks_torch
|
@@ -26,6 +26,7 @@ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"
|
|
26
26
|
class CompressionFormat(Enum):
|
27
27
|
dense = "dense"
|
28
28
|
sparse_bitmask = "sparse-bitmask"
|
29
|
+
sparse_24_bitmask = "sparse-24-bitmask"
|
29
30
|
int_quantized = "int-quantized"
|
30
31
|
float_quantized = "float-quantized"
|
31
32
|
naive_quantized = "naive-quantized"
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from compressed_tensors.config import (
|
18
|
+
CompressionFormat,
|
19
|
+
SparsityCompressionConfig,
|
20
|
+
SparsityStructure,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["Sparse24BitMaskConfig"]
|
25
|
+
|
26
|
+
|
27
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
|
28
|
+
class Sparse24BitMaskConfig(SparsityCompressionConfig):
|
29
|
+
"""
|
30
|
+
Configuration for storing a 24 sparse model using
|
31
|
+
bytemask compression
|
32
|
+
|
33
|
+
:param global_sparsity: average sparsity of the entire model
|
34
|
+
:param sparsity_structure: structure of the sparsity, should always be
|
35
|
+
"2:4" for this compression format
|
36
|
+
"""
|
37
|
+
|
38
|
+
format: str = CompressionFormat.sparse_24_bitmask.value
|
39
|
+
global_sparsity: Optional[float] = 0.0
|
40
|
+
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
|
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
|
|
18
18
|
from copy import deepcopy
|
19
19
|
from typing import Dict, Iterable, List, Optional
|
20
20
|
from typing import OrderedDict as OrderedDictType
|
21
|
-
from typing import Union
|
21
|
+
from typing import Set, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
@@ -52,6 +52,8 @@ __all__ = [
|
|
52
52
|
"apply_quantization_config",
|
53
53
|
"apply_quantization_status",
|
54
54
|
"find_name_or_class_matches",
|
55
|
+
"expand_sparse_target_names",
|
56
|
+
"is_sparse_target",
|
55
57
|
]
|
56
58
|
|
57
59
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -245,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
245
247
|
model.apply(compress_quantized_weights)
|
246
248
|
|
247
249
|
|
250
|
+
def expand_sparse_target_names(
|
251
|
+
model: Module, targets: Iterable[str], ignore: Iterable[str]
|
252
|
+
) -> Set[str]:
|
253
|
+
"""
|
254
|
+
Finds all unique module names in the model that match the given
|
255
|
+
targets and ignore lists.
|
256
|
+
|
257
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
258
|
+
|
259
|
+
:param model: model to search for targets in
|
260
|
+
:param targets: list of targets to search for
|
261
|
+
:param ignore: list of targets to ignore
|
262
|
+
:return: set of all targets that match the given targets and should
|
263
|
+
not be ignored
|
264
|
+
"""
|
265
|
+
return {
|
266
|
+
name
|
267
|
+
for name, module in iter_named_leaf_modules(model)
|
268
|
+
if is_sparse_target(name, module, targets, ignore)
|
269
|
+
}
|
270
|
+
|
271
|
+
|
272
|
+
def is_sparse_target(
|
273
|
+
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
|
274
|
+
) -> bool:
|
275
|
+
"""
|
276
|
+
Determines if a module should be included in the targets based on the
|
277
|
+
targets and ignore lists.
|
278
|
+
|
279
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
280
|
+
|
281
|
+
:param name: name of the module
|
282
|
+
:param module: the module itself
|
283
|
+
:param targets: list of targets to search for
|
284
|
+
:param ignore: list of targets to ignore
|
285
|
+
:return: True if the module is a target and not ignored, False otherwise
|
286
|
+
"""
|
287
|
+
return bool(
|
288
|
+
find_name_or_class_matches(name, module, targets)
|
289
|
+
and not find_name_or_class_matches(name, module, ignore or [])
|
290
|
+
)
|
291
|
+
|
292
|
+
|
248
293
|
def find_name_or_class_matches(
|
249
294
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
250
295
|
) -> List[str]:
|
@@ -82,8 +82,8 @@ def quantize(
|
|
82
82
|
def dequantize(
|
83
83
|
x_q: torch.Tensor,
|
84
84
|
scale: torch.Tensor,
|
85
|
-
zero_point: torch.Tensor = None,
|
86
|
-
args: QuantizationArgs = None,
|
85
|
+
zero_point: Optional[torch.Tensor] = None,
|
86
|
+
args: Optional[QuantizationArgs] = None,
|
87
87
|
dtype: Optional[torch.dtype] = None,
|
88
88
|
g_idx: Optional[torch.Tensor] = None,
|
89
89
|
) -> torch.Tensor:
|
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
30
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
31
|
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
32
|
-
from compressed_tensors.utils import
|
32
|
+
from compressed_tensors.utils import (
|
33
|
+
disable_hf_hook,
|
34
|
+
has_offloaded_params,
|
35
|
+
register_offload_parameter,
|
36
|
+
)
|
33
37
|
from torch.nn import Module, Parameter
|
34
38
|
|
35
39
|
|
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
|
|
112
116
|
module.quantization_scheme = scheme
|
113
117
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
114
118
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
120
|
-
from accelerate.utils import PrefixedDataset
|
121
|
-
except ModuleNotFoundError:
|
122
|
-
raise ModuleNotFoundError(
|
123
|
-
"Offloaded model detected. To use CPU offloading with "
|
124
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
125
|
-
"run `pip install compressed-tensors[accelerate]`"
|
126
|
-
)
|
127
|
-
|
128
|
-
offloaded = True
|
129
|
-
hook = module._hf_hook
|
130
|
-
prefix_dict = module._hf_hook.weights_map
|
131
|
-
new_prefix = {}
|
132
|
-
|
133
|
-
# recreate the prefix dict (since it is immutable)
|
134
|
-
# and add quantization parameters
|
135
|
-
for key, data in module.named_parameters():
|
136
|
-
if key not in prefix_dict:
|
137
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
138
|
-
else:
|
139
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
140
|
-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
141
|
-
remove_hook_from_module(module)
|
142
|
-
|
143
|
-
# wrap forward call of module to perform
|
144
|
-
# quantized actions based on calltime status
|
145
|
-
wrap_module_forward_quantized(module, scheme)
|
146
|
-
|
147
|
-
if offloaded:
|
148
|
-
# we need to re-add the hook for offloading now that we've wrapped forward
|
149
|
-
add_hook_to_module(module, hook)
|
150
|
-
if prefix_dict is not None:
|
151
|
-
module._hf_hook.weights_map = new_prefix_dict
|
119
|
+
with disable_hf_hook(module):
|
120
|
+
# wrap forward call of module to perform
|
121
|
+
# quantized actions based on calltime status
|
122
|
+
wrap_module_forward_quantized(module, scheme)
|
152
123
|
|
153
124
|
|
154
125
|
def is_attention_module(module: Module):
|
@@ -169,12 +140,17 @@ def _initialize_scale_zero_point(
|
|
169
140
|
if quantization_args.dynamic:
|
170
141
|
return
|
171
142
|
|
172
|
-
device
|
173
|
-
|
174
|
-
|
143
|
+
# begin on the same device as other parameters or cpu if offloaded.
|
144
|
+
# in the offloaded case, there's no point moving tensors to the execution device
|
145
|
+
# if they're going to be immediately offloaded by `register_offload_parameter`
|
146
|
+
params_device = next(module.parameters()).device
|
147
|
+
device = "cpu" if has_offloaded_params(module) else params_device
|
175
148
|
|
176
149
|
# infer expected scale/zero point shape
|
177
|
-
|
150
|
+
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
151
|
+
expected_shape = (1, 1)
|
152
|
+
else:
|
153
|
+
expected_shape = 1
|
178
154
|
|
179
155
|
if base_name == "weight" and weight_shape is not None:
|
180
156
|
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
@@ -193,7 +169,7 @@ def _initialize_scale_zero_point(
|
|
193
169
|
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
194
170
|
requires_grad=False,
|
195
171
|
)
|
196
|
-
module
|
172
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
197
173
|
|
198
174
|
if force_zero_point or not quantization_args.symmetric:
|
199
175
|
zp_dtype = quantization_args.pytorch_dtype()
|
@@ -201,7 +177,7 @@ def _initialize_scale_zero_point(
|
|
201
177
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
202
178
|
requires_grad=False,
|
203
179
|
)
|
204
|
-
module
|
180
|
+
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
|
205
181
|
|
206
182
|
# only grouped activation ordering has g_idx
|
207
183
|
if quantization_args.actorder == ActivationOrdering.GROUP:
|
@@ -211,7 +187,7 @@ def _initialize_scale_zero_point(
|
|
211
187
|
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
212
188
|
requires_grad=False,
|
213
189
|
)
|
214
|
-
module
|
190
|
+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
|
215
191
|
|
216
192
|
|
217
193
|
def _initialize_attn_scales(module: Module) -> None:
|
@@ -12,8 +12,11 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
import warnings
|
16
|
+
from functools import wraps
|
17
|
+
from typing import Any, Callable, Dict, List, Optional
|
16
18
|
|
19
|
+
import numpy
|
17
20
|
import torch
|
18
21
|
from transformers import AutoConfig
|
19
22
|
|
@@ -24,7 +27,13 @@ __all__ = [
|
|
24
27
|
"tensor_follows_mask_structure",
|
25
28
|
"replace_module",
|
26
29
|
"is_compressed_tensors_config",
|
30
|
+
"getattr_chain",
|
31
|
+
"deprecated",
|
27
32
|
"Aliasable",
|
33
|
+
"combine_shards",
|
34
|
+
"shard_tensor",
|
35
|
+
"pack_bitmasks",
|
36
|
+
"unpack_bitmasks",
|
28
37
|
]
|
29
38
|
|
30
39
|
FSDP_WRAPPER_NAME = "_fsdp_wrapped_module"
|
@@ -122,6 +131,65 @@ def is_compressed_tensors_config(compression_config: Any) -> bool:
|
|
122
131
|
return False
|
123
132
|
|
124
133
|
|
134
|
+
def getattr_chain(obj: Any, chain_str: str, *args, **kwargs) -> Any:
|
135
|
+
"""
|
136
|
+
Chain multiple getattr calls, separated by `.`
|
137
|
+
|
138
|
+
:param obj: base object whose attributes are being retrieved
|
139
|
+
:param chain_str: attribute names separated by `.`
|
140
|
+
:param default: default value, throw error otherwise
|
141
|
+
"""
|
142
|
+
if len(args) >= 1:
|
143
|
+
has_default = True
|
144
|
+
default = args[0]
|
145
|
+
elif "default" in kwargs:
|
146
|
+
has_default = True
|
147
|
+
default = kwargs["default"]
|
148
|
+
else:
|
149
|
+
has_default = False
|
150
|
+
|
151
|
+
attr_names = chain_str.split(".")
|
152
|
+
|
153
|
+
res = obj
|
154
|
+
for attr_name in attr_names:
|
155
|
+
if not hasattr(res, attr_name):
|
156
|
+
if has_default:
|
157
|
+
return default
|
158
|
+
else:
|
159
|
+
raise AttributeError(f"{res} object has no attribute {attr_name}")
|
160
|
+
res = getattr(res, attr_name)
|
161
|
+
|
162
|
+
return res
|
163
|
+
|
164
|
+
|
165
|
+
def deprecated(future_name: Optional[str] = None, message: Optional[str] = None):
|
166
|
+
"""
|
167
|
+
Decorator to mark functions as deprecated
|
168
|
+
|
169
|
+
:param new_function: Function called in place of depreciated function
|
170
|
+
:param message: Depreciation message, replaces default depreciation message
|
171
|
+
"""
|
172
|
+
|
173
|
+
def decorator(func: Callable[[Any], Any]):
|
174
|
+
nonlocal message
|
175
|
+
|
176
|
+
if message is None:
|
177
|
+
message = (
|
178
|
+
f"{func.__name__} is deprecated and will be removed in a future release"
|
179
|
+
)
|
180
|
+
if future_name is not None:
|
181
|
+
message += f". Please use {future_name} instead."
|
182
|
+
|
183
|
+
@wraps(func)
|
184
|
+
def wrapped(*args, **kwargs):
|
185
|
+
warnings.warn(message, DeprecationWarning, stacklevel=2)
|
186
|
+
return func(*args, **kwargs)
|
187
|
+
|
188
|
+
return wrapped
|
189
|
+
|
190
|
+
return decorator
|
191
|
+
|
192
|
+
|
125
193
|
class Aliasable:
|
126
194
|
"""
|
127
195
|
A mixin for enums to allow aliasing of enum members
|
@@ -151,3 +219,108 @@ class Aliasable:
|
|
151
219
|
def __hash__(self):
|
152
220
|
canonical_value = self.aliases.get(self.value, self.value)
|
153
221
|
return hash(canonical_value)
|
222
|
+
|
223
|
+
|
224
|
+
def shard_tensor(
|
225
|
+
tensor: torch.Tensor, shard_sizes: List[int], dim: int = 0
|
226
|
+
) -> List[torch.Tensor]:
|
227
|
+
"""
|
228
|
+
Shards a tensor into a list of tensors along a given dimension.
|
229
|
+
|
230
|
+
raises: ValueError: If the sum of shard_sizes does not match the
|
231
|
+
size of the tensor along the given dimension.
|
232
|
+
|
233
|
+
:param tensor: The input tensor to shard.
|
234
|
+
:param shard_sizes : List of sizes for each shard along the specified dimension.
|
235
|
+
:param dim : The dimension along which to shard the tensor.
|
236
|
+
:returns: A list of tensors sharded along the specified dimension.
|
237
|
+
"""
|
238
|
+
if sum(shard_sizes) != tensor.size(dim):
|
239
|
+
raise ValueError(
|
240
|
+
"Sum of shard_sizes must equal the size of the tensor "
|
241
|
+
"along the specified dimension."
|
242
|
+
)
|
243
|
+
|
244
|
+
shards = []
|
245
|
+
start_idx = 0
|
246
|
+
|
247
|
+
for size in shard_sizes:
|
248
|
+
end_idx = start_idx + size
|
249
|
+
shard = tensor.narrow(dim, start_idx, size)
|
250
|
+
shards.append(shard)
|
251
|
+
start_idx = end_idx
|
252
|
+
|
253
|
+
return shards
|
254
|
+
|
255
|
+
|
256
|
+
def combine_shards(shards, dim=0):
|
257
|
+
"""
|
258
|
+
Combine decompressed shards along a given dimension using `narrow`.
|
259
|
+
|
260
|
+
:param shards: List of decompressed shard tensors.
|
261
|
+
:param dim: Dimension to combine along (default: 0).
|
262
|
+
:return: Combined decompressed tensor.
|
263
|
+
"""
|
264
|
+
if not shards:
|
265
|
+
raise ValueError("The list of shards is empty.")
|
266
|
+
|
267
|
+
# Assert that all shards have the same dtype
|
268
|
+
shard_dtypes = {shard.dtype for shard in shards}
|
269
|
+
if len(shard_dtypes) > 1:
|
270
|
+
raise ValueError("All shards must have the same dtype.")
|
271
|
+
|
272
|
+
# Determine the total shape of the combined tensor
|
273
|
+
total_shape = list(shards[0].shape)
|
274
|
+
total_shape[dim] = sum(shard.shape[dim] for shard in shards)
|
275
|
+
|
276
|
+
# Create the combined tensor
|
277
|
+
combined = torch.zeros(total_shape, dtype=shards[0].dtype, device=shards[0].device)
|
278
|
+
|
279
|
+
# Fill the combined tensor using narrow
|
280
|
+
shard_offset = 0
|
281
|
+
for shard in shards:
|
282
|
+
shard_size = shard.shape[dim]
|
283
|
+
combined.narrow(dim, shard_offset, shard_size).copy_(shard)
|
284
|
+
shard_offset += shard_size
|
285
|
+
|
286
|
+
return combined
|
287
|
+
|
288
|
+
|
289
|
+
def pack_bitmasks(bytemasks: torch.Tensor) -> torch.Tensor:
|
290
|
+
"""
|
291
|
+
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
292
|
+
compressed to R x ceil(C/8)
|
293
|
+
|
294
|
+
:param bytemasks: mask tensor where each byte corresponds to a weight
|
295
|
+
:return: mask tensor where each bit corresounds to a weight
|
296
|
+
"""
|
297
|
+
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
298
|
+
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
299
|
+
|
300
|
+
return packed_bits_torch
|
301
|
+
|
302
|
+
|
303
|
+
def unpack_bitmasks(
|
304
|
+
packed_bitmasks: torch.Tensor, original_shape: torch.Size
|
305
|
+
) -> torch.Tensor:
|
306
|
+
"""
|
307
|
+
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
308
|
+
|
309
|
+
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
310
|
+
:param original_shape: dense shape to decompress to
|
311
|
+
:return: boolean mask of weights in the original dense shape
|
312
|
+
"""
|
313
|
+
# Unpack the bits
|
314
|
+
unpacked_bits = numpy.unpackbits(
|
315
|
+
packed_bitmasks.cpu().numpy(),
|
316
|
+
axis=-1,
|
317
|
+
count=original_shape[-1],
|
318
|
+
bitorder="little",
|
319
|
+
)
|
320
|
+
|
321
|
+
# Reshape to match the original shape
|
322
|
+
unpacked_bitmasks_torch = torch.from_numpy(
|
323
|
+
unpacked_bits.reshape(original_shape).astype(bool)
|
324
|
+
)
|
325
|
+
|
326
|
+
return unpacked_bitmasks_torch
|