compressed-tensors 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. compressed_tensors/compressors/base.py +200 -8
  2. compressed_tensors/compressors/model_compressor.py +68 -1
  3. compressed_tensors/compressors/naive_quantized.py +71 -75
  4. compressed_tensors/compressors/pack_quantized.py +83 -94
  5. compressed_tensors/config/base.py +6 -1
  6. compressed_tensors/linear/__init__.py +13 -0
  7. compressed_tensors/linear/compressed_linear.py +87 -0
  8. compressed_tensors/quantization/lifecycle/apply.py +46 -8
  9. compressed_tensors/quantization/lifecycle/calibration.py +5 -4
  10. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  11. compressed_tensors/quantization/lifecycle/forward.py +76 -43
  12. compressed_tensors/quantization/lifecycle/helpers.py +29 -2
  13. compressed_tensors/quantization/lifecycle/initialize.py +51 -16
  14. compressed_tensors/quantization/observers/__init__.py +1 -0
  15. compressed_tensors/quantization/observers/base.py +54 -14
  16. compressed_tensors/quantization/observers/min_max.py +8 -0
  17. compressed_tensors/quantization/observers/mse.py +162 -0
  18. compressed_tensors/quantization/quant_args.py +96 -24
  19. compressed_tensors/quantization/quant_scheme.py +7 -9
  20. compressed_tensors/quantization/utils/helpers.py +1 -1
  21. compressed_tensors/utils/__init__.py +1 -0
  22. compressed_tensors/utils/helpers.py +13 -0
  23. compressed_tensors/utils/offload.py +14 -2
  24. compressed_tensors/utils/permute.py +70 -0
  25. compressed_tensors/utils/safetensors_load.py +2 -0
  26. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  27. compressed_tensors/version.py +1 -1
  28. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -23
  29. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  30. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  31. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  32. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  33. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -11,10 +11,8 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
15
- import logging
16
14
  import math
17
- from typing import Dict, Generator, Tuple
15
+ from typing import Dict, Optional, Tuple
18
16
 
19
17
  import numpy as np
20
18
  import torch
@@ -23,16 +21,11 @@ from compressed_tensors.config import CompressionFormat
23
21
  from compressed_tensors.quantization import QuantizationArgs
24
22
  from compressed_tensors.quantization.lifecycle.forward import dequantize, quantize
25
23
  from compressed_tensors.quantization.utils import can_quantize
26
- from compressed_tensors.utils import get_nested_weight_mappings, merge_names
27
- from safetensors import safe_open
28
24
  from torch import Tensor
29
- from tqdm import tqdm
30
25
 
31
26
 
32
27
  __all__ = ["PackedQuantizationCompressor", "pack_to_int32", "unpack_from_int32"]
33
28
 
34
- _LOGGER: logging.Logger = logging.getLogger(__name__)
35
-
36
29
 
37
30
  @Compressor.register(name=CompressionFormat.pack_quantized.value)
38
31
  class PackedQuantizationCompressor(Compressor):
@@ -44,102 +37,96 @@ class PackedQuantizationCompressor(Compressor):
44
37
  "weight_packed",
45
38
  "weight_scale",
46
39
  "weight_zero_point",
40
+ "weight_g_idx",
47
41
  "weight_shape",
48
42
  ]
49
43
 
50
- def compress(
44
+ def compression_param_info(
51
45
  self,
52
- model_state: Dict[str, Tensor],
53
- names_to_scheme: Dict[str, QuantizationArgs],
54
- **kwargs,
55
- ) -> Dict[str, Tensor]:
46
+ weight_shape: torch.Size,
47
+ quantization_args: Optional[QuantizationArgs] = None,
48
+ ) -> Dict[str, Tuple[torch.Size, torch.dtype]]:
56
49
  """
57
- Compresses a dense state dict
50
+ Creates a dictionary of expected shapes and dtypes for each compression
51
+ parameter used by the compressor
58
52
 
59
- :param model_state: state dict of uncompressed model
60
- :param names_to_scheme: quantization args for each quantized weight, needed for
61
- quantize function to calculate bit depth
62
- :return: compressed state dict
53
+ :param weight_shape: uncompressed weight shape
54
+ :param quantization_args: quantization parameters for the weight
55
+ :return: dictionary mapping compressed parameter names to shape and dtype
56
+ """
57
+ pack_factor = 32 // quantization_args.num_bits
58
+ packed_size = math.ceil(weight_shape[1] / pack_factor)
59
+ return {
60
+ "weight_packed": (torch.Size((weight_shape[0], packed_size)), torch.int32),
61
+ "weight_shape": (torch.Size((2,)), torch.int32),
62
+ }
63
+
64
+ def compress_weight(
65
+ self,
66
+ weight: Tensor,
67
+ scale: Tensor,
68
+ zero_point: Optional[Tensor] = None,
69
+ g_idx: Optional[torch.Tensor] = None,
70
+ quantization_args: Optional[QuantizationArgs] = None,
71
+ device: Optional[torch.device] = None,
72
+ ) -> Dict[str, torch.Tensor]:
73
+ """
74
+ Compresses a single uncompressed weight
75
+
76
+ :param weight: uncompressed weight tensor
77
+ :param scale: quantization scale for weight
78
+ :param zero_point: quantization zero point for weight
79
+ :param g_idx: optional mapping from column index to group index
80
+ :param quantization_args: quantization parameters for weight
81
+ :param device: optional device to move compressed output to
82
+ :return: dictionary of compressed weight data
63
83
  """
64
84
  compressed_dict = {}
65
- weight_suffix = ".weight"
66
- _LOGGER.debug(
67
- f"Compressing model with {len(model_state)} parameterized layers..."
68
- )
69
-
70
- for name, value in tqdm(model_state.items(), desc="Compressing model"):
71
- if name.endswith(weight_suffix):
72
- prefix = name[: -(len(weight_suffix))]
73
- scale = model_state.get(merge_names(prefix, "weight_scale"), None)
74
- zp = model_state.get(merge_names(prefix, "weight_zero_point"), None)
75
- shape = torch.tensor(value.shape)
76
- if scale is not None and zp is not None:
77
- # weight is quantized, compress it
78
- quant_args = names_to_scheme[prefix]
79
- if can_quantize(value, quant_args):
80
- # convert weight to an int if not already compressed
81
- value = quantize(
82
- x=value,
83
- scale=scale,
84
- zero_point=zp,
85
- args=quant_args,
86
- dtype=torch.int8,
87
- )
88
- value = pack_to_int32(value.cpu(), quant_args.num_bits)
89
- compressed_dict[merge_names(prefix, "weight_shape")] = shape
90
- compressed_dict[merge_names(prefix, "weight_packed")] = value
91
- continue
92
-
93
- elif name.endswith("zero_point"):
94
- if torch.all(value == 0):
95
- # all zero_points are 0, no need to include in
96
- # compressed state_dict
97
- continue
98
-
99
- compressed_dict[name] = value.to("cpu")
85
+ if can_quantize(weight, quantization_args):
86
+ quantized_weight = quantize(
87
+ x=weight,
88
+ scale=scale,
89
+ zero_point=zero_point,
90
+ g_idx=g_idx,
91
+ args=quantization_args,
92
+ dtype=torch.int8,
93
+ )
94
+
95
+ packed_weight = pack_to_int32(quantized_weight, quantization_args.num_bits)
96
+ weight_shape = torch.tensor(weight.shape)
97
+ if device is not None:
98
+ packed_weight = packed_weight.to(device)
99
+ weight_shape = weight_shape.to(device)
100
+
101
+ compressed_dict["weight_shape"] = weight_shape
102
+ compressed_dict["weight_packed"] = packed_weight
100
103
 
101
104
  return compressed_dict
102
105
 
103
- def decompress(
106
+ def decompress_weight(
104
107
  self,
105
- path_to_model_or_tensors: str,
106
- names_to_scheme: Dict[str, QuantizationArgs],
107
- device: str = "cpu",
108
- ) -> Generator[Tuple[str, Tensor], None, None]:
108
+ compressed_data: Dict[str, Tensor],
109
+ quantization_args: Optional[QuantizationArgs] = None,
110
+ ) -> torch.Tensor:
109
111
  """
110
- Reads a compressed state dict located at path_to_model_or_tensors
111
- and returns a generator for sequentially decompressing back to a
112
- dense state dict
113
-
114
- :param model_path: path to compressed safetensors model (directory with
115
- one or more safetensors files) or compressed tensors file
116
- :param device: optional device to load intermediate weights into
117
- :return: compressed state dict
112
+ Decompresses a single compressed weight
113
+
114
+ :param compressed_data: dictionary of data needed for decompression
115
+ :param quantization_args: quantization parameters for the weight
116
+ :return: tensor of the decompressed weight
118
117
  """
119
- weight_mappings = get_nested_weight_mappings(
120
- path_to_model_or_tensors, self.COMPRESSION_PARAM_NAMES
118
+ weight = compressed_data["weight_packed"]
119
+ scale = compressed_data["weight_scale"]
120
+ zero_point = compressed_data.get("weight_zero_point", None)
121
+ g_idx = compressed_data.get("weight_g_idx", None)
122
+ original_shape = torch.Size(compressed_data["weight_shape"])
123
+ num_bits = quantization_args.num_bits
124
+ unpacked = unpack_from_int32(weight, num_bits, original_shape)
125
+ decompressed_weight = dequantize(
126
+ x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
121
127
  )
122
- for weight_name in weight_mappings.keys():
123
- weight_data = {}
124
- for param_name, safe_path in weight_mappings[weight_name].items():
125
- weight_data["num_bits"] = names_to_scheme.get(weight_name).num_bits
126
- full_name = merge_names(weight_name, param_name)
127
- with safe_open(safe_path, framework="pt", device=device) as f:
128
- weight_data[param_name] = f.get_tensor(full_name)
129
-
130
- if "weight_scale" in weight_data:
131
- zero_point = weight_data.get("weight_zero_point", None)
132
- scale = weight_data["weight_scale"]
133
- weight = weight_data["weight_packed"]
134
- num_bits = weight_data["num_bits"]
135
- original_shape = torch.Size(weight_data["weight_shape"])
136
- unpacked = unpack_from_int32(weight, num_bits, original_shape)
137
- decompressed = dequantize(
138
- x_q=unpacked,
139
- scale=scale,
140
- zero_point=zero_point,
141
- )
142
- yield merge_names(weight_name, "weight"), decompressed
128
+
129
+ return decompressed_weight
143
130
 
144
131
 
145
132
  def pack_to_int32(value: torch.Tensor, num_bits: int) -> torch.Tensor:
@@ -197,13 +184,15 @@ def unpack_from_int32(
197
184
  if num_bits > 8:
198
185
  raise ValueError("Unpacking is only supported for less than 8 bits")
199
186
 
200
- # convert packed input to unsigned numpy
201
- value = value.numpy().view(np.uint32)
202
187
  pack_factor = 32 // num_bits
203
188
 
204
189
  # unpack
205
190
  mask = pow(2, num_bits) - 1
206
- unpacked = np.zeros((value.shape[0], value.shape[1] * pack_factor))
191
+ unpacked = torch.zeros(
192
+ (value.shape[0], value.shape[1] * pack_factor),
193
+ device=value.device,
194
+ dtype=torch.int32,
195
+ )
207
196
  for i in range(pack_factor):
208
197
  unpacked[:, i::pack_factor] = (value >> (num_bits * i)) & mask
209
198
 
@@ -214,6 +203,6 @@ def unpack_from_int32(
214
203
  # bits are packed in unsigned format, reformat to signed
215
204
  # update the value range from unsigned to signed
216
205
  offset = pow(2, num_bits) // 2
217
- unpacked = (unpacked.astype(np.int16) - offset).astype(np.int8)
206
+ unpacked = (unpacked - offset).to(torch.int8)
218
207
 
219
- return torch.from_numpy(unpacked)
208
+ return unpacked
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Optional
16
+ from typing import List, Optional
17
17
 
18
18
  from compressed_tensors.registry import RegistryMixin
19
19
  from pydantic import BaseModel
@@ -37,11 +37,16 @@ class SparsityCompressionConfig(RegistryMixin, BaseModel):
37
37
  Base data class for storing sparsity compression parameters
38
38
 
39
39
  :param format: name of compression format
40
+ :param targets: List of layer names or layer types that aren't sparse and should
41
+ be ignored during compression. By default, assume all layers are targeted
42
+ :param ignore: List of layer names (unique) to ignore from targets. Defaults to None
40
43
  :param global_sparsity: average sparsity of the entire model
41
44
  :param sparsity_structure: structure of the sparsity, such as
42
45
  "unstructured", "2:4", "8:16" etc
43
46
  """
44
47
 
45
48
  format: str
49
+ targets: Optional[List[str]] = None
50
+ ignore: Optional[List[str]] = None
46
51
  global_sparsity: Optional[float] = 0.0
47
52
  sparsity_structure: Optional[str] = "unstructured"
@@ -0,0 +1,13 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,87 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ from compressed_tensors.compressors.base import Compressor
17
+ from compressed_tensors.quantization import (
18
+ QuantizationScheme,
19
+ QuantizationStatus,
20
+ initialize_module_for_quantization,
21
+ )
22
+ from torch import Tensor
23
+ from torch.nn import Parameter
24
+ from torch.nn.functional import linear
25
+ from torch.nn.modules import Linear
26
+
27
+
28
+ class CompressedLinear(Linear):
29
+ """
30
+ Wrapper module for running a compressed forward pass of a quantized Linear module.
31
+ The wrapped layer will decompressed on each forward call.
32
+
33
+ :param module: dense linear module to replace
34
+ :param quantization_scheme: quantization config for the module to wrap
35
+ :param quantization_format: compression format module is stored as
36
+ """
37
+
38
+ @classmethod
39
+ @torch.no_grad()
40
+ def from_linear(
41
+ cls,
42
+ module: Linear,
43
+ quantization_scheme: QuantizationScheme,
44
+ quantization_format: str,
45
+ ):
46
+ module.__class__ = CompressedLinear
47
+ module.compressor = Compressor.load_from_registry(quantization_format)
48
+ device = next(module.parameters()).device
49
+
50
+ # this will initialize all the scales and zero points
51
+ initialize_module_for_quantization(
52
+ module, quantization_scheme, force_zero_point=False
53
+ )
54
+
55
+ # get the shape and dtype of compressed parameters
56
+ compression_params = module.compressor.compression_param_info(
57
+ module.weight.shape, quantization_scheme.weights
58
+ )
59
+
60
+ # no need for this once quantization is initialized, will be replaced
61
+ # with the compressed parameter
62
+ delattr(module, "weight")
63
+
64
+ # populate compressed weights and quantization parameters
65
+ for name, (shape, dtype) in compression_params.items():
66
+ param = Parameter(
67
+ torch.empty(shape, device=device, dtype=dtype), requires_grad=False
68
+ )
69
+ module.register_parameter(name, param)
70
+
71
+ # mark module as compressed
72
+ module.quantization_status = QuantizationStatus.COMPRESSED
73
+
74
+ # handles case where forward is wrapped in new_forward by accelerate hooks
75
+ if hasattr(module, "_old_forward"):
76
+ module._old_forward = CompressedLinear.forward.__get__(
77
+ module, CompressedLinear
78
+ )
79
+
80
+ return module
81
+
82
+ def forward(self, input: Tensor) -> Tensor:
83
+ """
84
+ Decompresses the weight, then runs the wrapped forward pass
85
+ """
86
+ uncompressed_weight = self.compressor.decompress_module(self)
87
+ return linear(input, uncompressed_weight, self.bias)
@@ -14,12 +14,14 @@
14
14
 
15
15
  import logging
16
16
  import re
17
- from collections import OrderedDict
17
+ from collections import OrderedDict, defaultdict
18
+ from copy import deepcopy
18
19
  from typing import Dict, Iterable, List, Optional
19
20
  from typing import OrderedDict as OrderedDictType
20
21
  from typing import Union
21
22
 
22
23
  import torch
24
+ from compressed_tensors.config import CompressionFormat
23
25
  from compressed_tensors.quantization.lifecycle.calibration import (
24
26
  set_module_for_calibration,
25
27
  )
@@ -42,7 +44,7 @@ from compressed_tensors.quantization.utils import (
42
44
  is_kv_cache_quant_scheme,
43
45
  iter_named_leaf_modules,
44
46
  )
45
- from compressed_tensors.utils.helpers import fix_fsdp_module_name
47
+ from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
46
48
  from compressed_tensors.utils.offload import update_parameter_data
47
49
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
48
50
  from torch.nn import Module
@@ -103,13 +105,21 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
103
105
  )
104
106
 
105
107
 
106
- def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
108
+ def apply_quantization_config(
109
+ model: Module, config: QuantizationConfig, run_compressed: bool = False
110
+ ) -> Dict:
107
111
  """
108
112
  Initializes the model for quantization in-place based on the given config
109
113
 
110
114
  :param model: model to apply quantization config to
111
115
  :param config: quantization config
116
+ :param run_compressed: Whether the model will be run in compressed mode or
117
+ decompressed fully on load
112
118
  """
119
+ # remove reference to the original `config`
120
+ # argument. This function can mutate it, and we'd
121
+ # like to keep the original `config` as it is.
122
+ config = deepcopy(config)
113
123
  # build mapping of targets to schemes for easier matching
114
124
  # use ordered dict to preserve target ordering in config
115
125
  target_to_scheme = OrderedDict()
@@ -119,21 +129,39 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
119
129
  for target in scheme.targets:
120
130
  target_to_scheme[target] = scheme
121
131
 
132
+ if run_compressed:
133
+ from compressed_tensors.linear.compressed_linear import CompressedLinear
134
+
122
135
  # list of submodules to ignore
123
- ignored_submodules = []
136
+ ignored_submodules = defaultdict(list)
124
137
  # mark appropriate layers for quantization by setting their quantization schemes
125
138
  for name, submodule in iter_named_leaf_modules(model):
126
139
  # potentially fix module name to remove FSDP wrapper prefix
127
140
  name = fix_fsdp_module_name(name)
128
- if find_name_or_class_matches(name, submodule, config.ignore):
129
- ignored_submodules.append(name)
141
+ if matches := find_name_or_class_matches(name, submodule, config.ignore):
142
+ for match in matches:
143
+ ignored_submodules[match].append(name)
130
144
  continue # layer matches ignore list, continue
131
145
  targets = find_name_or_class_matches(name, submodule, target_to_scheme)
132
146
  if targets:
147
+ scheme = _scheme_from_targets(target_to_scheme, targets, name)
148
+ if run_compressed:
149
+ format = config.format
150
+ if format != CompressionFormat.dense.value:
151
+ if isinstance(submodule, torch.nn.Linear):
152
+ # TODO: expand to more module types
153
+ compressed_linear = CompressedLinear.from_linear(
154
+ submodule,
155
+ quantization_scheme=scheme,
156
+ quantization_format=format,
157
+ )
158
+ replace_module(model, name, compressed_linear)
159
+
133
160
  # target matched - add layer and scheme to target list
134
161
  submodule.quantization_scheme = _scheme_from_targets(
135
162
  target_to_scheme, targets, name
136
163
  )
164
+
137
165
  names_to_scheme[name] = submodule.quantization_scheme.weights
138
166
 
139
167
  if config.ignore is not None and ignored_submodules is not None:
@@ -143,8 +171,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict
143
171
  "not found in the model: "
144
172
  f"{set(config.ignore) - set(ignored_submodules)}"
145
173
  )
146
- # apply current quantization status across all targeted layers
147
174
 
175
+ # apply current quantization status across all targeted layers
148
176
  apply_quantization_status(model, config.quantization_status)
149
177
  return names_to_scheme
150
178
 
@@ -192,7 +220,12 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
192
220
  current_status = infer_quantization_status(model)
193
221
 
194
222
  if status >= QuantizationStatus.INITIALIZED > current_status:
195
- model.apply(initialize_module_for_quantization)
223
+ force_zero_point_init = status != QuantizationStatus.COMPRESSED
224
+ model.apply(
225
+ lambda module: initialize_module_for_quantization(
226
+ module, force_zero_point=force_zero_point_init
227
+ )
228
+ )
196
229
 
197
230
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
198
231
  # only quantize weights up front when our end goal state is calibration,
@@ -273,9 +306,11 @@ def _load_quant_args_from_state_dict(
273
306
  """
274
307
  scale_name = f"{base_name}_scale"
275
308
  zp_name = f"{base_name}_zero_point"
309
+ g_idx_name = f"{base_name}_g_idx"
276
310
 
277
311
  state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
278
312
  state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
313
+ state_dict_g_idx = state_dict.get(f"{module_name}.{g_idx_name}", None)
279
314
 
280
315
  if state_dict_scale is not None:
281
316
  # module is quantized
@@ -285,6 +320,9 @@ def _load_quant_args_from_state_dict(
285
320
  state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
286
321
  update_parameter_data(module, state_dict_zp, zp_name)
287
322
 
323
+ if state_dict_g_idx is not None:
324
+ update_parameter_data(module, state_dict_g_idx, g_idx_name)
325
+
288
326
 
289
327
  def _scheme_from_targets(
290
328
  target_to_scheme: OrderedDictType[str, QuantizationScheme],
@@ -36,15 +36,15 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
36
36
  apply to full model with `model.apply(set_module_for_calibration)`
37
37
 
38
38
  :param module: module to set for calibration
39
- :param quantize_weights_upfront: whether to automatically run weight quantization at the
40
- start of calibration
39
+ :param quantize_weights_upfront: whether to automatically
40
+ run weight quantization at the start of calibration
41
41
  """
42
42
  if not getattr(module, "quantization_scheme", None):
43
43
  # no quantization scheme nothing to do
44
44
  return
45
45
  status = getattr(module, "quantization_status", None)
46
46
  if not status or status != QuantizationStatus.INITIALIZED:
47
- raise _LOGGER.warning(
47
+ _LOGGER.warning(
48
48
  f"Attempting set module with status {status} to calibration mode. "
49
49
  f"but status is not {QuantizationStatus.INITIALIZED} - you may "
50
50
  "be calibrating an uninitialized module which may fail or attempting "
@@ -54,13 +54,14 @@ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool =
54
54
  if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
55
  # set weight scale and zero_point up front, calibration data doesn't affect it
56
56
  observer = module.weight_observer
57
+ g_idx = getattr(module, "weight_g_idx", None)
57
58
 
58
59
  offloaded = False
59
60
  if is_module_offloaded(module):
60
61
  module._hf_hook.pre_forward(module)
61
62
  offloaded = True
62
63
 
63
- scale, zero_point = observer(module.weight)
64
+ scale, zero_point = observer(module.weight, g_idx=g_idx)
64
65
  update_parameter_data(module, scale, "weight_scale")
65
66
  update_parameter_data(module, zero_point, "weight_zero_point")
66
67
 
@@ -49,8 +49,9 @@ def compress_quantized_weights(module: Module):
49
49
  weight = getattr(module, "weight", None)
50
50
  scale = getattr(module, "weight_scale", None)
51
51
  zero_point = getattr(module, "weight_zero_point", None)
52
+ g_idx = getattr(module, "weight_g_idx", None)
52
53
 
53
- if weight is None or scale is None or zero_point is None:
54
+ if weight is None or scale is None:
54
55
  # no weight, scale, or ZP, nothing to do
55
56
 
56
57
  # mark as compressed here to maintain consistent status throughout the model
@@ -62,6 +63,7 @@ def compress_quantized_weights(module: Module):
62
63
  x=weight,
63
64
  scale=scale,
64
65
  zero_point=zero_point,
66
+ g_idx=g_idx,
65
67
  args=scheme.weights,
66
68
  dtype=torch.int8,
67
69
  )