compressed-tensors-nightly 0.5.0.20240814__py3-none-any.whl → 0.5.0.20240830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. compressed_tensors/compressors/base.py +198 -8
  2. compressed_tensors/compressors/model_compressor.py +65 -1
  3. compressed_tensors/compressors/naive_quantized.py +71 -75
  4. compressed_tensors/compressors/pack_quantized.py +83 -94
  5. compressed_tensors/linear/__init__.py +13 -0
  6. compressed_tensors/linear/compressed_linear.py +87 -0
  7. compressed_tensors/quantization/lifecycle/apply.py +36 -4
  8. compressed_tensors/quantization/lifecycle/calibration.py +3 -2
  9. compressed_tensors/quantization/lifecycle/compressed.py +1 -1
  10. compressed_tensors/quantization/lifecycle/forward.py +67 -43
  11. compressed_tensors/quantization/lifecycle/helpers.py +29 -2
  12. compressed_tensors/quantization/lifecycle/initialize.py +50 -16
  13. compressed_tensors/quantization/observers/__init__.py +1 -0
  14. compressed_tensors/quantization/observers/base.py +54 -14
  15. compressed_tensors/quantization/observers/min_max.py +8 -0
  16. compressed_tensors/quantization/observers/mse.py +162 -0
  17. compressed_tensors/quantization/quant_args.py +48 -20
  18. compressed_tensors/utils/__init__.py +1 -0
  19. compressed_tensors/utils/helpers.py +13 -0
  20. compressed_tensors/utils/offload.py +7 -1
  21. compressed_tensors/utils/permute.py +70 -0
  22. compressed_tensors/utils/safetensors_load.py +2 -0
  23. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  24. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/METADATA +3 -2
  25. compressed_tensors_nightly-0.5.0.20240830.dist-info/RECORD +52 -0
  26. compressed_tensors_nightly-0.5.0.20240814.dist-info/RECORD +0 -48
  27. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/LICENSE +0 -0
  28. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/WHEEL +0 -0
  29. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- import logging
16
14
  import math
17
- from typing import Dict, Generator, Tuple
15
+ from typing import Dict, Optional, Tuple
18
16
 
19
17
  import numpy as np
20
18
  import torch
@@ -23,16 +21,11 @@ from compressed_tensors.config import CompressionFormat
23
21
  from compressed_tensors.quantization import QuantizationArgs
24
22
  from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25
23
  from compressed_tensors.quantization.utils import can_quantize
26
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
27
- from safetensors import safe_open
28
24
  from torch import Tensor
29
- from tqdm import tqdm
30
25
 
31
26
 
32
27
  __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
33
28
 
34
- _LOGGER: logging.Logger = logging.getLogger(__name__)
35
-
36
29
 
37
30
  @Compressor.register(name=CompressionFormat.pack_quantized.value)
38
31
  class PackedQuantizationCompressor(Compressor):
@@ -44,102 +37,96 @@ class PackedQuantizationCompressor(Compressor):
44
37
  "weight_packed",
45
38
  "weight_scale",
46
39
  "weight_zero_point",
40
+ "weight_g_idx",
47
41
  "weight_shape",
48
42
  ]
49
43
 
50
- def compress(
44
+ def compression_param_info(
51
45
  self,
52
- model_state: Dict[str, Tensor],
53
- names_to_scheme: Dict[str, QuantizationArgs],
54
- **kwargs,
55
- ) -> Dict[str, Tensor]:
46
+ weight_shape: torch.Size,
47
+ quantization_args: Optional[QuantizationArgs] = None,
48
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
56
49
  """
57
- Compresses a dense state dict
50
+ Creates a dictionary of expected shapes and dtypes for each compression
51
+ parameter used by the compressor
58
52
 
59
- :param model_state: state dict of uncompressed model
60
- :param names_to_scheme: quantization args for each quantized weight, needed for
61
- quantize function to calculate bit depth
62
- :return: compressed state dict
53
+ :param weight_shape: uncompressed weight shape
54
+ :param quantization_args: quantization parameters for the weight
55
+ :return: dictionary mapping compressed parameter names to shape and dtype
56
+ """
57
+ pack_factor = 32 // quantization_args.num_bits
58
+ packed_size = math.ceil(weight_shape[1] / pack_factor)
59
+ return {
60
+ "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
61
+ "weight_shape": (torch.Size((2,)), torch.int32),
62
+ }
63
+
64
+ def compress_weight(
65
+ self,
66
+ weight: Tensor,
67
+ scale: Tensor,
68
+ zero_point: Optional[Tensor] = None,
69
+ g_idx: Optional[torch.Tensor] = None,
70
+ quantization_args: Optional[QuantizationArgs] = None,
71
+ device: Optional[torch.device] = None,
72
+ ) -> Dict[str, torch.Tensor]:
73
+ """
74
+ Compresses a single uncompressed weight
75
+
76
+ :param weight: uncompressed weight tensor
77
+ :param scale: quantization scale for weight
78
+ :param zero_point: quantization zero point for weight
79
+ :param g_idx: optional mapping from column index to group index
80
+ :param quantization_args: quantization parameters for weight
81
+ :param device: optional device to move compressed output to
82
+ :return: dictionary of compressed weight data
63
83
  """
64
84
  compressed_dict = {}
65
- weight_suffix = ".weight"
66
- _LOGGER.debug(
67
- f"Compressing model with {len(model_state)} parameterized layers..."
68
- )
69
-
70
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
71
- if name.endswith(weight_suffix):
72
- prefix = name[: -(len(weight_suffix))]
73
- scale = model_state.get(merge_names(prefix, "weight_scale"), None)
74
- zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
75
- shape = torch.tensor(value.shape)
76
- if scale is not None and zp is not None:
77
- # weight is quantized, compress it
78
- quant_args = names_to_scheme[prefix]
79
- if can_quantize(value, quant_args):
80
- # convert weight to an int if not already compressed
81
- value = quantize(
82
- x=value,
83
- scale=scale,
84
- zero_point=zp,
85
- args=quant_args,
86
- dtype=torch.int8,
87
- )
88
- value = pack_to_int32(value.cpu(), quant_args.num_bits)
89
- compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
- compressed_dict[merge_names(prefix, "weight_packed")] = value
91
- continue
92
-
93
- elif name.endswith("zero_point"):
94
- if torch.all(value == 0):
95
- # all zero_points are 0, no need to include in
96
- # compressed state_dict
97
- continue
98
-
99
- compressed_dict[name] = value.to("cpu")
85
+ if can_quantize(weight, quantization_args):
86
+ quantized_weight = quantize(
87
+ x=weight,
88
+ scale=scale,
89
+ zero_point=zero_point,
90
+ g_idx=g_idx,
91
+ args=quantization_args,
92
+ dtype=torch.int8,
93
+ )
94
+
95
+ packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
96
+ weight_shape = torch.tensor(weight.shape)
97
+ if device is not None:
98
+ packed_weight = packed_weight.to(device)
99
+ weight_shape = weight_shape.to(device)
100
+
101
+ compressed_dict["weight_shape"] = weight_shape
102
+ compressed_dict["weight_packed"] = packed_weight
100
103
 
101
104
  return compressed_dict
102
105
 
103
- def decompress(
106
+ def decompress_weight(
104
107
  self,
105
- path_to_model_or_tensors: str,
106
- names_to_scheme: Dict[str, QuantizationArgs],
107
- device: str = "cpu",
108
- ) -> Generator[Tuple[str, Tensor], None, None]:
108
+ compressed_data: Dict[str, Tensor],
109
+ quantization_args: Optional[QuantizationArgs] = None,
110
+ ) -> torch.Tensor:
109
111
  """
110
- Reads a compressed state dict located at path_to_model_or_tensors
111
- and returns a generator for sequentially decompressing back to a
112
- dense state dict
113
-
114
- :param model_path: path to compressed safetensors model (directory with
115
- one or more safetensors files) or compressed tensors file
116
- :param device: optional device to load intermediate weights into
117
- :return: compressed state dict
112
+ Decompresses a single compressed weight
113
+
114
+ :param compressed_data: dictionary of data needed for decompression
115
+ :param quantization_args: quantization parameters for the weight
116
+ :return: tensor of the decompressed weight
118
117
  """
119
- weight_mappings = get_nested_weight_mappings(
120
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
118
+ weight = compressed_data["weight_packed"]
119
+ scale = compressed_data["weight_scale"]
120
+ zero_point = compressed_data.get("weight_zero_point", None)
121
+ g_idx = compressed_data.get("weight_g_idx", None)
122
+ original_shape = torch.Size(compressed_data["weight_shape"])
123
+ num_bits = quantization_args.num_bits
124
+ unpacked = unpack_from_int32(weight, num_bits, original_shape)
125
+ decompressed_weight = dequantize(
126
+ x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
121
127
  )
122
- for weight_name in weight_mappings.keys():
123
- weight_data = {}
124
- for param_name, safe_path in weight_mappings[weight_name].items():
125
- weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
126
- full_name = merge_names(weight_name, param_name)
127
- with safe_open(safe_path, framework="pt", device=device) as f:
128
- weight_data[param_name] = f.get_tensor(full_name)
129
-
130
- if "weight_scale" in weight_data:
131
- zero_point = weight_data.get("weight_zero_point", None)
132
- scale = weight_data["weight_scale"]
133
- weight = weight_data["weight_packed"]
134
- num_bits = weight_data["num_bits"]
135
- original_shape = torch.Size(weight_data["weight_shape"])
136
- unpacked = unpack_from_int32(weight, num_bits, original_shape)
137
- decompressed = dequantize(
138
- x_q=unpacked,
139
- scale=scale,
140
- zero_point=zero_point,
141
- )
142
- yield merge_names(weight_name, "weight"), decompressed
128
+
129
+ return decompressed_weight
143
130
 
144
131
 
145
132
  def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
@@ -197,13 +184,15 @@ def unpack_from_int32(
197
184
  if num_bits > 8:
198
185
  raise ValueError("Unpacking is only supported for less than 8 bits")
199
186
 
200
- # convert packed input to unsigned numpy
201
- value = value.numpy().view(np.uint32)
202
187
  pack_factor = 32 // num_bits
203
188
 
204
189
  # unpack
205
190
  mask = pow(2, num_bits) - 1
206
- unpacked = np.zeros((value.shape[0], value.shape[1] * pack_factor))
191
+ unpacked = torch.zeros(
192
+ (value.shape[0], value.shape[1] * pack_factor),
193
+ device=value.device,
194
+ dtype=torch.int32,
195
+ )
207
196
  for i in range(pack_factor):
208
197
  unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
209
198
 
@@ -214,6 +203,6 @@ def unpack_from_int32(
214
203
  # bits are packed in unsigned format, reformat to signed
215
204
  # update the value range from unsigned to signed
216
205
  offset = pow(2, num_bits) // 2
217
- unpacked = (unpacked.astype(np.int16) - offset).astype(np.int8)
206
+ unpacked = (unpacked - offset).to(torch.int8)
218
207
 
219
- return torch.from_numpy(unpacked)
208
+ return unpacked
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,87 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.compressors.base import Compressor
17
+ from compressed_tensors.quantization import (
18
+ QuantizationScheme,
19
+ QuantizationStatus,
20
+ initialize_module_for_quantization,
21
+ )
22
+ from torch import Tensor
23
+ from torch.nn import Parameter
24
+ from torch.nn.functional import linear
25
+ from torch.nn.modules import Linear
26
+
27
+
28
+ class CompressedLinear(Linear):
29
+ """
30
+ Wrapper module for running a compressed forward pass of a quantized Linear module.
31
+ The wrapped layer will decompressed on each forward call.
32
+
33
+ :param module: dense linear module to replace
34
+ :param quantization_scheme: quantization config for the module to wrap
35
+ :param quantization_format: compression format module is stored as
36
+ """
37
+
38
+ @classmethod
39
+ @torch.no_grad()
40
+ def from_linear(
41
+ cls,
42
+ module: Linear,
43
+ quantization_scheme: QuantizationScheme,
44
+ quantization_format: str,
45
+ ):
46
+ module.__class__ = CompressedLinear
47
+ module.compressor = Compressor.load_from_registry(quantization_format)
48
+ device = next(module.parameters()).device
49
+
50
+ # this will initialize all the scales and zero points
51
+ initialize_module_for_quantization(
52
+ module, quantization_scheme, force_zero_point=False
53
+ )
54
+
55
+ # get the shape and dtype of compressed parameters
56
+ compression_params = module.compressor.compression_param_info(
57
+ module.weight.shape, quantization_scheme.weights
58
+ )
59
+
60
+ # no need for this once quantization is initialized, will be replaced
61
+ # with the compressed parameter
62
+ delattr(module, "weight")
63
+
64
+ # populate compressed weights and quantization parameters
65
+ for name, (shape, dtype) in compression_params.items():
66
+ param = Parameter(
67
+ torch.empty(shape, device=device, dtype=dtype), requires_grad=False
68
+ )
69
+ module.register_parameter(name, param)
70
+
71
+ # mark module as compressed
72
+ module.quantization_status = QuantizationStatus.COMPRESSED
73
+
74
+ # handles case where forward is wrapped in new_forward by accelerate hooks
75
+ if hasattr(module, "_old_forward"):
76
+ module._old_forward = CompressedLinear.forward.__get__(
77
+ module, CompressedLinear
78
+ )
79
+
80
+ return module
81
+
82
+ def forward(self, input: Tensor) -> Tensor:
83
+ """
84
+ Decompresses the weight, then runs the wrapped forward pass
85
+ """
86
+ uncompressed_weight = self.compressor.decompress_module(self)
87
+ return linear(input, uncompressed_weight, self.bias)
@@ -21,6 +21,7 @@ from typing import OrderedDict as OrderedDictType
21
21
  from typing import Union
22
22
 
23
23
  import torch
24
+ from compressed_tensors.config import CompressionFormat
24
25
  from compressed_tensors.quantization.lifecycle.calibration import (
25
26
  set_module_for_calibration,
26
27
  )
@@ -43,7 +44,7 @@ from compressed_tensors.quantization.utils import (
43
44
  is_kv_cache_quant_scheme,
44
45
  iter_named_leaf_modules,
45
46
  )
46
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
47
48
  from compressed_tensors.utils.offload import update_parameter_data
48
49
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
49
50
  from torch.nn import Module
@@ -104,12 +105,16 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
104
105
  )
105
106
 
106
107
 
107
- def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
108
+ def apply_quantization_config(
109
+ model: Module, config: QuantizationConfig, run_compressed: bool = False
110
+ ) -> Dict:
108
111
  """
109
112
  Initializes the model for quantization in-place based on the given config
110
113
 
111
114
  :param model: model to apply quantization config to
112
115
  :param config: quantization config
116
+ :param run_compressed: Whether the model will be run in compressed mode or
117
+ decompressed fully on load
113
118
  """
114
119
  # remove reference to the original `config`
115
120
  # argument. This function can mutate it, and we'd
@@ -124,6 +129,9 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
124
129
  for target in scheme.targets:
125
130
  target_to_scheme[target] = scheme
126
131
 
132
+ if run_compressed:
133
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
134
+
127
135
  # list of submodules to ignore
128
136
  ignored_submodules = defaultdict(list)
129
137
  # mark appropriate layers for quantization by setting their quantization schemes
@@ -136,10 +144,24 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
136
144
  continue # layer matches ignore list, continue
137
145
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
138
146
  if targets:
147
+ scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
+ if run_compressed:
149
+ format = config.format
150
+ if format != CompressionFormat.dense.value:
151
+ if isinstance(submodule, torch.nn.Linear):
152
+ # TODO: expand to more module types
153
+ compressed_linear = CompressedLinear.from_linear(
154
+ submodule,
155
+ quantization_scheme=scheme,
156
+ quantization_format=format,
157
+ )
158
+ replace_module(model, name, compressed_linear)
159
+
139
160
  # target matched - add layer and scheme to target list
140
161
  submodule.quantization_scheme = _scheme_from_targets(
141
162
  target_to_scheme, targets, name
142
163
  )
164
+
143
165
  names_to_scheme[name] = submodule.quantization_scheme.weights
144
166
 
145
167
  if config.ignore is not None and ignored_submodules is not None:
@@ -149,8 +171,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
149
171
  "not found in the model: "
150
172
  f"{set(config.ignore) - set(ignored_submodules)}"
151
173
  )
152
- # apply current quantization status across all targeted layers
153
174
 
175
+ # apply current quantization status across all targeted layers
154
176
  apply_quantization_status(model, config.quantization_status)
155
177
  return names_to_scheme
156
178
 
@@ -198,7 +220,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
198
220
  current_status = infer_quantization_status(model)
199
221
 
200
222
  if status >= QuantizationStatus.INITIALIZED > current_status:
201
- model.apply(initialize_module_for_quantization)
223
+ force_zero_point_init = status != QuantizationStatus.COMPRESSED
224
+ model.apply(
225
+ lambda module: initialize_module_for_quantization(
226
+ module, force_zero_point=force_zero_point_init
227
+ )
228
+ )
202
229
 
203
230
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
204
231
  # only quantize weights up front when our end goal state is calibration,
@@ -279,9 +306,11 @@ def _load_quant_args_from_state_dict(
279
306
  """
280
307
  scale_name = f"{base_name}_scale"
281
308
  zp_name = f"{base_name}_zero_point"
309
+ g_idx_name = f"{base_name}_g_idx"
282
310
 
283
311
  state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
284
312
  state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
313
+ state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
285
314
 
286
315
  if state_dict_scale is not None:
287
316
  # module is quantized
@@ -291,6 +320,9 @@ def _load_quant_args_from_state_dict(
291
320
  state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
292
321
  update_parameter_data(module, state_dict_zp, zp_name)
293
322
 
323
+ if state_dict_g_idx is not None:
324
+ update_parameter_data(module, state_dict_g_idx, g_idx_name)
325
+
294
326
 
295
327
  def _scheme_from_targets(
296
328
  target_to_scheme: OrderedDictType[str, QuantizationScheme],
@@ -44,7 +44,7 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
44
44
  return
45
45
  status = getattr(module, "quantization_status", None)
46
46
  if not status or status != QuantizationStatus.INITIALIZED:
47
- raise _LOGGER.warning(
47
+ _LOGGER.warning(
48
48
  f"Attempting set module with status {status} to calibration mode. "
49
49
  f"but status is not {QuantizationStatus.INITIALIZED} - you may "
50
50
  "be calibrating an uninitialized module which may fail or attempting "
@@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
54
54
  if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
55
  # set weight scale and zero_point up front, calibration data doesn't affect it
56
56
  observer = module.weight_observer
57
+ g_idx = getattr(module, "weight_g_idx", None)
57
58
 
58
59
  offloaded = False
59
60
  if is_module_offloaded(module):
60
61
  module._hf_hook.pre_forward(module)
61
62
  offloaded = True
62
63
 
63
- scale, zero_point = observer(module.weight)
64
+ scale, zero_point = observer(module.weight, g_idx=g_idx)
64
65
  update_parameter_data(module, scale, "weight_scale")
65
66
  update_parameter_data(module, zero_point, "weight_zero_point")
66
67
 
@@ -50,7 +50,7 @@ def compress_quantized_weights(module: Module):
50
50
  scale = getattr(module, "weight_scale", None)
51
51
  zero_point = getattr(module, "weight_zero_point", None)
52
52
 
53
- if weight is None or scale is None or zero_point is None:
53
+ if weight is None or scale is None:
54
54
  # no weight, scale, or ZP, nothing to do
55
55
 
56
56
  # mark as compressed here to maintain consistent status throughout the model