compressed-tensors 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. compressed_tensors/base.py +2 -1
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +11 -54
  4. compressed_tensors/compressors/dense.py +4 -4
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/int_quantized.py +126 -0
  7. compressed_tensors/compressors/marlin_24.py +250 -0
  8. compressed_tensors/compressors/model_compressor.py +315 -0
  9. compressed_tensors/compressors/pack_quantized.py +212 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +4 -4
  11. compressed_tensors/compressors/utils/__init__.py +19 -0
  12. compressed_tensors/compressors/utils/helpers.py +43 -0
  13. compressed_tensors/compressors/utils/permutations_24.py +65 -0
  14. compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
  15. compressed_tensors/config/base.py +7 -4
  16. compressed_tensors/config/dense.py +4 -4
  17. compressed_tensors/config/sparse_bitmask.py +3 -3
  18. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  19. compressed_tensors/quantization/lifecycle/apply.py +75 -19
  20. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors/quantization/lifecycle/forward.py +208 -22
  22. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors/quantization/lifecycle/initialize.py +33 -5
  24. compressed_tensors/quantization/observers/base.py +70 -5
  25. compressed_tensors/quantization/observers/helpers.py +6 -1
  26. compressed_tensors/quantization/observers/memoryless.py +17 -9
  27. compressed_tensors/quantization/observers/min_max.py +44 -13
  28. compressed_tensors/quantization/quant_args.py +33 -4
  29. compressed_tensors/quantization/quant_config.py +69 -21
  30. compressed_tensors/quantization/quant_scheme.py +81 -1
  31. compressed_tensors/quantization/utils/helpers.py +77 -8
  32. compressed_tensors/utils/helpers.py +26 -122
  33. compressed_tensors/utils/safetensors_load.py +3 -2
  34. compressed_tensors/version.py +53 -0
  35. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -9
  36. compressed_tensors-0.4.0.dist-info/RECORD +48 -0
  37. compressed_tensors-0.3.2.dist-info/RECORD +0 -38
  38. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
  39. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
  40. {compressed_tensors-0.3.2.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -13,15 +13,26 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from functools import wraps
16
+ from math import ceil
17
+ from typing import Optional
16
18
 
17
19
  import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
20
+ from compressed_tensors.quantization.quant_args import (
21
+ QuantizationArgs,
22
+ QuantizationStrategy,
23
+ )
19
24
  from compressed_tensors.quantization.quant_config import QuantizationStatus
20
25
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
21
26
  from torch.nn import Module
22
27
 
23
28
 
24
- __all__ = ["wrap_module_forward_quantized"]
29
+ __all__ = [
30
+ "quantize",
31
+ "dequantize",
32
+ "fake_quantize",
33
+ "wrap_module_forward_quantized",
34
+ "maybe_calibrate_or_quantize",
35
+ ]
25
36
 
26
37
 
27
38
  @torch.no_grad()
@@ -29,15 +40,39 @@ def quantize(
29
40
  x: torch.Tensor,
30
41
  scale: torch.Tensor,
31
42
  zero_point: torch.Tensor,
32
- q_min: torch.Tensor,
33
- q_max: torch.Tensor,
43
+ args: QuantizationArgs,
44
+ dtype: Optional[torch.dtype] = None,
34
45
  ) -> torch.Tensor:
35
- return torch.clamp(
36
- torch.round(
37
- x / scale + zero_point,
38
- ),
39
- q_min,
40
- q_max,
46
+ """
47
+ Quantize the input tensor x using the QuantizationStrategy specified in args.
48
+ Quantization can be done per tensor, channel, token or group. For group
49
+ quantization, the group_size must be divisible by the column size. The input scale
50
+ and zero_points are reshaped to support vectorization (Assumes 1 is the
51
+ channel dimension)
52
+
53
+ :param x: Input tensor
54
+ :param scale: scale tensor
55
+ :param zero_point: zero point tensor
56
+ :param args: quantization args dictating how to quantize x
57
+ :param dtype: optional dtype to cast the quantized output to
58
+ :return: fake quantized tensor
59
+ """
60
+ # ensure all tensors are on the same device
61
+ # assumes that the target device is the input
62
+ # tensor's device
63
+ if x.device != scale.device:
64
+ scale = scale.to(x.device)
65
+ if x.device != zero_point.device:
66
+ zero_point = zero_point.to(x.device)
67
+
68
+ return _process_quantization(
69
+ x=x,
70
+ scale=scale,
71
+ zero_point=zero_point,
72
+ args=args,
73
+ dtype=dtype,
74
+ do_quantize=True,
75
+ do_dequantize=False,
41
76
  )
42
77
 
43
78
 
@@ -46,8 +81,42 @@ def dequantize(
46
81
  x_q: torch.Tensor,
47
82
  scale: torch.Tensor,
48
83
  zero_point: torch.Tensor,
84
+ args: QuantizationArgs = None,
49
85
  ) -> torch.Tensor:
50
- return (x_q - zero_point) * scale
86
+ """
87
+ Dequantize a quantized input tensor x_q based on the strategy specified in args. If
88
+ args is not provided, the strategy will be inferred.
89
+
90
+ :param x: quantized input tensor
91
+ :param scale: scale tensor
92
+ :param zero_point: zero point tensor
93
+ :param args: quantization args used to quantize x_q
94
+ :return: dequantized float tensor
95
+ """
96
+ if args is None:
97
+ if scale.ndim == 0 or scale.ndim == 1:
98
+ args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
99
+ elif scale.ndim == 2:
100
+ if scale.shape[1] == 1:
101
+ args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
102
+ else:
103
+ group_size = int(x_q.shape[1] / scale.shape[1])
104
+ args = QuantizationArgs(
105
+ strategy=QuantizationStrategy.GROUP, group_size=group_size
106
+ )
107
+ else:
108
+ raise ValueError(
109
+ f"Could not infer a quantization strategy from scale with {scale.ndim} "
110
+ "dimmensions. Expected 0-2 dimmensions."
111
+ )
112
+ return _process_quantization(
113
+ x=x_q,
114
+ scale=scale,
115
+ zero_point=zero_point,
116
+ args=args,
117
+ do_quantize=False,
118
+ do_dequantize=True,
119
+ )
51
120
 
52
121
 
53
122
  @torch.no_grad()
@@ -56,19 +125,106 @@ def fake_quantize(
56
125
  scale: torch.Tensor,
57
126
  zero_point: torch.Tensor,
58
127
  args: QuantizationArgs,
128
+ ) -> torch.Tensor:
129
+ """
130
+ Fake quantize the input tensor x by quantizing then dequantizing with
131
+ the QuantizationStrategy specified in args. Quantization can be done per tensor,
132
+ channel, token or group. For group quantization, the group_size must be divisible
133
+ by the column size. The input scale and zero_points are reshaped to support
134
+ vectorization (Assumes 1 is the channel dimension)
135
+
136
+ :param x: Input tensor
137
+ :param scale: scale tensor
138
+ :param zero_point: zero point tensor
139
+ :param args: quantization args dictating how to quantize x
140
+ :return: fake quantized tensor
141
+ """
142
+ return _process_quantization(
143
+ x=x,
144
+ scale=scale,
145
+ zero_point=zero_point,
146
+ args=args,
147
+ do_quantize=True,
148
+ do_dequantize=True,
149
+ )
150
+
151
+
152
+ @torch.no_grad()
153
+ def _process_quantization(
154
+ x: torch.Tensor,
155
+ scale: torch.Tensor,
156
+ zero_point: torch.Tensor,
157
+ args: QuantizationArgs,
158
+ dtype: Optional[torch.dtype] = None,
159
+ do_quantize: bool = True,
160
+ do_dequantize: bool = True,
59
161
  ) -> torch.Tensor:
60
162
  bit_range = 2**args.num_bits
61
- max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
62
- min_q = torch.tensor(-bit_range / 2, device=x.device)
63
- Q = torch.zeros_like(x)
64
- Q = quantize(x, scale, zero_point, min_q, max_q)
65
- return dequantize(Q, scale, zero_point)
163
+ q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
164
+ q_min = torch.tensor(-bit_range / 2, device=x.device)
165
+ group_size = args.group_size
166
+
167
+ if args.strategy == QuantizationStrategy.GROUP:
168
+
169
+ if do_dequantize and not do_quantize:
170
+ # if dequantizing a quantized type infer the output type from the scale
171
+ output = torch.zeros_like(x, dtype=scale.dtype)
172
+ else:
173
+ output_dtype = dtype if dtype is not None else x.dtype
174
+ output = torch.zeros_like(x, dtype=output_dtype)
175
+
176
+ # TODO: vectorize the for loop
177
+ # TODO: fix genetric assumption about the tensor size for computing group
178
+
179
+ # TODO: make validation step for inputs
180
+
181
+ while scale.ndim < 2:
182
+ # pad scale and zero point dims for slicing
183
+ scale = scale.unsqueeze(1)
184
+ zero_point = zero_point.unsqueeze(1)
185
+
186
+ columns = x.shape[1]
187
+ if columns >= group_size:
188
+ if columns % group_size != 0:
189
+ raise ValueError(
190
+ "tesnor column shape must be divisble "
191
+ f"by the given group_size {group_size}"
192
+ )
193
+ for i in range(ceil(columns / group_size)):
194
+ # scale.shape should be [nchan, ndim]
195
+ # sc.shape should be [nchan, 1] after unsqueeze
196
+ sc = scale[:, i].view(-1, 1)
197
+ zp = zero_point[:, i].view(-1, 1)
198
+
199
+ idx = i * group_size
200
+ if do_quantize:
201
+ output[:, idx : (idx + group_size)] = _quantize(
202
+ x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
203
+ )
204
+ if do_dequantize:
205
+ input = (
206
+ output[:, idx : (idx + group_size)]
207
+ if do_quantize
208
+ else x[:, idx : (idx + group_size)]
209
+ )
210
+ output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
211
+
212
+ else: # covers channel, token and tensor strategies
213
+ if do_quantize:
214
+ output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
215
+ if do_dequantize:
216
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
217
+
218
+ return output
66
219
 
67
220
 
68
221
  def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
69
222
  # expects a module already initialized and injected with the parameters in
70
223
  # initialize_module_for_quantization
71
- forward_func_orig = module.forward.__func__
224
+ if hasattr(module.forward, "__func__"):
225
+ forward_func_orig = module.forward.__func__
226
+ else:
227
+ forward_func_orig = module.forward.func
72
228
 
73
229
  @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
74
230
  def wrapped_forward(self, *args, **kwargs):
@@ -76,14 +232,14 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
76
232
 
77
233
  if scheme.input_activations is not None:
78
234
  # calibrate and (fake) quantize input activations when applicable
79
- input_ = _maybe_calibrate_or_quantize(
235
+ input_ = maybe_calibrate_or_quantize(
80
236
  module, input_, "input", scheme.input_activations
81
237
  )
82
238
 
83
239
  if scheme.weights is not None:
84
240
  # calibrate and (fake) quantize weights when applicable
85
241
  unquantized_weight = self.weight.data.clone()
86
- self.weight.data = _maybe_calibrate_or_quantize(
242
+ self.weight.data = maybe_calibrate_or_quantize(
87
243
  module, self.weight, "weight", scheme.weights
88
244
  )
89
245
 
@@ -94,7 +250,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
94
250
 
95
251
  if scheme.output_activations is not None:
96
252
  # calibrate and (fake) quantize output activations when applicable
97
- output = _maybe_calibrate_or_quantize(
253
+ output = maybe_calibrate_or_quantize(
98
254
  module, output, "output", scheme.output_activations
99
255
  )
100
256
 
@@ -110,7 +266,7 @@ def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
110
266
  setattr(module, "forward", bound_wrapped_forward)
111
267
 
112
268
 
113
- def _maybe_calibrate_or_quantize(
269
+ def maybe_calibrate_or_quantize(
114
270
  module: Module, value: torch.Tensor, base_name: str, args: "QuantizationArgs"
115
271
  ) -> torch.Tensor:
116
272
  # only run quantized for the included stages
@@ -132,11 +288,41 @@ def _maybe_calibrate_or_quantize(
132
288
  if module.quantization_status == QuantizationStatus.CALIBRATION:
133
289
  # calibration mode - get new quant params from observer
134
290
  observer = getattr(module, f"{base_name}_observer")
291
+
135
292
  updated_scale, updated_zero_point = observer(value)
136
293
 
137
294
  # update scale and zero point
138
295
  device = next(module.parameters()).device
139
296
  scale.data = updated_scale.to(device)
140
297
  zero_point.data = updated_zero_point.to(device)
141
-
142
298
  return fake_quantize(value, scale, zero_point, args)
299
+
300
+
301
+ @torch.no_grad()
302
+ def _quantize(
303
+ x: torch.Tensor,
304
+ scale: torch.Tensor,
305
+ zero_point: torch.Tensor,
306
+ q_min: torch.Tensor,
307
+ q_max: torch.Tensor,
308
+ dtype: Optional[torch.dtype] = None,
309
+ ) -> torch.Tensor:
310
+ quantized_value = torch.clamp(
311
+ torch.round(x / scale + zero_point),
312
+ q_min,
313
+ q_max,
314
+ )
315
+
316
+ if dtype is not None:
317
+ quantized_value = quantized_value.to(dtype)
318
+
319
+ return quantized_value
320
+
321
+
322
+ @torch.no_grad()
323
+ def _dequantize(
324
+ x_q: torch.Tensor,
325
+ scale: torch.Tensor,
326
+ zero_point: torch.Tensor,
327
+ ) -> torch.Tensor:
328
+ return (x_q - zero_point) * scale
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
35
35
  # no quantization scheme nothing to do
36
36
  return
37
37
 
38
+ if module.quantization_status == QuantizationStatus.FROZEN:
39
+ # nothing to do, already frozen
40
+ return
41
+
38
42
  # delete observers from module if not dynamic
39
43
  if scheme.input_activations and not scheme.input_activations.dynamic:
40
44
  delattr(module, "input_observer")
@@ -20,7 +20,10 @@ import torch
20
20
  from compressed_tensors.quantization.lifecycle.forward import (
21
21
  wrap_module_forward_quantized,
22
22
  )
23
- from compressed_tensors.quantization.quant_args import QuantizationArgs
23
+ from compressed_tensors.quantization.quant_args import (
24
+ QuantizationArgs,
25
+ QuantizationStrategy,
26
+ )
24
27
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
28
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
29
  from torch.nn import Module, Parameter
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
58
61
  _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
62
  if scheme.weights is not None:
60
63
  if hasattr(module, "weight"):
61
- _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
64
+ weight_shape = None
65
+ if isinstance(module, torch.nn.Linear):
66
+ weight_shape = module.weight.shape
67
+ _initialize_scale_zero_point_observer(
68
+ module, "weight", scheme.weights, weight_shape=weight_shape
69
+ )
62
70
  else:
63
71
  _LOGGER.warning(
64
72
  f"module type {type(module)} targeted for weight quantization but "
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
78
86
 
79
87
 
80
88
  def _initialize_scale_zero_point_observer(
81
- module: Module, base_name: str, quantization_args: QuantizationArgs
89
+ module: Module,
90
+ base_name: str,
91
+ quantization_args: QuantizationArgs,
92
+ weight_shape: Optional[torch.Size] = None,
82
93
  ):
83
94
  # initialize observer module and attach as submodule
84
95
  observer = quantization_args.get_observer()
@@ -89,11 +100,28 @@ def _initialize_scale_zero_point_observer(
89
100
 
90
101
  device = next(module.parameters()).device
91
102
 
103
+ # infer expected scale/zero point shape
104
+ expected_shape = 1 # per tensor
105
+
106
+ if base_name == "weight" and weight_shape is not None:
107
+ if quantization_args.strategy == QuantizationStrategy.CHANNEL:
108
+ # (output_channels, 1)
109
+ expected_shape = (weight_shape[0], 1)
110
+ elif quantization_args.strategy == QuantizationStrategy.GROUP:
111
+ expected_shape = (
112
+ weight_shape[0],
113
+ weight_shape[1] // quantization_args.group_size,
114
+ )
115
+
92
116
  # initializes empty scale and zero point parameters for the module
93
- init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
117
+ init_scale = Parameter(
118
+ torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
119
+ requires_grad=False,
120
+ )
94
121
  module.register_parameter(f"{base_name}_scale", init_scale)
95
122
 
96
123
  init_zero_point = Parameter(
97
- torch.empty(0, device=device, dtype=int), requires_grad=False
124
+ torch.empty(expected_shape, device=device, dtype=int),
125
+ requires_grad=False,
98
126
  )
99
127
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)
@@ -12,9 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Iterable, Optional, Tuple, Union
16
16
 
17
- from compressed_tensors.quantization.quant_args import QuantizationArgs
17
+ import torch
18
+ from compressed_tensors.quantization.quant_args import (
19
+ QuantizationArgs,
20
+ QuantizationStrategy,
21
+ )
18
22
  from compressed_tensors.registry.registry import RegistryMixin
19
23
  from torch import FloatTensor, IntTensor, Tensor
20
24
  from torch.nn import Module
@@ -36,6 +40,7 @@ class Observer(Module, RegistryMixin):
36
40
  self._scale = None
37
41
  self._zero_point = None
38
42
 
43
+ @torch.no_grad()
39
44
  def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
40
45
  """
41
46
  maps directly to get_qparams
@@ -45,13 +50,26 @@ class Observer(Module, RegistryMixin):
45
50
  """
46
51
  return self.get_qparams(observed=observed)
47
52
 
48
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
53
+ def calculate_qparams(
54
+ self,
55
+ observed: Tensor,
56
+ reduce_dims: Optional[Tuple[int]] = None,
57
+ ) -> Tuple[FloatTensor, IntTensor]:
49
58
  """
50
59
  :param observed: observed tensor to calculate quantization parameters for
60
+ :param reduce_dims: optional tuple of dimensions to reduce along,
61
+ returned scale and zero point will be shaped (1,) along the
62
+ reduced dimensions
51
63
  :return: tuple of scale and zero point derived from the observed tensor
52
64
  """
53
65
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
54
66
 
67
+ def post_calculate_qparams(self) -> None:
68
+ """
69
+ Run any logic specific to its observers after running calculate_qparams
70
+ """
71
+ ...
72
+
55
73
  def get_qparams(
56
74
  self, observed: Optional[Tensor] = None
57
75
  ) -> Tuple[FloatTensor, IntTensor]:
@@ -59,11 +77,58 @@ class Observer(Module, RegistryMixin):
59
77
  Convenience function to wrap overwritten calculate_qparams
60
78
  adds support to make observed tensor optional and support for tracking latest
61
79
  calculated scale and zero point
80
+
62
81
  :param observed: optional observed tensor to calculate quantization parameters
63
82
  from
64
83
  :return: tuple of scale and zero point based on last observed value
65
84
  """
66
85
  if observed is not None:
67
- # re-calcualte scale and zero point, update the stored value
68
- self._scale, self._zero_point = self.calculate_qparams(observed)
86
+ group_size = self.quantization_args.group_size
87
+
88
+ if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
89
+
90
+ # re-calculate scale and zero point, update the stored value
91
+ self._scale, self._zero_point = self.calculate_qparams(observed)
92
+
93
+ elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
94
+ columns = observed.shape[1]
95
+ scales, zero_points = [], []
96
+ group_idxs = range(0, columns, self.quantization_args.group_size)
97
+ for group_id, group_idx in enumerate(group_idxs):
98
+ scale, zero_point = self.get_qparams_along_dim(
99
+ observed[:, group_idx : (group_idx + group_size)],
100
+ 0,
101
+ tensor_id=group_id,
102
+ )
103
+ scales.append(scale)
104
+ zero_points.append(zero_point)
105
+
106
+ self._scale = torch.cat(scales, dim=1, out=self._scale)
107
+ self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
108
+
109
+ elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
110
+ # assume observed is transposed, because its the output, hence use dim 0
111
+ self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
112
+
113
+ elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
114
+ # use dim 1, assume the obsersed.shape = [batch, token, hidden]
115
+ # should be batch, token
116
+ self._scale, self._zero_point = self.get_qparams_along_dim(
117
+ observed,
118
+ dim={0, 1},
119
+ )
120
+
69
121
  return self._scale, self._zero_point
122
+
123
+ def get_qparams_along_dim(
124
+ self,
125
+ observed,
126
+ dim: Union[int, Iterable[int]],
127
+ tensor_id: Optional[Any] = None,
128
+ ):
129
+ dim = set(dim)
130
+
131
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
132
+ return self.calculate_qparams(
133
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
134
+ )
@@ -35,19 +35,24 @@ def calculate_qparams(
35
35
  """
36
36
  min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
37
37
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
+ device = min_vals.device
38
39
 
39
40
  bit_range = 2**quantization_args.num_bits - 1
40
41
  bit_min = -(bit_range + 1) / 2
41
42
  bit_max = bit_min + bit_range
42
43
  if quantization_args.symmetric:
43
- zero_points = torch.tensor(0).to(torch.int8)
44
44
  max_val_pos = torch.max(-min_vals, max_vals)
45
45
  scales = max_val_pos / (float(bit_range) / 2)
46
46
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
+ zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
47
48
  else:
48
49
  scales = (max_vals - min_vals) / float(bit_range)
49
50
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
50
51
  zero_points = bit_min - torch.round(min_vals / scales)
51
52
  zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
52
53
 
54
+ if scales.ndim == 0:
55
+ scales = scales.reshape(1)
56
+ zero_points = zero_points.reshape(1)
57
+
53
58
  return scales, zero_points
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
33
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
33
+ def calculate_qparams(
34
+ self,
35
+ observed: Tensor,
36
+ tensor_id: Optional[Any] = None,
37
+ reduce_dims: Optional[Tuple[int]] = None,
38
+ ) -> Tuple[FloatTensor, IntTensor]:
34
39
  """
35
- Returns the min and max values of observed
40
+ Returns the min and max values of observed tensor
36
41
 
37
42
  :param observed: observed tensor to calculate quantization parameters for
43
+ :param tensor_id: optional id for tensor; not used for memoryless
44
+ :param reduce_dims: optional tuple of dimensions to reduce along,
45
+ returned scale and zero point will be shaped (1,) along the
46
+ reduced dimensions
38
47
  :return: tuple of scale and zero point derived from the observed tensor
39
48
  """
40
- # TODO: Add support for full range of quantization Args, only supports 8bit
41
- # per tensor
42
- min_val, max_val = torch.aminmax(observed)
43
49
 
44
- # ensure zero is in the range
45
- min_val = torch.min(min_val, torch.zeros_like(min_val))
46
- max_val = torch.max(max_val, torch.zeros_like(max_val))
50
+ if not reduce_dims:
51
+ min_val, max_val = torch.aminmax(observed)
52
+ else:
53
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
54
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
47
55
 
48
56
  return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
36
36
  ):
37
37
  super().__init__(quantization_args=quantization_args)
38
38
 
39
- self.min_val = float("inf")
40
- self.max_val = -float("inf")
39
+ self.min_val = {}
40
+ self.max_val = {}
41
41
  self.averaging_constant = averaging_constant
42
42
 
43
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
43
+ def calculate_qparams(
44
+ self,
45
+ observed: Tensor,
46
+ reduce_dims: Optional[Tuple[int]] = None,
47
+ tensor_id: Optional[Any] = None,
48
+ ) -> Tuple[FloatTensor, IntTensor]:
44
49
  """
45
50
  Updates the observed min and max using a moving average smoothed by the
46
51
  averaging_constant
47
52
 
48
53
  :param observed: observed tensor to calculate quantization parameters for
54
+ :param reduce_dims: optional tuple of dimensions to reduce along,
55
+ returned scale and zero point will be shaped (1,) along the
56
+ reduced dimensions
57
+ :param tensor_id: Optional id if different ranges of observed tensors are
58
+ passed, useful for sharding tensors by group_size
49
59
  :return: tuple of scale and zero point derived from the observed tensor
50
60
  """
61
+ tensor_id = tensor_id or "default"
51
62
 
52
- min_val, max_val = torch.aminmax(observed)
63
+ if not reduce_dims:
64
+ min_val, max_val = torch.aminmax(observed)
65
+ else:
66
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
67
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
68
+
69
+ running_min_val = self.min_val.get(tensor_id, None)
70
+ running_max_val = self.max_val.get(tensor_id, None)
53
71
 
54
- if self.min_val == float("inf") and self.max_val == float("-inf"):
55
- self.min_val = min_val
56
- self.max_val = max_val
72
+ if running_min_val is None or running_max_val is None:
73
+ updated_min_val = min_val
74
+ updated_max_val = max_val
57
75
  else:
58
- self.min_val = self.min_val + self.averaging_constant * (
59
- min_val - self.min_val
76
+ updated_min_val = running_min_val + self.averaging_constant * (
77
+ min_val - running_min_val
60
78
  )
61
- self.max_val = self.max_val + self.averaging_constant * (
62
- max_val - self.max_val
79
+ updated_max_val = running_max_val + self.averaging_constant * (
80
+ max_val - running_max_val
63
81
  )
64
82
 
65
- return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
83
+ self.min_val[tensor_id] = updated_min_val
84
+ self.max_val[tensor_id] = updated_max_val
85
+
86
+ return calculate_qparams(
87
+ updated_min_val, updated_max_val, self.quantization_args
88
+ )
89
+
90
+ def get_qparams_along_dim(
91
+ self, observed, dim: int, tensor_id: Optional[Any] = None
92
+ ):
93
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
94
+ return self.calculate_qparams(
95
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
+ )