compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +200 -8
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +101 -13
  7. compressed_tensors/compressors/naive_quantized.py +140 -0
  8. compressed_tensors/compressors/pack_quantized.py +128 -132
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +8 -1
  11. compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
  12. compressed_tensors/linear/compressed_linear.py +87 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +204 -44
  15. compressed_tensors/quantization/lifecycle/calibration.py +22 -2
  16. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  17. compressed_tensors/quantization/lifecycle/forward.py +139 -61
  18. compressed_tensors/quantization/lifecycle/helpers.py +80 -0
  19. compressed_tensors/quantization/lifecycle/initialize.py +77 -13
  20. compressed_tensors/quantization/observers/__init__.py +1 -0
  21. compressed_tensors/quantization/observers/base.py +93 -14
  22. compressed_tensors/quantization/observers/helpers.py +64 -11
  23. compressed_tensors/quantization/observers/min_max.py +8 -0
  24. compressed_tensors/quantization/observers/mse.py +162 -0
  25. compressed_tensors/quantization/quant_args.py +139 -23
  26. compressed_tensors/quantization/quant_config.py +35 -2
  27. compressed_tensors/quantization/quant_scheme.py +112 -13
  28. compressed_tensors/quantization/utils/helpers.py +68 -2
  29. compressed_tensors/utils/__init__.py +5 -0
  30. compressed_tensors/utils/helpers.py +44 -2
  31. compressed_tensors/utils/offload.py +116 -0
  32. compressed_tensors/utils/permute.py +70 -0
  33. compressed_tensors/utils/safetensors_load.py +2 -0
  34. compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
  35. compressed_tensors/version.py +1 -1
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
  37. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  38. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  39. compressed_tensors/compressors/int_quantized.py +0 -126
  40. compressed_tensors/compressors/utils/helpers.py +0 -43
  41. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  42. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  43. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  44. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -17,12 +17,15 @@ from math import ceil
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
+ from compressed_tensors.quantization.observers.helpers import calculate_range
20
21
  from compressed_tensors.quantization.quant_args import (
21
22
  QuantizationArgs,
22
23
  QuantizationStrategy,
24
+ round_to_quantized_type,
23
25
  )
24
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28
+ from compressed_tensors.utils import safe_permute, update_parameter_data
26
29
  from torch.nn import Module
27
30
 
28
31
 
@@ -42,6 +45,7 @@ def quantize(
42
45
  zero_point: torch.Tensor,
43
46
  args: QuantizationArgs,
44
47
  dtype: Optional[torch.dtype] = None,
48
+ g_idx: Optional[torch.Tensor] = None,
45
49
  ) -> torch.Tensor:
46
50
  """
47
51
  Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -55,16 +59,9 @@ def quantize(
55
59
  :param zero_point: zero point tensor
56
60
  :param args: quantization args dictating how to quantize x
57
61
  :param dtype: optional dtype to cast the quantized output to
62
+ :param g_idx: optional mapping from column index to group index
58
63
  :return: fake quantized tensor
59
64
  """
60
- # ensure all tensors are on the same device
61
- # assumes that the target device is the input
62
- # tensor's device
63
- if x.device != scale.device:
64
- scale = scale.to(x.device)
65
- if x.device != zero_point.device:
66
- zero_point = zero_point.to(x.device)
67
-
68
65
  return _process_quantization(
69
66
  x=x,
70
67
  scale=scale,
@@ -73,6 +70,7 @@ def quantize(
73
70
  dtype=dtype,
74
71
  do_quantize=True,
75
72
  do_dequantize=False,
73
+ g_idx=g_idx,
76
74
  )
77
75
 
78
76
 
@@ -80,8 +78,10 @@ def quantize(
80
78
  def dequantize(
81
79
  x_q: torch.Tensor,
82
80
  scale: torch.Tensor,
83
- zero_point: torch.Tensor,
81
+ zero_point: torch.Tensor = None,
84
82
  args: QuantizationArgs = None,
83
+ dtype: Optional[torch.dtype] = None,
84
+ g_idx: Optional[torch.Tensor] = None,
85
85
  ) -> torch.Tensor:
86
86
  """
87
87
  Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -91,6 +91,8 @@ def dequantize(
91
91
  :param scale: scale tensor
92
92
  :param zero_point: zero point tensor
93
93
  :param args: quantization args used to quantize x_q
94
+ :param dtype: optional dtype to cast the dequantized output to
95
+ :param g_idx: optional mapping from column index to group index
94
96
  :return: dequantized float tensor
95
97
  """
96
98
  if args is None:
@@ -107,8 +109,12 @@ def dequantize(
107
109
  else:
108
110
  raise ValueError(
109
111
  f"Could not infer a quantization strategy from scale with {scale.ndim} "
110
- "dimmensions. Expected 0-2 dimmensions."
112
+ "dimmensions. Expected 0 or 2 dimmensions."
111
113
  )
114
+
115
+ if dtype is None:
116
+ dtype = scale.dtype
117
+
112
118
  return _process_quantization(
113
119
  x=x_q,
114
120
  scale=scale,
@@ -116,6 +122,8 @@ def dequantize(
116
122
  args=args,
117
123
  do_quantize=False,
118
124
  do_dequantize=True,
125
+ dtype=dtype,
126
+ g_idx=g_idx,
119
127
  )
120
128
 
121
129
 
@@ -125,6 +133,7 @@ def fake_quantize(
125
133
  scale: torch.Tensor,
126
134
  zero_point: torch.Tensor,
127
135
  args: QuantizationArgs,
136
+ g_idx: Optional[torch.Tensor] = None,
128
137
  ) -> torch.Tensor:
129
138
  """
130
139
  Fake quantize the input tensor x by quantizing then dequantizing with
@@ -137,6 +146,7 @@ def fake_quantize(
137
146
  :param scale: scale tensor
138
147
  :param zero_point: zero point tensor
139
148
  :param args: quantization args dictating how to quantize x
149
+ :param g_idx: optional mapping from column index to group index
140
150
  :return: fake quantized tensor
141
151
  """
142
152
  return _process_quantization(
@@ -146,6 +156,7 @@ def fake_quantize(
146
156
  args=args,
147
157
  do_quantize=True,
148
158
  do_dequantize=True,
159
+ g_idx=g_idx,
149
160
  )
150
161
 
151
162
 
@@ -154,64 +165,85 @@ def _process_quantization(
154
165
  x: torch.Tensor,
155
166
  scale: torch.Tensor,
156
167
  zero_point: torch.Tensor,
168
+ g_idx: Optional[torch.Tensor],
157
169
  args: QuantizationArgs,
158
170
  dtype: Optional[torch.dtype] = None,
159
171
  do_quantize: bool = True,
160
172
  do_dequantize: bool = True,
161
173
  ) -> torch.Tensor:
162
- bit_range = 2**args.num_bits
163
- q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
164
- q_min = torch.tensor(-bit_range / 2, device=x.device)
174
+ q_min, q_max = calculate_range(args, x.device)
165
175
  group_size = args.group_size
166
176
 
167
177
  if args.strategy == QuantizationStrategy.GROUP:
168
-
169
- if do_dequantize and not do_quantize:
170
- # if dequantizing a quantized type infer the output type from the scale
171
- output = torch.zeros_like(x, dtype=scale.dtype)
172
- else:
173
- output_dtype = dtype if dtype is not None else x.dtype
174
- output = torch.zeros_like(x, dtype=output_dtype)
175
-
176
- # TODO: vectorize the for loop
177
- # TODO: fix genetric assumption about the tensor size for computing group
178
+ output_dtype = dtype if dtype is not None else x.dtype
179
+ output = torch.zeros_like(x).to(output_dtype)
180
+ columns = output.shape[1]
178
181
 
179
182
  # TODO: make validation step for inputs
180
183
 
181
184
  while scale.ndim < 2:
182
185
  # pad scale and zero point dims for slicing
183
186
  scale = scale.unsqueeze(1)
184
- zero_point = zero_point.unsqueeze(1)
187
+ zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
185
188
 
186
- columns = x.shape[1]
187
189
  if columns >= group_size:
188
190
  if columns % group_size != 0:
189
191
  raise ValueError(
190
- "tesnor column shape must be divisble "
192
+ "tensor column shape must be divisble "
191
193
  f"by the given group_size {group_size}"
192
194
  )
193
- for i in range(ceil(columns / group_size)):
194
- # scale.shape should be [nchan, ndim]
195
- # sc.shape should be [nchan, 1] after unsqueeze
196
- sc = scale[:, i].view(-1, 1)
197
- zp = zero_point[:, i].view(-1, 1)
198
195
 
199
- idx = i * group_size
196
+ # support column-order (default) quantization as well as other orderings
197
+ # such as activation ordering. Below checks if g_idx has been initialized
198
+ is_column_order = g_idx is None or -1 in g_idx
199
+ if is_column_order:
200
+ num_groups = int(ceil(columns / group_size))
201
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
202
+
203
+ else:
204
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
205
+ group_sizes = group_sizes[torch.argsort(group_indices)]
206
+
207
+ perm = torch.argsort(g_idx)
208
+ x = safe_permute(x, perm, dim=1)
209
+
210
+ # TODO: experiment with vectorizing for loop for performance
211
+ end = 0
212
+ for index, group_count in enumerate(group_sizes):
213
+ sc = scale[:, index].view(-1, 1)
214
+ zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
215
+
216
+ start = end
217
+ end = start + group_count
200
218
  if do_quantize:
201
- output[:, idx : (idx + group_size)] = _quantize(
202
- x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
219
+ output[:, start:end] = _quantize(
220
+ x[:, start:end],
221
+ sc,
222
+ zp,
223
+ q_min,
224
+ q_max,
225
+ args,
226
+ dtype=dtype,
203
227
  )
228
+
204
229
  if do_dequantize:
205
- input = (
206
- output[:, idx : (idx + group_size)]
207
- if do_quantize
208
- else x[:, idx : (idx + group_size)]
209
- )
210
- output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
230
+ input = output[:, start:end] if do_quantize else x[:, start:end]
231
+ output[:, start:end] = _dequantize(input, sc, zp)
232
+
233
+ if not is_column_order:
234
+ output = safe_permute(output, torch.argsort(perm), dim=1)
211
235
 
212
236
  else: # covers channel, token and tensor strategies
213
237
  if do_quantize:
214
- output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
238
+ output = _quantize(
239
+ x,
240
+ scale,
241
+ zero_point,
242
+ q_min,
243
+ q_max,
244
+ args,
245
+ dtype=dtype,
246
+ )
215
247
  if do_dequantize:
216
248
  output = _dequantize(output if do_quantize else x, scale, zero_point)
217
249
 
@@ -228,7 +260,13 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
228
260
 
229
261
  @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
230
262
  def wrapped_forward(self, *args, **kwargs):
263
+ if not getattr(module, "quantization_enabled", True):
264
+ # quantization is disabled on forward passes, return baseline
265
+ # forward call
266
+ return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
267
+
231
268
  input_ = args[0]
269
+ compressed = module.quantization_status == QuantizationStatus.COMPRESSED
232
270
 
233
271
  if scheme.input_activations is not None:
234
272
  # calibrate and (fake) quantize input activations when applicable
@@ -236,7 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
236
274
  module, input_, "input", scheme.input_activations
237
275
  )
238
276
 
239
- if scheme.weights is not None:
277
+ if scheme.weights is not None and not compressed:
240
278
  # calibrate and (fake) quantize weights when applicable
241
279
  unquantized_weight = self.weight.data.clone()
242
280
  self.weight.data = maybe_calibrate_or_quantize(
@@ -255,7 +293,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
255
293
  )
256
294
 
257
295
  # restore back to unquantized_value
258
- if scheme.weights is not None:
296
+ if scheme.weights is not None and not compressed:
259
297
  self.weight.data = unquantized_weight
260
298
 
261
299
  return output
@@ -269,33 +307,57 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
269
307
  def maybe_calibrate_or_quantize(
270
308
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
271
309
  ) -> torch.Tensor:
272
- # only run quantized for the included stages
273
- if module.quantization_status not in {
274
- QuantizationStatus.CALIBRATION,
275
- QuantizationStatus.FROZEN,
276
- }:
310
+ # don't run quantization if we haven't entered calibration mode
311
+ if module.quantization_status == QuantizationStatus.INITIALIZED:
277
312
  return value
278
313
 
314
+ # in compressed mode, the weight is already compressed and quantized so we don't
315
+ # need to run fake quantization
316
+ if (
317
+ module.quantization_status == QuantizationStatus.COMPRESSED
318
+ and base_name == "weight"
319
+ ):
320
+ return value
321
+
322
+ if value.numel() == 0:
323
+ # if the tensor is empty,
324
+ # skip quantization
325
+ return value
326
+
327
+ g_idx = getattr(module, "weight_g_idx", None)
328
+
279
329
  if args.dynamic:
280
330
  # dynamic quantization - get scale and zero point directly from observer
281
331
  observer = getattr(module, f"{base_name}_observer")
282
- scale, zero_point = observer(value)
332
+ scale, zero_point = observer(value, g_idx=g_idx)
283
333
  else:
284
334
  # static quantization - get previous scale and zero point from layer
285
335
  scale = getattr(module, f"{base_name}_scale")
286
- zero_point = getattr(module, f"{base_name}_zero_point")
336
+ zero_point = getattr(module, f"{base_name}_zero_point", None)
287
337
 
288
- if module.quantization_status == QuantizationStatus.CALIBRATION:
338
+ if (
339
+ module.quantization_status == QuantizationStatus.CALIBRATION
340
+ and base_name != "weight"
341
+ ):
289
342
  # calibration mode - get new quant params from observer
290
343
  observer = getattr(module, f"{base_name}_observer")
291
344
 
292
- updated_scale, updated_zero_point = observer(value)
345
+ updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
293
346
 
294
347
  # update scale and zero point
295
- device = next(module.parameters()).device
296
- scale.data = updated_scale.to(device)
297
- zero_point.data = updated_zero_point.to(device)
298
- return fake_quantize(value, scale, zero_point, args)
348
+ update_parameter_data(module, updated_scale, f"{base_name}_scale")
349
+ update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
350
+
351
+ scale = updated_scale
352
+ zero_point = updated_zero_point
353
+
354
+ return fake_quantize(
355
+ x=value,
356
+ scale=scale,
357
+ zero_point=zero_point,
358
+ args=args,
359
+ g_idx=g_idx,
360
+ )
299
361
 
300
362
 
301
363
  @torch.no_grad()
@@ -305,14 +367,20 @@ def _quantize(
305
367
  zero_point: torch.Tensor,
306
368
  q_min: torch.Tensor,
307
369
  q_max: torch.Tensor,
370
+ args: QuantizationArgs,
308
371
  dtype: Optional[torch.dtype] = None,
309
372
  ) -> torch.Tensor:
310
- quantized_value = torch.clamp(
311
- torch.round(x / scale + zero_point),
373
+
374
+ scaled = x / scale
375
+ if zero_point is not None:
376
+ scaled += zero_point.to(x.dtype)
377
+ # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
378
+ clamped_value = torch.clamp(
379
+ scaled,
312
380
  q_min,
313
381
  q_max,
314
382
  )
315
-
383
+ quantized_value = round_to_quantized_type(clamped_value, args)
316
384
  if dtype is not None:
317
385
  quantized_value = quantized_value.to(dtype)
318
386
 
@@ -323,6 +391,16 @@ def _quantize(
323
391
  def _dequantize(
324
392
  x_q: torch.Tensor,
325
393
  scale: torch.Tensor,
326
- zero_point: torch.Tensor,
394
+ zero_point: torch.Tensor = None,
395
+ dtype: Optional[torch.dtype] = None,
327
396
  ) -> torch.Tensor:
328
- return (x_q - zero_point) * scale
397
+ dequant_value = x_q.to(scale.dtype)
398
+
399
+ if zero_point is not None:
400
+ dequant_value = dequant_value - zero_point.to(scale.dtype)
401
+ dequant_value = dequant_value * scale
402
+
403
+ if dtype is not None:
404
+ dequant_value = dequant_value.to(dtype)
405
+
406
+ return dequant_value
@@ -0,0 +1,80 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Miscelaneous helpers for the quantization lifecycle
17
+ """
18
+
19
+ from typing import Optional
20
+
21
+ import torch
22
+ from torch.nn import Module
23
+
24
+
25
+ __all__ = [
26
+ "update_layer_weight_quant_params",
27
+ "enable_quantization",
28
+ "disable_quantization",
29
+ ]
30
+
31
+
32
+ def update_layer_weight_quant_params(
33
+ layer: Module,
34
+ weight: Optional[torch.Tensor] = None,
35
+ g_idx: Optional[torch.Tensor] = None,
36
+ reset_obs: bool = False,
37
+ ):
38
+ """
39
+ Update quantization parameters on layer
40
+
41
+ :param layer: input layer
42
+ :param weight: weight to update quant params with, defaults to layer weight
43
+ :param g_idx: optional mapping from column index to group index
44
+ :param reset_obs: reset the observer before calculating quant params,
45
+ defaults to False
46
+ """
47
+ attached_weight = getattr(layer, "weight", None)
48
+
49
+ if weight is None:
50
+ weight = attached_weight
51
+ scale = getattr(layer, "weight_scale", None)
52
+ zero_point = getattr(layer, "weight_zero_point", None)
53
+ if g_idx is None:
54
+ g_idx = getattr(layer, "weight_g_idx", None)
55
+ observer = getattr(layer, "weight_observer", None)
56
+
57
+ if weight is None or observer is None or scale is None or zero_point is None:
58
+ # scale, zp, or observer not calibratable or weight not available
59
+ return
60
+
61
+ if reset_obs:
62
+ observer.reset()
63
+
64
+ if attached_weight is not None:
65
+ weight = weight.to(attached_weight.dtype)
66
+
67
+ updated_scale, updated_zero_point = observer(weight)
68
+
69
+ # update scale and zero point
70
+ device = next(layer.parameters()).device
71
+ scale.data = updated_scale.to(device)
72
+ zero_point.data = updated_zero_point.to(device)
73
+
74
+
75
+ def enable_quantization(module: Module):
76
+ module.quantization_enabled = True
77
+
78
+
79
+ def disable_quantization(module: Module):
80
+ module.quantization_enabled = False
@@ -21,11 +21,13 @@ from compressed_tensors.quantization.lifecycle.forward import (
21
21
  wrap_module_forward_quantized,
22
22
  )
23
23
  from compressed_tensors.quantization.quant_args import (
24
+ ActivationOrdering,
24
25
  QuantizationArgs,
25
26
  QuantizationStrategy,
26
27
  )
27
28
  from compressed_tensors.quantization.quant_config import QuantizationStatus
28
29
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
30
+ from compressed_tensors.utils import get_execution_device, is_module_offloaded
29
31
  from torch.nn import Module, Parameter
30
32
 
31
33
 
@@ -40,6 +42,7 @@ _LOGGER = logging.getLogger(__name__)
40
42
  def initialize_module_for_quantization(
41
43
  module: Module,
42
44
  scheme: Optional[QuantizationScheme] = None,
45
+ force_zero_point: bool = True,
43
46
  ):
44
47
  """
45
48
  attaches appropriate scales, zero points, and observers to a layer
@@ -51,6 +54,8 @@ def initialize_module_for_quantization(
51
54
  :param scheme: scheme to use for quantization. if None is provided,
52
55
  will attempt to use scheme stored in the module under `quantization_scheme`,
53
56
  if not provided, the layer will be skipped
57
+ :param force_zero_point: whether to force initialization of a zero point for
58
+ symmetric quantization
54
59
  """
55
60
  scheme = scheme or getattr(module, "quantization_scheme", None)
56
61
  if scheme is None:
@@ -58,14 +63,18 @@ def initialize_module_for_quantization(
58
63
  return
59
64
 
60
65
  if scheme.input_activations is not None:
61
- _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
66
+ _initialize_scale_zero_point_observer(
67
+ module, "input", scheme.input_activations, force_zero_point=force_zero_point
68
+ )
62
69
  if scheme.weights is not None:
63
70
  if hasattr(module, "weight"):
64
- weight_shape = None
65
- if isinstance(module, torch.nn.Linear):
66
- weight_shape = module.weight.shape
71
+ weight_shape = module.weight.shape
67
72
  _initialize_scale_zero_point_observer(
68
- module, "weight", scheme.weights, weight_shape=weight_shape
73
+ module,
74
+ "weight",
75
+ scheme.weights,
76
+ weight_shape=weight_shape,
77
+ force_zero_point=force_zero_point,
69
78
  )
70
79
  else:
71
80
  _LOGGER.warning(
@@ -75,21 +84,58 @@ def initialize_module_for_quantization(
75
84
  )
76
85
  if scheme.output_activations is not None:
77
86
  _initialize_scale_zero_point_observer(
78
- module, "output", scheme.output_activations
87
+ module,
88
+ "output",
89
+ scheme.output_activations,
90
+ force_zero_point=force_zero_point,
79
91
  )
80
92
 
81
93
  module.quantization_scheme = scheme
82
94
  module.quantization_status = QuantizationStatus.INITIALIZED
83
95
 
96
+ offloaded = False
97
+ if is_module_offloaded(module):
98
+ try:
99
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
100
+ from accelerate.utils import PrefixedDataset
101
+ except ModuleNotFoundError:
102
+ raise ModuleNotFoundError(
103
+ "Offloaded model detected. To use CPU offloading with "
104
+ "compressed-tensors the `accelerate` package must be installed, "
105
+ "run `pip install compressed-tensors[accelerate]`"
106
+ )
107
+
108
+ offloaded = True
109
+ hook = module._hf_hook
110
+ prefix_dict = module._hf_hook.weights_map
111
+ new_prefix = {}
112
+
113
+ # recreate the prefix dict (since it is immutable)
114
+ # and add quantization parameters
115
+ for key, data in module.named_parameters():
116
+ if key not in prefix_dict:
117
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
118
+ else:
119
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
120
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
121
+ remove_hook_from_module(module)
122
+
84
123
  # wrap forward call of module to perform quantized actions based on calltime status
85
124
  wrap_module_forward_quantized(module, scheme)
86
125
 
126
+ if offloaded:
127
+ # we need to re-add the hook for offloading now that we've wrapped forward
128
+ add_hook_to_module(module, hook)
129
+ if prefix_dict is not None:
130
+ module._hf_hook.weights_map = new_prefix_dict
131
+
87
132
 
88
133
  def _initialize_scale_zero_point_observer(
89
134
  module: Module,
90
135
  base_name: str,
91
136
  quantization_args: QuantizationArgs,
92
137
  weight_shape: Optional[torch.Size] = None,
138
+ force_zero_point: bool = True,
93
139
  ):
94
140
  # initialize observer module and attach as submodule
95
141
  observer = quantization_args.get_observer()
@@ -99,6 +145,8 @@ def _initialize_scale_zero_point_observer(
99
145
  return # no need to register a scale and zero point for a dynamic observer
100
146
 
101
147
  device = next(module.parameters()).device
148
+ if is_module_offloaded(module):
149
+ device = get_execution_device(module)
102
150
 
103
151
  # infer expected scale/zero point shape
104
152
  expected_shape = 1 # per tensor
@@ -113,15 +161,31 @@ def _initialize_scale_zero_point_observer(
113
161
  weight_shape[1] // quantization_args.group_size,
114
162
  )
115
163
 
116
- # initializes empty scale and zero point parameters for the module
164
+ scale_dtype = module.weight.dtype
165
+ if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
166
+ scale_dtype = torch.float16
167
+
168
+ # initializes empty scale, zero point, and g_idx parameters for the module
117
169
  init_scale = Parameter(
118
- torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
170
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
119
171
  requires_grad=False,
120
172
  )
121
173
  module.register_parameter(f"{base_name}_scale", init_scale)
122
174
 
123
- init_zero_point = Parameter(
124
- torch.empty(expected_shape, device=device, dtype=int),
125
- requires_grad=False,
126
- )
127
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
175
+ if force_zero_point or not quantization_args.symmetric:
176
+ zp_dtype = quantization_args.pytorch_dtype()
177
+ init_zero_point = Parameter(
178
+ torch.zeros(expected_shape, device=device, dtype=zp_dtype),
179
+ requires_grad=False,
180
+ )
181
+ module.register_parameter(f"{base_name}_zero_point", init_zero_point)
182
+
183
+ # only grouped activation ordering has g_idx
184
+ if quantization_args.actorder == ActivationOrdering.GROUP:
185
+ g_idx_shape = (weight_shape[1],)
186
+ g_idx_dtype = torch.int
187
+ init_g_idx = Parameter(
188
+ torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
189
+ requires_grad=False,
190
+ )
191
+ module.register_parameter(f"{base_name}_g_idx", init_g_idx)
@@ -19,3 +19,4 @@ from .helpers import *
19
19
  from .base import *
20
20
  from .memoryless import *
21
21
  from .min_max import *
22
+ from .mse import *