compressed-tensors 0.3.3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. compressed_tensors/base.py +2 -1
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +11 -54
  4. compressed_tensors/compressors/dense.py +4 -4
  5. compressed_tensors/compressors/helpers.py +12 -12
  6. compressed_tensors/compressors/int_quantized.py +126 -0
  7. compressed_tensors/compressors/marlin_24.py +250 -0
  8. compressed_tensors/compressors/model_compressor.py +315 -0
  9. compressed_tensors/compressors/pack_quantized.py +212 -0
  10. compressed_tensors/compressors/sparse_bitmask.py +3 -3
  11. compressed_tensors/compressors/utils/__init__.py +19 -0
  12. compressed_tensors/compressors/utils/helpers.py +43 -0
  13. compressed_tensors/compressors/utils/permutations_24.py +65 -0
  14. compressed_tensors/compressors/utils/semi_structured_conversions.py +341 -0
  15. compressed_tensors/config/base.py +7 -4
  16. compressed_tensors/config/dense.py +4 -4
  17. compressed_tensors/config/sparse_bitmask.py +3 -3
  18. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  19. compressed_tensors/quantization/lifecycle/apply.py +62 -11
  20. compressed_tensors/quantization/lifecycle/compressed.py +69 -0
  21. compressed_tensors/quantization/lifecycle/forward.py +161 -54
  22. compressed_tensors/quantization/lifecycle/frozen.py +4 -0
  23. compressed_tensors/quantization/lifecycle/initialize.py +33 -5
  24. compressed_tensors/quantization/observers/base.py +31 -27
  25. compressed_tensors/quantization/observers/helpers.py +6 -1
  26. compressed_tensors/quantization/observers/memoryless.py +17 -9
  27. compressed_tensors/quantization/observers/min_max.py +44 -13
  28. compressed_tensors/quantization/quant_args.py +2 -2
  29. compressed_tensors/quantization/quant_config.py +69 -21
  30. compressed_tensors/quantization/quant_scheme.py +81 -1
  31. compressed_tensors/quantization/utils/helpers.py +76 -8
  32. compressed_tensors/utils/helpers.py +24 -6
  33. compressed_tensors/utils/safetensors_load.py +3 -2
  34. compressed_tensors/version.py +53 -0
  35. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/METADATA +46 -8
  36. compressed_tensors-0.4.0.dist-info/RECORD +48 -0
  37. compressed_tensors-0.3.3.dist-info/RECORD +0 -38
  38. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/LICENSE +0 -0
  39. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/WHEEL +0 -0
  40. {compressed_tensors-0.3.3.dist-info → compressed_tensors-0.4.0.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,7 @@
14
14
 
15
15
  from functools import wraps
16
16
  from math import ceil
17
+ from typing import Optional
17
18
 
18
19
  import torch
19
20
  from compressed_tensors.quantization.quant_args import (
@@ -25,7 +26,13 @@ from compressed_tensors.quantization.quant_scheme import QuantizationScheme
25
26
  from torch.nn import Module
26
27
 
27
28
 
28
- __all__ = ["wrap_module_forward_quantized", "maybe_calibrate_or_quantize"]
29
+ __all__ = [
30
+ "quantize",
31
+ "dequantize",
32
+ "fake_quantize",
33
+ "wrap_module_forward_quantized",
34
+ "maybe_calibrate_or_quantize",
35
+ ]
29
36
 
30
37
 
31
38
  @torch.no_grad()
@@ -33,14 +40,39 @@ def quantize(
33
40
  x: torch.Tensor,
34
41
  scale: torch.Tensor,
35
42
  zero_point: torch.Tensor,
36
- q_min: torch.Tensor,
37
- q_max: torch.Tensor,
43
+ args: QuantizationArgs,
44
+ dtype: Optional[torch.dtype] = None,
38
45
  ) -> torch.Tensor:
46
+ """
47
+ Quantize the input tensor x using the QuantizationStrategy specified in args.
48
+ Quantization can be done per tensor, channel, token or group. For group
49
+ quantization, the group_size must be divisible by the column size. The input scale
50
+ and zero_points are reshaped to support vectorization (Assumes 1 is the
51
+ channel dimension)
39
52
 
40
- return torch.clamp(
41
- torch.round(x / scale + zero_point),
42
- q_min,
43
- q_max,
53
+ :param x: Input tensor
54
+ :param scale: scale tensor
55
+ :param zero_point: zero point tensor
56
+ :param args: quantization args dictating how to quantize x
57
+ :param dtype: optional dtype to cast the quantized output to
58
+ :return: fake quantized tensor
59
+ """
60
+ # ensure all tensors are on the same device
61
+ # assumes that the target device is the input
62
+ # tensor's device
63
+ if x.device != scale.device:
64
+ scale = scale.to(x.device)
65
+ if x.device != zero_point.device:
66
+ zero_point = zero_point.to(x.device)
67
+
68
+ return _process_quantization(
69
+ x=x,
70
+ scale=scale,
71
+ zero_point=zero_point,
72
+ args=args,
73
+ dtype=dtype,
74
+ do_quantize=True,
75
+ do_dequantize=False,
44
76
  )
45
77
 
46
78
 
@@ -49,8 +81,42 @@ def dequantize(
49
81
  x_q: torch.Tensor,
50
82
  scale: torch.Tensor,
51
83
  zero_point: torch.Tensor,
84
+ args: QuantizationArgs = None,
52
85
  ) -> torch.Tensor:
53
- return (x_q - zero_point) * scale
86
+ """
87
+ Dequantize a quantized input tensor x_q based on the strategy specified in args. If
88
+ args is not provided, the strategy will be inferred.
89
+
90
+ :param x: quantized input tensor
91
+ :param scale: scale tensor
92
+ :param zero_point: zero point tensor
93
+ :param args: quantization args used to quantize x_q
94
+ :return: dequantized float tensor
95
+ """
96
+ if args is None:
97
+ if scale.ndim == 0 or scale.ndim == 1:
98
+ args = QuantizationArgs(strategy=QuantizationStrategy.TENSOR)
99
+ elif scale.ndim == 2:
100
+ if scale.shape[1] == 1:
101
+ args = QuantizationArgs(strategy=QuantizationStrategy.CHANNEL)
102
+ else:
103
+ group_size = int(x_q.shape[1] / scale.shape[1])
104
+ args = QuantizationArgs(
105
+ strategy=QuantizationStrategy.GROUP, group_size=group_size
106
+ )
107
+ else:
108
+ raise ValueError(
109
+ f"Could not infer a quantization strategy from scale with {scale.ndim} "
110
+ "dimmensions. Expected 0-2 dimmensions."
111
+ )
112
+ return _process_quantization(
113
+ x=x_q,
114
+ scale=scale,
115
+ zero_point=zero_point,
116
+ args=args,
117
+ do_quantize=False,
118
+ do_dequantize=True,
119
+ )
54
120
 
55
121
 
56
122
  @torch.no_grad()
@@ -61,30 +127,51 @@ def fake_quantize(
61
127
  args: QuantizationArgs,
62
128
  ) -> torch.Tensor:
63
129
  """
64
- Fake quantize the input tensor x depending on the group_size.
65
- if group_size is greater than 0, then q/dq by groups. The groups
66
- must be divisible by the column size
67
- if group_size is -1, then channel wise q/dq. THe input scale and
68
- zero_points are reshaped to support vectorization (Assumes 1 is
69
- the channel dimension)
130
+ Fake quantize the input tensor x by quantizing then dequantizing with
131
+ the QuantizationStrategy specified in args. Quantization can be done per tensor,
132
+ channel, token or group. For group quantization, the group_size must be divisible
133
+ by the column size. The input scale and zero_points are reshaped to support
134
+ vectorization (Assumes 1 is the channel dimension)
70
135
 
71
136
  :param x: Input tensor
72
137
  :param scale: scale tensor
73
138
  :param zero_point: zero point tensor
74
- :param args: quantization args that contain group_size info
139
+ :param args: quantization args dictating how to quantize x
75
140
  :return: fake quantized tensor
76
-
77
141
  """
78
- bit_range = 2**args.num_bits
79
- max_q = torch.tensor(bit_range / 2 - 1, device=x.device)
80
- min_q = torch.tensor(-bit_range / 2, device=x.device)
142
+ return _process_quantization(
143
+ x=x,
144
+ scale=scale,
145
+ zero_point=zero_point,
146
+ args=args,
147
+ do_quantize=True,
148
+ do_dequantize=True,
149
+ )
150
+
81
151
 
152
+ @torch.no_grad()
153
+ def _process_quantization(
154
+ x: torch.Tensor,
155
+ scale: torch.Tensor,
156
+ zero_point: torch.Tensor,
157
+ args: QuantizationArgs,
158
+ dtype: Optional[torch.dtype] = None,
159
+ do_quantize: bool = True,
160
+ do_dequantize: bool = True,
161
+ ) -> torch.Tensor:
162
+ bit_range = 2**args.num_bits
163
+ q_max = torch.tensor(bit_range / 2 - 1, device=x.device)
164
+ q_min = torch.tensor(-bit_range / 2, device=x.device)
82
165
  group_size = args.group_size
83
166
 
84
- # group
85
167
  if args.strategy == QuantizationStrategy.GROUP:
86
168
 
87
- DQ = torch.zeros_like(x)
169
+ if do_dequantize and not do_quantize:
170
+ # if dequantizing a quantized type infer the output type from the scale
171
+ output = torch.zeros_like(x, dtype=scale.dtype)
172
+ else:
173
+ output_dtype = dtype if dtype is not None else x.dtype
174
+ output = torch.zeros_like(x, dtype=output_dtype)
88
175
 
89
176
  # TODO: vectorize the for loop
90
177
  # TODO: fix genetric assumption about the tensor size for computing group
@@ -106,48 +193,38 @@ def fake_quantize(
106
193
  for i in range(ceil(columns / group_size)):
107
194
  # scale.shape should be [nchan, ndim]
108
195
  # sc.shape should be [nchan, 1] after unsqueeze
109
-
110
- sc = scale[:, i].unsqueeze(1)
111
- zp = zero_point[:, i].unsqueeze(1)
196
+ sc = scale[:, i].view(-1, 1)
197
+ zp = zero_point[:, i].view(-1, 1)
112
198
 
113
199
  idx = i * group_size
114
- Q = quantize(x[:, idx : (idx + group_size)], sc, zp, min_q, max_q)
115
- DQ[:, idx : (idx + group_size)] = dequantize(Q, sc, zp)
116
-
117
- # channel-wise
118
- elif args.strategy == QuantizationStrategy.CHANNEL: # group_size == -1
119
- # before: scale shape = [channel_size]
120
- # after: scale shape = [1, channel_size]
121
- scale = scale.unsqueeze(0)
122
- zero_point = zero_point.unsqueeze(0)
123
-
124
- Q = quantize(x, scale, zero_point, min_q, max_q)
125
- DQ = dequantize(Q, scale, zero_point)
126
-
127
- # per-token
128
- elif args.strategy == QuantizationStrategy.TOKEN:
129
- # before: scale shape = [num_tokens]
130
- # after: scale shape = [num_tokens, 1]
131
- # x.shape = 1, num_tokens, 1]
132
- # scale gets broadcasted as expected withput having [1, num_tokens, 1] shape
133
-
134
- scale = scale.unsqueeze(1)
135
- zero_point = zero_point.unsqueeze(1)
136
-
137
- Q = quantize(x, scale, zero_point, min_q, max_q)
138
- DQ = dequantize(Q, scale, zero_point)
200
+ if do_quantize:
201
+ output[:, idx : (idx + group_size)] = _quantize(
202
+ x[:, idx : (idx + group_size)], sc, zp, q_min, q_max, dtype=dtype
203
+ )
204
+ if do_dequantize:
205
+ input = (
206
+ output[:, idx : (idx + group_size)]
207
+ if do_quantize
208
+ else x[:, idx : (idx + group_size)]
209
+ )
210
+ output[:, idx : (idx + group_size)] = _dequantize(input, sc, zp)
139
211
 
140
- else:
141
- Q = quantize(x, scale, zero_point, min_q, max_q)
142
- DQ = dequantize(Q, scale, zero_point)
212
+ else: # covers channel, token and tensor strategies
213
+ if do_quantize:
214
+ output = _quantize(x, scale, zero_point, q_min, q_max, dtype=dtype)
215
+ if do_dequantize:
216
+ output = _dequantize(output if do_quantize else x, scale, zero_point)
143
217
 
144
- return DQ
218
+ return output
145
219
 
146
220
 
147
221
  def wrap_module_forward_quantized(module: Module, scheme: QuantizationScheme):
148
222
  # expects a module already initialized and injected with the parameters in
149
223
  # initialize_module_for_quantization
150
- forward_func_orig = module.forward.__func__
224
+ if hasattr(module.forward, "__func__"):
225
+ forward_func_orig = module.forward.__func__
226
+ else:
227
+ forward_func_orig = module.forward.func
151
228
 
152
229
  @wraps(forward_func_orig) # ensures docstring, names, etc are propagated
153
230
  def wrapped_forward(self, *args, **kwargs):
@@ -219,3 +296,33 @@ def maybe_calibrate_or_quantize(
219
296
  scale.data = updated_scale.to(device)
220
297
  zero_point.data = updated_zero_point.to(device)
221
298
  return fake_quantize(value, scale, zero_point, args)
299
+
300
+
301
+ @torch.no_grad()
302
+ def _quantize(
303
+ x: torch.Tensor,
304
+ scale: torch.Tensor,
305
+ zero_point: torch.Tensor,
306
+ q_min: torch.Tensor,
307
+ q_max: torch.Tensor,
308
+ dtype: Optional[torch.dtype] = None,
309
+ ) -> torch.Tensor:
310
+ quantized_value = torch.clamp(
311
+ torch.round(x / scale + zero_point),
312
+ q_min,
313
+ q_max,
314
+ )
315
+
316
+ if dtype is not None:
317
+ quantized_value = quantized_value.to(dtype)
318
+
319
+ return quantized_value
320
+
321
+
322
+ @torch.no_grad()
323
+ def _dequantize(
324
+ x_q: torch.Tensor,
325
+ scale: torch.Tensor,
326
+ zero_point: torch.Tensor,
327
+ ) -> torch.Tensor:
328
+ return (x_q - zero_point) * scale
@@ -35,6 +35,10 @@ def freeze_module_quantization(module: Module):
35
35
  # no quantization scheme nothing to do
36
36
  return
37
37
 
38
+ if module.quantization_status == QuantizationStatus.FROZEN:
39
+ # nothing to do, already frozen
40
+ return
41
+
38
42
  # delete observers from module if not dynamic
39
43
  if scheme.input_activations and not scheme.input_activations.dynamic:
40
44
  delattr(module, "input_observer")
@@ -20,7 +20,10 @@ import torch
20
20
  from compressed_tensors.quantization.lifecycle.forward import (
21
21
  wrap_module_forward_quantized,
22
22
  )
23
- from compressed_tensors.quantization.quant_args import QuantizationArgs
23
+ from compressed_tensors.quantization.quant_args import (
24
+ QuantizationArgs,
25
+ QuantizationStrategy,
26
+ )
24
27
  from compressed_tensors.quantization.quant_config import QuantizationStatus
25
28
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
26
29
  from torch.nn import Module, Parameter
@@ -58,7 +61,12 @@ def initialize_module_for_quantization(
58
61
  _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
59
62
  if scheme.weights is not None:
60
63
  if hasattr(module, "weight"):
61
- _initialize_scale_zero_point_observer(module, "weight", scheme.weights)
64
+ weight_shape = None
65
+ if isinstance(module, torch.nn.Linear):
66
+ weight_shape = module.weight.shape
67
+ _initialize_scale_zero_point_observer(
68
+ module, "weight", scheme.weights, weight_shape=weight_shape
69
+ )
62
70
  else:
63
71
  _LOGGER.warning(
64
72
  f"module type {type(module)} targeted for weight quantization but "
@@ -78,7 +86,10 @@ def initialize_module_for_quantization(
78
86
 
79
87
 
80
88
  def _initialize_scale_zero_point_observer(
81
- module: Module, base_name: str, quantization_args: QuantizationArgs
89
+ module: Module,
90
+ base_name: str,
91
+ quantization_args: QuantizationArgs,
92
+ weight_shape: Optional[torch.Size] = None,
82
93
  ):
83
94
  # initialize observer module and attach as submodule
84
95
  observer = quantization_args.get_observer()
@@ -89,11 +100,28 @@ def _initialize_scale_zero_point_observer(
89
100
 
90
101
  device = next(module.parameters()).device
91
102
 
103
+ # infer expected scale/zero point shape
104
+ expected_shape = 1 # per tensor
105
+
106
+ if base_name == "weight" and weight_shape is not None:
107
+ if quantization_args.strategy == QuantizationStrategy.CHANNEL:
108
+ # (output_channels, 1)
109
+ expected_shape = (weight_shape[0], 1)
110
+ elif quantization_args.strategy == QuantizationStrategy.GROUP:
111
+ expected_shape = (
112
+ weight_shape[0],
113
+ weight_shape[1] // quantization_args.group_size,
114
+ )
115
+
92
116
  # initializes empty scale and zero point parameters for the module
93
- init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
117
+ init_scale = Parameter(
118
+ torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
119
+ requires_grad=False,
120
+ )
94
121
  module.register_parameter(f"{base_name}_scale", init_scale)
95
122
 
96
123
  init_zero_point = Parameter(
97
- torch.empty(0, device=device, dtype=int), requires_grad=False
124
+ torch.empty(expected_shape, device=device, dtype=int),
125
+ requires_grad=False,
98
126
  )
99
127
  module.register_parameter(f"{base_name}_zero_point", init_zero_point)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Tuple
15
+ from typing import Any, Iterable, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.quant_args import (
@@ -40,6 +40,7 @@ class Observer(Module, RegistryMixin):
40
40
  self._scale = None
41
41
  self._zero_point = None
42
42
 
43
+ @torch.no_grad()
43
44
  def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
44
45
  """
45
46
  maps directly to get_qparams
@@ -49,9 +50,16 @@ class Observer(Module, RegistryMixin):
49
50
  """
50
51
  return self.get_qparams(observed=observed)
51
52
 
52
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
53
+ def calculate_qparams(
54
+ self,
55
+ observed: Tensor,
56
+ reduce_dims: Optional[Tuple[int]] = None,
57
+ ) -> Tuple[FloatTensor, IntTensor]:
53
58
  """
54
59
  :param observed: observed tensor to calculate quantization parameters for
60
+ :param reduce_dims: optional tuple of dimensions to reduce along,
61
+ returned scale and zero point will be shaped (1,) along the
62
+ reduced dimensions
55
63
  :return: tuple of scale and zero point derived from the observed tensor
56
64
  """
57
65
  raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
@@ -69,6 +77,7 @@ class Observer(Module, RegistryMixin):
69
77
  Convenience function to wrap overwritten calculate_qparams
70
78
  adds support to make observed tensor optional and support for tracking latest
71
79
  calculated scale and zero point
80
+
72
81
  :param observed: optional observed tensor to calculate quantization parameters
73
82
  from
74
83
  :return: tuple of scale and zero point based on last observed value
@@ -84,47 +93,42 @@ class Observer(Module, RegistryMixin):
84
93
  elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
85
94
  columns = observed.shape[1]
86
95
  scales, zero_points = [], []
87
- for i in range(0, columns, self.quantization_args.group_size):
96
+ group_idxs = range(0, columns, self.quantization_args.group_size)
97
+ for group_id, group_idx in enumerate(group_idxs):
88
98
  scale, zero_point = self.get_qparams_along_dim(
89
- observed[:, i : (i + group_size)],
99
+ observed[:, group_idx : (group_idx + group_size)],
90
100
  0,
101
+ tensor_id=group_id,
91
102
  )
92
103
  scales.append(scale)
93
104
  zero_points.append(zero_point)
94
105
 
95
- self._scale = torch.stack(scales, dim=1)
96
- self._zero_point = torch.stack(zero_points, dim=1)
106
+ self._scale = torch.cat(scales, dim=1, out=self._scale)
107
+ self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
97
108
 
98
109
  elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
99
110
  # assume observed is transposed, because its the output, hence use dim 0
100
111
  self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
101
112
 
102
113
  elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
103
-
104
114
  # use dim 1, assume the obsersed.shape = [batch, token, hidden]
105
115
  # should be batch, token
106
-
107
116
  self._scale, self._zero_point = self.get_qparams_along_dim(
108
- observed, dim=1
117
+ observed,
118
+ dim={0, 1},
109
119
  )
110
120
 
111
121
  return self._scale, self._zero_point
112
122
 
113
- def get_qparams_along_dim(self, observed, dim: int):
114
- # TODO: add documentation that specifies the shape must
115
- # be padded with 1-dims so the scales are along the right channel
116
- # TODO: generalize the logic for reduce_dims
117
- scales, zero_points = [], []
118
-
119
- # TODO: make a more generic way to get the channel
120
- num_dims = observed.shape[dim]
121
-
122
- for dim_idx in range(num_dims):
123
- scale, zero_point = self.calculate_qparams(
124
- observed.select(dim=dim, index=dim_idx)
125
- )
126
-
127
- scales.append(scale)
128
- zero_points.append(zero_point)
129
- # breakpoint()
130
- return torch.stack(scales), torch.stack(zero_points)
123
+ def get_qparams_along_dim(
124
+ self,
125
+ observed,
126
+ dim: Union[int, Iterable[int]],
127
+ tensor_id: Optional[Any] = None,
128
+ ):
129
+ dim = set(dim)
130
+
131
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
132
+ return self.calculate_qparams(
133
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
134
+ )
@@ -35,19 +35,24 @@ def calculate_qparams(
35
35
  """
36
36
  min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
37
37
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
+ device = min_vals.device
38
39
 
39
40
  bit_range = 2**quantization_args.num_bits - 1
40
41
  bit_min = -(bit_range + 1) / 2
41
42
  bit_max = bit_min + bit_range
42
43
  if quantization_args.symmetric:
43
- zero_points = torch.tensor(0).to(torch.int8)
44
44
  max_val_pos = torch.max(-min_vals, max_vals)
45
45
  scales = max_val_pos / (float(bit_range) / 2)
46
46
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
+ zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
47
48
  else:
48
49
  scales = (max_vals - min_vals) / float(bit_range)
49
50
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
50
51
  zero_points = bit_min - torch.round(min_vals / scales)
51
52
  zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
52
53
 
54
+ if scales.ndim == 0:
55
+ scales = scales.reshape(1)
56
+ zero_points = zero_points.reshape(1)
57
+
53
58
  return scales, zero_points
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -30,19 +30,27 @@ class MemorylessObserver(Observer):
30
30
  zero point based on the latest observed value without tracking state
31
31
  """
32
32
 
33
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
33
+ def calculate_qparams(
34
+ self,
35
+ observed: Tensor,
36
+ tensor_id: Optional[Any] = None,
37
+ reduce_dims: Optional[Tuple[int]] = None,
38
+ ) -> Tuple[FloatTensor, IntTensor]:
34
39
  """
35
- Returns the min and max values of observed
40
+ Returns the min and max values of observed tensor
36
41
 
37
42
  :param observed: observed tensor to calculate quantization parameters for
43
+ :param tensor_id: optional id for tensor; not used for memoryless
44
+ :param reduce_dims: optional tuple of dimensions to reduce along,
45
+ returned scale and zero point will be shaped (1,) along the
46
+ reduced dimensions
38
47
  :return: tuple of scale and zero point derived from the observed tensor
39
48
  """
40
- # TODO: Add support for full range of quantization Args, only supports 8bit
41
- # per tensor
42
- min_val, max_val = torch.aminmax(observed)
43
49
 
44
- # ensure zero is in the range
45
- min_val = torch.min(min_val, torch.zeros_like(min_val))
46
- max_val = torch.max(max_val, torch.zeros_like(max_val))
50
+ if not reduce_dims:
51
+ min_val, max_val = torch.aminmax(observed)
52
+ else:
53
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
54
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
47
55
 
48
56
  return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -12,7 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Tuple
15
+ from typing import Any, Optional, Tuple
16
16
 
17
17
  import torch
18
18
  from compressed_tensors.quantization.observers.base import Observer
@@ -36,30 +36,61 @@ class MovingAverageMinMaxObserver(Observer):
36
36
  ):
37
37
  super().__init__(quantization_args=quantization_args)
38
38
 
39
- self.min_val = float("inf")
40
- self.max_val = -float("inf")
39
+ self.min_val = {}
40
+ self.max_val = {}
41
41
  self.averaging_constant = averaging_constant
42
42
 
43
- def calculate_qparams(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
43
+ def calculate_qparams(
44
+ self,
45
+ observed: Tensor,
46
+ reduce_dims: Optional[Tuple[int]] = None,
47
+ tensor_id: Optional[Any] = None,
48
+ ) -> Tuple[FloatTensor, IntTensor]:
44
49
  """
45
50
  Updates the observed min and max using a moving average smoothed by the
46
51
  averaging_constant
47
52
 
48
53
  :param observed: observed tensor to calculate quantization parameters for
54
+ :param reduce_dims: optional tuple of dimensions to reduce along,
55
+ returned scale and zero point will be shaped (1,) along the
56
+ reduced dimensions
57
+ :param tensor_id: Optional id if different ranges of observed tensors are
58
+ passed, useful for sharding tensors by group_size
49
59
  :return: tuple of scale and zero point derived from the observed tensor
50
60
  """
61
+ tensor_id = tensor_id or "default"
51
62
 
52
- min_val, max_val = torch.aminmax(observed)
63
+ if not reduce_dims:
64
+ min_val, max_val = torch.aminmax(observed)
65
+ else:
66
+ min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
67
+ max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
68
+
69
+ running_min_val = self.min_val.get(tensor_id, None)
70
+ running_max_val = self.max_val.get(tensor_id, None)
53
71
 
54
- if self.min_val == float("inf") and self.max_val == float("-inf"):
55
- self.min_val = min_val
56
- self.max_val = max_val
72
+ if running_min_val is None or running_max_val is None:
73
+ updated_min_val = min_val
74
+ updated_max_val = max_val
57
75
  else:
58
- self.min_val = self.min_val + self.averaging_constant * (
59
- min_val - self.min_val
76
+ updated_min_val = running_min_val + self.averaging_constant * (
77
+ min_val - running_min_val
60
78
  )
61
- self.max_val = self.max_val + self.averaging_constant * (
62
- max_val - self.max_val
79
+ updated_max_val = running_max_val + self.averaging_constant * (
80
+ max_val - running_max_val
63
81
  )
64
82
 
65
- return calculate_qparams(self.min_val, self.max_val, self.quantization_args)
83
+ self.min_val[tensor_id] = updated_min_val
84
+ self.max_val[tensor_id] = updated_max_val
85
+
86
+ return calculate_qparams(
87
+ updated_min_val, updated_max_val, self.quantization_args
88
+ )
89
+
90
+ def get_qparams_along_dim(
91
+ self, observed, dim: int, tensor_id: Optional[Any] = None
92
+ ):
93
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
94
+ return self.calculate_qparams(
95
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
+ )
@@ -42,7 +42,7 @@ class QuantizationStrategy(str, Enum):
42
42
  TOKEN = "token"
43
43
 
44
44
 
45
- class QuantizationArgs(BaseModel):
45
+ class QuantizationArgs(BaseModel, use_enum_values=True):
46
46
  """
47
47
  User facing arguments used to define a quantization config for weights or
48
48
  activations
@@ -62,7 +62,7 @@ class QuantizationArgs(BaseModel):
62
62
  """
63
63
 
64
64
  num_bits: int = 8
65
- type: QuantizationType = QuantizationType.INT
65
+ type: QuantizationType = QuantizationType.INT.value
66
66
  symmetric: bool = True
67
67
  group_size: Optional[int] = None
68
68
  strategy: Optional[QuantizationStrategy] = None