compressed-tensors 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. compressed_tensors/base.py +1 -0
  2. compressed_tensors/compressors/__init__.py +5 -1
  3. compressed_tensors/compressors/base.py +200 -8
  4. compressed_tensors/compressors/dense.py +1 -1
  5. compressed_tensors/compressors/marlin_24.py +11 -10
  6. compressed_tensors/compressors/model_compressor.py +101 -13
  7. compressed_tensors/compressors/naive_quantized.py +140 -0
  8. compressed_tensors/compressors/pack_quantized.py +128 -132
  9. compressed_tensors/compressors/sparse_bitmask.py +1 -1
  10. compressed_tensors/config/base.py +8 -1
  11. compressed_tensors/{compressors/utils → linear}/__init__.py +0 -6
  12. compressed_tensors/linear/compressed_linear.py +87 -0
  13. compressed_tensors/quantization/lifecycle/__init__.py +1 -0
  14. compressed_tensors/quantization/lifecycle/apply.py +204 -44
  15. compressed_tensors/quantization/lifecycle/calibration.py +22 -2
  16. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  17. compressed_tensors/quantization/lifecycle/forward.py +139 -61
  18. compressed_tensors/quantization/lifecycle/helpers.py +80 -0
  19. compressed_tensors/quantization/lifecycle/initialize.py +77 -13
  20. compressed_tensors/quantization/observers/__init__.py +1 -0
  21. compressed_tensors/quantization/observers/base.py +93 -14
  22. compressed_tensors/quantization/observers/helpers.py +64 -11
  23. compressed_tensors/quantization/observers/min_max.py +8 -0
  24. compressed_tensors/quantization/observers/mse.py +162 -0
  25. compressed_tensors/quantization/quant_args.py +139 -23
  26. compressed_tensors/quantization/quant_config.py +35 -2
  27. compressed_tensors/quantization/quant_scheme.py +112 -13
  28. compressed_tensors/quantization/utils/helpers.py +68 -2
  29. compressed_tensors/utils/__init__.py +5 -0
  30. compressed_tensors/utils/helpers.py +44 -2
  31. compressed_tensors/utils/offload.py +116 -0
  32. compressed_tensors/utils/permute.py +70 -0
  33. compressed_tensors/utils/safetensors_load.py +2 -0
  34. compressed_tensors/{compressors/utils → utils}/semi_structured_conversions.py +1 -0
  35. compressed_tensors/version.py +1 -1
  36. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/METADATA +35 -22
  37. compressed_tensors-0.6.0.dist-info/RECORD +52 -0
  38. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/WHEEL +1 -1
  39. compressed_tensors/compressors/int_quantized.py +0 -126
  40. compressed_tensors/compressors/utils/helpers.py +0 -43
  41. compressed_tensors-0.4.0.dist-info/RECORD +0 -48
  42. /compressed_tensors/{compressors/utils → utils}/permutations_24.py +0 -0
  43. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/LICENSE +0 -0
  44. {compressed_tensors-0.4.0.dist-info → compressed_tensors-0.6.0.dist-info}/top_level.txt +0 -0
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import logging
16
+ from math import ceil
15
17
  from typing import Any, Iterable, Optional, Tuple, Union
16
18
 
17
19
  import torch
@@ -20,10 +22,14 @@ from compressed_tensors.quantization.quant_args import (
20
22
  QuantizationStrategy,
21
23
  )
22
24
  from compressed_tensors.registry.registry import RegistryMixin
25
+ from compressed_tensors.utils import safe_permute
23
26
  from torch import FloatTensor, IntTensor, Tensor
24
27
  from torch.nn import Module
25
28
 
26
29
 
30
+ _LOGGER = logging.getLogger(__name__)
31
+
32
+
27
33
  __all__ = ["Observer"]
28
34
 
29
35
 
@@ -39,16 +45,21 @@ class Observer(Module, RegistryMixin):
39
45
  super().__init__()
40
46
  self._scale = None
41
47
  self._zero_point = None
48
+ self._num_observed_tokens = None
42
49
 
43
50
  @torch.no_grad()
44
- def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
51
+ def forward(
52
+ self, observed: Tensor, g_idx: Optional[Tensor] = None
53
+ ) -> Tuple[FloatTensor, IntTensor]:
45
54
  """
46
55
  maps directly to get_qparams
47
- :param observed: optional observed tensor to calculate quantization parameters
48
- from
56
+ :param observed: optional observed tensor from which to calculate
57
+ quantization parameters
58
+ :param g_idx: optional mapping from column index to group index
49
59
  :return: tuple of scale and zero point based on last observed value
50
60
  """
51
- return self.get_qparams(observed=observed)
61
+ self.record_observed_tokens(observed)
62
+ return self.get_qparams(observed=observed, g_idx=g_idx)
52
63
 
53
64
  def calculate_qparams(
54
65
  self,
@@ -71,7 +82,9 @@ class Observer(Module, RegistryMixin):
71
82
  ...
72
83
 
73
84
  def get_qparams(
74
- self, observed: Optional[Tensor] = None
85
+ self,
86
+ observed: Optional[Tensor] = None,
87
+ g_idx: Optional[Tensor] = None,
75
88
  ) -> Tuple[FloatTensor, IntTensor]:
76
89
  """
77
90
  Convenience function to wrap overwritten calculate_qparams
@@ -80,6 +93,7 @@ class Observer(Module, RegistryMixin):
80
93
 
81
94
  :param observed: optional observed tensor to calculate quantization parameters
82
95
  from
96
+ :param g_idx: optional mapping from column index to group index
83
97
  :return: tuple of scale and zero point based on last observed value
84
98
  """
85
99
  if observed is not None:
@@ -91,20 +105,42 @@ class Observer(Module, RegistryMixin):
91
105
  self._scale, self._zero_point = self.calculate_qparams(observed)
92
106
 
93
107
  elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
108
+ rows = observed.shape[0]
94
109
  columns = observed.shape[1]
95
- scales, zero_points = [], []
96
- group_idxs = range(0, columns, self.quantization_args.group_size)
97
- for group_id, group_idx in enumerate(group_idxs):
110
+ num_groups = int(ceil(columns / group_size))
111
+ self._scale = torch.empty(
112
+ (rows, num_groups), dtype=observed.dtype, device=observed.device
113
+ )
114
+ zp_dtype = self.quantization_args.pytorch_dtype()
115
+ self._zero_point = torch.empty(
116
+ (rows, num_groups), dtype=zp_dtype, device=observed.device
117
+ )
118
+
119
+ # support column-order (default) quantization as well as other orderings
120
+ # such as activation ordering. Below checks if g_idx has initialized
121
+ is_column_order = g_idx is None or -1 in g_idx
122
+ if is_column_order:
123
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
124
+ else:
125
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
126
+ group_sizes = group_sizes[torch.argsort(group_indices)]
127
+
128
+ perm = torch.argsort(g_idx)
129
+ observed = safe_permute(observed, perm, dim=1)
130
+
131
+ # TODO: experiment with vectorizing for loop for performance
132
+ end = 0
133
+ for group_index, group_count in enumerate(group_sizes):
134
+ start = end
135
+ end = start + group_count
98
136
  scale, zero_point = self.get_qparams_along_dim(
99
- observed[:, group_idx : (group_idx + group_size)],
137
+ observed[:, start:end],
100
138
  0,
101
- tensor_id=group_id,
139
+ tensor_id=group_index,
102
140
  )
103
- scales.append(scale)
104
- zero_points.append(zero_point)
105
141
 
106
- self._scale = torch.cat(scales, dim=1, out=self._scale)
107
- self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
142
+ self._scale[:, group_index] = scale.squeeze(1)
143
+ self._zero_point[:, group_index] = zero_point.squeeze(1)
108
144
 
109
145
  elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
110
146
  # assume observed is transposed, because its the output, hence use dim 0
@@ -126,9 +162,52 @@ class Observer(Module, RegistryMixin):
126
162
  dim: Union[int, Iterable[int]],
127
163
  tensor_id: Optional[Any] = None,
128
164
  ):
165
+ if isinstance(dim, int):
166
+ dim = [dim]
129
167
  dim = set(dim)
130
168
 
131
169
  reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
132
170
  return self.calculate_qparams(
133
171
  observed, reduce_dims=reduce_dims, tensor_id=tensor_id
134
172
  )
173
+
174
+ def record_observed_tokens(self, batch_tensor: Tensor):
175
+ """
176
+ Counts the number of tokens observed during the
177
+ forward passes. The count is aggregated in the
178
+ _num_observed_tokens attribute of the class.
179
+
180
+ Note: The batch_tensor is expected to have two dimensions
181
+ (batch_size * sequence_length, num_features). This is the
182
+ general shape expected by the forward pass of the expert
183
+ layers in a MOE model. If the input tensor does not have
184
+ two dimensions, the _num_observed_tokens attribute will be set
185
+ to None.
186
+ """
187
+ if not isinstance(batch_tensor, Tensor):
188
+ raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
189
+
190
+ if batch_tensor.ndim != 2:
191
+ _LOGGER.debug(
192
+ "The input tensor is expected to have two dimensions "
193
+ "(batch_size * sequence_length, num_features). "
194
+ f"The input tensor has {batch_tensor.ndim} dimensions."
195
+ )
196
+ return
197
+
198
+ if self._num_observed_tokens is None:
199
+ # initialize the count
200
+ self._num_observed_tokens = 0
201
+
202
+ # batch_tensor (batch_size * sequence_length, num_features)
203
+ # observed_tokens (batch_size * sequence_length)
204
+ observed_tokens, _ = batch_tensor.shape
205
+ self._num_observed_tokens += observed_tokens
206
+
207
+ def reset(self):
208
+ """
209
+ Reset the state of the observer
210
+ """
211
+ self._num_observed_tokens = None
212
+ self._scale = None
213
+ self._zero_point = None
@@ -12,23 +12,45 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from collections import Counter
15
16
  from typing import Tuple
16
17
 
17
18
  import torch
18
- from compressed_tensors.quantization.quant_args import QuantizationArgs
19
+ from compressed_tensors.quantization.quant_args import (
20
+ FP8_DTYPE,
21
+ QuantizationArgs,
22
+ QuantizationType,
23
+ )
19
24
  from torch import FloatTensor, IntTensor, Tensor
20
25
 
21
26
 
22
- __all__ = ["calculate_qparams"]
27
+ __all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28
+
29
+
30
+ def get_observer_token_count(module: torch.nn.Module) -> Counter:
31
+ """
32
+ Parse the module and return the number of tokens observed by
33
+ each module's observer.
34
+
35
+ :param module: module to parse
36
+ :return: counter with the number of tokens observed by each observer
37
+ """
38
+ token_counts = Counter()
39
+ for name, module in module.named_modules():
40
+ if name.endswith(".input_observer"):
41
+ token_counts[
42
+ name.replace(".input_observer", "")
43
+ ] = module._num_observed_tokens
44
+ return token_counts
23
45
 
24
46
 
25
47
  def calculate_qparams(
26
48
  min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
27
49
  ) -> Tuple[FloatTensor, IntTensor]:
28
50
  """
29
- :param min_vals: tensor of min value(s) to caluclate scale(s) and zero point(s)
51
+ :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
30
52
  from
31
- :param max_vals: tensor of max value(s) to caluclate scale(s) and zero point(s)
53
+ :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
32
54
  from
33
55
  :param quantization_args: settings to quantization
34
56
  :return: tuple of the calculated scale(s) and zero point(s)
@@ -37,22 +59,53 @@ def calculate_qparams(
37
59
  max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
38
60
  device = min_vals.device
39
61
 
40
- bit_range = 2**quantization_args.num_bits - 1
41
- bit_min = -(bit_range + 1) / 2
42
- bit_max = bit_min + bit_range
62
+ bit_min, bit_max = calculate_range(quantization_args, device)
63
+ bit_range = bit_max - bit_min
64
+ zp_dtype = quantization_args.pytorch_dtype()
65
+
43
66
  if quantization_args.symmetric:
44
- max_val_pos = torch.max(-min_vals, max_vals)
67
+ max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
45
68
  scales = max_val_pos / (float(bit_range) / 2)
46
69
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
47
- zero_points = torch.zeros(scales.shape, device=device, dtype=torch.int8)
70
+ zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
48
71
  else:
49
72
  scales = (max_vals - min_vals) / float(bit_range)
50
73
  scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
51
- zero_points = bit_min - torch.round(min_vals / scales)
52
- zero_points = torch.clamp(zero_points, bit_min, bit_max).to(torch.int8)
74
+ zero_points = bit_min - (min_vals / scales)
75
+ zero_points = torch.clamp(zero_points, bit_min, bit_max)
76
+
77
+ # match zero-points to quantized type
78
+ zero_points = zero_points.to(zp_dtype)
53
79
 
54
80
  if scales.ndim == 0:
55
81
  scales = scales.reshape(1)
56
82
  zero_points = zero_points.reshape(1)
57
83
 
58
84
  return scales, zero_points
85
+
86
+
87
+ def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
88
+ """
89
+ Calculated the effective quantization range for the given Quantization Args
90
+
91
+ :param quantization_args: quantization args to get range of
92
+ :param device: device to store the range to
93
+ :return: tuple endpoints for the given quantization range
94
+ """
95
+ if quantization_args.type == QuantizationType.INT:
96
+ bit_range = 2**quantization_args.num_bits
97
+ q_max = torch.tensor(bit_range / 2 - 1, device=device)
98
+ q_min = torch.tensor(-bit_range / 2, device=device)
99
+ elif quantization_args.type == QuantizationType.FLOAT:
100
+ if quantization_args.num_bits != 8:
101
+ raise ValueError(
102
+ "Floating point quantization is only supported for 8 bits,"
103
+ f"got {quantization_args.num_bits}"
104
+ )
105
+ fp_range_info = torch.finfo(FP8_DTYPE)
106
+ q_max = torch.tensor(fp_range_info.max, device=device)
107
+ q_min = torch.tensor(fp_range_info.min, device=device)
108
+ else:
109
+ raise ValueError(f"Invalid quantization type {quantization_args.type}")
110
+
111
+ return q_min, q_max
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
94
94
  return self.calculate_qparams(
95
95
  observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
96
  )
97
+
98
+ def reset(self):
99
+ """
100
+ Reset the state of the observer, including min and maximum values
101
+ """
102
+ super().reset()
103
+ self.min_val = {}
104
+ self.max_val = {}
@@ -0,0 +1,162 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import FloatTensor, IntTensor, Tensor
22
+
23
+
24
+ __all__ = ["MovingAverageMSEObserver"]
25
+
26
+
27
+ @Observer.register("mse")
28
+ class MovingAverageMSEObserver(Observer):
29
+ """
30
+ Implements a dynamic quantization observer that sets the scale and
31
+ zero point based on a moving average of the mse-clipped min and max observed values
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ quantization_args: QuantizationArgs,
37
+ averaging_constant: float = 0.01,
38
+ grid: float = 100.0,
39
+ maxshrink: float = 0.80,
40
+ norm: float = 2.4,
41
+ ):
42
+ super().__init__(quantization_args=quantization_args)
43
+
44
+ self.min_val = {}
45
+ self.max_val = {}
46
+ self.averaging_constant = averaging_constant
47
+ self.grid = grid
48
+ self.maxshrink = maxshrink
49
+ self.norm = norm
50
+
51
+ def calculate_mse_min_max(
52
+ self,
53
+ observed: Tensor,
54
+ reduce_dims: Optional[Tuple[int]] = None,
55
+ ):
56
+ """
57
+ Computes the mse-clipped min and max values of the observed tensor by
58
+ optimizing for quantization error
59
+
60
+ :param observed: observed tensor to calculate quantization parameters for
61
+ :param reduce_dims: optional tuple of dimensions to reduce along,
62
+ returned values will be shaped (1,) along the reduced dimensions
63
+ :return: tuple of min and max values derived from the observed tensor
64
+ """
65
+ from compressed_tensors.quantization.lifecycle import fake_quantize
66
+
67
+ if not reduce_dims:
68
+ absolute_min_val, absolute_max_val = torch.aminmax(observed)
69
+ else:
70
+ absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
71
+ absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
72
+
73
+ best = torch.full(absolute_min_val.shape, float("inf"))
74
+ min_val = torch.ones(absolute_min_val.shape)
75
+ max_val = torch.zeros(absolute_max_val.shape)
76
+ for i in range(int(self.maxshrink * self.grid)):
77
+ p = 1 - i / self.grid
78
+ shrinked_min_val = p * absolute_min_val
79
+ shrinked_max_val = p * absolute_max_val
80
+
81
+ candidate_scales, candidate_zero_points = calculate_qparams(
82
+ shrinked_min_val, shrinked_max_val, self.quantization_args
83
+ )
84
+ q = fake_quantize(
85
+ observed,
86
+ candidate_scales,
87
+ candidate_zero_points,
88
+ self.quantization_args,
89
+ )
90
+
91
+ q -= observed
92
+ q.abs_()
93
+ q.pow_(self.norm)
94
+ if not reduce_dims:
95
+ err = torch.sum(q)
96
+ else:
97
+ err = torch.sum(q, reduce_dims, keepdims=True)
98
+
99
+ tmp = err < best
100
+ if torch.any(tmp):
101
+ best[tmp] = err[tmp]
102
+ min_val[tmp] = shrinked_min_val[tmp]
103
+ max_val[tmp] = shrinked_max_val[tmp]
104
+ return min_val, max_val
105
+
106
+ def calculate_qparams(
107
+ self,
108
+ observed: Tensor,
109
+ reduce_dims: Optional[Tuple[int]] = None,
110
+ tensor_id: Optional[Any] = None,
111
+ ) -> Tuple[FloatTensor, IntTensor]:
112
+ """
113
+ Updates the mse-clipped min and max values of the observed tensor using
114
+ a moving average smoothed by the averaging_constant
115
+
116
+ :param observed: observed tensor to calculate quantization parameters for
117
+ :param reduce_dims: optional tuple of dimensions to reduce along,
118
+ returned scale and zero point will be shaped (1,) along the
119
+ reduced dimensions
120
+ :param tensor_id: Optional id if different ranges of observed tensors are
121
+ passed, useful for sharding tensors by group_size
122
+ :return: tuple of scale and zero point derived from the observed tensor
123
+ """
124
+ min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
125
+
126
+ running_min_val = self.min_val.get(tensor_id, None)
127
+ running_max_val = self.max_val.get(tensor_id, None)
128
+
129
+ if running_min_val is None or running_max_val is None:
130
+ updated_min_val = min_val
131
+ updated_max_val = max_val
132
+ else:
133
+ updated_min_val = running_min_val + self.averaging_constant * (
134
+ min_val - running_min_val
135
+ )
136
+ updated_max_val = running_max_val + self.averaging_constant * (
137
+ max_val - running_max_val
138
+ )
139
+
140
+ tensor_id = tensor_id or "default"
141
+ self.min_val[tensor_id] = updated_min_val
142
+ self.max_val[tensor_id] = updated_max_val
143
+
144
+ return calculate_qparams(
145
+ updated_min_val, updated_max_val, self.quantization_args
146
+ )
147
+
148
+ def get_qparams_along_dim(
149
+ self, observed, dim: int, tensor_id: Optional[Any] = None
150
+ ):
151
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
152
+ return self.calculate_qparams(
153
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
154
+ )
155
+
156
+ def reset(self):
157
+ """
158
+ Reset the state of the observer, including min and maximum values
159
+ """
160
+ super().reset()
161
+ self.min_val = {}
162
+ self.max_val = {}
@@ -13,12 +13,22 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from enum import Enum
16
- from typing import Any, Dict, Optional
16
+ from typing import Any, Dict, Optional, Union
17
17
 
18
- from pydantic import BaseModel, Field, validator
18
+ import torch
19
+ from pydantic import BaseModel, Field, field_validator, model_validator
19
20
 
20
21
 
21
- __all__ = ["QuantizationType", "QuantizationStrategy", "QuantizationArgs"]
22
+ __all__ = [
23
+ "FP8_DTYPE",
24
+ "QuantizationType",
25
+ "QuantizationStrategy",
26
+ "QuantizationArgs",
27
+ "round_to_quantized_type",
28
+ "ActivationOrdering",
29
+ ]
30
+
31
+ FP8_DTYPE = torch.float8_e4m3fn
22
32
 
23
33
 
24
34
  class QuantizationType(str, Enum):
@@ -42,6 +52,19 @@ class QuantizationStrategy(str, Enum):
42
52
  TOKEN = "token"
43
53
 
44
54
 
55
+ class ActivationOrdering(str, Enum):
56
+ """
57
+ Enum storing strategies for activation ordering
58
+
59
+ Group: reorder groups and weight\n
60
+ Weight: only reorder weight, not groups. Slightly lower latency and
61
+ accuracy compared to group actorder\n
62
+ """
63
+
64
+ GROUP = "group"
65
+ WEIGHT = "weight"
66
+
67
+
45
68
  class QuantizationArgs(BaseModel, use_enum_values=True):
46
69
  """
47
70
  User facing arguments used to define a quantization config for weights or
@@ -59,15 +82,18 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
59
82
  ranges will be observed with every sample. Defaults to False for static
60
83
  quantization. Note that enabling dynamic quantization will change the default
61
84
  observer to a memoryless one
85
+ :param actorder: whether to apply group quantization in decreasing order of
86
+ activation. Defaults to None for arbitrary ordering
62
87
  """
63
88
 
64
89
  num_bits: int = 8
65
- type: QuantizationType = QuantizationType.INT.value
90
+ type: QuantizationType = QuantizationType.INT
66
91
  symmetric: bool = True
67
92
  group_size: Optional[int] = None
68
93
  strategy: Optional[QuantizationStrategy] = None
69
94
  block_structure: Optional[str] = None
70
95
  dynamic: bool = False
96
+ actorder: Union[ActivationOrdering, bool, None] = None
71
97
  observer: str = Field(
72
98
  default="minmax",
73
99
  description=(
@@ -89,37 +115,127 @@ class QuantizationArgs(BaseModel, use_enum_values=True):
89
115
  """
90
116
  from compressed_tensors.quantization.observers.base import Observer
91
117
 
92
- if self.observer == "minmax" and self.dynamic:
118
+ if self.dynamic:
93
119
  # override defualt observer for dynamic, you never want minmax which
94
120
  # keeps state across samples for dynamic
95
121
  self.observer = "memoryless"
96
122
 
97
123
  return Observer.load_from_registry(self.observer, quantization_args=self)
98
124
 
99
- @validator("strategy", pre=True, always=True)
100
- def validate_strategy(cls, value, values):
101
- group_size = values.get("group_size")
125
+ @field_validator("type", mode="before")
126
+ def validate_type(cls, value) -> QuantizationType:
127
+ if isinstance(value, str):
128
+ return QuantizationType(value.lower())
102
129
 
103
- # use group_size to determinine strategy if not given explicity
104
- if group_size is not None and value is None:
105
- if group_size > 0:
106
- return QuantizationStrategy.GROUP
130
+ return value
107
131
 
108
- elif group_size == -1:
109
- return QuantizationStrategy.CHANNEL
132
+ @field_validator("group_size", mode="before")
133
+ def validate_group(cls, value) -> Union[int, None]:
134
+ if value is None:
135
+ return value
136
+
137
+ if value < -1:
138
+ raise ValueError(
139
+ f"Invalid group size {value}. Use group_size > 0 for "
140
+ "strategy='group' and group_size = -1 for 'channel'"
141
+ )
142
+
143
+ return value
144
+
145
+ @field_validator("strategy", mode="before")
146
+ def validate_strategy(cls, value) -> Union[QuantizationStrategy, None]:
147
+ if isinstance(value, str):
148
+ return QuantizationStrategy(value.lower())
149
+
150
+ return value
110
151
 
152
+ @field_validator("actorder", mode="before")
153
+ def validate_actorder(cls, value) -> Optional[ActivationOrdering]:
154
+ if isinstance(value, bool):
155
+ return ActivationOrdering.GROUP if value else None
156
+
157
+ if isinstance(value, str):
158
+ return ActivationOrdering(value.lower())
159
+
160
+ return value
161
+
162
+ @model_validator(mode="after")
163
+ def validate_model_after(model: "QuantizationArgs") -> Dict[str, Any]:
164
+ # extract user-passed values from dictionary
165
+ strategy = model.strategy
166
+ group_size = model.group_size
167
+ actorder = model.actorder
168
+
169
+ # infer strategy
170
+ if strategy is None:
171
+ if group_size is None:
172
+ strategy = QuantizationStrategy.TENSOR
173
+ elif group_size > 0:
174
+ strategy = QuantizationStrategy.GROUP
175
+ elif group_size == -1:
176
+ strategy = QuantizationStrategy.CHANNEL
111
177
  else:
112
178
  raise ValueError(
113
- f"group_size={group_size} with strategy {value} is invald. "
114
- "group_size > 0 for strategy='group' and "
115
- "group_size = -1 for 'channel'"
179
+ f"Invalid group size {group_size}. Use group_size > 0 for "
180
+ "strategy='group' and group_size = -1 for 'channel'"
116
181
  )
117
182
 
118
- if value == QuantizationStrategy.GROUP:
119
- if group_size is None:
120
- raise ValueError(f"strategy {value} requires group_size to be set.")
183
+ # validate strategy and group
184
+ if strategy == QuantizationStrategy.GROUP:
185
+ if group_size is None or group_size <= 0:
186
+ raise ValueError(
187
+ f"strategy {strategy} requires group_size to be "
188
+ "set to a positive value"
189
+ )
190
+ if (
191
+ group_size is not None
192
+ and group_size > 0
193
+ and strategy != QuantizationStrategy.GROUP
194
+ ):
195
+ raise ValueError("group_size requires strategy to be set to 'group'")
196
+
197
+ # validate activation ordering and strategy
198
+ if actorder is not None and strategy != QuantizationStrategy.GROUP:
199
+ raise ValueError(
200
+ "Must use group quantization strategy in order to apply "
201
+ "activation ordering"
202
+ )
203
+
204
+ # write back modified values
205
+ model.strategy = strategy
206
+ return model
207
+
208
+ def pytorch_dtype(self) -> torch.dtype:
209
+ if self.type == QuantizationType.FLOAT:
210
+ return FP8_DTYPE
211
+ elif self.type == QuantizationType.INT:
212
+ if self.num_bits <= 8:
213
+ return torch.int8
214
+ elif self.num_bits <= 16:
215
+ return torch.int16
216
+ else:
217
+ return torch.int32
218
+ else:
219
+ raise ValueError(f"Invalid quantization type {self.type}")
121
220
 
122
- if value is None:
123
- return QuantizationStrategy.TENSOR
124
221
 
125
- return value
222
+ def round_to_quantized_type(
223
+ tensor: torch.Tensor, args: QuantizationArgs
224
+ ) -> torch.Tensor:
225
+ """
226
+ Rounds each element of the input tensor to the nearest quantized representation,
227
+ keeping to original dtype
228
+
229
+ :param tensor: tensor to round
230
+ :param args: QuantizationArgs to pull appropriate dtype from
231
+ :return: rounded tensor
232
+ """
233
+ original_dtype = tensor.dtype
234
+ if args.type == QuantizationType.FLOAT:
235
+ rounded = tensor.to(FP8_DTYPE)
236
+ elif args.type == QuantizationType.INT:
237
+ rounded = torch.round(tensor)
238
+ else:
239
+ raise ValueError(f"Invalid quantization type {args.type}")
240
+
241
+ return rounded.to(original_dtype)