compressed-tensors 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +1 -1
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +33 -12
  7. compressed_tensors/compressors/{int_quantized.py → naive_quantized.py} +33 -15
  8. compressed_tensors/compressors/pack_quantized.py +58 -51
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +2 -0
  11. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  12. compressed_tensors/quantization/lifecycle/apply.py +161 -39
  13. compressed_tensors/quantization/lifecycle/calibration.py +20 -1
  14. compressed_tensors/quantization/lifecycle/forward.py +70 -25
  15. compressed_tensors/quantization/lifecycle/helpers.py +53 -0
  16. compressed_tensors/quantization/lifecycle/initialize.py +30 -1
  17. compressed_tensors/quantization/observers/base.py +39 -0
  18. compressed_tensors/quantization/observers/helpers.py +64 -11
  19. compressed_tensors/quantization/quant_args.py +45 -1
  20. compressed_tensors/quantization/quant_config.py +35 -2
  21. compressed_tensors/quantization/quant_scheme.py +105 -4
  22. compressed_tensors/quantization/utils/helpers.py +67 -1
  23. compressed_tensors/utils/__init__.py +4 -0
  24. compressed_tensors/utils/helpers.py +31 -2
  25. compressed_tensors/utils/offload.py +104 -0
  26. compressed_tensors/version.py +1 -1
  27. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/METADATA +2 -1
  28. compressed_tensors-0.5.0.dist-info/RECORD +48 -0
  29. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/WHEEL +1 -1
  30. compressed_tensors/compressors/utils/__init__.py +0 -19
  31. compressed_tensors/compressors/utils/helpers.py +0 -43
  32. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  33. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  34. /compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +0 -0
  35. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/LICENSE +0 -0
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.5.0.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,9 @@
15
15
  import logging
16
16
  import re
17
17
  from collections import OrderedDict
18
- from typing import Dict, Iterable, Optional
18
+ from typing import Dict, Iterable, List, Optional
19
+ from typing import OrderedDict as OrderedDictType
20
+ from typing import Union
19
21
 
20
22
  import torch
21
23
  from compressed_tensors.quantization.lifecycle.calibration import (
@@ -28,15 +30,20 @@ from compressed_tensors.quantization.lifecycle.frozen import freeze_module_quant
28
30
  from compressed_tensors.quantization.lifecycle.initialize import (
29
31
  initialize_module_for_quantization,
30
32
  )
33
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
31
34
  from compressed_tensors.quantization.quant_config import (
32
35
  QuantizationConfig,
33
36
  QuantizationStatus,
34
37
  )
38
+ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
35
39
  from compressed_tensors.quantization.utils import (
40
+ KV_CACHE_TARGETS,
36
41
  infer_quantization_status,
42
+ is_kv_cache_quant_scheme,
37
43
  iter_named_leaf_modules,
38
44
  )
39
45
  from compressed_tensors.utils.helpers import fix_fsdp_module_name
46
+ from compressed_tensors.utils.offload import update_parameter_data
40
47
  from compressed_tensors.utils.safetensors_load import get_safetensors_folder
41
48
  from torch.nn import Module
42
49
 
@@ -45,7 +52,7 @@ __all__ = [
45
52
  "load_pretrained_quantization",
46
53
  "apply_quantization_config",
47
54
  "apply_quantization_status",
48
- "find_first_name_or_class_match",
55
+ "find_name_or_class_matches",
49
56
  ]
50
57
 
51
58
  from compressed_tensors.quantization.utils.helpers import is_module_quantized
@@ -96,7 +103,7 @@ def load_pretrained_quantization(model: Module, model_name_or_path: str):
96
103
  )
97
104
 
98
105
 
99
- def apply_quantization_config(model: Module, config: QuantizationConfig):
106
+ def apply_quantization_config(model: Module, config: QuantizationConfig) -> Dict:
100
107
  """
101
108
  Initializes the model for quantization in-place based on the given config
102
109
 
@@ -106,6 +113,8 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
106
113
  # build mapping of targets to schemes for easier matching
107
114
  # use ordered dict to preserve target ordering in config
108
115
  target_to_scheme = OrderedDict()
116
+ config = process_quantization_config(config)
117
+ names_to_scheme = OrderedDict()
109
118
  for scheme in config.config_groups.values():
110
119
  for target in scheme.targets:
111
120
  target_to_scheme[target] = scheme
@@ -116,13 +125,16 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
116
125
  for name, submodule in iter_named_leaf_modules(model):
117
126
  # potentially fix module name to remove FSDP wrapper prefix
118
127
  name = fix_fsdp_module_name(name)
119
- if find_first_name_or_class_match(name, submodule, config.ignore):
128
+ if find_name_or_class_matches(name, submodule, config.ignore):
120
129
  ignored_submodules.append(name)
121
130
  continue # layer matches ignore list, continue
122
- target = find_first_name_or_class_match(name, submodule, target_to_scheme)
123
- if target is not None:
131
+ targets = find_name_or_class_matches(name, submodule, target_to_scheme)
132
+ if targets:
124
133
  # target matched - add layer and scheme to target list
125
- submodule.quantization_scheme = target_to_scheme[target]
134
+ submodule.quantization_scheme = _scheme_from_targets(
135
+ target_to_scheme, targets, name
136
+ )
137
+ names_to_scheme[name] = submodule.quantization_scheme.weights
126
138
 
127
139
  if config.ignore is not None and ignored_submodules is not None:
128
140
  if set(config.ignore) - set(ignored_submodules):
@@ -132,7 +144,42 @@ def apply_quantization_config(model: Module, config: QuantizationConfig):
132
144
  f"{set(config.ignore) - set(ignored_submodules)}"
133
145
  )
134
146
  # apply current quantization status across all targeted layers
147
+
135
148
  apply_quantization_status(model, config.quantization_status)
149
+ return names_to_scheme
150
+
151
+
152
+ def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig:
153
+ """
154
+ Preprocess the raw QuantizationConfig
155
+
156
+ :param config: the raw QuantizationConfig
157
+ :return: the processed QuantizationConfig
158
+ """
159
+ if config.kv_cache_scheme is not None:
160
+ config = process_kv_cache_config(config)
161
+
162
+ return config
163
+
164
+
165
+ def process_kv_cache_config(
166
+ config: QuantizationConfig, targets: Union[List[str], str] = KV_CACHE_TARGETS
167
+ ) -> QuantizationConfig:
168
+ """
169
+ Reformulate the `config.kv_cache` as a `config_group`
170
+ and add it to the set of existing `config.groups`
171
+
172
+ :param config: the QuantizationConfig
173
+ :return: the QuantizationConfig with additional "kv_cache" group
174
+ """
175
+ kv_cache_dict = config.kv_cache_scheme.model_dump()
176
+ kv_cache_scheme = QuantizationScheme(
177
+ output_activations=QuantizationArgs(**kv_cache_dict),
178
+ targets=targets,
179
+ )
180
+ kv_cache_group = dict(kv_cache=kv_cache_scheme)
181
+ config.config_groups.update(kv_cache_group)
182
+ return config
136
183
 
137
184
 
138
185
  def apply_quantization_status(model: Module, status: QuantizationStatus):
@@ -148,7 +195,14 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
148
195
  model.apply(initialize_module_for_quantization)
149
196
 
150
197
  if current_status < status >= QuantizationStatus.CALIBRATION > current_status:
151
- model.apply(set_module_for_calibration)
198
+ # only quantize weights up front when our end goal state is calibration,
199
+ # weight quantization parameters are already loaded for frozen/compressed
200
+ quantize_weights_upfront = status == QuantizationStatus.CALIBRATION
201
+ model.apply(
202
+ lambda module: set_module_for_calibration(
203
+ module, quantize_weights_upfront=quantize_weights_upfront
204
+ )
205
+ )
152
206
  if current_status < status >= QuantizationStatus.FROZEN > current_status:
153
207
  model.apply(freeze_module_quantization)
154
208
 
@@ -156,36 +210,45 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
156
210
  model.apply(compress_quantized_weights)
157
211
 
158
212
 
159
- def find_first_name_or_class_match(
213
+ def find_name_or_class_matches(
160
214
  name: str, module: Module, targets: Iterable[str], check_contains: bool = False
161
- ) -> Optional[str]:
162
- # first element of targets that matches the given name
163
- # if no name matches returns first target that matches the class name
164
- # returns None otherwise
215
+ ) -> List[str]:
216
+ """
217
+ Returns all targets that match the given name or the class name.
218
+ Returns empty list otherwise.
219
+ The order of the output `matches` list matters.
220
+ The entries are sorted in the following order:
221
+ 1. matches on exact strings
222
+ 2. matches on regex patterns
223
+ 3. matches on module names
224
+ """
225
+ targets = sorted(targets, key=lambda x: ("re:" in x, x))
165
226
  if isinstance(targets, Iterable):
166
- return _find_first_match(name, targets) or _find_first_match(
227
+ matches = _find_matches(name, targets) + _find_matches(
167
228
  module.__class__.__name__, targets, check_contains
168
229
  )
230
+ matches = [match for match in matches if match is not None]
231
+ return matches
169
232
 
170
233
 
171
- def _find_first_match(
234
+ def _find_matches(
172
235
  value: str, targets: Iterable[str], check_contains: bool = False
173
- ) -> Optional[str]:
174
- # returns first element of target that matches value either
236
+ ) -> List[str]:
237
+ # returns all the targets that match value either
175
238
  # exactly or as a regex after 're:'. if check_contains is set to True,
176
239
  # additionally checks if the target string is contained with value.
177
-
240
+ matches = []
178
241
  for target in targets:
179
242
  if target.startswith("re:"):
180
243
  pattern = target[3:]
181
244
  if re.match(pattern, value):
182
- return target
245
+ matches.append(target)
183
246
  elif check_contains:
184
247
  if target.lower() in value.lower():
185
- return target
248
+ matches.append(target)
186
249
  elif target == value:
187
- return target
188
- return None
250
+ matches.append(target)
251
+ return matches
189
252
 
190
253
 
191
254
  def _infer_status(model: Module) -> Optional[QuantizationStatus]:
@@ -210,20 +273,79 @@ def _load_quant_args_from_state_dict(
210
273
  """
211
274
  scale_name = f"{base_name}_scale"
212
275
  zp_name = f"{base_name}_zero_point"
213
- device = next(module.parameters()).device
214
-
215
- scale = getattr(module, scale_name, None)
216
- zp = getattr(module, zp_name, None)
217
- if scale is not None:
218
- state_dict_scale = state_dict.get(f"{module_name}.{scale_name}")
219
- if state_dict_scale is not None:
220
- scale.data = state_dict_scale.to(device).to(scale.dtype)
221
- else:
222
- scale.data = scale.data.to(device)
223
-
224
- if zp is not None:
225
- zp_from_state = state_dict.get(f"{module_name}.{zp_name}", None)
226
- if zp_from_state is not None: # load the non-zero zero points
227
- zp.data = state_dict[f"{module_name}.{zp_name}"].to(device)
228
- else: # fill with zeros matching scale shape
229
- zp.data = torch.zeros_like(scale, dtype=torch.int8).to(device)
276
+
277
+ state_dict_scale = state_dict.get(f"{module_name}.{scale_name}", None)
278
+ state_dict_zp = state_dict.get(f"{module_name}.{zp_name}", None)
279
+
280
+ if state_dict_scale is not None:
281
+ # module is quantized
282
+ update_parameter_data(module, state_dict_scale, scale_name)
283
+ if state_dict_zp is None:
284
+ # fill in zero point for symmetric quantization
285
+ state_dict_zp = torch.zeros_like(state_dict_scale, device="cpu")
286
+ update_parameter_data(module, state_dict_zp, zp_name)
287
+
288
+
289
+ def _scheme_from_targets(
290
+ target_to_scheme: OrderedDictType[str, QuantizationScheme],
291
+ targets: List[str],
292
+ name: str,
293
+ ) -> QuantizationScheme:
294
+ if len(targets) == 1:
295
+ # if `targets` iterable contains a single element
296
+ # use it as the key
297
+ return target_to_scheme[targets[0]]
298
+
299
+ # otherwise, we need to merge QuantizationSchemes corresponding
300
+ # to multiple targets. This is most likely because `name` module
301
+ # is being target both as an ordinary quantization target, as well
302
+ # as kv cache quantization target
303
+ schemes_to_merge = [target_to_scheme[target] for target in targets]
304
+ return _merge_schemes(schemes_to_merge, name)
305
+
306
+
307
+ def _merge_schemes(
308
+ schemes_to_merge: List[QuantizationScheme], name: str
309
+ ) -> QuantizationScheme:
310
+
311
+ kv_cache_quantization_scheme = [
312
+ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
313
+ ]
314
+ if not kv_cache_quantization_scheme:
315
+ # if the schemes_to_merge do not contain any
316
+ # kv cache QuantizationScheme
317
+ # return the first scheme (the prioritized one,
318
+ # since the order of schemes_to_merge matters)
319
+ return schemes_to_merge[0]
320
+ else:
321
+ # fetch the kv cache QuantizationScheme and the highest
322
+ # priority non-kv cache QuantizationScheme and merge them
323
+ kv_cache_quantization_scheme = kv_cache_quantization_scheme[0]
324
+ quantization_scheme = [
325
+ scheme
326
+ for scheme in schemes_to_merge
327
+ if not is_kv_cache_quant_scheme(scheme)
328
+ ][0]
329
+ schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme]
330
+ merged_scheme = {}
331
+ for scheme in schemes_to_merge:
332
+ scheme_dict = {
333
+ k: v for k, v in scheme.model_dump().items() if v is not None
334
+ }
335
+ # when merging multiple schemes, the final target will be
336
+ # the `name` argument - hence erase the original targets
337
+ del scheme_dict["targets"]
338
+ # make sure that schemes do not "clash" with each other
339
+ overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys())
340
+ if overlapping_keys:
341
+ raise ValueError(
342
+ f"The module: {name} is being modified by two clashing "
343
+ f"quantization schemes, that jointly try to override "
344
+ f"properties: {overlapping_keys}. Fix the quantization config "
345
+ "so that it is not ambiguous."
346
+ )
347
+ merged_scheme.update(scheme_dict)
348
+
349
+ merged_scheme.update(targets=[name])
350
+
351
+ return QuantizationScheme(**merged_scheme)
@@ -16,6 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from compressed_tensors.quantization.quant_config import QuantizationStatus
19
+ from compressed_tensors.utils import is_module_offloaded, update_parameter_data
19
20
  from torch.nn import Module
20
21
 
21
22
 
@@ -27,7 +28,7 @@ __all__ = [
27
28
  _LOGGER = logging.getLogger(__name__)
28
29
 
29
30
 
30
- def set_module_for_calibration(module: Module):
31
+ def set_module_for_calibration(module: Module, quantize_weights_upfront: bool = True):
31
32
  """
32
33
  marks a layer as ready for calibration which activates observers
33
34
  to update scales and zero points on each forward pass
@@ -35,6 +36,8 @@ def set_module_for_calibration(module: Module):
35
36
  apply to full model with `model.apply(set_module_for_calibration)`
36
37
 
37
38
  :param module: module to set for calibration
39
+ :param quantize_weights_upfront: whether to automatically run weight quantization at the
40
+ start of calibration
38
41
  """
39
42
  if not getattr(module, "quantization_scheme", None):
40
43
  # no quantization scheme nothing to do
@@ -48,4 +51,20 @@ def set_module_for_calibration(module: Module):
48
51
  "to re-calibrate a frozen module"
49
52
  )
50
53
 
54
+ if quantize_weights_upfront and module.quantization_scheme.weights is not None:
55
+ # set weight scale and zero_point up front, calibration data doesn't affect it
56
+ observer = module.weight_observer
57
+
58
+ offloaded = False
59
+ if is_module_offloaded(module):
60
+ module._hf_hook.pre_forward(module)
61
+ offloaded = True
62
+
63
+ scale, zero_point = observer(module.weight)
64
+ update_parameter_data(module, scale, "weight_scale")
65
+ update_parameter_data(module, zero_point, "weight_zero_point")
66
+
67
+ if offloaded:
68
+ module._hf_hook.post_forward(module, None)
69
+
51
70
  module.quantization_status = QuantizationStatus.CALIBRATION
@@ -17,12 +17,15 @@ from math import ceil
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.observers.helpers import calculate_range
20
21
  from compressed_tensors.quantization.quant_args import (
21
22
  QuantizationArgs,
22
23
  QuantizationStrategy,
24
+ round_to_quantized_type,
23
25
  )
24
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28
+ from compressed_tensors.utils import update_parameter_data
26
29
  from torch.nn import Module
27
30
 
28
31
 
@@ -80,8 +83,9 @@ def quantize(
80
83
  def dequantize(
81
84
  x_q: torch.Tensor,
82
85
  scale: torch.Tensor,
83
- zero_point: torch.Tensor,
86
+ zero_point: torch.Tensor = None,
84
87
  args: QuantizationArgs = None,
88
+ dtype: Optional[torch.dtype] = None,
85
89
  ) -> torch.Tensor:
86
90
  """
87
91
  Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -91,6 +95,7 @@ def dequantize(
91
95
  :param scale: scale tensor
92
96
  :param zero_point: zero point tensor
93
97
  :param args: quantization args used to quantize x_q
98
+ :param dtype: optional dtype to cast the dequantized output to
94
99
  :return: dequantized float tensor
95
100
  """
96
101
  if args is None:
@@ -107,8 +112,12 @@ def dequantize(
107
112
  else:
108
113
  raise ValueError(
109
114
  f"Could not infer a quantization strategy from scale with {scale.ndim} "
110
- "dimmensions. Expected 0-2 dimmensions."
115
+ "dimmensions. Expected 0 or 2 dimmensions."
111
116
  )
117
+
118
+ if dtype is None:
119
+ dtype = scale.dtype
120
+
112
121
  return _process_quantization(
113
122
  x=x_q,
114
123
  scale=scale,
@@ -116,6 +125,7 @@ def dequantize(
116
125
  args=args,
117
126
  do_quantize=False,
118
127
  do_dequantize=True,
128
+ dtype=dtype,
119
129
  )
120
130
 
121
131
 
@@ -159,19 +169,13 @@ def _process_quantization(
159
169
  do_quantize: bool = True,
160
170
  do_dequantize: bool = True,
161
171
  ) -> torch.Tensor:
162
- bit_range = 2**args.num_bits
163
- q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
164
- q_min = torch.tensor(-bit_range / 2, device=x.device)
172
+
173
+ q_min, q_max = calculate_range(args, x.device)
165
174
  group_size = args.group_size
166
175
 
167
176
  if args.strategy == QuantizationStrategy.GROUP:
168
-
169
- if do_dequantize and not do_quantize:
170
- # if dequantizing a quantized type infer the output type from the scale
171
- output = torch.zeros_like(x, dtype=scale.dtype)
172
- else:
173
- output_dtype = dtype if dtype is not None else x.dtype
174
- output = torch.zeros_like(x, dtype=output_dtype)
177
+ output_dtype = dtype if dtype is not None else x.dtype
178
+ output = torch.zeros_like(x).to(output_dtype)
175
179
 
176
180
  # TODO: vectorize the for loop
177
181
  # TODO: fix genetric assumption about the tensor size for computing group
@@ -181,7 +185,7 @@ def _process_quantization(
181
185
  while scale.ndim < 2:
182
186
  # pad scale and zero point dims for slicing
183
187
  scale = scale.unsqueeze(1)
184
- zero_point = zero_point.unsqueeze(1)
188
+ zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
185
189
 
186
190
  columns = x.shape[1]
187
191
  if columns >= group_size:
@@ -194,12 +198,18 @@ def _process_quantization(
194
198
  # scale.shape should be [nchan, ndim]
195
199
  # sc.shape should be [nchan, 1] after unsqueeze
196
200
  sc = scale[:, i].view(-1, 1)
197
- zp = zero_point[:, i].view(-1, 1)
201
+ zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
198
202
 
199
203
  idx = i * group_size
200
204
  if do_quantize:
201
205
  output[:, idx : (idx + group_size)] = _quantize(
202
- x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
206
+ x[:, idx : (idx + group_size)],
207
+ sc,
208
+ zp,
209
+ q_min,
210
+ q_max,
211
+ args,
212
+ dtype=dtype,
203
213
  )
204
214
  if do_dequantize:
205
215
  input = (
@@ -211,7 +221,15 @@ def _process_quantization(
211
221
 
212
222
  else: # covers channel, token and tensor strategies
213
223
  if do_quantize:
214
- output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
224
+ output = _quantize(
225
+ x,
226
+ scale,
227
+ zero_point,
228
+ q_min,
229
+ q_max,
230
+ args,
231
+ dtype=dtype,
232
+ )
215
233
  if do_dequantize:
216
234
  output = _dequantize(output if do_quantize else x, scale, zero_point)
217
235
 
@@ -228,6 +246,11 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
228
246
 
229
247
  @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
230
248
  def wrapped_forward(self, *args, **kwargs):
249
+ if not getattr(module, "quantization_enabled", True):
250
+ # quantization is disabled on forward passes, return baseline
251
+ # forward call
252
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
253
+
231
254
  input_ = args[0]
232
255
 
233
256
  if scheme.input_activations is not None:
@@ -276,6 +299,11 @@ def maybe_calibrate_or_quantize(
276
299
  }:
277
300
  return value
278
301
 
302
+ if value.numel() == 0:
303
+ # if the tensor is empty,
304
+ # skip quantization
305
+ return value
306
+
279
307
  if args.dynamic:
280
308
  # dynamic quantization - get scale and zero point directly from observer
281
309
  observer = getattr(module, f"{base_name}_observer")
@@ -285,16 +313,19 @@ def maybe_calibrate_or_quantize(
285
313
  scale = getattr(module, f"{base_name}_scale")
286
314
  zero_point = getattr(module, f"{base_name}_zero_point")
287
315
 
288
- if module.quantization_status == QuantizationStatus.CALIBRATION:
316
+ if (
317
+ module.quantization_status == QuantizationStatus.CALIBRATION
318
+ and base_name != "weight"
319
+ ):
289
320
  # calibration mode - get new quant params from observer
290
321
  observer = getattr(module, f"{base_name}_observer")
291
322
 
292
323
  updated_scale, updated_zero_point = observer(value)
293
324
 
294
325
  # update scale and zero point
295
- device = next(module.parameters()).device
296
- scale.data = updated_scale.to(device)
297
- zero_point.data = updated_zero_point.to(device)
326
+ update_parameter_data(module, updated_scale, f"{base_name}_scale")
327
+ update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328
+
298
329
  return fake_quantize(value, scale, zero_point, args)
299
330
 
300
331
 
@@ -305,14 +336,18 @@ def _quantize(
305
336
  zero_point: torch.Tensor,
306
337
  q_min: torch.Tensor,
307
338
  q_max: torch.Tensor,
339
+ args: QuantizationArgs,
308
340
  dtype: Optional[torch.dtype] = None,
309
341
  ) -> torch.Tensor:
310
- quantized_value = torch.clamp(
311
- torch.round(x / scale + zero_point),
342
+
343
+ scaled = x / scale + zero_point.to(x.dtype)
344
+ # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
345
+ clamped_value = torch.clamp(
346
+ scaled,
312
347
  q_min,
313
348
  q_max,
314
349
  )
315
-
350
+ quantized_value = round_to_quantized_type(clamped_value, args)
316
351
  if dtype is not None:
317
352
  quantized_value = quantized_value.to(dtype)
318
353
 
@@ -323,6 +358,16 @@ def _quantize(
323
358
  def _dequantize(
324
359
  x_q: torch.Tensor,
325
360
  scale: torch.Tensor,
326
- zero_point: torch.Tensor,
361
+ zero_point: torch.Tensor = None,
362
+ dtype: Optional[torch.dtype] = None,
327
363
  ) -> torch.Tensor:
328
- return (x_q - zero_point) * scale
364
+
365
+ dequant_value = x_q
366
+ if zero_point is not None:
367
+ dequant_value = dequant_value - zero_point.to(scale.dtype)
368
+ dequant_value = dequant_value.to(scale.dtype) * scale
369
+
370
+ if dtype is not None:
371
+ dequant_value = dequant_value.to(dtype)
372
+
373
+ return dequant_value
@@ -0,0 +1,53 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Miscelaneous helpers for the quantization lifecycle
17
+ """
18
+
19
+
20
+ from torch.nn import Module
21
+
22
+
23
+ __all__ = [
24
+ "update_layer_weight_quant_params",
25
+ "enable_quantization",
26
+ "disable_quantization",
27
+ ]
28
+
29
+
30
+ def update_layer_weight_quant_params(layer: Module):
31
+ weight = getattr(layer, "weight", None)
32
+ scale = getattr(layer, "weight_scale", None)
33
+ zero_point = getattr(layer, "weight_zero_point", None)
34
+ observer = getattr(layer, "weight_observer", None)
35
+
36
+ if weight is None or observer is None or scale is None or zero_point is None:
37
+ # scale, zp, or observer not calibratable or weight not available
38
+ return
39
+
40
+ updated_scale, updated_zero_point = observer(weight)
41
+
42
+ # update scale and zero point
43
+ device = next(layer.parameters()).device
44
+ scale.data = updated_scale.to(device)
45
+ zero_point.data = updated_zero_point.to(device)
46
+
47
+
48
+ def enable_quantization(module: Module):
49
+ module.quantization_enabled = True
50
+
51
+
52
+ def disable_quantization(module: Module):
53
+ module.quantization_enabled = False
@@ -17,6 +17,8 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
21
+ from accelerate.utils import PrefixedDataset
20
22
  from compressed_tensors.quantization.lifecycle.forward import (
21
23
  wrap_module_forward_quantized,
22
24
  )
@@ -26,6 +28,7 @@ from compressed_tensors.quantization.quant_args import (
26
28
  )
27
29
  from compressed_tensors.quantization.quant_config import QuantizationStatus
28
30
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
31
+ from compressed_tensors.utils import get_execution_device, is_module_offloaded
29
32
  from torch.nn import Module, Parameter
30
33
 
31
34
 
@@ -81,9 +84,32 @@ def initialize_module_for_quantization(
81
84
  module.quantization_scheme = scheme
82
85
  module.quantization_status = QuantizationStatus.INITIALIZED
83
86
 
87
+ offloaded = False
88
+ if is_module_offloaded(module):
89
+ offloaded = True
90
+ hook = module._hf_hook
91
+ prefix_dict = module._hf_hook.weights_map
92
+ new_prefix = {}
93
+
94
+ # recreate the prefix dict (since it is immutable)
95
+ # and add quantization parameters
96
+ for key, data in module.named_parameters():
97
+ if key not in prefix_dict:
98
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
99
+ else:
100
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
101
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
102
+ remove_hook_from_module(module)
103
+
84
104
  # wrap forward call of module to perform quantized actions based on calltime status
85
105
  wrap_module_forward_quantized(module, scheme)
86
106
 
107
+ if offloaded:
108
+ # we need to re-add the hook for offloading now that we've wrapped forward
109
+ add_hook_to_module(module, hook)
110
+ if prefix_dict is not None:
111
+ module._hf_hook.weights_map = new_prefix_dict
112
+
87
113
 
88
114
  def _initialize_scale_zero_point_observer(
89
115
  module: Module,
@@ -99,6 +125,8 @@ def _initialize_scale_zero_point_observer(
99
125
  return # no need to register a scale and zero point for a dynamic observer
100
126
 
101
127
  device = next(module.parameters()).device
128
+ if is_module_offloaded(module):
129
+ device = get_execution_device(module)
102
130
 
103
131
  # infer expected scale/zero point shape
104
132
  expected_shape = 1 # per tensor
@@ -120,8 +148,9 @@ def _initialize_scale_zero_point_observer(
120
148
  )
121
149
  module.register_parameter(f"{base_name}_scale", init_scale)
122
150
 
151
+ zp_dtype = quantization_args.pytorch_dtype()
123
152
  init_zero_point = Parameter(
124
- torch.empty(expected_shape, device=device, dtype=int),
153
+ torch.empty(expected_shape, device=device, dtype=zp_dtype),
125
154
  requires_grad=False,
126
155
  )
127
156
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)