compressed-tensors 0.8.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. compressed_tensors/compressors/model_compressors/model_compressor.py +92 -18
  2. compressed_tensors/compressors/quantized_compressors/base.py +35 -5
  3. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +6 -4
  4. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +4 -2
  5. compressed_tensors/compressors/sparse_compressors/__init__.py +1 -0
  6. compressed_tensors/compressors/sparse_compressors/base.py +45 -7
  7. compressed_tensors/compressors/sparse_compressors/sparse_24_bitmask.py +238 -0
  8. compressed_tensors/compressors/sparse_compressors/sparse_bitmask.py +9 -40
  9. compressed_tensors/config/__init__.py +1 -0
  10. compressed_tensors/config/base.py +1 -0
  11. compressed_tensors/config/sparse_24_bitmask.py +40 -0
  12. compressed_tensors/linear/compressed_linear.py +3 -1
  13. compressed_tensors/quantization/lifecycle/apply.py +48 -2
  14. compressed_tensors/quantization/lifecycle/forward.py +2 -2
  15. compressed_tensors/quantization/lifecycle/initialize.py +21 -45
  16. compressed_tensors/quantization/quant_args.py +16 -3
  17. compressed_tensors/quantization/quant_config.py +3 -3
  18. compressed_tensors/quantization/quant_scheme.py +17 -24
  19. compressed_tensors/utils/helpers.py +206 -1
  20. compressed_tensors/utils/offload.py +332 -44
  21. compressed_tensors/utils/safetensors_load.py +83 -17
  22. compressed_tensors/version.py +1 -1
  23. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/METADATA +1 -1
  24. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/RECORD +27 -25
  25. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/WHEEL +1 -1
  26. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/LICENSE +0 -0
  27. {compressed_tensors-0.8.0.dist-info → compressed_tensors-0.9.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,238 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Dict, List, Tuple, Union
17
+
18
+ import torch
19
+ from compressed_tensors.compressors.base import BaseCompressor
20
+ from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
+ from compressed_tensors.config import CompressionFormat, SparsityStructure
22
+ from compressed_tensors.quantization import FP8_DTYPE
23
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
24
+ from torch import Tensor
25
+
26
+
27
+ __all__ = [
28
+ "Sparse24BitMaskCompressor",
29
+ "Sparse24BitMaskTensor",
30
+ "sparse24_bitmask_compress",
31
+ "sparse24_bitmask_decompress",
32
+ "get_24_bytemasks",
33
+ ]
34
+
35
+
36
+ @BaseCompressor.register(name=CompressionFormat.sparse_24_bitmask.value)
37
+ class Sparse24BitMaskCompressor(BaseSparseCompressor):
38
+ """
39
+ Compression for sparse models using bitmasks. Non-zero weights are stored in a 2d
40
+ values tensor, with their locations stored in a 2d bitmask
41
+ """
42
+
43
+ COMPRESSION_PARAM_NAMES = [
44
+ "shape",
45
+ "compressed",
46
+ "bitmask",
47
+ ]
48
+
49
+ def compress_weight(self, name, value):
50
+ bitmask_tensor = Sparse24BitMaskTensor.from_dense(
51
+ value, self.config.sparsity_structure
52
+ )
53
+ bitmask_dict = bitmask_tensor.dict(name_prefix=name, device="cpu")
54
+ return bitmask_dict
55
+
56
+ def decompress_weight(self, weight_data):
57
+ data = Sparse24BitMaskTensor.from_compressed_data(**weight_data)
58
+ decompressed = data.decompress()
59
+ return decompressed
60
+
61
+
62
+ @dataclass
63
+ class Sparse24BitMaskTensor:
64
+ """
65
+ Owns compressions and decompression for a single 2:4 sparse
66
+ bitmask compressed tensor.
67
+
68
+ :param shape: shape of dense tensor
69
+ :param compressed: 2d tensor of non-zero values
70
+ :param bitmask: 2d bitmask of non-zero values
71
+ """
72
+
73
+ shape: List[int]
74
+ compressed: Tensor
75
+ bitmask: Tensor
76
+
77
+ @staticmethod
78
+ def from_dense(
79
+ tensor: Tensor,
80
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
81
+ ) -> "Sparse24BitMaskTensor":
82
+ """
83
+ :param tensor: dense tensor to compress
84
+ :return: instantiated compressed tensor
85
+ """
86
+ shape = list(tensor.shape)
87
+ compressed, bitmask = sparse24_bitmask_compress(
88
+ tensor.cpu(), sparsity_structure=sparsity_structure
89
+ )
90
+ return Sparse24BitMaskTensor(
91
+ shape=shape,
92
+ compressed=compressed,
93
+ bitmask=bitmask,
94
+ )
95
+
96
+ @staticmethod
97
+ def from_compressed_data(
98
+ shape: Union[List[int], Tensor], compressed: Tensor, bitmask: Tensor
99
+ ) -> "Sparse24BitMaskTensor":
100
+ """
101
+ :param shape: shape of the dense tensor (can be a list or a tensor)
102
+ :param compressed: 2d tensor of non-zero values
103
+ :param bitmask: 2d bitmask of non-zero values
104
+ :return: instantiated Sparse24BitMaskTensor
105
+ """
106
+ if isinstance(shape, Tensor):
107
+ shape = shape.tolist()
108
+ return Sparse24BitMaskTensor(
109
+ shape=shape, compressed=compressed, bitmask=bitmask
110
+ )
111
+
112
+ def decompress(self) -> Tensor:
113
+ """
114
+ :return: reconstructed dense tensor
115
+ """
116
+ return sparse24_bitmask_decompress(self.compressed, self.bitmask, self.shape)
117
+
118
+ def curr_memory_size_bytes(self) -> int:
119
+ """
120
+ :return: size in bytes required to store compressed tensor on disk
121
+ """
122
+
123
+ def sizeof_tensor(a: Tensor) -> int:
124
+ return a.element_size() * a.nelement()
125
+
126
+ return sizeof_tensor(self.compressed) + sizeof_tensor(self.bitmask)
127
+
128
+ def dict(self, name_prefix: str, device: str = "cpu") -> Dict[str, Tensor]:
129
+ """
130
+ :param name_prefix: name of original tensor to store compressed weight as
131
+ :return: dict of compressed data for the stored weight
132
+ """
133
+ if name_prefix.endswith(".weight"):
134
+ name_prefix = name_prefix[: -len(".weight")]
135
+ return {
136
+ merge_names(name_prefix, "shape"): torch.tensor(
137
+ self.shape, device=device
138
+ ).reshape(-1, 1),
139
+ merge_names(name_prefix, "compressed"): self.compressed.to(device),
140
+ merge_names(name_prefix, "bitmask"): self.bitmask.to(device),
141
+ }
142
+
143
+ def __repr__(self) -> str:
144
+ return f"BitMaskTensor(shape={self.shape}, compressed=True)"
145
+
146
+
147
+ def sparse24_bitmask_compress(
148
+ tensor: Tensor,
149
+ sparsity_structure: Union[SparsityStructure, str] = SparsityStructure.TWO_FOUR,
150
+ ) -> Tuple[Tensor, Tensor, Tensor]:
151
+ """
152
+ Compresses a dense tensor using bitmask compression
153
+
154
+ :param tensor: dense 2D tensor to compress
155
+ :param sparsity_structure: structure of sparsity in the tensor, defaults
156
+ to unstructured, can also be set to `2:4`
157
+ :return: tuple of compressed data representing tensor
158
+ """
159
+ assert len(tensor.shape) == 2, "Only 2D tensors are supported"
160
+ assert (
161
+ SparsityStructure(sparsity_structure) == SparsityStructure.TWO_FOUR
162
+ ), "Only 2:4 sparsity is supported"
163
+
164
+ bytemasks = get_24_bytemasks(tensor=tensor)
165
+
166
+ if tensor.dtype == FP8_DTYPE:
167
+ # acces raw bytes of the tensor
168
+ tensor_view = tensor.view(torch.int8)
169
+ values = tensor_view[bytemasks]
170
+ values = values.view(FP8_DTYPE)
171
+ else:
172
+ values = tensor[bytemasks]
173
+
174
+ num_rows, num_cols = tensor.shape
175
+ compressed_values = values.reshape(num_rows, num_cols // 2)
176
+ bitmasks_packed = pack_bitmasks(bytemasks)
177
+ return compressed_values, bitmasks_packed
178
+
179
+
180
+ def sparse24_bitmask_decompress(
181
+ values: Tensor, bitmasks: Tensor, original_shape: torch.Size
182
+ ) -> Tensor:
183
+ """
184
+ Reconstructs a dense tensor from a compressed one
185
+
186
+ :param values: 1d tensor of non-zero values
187
+ :param bitmasks: 2d int8 tensor flagging locations of non-zero values in the
188
+ tensors original shape
189
+ :param original_shape: shape of the dense tensor
190
+ :return: decompressed dense tensor
191
+ """
192
+ bytemasks_unpacked = unpack_bitmasks(bitmasks, original_shape)
193
+
194
+ decompressed_tensor = torch.zeros(original_shape, dtype=values.dtype)
195
+ decompressed_tensor = decompressed_tensor.to(values.device)
196
+ values = values.flatten()
197
+ if decompressed_tensor.dtype == FP8_DTYPE:
198
+ decompressed_tensor[bytemasks_unpacked] = values
199
+ decompressed_tensor = decompressed_tensor.cuda()
200
+ else:
201
+ decompressed_tensor[bytemasks_unpacked] = values
202
+ return decompressed_tensor
203
+
204
+
205
+ def get_24_bytemasks(tensor):
206
+ """
207
+ Generate a 2:4 sparsity mask for the given tensor.
208
+
209
+ This function creates a mask where exactly 2 out of every 4 elements are
210
+ preserved based on their magnitudes. The preserved elements are the ones
211
+ with the highest absolute values in each group of 4 elements.
212
+
213
+ :param tensor: The input tensor for which the 2:4 sparsity mask is to be created.
214
+ The tensor can be of any shape but its total number of elements
215
+ must be a multiple of 4.
216
+ :return: A boolean tensor of the same shape as the input tensor, where `True`
217
+ indicates the preserved elements and `False` indicates the pruned elements.
218
+ :raises ValueError: If the total number of elements in the tensor is not a
219
+ multiple of 4.
220
+ """
221
+ original_dtype = tensor.dtype
222
+ if tensor.dtype == FP8_DTYPE:
223
+ tensor = tensor.view(torch.int8)
224
+ original_shape = tensor.shape
225
+ num_elements = tensor.numel()
226
+
227
+ if num_elements % 4 != 0:
228
+ raise ValueError("Tensor size must be a multiple of 4 for TWO_FOUR sparsity")
229
+
230
+ reshaped_tensor = tensor.view(-1, 4)
231
+ abs_tensor = reshaped_tensor.abs()
232
+ topk_indices = abs_tensor.topk(2, dim=1).indices
233
+ mask = torch.zeros_like(reshaped_tensor, dtype=torch.bool)
234
+ mask.scatter_(1, topk_indices, True)
235
+ mask = mask.view(original_shape)
236
+ tensor = tensor.view(original_dtype)
237
+
238
+ return mask
@@ -14,12 +14,12 @@
14
14
 
15
15
  from typing import Dict, List, Tuple, Union
16
16
 
17
- import numpy
18
17
  import torch
19
18
  from compressed_tensors.compressors.base import BaseCompressor
20
19
  from compressed_tensors.compressors.sparse_compressors.base import BaseSparseCompressor
21
20
  from compressed_tensors.config import CompressionFormat
22
- from compressed_tensors.utils import merge_names
21
+ from compressed_tensors.quantization import FP8_DTYPE
22
+ from compressed_tensors.utils import merge_names, pack_bitmasks, unpack_bitmasks
23
23
  from torch import Tensor
24
24
 
25
25
 
@@ -28,8 +28,6 @@ __all__ = [
28
28
  "BitmaskTensor",
29
29
  "bitmask_compress",
30
30
  "bitmask_decompress",
31
- "pack_bitmasks",
32
- "unpack_bitmasks",
33
31
  ]
34
32
 
35
33
 
@@ -134,9 +132,14 @@ def bitmask_compress(tensor: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
134
132
  bytemasks = tensor != 0
135
133
  row_counts = bytemasks.sum(dim=-1)
136
134
  row_offsets = torch.cumsum(row_counts, 0) - row_counts
137
- values = tensor[bytemasks]
135
+ if tensor.dtype == FP8_DTYPE:
136
+ # acces raw bytes of the tensor
137
+ tensor_view = tensor.view(torch.int8)
138
+ values = tensor_view[bytemasks]
139
+ values = values.view(FP8_DTYPE)
140
+ else:
141
+ values = tensor[bytemasks]
138
142
  bitmasks_packed = pack_bitmasks(bytemasks)
139
-
140
143
  return values, bitmasks_packed, row_offsets
141
144
 
142
145
 
@@ -158,37 +161,3 @@ def bitmask_decompress(
158
161
  decompressed_tensor[bytemasks_unpacked] = values
159
162
 
160
163
  return decompressed_tensor
161
-
162
-
163
- def pack_bitmasks(bytemasks: Tensor) -> Tensor:
164
- """
165
- Converts a bytemask tensor to a bitmask tensor to reduce memory. Shape RxC will be
166
- compressed to R x ceil(C/8)
167
- :param bytemasks: mask tensor where each byte corresponds to a weight
168
- :return: mask tensor where each bit corresounds to a weight
169
- """
170
- packed_bits_numpy = numpy.packbits(bytemasks.numpy(), axis=-1, bitorder="little")
171
- packed_bits_torch = torch.from_numpy(packed_bits_numpy)
172
-
173
- return packed_bits_torch
174
-
175
-
176
- def unpack_bitmasks(packed_bitmasks: Tensor, original_shape: torch.Size) -> Tensor:
177
- """
178
- Converts a bitmask tensor back to a bytemask tensor for use during decompression
179
-
180
- :param packed_bitmasks: mask tensor where each bit corresponds to a weight
181
- :param original_shape: dense shape to decompress to
182
- :return: boolean mask of weights in the original dense shape
183
- """
184
- # Unpack the bits
185
- unpacked_bits = numpy.unpackbits(
186
- packed_bitmasks.numpy(), axis=-1, count=original_shape[-1], bitorder="little"
187
- )
188
-
189
- # Reshape to match the original shape
190
- unpacked_bitmasks_torch = torch.from_numpy(
191
- unpacked_bits.reshape(original_shape).astype(bool)
192
- )
193
-
194
- return unpacked_bitmasks_torch
@@ -15,4 +15,5 @@
15
15
  # flake8: noqa
16
16
  from .base import *
17
17
  from .dense import *
18
+ from .sparse_24_bitmask import *
18
19
  from .sparse_bitmask import *
@@ -26,6 +26,7 @@ __all__ = ["SparsityCompressionConfig", "CompressionFormat", "SparsityStructure"
26
26
  class CompressionFormat(Enum):
27
27
  dense = "dense"
28
28
  sparse_bitmask = "sparse-bitmask"
29
+ sparse_24_bitmask = "sparse-24-bitmask"
29
30
  int_quantized = "int-quantized"
30
31
  float_quantized = "float-quantized"
31
32
  naive_quantized = "naive-quantized"
@@ -0,0 +1,40 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+
17
+ from compressed_tensors.config import (
18
+ CompressionFormat,
19
+ SparsityCompressionConfig,
20
+ SparsityStructure,
21
+ )
22
+
23
+
24
+ __all__ = ["Sparse24BitMaskConfig"]
25
+
26
+
27
+ @SparsityCompressionConfig.register(name=CompressionFormat.sparse_24_bitmask.value)
28
+ class Sparse24BitMaskConfig(SparsityCompressionConfig):
29
+ """
30
+ Configuration for storing a 24 sparse model using
31
+ bytemask compression
32
+
33
+ :param global_sparsity: average sparsity of the entire model
34
+ :param sparsity_structure: structure of the sparsity, should always be
35
+ "2:4" for this compression format
36
+ """
37
+
38
+ format: str = CompressionFormat.sparse_24_bitmask.value
39
+ global_sparsity: Optional[float] = 0.0
40
+ sparsity_structure: Optional[str] = SparsityStructure.TWO_FOUR.value
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from typing import Dict, Tuple
16
+
15
17
  import torch
16
18
  from compressed_tensors.compressors.base import BaseCompressor
17
19
  from compressed_tensors.quantization import (
@@ -53,7 +55,7 @@ class CompressedLinear(Linear):
53
55
  )
54
56
 
55
57
  # get the shape and dtype of compressed parameters
56
- compression_params = module.compressor.compression_param_info(
58
+ compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
57
59
  module.weight.shape, quantization_scheme.weights
58
60
  )
59
61
 
@@ -18,7 +18,7 @@ from collections import OrderedDict, defaultdict
18
18
  from copy import deepcopy
19
19
  from typing import Dict, Iterable, List, Optional
20
20
  from typing import OrderedDict as OrderedDictType
21
- from typing import Union
21
+ from typing import Set, Union
22
22
 
23
23
  import torch
24
24
  from compressed_tensors.config import CompressionFormat
@@ -52,6 +52,8 @@ __all__ = [
52
52
  "apply_quantization_config",
53
53
  "apply_quantization_status",
54
54
  "find_name_or_class_matches",
55
+ "expand_sparse_target_names",
56
+ "is_sparse_target",
55
57
  ]
56
58
 
57
59
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -106,7 +108,8 @@ def apply_quantization_config(
106
108
  model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
107
109
  ) -> OrderedDict:
108
110
  """
109
- Initializes the model for quantization in-place based on the given config
111
+ Initializes the model for quantization in-place based on the given config.
112
+ Optionally coverts quantizable modules to compressed_linear modules
110
113
 
111
114
  :param model: model to apply quantization config to
112
115
  :param config: quantization config
@@ -244,6 +247,49 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
244
247
  model.apply(compress_quantized_weights)
245
248
 
246
249
 
250
+ def expand_sparse_target_names(
251
+ model: Module, targets: Iterable[str], ignore: Iterable[str]
252
+ ) -> Set[str]:
253
+ """
254
+ Finds all unique module names in the model that match the given
255
+ targets and ignore lists.
256
+
257
+ Note: Targets must be regexes, layer types, or full layer names.
258
+
259
+ :param model: model to search for targets in
260
+ :param targets: list of targets to search for
261
+ :param ignore: list of targets to ignore
262
+ :return: set of all targets that match the given targets and should
263
+ not be ignored
264
+ """
265
+ return {
266
+ name
267
+ for name, module in iter_named_leaf_modules(model)
268
+ if is_sparse_target(name, module, targets, ignore)
269
+ }
270
+
271
+
272
+ def is_sparse_target(
273
+ name: str, module: Module, targets: Iterable[str], ignore: Iterable[str]
274
+ ) -> bool:
275
+ """
276
+ Determines if a module should be included in the targets based on the
277
+ targets and ignore lists.
278
+
279
+ Note: Targets must be regexes, layer types, or full layer names.
280
+
281
+ :param name: name of the module
282
+ :param module: the module itself
283
+ :param targets: list of targets to search for
284
+ :param ignore: list of targets to ignore
285
+ :return: True if the module is a target and not ignored, False otherwise
286
+ """
287
+ return bool(
288
+ find_name_or_class_matches(name, module, targets)
289
+ and not find_name_or_class_matches(name, module, ignore or [])
290
+ )
291
+
292
+
247
293
  def find_name_or_class_matches(
248
294
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
249
295
  ) -> List[str]:
@@ -82,8 +82,8 @@ def quantize(
82
82
  def dequantize(
83
83
  x_q: torch.Tensor,
84
84
  scale: torch.Tensor,
85
- zero_point: torch.Tensor = None,
86
- args: QuantizationArgs = None,
85
+ zero_point: Optional[torch.Tensor] = None,
86
+ args: Optional[QuantizationArgs] = None,
87
87
  dtype: Optional[torch.dtype] = None,
88
88
  g_idx: Optional[torch.Tensor] = None,
89
89
  ) -> torch.Tensor:
@@ -29,7 +29,11 @@ from compressed_tensors.quantization.quant_args import (
29
29
  from compressed_tensors.quantization.quant_config import QuantizationStatus
30
30
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
31
  from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
32
- from compressed_tensors.utils import get_execution_device, is_module_offloaded
32
+ from compressed_tensors.utils import (
33
+ disable_hf_hook,
34
+ has_offloaded_params,
35
+ register_offload_parameter,
36
+ )
33
37
  from torch.nn import Module, Parameter
34
38
 
35
39
 
@@ -112,43 +116,10 @@ def initialize_module_for_quantization(
112
116
  module.quantization_scheme = scheme
113
117
  module.quantization_status = QuantizationStatus.INITIALIZED
114
118
 
115
- offloaded = False
116
- # What is this doing/why isn't this in the attn case?
117
- if is_module_offloaded(module):
118
- try:
119
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
120
- from accelerate.utils import PrefixedDataset
121
- except ModuleNotFoundError:
122
- raise ModuleNotFoundError(
123
- "Offloaded model detected. To use CPU offloading with "
124
- "compressed-tensors the `accelerate` package must be installed, "
125
- "run `pip install compressed-tensors[accelerate]`"
126
- )
127
-
128
- offloaded = True
129
- hook = module._hf_hook
130
- prefix_dict = module._hf_hook.weights_map
131
- new_prefix = {}
132
-
133
- # recreate the prefix dict (since it is immutable)
134
- # and add quantization parameters
135
- for key, data in module.named_parameters():
136
- if key not in prefix_dict:
137
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
138
- else:
139
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
140
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
141
- remove_hook_from_module(module)
142
-
143
- # wrap forward call of module to perform
144
- # quantized actions based on calltime status
145
- wrap_module_forward_quantized(module, scheme)
146
-
147
- if offloaded:
148
- # we need to re-add the hook for offloading now that we've wrapped forward
149
- add_hook_to_module(module, hook)
150
- if prefix_dict is not None:
151
- module._hf_hook.weights_map = new_prefix_dict
119
+ with disable_hf_hook(module):
120
+ # wrap forward call of module to perform
121
+ # quantized actions based on calltime status
122
+ wrap_module_forward_quantized(module, scheme)
152
123
 
153
124
 
154
125
  def is_attention_module(module: Module):
@@ -169,12 +140,17 @@ def _initialize_scale_zero_point(
169
140
  if quantization_args.dynamic:
170
141
  return
171
142
 
172
- device = next(module.parameters()).device
173
- if is_module_offloaded(module):
174
- device = get_execution_device(module)
143
+ # begin on the same device as other parameters or cpu if offloaded.
144
+ # in the offloaded case, there's no point moving tensors to the execution device
145
+ # if they're going to be immediately offloaded by `register_offload_parameter`
146
+ params_device = next(module.parameters()).device
147
+ device = "cpu" if has_offloaded_params(module) else params_device
175
148
 
176
149
  # infer expected scale/zero point shape
177
- expected_shape = 1 # per tensor
150
+ if quantization_args.strategy == QuantizationStrategy.TOKEN:
151
+ expected_shape = (1, 1)
152
+ else:
153
+ expected_shape = 1
178
154
 
179
155
  if base_name == "weight" and weight_shape is not None:
180
156
  if quantization_args.strategy == QuantizationStrategy.CHANNEL:
@@ -193,7 +169,7 @@ def _initialize_scale_zero_point(
193
169
  torch.empty(expected_shape, dtype=scale_dtype, device=device),
194
170
  requires_grad=False,
195
171
  )
196
- module.register_parameter(f"{base_name}_scale", init_scale)
172
+ register_offload_parameter(module, f"{base_name}_scale", init_scale)
197
173
 
198
174
  if force_zero_point or not quantization_args.symmetric:
199
175
  zp_dtype = quantization_args.pytorch_dtype()
@@ -201,7 +177,7 @@ def _initialize_scale_zero_point(
201
177
  torch.zeros(expected_shape, device=device, dtype=zp_dtype),
202
178
  requires_grad=False,
203
179
  )
204
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
180
+ register_offload_parameter(module, f"{base_name}_zero_point", init_zero_point)
205
181
 
206
182
  # only grouped activation ordering has g_idx
207
183
  if quantization_args.actorder == ActivationOrdering.GROUP:
@@ -211,7 +187,7 @@ def _initialize_scale_zero_point(
211
187
  torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
212
188
  requires_grad=False,
213
189
  )
214
- module.register_parameter(f"{base_name}_g_idx", init_g_idx)
190
+ register_offload_parameter(module, f"{base_name}_g_idx", init_g_idx)
215
191
 
216
192
 
217
193
  def _initialize_attn_scales(module: Module) -> None:
@@ -17,6 +17,7 @@ from enum import Enum
17
17
  from typing import Any, Dict, Optional, Union
18
18
 
19
19
  import torch
20
+ from compressed_tensors.utils import Aliasable
20
21
  from pydantic import BaseModel, Field, field_validator, model_validator
21
22
 
22
23
 
@@ -53,17 +54,29 @@ class QuantizationStrategy(str, Enum):
53
54
  TOKEN = "token"
54
55
 
55
56
 
56
- class ActivationOrdering(str, Enum):
57
+ class ActivationOrdering(Aliasable, str, Enum):
57
58
  """
58
59
  Enum storing strategies for activation ordering
59
60
 
60
61
  Group: reorder groups and weight\n
61
- Weight: only reorder weight, not groups. Slightly lower latency and
62
- accuracy compared to group actorder\n
62
+ Weight: only reorder weight, not groups. Slightly lower accuracy but also lower
63
+ latency when compared to group actorder\n
64
+ Dynamic: alias for Group\n
65
+ Static: alias for Weight\n
63
66
  """
64
67
 
65
68
  GROUP = "group"
66
69
  WEIGHT = "weight"
70
+ # aliases
71
+ DYNAMIC = "dynamic"
72
+ STATIC = "static"
73
+
74
+ @staticmethod
75
+ def get_aliases() -> Dict[str, str]:
76
+ return {
77
+ "dynamic": "group",
78
+ "static": "weight",
79
+ }
67
80
 
68
81
 
69
82
  class QuantizationArgs(BaseModel, use_enum_values=True):
@@ -132,9 +132,9 @@ class QuantizationConfig(BaseModel):
132
132
  `k_proj` and `v_proj` in their names. If this is not the case
133
133
  and kv_cache_scheme != None, the quantization of kv cache will fail
134
134
  :global_compression_ratio: optional informational config to report the model
135
- compression ratio acheived by the quantization config
135
+ compression ratio acheived by the quantization config
136
136
  :ignore: optional list of layers to ignore from config_groups. Layers in this list
137
- are not quantized even if they match up with a target in config_groups
137
+ are not quantized even if they match up with a target in config_groups
138
138
  """
139
139
 
140
140
  config_groups: Dict[str, Union[QuantizationScheme, List[str]]]
@@ -160,7 +160,7 @@ class QuantizationConfig(BaseModel):
160
160
 
161
161
  def to_dict(self):
162
162
  # for compatibility with HFQuantizer
163
- return self.dict()
163
+ return self.model_dump()
164
164
 
165
165
  @staticmethod
166
166
  def from_pretrained(