compressed-tensors 0.5.0__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. compressed_tensors/__init__.py +1 -0
  2. compressed_tensors/base.py +2 -0
  3. compressed_tensors/compressors/__init__.py +6 -12
  4. compressed_tensors/compressors/base.py +137 -9
  5. compressed_tensors/compressors/helpers.py +6 -6
  6. compressed_tensors/compressors/model_compressors/__init__.py +17 -0
  7. compressed_tensors/compressors/{model_compressor.py → model_compressors/model_compressor.py} +99 -43
  8. compressed_tensors/compressors/quantized_compressors/__init__.py +18 -0
  9. compressed_tensors/compressors/{naive_quantized.py → quantized_compressors/base.py} +64 -62
  10. compressed_tensors/compressors/quantized_compressors/naive_quantized.py +140 -0
  11. compressed_tensors/compressors/quantized_compressors/pack_quantized.py +211 -0
  12. compressed_tensors/compressors/sparse_compressors/__init__.py +18 -0
  13. compressed_tensors/compressors/sparse_compressors/base.py +110 -0
  14. compressed_tensors/compressors/{dense.py → sparse_compressors/dense.py} +3 -3
  15. compressed_tensors/compressors/{sparse_bitmask.py → sparse_compressors/sparse_bitmask.py} +14 -59
  16. compressed_tensors/compressors/sparse_quantized_compressors/__init__.py +16 -0
  17. compressed_tensors/compressors/{marlin_24.py → sparse_quantized_compressors/marlin_24.py} +3 -3
  18. compressed_tensors/config/base.py +6 -1
  19. compressed_tensors/linear/__init__.py +13 -0
  20. compressed_tensors/linear/compressed_linear.py +87 -0
  21. compressed_tensors/quantization/__init__.py +1 -0
  22. compressed_tensors/quantization/cache.py +201 -0
  23. compressed_tensors/quantization/lifecycle/apply.py +63 -9
  24. compressed_tensors/quantization/lifecycle/calibration.py +7 -7
  25. compressed_tensors/quantization/lifecycle/compressed.py +3 -1
  26. compressed_tensors/quantization/lifecycle/forward.py +126 -44
  27. compressed_tensors/quantization/lifecycle/frozen.py +6 -1
  28. compressed_tensors/quantization/lifecycle/helpers.py +0 -20
  29. compressed_tensors/quantization/lifecycle/initialize.py +138 -55
  30. compressed_tensors/quantization/observers/__init__.py +1 -0
  31. compressed_tensors/quantization/observers/base.py +54 -14
  32. compressed_tensors/quantization/observers/min_max.py +8 -0
  33. compressed_tensors/quantization/observers/mse.py +162 -0
  34. compressed_tensors/quantization/quant_args.py +102 -24
  35. compressed_tensors/quantization/quant_config.py +14 -2
  36. compressed_tensors/quantization/quant_scheme.py +12 -13
  37. compressed_tensors/quantization/utils/helpers.py +44 -19
  38. compressed_tensors/utils/__init__.py +1 -0
  39. compressed_tensors/utils/helpers.py +30 -1
  40. compressed_tensors/utils/offload.py +14 -2
  41. compressed_tensors/utils/permute.py +70 -0
  42. compressed_tensors/utils/safetensors_load.py +2 -0
  43. compressed_tensors/utils/semi_structured_conversions.py +1 -0
  44. compressed_tensors/version.py +1 -1
  45. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/METADATA +35 -23
  46. compressed_tensors-0.7.0.dist-info/RECORD +59 -0
  47. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/WHEEL +1 -1
  48. compressed_tensors/compressors/pack_quantized.py +0 -219
  49. compressed_tensors-0.5.0.dist-info/RECORD +0 -48
  50. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/LICENSE +0 -0
  51. {compressed_tensors-0.5.0.dist-info → compressed_tensors-0.7.0.dist-info}/top_level.txt +0 -0
@@ -17,17 +17,19 @@ import logging
17
17
  from typing import Optional
18
18
 
19
19
  import torch
20
- from accelerate.hooks import add_hook_to_module, remove_hook_from_module
21
- from accelerate.utils import PrefixedDataset
20
+ from compressed_tensors.quantization.cache import KVCacheScaleType
22
21
  from compressed_tensors.quantization.lifecycle.forward import (
23
22
  wrap_module_forward_quantized,
23
+ wrap_module_forward_quantized_attn,
24
24
  )
25
25
  from compressed_tensors.quantization.quant_args import (
26
+ ActivationOrdering,
26
27
  QuantizationArgs,
27
28
  QuantizationStrategy,
28
29
  )
29
30
  from compressed_tensors.quantization.quant_config import QuantizationStatus
30
31
  from compressed_tensors.quantization.quant_scheme import QuantizationScheme
32
+ from compressed_tensors.quantization.utils import is_kv_cache_quant_scheme
31
33
  from compressed_tensors.utils import get_execution_device, is_module_offloaded
32
34
  from torch.nn import Module, Parameter
33
35
 
@@ -43,6 +45,7 @@ _LOGGER = logging.getLogger(__name__)
43
45
  def initialize_module_for_quantization(
44
46
  module: Module,
45
47
  scheme: Optional[QuantizationScheme] = None,
48
+ force_zero_point: bool = True,
46
49
  ):
47
50
  """
48
51
  attaches appropriate scales, zero points, and observers to a layer
@@ -54,61 +57,93 @@ def initialize_module_for_quantization(
54
57
  :param scheme: scheme to use for quantization. if None is provided,
55
58
  will attempt to use scheme stored in the module under `quantization_scheme`,
56
59
  if not provided, the layer will be skipped
60
+ :param force_zero_point: whether to force initialization of a zero point for
61
+ symmetric quantization
57
62
  """
58
63
  scheme = scheme or getattr(module, "quantization_scheme", None)
59
64
  if scheme is None:
60
65
  # no scheme passed and layer not targeted for quantization - skip
61
66
  return
62
67
 
63
- if scheme.input_activations is not None:
64
- _initialize_scale_zero_point_observer(module, "input", scheme.input_activations)
65
- if scheme.weights is not None:
66
- if hasattr(module, "weight"):
67
- weight_shape = None
68
- if isinstance(module, torch.nn.Linear):
69
- weight_shape = module.weight.shape
70
- _initialize_scale_zero_point_observer(
71
- module, "weight", scheme.weights, weight_shape=weight_shape
72
- )
73
- else:
74
- _LOGGER.warning(
75
- f"module type {type(module)} targeted for weight quantization but "
76
- "has no attribute weight, skipping weight quantization "
77
- f"for {type(module)}"
78
- )
79
- if scheme.output_activations is not None:
80
- _initialize_scale_zero_point_observer(
81
- module, "output", scheme.output_activations
82
- )
68
+ if is_attention_module(module):
69
+ # wrap forward call of module to perform
70
+ # quantized actions based on calltime status
71
+ wrap_module_forward_quantized_attn(module, scheme)
72
+ _initialize_attn_scales(module)
83
73
 
84
- module.quantization_scheme = scheme
85
- module.quantization_status = QuantizationStatus.INITIALIZED
74
+ else:
86
75
 
87
- offloaded = False
88
- if is_module_offloaded(module):
89
- offloaded = True
90
- hook = module._hf_hook
91
- prefix_dict = module._hf_hook.weights_map
92
- new_prefix = {}
93
-
94
- # recreate the prefix dict (since it is immutable)
95
- # and add quantization parameters
96
- for key, data in module.named_parameters():
97
- if key not in prefix_dict:
98
- new_prefix[f"{prefix_dict.prefix}{key}"] = data
76
+ if scheme.input_activations is not None:
77
+ _initialize_scale_zero_point_observer(
78
+ module,
79
+ "input",
80
+ scheme.input_activations,
81
+ force_zero_point=force_zero_point,
82
+ )
83
+ if scheme.weights is not None:
84
+ if hasattr(module, "weight"):
85
+ weight_shape = None
86
+ if isinstance(module, torch.nn.Linear):
87
+ weight_shape = module.weight.shape
88
+ _initialize_scale_zero_point_observer(
89
+ module,
90
+ "weight",
91
+ scheme.weights,
92
+ weight_shape=weight_shape,
93
+ force_zero_point=force_zero_point,
94
+ )
99
95
  else:
100
- new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
101
- new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
102
- remove_hook_from_module(module)
103
-
104
- # wrap forward call of module to perform quantized actions based on calltime status
105
- wrap_module_forward_quantized(module, scheme)
106
-
107
- if offloaded:
108
- # we need to re-add the hook for offloading now that we've wrapped forward
109
- add_hook_to_module(module, hook)
110
- if prefix_dict is not None:
111
- module._hf_hook.weights_map = new_prefix_dict
96
+ _LOGGER.warning(
97
+ f"module type {type(module)} targeted for weight quantization but "
98
+ "has no attribute weight, skipping weight quantization "
99
+ f"for {type(module)}"
100
+ )
101
+
102
+ if scheme.output_activations is not None:
103
+ if not is_kv_cache_quant_scheme(scheme):
104
+ _initialize_scale_zero_point_observer(
105
+ module, "output", scheme.output_activations
106
+ )
107
+
108
+ module.quantization_scheme = scheme
109
+ module.quantization_status = QuantizationStatus.INITIALIZED
110
+
111
+ offloaded = False
112
+ if is_module_offloaded(module):
113
+ try:
114
+ from accelerate.hooks import add_hook_to_module, remove_hook_from_module
115
+ from accelerate.utils import PrefixedDataset
116
+ except ModuleNotFoundError:
117
+ raise ModuleNotFoundError(
118
+ "Offloaded model detected. To use CPU offloading with "
119
+ "compressed-tensors the `accelerate` package must be installed, "
120
+ "run `pip install compressed-tensors[accelerate]`"
121
+ )
122
+
123
+ offloaded = True
124
+ hook = module._hf_hook
125
+ prefix_dict = module._hf_hook.weights_map
126
+ new_prefix = {}
127
+
128
+ # recreate the prefix dict (since it is immutable)
129
+ # and add quantization parameters
130
+ for key, data in module.named_parameters():
131
+ if key not in prefix_dict:
132
+ new_prefix[f"{prefix_dict.prefix}{key}"] = data
133
+ else:
134
+ new_prefix[f"{prefix_dict.prefix}{key}"] = prefix_dict[key]
135
+ new_prefix_dict = PrefixedDataset(new_prefix, prefix_dict.prefix)
136
+ remove_hook_from_module(module)
137
+
138
+ # wrap forward call of module to perform
139
+ # quantized actions based on calltime status
140
+ wrap_module_forward_quantized(module, scheme)
141
+
142
+ if offloaded:
143
+ # we need to re-add the hook for offloading now that we've wrapped forward
144
+ add_hook_to_module(module, hook)
145
+ if prefix_dict is not None:
146
+ module._hf_hook.weights_map = new_prefix_dict
112
147
 
113
148
 
114
149
  def _initialize_scale_zero_point_observer(
@@ -116,6 +151,7 @@ def _initialize_scale_zero_point_observer(
116
151
  base_name: str,
117
152
  quantization_args: QuantizationArgs,
118
153
  weight_shape: Optional[torch.Size] = None,
154
+ force_zero_point: bool = True,
119
155
  ):
120
156
  # initialize observer module and attach as submodule
121
157
  observer = quantization_args.get_observer()
@@ -136,21 +172,68 @@ def _initialize_scale_zero_point_observer(
136
172
  # (output_channels, 1)
137
173
  expected_shape = (weight_shape[0], 1)
138
174
  elif quantization_args.strategy == QuantizationStrategy.GROUP:
175
+ num_groups = weight_shape[1] // quantization_args.group_size
139
176
  expected_shape = (
140
177
  weight_shape[0],
141
- weight_shape[1] // quantization_args.group_size,
178
+ max(num_groups, 1)
142
179
  )
143
180
 
144
- # initializes empty scale and zero point parameters for the module
181
+ scale_dtype = module.weight.dtype
182
+ if scale_dtype not in [torch.float16, torch.bfloat16, torch.float32]:
183
+ scale_dtype = torch.float16
184
+
185
+ # initializes empty scale, zero point, and g_idx parameters for the module
145
186
  init_scale = Parameter(
146
- torch.empty(expected_shape, dtype=module.weight.dtype, device=device),
187
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
147
188
  requires_grad=False,
148
189
  )
149
190
  module.register_parameter(f"{base_name}_scale", init_scale)
150
191
 
151
- zp_dtype = quantization_args.pytorch_dtype()
152
- init_zero_point = Parameter(
153
- torch.empty(expected_shape, device=device, dtype=zp_dtype),
192
+ if force_zero_point or not quantization_args.symmetric:
193
+ zp_dtype = quantization_args.pytorch_dtype()
194
+ init_zero_point = Parameter(
195
+ torch.zeros(expected_shape, device=device, dtype=zp_dtype),
196
+ requires_grad=False,
197
+ )
198
+ module.register_parameter(f"{base_name}_zero_point", init_zero_point)
199
+
200
+ # only grouped activation ordering has g_idx
201
+ if quantization_args.actorder == ActivationOrdering.GROUP:
202
+ g_idx_shape = (weight_shape[1],)
203
+ g_idx_dtype = torch.int
204
+ init_g_idx = Parameter(
205
+ torch.full(g_idx_shape, -1, device=device, dtype=g_idx_dtype),
206
+ requires_grad=False,
207
+ )
208
+ module.register_parameter(f"{base_name}_g_idx", init_g_idx)
209
+
210
+
211
+ def is_attention_module(module: Module):
212
+ return "attention" in module.__class__.__name__.lower() and (
213
+ hasattr(module, "k_proj")
214
+ or hasattr(module, "v_proj")
215
+ or hasattr(module, "qkv_proj")
216
+ )
217
+
218
+
219
+ def _initialize_attn_scales(module: Module) -> None:
220
+ """Initlaize k_scale, v_scale for self_attn"""
221
+
222
+ expected_shape = 1 # per tensor
223
+
224
+ param = next(module.parameters())
225
+ scale_dtype = param.dtype
226
+ device = param.device
227
+
228
+ init_scale = Parameter(
229
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
230
+ requires_grad=False,
231
+ )
232
+
233
+ module.register_parameter(KVCacheScaleType.KEY.value, init_scale)
234
+
235
+ init_scale = Parameter(
236
+ torch.empty(expected_shape, dtype=scale_dtype, device=device),
154
237
  requires_grad=False,
155
238
  )
156
- module.register_parameter(f"{base_name}_zero_point", init_zero_point)
239
+ module.register_parameter(KVCacheScaleType.VALUE.value, init_scale)
@@ -19,3 +19,4 @@ from .helpers import *
19
19
  from .base import *
20
20
  from .memoryless import *
21
21
  from .min_max import *
22
+ from .mse import *
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import logging
16
+ from math import ceil
16
17
  from typing import Any, Iterable, Optional, Tuple, Union
17
18
 
18
19
  import torch
@@ -21,6 +22,7 @@ from compressed_tensors.quantization.quant_args import (
21
22
  QuantizationStrategy,
22
23
  )
23
24
  from compressed_tensors.registry.registry import RegistryMixin
25
+ from compressed_tensors.utils import safe_permute
24
26
  from torch import FloatTensor, IntTensor, Tensor
25
27
  from torch.nn import Module
26
28
 
@@ -46,15 +48,18 @@ class Observer(Module, RegistryMixin):
46
48
  self._num_observed_tokens = None
47
49
 
48
50
  @torch.no_grad()
49
- def forward(self, observed: Tensor) -> Tuple[FloatTensor, IntTensor]:
51
+ def forward(
52
+ self, observed: Tensor, g_idx: Optional[Tensor] = None
53
+ ) -> Tuple[FloatTensor, IntTensor]:
50
54
  """
51
55
  maps directly to get_qparams
52
- :param observed: optional observed tensor to calculate quantization parameters
53
- from
56
+ :param observed: optional observed tensor from which to calculate
57
+ quantization parameters
58
+ :param g_idx: optional mapping from column index to group index
54
59
  :return: tuple of scale and zero point based on last observed value
55
60
  """
56
61
  self.record_observed_tokens(observed)
57
- return self.get_qparams(observed=observed)
62
+ return self.get_qparams(observed=observed, g_idx=g_idx)
58
63
 
59
64
  def calculate_qparams(
60
65
  self,
@@ -77,7 +82,9 @@ class Observer(Module, RegistryMixin):
77
82
  ...
78
83
 
79
84
  def get_qparams(
80
- self, observed: Optional[Tensor] = None
85
+ self,
86
+ observed: Optional[Tensor] = None,
87
+ g_idx: Optional[Tensor] = None,
81
88
  ) -> Tuple[FloatTensor, IntTensor]:
82
89
  """
83
90
  Convenience function to wrap overwritten calculate_qparams
@@ -86,6 +93,7 @@ class Observer(Module, RegistryMixin):
86
93
 
87
94
  :param observed: optional observed tensor to calculate quantization parameters
88
95
  from
96
+ :param g_idx: optional mapping from column index to group index
89
97
  :return: tuple of scale and zero point based on last observed value
90
98
  """
91
99
  if observed is not None:
@@ -97,20 +105,42 @@ class Observer(Module, RegistryMixin):
97
105
  self._scale, self._zero_point = self.calculate_qparams(observed)
98
106
 
99
107
  elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
108
+ rows = observed.shape[0]
100
109
  columns = observed.shape[1]
101
- scales, zero_points = [], []
102
- group_idxs = range(0, columns, self.quantization_args.group_size)
103
- for group_id, group_idx in enumerate(group_idxs):
110
+ num_groups = int(ceil(columns / group_size))
111
+ self._scale = torch.empty(
112
+ (rows, num_groups), dtype=observed.dtype, device=observed.device
113
+ )
114
+ zp_dtype = self.quantization_args.pytorch_dtype()
115
+ self._zero_point = torch.empty(
116
+ (rows, num_groups), dtype=zp_dtype, device=observed.device
117
+ )
118
+
119
+ # support column-order (default) quantization as well as other orderings
120
+ # such as activation ordering. Below checks if g_idx has initialized
121
+ is_column_order = g_idx is None or -1 in g_idx
122
+ if is_column_order:
123
+ group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
124
+ else:
125
+ group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
126
+ group_sizes = group_sizes[torch.argsort(group_indices)]
127
+
128
+ perm = torch.argsort(g_idx)
129
+ observed = safe_permute(observed, perm, dim=1)
130
+
131
+ # TODO: experiment with vectorizing for loop for performance
132
+ end = 0
133
+ for group_index, group_count in enumerate(group_sizes):
134
+ start = end
135
+ end = start + group_count
104
136
  scale, zero_point = self.get_qparams_along_dim(
105
- observed[:, group_idx : (group_idx + group_size)],
137
+ observed[:, start:end],
106
138
  0,
107
- tensor_id=group_id,
139
+ tensor_id=group_index,
108
140
  )
109
- scales.append(scale)
110
- zero_points.append(zero_point)
111
141
 
112
- self._scale = torch.cat(scales, dim=1, out=self._scale)
113
- self._zero_point = torch.cat(zero_points, dim=1, out=self._zero_point)
142
+ self._scale[:, group_index] = scale.squeeze(1)
143
+ self._zero_point[:, group_index] = zero_point.squeeze(1)
114
144
 
115
145
  elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
116
146
  # assume observed is transposed, because its the output, hence use dim 0
@@ -132,6 +162,8 @@ class Observer(Module, RegistryMixin):
132
162
  dim: Union[int, Iterable[int]],
133
163
  tensor_id: Optional[Any] = None,
134
164
  ):
165
+ if isinstance(dim, int):
166
+ dim = [dim]
135
167
  dim = set(dim)
136
168
 
137
169
  reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
@@ -171,3 +203,11 @@ class Observer(Module, RegistryMixin):
171
203
  # observed_tokens (batch_size * sequence_length)
172
204
  observed_tokens, _ = batch_tensor.shape
173
205
  self._num_observed_tokens += observed_tokens
206
+
207
+ def reset(self):
208
+ """
209
+ Reset the state of the observer
210
+ """
211
+ self._num_observed_tokens = None
212
+ self._scale = None
213
+ self._zero_point = None
@@ -94,3 +94,11 @@ class MovingAverageMinMaxObserver(Observer):
94
94
  return self.calculate_qparams(
95
95
  observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
96
  )
97
+
98
+ def reset(self):
99
+ """
100
+ Reset the state of the observer, including min and maximum values
101
+ """
102
+ super().reset()
103
+ self.min_val = {}
104
+ self.max_val = {}
@@ -0,0 +1,162 @@
1
+ # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing,
10
+ # software distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, Optional, Tuple
16
+
17
+ import torch
18
+ from compressed_tensors.quantization.observers.base import Observer
19
+ from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
+ from compressed_tensors.quantization.quant_args import QuantizationArgs
21
+ from torch import FloatTensor, IntTensor, Tensor
22
+
23
+
24
+ __all__ = ["MovingAverageMSEObserver"]
25
+
26
+
27
+ @Observer.register("mse")
28
+ class MovingAverageMSEObserver(Observer):
29
+ """
30
+ Implements a dynamic quantization observer that sets the scale and
31
+ zero point based on a moving average of the mse-clipped min and max observed values
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ quantization_args: QuantizationArgs,
37
+ averaging_constant: float = 0.01,
38
+ grid: float = 100.0,
39
+ maxshrink: float = 0.80,
40
+ norm: float = 2.4,
41
+ ):
42
+ super().__init__(quantization_args=quantization_args)
43
+
44
+ self.min_val = {}
45
+ self.max_val = {}
46
+ self.averaging_constant = averaging_constant
47
+ self.grid = grid
48
+ self.maxshrink = maxshrink
49
+ self.norm = norm
50
+
51
+ def calculate_mse_min_max(
52
+ self,
53
+ observed: Tensor,
54
+ reduce_dims: Optional[Tuple[int]] = None,
55
+ ):
56
+ """
57
+ Computes the mse-clipped min and max values of the observed tensor by
58
+ optimizing for quantization error
59
+
60
+ :param observed: observed tensor to calculate quantization parameters for
61
+ :param reduce_dims: optional tuple of dimensions to reduce along,
62
+ returned values will be shaped (1,) along the reduced dimensions
63
+ :return: tuple of min and max values derived from the observed tensor
64
+ """
65
+ from compressed_tensors.quantization.lifecycle import fake_quantize
66
+
67
+ if not reduce_dims:
68
+ absolute_min_val, absolute_max_val = torch.aminmax(observed)
69
+ else:
70
+ absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
71
+ absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
72
+
73
+ best = torch.full(absolute_min_val.shape, float("inf"))
74
+ min_val = torch.ones(absolute_min_val.shape)
75
+ max_val = torch.zeros(absolute_max_val.shape)
76
+ for i in range(int(self.maxshrink * self.grid)):
77
+ p = 1 - i / self.grid
78
+ shrinked_min_val = p * absolute_min_val
79
+ shrinked_max_val = p * absolute_max_val
80
+
81
+ candidate_scales, candidate_zero_points = calculate_qparams(
82
+ shrinked_min_val, shrinked_max_val, self.quantization_args
83
+ )
84
+ q = fake_quantize(
85
+ observed,
86
+ candidate_scales,
87
+ candidate_zero_points,
88
+ self.quantization_args,
89
+ )
90
+
91
+ q -= observed
92
+ q.abs_()
93
+ q.pow_(self.norm)
94
+ if not reduce_dims:
95
+ err = torch.sum(q)
96
+ else:
97
+ err = torch.sum(q, reduce_dims, keepdims=True)
98
+
99
+ tmp = err < best
100
+ if torch.any(tmp):
101
+ best[tmp] = err[tmp]
102
+ min_val[tmp] = shrinked_min_val[tmp]
103
+ max_val[tmp] = shrinked_max_val[tmp]
104
+ return min_val, max_val
105
+
106
+ def calculate_qparams(
107
+ self,
108
+ observed: Tensor,
109
+ reduce_dims: Optional[Tuple[int]] = None,
110
+ tensor_id: Optional[Any] = None,
111
+ ) -> Tuple[FloatTensor, IntTensor]:
112
+ """
113
+ Updates the mse-clipped min and max values of the observed tensor using
114
+ a moving average smoothed by the averaging_constant
115
+
116
+ :param observed: observed tensor to calculate quantization parameters for
117
+ :param reduce_dims: optional tuple of dimensions to reduce along,
118
+ returned scale and zero point will be shaped (1,) along the
119
+ reduced dimensions
120
+ :param tensor_id: Optional id if different ranges of observed tensors are
121
+ passed, useful for sharding tensors by group_size
122
+ :return: tuple of scale and zero point derived from the observed tensor
123
+ """
124
+ min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
125
+
126
+ running_min_val = self.min_val.get(tensor_id, None)
127
+ running_max_val = self.max_val.get(tensor_id, None)
128
+
129
+ if running_min_val is None or running_max_val is None:
130
+ updated_min_val = min_val
131
+ updated_max_val = max_val
132
+ else:
133
+ updated_min_val = running_min_val + self.averaging_constant * (
134
+ min_val - running_min_val
135
+ )
136
+ updated_max_val = running_max_val + self.averaging_constant * (
137
+ max_val - running_max_val
138
+ )
139
+
140
+ tensor_id = tensor_id or "default"
141
+ self.min_val[tensor_id] = updated_min_val
142
+ self.max_val[tensor_id] = updated_max_val
143
+
144
+ return calculate_qparams(
145
+ updated_min_val, updated_max_val, self.quantization_args
146
+ )
147
+
148
+ def get_qparams_along_dim(
149
+ self, observed, dim: int, tensor_id: Optional[Any] = None
150
+ ):
151
+ reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
152
+ return self.calculate_qparams(
153
+ observed, reduce_dims=reduce_dims, tensor_id=tensor_id
154
+ )
155
+
156
+ def reset(self):
157
+ """
158
+ Reset the state of the observer, including min and maximum values
159
+ """
160
+ super().reset()
161
+ self.min_val = {}
162
+ self.max_val = {}