compressed-tensors 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (26) hide show
  1. compressed_tensors/compressors/sparse_quantized_compressors/marlin_24.py +1 -1
  2. compressed_tensors/config/base.py +60 -2
  3. compressed_tensors/quantization/__init__.py +0 -1
  4. compressed_tensors/quantization/lifecycle/__init__.py +0 -2
  5. compressed_tensors/quantization/lifecycle/apply.py +1 -16
  6. compressed_tensors/quantization/lifecycle/forward.py +25 -86
  7. compressed_tensors/quantization/lifecycle/initialize.py +23 -25
  8. compressed_tensors/quantization/quant_args.py +28 -15
  9. compressed_tensors/quantization/quant_scheme.py +3 -0
  10. compressed_tensors/quantization/utils/helpers.py +125 -8
  11. compressed_tensors/registry/registry.py +1 -1
  12. compressed_tensors/version.py +1 -1
  13. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/METADATA +1 -1
  14. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/RECORD +17 -26
  15. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/WHEEL +1 -1
  16. compressed_tensors/quantization/cache.py +0 -201
  17. compressed_tensors/quantization/lifecycle/calibration.py +0 -70
  18. compressed_tensors/quantization/lifecycle/frozen.py +0 -55
  19. compressed_tensors/quantization/observers/__init__.py +0 -22
  20. compressed_tensors/quantization/observers/base.py +0 -213
  21. compressed_tensors/quantization/observers/helpers.py +0 -111
  22. compressed_tensors/quantization/observers/memoryless.py +0 -56
  23. compressed_tensors/quantization/observers/min_max.py +0 -104
  24. compressed_tensors/quantization/observers/mse.py +0 -162
  25. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/LICENSE +0 -0
  26. {compressed_tensors-0.7.0.dist-info → compressed_tensors-0.8.0.dist-info}/top_level.txt +0 -0
@@ -1,213 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import logging
16
- from math import ceil
17
- from typing import Any, Iterable, Optional, Tuple, Union
18
-
19
- import torch
20
- from compressed_tensors.quantization.quant_args import (
21
- QuantizationArgs,
22
- QuantizationStrategy,
23
- )
24
- from compressed_tensors.registry.registry import RegistryMixin
25
- from compressed_tensors.utils import safe_permute
26
- from torch import FloatTensor, IntTensor, Tensor
27
- from torch.nn import Module
28
-
29
-
30
- _LOGGER = logging.getLogger(__name__)
31
-
32
-
33
- __all__ = ["Observer"]
34
-
35
-
36
- class Observer(Module, RegistryMixin):
37
- """
38
- Base Observer class to be subclassed for specific implementation.
39
- Subclasses should override `calculate_qparams` to return a scale, zero_point
40
- pair
41
- """
42
-
43
- def __init__(self, quantization_args: QuantizationArgs):
44
- self.quantization_args: QuantizationArgs = quantization_args
45
- super().__init__()
46
- self._scale = None
47
- self._zero_point = None
48
- self._num_observed_tokens = None
49
-
50
- @torch.no_grad()
51
- def forward(
52
- self, observed: Tensor, g_idx: Optional[Tensor] = None
53
- ) -> Tuple[FloatTensor, IntTensor]:
54
- """
55
- maps directly to get_qparams
56
- :param observed: optional observed tensor from which to calculate
57
- quantization parameters
58
- :param g_idx: optional mapping from column index to group index
59
- :return: tuple of scale and zero point based on last observed value
60
- """
61
- self.record_observed_tokens(observed)
62
- return self.get_qparams(observed=observed, g_idx=g_idx)
63
-
64
- def calculate_qparams(
65
- self,
66
- observed: Tensor,
67
- reduce_dims: Optional[Tuple[int]] = None,
68
- ) -> Tuple[FloatTensor, IntTensor]:
69
- """
70
- :param observed: observed tensor to calculate quantization parameters for
71
- :param reduce_dims: optional tuple of dimensions to reduce along,
72
- returned scale and zero point will be shaped (1,) along the
73
- reduced dimensions
74
- :return: tuple of scale and zero point derived from the observed tensor
75
- """
76
- raise NotImplementedError(f"{self.__class__} must implement calculate_qparams")
77
-
78
- def post_calculate_qparams(self) -> None:
79
- """
80
- Run any logic specific to its observers after running calculate_qparams
81
- """
82
- ...
83
-
84
- def get_qparams(
85
- self,
86
- observed: Optional[Tensor] = None,
87
- g_idx: Optional[Tensor] = None,
88
- ) -> Tuple[FloatTensor, IntTensor]:
89
- """
90
- Convenience function to wrap overwritten calculate_qparams
91
- adds support to make observed tensor optional and support for tracking latest
92
- calculated scale and zero point
93
-
94
- :param observed: optional observed tensor to calculate quantization parameters
95
- from
96
- :param g_idx: optional mapping from column index to group index
97
- :return: tuple of scale and zero point based on last observed value
98
- """
99
- if observed is not None:
100
- group_size = self.quantization_args.group_size
101
-
102
- if self.quantization_args.strategy == QuantizationStrategy.TENSOR:
103
-
104
- # re-calculate scale and zero point, update the stored value
105
- self._scale, self._zero_point = self.calculate_qparams(observed)
106
-
107
- elif self.quantization_args.strategy == QuantizationStrategy.GROUP:
108
- rows = observed.shape[0]
109
- columns = observed.shape[1]
110
- num_groups = int(ceil(columns / group_size))
111
- self._scale = torch.empty(
112
- (rows, num_groups), dtype=observed.dtype, device=observed.device
113
- )
114
- zp_dtype = self.quantization_args.pytorch_dtype()
115
- self._zero_point = torch.empty(
116
- (rows, num_groups), dtype=zp_dtype, device=observed.device
117
- )
118
-
119
- # support column-order (default) quantization as well as other orderings
120
- # such as activation ordering. Below checks if g_idx has initialized
121
- is_column_order = g_idx is None or -1 in g_idx
122
- if is_column_order:
123
- group_sizes = torch.full((num_groups,), group_size, dtype=torch.int)
124
- else:
125
- group_indices, group_sizes = torch.unique(g_idx, return_counts=True)
126
- group_sizes = group_sizes[torch.argsort(group_indices)]
127
-
128
- perm = torch.argsort(g_idx)
129
- observed = safe_permute(observed, perm, dim=1)
130
-
131
- # TODO: experiment with vectorizing for loop for performance
132
- end = 0
133
- for group_index, group_count in enumerate(group_sizes):
134
- start = end
135
- end = start + group_count
136
- scale, zero_point = self.get_qparams_along_dim(
137
- observed[:, start:end],
138
- 0,
139
- tensor_id=group_index,
140
- )
141
-
142
- self._scale[:, group_index] = scale.squeeze(1)
143
- self._zero_point[:, group_index] = zero_point.squeeze(1)
144
-
145
- elif self.quantization_args.strategy == QuantizationStrategy.CHANNEL:
146
- # assume observed is transposed, because its the output, hence use dim 0
147
- self._scale, self._zero_point = self.get_qparams_along_dim(observed, 0)
148
-
149
- elif self.quantization_args.strategy == QuantizationStrategy.TOKEN:
150
- # use dim 1, assume the obsersed.shape = [batch, token, hidden]
151
- # should be batch, token
152
- self._scale, self._zero_point = self.get_qparams_along_dim(
153
- observed,
154
- dim={0, 1},
155
- )
156
-
157
- return self._scale, self._zero_point
158
-
159
- def get_qparams_along_dim(
160
- self,
161
- observed,
162
- dim: Union[int, Iterable[int]],
163
- tensor_id: Optional[Any] = None,
164
- ):
165
- if isinstance(dim, int):
166
- dim = [dim]
167
- dim = set(dim)
168
-
169
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx not in dim)
170
- return self.calculate_qparams(
171
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
172
- )
173
-
174
- def record_observed_tokens(self, batch_tensor: Tensor):
175
- """
176
- Counts the number of tokens observed during the
177
- forward passes. The count is aggregated in the
178
- _num_observed_tokens attribute of the class.
179
-
180
- Note: The batch_tensor is expected to have two dimensions
181
- (batch_size * sequence_length, num_features). This is the
182
- general shape expected by the forward pass of the expert
183
- layers in a MOE model. If the input tensor does not have
184
- two dimensions, the _num_observed_tokens attribute will be set
185
- to None.
186
- """
187
- if not isinstance(batch_tensor, Tensor):
188
- raise ValueError(f"Expected value to be a tensor, got {type(batch_tensor)}")
189
-
190
- if batch_tensor.ndim != 2:
191
- _LOGGER.debug(
192
- "The input tensor is expected to have two dimensions "
193
- "(batch_size * sequence_length, num_features). "
194
- f"The input tensor has {batch_tensor.ndim} dimensions."
195
- )
196
- return
197
-
198
- if self._num_observed_tokens is None:
199
- # initialize the count
200
- self._num_observed_tokens = 0
201
-
202
- # batch_tensor (batch_size * sequence_length, num_features)
203
- # observed_tokens (batch_size * sequence_length)
204
- observed_tokens, _ = batch_tensor.shape
205
- self._num_observed_tokens += observed_tokens
206
-
207
- def reset(self):
208
- """
209
- Reset the state of the observer
210
- """
211
- self._num_observed_tokens = None
212
- self._scale = None
213
- self._zero_point = None
@@ -1,111 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from collections import Counter
16
- from typing import Tuple
17
-
18
- import torch
19
- from compressed_tensors.quantization.quant_args import (
20
- FP8_DTYPE,
21
- QuantizationArgs,
22
- QuantizationType,
23
- )
24
- from torch import FloatTensor, IntTensor, Tensor
25
-
26
-
27
- __all__ = ["calculate_qparams", "get_observer_token_count", "calculate_range"]
28
-
29
-
30
- def get_observer_token_count(module: torch.nn.Module) -> Counter:
31
- """
32
- Parse the module and return the number of tokens observed by
33
- each module's observer.
34
-
35
- :param module: module to parse
36
- :return: counter with the number of tokens observed by each observer
37
- """
38
- token_counts = Counter()
39
- for name, module in module.named_modules():
40
- if name.endswith(".input_observer"):
41
- token_counts[
42
- name.replace(".input_observer", "")
43
- ] = module._num_observed_tokens
44
- return token_counts
45
-
46
-
47
- def calculate_qparams(
48
- min_vals: Tensor, max_vals: Tensor, quantization_args: QuantizationArgs
49
- ) -> Tuple[FloatTensor, IntTensor]:
50
- """
51
- :param min_vals: tensor of min value(s) to calculate scale(s) and zero point(s)
52
- from
53
- :param max_vals: tensor of max value(s) to calculate scale(s) and zero point(s)
54
- from
55
- :param quantization_args: settings to quantization
56
- :return: tuple of the calculated scale(s) and zero point(s)
57
- """
58
- min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
59
- max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
60
- device = min_vals.device
61
-
62
- bit_min, bit_max = calculate_range(quantization_args, device)
63
- bit_range = bit_max - bit_min
64
- zp_dtype = quantization_args.pytorch_dtype()
65
-
66
- if quantization_args.symmetric:
67
- max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
68
- scales = max_val_pos / (float(bit_range) / 2)
69
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
70
- zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
71
- else:
72
- scales = (max_vals - min_vals) / float(bit_range)
73
- scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
74
- zero_points = bit_min - (min_vals / scales)
75
- zero_points = torch.clamp(zero_points, bit_min, bit_max)
76
-
77
- # match zero-points to quantized type
78
- zero_points = zero_points.to(zp_dtype)
79
-
80
- if scales.ndim == 0:
81
- scales = scales.reshape(1)
82
- zero_points = zero_points.reshape(1)
83
-
84
- return scales, zero_points
85
-
86
-
87
- def calculate_range(quantization_args: QuantizationArgs, device: str) -> Tuple:
88
- """
89
- Calculated the effective quantization range for the given Quantization Args
90
-
91
- :param quantization_args: quantization args to get range of
92
- :param device: device to store the range to
93
- :return: tuple endpoints for the given quantization range
94
- """
95
- if quantization_args.type == QuantizationType.INT:
96
- bit_range = 2**quantization_args.num_bits
97
- q_max = torch.tensor(bit_range / 2 - 1, device=device)
98
- q_min = torch.tensor(-bit_range / 2, device=device)
99
- elif quantization_args.type == QuantizationType.FLOAT:
100
- if quantization_args.num_bits != 8:
101
- raise ValueError(
102
- "Floating point quantization is only supported for 8 bits,"
103
- f"got {quantization_args.num_bits}"
104
- )
105
- fp_range_info = torch.finfo(FP8_DTYPE)
106
- q_max = torch.tensor(fp_range_info.max, device=device)
107
- q_min = torch.tensor(fp_range_info.min, device=device)
108
- else:
109
- raise ValueError(f"Invalid quantization type {quantization_args.type}")
110
-
111
- return q_min, q_max
@@ -1,56 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from torch import FloatTensor, IntTensor, Tensor
21
-
22
-
23
- __all__ = ["MemorylessObserver"]
24
-
25
-
26
- @Observer.register("memoryless", alias=["dynamic"])
27
- class MemorylessObserver(Observer):
28
- """
29
- Implements a quantization observer that sets the scale and
30
- zero point based on the latest observed value without tracking state
31
- """
32
-
33
- def calculate_qparams(
34
- self,
35
- observed: Tensor,
36
- tensor_id: Optional[Any] = None,
37
- reduce_dims: Optional[Tuple[int]] = None,
38
- ) -> Tuple[FloatTensor, IntTensor]:
39
- """
40
- Returns the min and max values of observed tensor
41
-
42
- :param observed: observed tensor to calculate quantization parameters for
43
- :param tensor_id: optional id for tensor; not used for memoryless
44
- :param reduce_dims: optional tuple of dimensions to reduce along,
45
- returned scale and zero point will be shaped (1,) along the
46
- reduced dimensions
47
- :return: tuple of scale and zero point derived from the observed tensor
48
- """
49
-
50
- if not reduce_dims:
51
- min_val, max_val = torch.aminmax(observed)
52
- else:
53
- min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
54
- max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
55
-
56
- return calculate_qparams(min_val, max_val, self.quantization_args)
@@ -1,104 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import FloatTensor, IntTensor, Tensor
22
-
23
-
24
- __all__ = ["MovingAverageMinMaxObserver"]
25
-
26
-
27
- @Observer.register("minmax")
28
- class MovingAverageMinMaxObserver(Observer):
29
- """
30
- Implements a dynamic quantization observer that sets the scale and
31
- zero point based on a moving average of the overall min and max observed values
32
- """
33
-
34
- def __init__(
35
- self, quantization_args: QuantizationArgs, averaging_constant: float = 0.01
36
- ):
37
- super().__init__(quantization_args=quantization_args)
38
-
39
- self.min_val = {}
40
- self.max_val = {}
41
- self.averaging_constant = averaging_constant
42
-
43
- def calculate_qparams(
44
- self,
45
- observed: Tensor,
46
- reduce_dims: Optional[Tuple[int]] = None,
47
- tensor_id: Optional[Any] = None,
48
- ) -> Tuple[FloatTensor, IntTensor]:
49
- """
50
- Updates the observed min and max using a moving average smoothed by the
51
- averaging_constant
52
-
53
- :param observed: observed tensor to calculate quantization parameters for
54
- :param reduce_dims: optional tuple of dimensions to reduce along,
55
- returned scale and zero point will be shaped (1,) along the
56
- reduced dimensions
57
- :param tensor_id: Optional id if different ranges of observed tensors are
58
- passed, useful for sharding tensors by group_size
59
- :return: tuple of scale and zero point derived from the observed tensor
60
- """
61
- tensor_id = tensor_id or "default"
62
-
63
- if not reduce_dims:
64
- min_val, max_val = torch.aminmax(observed)
65
- else:
66
- min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
67
- max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
68
-
69
- running_min_val = self.min_val.get(tensor_id, None)
70
- running_max_val = self.max_val.get(tensor_id, None)
71
-
72
- if running_min_val is None or running_max_val is None:
73
- updated_min_val = min_val
74
- updated_max_val = max_val
75
- else:
76
- updated_min_val = running_min_val + self.averaging_constant * (
77
- min_val - running_min_val
78
- )
79
- updated_max_val = running_max_val + self.averaging_constant * (
80
- max_val - running_max_val
81
- )
82
-
83
- self.min_val[tensor_id] = updated_min_val
84
- self.max_val[tensor_id] = updated_max_val
85
-
86
- return calculate_qparams(
87
- updated_min_val, updated_max_val, self.quantization_args
88
- )
89
-
90
- def get_qparams_along_dim(
91
- self, observed, dim: int, tensor_id: Optional[Any] = None
92
- ):
93
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
94
- return self.calculate_qparams(
95
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
96
- )
97
-
98
- def reset(self):
99
- """
100
- Reset the state of the observer, including min and maximum values
101
- """
102
- super().reset()
103
- self.min_val = {}
104
- self.max_val = {}
@@ -1,162 +0,0 @@
1
- # Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing,
10
- # software distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- from typing import Any, Optional, Tuple
16
-
17
- import torch
18
- from compressed_tensors.quantization.observers.base import Observer
19
- from compressed_tensors.quantization.observers.helpers import calculate_qparams
20
- from compressed_tensors.quantization.quant_args import QuantizationArgs
21
- from torch import FloatTensor, IntTensor, Tensor
22
-
23
-
24
- __all__ = ["MovingAverageMSEObserver"]
25
-
26
-
27
- @Observer.register("mse")
28
- class MovingAverageMSEObserver(Observer):
29
- """
30
- Implements a dynamic quantization observer that sets the scale and
31
- zero point based on a moving average of the mse-clipped min and max observed values
32
- """
33
-
34
- def __init__(
35
- self,
36
- quantization_args: QuantizationArgs,
37
- averaging_constant: float = 0.01,
38
- grid: float = 100.0,
39
- maxshrink: float = 0.80,
40
- norm: float = 2.4,
41
- ):
42
- super().__init__(quantization_args=quantization_args)
43
-
44
- self.min_val = {}
45
- self.max_val = {}
46
- self.averaging_constant = averaging_constant
47
- self.grid = grid
48
- self.maxshrink = maxshrink
49
- self.norm = norm
50
-
51
- def calculate_mse_min_max(
52
- self,
53
- observed: Tensor,
54
- reduce_dims: Optional[Tuple[int]] = None,
55
- ):
56
- """
57
- Computes the mse-clipped min and max values of the observed tensor by
58
- optimizing for quantization error
59
-
60
- :param observed: observed tensor to calculate quantization parameters for
61
- :param reduce_dims: optional tuple of dimensions to reduce along,
62
- returned values will be shaped (1,) along the reduced dimensions
63
- :return: tuple of min and max values derived from the observed tensor
64
- """
65
- from compressed_tensors.quantization.lifecycle import fake_quantize
66
-
67
- if not reduce_dims:
68
- absolute_min_val, absolute_max_val = torch.aminmax(observed)
69
- else:
70
- absolute_min_val = torch.amin(observed, dim=reduce_dims, keepdims=True)
71
- absolute_max_val = torch.amax(observed, dim=reduce_dims, keepdims=True)
72
-
73
- best = torch.full(absolute_min_val.shape, float("inf"))
74
- min_val = torch.ones(absolute_min_val.shape)
75
- max_val = torch.zeros(absolute_max_val.shape)
76
- for i in range(int(self.maxshrink * self.grid)):
77
- p = 1 - i / self.grid
78
- shrinked_min_val = p * absolute_min_val
79
- shrinked_max_val = p * absolute_max_val
80
-
81
- candidate_scales, candidate_zero_points = calculate_qparams(
82
- shrinked_min_val, shrinked_max_val, self.quantization_args
83
- )
84
- q = fake_quantize(
85
- observed,
86
- candidate_scales,
87
- candidate_zero_points,
88
- self.quantization_args,
89
- )
90
-
91
- q -= observed
92
- q.abs_()
93
- q.pow_(self.norm)
94
- if not reduce_dims:
95
- err = torch.sum(q)
96
- else:
97
- err = torch.sum(q, reduce_dims, keepdims=True)
98
-
99
- tmp = err < best
100
- if torch.any(tmp):
101
- best[tmp] = err[tmp]
102
- min_val[tmp] = shrinked_min_val[tmp]
103
- max_val[tmp] = shrinked_max_val[tmp]
104
- return min_val, max_val
105
-
106
- def calculate_qparams(
107
- self,
108
- observed: Tensor,
109
- reduce_dims: Optional[Tuple[int]] = None,
110
- tensor_id: Optional[Any] = None,
111
- ) -> Tuple[FloatTensor, IntTensor]:
112
- """
113
- Updates the mse-clipped min and max values of the observed tensor using
114
- a moving average smoothed by the averaging_constant
115
-
116
- :param observed: observed tensor to calculate quantization parameters for
117
- :param reduce_dims: optional tuple of dimensions to reduce along,
118
- returned scale and zero point will be shaped (1,) along the
119
- reduced dimensions
120
- :param tensor_id: Optional id if different ranges of observed tensors are
121
- passed, useful for sharding tensors by group_size
122
- :return: tuple of scale and zero point derived from the observed tensor
123
- """
124
- min_val, max_val = self.calculate_mse_min_max(observed, reduce_dims)
125
-
126
- running_min_val = self.min_val.get(tensor_id, None)
127
- running_max_val = self.max_val.get(tensor_id, None)
128
-
129
- if running_min_val is None or running_max_val is None:
130
- updated_min_val = min_val
131
- updated_max_val = max_val
132
- else:
133
- updated_min_val = running_min_val + self.averaging_constant * (
134
- min_val - running_min_val
135
- )
136
- updated_max_val = running_max_val + self.averaging_constant * (
137
- max_val - running_max_val
138
- )
139
-
140
- tensor_id = tensor_id or "default"
141
- self.min_val[tensor_id] = updated_min_val
142
- self.max_val[tensor_id] = updated_max_val
143
-
144
- return calculate_qparams(
145
- updated_min_val, updated_max_val, self.quantization_args
146
- )
147
-
148
- def get_qparams_along_dim(
149
- self, observed, dim: int, tensor_id: Optional[Any] = None
150
- ):
151
- reduce_dims = tuple(idx for idx in range(observed.ndim) if idx != dim)
152
- return self.calculate_qparams(
153
- observed, reduce_dims=reduce_dims, tensor_id=tensor_id
154
- )
155
-
156
- def reset(self):
157
- """
158
- Reset the state of the observer, including min and maximum values
159
- """
160
- super().reset()
161
- self.min_val = {}
162
- self.max_val = {}