compressed-tensors-nightly 0.5.0.20240814__py3-none-any.whl → 0.5.0.20240830__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. compressed_tensors/compressors/base.py +198 -8
  2. compressed_tensors/compressors/model_compressor.py +65 -1
  3. compressed_tensors/compressors/naive_quantized.py +71 -75
  4. compressed_tensors/compressors/pack_quantized.py +83 -94
  5. compressed_tensors/linear/__init__.py +13 -0
  6. compressed_tensors/linear/compressed_linear.py +87 -0
  7. compressed_tensors/quantization/lifecycle/apply.py +36 -4
  8. compressed_tensors/quantization/lifecycle/calibration.py +3 -2
  9. compressed_tensors/quantization/lifecycle/compressed.py +1 -1
  10. compressed_tensors/quantization/lifecycle/forward.py +67 -43
  11. compressed_tensors/quantization/lifecycle/helpers.py +29 -2
  12. compressed_tensors/quantization/lifecycle/initialize.py +50 -16
  13. compressed_tensors/quantization/observers/__init__.py +1 -0
  14. compressed_tensors/quantization/observers/base.py +54 -14
  15. compressed_tensors/quantization/observers/min_max.py +8 -0
  16. compressed_tensors/quantization/observers/mse.py +162 -0
  17. compressed_tensors/quantization/quant_args.py +48 -20
  18. compressed_tensors/utils/__init__.py +1 -0
  19. compressed_tensors/utils/helpers.py +13 -0
  20. compressed_tensors/utils/offload.py +7 -1
  21. compressed_tensors/utils/permute.py +70 -0
  22. compressed_tensors/utils/safetensors_load.py +2 -0
  23. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  24. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/METADATA +3 -2
  25. compressed_tensors_nightly-0.5.0.20240830.dist-info/RECORD +52 -0
  26. compressed_tensors_nightly-0.5.0.20240814.dist-info/RECORD +0 -48
  27. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/LICENSE +0 -0
  28. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/WHEEL +0 -0
  29. {compressed_tensors_nightly-0.5.0.20240814.dist-info → compressed_tensors_nightly-0.5.0.20240830.dist-info}/top_level.txt +0 -0
@@ -25,7 +25,7 @@ from compressed_tensors.quantization.quant_args import (
25
25
  )
26
26
  from compressed_tensors.quantization.quant_config import QuantizationStatus
27
27
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
28
- from compressed_tensors.utils import update_parameter_data
28
+ from compressed_tensors.utils import safe_permute, update_parameter_data
29
29
  from torch.nn import Module
30
30
 
31
31
 
@@ -45,6 +45,7 @@ def quantize(
45
45
  zero_point: torch.Tensor,
46
46
  args: QuantizationArgs,
47
47
  dtype: Optional[torch.dtype] = None,
48
+ g_idx: Optional[torch.Tensor] = None,
48
49
  ) -> torch.Tensor:
49
50
  """
50
51
  Quantize the input tensor x using the QuantizationStrategy specified in args.
@@ -58,16 +59,9 @@ def quantize(
58
59
  :param zero_point: zero point tensor
59
60
  :param args: quantization args dictating how to quantize x
60
61
  :param dtype: optional dtype to cast the quantized output to
62
+ :param g_idx: optional mapping from column index to group index
61
63
  :return: fake quantized tensor
62
64
  """
63
- # ensure all tensors are on the same device
64
- # assumes that the target device is the input
65
- # tensor's device
66
- if x.device != scale.device:
67
- scale = scale.to(x.device)
68
- if x.device != zero_point.device:
69
- zero_point = zero_point.to(x.device)
70
-
71
65
  return _process_quantization(
72
66
  x=x,
73
67
  scale=scale,
@@ -76,6 +70,7 @@ def quantize(
76
70
  dtype=dtype,
77
71
  do_quantize=True,
78
72
  do_dequantize=False,
73
+ g_idx=g_idx,
79
74
  )
80
75
 
81
76
 
@@ -86,6 +81,7 @@ def dequantize(
86
81
  zero_point: torch.Tensor = None,
87
82
  args: QuantizationArgs = None,
88
83
  dtype: Optional[torch.dtype] = None,
84
+ g_idx: Optional[torch.Tensor] = None,
89
85
  ) -> torch.Tensor:
90
86
  """
91
87
  Dequantize a quantized input tensor x_q based on the strategy specified in args. If
@@ -96,6 +92,7 @@ def dequantize(
96
92
  :param zero_point: zero point tensor
97
93
  :param args: quantization args used to quantize x_q
98
94
  :param dtype: optional dtype to cast the dequantized output to
95
+ :param g_idx: optional mapping from column index to group index
99
96
  :return: dequantized float tensor
100
97
  """
101
98
  if args is None:
@@ -126,6 +123,7 @@ def dequantize(
126
123
  do_quantize=False,
127
124
  do_dequantize=True,
128
125
  dtype=dtype,
126
+ g_idx=g_idx,
129
127
  )
130
128
 
131
129
 
@@ -135,6 +133,7 @@ def fake_quantize(
135
133
  scale: torch.Tensor,
136
134
  zero_point: torch.Tensor,
137
135
  args: QuantizationArgs,
136
+ g_idx: Optional[torch.Tensor] = None,
138
137
  ) -> torch.Tensor:
139
138
  """
140
139
  Fake quantize the input tensor x by quantizing then dequantizing with
@@ -147,6 +146,7 @@ def fake_quantize(
147
146
  :param scale: scale tensor
148
147
  :param zero_point: zero point tensor
149
148
  :param args: quantization args dictating how to quantize x
149
+ :param g_idx: optional mapping from column index to group index
150
150
  :return: fake quantized tensor
151
151
  """
152
152
  return _process_quantization(
@@ -156,6 +156,7 @@ def fake_quantize(
156
156
  args=args,
157
157
  do_quantize=True,
158
158
  do_dequantize=True,
159
+ g_idx=g_idx,
159
160
  )
160
161
 
161
162
 
@@ -164,21 +165,19 @@ def _process_quantization(
164
165
  x: torch.Tensor,
165
166
  scale: torch.Tensor,
166
167
  zero_point: torch.Tensor,
168
+ g_idx: Optional[torch.Tensor],
167
169
  args: QuantizationArgs,
168
170
  dtype: Optional[torch.dtype] = None,
169
171
  do_quantize: bool = True,
170
172
  do_dequantize: bool = True,
171
173
  ) -> torch.Tensor:
172
-
173
174
  q_min, q_max = calculate_range(args, x.device)
174
175
  group_size = args.group_size
175
176
 
176
177
  if args.strategy == QuantizationStrategy.GROUP:
177
178
  output_dtype = dtype if dtype is not None else x.dtype
178
179
  output = torch.zeros_like(x).to(output_dtype)
179
-
180
- # TODO: vectorize the for loop
181
- # TODO: fix genetric assumption about the tensor size for computing group
180
+ columns = output.shape[1]
182
181
 
183
182
  # TODO: make validation step for inputs
184
183
 
@@ -187,23 +186,38 @@ def _process_quantization(
187
186
  scale = scale.unsqueeze(1)
188
187
  zero_point = zero_point.unsqueeze(1) if zero_point is not None else None
189
188
 
190
- columns = x.shape[1]
191
189
  if columns >= group_size:
192
190
  if columns % group_size != 0:
193
191
  raise ValueError(
194
- "tesnor column shape must be divisble "
192
+ "tensor column shape must be divisble "
195
193
  f"by the given group_size {group_size}"
196
194
  )
197
- for i in range(ceil(columns / group_size)):
198
- # scale.shape should be [nchan, ndim]
199
- # sc.shape should be [nchan, 1] after unsqueeze
200
- sc = scale[:, i].view(-1, 1)
201
- zp = zero_point[:, i].view(-1, 1) if zero_point is not None else None
202
195
 
203
- idx = i * group_size
196
+ # support column-order (default) quantization as well as other orderings
197
+ # such as activation ordering. Below checks if g_idx has been initialized
198
+ is_column_order = g_idx is None or -1 in g_idx
199
+ if is_column_order:
200
+ num_groups = int(ceil(columns / group_size))
201
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
202
+
203
+ else:
204
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
205
+ group_sizes = group_sizes[torch.argsort(group_indices)]
206
+
207
+ perm = torch.argsort(g_idx)
208
+ x = safe_permute(x, perm, dim=1)
209
+
210
+ # TODO: experiment with vectorizing for loop for performance
211
+ end = 0
212
+ for index, group_count in enumerate(group_sizes):
213
+ sc = scale[:, index].view(-1, 1)
214
+ zp = zero_point[:, index].view(-1, 1) if zero_point is not None else None
215
+
216
+ start = end
217
+ end = start + group_count
204
218
  if do_quantize:
205
- output[:, idx : (idx + group_size)] = _quantize(
206
- x[:, idx : (idx + group_size)],
219
+ output[:, start:end] = _quantize(
220
+ x[:, start:end],
207
221
  sc,
208
222
  zp,
209
223
  q_min,
@@ -211,13 +225,13 @@ def _process_quantization(
211
225
  args,
212
226
  dtype=dtype,
213
227
  )
228
+
214
229
  if do_dequantize:
215
- input = (
216
- output[:, idx : (idx + group_size)]
217
- if do_quantize
218
- else x[:, idx : (idx + group_size)]
219
- )
220
- output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
230
+ input = output[:, start:end] if do_quantize else x[:, start:end]
231
+ output[:, start:end] = _dequantize(input, sc, zp)
232
+
233
+ if not is_column_order:
234
+ output = safe_permute(output, torch.argsort(perm), dim=1)
221
235
 
222
236
  else: # covers channel, token and tensor strategies
223
237
  if do_quantize:
@@ -252,6 +266,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
252
266
  return forward_func_orig.__get__(module, module.__class__)(*args, **kwargs)
253
267
 
254
268
  input_ = args[0]
269
+ compressed = module.quantization_status == QuantizationStatus.COMPRESSED
255
270
 
256
271
  if scheme.input_activations is not None:
257
272
  # calibrate and (fake) quantize input activations when applicable
@@ -259,7 +274,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
259
274
  module, input_, "input", scheme.input_activations
260
275
  )
261
276
 
262
- if scheme.weights is not None:
277
+ if scheme.weights is not None and not compressed:
263
278
  # calibrate and (fake) quantize weights when applicable
264
279
  unquantized_weight = self.weight.data.clone()
265
280
  self.weight.data = maybe_calibrate_or_quantize(
@@ -278,7 +293,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
278
293
  )
279
294
 
280
295
  # restore back to unquantized_value
281
- if scheme.weights is not None:
296
+ if scheme.weights is not None and not compressed:
282
297
  self.weight.data = unquantized_weight
283
298
 
284
299
  return output
@@ -292,11 +307,16 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
292
307
  def maybe_calibrate_or_quantize(
293
308
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
294
309
  ) -> torch.Tensor:
295
- # only run quantized for the included stages
296
- if module.quantization_status not in {
297
- QuantizationStatus.CALIBRATION,
298
- QuantizationStatus.FROZEN,
299
- }:
310
+ # don't run quantization if we haven't entered calibration mode
311
+ if module.quantization_status == QuantizationStatus.INITIALIZED:
312
+ return value
313
+
314
+ # in compressed mode, the weight is already compressed and quantized so we don't
315
+ # need to run fake quantization
316
+ if (
317
+ module.quantization_status == QuantizationStatus.COMPRESSED
318
+ and base_name == "weight"
319
+ ):
300
320
  return value
301
321
 
302
322
  if value.numel() == 0:
@@ -304,14 +324,16 @@ def maybe_calibrate_or_quantize(
304
324
  # skip quantization
305
325
  return value
306
326
 
327
+ g_idx = getattr(module, "weight_g_idx", None)
328
+
307
329
  if args.dynamic:
308
330
  # dynamic quantization - get scale and zero point directly from observer
309
331
  observer = getattr(module, f"{base_name}_observer")
310
- scale, zero_point = observer(value)
332
+ scale, zero_point = observer(value, g_idx=g_idx)
311
333
  else:
312
334
  # static quantization - get previous scale and zero point from layer
313
335
  scale = getattr(module, f"{base_name}_scale")
314
- zero_point = getattr(module, f"{base_name}_zero_point")
336
+ zero_point = getattr(module, f"{base_name}_zero_point", None)
315
337
 
316
338
  if (
317
339
  module.quantization_status == QuantizationStatus.CALIBRATION
@@ -320,13 +342,13 @@ def maybe_calibrate_or_quantize(
320
342
  # calibration mode - get new quant params from observer
321
343
  observer = getattr(module, f"{base_name}_observer")
322
344
 
323
- updated_scale, updated_zero_point = observer(value)
345
+ updated_scale, updated_zero_point = observer(value, g_idx=g_idx)
324
346
 
325
347
  # update scale and zero point
326
348
  update_parameter_data(module, updated_scale, f"{base_name}_scale")
327
349
  update_parameter_data(module, updated_zero_point, f"{base_name}_zero_point")
328
350
 
329
- return fake_quantize(value, scale, zero_point, args)
351
+ return fake_quantize(value, scale, zero_point, args, g_idx=g_idx)
330
352
 
331
353
 
332
354
  @torch.no_grad()
@@ -340,7 +362,9 @@ def _quantize(
340
362
  dtype: Optional[torch.dtype] = None,
341
363
  ) -> torch.Tensor:
342
364
 
343
- scaled = x / scale + zero_point.to(x.dtype)
365
+ scaled = x / scale
366
+ if zero_point is not None:
367
+ scaled += zero_point.to(x.dtype)
344
368
  # clamp first because cast isn't guaranteed to be saturated (ie for fp8)
345
369
  clamped_value = torch.clamp(
346
370
  scaled,
@@ -361,11 +385,11 @@ def _dequantize(
361
385
  zero_point: torch.Tensor = None,
362
386
  dtype: Optional[torch.dtype] = None,
363
387
  ) -> torch.Tensor:
388
+ dequant_value = x_q.to(scale.dtype)
364
389
 
365
- dequant_value = x_q
366
390
  if zero_point is not None:
367
391
  dequant_value = dequant_value - zero_point.to(scale.dtype)
368
- dequant_value = dequant_value.to(scale.dtype) * scale
392
+ dequant_value = dequant_value * scale
369
393
 
370
394
  if dtype is not None:
371
395
  dequant_value = dequant_value.to(dtype)
@@ -16,7 +16,9 @@
16
16
  Miscelaneous helpers for the quantization lifecycle
17
17
  """
18
18
 
19
+ from typing import Optional
19
20
 
21
+ import torch
20
22
  from torch.nn import Module
21
23
 
22
24
 
@@ -27,16 +29,41 @@ __all__ = [
27
29
  ]
28
30
 
29
31
 
30
- def update_layer_weight_quant_params(layer: Module):
31
- weight = getattr(layer, "weight", None)
32
+ def update_layer_weight_quant_params(
33
+ layer: Module,
34
+ weight: Optional[torch.Tensor] = None,
35
+ g_idx: Optional[torch.Tensor] = None,
36
+ reset_obs: bool = False,
37
+ ):
38
+ """
39
+ Update quantization parameters on layer
40
+
41
+ :param layer: input layer
42
+ :param weight: weight to update quant params with, defaults to layer weight
43
+ :param g_idx: optional mapping from column index to group index
44
+ :param reset_obs: reset the observer before calculating quant params,
45
+ defaults to False
46
+ """
47
+ attached_weight = getattr(layer, "weight", None)
48
+
49
+ if weight is None:
50
+ weight = attached_weight
32
51
  scale = getattr(layer, "weight_scale", None)
33
52
  zero_point = getattr(layer, "weight_zero_point", None)
53
+ if g_idx is None:
54
+ g_idx = getattr(layer, "weight_g_idx", None)
34
55
  observer = getattr(layer, "weight_observer", None)
35
56
 
36
57
  if weight is None or observer is None or scale is None or zero_point is None:
37
58
  # scale, zp, or observer not calibratable or weight not available
38
59
  return
39
60
 
61
+ if reset_obs:
62
+ observer.reset()
63
+
64
+ if attached_weight is not None:
65
+ weight = weight.to(attached_weight.dtype)
66
+
40
67
  updated_scale, updated_zero_point = observer(weight)
41
68
 
42
69
  # update scale and zero point
@@ -17,8 +17,6 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
21
- from accelerate.utils import PrefixedDataset
22
20
  from compressed_tensors.quantization.lifecycle.forward import (
23
21
  wrap_module_forward_quantized,
24
22
  )
@@ -43,6 +41,7 @@ _LOGGER = logging.getLogger(__name__)
43
41
  def initialize_module_for_quantization(
44
42
  module: Module,
45
43
  scheme: Optional[QuantizationScheme] = None,
44
+ force_zero_point: bool = True,
46
45
  ):
47
46
  """
48
47
  attaches appropriate scales, zero points, and observers to a layer
@@ -54,6 +53,8 @@ def initialize_module_for_quantization(
54
53
  :param scheme: scheme to use for quantization. if None is provided,
55
54
  will attempt to use scheme stored in the module under `quantization_scheme`,
56
55
  if not provided, the layer will be skipped
56
+ :param force_zero_point: whether to force initialization of a zero point for
57
+ symmetric quantization
57
58
  """
58
59
  scheme = scheme or getattr(module, "quantization_scheme", None)
59
60
  if scheme is None:
@@ -61,14 +62,18 @@ def initialize_module_for_quantization(
61
62
  return
62
63
 
63
64
  if scheme.input_activations is not None:
64
- _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
65
+ _initialize_scale_zero_point_observer(
66
+ module, "input", scheme.input_activations, force_zero_point=force_zero_point
67
+ )
65
68
  if scheme.weights is not None:
66
69
  if hasattr(module, "weight"):
67
- weight_shape = None
68
- if isinstance(module, torch.nn.Linear):
69
- weight_shape = module.weight.shape
70
+ weight_shape = module.weight.shape
70
71
  _initialize_scale_zero_point_observer(
71
- module, "weight", scheme.weights, weight_shape=weight_shape
72
+ module,
73
+ "weight",
74
+ scheme.weights,
75
+ weight_shape=weight_shape,
76
+ force_zero_point=force_zero_point,
72
77
  )
73
78
  else:
74
79
  _LOGGER.warning(
@@ -78,7 +83,10 @@ def initialize_module_for_quantization(
78
83
  )
79
84
  if scheme.output_activations is not None:
80
85
  _initialize_scale_zero_point_observer(
81
- module, "output", scheme.output_activations
86
+ module,
87
+ "output",
88
+ scheme.output_activations,
89
+ force_zero_point=force_zero_point,
82
90
  )
83
91
 
84
92
  module.quantization_scheme = scheme
@@ -86,6 +94,16 @@ def initialize_module_for_quantization(
86
94
 
87
95
  offloaded = False
88
96
  if is_module_offloaded(module):
97
+ try:
98
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
99
+ from accelerate.utils import PrefixedDataset
100
+ except ModuleNotFoundError:
101
+ raise ModuleNotFoundError(
102
+ "Offloaded model detected. To use CPU offloading with "
103
+ "compressed-tensors the `accelerate` package must be installed, "
104
+ "run `pip install compressed-tensors[accelerate]`"
105
+ )
106
+
89
107
  offloaded = True
90
108
  hook = module._hf_hook
91
109
  prefix_dict = module._hf_hook.weights_map
@@ -116,6 +134,7 @@ def _initialize_scale_zero_point_observer(
116
134
  base_name: str,
117
135
  quantization_args: QuantizationArgs,
118
136
  weight_shape: Optional[torch.Size] = None,
137
+ force_zero_point: bool = True,
119
138
  ):
120
139
  # initialize observer module and attach as submodule
121
140
  observer = quantization_args.get_observer()
@@ -141,16 +160,31 @@ def _initialize_scale_zero_point_observer(
141
160
  weight_shape[1] // quantization_args.group_size,
142
161
  )
143
162
 
144
- # initializes empty scale and zero point parameters for the module
163
+ scale_dtype = module.weight.dtype
164
+ if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
165
+ scale_dtype = torch.float16
166
+
167
+ # initializes empty scale, zero point, and g_idx parameters for the module
145
168
  init_scale = Parameter(
146
- torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
169
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
147
170
  requires_grad=False,
148
171
  )
149
172
  module.register_parameter(f"{base_name}_scale", init_scale)
150
173
 
151
- zp_dtype = quantization_args.pytorch_dtype()
152
- init_zero_point = Parameter(
153
- torch.empty(expected_shape, device=device, dtype=zp_dtype),
154
- requires_grad=False,
155
- )
156
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
174
+ if force_zero_point or not quantization_args.symmetric:
175
+ zp_dtype = quantization_args.pytorch_dtype()
176
+ init_zero_point = Parameter(
177
+ torch.zeros(expected_shape, device=device, dtype=zp_dtype),
178
+ requires_grad=False,
179
+ )
180
+ module.register_parameter(f"{base_name}_zero_point", init_zero_point)
181
+
182
+ # initialize with empty for actorder, to be populated by GPTQ or state_dict
183
+ if quantization_args.actorder:
184
+ g_idx_shape = (weight_shape[1],)
185
+ g_idx_dtype = torch.int
186
+ init_g_idx = Parameter(
187
+ torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
188
+ requires_grad=False,
189
+ )
190
+ module.register_parameter(f"{base_name}_g_idx", init_g_idx)
@@ -19,3 +19,4 @@ from .helpers import *
19
19
  from .base import *
20
20
  from .memoryless import *
21
21
  from .min_max import *
22
+ from .mse import *
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
+ from math import ceil
16
17
  from typing import Any, Iterable, Optional, Tuple, Union
17
18
 
18
19
  import torch
@@ -21,6 +22,7 @@ from compressed_tensors.quantization.quant_args import (
21
22
  QuantizationStrategy,
22
23
  )
23
24
  from compressed_tensors.registry.registry import RegistryMixin
25
+ from compressed_tensors.utils import safe_permute
24
26
  from torch import FloatTensor, IntTensor, Tensor
25
27
  from torch.nn import Module
26
28
 
@@ -46,15 +48,18 @@ class Observer(Module, RegistryMixin):
46
48
  self._num_observed_tokens = None
47
49
 
48
50
  @torch.no_grad()
49
- def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
51
+ def forward(
52
+ self, observed: Tensor, g_idx: Optional[Tensor] = None
53
+ ) -> Tuple[FloatTensor, IntTensor]:
50
54
  """
51
55
  maps directly to get_qparams
52
- :param observed: optional observed tensor to calculate quantization parameters
53
- from
56
+ :param observed: optional observed tensor from which to calculate
57
+ quantization parameters
58
+ :param g_idx: optional mapping from column index to group index
54
59
  :return: tuple of scale and zero point based on last observed value
55
60
  """
56
61
  self.record_observed_tokens(observed)
57
- return self.get_qparams(observed=observed)
62
+ return self.get_qparams(observed=observed, g_idx=g_idx)
58
63
 
59
64
  def calculate_qparams(
60
65
  self,
@@ -77,7 +82,9 @@ class Observer(Module, RegistryMixin):
77
82
  ...
78
83
 
79
84
  def get_qparams(
80
- self, observed: Optional[Tensor] = None
85
+ self,
86
+ observed: Optional[Tensor] = None,
87
+ g_idx: Optional[Tensor] = None,
81
88
  ) -> Tuple[FloatTensor, IntTensor]:
82
89
  """
83
90
  Convenience function to wrap overwritten calculate_qparams
@@ -86,6 +93,7 @@ class Observer(Module, RegistryMixin):
86
93
 
87
94
  :param observed: optional observed tensor to calculate quantization parameters
88
95
  from
96
+ :param g_idx: optional mapping from column index to group index
89
97
  :return: tuple of scale and zero point based on last observed value
90
98
  """
91
99
  if observed is not None:
@@ -97,20 +105,42 @@ class Observer(Module, RegistryMixin):
97
105
  self._scale, self._zero_point = self.calculate_qparams(observed)
98
106
 
99
107
  elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
108
+ rows = observed.shape[0]
100
109
  columns = observed.shape[1]
101
- scales, zero_points = [], []
102
- group_idxs = range(0, columns, self.quantization_args.group_size)
103
- for group_id, group_idx in enumerate(group_idxs):
110
+ num_groups = int(ceil(columns / group_size))
111
+ self._scale = torch.empty(
112
+ (rows, num_groups), dtype=observed.dtype, device=observed.device
113
+ )
114
+ zp_dtype = self.quantization_args.pytorch_dtype()
115
+ self._zero_point = torch.empty(
116
+ (rows, num_groups), dtype=zp_dtype, device=observed.device
117
+ )
118
+
119
+ # support column-order (default) quantization as well as other orderings
120
+ # such as activation ordering. Below checks if g_idx has initialized
121
+ is_column_order = g_idx is None or -1 in g_idx
122
+ if is_column_order:
123
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
124
+ else:
125
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
126
+ group_sizes = group_sizes[torch.argsort(group_indices)]
127
+
128
+ perm = torch.argsort(g_idx)
129
+ observed = safe_permute(observed, perm, dim=1)
130
+
131
+ # TODO: experiment with vectorizing for loop for performance
132
+ end = 0
133
+ for group_index, group_count in enumerate(group_sizes):
134
+ start = end
135
+ end = start + group_count
104
136
  scale, zero_point = self.get_qparams_along_dim(
105
- observed[:, group_idx : (group_idx + group_size)],
137
+ observed[:, start:end],
106
138
  0,
107
- tensor_id=group_id,
139
+ tensor_id=group_index,
108
140
  )
109
- scales.append(scale)
110
- zero_points.append(zero_point)
111
141
 
112
- self._scale = torch.cat(scales, dim=1, out=self._scale)
113
- self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
142
+ self._scale[:, group_index] = scale.squeeze(1)
143
+ self._zero_point[:, group_index] = zero_point.squeeze(1)
114
144
 
115
145
  elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
116
146
  # assume observed is transposed, because its the output, hence use dim 0
@@ -132,6 +162,8 @@ class Observer(Module, RegistryMixin):
132
162
  dim: Union[int, Iterable[int]],
133
163
  tensor_id: Optional[Any] = None,
134
164
  ):
165
+ if isinstance(dim, int):
166
+ dim = [dim]
135
167
  dim = set(dim)
136
168
 
137
169
  reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
@@ -171,3 +203,11 @@ class Observer(Module, RegistryMixin):
171
203
  # observed_tokens (batch_size * sequence_length)
172
204
  observed_tokens, _ = batch_tensor.shape
173
205
  self._num_observed_tokens += observed_tokens
206
+
207
+ def reset(self):
208
+ """
209
+ Reset the state of the observer
210
+ """
211
+ self._num_observed_tokens = None
212
+ self._scale = None
213
+ self._zero_point = None
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
94
94
  return self.calculate_qparams(
95
95
  observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
96
  )
97
+
98
+ def reset(self):
99
+ """
100
+ Reset the state of the observer, including min and maximum values
101
+ """
102
+ super().reset()
103
+ self.min_val = {}
104
+ self.max_val = {}