compressed-tensors 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
- compressed_tensors/compressors/quantized_compressors/base.py +35 -5
- compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
- compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
- compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
- compressed_tensors/compressors/sparse_compressors/base.py +45 -7
- compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
- compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
- compressed_tensors/config/__init__.py +1 -0
- compressed_tensors/config/base.py +1 -0
- compressed_tensors/config/sparse_24_bitmask.py +40 -0
- compressed_tensors/linear/compressed_linear.py +3 -1
- compressed_tensors/quantization/lifecycle/apply.py +48 -2
- compressed_tensors/quantization/lifecycle/forward.py +2 -2
- compressed_tensors/quantization/lifecycle/initialize.py +21 -45
- compressed_tensors/quantization/quant_args.py +16 -3
- compressed_tensors/quantization/quant_config.py +3 -3
- compressed_tensors/quantization/quant_scheme.py +17 -24
- compressed_tensors/utils/helpers.py +206 -1
- compressed_tensors/utils/offload.py +332 -44
- compressed_tensors/utils/safetensors_load.py +83 -17
- compressed_tensors/version.py +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +27 -25
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +1 -1
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
- {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Dict, List, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
from compressed_tensors.compressors.base import BaseCompressor
|
20
|
+
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
|
+
from compressed_tensors.config import CompressionFormat, SparsityStructure
|
22
|
+
from compressed_tensors.quantization import FP8_DTYPE
|
23
|
+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
24
|
+
from torch import Tensor
|
25
|
+
|
26
|
+
|
27
|
+
__all__ = [
|
28
|
+
"Sparse24BitMaskCompressor",
|
29
|
+
"Sparse24BitMaskTensor",
|
30
|
+
"sparse24_bitmask_compress",
|
31
|
+
"sparse24_bitmask_decompress",
|
32
|
+
"get_24_bytemasks",
|
33
|
+
]
|
34
|
+
|
35
|
+
|
36
|
+
@BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
|
37
|
+
class Sparse24BitMaskCompressor(BaseSparseCompressor):
|
38
|
+
"""
|
39
|
+
Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
|
40
|
+
values tensor, with their locations stored in a 2d bitmask
|
41
|
+
"""
|
42
|
+
|
43
|
+
COMPRESSION_PARAM_NAMES = [
|
44
|
+
"shape",
|
45
|
+
"compressed",
|
46
|
+
"bitmask",
|
47
|
+
]
|
48
|
+
|
49
|
+
def compress_weight(self, name, value):
|
50
|
+
bitmask_tensor = Sparse24BitMaskTensor.from_dense(
|
51
|
+
value, self.config.sparsity_structure
|
52
|
+
)
|
53
|
+
bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
|
54
|
+
return bitmask_dict
|
55
|
+
|
56
|
+
def decompress_weight(self, weight_data):
|
57
|
+
data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
|
58
|
+
decompressed = data.decompress()
|
59
|
+
return decompressed
|
60
|
+
|
61
|
+
|
62
|
+
@dataclass
|
63
|
+
class Sparse24BitMaskTensor:
|
64
|
+
"""
|
65
|
+
Owns compressions and decompression for a single 2:4 sparse
|
66
|
+
bitmask compressed tensor.
|
67
|
+
|
68
|
+
:param shape: shape of dense tensor
|
69
|
+
:param compressed: 2d tensor of non-zero values
|
70
|
+
:param bitmask: 2d bitmask of non-zero values
|
71
|
+
"""
|
72
|
+
|
73
|
+
shape: List[int]
|
74
|
+
compressed: Tensor
|
75
|
+
bitmask: Tensor
|
76
|
+
|
77
|
+
@staticmethod
|
78
|
+
def from_dense(
|
79
|
+
tensor: Tensor,
|
80
|
+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
|
81
|
+
) -> "Sparse24BitMaskTensor":
|
82
|
+
"""
|
83
|
+
:param tensor: dense tensor to compress
|
84
|
+
:return: instantiated compressed tensor
|
85
|
+
"""
|
86
|
+
shape = list(tensor.shape)
|
87
|
+
compressed, bitmask = sparse24_bitmask_compress(
|
88
|
+
tensor.cpu(), sparsity_structure=sparsity_structure
|
89
|
+
)
|
90
|
+
return Sparse24BitMaskTensor(
|
91
|
+
shape=shape,
|
92
|
+
compressed=compressed,
|
93
|
+
bitmask=bitmask,
|
94
|
+
)
|
95
|
+
|
96
|
+
@staticmethod
|
97
|
+
def from_compressed_data(
|
98
|
+
shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
|
99
|
+
) -> "Sparse24BitMaskTensor":
|
100
|
+
"""
|
101
|
+
:param shape: shape of the dense tensor (can be a list or a tensor)
|
102
|
+
:param compressed: 2d tensor of non-zero values
|
103
|
+
:param bitmask: 2d bitmask of non-zero values
|
104
|
+
:return: instantiated Sparse24BitMaskTensor
|
105
|
+
"""
|
106
|
+
if isinstance(shape, Tensor):
|
107
|
+
shape = shape.tolist()
|
108
|
+
return Sparse24BitMaskTensor(
|
109
|
+
shape=shape, compressed=compressed, bitmask=bitmask
|
110
|
+
)
|
111
|
+
|
112
|
+
def decompress(self) -> Tensor:
|
113
|
+
"""
|
114
|
+
:return: reconstructed dense tensor
|
115
|
+
"""
|
116
|
+
return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
|
117
|
+
|
118
|
+
def curr_memory_size_bytes(self) -> int:
|
119
|
+
"""
|
120
|
+
:return: size in bytes required to store compressed tensor on disk
|
121
|
+
"""
|
122
|
+
|
123
|
+
def sizeof_tensor(a: Tensor) -> int:
|
124
|
+
return a.element_size() * a.nelement()
|
125
|
+
|
126
|
+
return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
|
127
|
+
|
128
|
+
def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
|
129
|
+
"""
|
130
|
+
:param name_prefix: name of original tensor to store compressed weight as
|
131
|
+
:return: dict of compressed data for the stored weight
|
132
|
+
"""
|
133
|
+
if name_prefix.endswith(".weight"):
|
134
|
+
name_prefix = name_prefix[: -len(".weight")]
|
135
|
+
return {
|
136
|
+
merge_names(name_prefix, "shape"): torch.tensor(
|
137
|
+
self.shape, device=device
|
138
|
+
).reshape(-1, 1),
|
139
|
+
merge_names(name_prefix, "compressed"): self.compressed.to(device),
|
140
|
+
merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
|
141
|
+
}
|
142
|
+
|
143
|
+
def __repr__(self) -> str:
|
144
|
+
return f"BitMaskTensor(shape={self.shape}, compressed=True)"
|
145
|
+
|
146
|
+
|
147
|
+
def sparse24_bitmask_compress(
|
148
|
+
tensor: Tensor,
|
149
|
+
sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
|
150
|
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
151
|
+
"""
|
152
|
+
Compresses a dense tensor using bitmask compression
|
153
|
+
|
154
|
+
:param tensor: dense 2D tensor to compress
|
155
|
+
:param sparsity_structure: structure of sparsity in the tensor, defaults
|
156
|
+
to unstructured, can also be set to `2:4`
|
157
|
+
:return: tuple of compressed data representing tensor
|
158
|
+
"""
|
159
|
+
assert len(tensor.shape) == 2, "Only 2D tensors are supported"
|
160
|
+
assert (
|
161
|
+
SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
|
162
|
+
), "Only 2:4 sparsity is supported"
|
163
|
+
|
164
|
+
bytemasks = get_24_bytemasks(tensor=tensor)
|
165
|
+
|
166
|
+
if tensor.dtype == FP8_DTYPE:
|
167
|
+
# acces raw bytes of the tensor
|
168
|
+
tensor_view = tensor.view(torch.int8)
|
169
|
+
values = tensor_view[bytemasks]
|
170
|
+
values = values.view(FP8_DTYPE)
|
171
|
+
else:
|
172
|
+
values = tensor[bytemasks]
|
173
|
+
|
174
|
+
num_rows, num_cols = tensor.shape
|
175
|
+
compressed_values = values.reshape(num_rows, num_cols // 2)
|
176
|
+
bitmasks_packed = pack_bitmasks(bytemasks)
|
177
|
+
return compressed_values, bitmasks_packed
|
178
|
+
|
179
|
+
|
180
|
+
def sparse24_bitmask_decompress(
|
181
|
+
values: Tensor, bitmasks: Tensor, original_shape: torch.Size
|
182
|
+
) -> Tensor:
|
183
|
+
"""
|
184
|
+
Reconstructs a dense tensor from a compressed one
|
185
|
+
|
186
|
+
:param values: 1d tensor of non-zero values
|
187
|
+
:param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
|
188
|
+
tensors original shape
|
189
|
+
:param original_shape: shape of the dense tensor
|
190
|
+
:return: decompressed dense tensor
|
191
|
+
"""
|
192
|
+
bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
|
193
|
+
|
194
|
+
decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
|
195
|
+
decompressed_tensor = decompressed_tensor.to(values.device)
|
196
|
+
values = values.flatten()
|
197
|
+
if decompressed_tensor.dtype == FP8_DTYPE:
|
198
|
+
decompressed_tensor[bytemasks_unpacked] = values
|
199
|
+
decompressed_tensor = decompressed_tensor.cuda()
|
200
|
+
else:
|
201
|
+
decompressed_tensor[bytemasks_unpacked] = values
|
202
|
+
return decompressed_tensor
|
203
|
+
|
204
|
+
|
205
|
+
def get_24_bytemasks(tensor):
|
206
|
+
"""
|
207
|
+
Generate a 2:4 sparsity mask for the given tensor.
|
208
|
+
|
209
|
+
This function creates a mask where exactly 2 out of every 4 elements are
|
210
|
+
preserved based on their magnitudes. The preserved elements are the ones
|
211
|
+
with the highest absolute values in each group of 4 elements.
|
212
|
+
|
213
|
+
:param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
|
214
|
+
The tensor can be of any shape but its total number of elements
|
215
|
+
must be a multiple of 4.
|
216
|
+
:return: A boolean tensor of the same shape as the input tensor, where `True`
|
217
|
+
indicates the preserved elements and `False` indicates the pruned elements.
|
218
|
+
:raises ValueError: If the total number of elements in the tensor is not a
|
219
|
+
multiple of 4.
|
220
|
+
"""
|
221
|
+
original_dtype = tensor.dtype
|
222
|
+
if tensor.dtype == FP8_DTYPE:
|
223
|
+
tensor = tensor.view(torch.int8)
|
224
|
+
original_shape = tensor.shape
|
225
|
+
num_elements = tensor.numel()
|
226
|
+
|
227
|
+
if num_elements % 4 != 0:
|
228
|
+
raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
|
229
|
+
|
230
|
+
reshaped_tensor = tensor.view(-1, 4)
|
231
|
+
abs_tensor = reshaped_tensor.abs()
|
232
|
+
topk_indices = abs_tensor.topk(2, dim=1).indices
|
233
|
+
mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
|
234
|
+
mask.scatter_(1, topk_indices, True)
|
235
|
+
mask = mask.view(original_shape)
|
236
|
+
tensor = tensor.view(original_dtype)
|
237
|
+
|
238
|
+
return mask
|
@@ -14,12 +14,12 @@
|
|
14
14
|
|
15
15
|
from typing import Dict, List, Tuple, Union
|
16
16
|
|
17
|
-
import numpy
|
18
17
|
import torch
|
19
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
20
19
|
from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
|
21
20
|
from compressed_tensors.config import CompressionFormat
|
22
|
-
from compressed_tensors.
|
21
|
+
from compressed_tensors.quantization import FP8_DTYPE
|
22
|
+
from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
|
23
23
|
from torch import Tensor
|
24
24
|
|
25
25
|
|
@@ -28,8 +28,6 @@ __all__ = [
|
|
28
28
|
"BitmaskTensor",
|
29
29
|
"bitmask_compress",
|
30
30
|
"bitmask_decompress",
|
31
|
-
"pack_bitmasks",
|
32
|
-
"unpack_bitmasks",
|
33
31
|
]
|
34
32
|
|
35
33
|
|
@@ -134,9 +132,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
|
134
132
|
bytemasks = tensor != 0
|
135
133
|
row_counts = bytemasks.sum(dim=-1)
|
136
134
|
row_offsets = torch.cumsum(row_counts, 0) - row_counts
|
137
|
-
|
135
|
+
if tensor.dtype == FP8_DTYPE:
|
136
|
+
# acces raw bytes of the tensor
|
137
|
+
tensor_view = tensor.view(torch.int8)
|
138
|
+
values = tensor_view[bytemasks]
|
139
|
+
values = values.view(FP8_DTYPE)
|
140
|
+
else:
|
141
|
+
values = tensor[bytemasks]
|
138
142
|
bitmasks_packed = pack_bitmasks(bytemasks)
|
139
|
-
|
140
143
|
return values, bitmasks_packed, row_offsets
|
141
144
|
|
142
145
|
|
@@ -158,37 +161,3 @@ def bitmask_decompress(
|
|
158
161
|
decompressed_tensor[bytemasks_unpacked] = values
|
159
162
|
|
160
163
|
return decompressed_tensor
|
161
|
-
|
162
|
-
|
163
|
-
def pack_bitmasks(bytemasks: Tensor) -> Tensor:
|
164
|
-
"""
|
165
|
-
Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
|
166
|
-
compressed to R x ceil(C/8)
|
167
|
-
:param bytemasks: mask tensor where each byte corresponds to a weight
|
168
|
-
:return: mask tensor where each bit corresounds to a weight
|
169
|
-
"""
|
170
|
-
packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
|
171
|
-
packed_bits_torch = torch.from_numpy(packed_bits_numpy)
|
172
|
-
|
173
|
-
return packed_bits_torch
|
174
|
-
|
175
|
-
|
176
|
-
def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
|
177
|
-
"""
|
178
|
-
Converts a bitmask tensor back to a bytemask tensor for use during decompression
|
179
|
-
|
180
|
-
:param packed_bitmasks: mask tensor where each bit corresponds to a weight
|
181
|
-
:param original_shape: dense shape to decompress to
|
182
|
-
:return: boolean mask of weights in the original dense shape
|
183
|
-
"""
|
184
|
-
# Unpack the bits
|
185
|
-
unpacked_bits = numpy.unpackbits(
|
186
|
-
packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
|
187
|
-
)
|
188
|
-
|
189
|
-
# Reshape to match the original shape
|
190
|
-
unpacked_bitmasks_torch = torch.from_numpy(
|
191
|
-
unpacked_bits.reshape(original_shape).astype(bool)
|
192
|
-
)
|
193
|
-
|
194
|
-
return unpacked_bitmasks_torch
|
@@ -26,6 +26,7 @@ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"
|
|
26
26
|
class CompressionFormat(Enum):
|
27
27
|
dense = "dense"
|
28
28
|
sparse_bitmask = "sparse-bitmask"
|
29
|
+
sparse_24_bitmask = "sparse-24-bitmask"
|
29
30
|
int_quantized = "int-quantized"
|
30
31
|
float_quantized = "float-quantized"
|
31
32
|
naive_quantized = "naive-quantized"
|
@@ -0,0 +1,40 @@
|
|
1
|
+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing,
|
10
|
+
# software distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from typing import Optional
|
16
|
+
|
17
|
+
from compressed_tensors.config import (
|
18
|
+
CompressionFormat,
|
19
|
+
SparsityCompressionConfig,
|
20
|
+
SparsityStructure,
|
21
|
+
)
|
22
|
+
|
23
|
+
|
24
|
+
__all__ = ["Sparse24BitMaskConfig"]
|
25
|
+
|
26
|
+
|
27
|
+
@SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
|
28
|
+
class Sparse24BitMaskConfig(SparsityCompressionConfig):
|
29
|
+
"""
|
30
|
+
Configuration for storing a 24 sparse model using
|
31
|
+
bytemask compression
|
32
|
+
|
33
|
+
:param global_sparsity: average sparsity of the entire model
|
34
|
+
:param sparsity_structure: structure of the sparsity, should always be
|
35
|
+
"2:4" for this compression format
|
36
|
+
"""
|
37
|
+
|
38
|
+
format: str = CompressionFormat.sparse_24_bitmask.value
|
39
|
+
global_sparsity: Optional[float] = 0.0
|
40
|
+
sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
|
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from typing import Dict, Tuple
|
16
|
+
|
15
17
|
import torch
|
16
18
|
from compressed_tensors.compressors.base import BaseCompressor
|
17
19
|
from compressed_tensors.quantization import (
|
@@ -53,7 +55,7 @@ class CompressedLinear(Linear):
|
|
53
55
|
)
|
54
56
|
|
55
57
|
# get the shape and dtype of compressed parameters
|
56
|
-
compression_params = module.compressor.compression_param_info(
|
58
|
+
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
|
57
59
|
module.weight.shape, quantization_scheme.weights
|
58
60
|
)
|
59
61
|
|
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
|
|
18
18
|
from copy import deepcopy
|
19
19
|
from typing import Dict, Iterable, List, Optional
|
20
20
|
from typing import OrderedDict as OrderedDictType
|
21
|
-
from typing import Union
|
21
|
+
from typing import Set, Union
|
22
22
|
|
23
23
|
import torch
|
24
24
|
from compressed_tensors.config import CompressionFormat
|
@@ -52,6 +52,8 @@ __all__ = [
|
|
52
52
|
"apply_quantization_config",
|
53
53
|
"apply_quantization_status",
|
54
54
|
"find_name_or_class_matches",
|
55
|
+
"expand_sparse_target_names",
|
56
|
+
"is_sparse_target",
|
55
57
|
]
|
56
58
|
|
57
59
|
from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -106,7 +108,8 @@ def apply_quantization_config(
|
|
106
108
|
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
|
107
109
|
) -> OrderedDict:
|
108
110
|
"""
|
109
|
-
Initializes the model for quantization in-place based on the given config
|
111
|
+
Initializes the model for quantization in-place based on the given config.
|
112
|
+
Optionally coverts quantizable modules to compressed_linear modules
|
110
113
|
|
111
114
|
:param model: model to apply quantization config to
|
112
115
|
:param config: quantization config
|
@@ -244,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
|
244
247
|
model.apply(compress_quantized_weights)
|
245
248
|
|
246
249
|
|
250
|
+
def expand_sparse_target_names(
|
251
|
+
model: Module, targets: Iterable[str], ignore: Iterable[str]
|
252
|
+
) -> Set[str]:
|
253
|
+
"""
|
254
|
+
Finds all unique module names in the model that match the given
|
255
|
+
targets and ignore lists.
|
256
|
+
|
257
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
258
|
+
|
259
|
+
:param model: model to search for targets in
|
260
|
+
:param targets: list of targets to search for
|
261
|
+
:param ignore: list of targets to ignore
|
262
|
+
:return: set of all targets that match the given targets and should
|
263
|
+
not be ignored
|
264
|
+
"""
|
265
|
+
return {
|
266
|
+
name
|
267
|
+
for name, module in iter_named_leaf_modules(model)
|
268
|
+
if is_sparse_target(name, module, targets, ignore)
|
269
|
+
}
|
270
|
+
|
271
|
+
|
272
|
+
def is_sparse_target(
|
273
|
+
name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
|
274
|
+
) -> bool:
|
275
|
+
"""
|
276
|
+
Determines if a module should be included in the targets based on the
|
277
|
+
targets and ignore lists.
|
278
|
+
|
279
|
+
Note: Targets must be regexes, layer types, or full layer names.
|
280
|
+
|
281
|
+
:param name: name of the module
|
282
|
+
:param module: the module itself
|
283
|
+
:param targets: list of targets to search for
|
284
|
+
:param ignore: list of targets to ignore
|
285
|
+
:return: True if the module is a target and not ignored, False otherwise
|
286
|
+
"""
|
287
|
+
return bool(
|
288
|
+
find_name_or_class_matches(name, module, targets)
|
289
|
+
and not find_name_or_class_matches(name, module, ignore or [])
|
290
|
+
)
|
291
|
+
|
292
|
+
|
247
293
|
def find_name_or_class_matches(
|
248
294
|
name: str, module: Module, targets: Iterable[str], check_contains: bool = False
|
249
295
|
) -> List[str]:
|
@@ -82,8 +82,8 @@ def quantize(
|
|
82
82
|
def dequantize(
|
83
83
|
x_q: torch.Tensor,
|
84
84
|
scale: torch.Tensor,
|
85
|
-
zero_point: torch.Tensor = None,
|
86
|
-
args: QuantizationArgs = None,
|
85
|
+
zero_point: Optional[torch.Tensor] = None,
|
86
|
+
args: Optional[QuantizationArgs] = None,
|
87
87
|
dtype: Optional[torch.dtype] = None,
|
88
88
|
g_idx: Optional[torch.Tensor] = None,
|
89
89
|
) -> torch.Tensor:
|
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
|
|
29
29
|
from compressed_tensors.quantization.quant_config import QuantizationStatus
|
30
30
|
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
|
31
31
|
from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
|
32
|
-
from compressed_tensors.utils import
|
32
|
+
from compressed_tensors.utils import (
|
33
|
+
disable_hf_hook,
|
34
|
+
has_offloaded_params,
|
35
|
+
register_offload_parameter,
|
36
|
+
)
|
33
37
|
from torch.nn import Module, Parameter
|
34
38
|
|
35
39
|
|
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
|
|
112
116
|
module.quantization_scheme = scheme
|
113
117
|
module.quantization_status = QuantizationStatus.INITIALIZED
|
114
118
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
from accelerate.hooks import add_hook_to_module, remove_hook_from_module
|
120
|
-
from accelerate.utils import PrefixedDataset
|
121
|
-
except ModuleNotFoundError:
|
122
|
-
raise ModuleNotFoundError(
|
123
|
-
"Offloaded model detected. To use CPU offloading with "
|
124
|
-
"compressed-tensors the `accelerate` package must be installed, "
|
125
|
-
"run `pip install compressed-tensors[accelerate]`"
|
126
|
-
)
|
127
|
-
|
128
|
-
offloaded = True
|
129
|
-
hook = module._hf_hook
|
130
|
-
prefix_dict = module._hf_hook.weights_map
|
131
|
-
new_prefix = {}
|
132
|
-
|
133
|
-
# recreate the prefix dict (since it is immutable)
|
134
|
-
# and add quantization parameters
|
135
|
-
for key, data in module.named_parameters():
|
136
|
-
if key not in prefix_dict:
|
137
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = data
|
138
|
-
else:
|
139
|
-
new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
|
140
|
-
new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
|
141
|
-
remove_hook_from_module(module)
|
142
|
-
|
143
|
-
# wrap forward call of module to perform
|
144
|
-
# quantized actions based on calltime status
|
145
|
-
wrap_module_forward_quantized(module, scheme)
|
146
|
-
|
147
|
-
if offloaded:
|
148
|
-
# we need to re-add the hook for offloading now that we've wrapped forward
|
149
|
-
add_hook_to_module(module, hook)
|
150
|
-
if prefix_dict is not None:
|
151
|
-
module._hf_hook.weights_map = new_prefix_dict
|
119
|
+
with disable_hf_hook(module):
|
120
|
+
# wrap forward call of module to perform
|
121
|
+
# quantized actions based on calltime status
|
122
|
+
wrap_module_forward_quantized(module, scheme)
|
152
123
|
|
153
124
|
|
154
125
|
def is_attention_module(module: Module):
|
@@ -169,12 +140,17 @@ def _initialize_scale_zero_point(
|
|
169
140
|
if quantization_args.dynamic:
|
170
141
|
return
|
171
142
|
|
172
|
-
device
|
173
|
-
|
174
|
-
|
143
|
+
# begin on the same device as other parameters or cpu if offloaded.
|
144
|
+
# in the offloaded case, there's no point moving tensors to the execution device
|
145
|
+
# if they're going to be immediately offloaded by `register_offload_parameter`
|
146
|
+
params_device = next(module.parameters()).device
|
147
|
+
device = "cpu" if has_offloaded_params(module) else params_device
|
175
148
|
|
176
149
|
# infer expected scale/zero point shape
|
177
|
-
|
150
|
+
if quantization_args.strategy == QuantizationStrategy.TOKEN:
|
151
|
+
expected_shape = (1, 1)
|
152
|
+
else:
|
153
|
+
expected_shape = 1
|
178
154
|
|
179
155
|
if base_name == "weight" and weight_shape is not None:
|
180
156
|
if quantization_args.strategy == QuantizationStrategy.CHANNEL:
|
@@ -193,7 +169,7 @@ def _initialize_scale_zero_point(
|
|
193
169
|
torch.empty(expected_shape, dtype=scale_dtype, device=device),
|
194
170
|
requires_grad=False,
|
195
171
|
)
|
196
|
-
module
|
172
|
+
register_offload_parameter(module, f"{base_name}_scale", init_scale)
|
197
173
|
|
198
174
|
if force_zero_point or not quantization_args.symmetric:
|
199
175
|
zp_dtype = quantization_args.pytorch_dtype()
|
@@ -201,7 +177,7 @@ def _initialize_scale_zero_point(
|
|
201
177
|
torch.zeros(expected_shape, device=device, dtype=zp_dtype),
|
202
178
|
requires_grad=False,
|
203
179
|
)
|
204
|
-
module
|
180
|
+
register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
|
205
181
|
|
206
182
|
# only grouped activation ordering has g_idx
|
207
183
|
if quantization_args.actorder == ActivationOrdering.GROUP:
|
@@ -211,7 +187,7 @@ def _initialize_scale_zero_point(
|
|
211
187
|
torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
|
212
188
|
requires_grad=False,
|
213
189
|
)
|
214
|
-
module
|
190
|
+
register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
|
215
191
|
|
216
192
|
|
217
193
|
def _initialize_attn_scales(module: Module) -> None:
|
@@ -17,6 +17,7 @@ from enum import Enum
|
|
17
17
|
from typing import Any, Dict, Optional, Union
|
18
18
|
|
19
19
|
import torch
|
20
|
+
from compressed_tensors.utils import Aliasable
|
20
21
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
21
22
|
|
22
23
|
|
@@ -53,17 +54,29 @@ class QuantizationStrategy(str, Enum):
|
|
53
54
|
TOKEN = "token"
|
54
55
|
|
55
56
|
|
56
|
-
class ActivationOrdering(str, Enum):
|
57
|
+
class ActivationOrdering(Aliasable, str, Enum):
|
57
58
|
"""
|
58
59
|
Enum storing strategies for activation ordering
|
59
60
|
|
60
61
|
Group: reorder groups and weight\n
|
61
|
-
Weight: only reorder weight, not groups. Slightly lower
|
62
|
-
|
62
|
+
Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
|
63
|
+
latency when compared to group actorder\n
|
64
|
+
Dynamic: alias for Group\n
|
65
|
+
Static: alias for Weight\n
|
63
66
|
"""
|
64
67
|
|
65
68
|
GROUP = "group"
|
66
69
|
WEIGHT = "weight"
|
70
|
+
# aliases
|
71
|
+
DYNAMIC = "dynamic"
|
72
|
+
STATIC = "static"
|
73
|
+
|
74
|
+
@staticmethod
|
75
|
+
def get_aliases() -> Dict[str, str]:
|
76
|
+
return {
|
77
|
+
"dynamic": "group",
|
78
|
+
"static": "weight",
|
79
|
+
}
|
67
80
|
|
68
81
|
|
69
82
|
class QuantizationArgs(BaseModel, use_enum_values=True):
|
@@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel):
|
|
132
132
|
`k_proj` and `v_proj` in their names. If this is not the case
|
133
133
|
and kv_cache_scheme != None, the quantization of kv cache will fail
|
134
134
|
:global_compression_ratio: optional informational config to report the model
|
135
|
-
|
135
|
+
compression ratio acheived by the quantization config
|
136
136
|
:ignore: optional list of layers to ignore from config_groups. Layers in this list
|
137
|
-
|
137
|
+
are not quantized even if they match up with a target in config_groups
|
138
138
|
"""
|
139
139
|
|
140
140
|
config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
|
@@ -160,7 +160,7 @@ class QuantizationConfig(BaseModel):
|
|
160
160
|
|
161
161
|
def to_dict(self):
|
162
162
|
# for compatibility with HFQuantizer
|
163
|
-
return self.
|
163
|
+
return self.model_dump()
|
164
164
|
|
165
165
|
@staticmethod
|
166
166
|
def from_pretrained(
|